Files
llmiotsafe/DPO_FULLLOG_QWEN2B_SERVER_BUNDLE/scripts/train_dpo.py
2026-05-12 17:01:39 +08:00

311 lines
13 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import inspect
import json
import os
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import DPOConfig, DPOTrainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="QLoRA + DPO training for Qwen3.5-9B.")
parser.add_argument("--model-name", required=True)
parser.add_argument("--train-file", required=True)
parser.add_argument("--dev-file", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--ref-logps-cache-dir", default=None)
parser.add_argument("--precompute-only", action="store_true")
parser.add_argument("--learning-rate", type=float, default=5e-7)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
parser.add_argument("--max-length", type=int, default=6144)
parser.add_argument("--max-prompt-length", type=int, default=5632)
parser.add_argument("--max-completion-length", type=int, default=None)
parser.add_argument("--truncation-mode", default="keep_end")
parser.add_argument("--padding-free", action="store_true")
parser.add_argument("--use-logits-to-keep", action="store_true")
parser.add_argument("--precompute-ref-log-probs", action="store_true")
parser.add_argument("--reference-free", action="store_true")
parser.add_argument("--beta", type=float, default=0.1)
parser.add_argument("--eval-steps", type=int, default=100)
parser.add_argument("--save-steps", type=int, default=100)
parser.add_argument("--logging-steps", type=int, default=10)
parser.add_argument("--torch-empty-cache-steps", type=int, default=10)
parser.add_argument("--warmup-ratio", type=float, default=0.03)
parser.add_argument("--weight-decay", type=float, default=0.0)
parser.add_argument("--optim", default="paged_adamw_8bit")
parser.add_argument("--load-best-model-at-end", action="store_true")
parser.add_argument("--lora-r", type=int, default=64)
parser.add_argument("--lora-alpha", type=int, default=128)
parser.add_argument("--lora-dropout", type=float, default=0.05)
parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"])
parser.add_argument("--resume-from-checkpoint", default=None)
parser.add_argument("--bf16", action="store_true")
return parser.parse_args()
def load_preference_dataset(train_file: str, dev_file: str):
files = {"train": train_file, "dev": dev_file}
dataset = load_dataset("json", data_files=files)
keep_cols = {"prompt", "chosen", "rejected", "pair_id", "chosen_source", "rejected_source", "rejected_error_type"}
def trim_cols(example):
return {k: v for k, v in example.items() if k in keep_cols}
return dataset.map(trim_cols)
def cache_paths(cache_dir: str | None):
if not cache_dir:
return None, None, None
root = Path(cache_dir)
return root, root / "train_dataset", root / "dev_dataset"
def try_load_cached_ref_datasets(cache_dir: str | None):
root, train_dir, dev_dir = cache_paths(cache_dir)
if root is None or not train_dir.exists() or not dev_dir.exists():
return None
train_dataset = load_from_disk(str(train_dir))
dev_dataset = load_from_disk(str(dev_dir))
needed = {"ref_chosen_logps", "ref_rejected_logps"}
if needed.issubset(set(train_dataset.column_names)) and needed.issubset(set(dev_dataset.column_names)):
print(f"Loaded cached reference log-prob datasets from: {root}")
return {"train": train_dataset, "dev": dev_dataset}
return None
def save_cached_ref_datasets(cache_dir: str | None, train_dataset, dev_dataset, summary: dict) -> None:
root, train_dir, dev_dir = cache_paths(cache_dir)
if root is None:
return
root.mkdir(parents=True, exist_ok=True)
train_dataset.save_to_disk(str(train_dir))
dev_dataset.save_to_disk(str(dev_dir))
with open(root / "cache_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"Saved cached reference log-prob datasets to: {root}")
def build_model_and_tokenizer(args: argparse.Namespace):
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=compute_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name,
trust_remote_code=True,
use_fast=False,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
trust_remote_code=True,
quantization_config=bnb_config,
torch_dtype=compute_dtype,
attn_implementation=args.attn_implementation,
device_map={"": local_rank},
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
return model, tokenizer
def main() -> None:
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
use_reference_side = not args.reference_free
cached_dataset = try_load_cached_ref_datasets(args.ref_logps_cache_dir) if use_reference_side else None
dataset = cached_dataset if cached_dataset is not None else load_preference_dataset(args.train_file, args.dev_file)
model, tokenizer = build_model_and_tokenizer(args)
enable_best_model = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
dpo_kwargs = dict(
output_dir=str(output_dir),
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
lr_scheduler_type="cosine",
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
optim=args.optim,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
save_strategy="steps",
save_total_limit=2,
bf16=args.bf16,
report_to="none",
remove_unused_columns=False,
dataloader_num_workers=2,
beta=args.beta,
truncation_mode=args.truncation_mode,
padding_free=args.padding_free,
use_logits_to_keep=args.use_logits_to_keep,
torch_empty_cache_steps=args.torch_empty_cache_steps,
load_best_model_at_end=enable_best_model,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
dpo_sig = inspect.signature(DPOConfig.__init__).parameters
if "eval_strategy" in dpo_sig:
dpo_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in dpo_sig:
dpo_kwargs["evaluation_strategy"] = "steps"
if "max_length" in dpo_sig:
dpo_kwargs["max_length"] = args.max_length
if "max_prompt_length" in dpo_sig:
dpo_kwargs["max_prompt_length"] = args.max_prompt_length
if "max_completion_length" in dpo_sig and args.max_completion_length is not None:
dpo_kwargs["max_completion_length"] = args.max_completion_length
using_cached_ref_logps = cached_dataset is not None
if "precompute_ref_log_probs" in dpo_sig and args.precompute_ref_log_probs and not using_cached_ref_logps and use_reference_side:
dpo_kwargs["precompute_ref_log_probs"] = True
if "reference_free" in dpo_sig and args.reference_free:
dpo_kwargs["reference_free"] = True
if "padding_free" in dpo_sig and args.padding_free:
dpo_kwargs["padding_free"] = True
if "use_logits_to_keep" in dpo_sig and args.use_logits_to_keep:
dpo_kwargs["use_logits_to_keep"] = True
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if k in dpo_sig}
train_args = DPOConfig(**dpo_kwargs)
trainer_kwargs = {}
trainer_sig = inspect.signature(DPOTrainer.__init__).parameters
if "max_length" in trainer_sig:
trainer_kwargs["max_length"] = args.max_length
if "max_prompt_length" in trainer_sig:
trainer_kwargs["max_prompt_length"] = args.max_prompt_length
if "max_completion_length" in trainer_sig and args.max_completion_length is not None:
trainer_kwargs["max_completion_length"] = args.max_completion_length
if "truncation_mode" in trainer_sig:
trainer_kwargs["truncation_mode"] = args.truncation_mode
if "precompute_ref_log_probs" in trainer_sig and args.precompute_ref_log_probs:
trainer_kwargs["precompute_ref_log_probs"] = True
if "reference_free" in trainer_sig and args.reference_free:
trainer_kwargs["reference_free"] = True
if "padding_free" in trainer_sig and args.padding_free:
trainer_kwargs["padding_free"] = True
if "use_logits_to_keep" in trainer_sig and args.use_logits_to_keep:
trainer_kwargs["use_logits_to_keep"] = True
if "processing_class" not in trainer_sig and "tokenizer" in trainer_sig:
trainer_kwargs["tokenizer"] = tokenizer
common_trainer_kwargs = dict(
model=model,
ref_model=None,
args=train_args,
train_dataset=dataset["train"],
eval_dataset=dataset["dev"],
peft_config=peft_config,
)
if "processing_class" in trainer_sig:
common_trainer_kwargs["processing_class"] = tokenizer
elif "tokenizer" in trainer_sig:
common_trainer_kwargs["tokenizer"] = tokenizer
trainer = DPOTrainer(
**common_trainer_kwargs,
**trainer_kwargs,
)
if args.ref_logps_cache_dir and not using_cached_ref_logps and args.precompute_ref_log_probs and use_reference_side:
trainer.get_train_dataloader()
trainer.get_eval_dataloader()
train_cols = set(trainer.train_dataset.column_names)
dev_cols = set(trainer.eval_dataset.column_names)
needed = {"ref_chosen_logps", "ref_rejected_logps"}
if needed.issubset(train_cols) and needed.issubset(dev_cols):
save_cached_ref_datasets(
args.ref_logps_cache_dir,
trainer.train_dataset,
trainer.eval_dataset,
{
"model_name": args.model_name,
"train_file": args.train_file,
"dev_file": args.dev_file,
"num_train_examples": len(trainer.train_dataset),
"num_dev_examples": len(trainer.eval_dataset),
"max_length": args.max_length,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
},
)
else:
print("Warning: ref log-prob columns were not found after precompute; skipping cache save.")
if args.precompute_only:
print("Precompute-only mode finished.")
return
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
tokenizer.save_pretrained(output_dir)
summary = {
"model_name": args.model_name,
"train_file": args.train_file,
"dev_file": args.dev_file,
"output_dir": str(output_dir),
"ref_logps_cache_dir": args.ref_logps_cache_dir,
"used_cached_ref_logps": using_cached_ref_logps,
"reference_free": args.reference_free,
"num_train_examples": len(dataset["train"]),
"num_dev_examples": len(dataset["dev"]),
"max_length": args.max_length,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
"truncation_mode": args.truncation_mode,
"padding_free": args.padding_free,
"use_logits_to_keep": args.use_logits_to_keep,
"precompute_ref_log_probs": args.precompute_ref_log_probs,
"reference_free": args.reference_free,
"beta": args.beta,
"learning_rate": args.learning_rate,
"epochs": args.num_train_epochs,
}
with open(output_dir / "run_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()