Files
llmiotsafe/SFT_FULLLOG_QWEN2B_V2_SERVER_BUNDLE/scripts/train_sft.py
2026-05-12 17:01:39 +08:00

306 lines
12 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import inspect
import json
import os
import statistics
from dataclasses import dataclass
from pathlib import Path
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="QLoRA SFT training for Qwen3.5-2B on full-log SafeHome data.")
parser.add_argument("--model-name", required=True)
parser.add_argument("--train-file", required=True)
parser.add_argument("--dev-file", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--learning-rate", type=float, default=1e-5)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
parser.add_argument("--max-length", type=int, default=4096)
parser.add_argument("--max-prompt-length", type=int, default=3584)
parser.add_argument("--max-completion-length", type=int, default=512)
parser.add_argument("--eval-steps", type=int, default=50)
parser.add_argument("--save-steps", type=int, default=50)
parser.add_argument("--logging-steps", type=int, default=5)
parser.add_argument("--warmup-ratio", type=float, default=0.03)
parser.add_argument("--weight-decay", type=float, default=0.0)
parser.add_argument("--optim", default="paged_adamw_8bit")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--load-best-model-at-end", action="store_true")
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.05)
parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"])
parser.add_argument("--resume-from-checkpoint", default=None)
parser.add_argument("--bf16", action="store_true")
return parser.parse_args()
def load_raw_dataset(train_file: str, dev_file: str):
cache_root = Path(train_file).resolve().parent.parent / ".hf_datasets_cache"
cache_root.mkdir(parents=True, exist_ok=True)
return load_dataset(
"json",
data_files={"train": train_file, "dev": dev_file},
cache_dir=str(cache_root),
download_mode="force_redownload",
)
def build_model_and_tokenizer(args: argparse.Namespace):
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=compute_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name,
trust_remote_code=True,
use_fast=False,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
trust_remote_code=True,
quantization_config=bnb_config,
torch_dtype=compute_dtype,
attn_implementation=args.attn_implementation,
device_map={"": local_rank},
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, tokenizer
def tokenize_datasets(dataset, tokenizer, args: argparse.Namespace):
stats = {"prompt_before": [], "prompt_after": [], "completion_before": [], "completion_after": [], "full_after": []}
def normalize_messages(messages):
if isinstance(messages, dict):
normalized = []
system_text = str(messages.get("system", "")).strip()
user_text = str(messages.get("user", "")).strip()
if system_text:
normalized.append({"role": "system", "content": system_text})
if user_text:
normalized.append({"role": "user", "content": user_text})
return normalized
normalized = []
for item in messages:
if isinstance(item, dict):
role = item.get("role")
content = item.get("content", "")
if role:
normalized.append({"role": role, "content": content})
elif isinstance(item, str):
# Defensive fallback for malformed legacy rows like ["system", "user"].
# These rows should be fixed at dataset-build time, but we avoid crashing here.
continue
return normalized
def tokenize_example(example):
messages = normalize_messages(example["messages"])
if len(messages) < 2 or messages[-1].get("role") != "assistant":
return {"input_ids": [], "attention_mask": [], "labels": []}
if not any(m.get("role") == "user" for m in messages[:-1]):
return {"input_ids": [], "attention_mask": [], "labels": []}
prompt_ids = tokenizer.apply_chat_template(
messages[:-1],
tokenize=True,
add_generation_prompt=True,
)
full_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
)
if len(full_ids) < len(prompt_ids) or full_ids[:len(prompt_ids)] != prompt_ids:
prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
full_ids = tokenizer(full_text, add_special_tokens=False)["input_ids"]
response_ids = full_ids[len(prompt_ids):]
stats["prompt_before"].append(len(prompt_ids))
stats["completion_before"].append(len(response_ids))
prompt_ids = prompt_ids[-args.max_prompt_length:]
response_ids = response_ids[:args.max_completion_length]
if len(prompt_ids) + len(response_ids) > args.max_length:
overflow = len(prompt_ids) + len(response_ids) - args.max_length
prompt_ids = prompt_ids[overflow:]
if len(prompt_ids) + len(response_ids) > args.max_length:
response_ids = response_ids[: max(0, args.max_length - len(prompt_ids))]
input_ids = prompt_ids + response_ids
labels = ([-100] * len(prompt_ids)) + response_ids
attention_mask = [1] * len(input_ids)
stats["prompt_after"].append(len(prompt_ids))
stats["completion_after"].append(len(response_ids))
stats["full_after"].append(len(input_ids))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
remove_columns = dataset["train"].column_names
tokenized = dataset.map(tokenize_example, remove_columns=remove_columns)
tokenized = tokenized.filter(lambda x: len(x["input_ids"]) > 0 and any(v != -100 for v in x["labels"]))
summary = {}
for key, values in stats.items():
if values:
summary[key] = {
"count": len(values),
"min": min(values),
"mean": round(statistics.mean(values), 1),
"p50": sorted(values)[len(values) // 2],
"max": max(values),
}
return tokenized, summary
@dataclass
class SupervisedDataCollator:
tokenizer: AutoTokenizer
def __call__(self, features):
pad_id = self.tokenizer.pad_token_id
max_len = max(len(f["input_ids"]) for f in features)
input_ids = []
attention_mask = []
labels = []
for f in features:
pad_len = max_len - len(f["input_ids"])
input_ids.append(f["input_ids"] + ([pad_id] * pad_len))
attention_mask.append(f["attention_mask"] + ([0] * pad_len))
labels.append(f["labels"] + ([-100] * pad_len))
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
def main() -> None:
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
raw_dataset = load_raw_dataset(args.train_file, args.dev_file)
model, tokenizer = build_model_and_tokenizer(args)
tokenized_dataset, token_stats = tokenize_datasets(raw_dataset, tokenizer, args)
train_args_kwargs = dict(
output_dir=str(output_dir),
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
save_strategy="steps",
save_total_limit=2,
bf16=args.bf16,
report_to="none",
remove_unused_columns=False,
dataloader_num_workers=2,
lr_scheduler_type="cosine",
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
optim=args.optim,
seed=args.seed,
data_seed=args.seed,
greater_is_better=False,
metric_for_best_model="eval_loss",
ddp_find_unused_parameters=False,
)
sig = inspect.signature(TrainingArguments.__init__).parameters
if "eval_strategy" in sig:
train_args_kwargs["eval_strategy"] = "steps"
elif "evaluation_strategy" in sig:
train_args_kwargs["evaluation_strategy"] = "steps"
enable_best = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0
if "load_best_model_at_end" in sig:
train_args_kwargs["load_best_model_at_end"] = enable_best
train_args_kwargs = {k: v for k, v in train_args_kwargs.items() if k in sig}
training_args = TrainingArguments(**train_args_kwargs)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["dev"],
data_collator=SupervisedDataCollator(tokenizer),
)
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
tokenizer.save_pretrained(output_dir)
summary = {
"model_name": args.model_name,
"train_file": args.train_file,
"dev_file": args.dev_file,
"output_dir": str(output_dir),
"num_train_examples": len(tokenized_dataset["train"]),
"num_dev_examples": len(tokenized_dataset["dev"]),
"max_length": args.max_length,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
"learning_rate": args.learning_rate,
"epochs": args.num_train_epochs,
"token_stats": token_stats,
}
with (output_dir / "run_summary.json").open("w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
main()