#!/usr/bin/env python3 from __future__ import annotations import argparse import inspect import json import os from pathlib import Path import torch from datasets import load_dataset, load_from_disk from peft import LoraConfig, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from trl import DPOConfig, DPOTrainer def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="QLoRA + DPO training for Qwen3.5-9B.") parser.add_argument("--model-name", required=True) parser.add_argument("--train-file", required=True) parser.add_argument("--dev-file", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--ref-logps-cache-dir", default=None) parser.add_argument("--precompute-only", action="store_true") parser.add_argument("--learning-rate", type=float, default=5e-7) parser.add_argument("--num-train-epochs", type=float, default=1.0) parser.add_argument("--per-device-train-batch-size", type=int, default=1) parser.add_argument("--per-device-eval-batch-size", type=int, default=1) parser.add_argument("--gradient-accumulation-steps", type=int, default=16) parser.add_argument("--max-length", type=int, default=6144) parser.add_argument("--max-prompt-length", type=int, default=5632) parser.add_argument("--max-completion-length", type=int, default=None) parser.add_argument("--truncation-mode", default="keep_end") parser.add_argument("--padding-free", action="store_true") parser.add_argument("--use-logits-to-keep", action="store_true") parser.add_argument("--precompute-ref-log-probs", action="store_true") parser.add_argument("--reference-free", action="store_true") parser.add_argument("--beta", type=float, default=0.1) parser.add_argument("--eval-steps", type=int, default=100) parser.add_argument("--save-steps", type=int, default=100) parser.add_argument("--logging-steps", type=int, default=10) parser.add_argument("--torch-empty-cache-steps", type=int, default=10) parser.add_argument("--warmup-ratio", type=float, default=0.03) parser.add_argument("--weight-decay", type=float, default=0.0) parser.add_argument("--optim", default="paged_adamw_8bit") parser.add_argument("--load-best-model-at-end", action="store_true") parser.add_argument("--lora-r", type=int, default=64) parser.add_argument("--lora-alpha", type=int, default=128) parser.add_argument("--lora-dropout", type=float, default=0.05) parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"]) parser.add_argument("--resume-from-checkpoint", default=None) parser.add_argument("--bf16", action="store_true") return parser.parse_args() def load_preference_dataset(train_file: str, dev_file: str): files = {"train": train_file, "dev": dev_file} dataset = load_dataset("json", data_files=files) keep_cols = {"prompt", "chosen", "rejected", "pair_id", "chosen_source", "rejected_source", "rejected_error_type"} def trim_cols(example): return {k: v for k, v in example.items() if k in keep_cols} return dataset.map(trim_cols) def cache_paths(cache_dir: str | None): if not cache_dir: return None, None, None root = Path(cache_dir) return root, root / "train_dataset", root / "dev_dataset" def try_load_cached_ref_datasets(cache_dir: str | None): root, train_dir, dev_dir = cache_paths(cache_dir) if root is None or not train_dir.exists() or not dev_dir.exists(): return None train_dataset = load_from_disk(str(train_dir)) dev_dataset = load_from_disk(str(dev_dir)) needed = {"ref_chosen_logps", "ref_rejected_logps"} if needed.issubset(set(train_dataset.column_names)) and needed.issubset(set(dev_dataset.column_names)): print(f"Loaded cached reference log-prob datasets from: {root}") return {"train": train_dataset, "dev": dev_dataset} return None def save_cached_ref_datasets(cache_dir: str | None, train_dataset, dev_dataset, summary: dict) -> None: root, train_dir, dev_dir = cache_paths(cache_dir) if root is None: return root.mkdir(parents=True, exist_ok=True) train_dataset.save_to_disk(str(train_dir)) dev_dataset.save_to_disk(str(dev_dir)) with open(root / "cache_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"Saved cached reference log-prob datasets to: {root}") def build_model_and_tokenizer(args: argparse.Namespace): local_rank = int(os.environ.get("LOCAL_RANK", "0")) compute_dtype = torch.bfloat16 if args.bf16 else torch.float16 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype, ) tokenizer = AutoTokenizer.from_pretrained( args.model_name, trust_remote_code=True, use_fast=False, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( args.model_name, trust_remote_code=True, quantization_config=bnb_config, torch_dtype=compute_dtype, attn_implementation=args.attn_implementation, device_map={"": local_rank}, ) model.config.use_cache = False model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) return model, tokenizer def main() -> None: args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) use_reference_side = not args.reference_free cached_dataset = try_load_cached_ref_datasets(args.ref_logps_cache_dir) if use_reference_side else None dataset = cached_dataset if cached_dataset is not None else load_preference_dataset(args.train_file, args.dev_file) model, tokenizer = build_model_and_tokenizer(args) enable_best_model = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0 peft_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) dpo_kwargs = dict( output_dir=str(output_dir), learning_rate=args.learning_rate, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, lr_scheduler_type="cosine", warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, optim=args.optim, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, save_strategy="steps", save_total_limit=2, bf16=args.bf16, report_to="none", remove_unused_columns=False, dataloader_num_workers=2, beta=args.beta, truncation_mode=args.truncation_mode, padding_free=args.padding_free, use_logits_to_keep=args.use_logits_to_keep, torch_empty_cache_steps=args.torch_empty_cache_steps, load_best_model_at_end=enable_best_model, metric_for_best_model="eval_loss", greater_is_better=False, ) dpo_sig = inspect.signature(DPOConfig.__init__).parameters if "eval_strategy" in dpo_sig: dpo_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in dpo_sig: dpo_kwargs["evaluation_strategy"] = "steps" if "max_length" in dpo_sig: dpo_kwargs["max_length"] = args.max_length if "max_prompt_length" in dpo_sig: dpo_kwargs["max_prompt_length"] = args.max_prompt_length if "max_completion_length" in dpo_sig and args.max_completion_length is not None: dpo_kwargs["max_completion_length"] = args.max_completion_length using_cached_ref_logps = cached_dataset is not None if "precompute_ref_log_probs" in dpo_sig and args.precompute_ref_log_probs and not using_cached_ref_logps and use_reference_side: dpo_kwargs["precompute_ref_log_probs"] = True if "reference_free" in dpo_sig and args.reference_free: dpo_kwargs["reference_free"] = True if "padding_free" in dpo_sig and args.padding_free: dpo_kwargs["padding_free"] = True if "use_logits_to_keep" in dpo_sig and args.use_logits_to_keep: dpo_kwargs["use_logits_to_keep"] = True dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if k in dpo_sig} train_args = DPOConfig(**dpo_kwargs) trainer_kwargs = {} trainer_sig = inspect.signature(DPOTrainer.__init__).parameters if "max_length" in trainer_sig: trainer_kwargs["max_length"] = args.max_length if "max_prompt_length" in trainer_sig: trainer_kwargs["max_prompt_length"] = args.max_prompt_length if "max_completion_length" in trainer_sig and args.max_completion_length is not None: trainer_kwargs["max_completion_length"] = args.max_completion_length if "truncation_mode" in trainer_sig: trainer_kwargs["truncation_mode"] = args.truncation_mode if "precompute_ref_log_probs" in trainer_sig and args.precompute_ref_log_probs: trainer_kwargs["precompute_ref_log_probs"] = True if "reference_free" in trainer_sig and args.reference_free: trainer_kwargs["reference_free"] = True if "padding_free" in trainer_sig and args.padding_free: trainer_kwargs["padding_free"] = True if "use_logits_to_keep" in trainer_sig and args.use_logits_to_keep: trainer_kwargs["use_logits_to_keep"] = True if "processing_class" not in trainer_sig and "tokenizer" in trainer_sig: trainer_kwargs["tokenizer"] = tokenizer common_trainer_kwargs = dict( model=model, ref_model=None, args=train_args, train_dataset=dataset["train"], eval_dataset=dataset["dev"], peft_config=peft_config, ) if "processing_class" in trainer_sig: common_trainer_kwargs["processing_class"] = tokenizer elif "tokenizer" in trainer_sig: common_trainer_kwargs["tokenizer"] = tokenizer trainer = DPOTrainer( **common_trainer_kwargs, **trainer_kwargs, ) if args.ref_logps_cache_dir and not using_cached_ref_logps and args.precompute_ref_log_probs and use_reference_side: trainer.get_train_dataloader() trainer.get_eval_dataloader() train_cols = set(trainer.train_dataset.column_names) dev_cols = set(trainer.eval_dataset.column_names) needed = {"ref_chosen_logps", "ref_rejected_logps"} if needed.issubset(train_cols) and needed.issubset(dev_cols): save_cached_ref_datasets( args.ref_logps_cache_dir, trainer.train_dataset, trainer.eval_dataset, { "model_name": args.model_name, "train_file": args.train_file, "dev_file": args.dev_file, "num_train_examples": len(trainer.train_dataset), "num_dev_examples": len(trainer.eval_dataset), "max_length": args.max_length, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, }, ) else: print("Warning: ref log-prob columns were not found after precompute; skipping cache save.") if args.precompute_only: print("Precompute-only mode finished.") return trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) trainer.save_model() tokenizer.save_pretrained(output_dir) summary = { "model_name": args.model_name, "train_file": args.train_file, "dev_file": args.dev_file, "output_dir": str(output_dir), "ref_logps_cache_dir": args.ref_logps_cache_dir, "used_cached_ref_logps": using_cached_ref_logps, "reference_free": args.reference_free, "num_train_examples": len(dataset["train"]), "num_dev_examples": len(dataset["dev"]), "max_length": args.max_length, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, "truncation_mode": args.truncation_mode, "padding_free": args.padding_free, "use_logits_to_keep": args.use_logits_to_keep, "precompute_ref_log_probs": args.precompute_ref_log_probs, "reference_free": args.reference_free, "beta": args.beta, "learning_rate": args.learning_rate, "epochs": args.num_train_epochs, } with open(output_dir / "run_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) if __name__ == "__main__": main()