#!/usr/bin/env python3 from __future__ import annotations import argparse import json import statistics from pathlib import Path from transformers import AutoTokenizer def pct(values, q: float) -> int: idx = min(len(values) - 1, int(len(values) * q)) return values[idx] def analyze_file(tokenizer, path: Path, label: str) -> None: prompt_lens = [] full_lens = [] completion_lens = [] with path.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue obj = json.loads(line) messages = obj["messages"] prompt_ids = tokenizer.apply_chat_template( messages[:-1], tokenize=True, add_generation_prompt=True, ) full_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, ) completion_len = max(0, len(full_ids) - len(prompt_ids)) prompt_lens.append(len(prompt_ids)) full_lens.append(len(full_ids)) completion_lens.append(completion_len) prompt_lens.sort() full_lens.sort() completion_lens.sort() print(f"=== {label} ===") print(f"count={len(full_lens)}") print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}") print(f"completion_tokens: min={completion_lens[0]} mean={round(statistics.mean(completion_lens),1)} p50={pct(completion_lens,0.5)} p95={pct(completion_lens,0.95)} max={completion_lens[-1]}") print(f"full_tokens: min={full_lens[0]} mean={round(statistics.mean(full_lens),1)} p50={pct(full_lens,0.5)} p95={pct(full_lens,0.95)} max={full_lens[-1]}") print() def main() -> None: parser = argparse.ArgumentParser(description="Analyze token lengths for conversational SFT datasets.") parser.add_argument("--model-name", required=True) parser.add_argument("--train-file", required=True) parser.add_argument("--dev-file", required=True) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token analyze_file(tokenizer, Path(args.train_file), "train") analyze_file(tokenizer, Path(args.dev_file), "dev") if __name__ == "__main__": main()