#!/usr/bin/env python3 from __future__ import annotations import argparse import inspect import json import os import statistics from dataclasses import dataclass from pathlib import Path import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="QLoRA SFT training for Qwen3.5-2B on full-log SafeHome data.") parser.add_argument("--model-name", required=True) parser.add_argument("--train-file", required=True) parser.add_argument("--dev-file", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--learning-rate", type=float, default=1e-5) parser.add_argument("--num-train-epochs", type=float, default=1.0) parser.add_argument("--per-device-train-batch-size", type=int, default=1) parser.add_argument("--per-device-eval-batch-size", type=int, default=1) parser.add_argument("--gradient-accumulation-steps", type=int, default=16) parser.add_argument("--max-length", type=int, default=4096) parser.add_argument("--max-prompt-length", type=int, default=3584) parser.add_argument("--max-completion-length", type=int, default=512) parser.add_argument("--eval-steps", type=int, default=50) parser.add_argument("--save-steps", type=int, default=50) parser.add_argument("--logging-steps", type=int, default=5) parser.add_argument("--warmup-ratio", type=float, default=0.03) parser.add_argument("--weight-decay", type=float, default=0.0) parser.add_argument("--optim", default="paged_adamw_8bit") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--load-best-model-at-end", action="store_true") parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--lora-alpha", type=int, default=32) parser.add_argument("--lora-dropout", type=float, default=0.05) parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"]) parser.add_argument("--resume-from-checkpoint", default=None) parser.add_argument("--bf16", action="store_true") return parser.parse_args() def load_raw_dataset(train_file: str, dev_file: str): return load_dataset("json", data_files={"train": train_file, "dev": dev_file}) def build_model_and_tokenizer(args: argparse.Namespace): local_rank = int(os.environ.get("LOCAL_RANK", "0")) compute_dtype = torch.bfloat16 if args.bf16 else torch.float16 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype, ) tokenizer = AutoTokenizer.from_pretrained( args.model_name, trust_remote_code=True, use_fast=False, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( args.model_name, trust_remote_code=True, quantization_config=bnb_config, torch_dtype=compute_dtype, attn_implementation=args.attn_implementation, device_map={"": local_rank}, ) model.config.use_cache = False model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) peft_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, peft_config) model.print_trainable_parameters() return model, tokenizer def tokenize_datasets(dataset, tokenizer, args: argparse.Namespace): stats = {"prompt_before": [], "prompt_after": [], "completion_before": [], "completion_after": [], "full_after": []} def tokenize_example(example): messages = example["messages"] prompt_ids = tokenizer.apply_chat_template( messages[:-1], tokenize=True, add_generation_prompt=True, ) full_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, ) if len(full_ids) < len(prompt_ids) or full_ids[:len(prompt_ids)] != prompt_ids: prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True) full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] full_ids = tokenizer(full_text, add_special_tokens=False)["input_ids"] response_ids = full_ids[len(prompt_ids):] stats["prompt_before"].append(len(prompt_ids)) stats["completion_before"].append(len(response_ids)) prompt_ids = prompt_ids[-args.max_prompt_length:] response_ids = response_ids[:args.max_completion_length] if len(prompt_ids) + len(response_ids) > args.max_length: overflow = len(prompt_ids) + len(response_ids) - args.max_length prompt_ids = prompt_ids[overflow:] if len(prompt_ids) + len(response_ids) > args.max_length: response_ids = response_ids[: max(0, args.max_length - len(prompt_ids))] input_ids = prompt_ids + response_ids labels = ([-100] * len(prompt_ids)) + response_ids attention_mask = [1] * len(input_ids) stats["prompt_after"].append(len(prompt_ids)) stats["completion_after"].append(len(response_ids)) stats["full_after"].append(len(input_ids)) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } remove_columns = dataset["train"].column_names tokenized = dataset.map(tokenize_example, remove_columns=remove_columns) tokenized = tokenized.filter(lambda x: len(x["input_ids"]) > 0 and any(v != -100 for v in x["labels"])) summary = {} for key, values in stats.items(): if values: summary[key] = { "count": len(values), "min": min(values), "mean": round(statistics.mean(values), 1), "p50": sorted(values)[len(values) // 2], "max": max(values), } return tokenized, summary @dataclass class SupervisedDataCollator: tokenizer: AutoTokenizer def __call__(self, features): pad_id = self.tokenizer.pad_token_id max_len = max(len(f["input_ids"]) for f in features) input_ids = [] attention_mask = [] labels = [] for f in features: pad_len = max_len - len(f["input_ids"]) input_ids.append(f["input_ids"] + ([pad_id] * pad_len)) attention_mask.append(f["attention_mask"] + ([0] * pad_len)) labels.append(f["labels"] + ([-100] * pad_len)) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } def main() -> None: args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) raw_dataset = load_raw_dataset(args.train_file, args.dev_file) model, tokenizer = build_model_and_tokenizer(args) tokenized_dataset, token_stats = tokenize_datasets(raw_dataset, tokenizer, args) train_args_kwargs = dict( output_dir=str(output_dir), learning_rate=args.learning_rate, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, save_strategy="steps", save_total_limit=2, bf16=args.bf16, report_to="none", remove_unused_columns=False, dataloader_num_workers=2, lr_scheduler_type="cosine", warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, optim=args.optim, seed=args.seed, data_seed=args.seed, greater_is_better=False, metric_for_best_model="eval_loss", ddp_find_unused_parameters=False, ) sig = inspect.signature(TrainingArguments.__init__).parameters if "eval_strategy" in sig: train_args_kwargs["eval_strategy"] = "steps" elif "evaluation_strategy" in sig: train_args_kwargs["evaluation_strategy"] = "steps" enable_best = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0 if "load_best_model_at_end" in sig: train_args_kwargs["load_best_model_at_end"] = enable_best train_args_kwargs = {k: v for k, v in train_args_kwargs.items() if k in sig} training_args = TrainingArguments(**train_args_kwargs) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["dev"], data_collator=SupervisedDataCollator(tokenizer), ) trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) trainer.save_model() tokenizer.save_pretrained(output_dir) summary = { "model_name": args.model_name, "train_file": args.train_file, "dev_file": args.dev_file, "output_dir": str(output_dir), "num_train_examples": len(tokenized_dataset["train"]), "num_dev_examples": len(tokenized_dataset["dev"]), "max_length": args.max_length, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, "learning_rate": args.learning_rate, "epochs": args.num_train_epochs, "token_stats": token_stats, } with (output_dir / "run_summary.json").open("w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) if __name__ == "__main__": main()