270 lines
10 KiB
Python
270 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import inspect
|
|
import json
|
|
import os
|
|
import statistics
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="QLoRA SFT training for Qwen3.5-2B on full-log SafeHome data.")
|
|
parser.add_argument("--model-name", required=True)
|
|
parser.add_argument("--train-file", required=True)
|
|
parser.add_argument("--dev-file", required=True)
|
|
parser.add_argument("--output-dir", required=True)
|
|
parser.add_argument("--learning-rate", type=float, default=1e-5)
|
|
parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
|
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
|
|
parser.add_argument("--per-device-eval-batch-size", type=int, default=1)
|
|
parser.add_argument("--gradient-accumulation-steps", type=int, default=16)
|
|
parser.add_argument("--max-length", type=int, default=4096)
|
|
parser.add_argument("--max-prompt-length", type=int, default=3584)
|
|
parser.add_argument("--max-completion-length", type=int, default=512)
|
|
parser.add_argument("--eval-steps", type=int, default=50)
|
|
parser.add_argument("--save-steps", type=int, default=50)
|
|
parser.add_argument("--logging-steps", type=int, default=5)
|
|
parser.add_argument("--warmup-ratio", type=float, default=0.03)
|
|
parser.add_argument("--weight-decay", type=float, default=0.0)
|
|
parser.add_argument("--optim", default="paged_adamw_8bit")
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--load-best-model-at-end", action="store_true")
|
|
parser.add_argument("--lora-r", type=int, default=16)
|
|
parser.add_argument("--lora-alpha", type=int, default=32)
|
|
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
|
parser.add_argument("--attn-implementation", default="sdpa", choices=["sdpa", "eager", "flash_attention_2"])
|
|
parser.add_argument("--resume-from-checkpoint", default=None)
|
|
parser.add_argument("--bf16", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_raw_dataset(train_file: str, dev_file: str):
|
|
return load_dataset("json", data_files={"train": train_file, "dev": dev_file})
|
|
|
|
|
|
def build_model_and_tokenizer(args: argparse.Namespace):
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_compute_dtype=compute_dtype,
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.model_name,
|
|
trust_remote_code=True,
|
|
use_fast=False,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "right"
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_name,
|
|
trust_remote_code=True,
|
|
quantization_config=bnb_config,
|
|
torch_dtype=compute_dtype,
|
|
attn_implementation=args.attn_implementation,
|
|
device_map={"": local_rank},
|
|
)
|
|
model.config.use_cache = False
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
|
|
peft_config = LoraConfig(
|
|
r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
|
)
|
|
model = get_peft_model(model, peft_config)
|
|
model.print_trainable_parameters()
|
|
return model, tokenizer
|
|
|
|
|
|
def tokenize_datasets(dataset, tokenizer, args: argparse.Namespace):
|
|
stats = {"prompt_before": [], "prompt_after": [], "completion_before": [], "completion_after": [], "full_after": []}
|
|
|
|
def tokenize_example(example):
|
|
messages = example["messages"]
|
|
prompt_ids = tokenizer.apply_chat_template(
|
|
messages[:-1],
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
)
|
|
full_ids = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=False,
|
|
)
|
|
|
|
if len(full_ids) < len(prompt_ids) or full_ids[:len(prompt_ids)] != prompt_ids:
|
|
prompt_text = tokenizer.apply_chat_template(messages[:-1], tokenize=False, add_generation_prompt=True)
|
|
full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
|
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
|
|
full_ids = tokenizer(full_text, add_special_tokens=False)["input_ids"]
|
|
|
|
response_ids = full_ids[len(prompt_ids):]
|
|
|
|
stats["prompt_before"].append(len(prompt_ids))
|
|
stats["completion_before"].append(len(response_ids))
|
|
|
|
prompt_ids = prompt_ids[-args.max_prompt_length:]
|
|
response_ids = response_ids[:args.max_completion_length]
|
|
|
|
if len(prompt_ids) + len(response_ids) > args.max_length:
|
|
overflow = len(prompt_ids) + len(response_ids) - args.max_length
|
|
prompt_ids = prompt_ids[overflow:]
|
|
if len(prompt_ids) + len(response_ids) > args.max_length:
|
|
response_ids = response_ids[: max(0, args.max_length - len(prompt_ids))]
|
|
|
|
input_ids = prompt_ids + response_ids
|
|
labels = ([-100] * len(prompt_ids)) + response_ids
|
|
attention_mask = [1] * len(input_ids)
|
|
|
|
stats["prompt_after"].append(len(prompt_ids))
|
|
stats["completion_after"].append(len(response_ids))
|
|
stats["full_after"].append(len(input_ids))
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"labels": labels,
|
|
}
|
|
|
|
remove_columns = dataset["train"].column_names
|
|
tokenized = dataset.map(tokenize_example, remove_columns=remove_columns)
|
|
tokenized = tokenized.filter(lambda x: len(x["input_ids"]) > 0 and any(v != -100 for v in x["labels"]))
|
|
|
|
summary = {}
|
|
for key, values in stats.items():
|
|
if values:
|
|
summary[key] = {
|
|
"count": len(values),
|
|
"min": min(values),
|
|
"mean": round(statistics.mean(values), 1),
|
|
"p50": sorted(values)[len(values) // 2],
|
|
"max": max(values),
|
|
}
|
|
return tokenized, summary
|
|
|
|
|
|
@dataclass
|
|
class SupervisedDataCollator:
|
|
tokenizer: AutoTokenizer
|
|
|
|
def __call__(self, features):
|
|
pad_id = self.tokenizer.pad_token_id
|
|
max_len = max(len(f["input_ids"]) for f in features)
|
|
input_ids = []
|
|
attention_mask = []
|
|
labels = []
|
|
for f in features:
|
|
pad_len = max_len - len(f["input_ids"])
|
|
input_ids.append(f["input_ids"] + ([pad_id] * pad_len))
|
|
attention_mask.append(f["attention_mask"] + ([0] * pad_len))
|
|
labels.append(f["labels"] + ([-100] * pad_len))
|
|
return {
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
|
"labels": torch.tensor(labels, dtype=torch.long),
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
raw_dataset = load_raw_dataset(args.train_file, args.dev_file)
|
|
model, tokenizer = build_model_and_tokenizer(args)
|
|
tokenized_dataset, token_stats = tokenize_datasets(raw_dataset, tokenizer, args)
|
|
|
|
train_args_kwargs = dict(
|
|
output_dir=str(output_dir),
|
|
learning_rate=args.learning_rate,
|
|
num_train_epochs=args.num_train_epochs,
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
gradient_checkpointing=True,
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
logging_steps=args.logging_steps,
|
|
save_steps=args.save_steps,
|
|
eval_steps=args.eval_steps,
|
|
save_strategy="steps",
|
|
save_total_limit=2,
|
|
bf16=args.bf16,
|
|
report_to="none",
|
|
remove_unused_columns=False,
|
|
dataloader_num_workers=2,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=args.warmup_ratio,
|
|
weight_decay=args.weight_decay,
|
|
optim=args.optim,
|
|
seed=args.seed,
|
|
data_seed=args.seed,
|
|
greater_is_better=False,
|
|
metric_for_best_model="eval_loss",
|
|
ddp_find_unused_parameters=False,
|
|
)
|
|
|
|
sig = inspect.signature(TrainingArguments.__init__).parameters
|
|
if "eval_strategy" in sig:
|
|
train_args_kwargs["eval_strategy"] = "steps"
|
|
elif "evaluation_strategy" in sig:
|
|
train_args_kwargs["evaluation_strategy"] = "steps"
|
|
|
|
enable_best = args.load_best_model_at_end and args.eval_steps > 0 and args.save_steps % args.eval_steps == 0
|
|
if "load_best_model_at_end" in sig:
|
|
train_args_kwargs["load_best_model_at_end"] = enable_best
|
|
|
|
train_args_kwargs = {k: v for k, v in train_args_kwargs.items() if k in sig}
|
|
training_args = TrainingArguments(**train_args_kwargs)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized_dataset["train"],
|
|
eval_dataset=tokenized_dataset["dev"],
|
|
data_collator=SupervisedDataCollator(tokenizer),
|
|
)
|
|
|
|
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
|
trainer.save_model()
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
summary = {
|
|
"model_name": args.model_name,
|
|
"train_file": args.train_file,
|
|
"dev_file": args.dev_file,
|
|
"output_dir": str(output_dir),
|
|
"num_train_examples": len(tokenized_dataset["train"]),
|
|
"num_dev_examples": len(tokenized_dataset["dev"]),
|
|
"max_length": args.max_length,
|
|
"max_prompt_length": args.max_prompt_length,
|
|
"max_completion_length": args.max_completion_length,
|
|
"learning_rate": args.learning_rate,
|
|
"epochs": args.num_train_epochs,
|
|
"token_stats": token_stats,
|
|
}
|
|
with (output_dir / "run_summary.json").open("w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|