311 lines
13 KiB
Python
311 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import inspect
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from datasets import load_dataset, load_from_disk
|
|
from peft import LoraConfig, prepare_model_for_kbit_training
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="QLoRA + DPO training for Qwen3.5-9B.")
|
|
parser.add_argument("--model-name", required=True)
|
|
parser.add_argument("--train-file", required=True)
|
|
parser.add_argument("--dev-file", required=True)
|
|
parser.add_argument("--output-dir", required=True)
|
|
parser.add_argument("--ref-logps-cache-dir", default=None)
|
|
parser.add_argument("--precompute-only", action="store_true")
|
|
parser.add_argument("--learning-rate", type=float, default=5e-7)
|
|
parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
|
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
|
|
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
|
|
parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
|
|
parser.add_argument("--max-length", type=int, default=6144)
|
|
parser.add_argument("--max-prompt-length", type=int, default=5632)
|
|
parser.add_argument("--max-completion-length", type=int, default=None)
|
|
parser.add_argument("--truncation-mode", default="keep_end")
|
|
parser.add_argument("--padding-free", action="store_true")
|
|
parser.add_argument("--use-logits-to-keep", action="store_true")
|
|
parser.add_argument("--precompute-ref-log-probs", action="store_true")
|
|
parser.add_argument("--reference-free", action="store_true")
|
|
parser.add_argument("--beta", type=float, default=0.1)
|
|
parser.add_argument("--eval-steps", type=int, default=100)
|
|
parser.add_argument("--save-steps", type=int, default=100)
|
|
parser.add_argument("--logging-steps", type=int, default=10)
|
|
parser.add_argument("--torch-empty-cache-steps", type=int, default=10)
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.03)
|
|
parser.add_argument("--weight-decay", type=float, default=0.0)
|
|
parser.add_argument("--optim", default="paged_adamw_8bit")
|
|
parser.add_argument("--load-best-model-at-end", action="store_true")
|
|
parser.add_argument("--lora-r", type=int, default=64)
|
|
parser.add_argument("--lora-alpha", type=int, default=128)
|
|
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
|
parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"])
|
|
parser.add_argument("--resume-from-checkpoint", default=None)
|
|
parser.add_argument("--bf16", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_preference_dataset(train_file: str, dev_file: str):
|
|
files = {"train": train_file, "dev": dev_file}
|
|
dataset = load_dataset("json", data_files=files)
|
|
keep_cols = {"prompt", "chosen", "rejected", "pair_id", "chosen_source", "rejected_source", "rejected_error_type"}
|
|
|
|
def trim_cols(example):
|
|
return {k: v for k, v in example.items() if k in keep_cols}
|
|
|
|
return dataset.map(trim_cols)
|
|
|
|
|
|
def cache_paths(cache_dir: str | None):
|
|
if not cache_dir:
|
|
return None, None, None
|
|
root = Path(cache_dir)
|
|
return root, root / "train_dataset", root / "dev_dataset"
|
|
|
|
|
|
def try_load_cached_ref_datasets(cache_dir: str | None):
|
|
root, train_dir, dev_dir = cache_paths(cache_dir)
|
|
if root is None or not train_dir.exists() or not dev_dir.exists():
|
|
return None
|
|
train_dataset = load_from_disk(str(train_dir))
|
|
dev_dataset = load_from_disk(str(dev_dir))
|
|
needed = {"ref_chosen_logps", "ref_rejected_logps"}
|
|
if needed.issubset(set(train_dataset.column_names)) and needed.issubset(set(dev_dataset.column_names)):
|
|
print(f"Loaded cached reference log-prob datasets from: {root}")
|
|
return {"train": train_dataset, "dev": dev_dataset}
|
|
return None
|
|
|
|
|
|
def save_cached_ref_datasets(cache_dir: str | None, train_dataset, dev_dataset, summary: dict) -> None:
|
|
root, train_dir, dev_dir = cache_paths(cache_dir)
|
|
if root is None:
|
|
return
|
|
root.mkdir(parents=True, exist_ok=True)
|
|
train_dataset.save_to_disk(str(train_dir))
|
|
dev_dataset.save_to_disk(str(dev_dir))
|
|
with open(root / "cache_summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
print(f"Saved cached reference log-prob datasets to: {root}")
|
|
|
|
|
|
def build_model_and_tokenizer(args: argparse.Namespace):
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.model_name,
|
|
trust_remote_code=True,
|
|
use_fast=False,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "right"
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_name,
|
|
trust_remote_code=True,
|
|
quantization_config=bnb_config,
|
|
torch_dtype=compute_dtype,
|
|
attn_implementation=args.attn_implementation,
|
|
device_map={"": local_rank},
|
|
)
|
|
model.config.use_cache = False
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
use_reference_side = not args.reference_free
|
|
cached_dataset = try_load_cached_ref_datasets(args.ref_logps_cache_dir) if use_reference_side else None
|
|
dataset = cached_dataset if cached_dataset is not None else load_preference_dataset(args.train_file, args.dev_file)
|
|
model, tokenizer = build_model_and_tokenizer(args)
|
|
enable_best_model = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0
|
|
|
|
peft_config = LoraConfig(
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
)
|
|
|
|
dpo_kwargs = dict(
|
|
output_dir=str(output_dir),
|
|
learning_rate=args.learning_rate,
|
|
num_train_epochs=args.num_train_epochs,
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
gradient_checkpointing=True,
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=args.warmup_ratio,
|
|
weight_decay=args.weight_decay,
|
|
optim=args.optim,
|
|
logging_steps=args.logging_steps,
|
|
save_steps=args.save_steps,
|
|
eval_steps=args.eval_steps,
|
|
save_strategy="steps",
|
|
save_total_limit=2,
|
|
bf16=args.bf16,
|
|
report_to="none",
|
|
remove_unused_columns=False,
|
|
dataloader_num_workers=2,
|
|
beta=args.beta,
|
|
truncation_mode=args.truncation_mode,
|
|
padding_free=args.padding_free,
|
|
use_logits_to_keep=args.use_logits_to_keep,
|
|
torch_empty_cache_steps=args.torch_empty_cache_steps,
|
|
load_best_model_at_end=enable_best_model,
|
|
metric_for_best_model="eval_loss",
|
|
greater_is_better=False,
|
|
)
|
|
|
|
dpo_sig = inspect.signature(DPOConfig.__init__).parameters
|
|
if "eval_strategy" in dpo_sig:
|
|
dpo_kwargs["eval_strategy"] = "steps"
|
|
elif "evaluation_strategy" in dpo_sig:
|
|
dpo_kwargs["evaluation_strategy"] = "steps"
|
|
|
|
if "max_length" in dpo_sig:
|
|
dpo_kwargs["max_length"] = args.max_length
|
|
if "max_prompt_length" in dpo_sig:
|
|
dpo_kwargs["max_prompt_length"] = args.max_prompt_length
|
|
if "max_completion_length" in dpo_sig and args.max_completion_length is not None:
|
|
dpo_kwargs["max_completion_length"] = args.max_completion_length
|
|
using_cached_ref_logps = cached_dataset is not None
|
|
if "precompute_ref_log_probs" in dpo_sig and args.precompute_ref_log_probs and not using_cached_ref_logps and use_reference_side:
|
|
dpo_kwargs["precompute_ref_log_probs"] = True
|
|
if "reference_free" in dpo_sig and args.reference_free:
|
|
dpo_kwargs["reference_free"] = True
|
|
if "padding_free" in dpo_sig and args.padding_free:
|
|
dpo_kwargs["padding_free"] = True
|
|
if "use_logits_to_keep" in dpo_sig and args.use_logits_to_keep:
|
|
dpo_kwargs["use_logits_to_keep"] = True
|
|
|
|
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if k in dpo_sig}
|
|
train_args = DPOConfig(**dpo_kwargs)
|
|
|
|
trainer_kwargs = {}
|
|
trainer_sig = inspect.signature(DPOTrainer.__init__).parameters
|
|
if "max_length" in trainer_sig:
|
|
trainer_kwargs["max_length"] = args.max_length
|
|
if "max_prompt_length" in trainer_sig:
|
|
trainer_kwargs["max_prompt_length"] = args.max_prompt_length
|
|
if "max_completion_length" in trainer_sig and args.max_completion_length is not None:
|
|
trainer_kwargs["max_completion_length"] = args.max_completion_length
|
|
if "truncation_mode" in trainer_sig:
|
|
trainer_kwargs["truncation_mode"] = args.truncation_mode
|
|
if "precompute_ref_log_probs" in trainer_sig and args.precompute_ref_log_probs:
|
|
trainer_kwargs["precompute_ref_log_probs"] = True
|
|
if "reference_free" in trainer_sig and args.reference_free:
|
|
trainer_kwargs["reference_free"] = True
|
|
if "padding_free" in trainer_sig and args.padding_free:
|
|
trainer_kwargs["padding_free"] = True
|
|
if "use_logits_to_keep" in trainer_sig and args.use_logits_to_keep:
|
|
trainer_kwargs["use_logits_to_keep"] = True
|
|
if "processing_class" not in trainer_sig and "tokenizer" in trainer_sig:
|
|
trainer_kwargs["tokenizer"] = tokenizer
|
|
|
|
common_trainer_kwargs = dict(
|
|
model=model,
|
|
ref_model=None,
|
|
args=train_args,
|
|
train_dataset=dataset["train"],
|
|
eval_dataset=dataset["dev"],
|
|
peft_config=peft_config,
|
|
)
|
|
if "processing_class" in trainer_sig:
|
|
common_trainer_kwargs["processing_class"] = tokenizer
|
|
elif "tokenizer" in trainer_sig:
|
|
common_trainer_kwargs["tokenizer"] = tokenizer
|
|
|
|
trainer = DPOTrainer(
|
|
**common_trainer_kwargs,
|
|
**trainer_kwargs,
|
|
)
|
|
|
|
if args.ref_logps_cache_dir and not using_cached_ref_logps and args.precompute_ref_log_probs and use_reference_side:
|
|
trainer.get_train_dataloader()
|
|
trainer.get_eval_dataloader()
|
|
train_cols = set(trainer.train_dataset.column_names)
|
|
dev_cols = set(trainer.eval_dataset.column_names)
|
|
needed = {"ref_chosen_logps", "ref_rejected_logps"}
|
|
if needed.issubset(train_cols) and needed.issubset(dev_cols):
|
|
save_cached_ref_datasets(
|
|
args.ref_logps_cache_dir,
|
|
trainer.train_dataset,
|
|
trainer.eval_dataset,
|
|
{
|
|
"model_name": args.model_name,
|
|
"train_file": args.train_file,
|
|
"dev_file": args.dev_file,
|
|
"num_train_examples": len(trainer.train_dataset),
|
|
"num_dev_examples": len(trainer.eval_dataset),
|
|
"max_length": args.max_length,
|
|
"max_prompt_length": args.max_prompt_length,
|
|
"max_completion_length": args.max_completion_length,
|
|
},
|
|
)
|
|
else:
|
|
print("Warning: ref log-prob columns were not found after precompute; skipping cache save.")
|
|
|
|
if args.precompute_only:
|
|
print("Precompute-only mode finished.")
|
|
return
|
|
|
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
|
trainer.save_model()
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
summary = {
|
|
"model_name": args.model_name,
|
|
"train_file": args.train_file,
|
|
"dev_file": args.dev_file,
|
|
"output_dir": str(output_dir),
|
|
"ref_logps_cache_dir": args.ref_logps_cache_dir,
|
|
"used_cached_ref_logps": using_cached_ref_logps,
|
|
"reference_free": args.reference_free,
|
|
"num_train_examples": len(dataset["train"]),
|
|
"num_dev_examples": len(dataset["dev"]),
|
|
"max_length": args.max_length,
|
|
"max_prompt_length": args.max_prompt_length,
|
|
"max_completion_length": args.max_completion_length,
|
|
"truncation_mode": args.truncation_mode,
|
|
"padding_free": args.padding_free,
|
|
"use_logits_to_keep": args.use_logits_to_keep,
|
|
"precompute_ref_log_probs": args.precompute_ref_log_probs,
|
|
"reference_free": args.reference_free,
|
|
"beta": args.beta,
|
|
"learning_rate": args.learning_rate,
|
|
"epochs": args.num_train_epochs,
|
|
}
|
|
with open(output_dir / "run_summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|