Files
llmiotsafe/SFT_FULLLOG_QWEN2B_SERVER_BUNDLE/scripts/analyze_sft_lengths.py
2026-05-12 17:01:39 +08:00

72 lines
2.5 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
from transformers import AutoTokenizer
def pct(values, q: float) -> int:
idx = min(len(values) - 1, int(len(values) * q))
return values[idx]
def analyze_file(tokenizer, path: Path, label: str) -> None:
prompt_lens = []
full_lens = []
completion_lens = []
with path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
messages = obj["messages"]
prompt_ids = tokenizer.apply_chat_template(
messages[:-1],
tokenize=True,
add_generation_prompt=True,
)
full_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
)
completion_len = max(0, len(full_ids) - len(prompt_ids))
prompt_lens.append(len(prompt_ids))
full_lens.append(len(full_ids))
completion_lens.append(completion_len)
prompt_lens.sort()
full_lens.sort()
completion_lens.sort()
print(f"=== {label} ===")
print(f"count={len(full_lens)}")
print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}")
print(f"completion_tokens: min={completion_lens[0]} mean={round(statistics.mean(completion_lens),1)} p50={pct(completion_lens,0.5)} p95={pct(completion_lens,0.95)} max={completion_lens[-1]}")
print(f"full_tokens: min={full_lens[0]} mean={round(statistics.mean(full_lens),1)} p50={pct(full_lens,0.5)} p95={pct(full_lens,0.95)} max={full_lens[-1]}")
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Analyze token lengths for conversational SFT datasets.")
parser.add_argument("--model-name", required=True)
parser.add_argument("--train-file", required=True)
parser.add_argument("--dev-file", required=True)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
analyze_file(tokenizer, Path(args.train_file), "train")
analyze_file(tokenizer, Path(args.dev_file), "dev")
if __name__ == "__main__":
main()