#!/usr/bin/env python3 from __future__ import annotations import argparse import json import statistics from pathlib import Path from transformers import AutoTokenizer def render_messages(tokenizer, messages): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) def analyze_file(tokenizer, path: Path, label: str) -> None: prompt_lens = [] chosen_lens = [] rejected_lens = [] total_lens = [] with path.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue obj = json.loads(line) prompt_text = render_messages(tokenizer, obj["prompt"]) chosen_text = render_messages(tokenizer, obj["prompt"] + obj["chosen"]) rejected_text = render_messages(tokenizer, obj["prompt"] + obj["rejected"]) prompt_len = len(tokenizer(prompt_text, add_special_tokens=False)["input_ids"]) chosen_len = len(tokenizer(chosen_text, add_special_tokens=False)["input_ids"]) rejected_len = len(tokenizer(rejected_text, add_special_tokens=False)["input_ids"]) prompt_lens.append(prompt_len) chosen_lens.append(chosen_len) rejected_lens.append(rejected_len) total_lens.append(max(chosen_len, rejected_len)) def pct(values, q): idx = min(len(values) - 1, int(len(values) * q)) return values[idx] prompt_lens.sort() chosen_lens.sort() rejected_lens.sort() total_lens.sort() print(f"=== {label} ===") print(f"count={len(prompt_lens)}") print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}") print(f"chosen_tokens: min={chosen_lens[0]} mean={round(statistics.mean(chosen_lens),1)} p50={pct(chosen_lens,0.5)} p95={pct(chosen_lens,0.95)} max={chosen_lens[-1]}") print(f"rejected_tokens: min={rejected_lens[0]} mean={round(statistics.mean(rejected_lens),1)} p50={pct(rejected_lens,0.5)} p95={pct(rejected_lens,0.95)} max={rejected_lens[-1]}") print(f"max_pair_tokens: min={total_lens[0]} mean={round(statistics.mean(total_lens),1)} p50={pct(total_lens,0.5)} p95={pct(total_lens,0.95)} max={total_lens[-1]}") print() def main() -> None: parser = argparse.ArgumentParser(description="Analyze token lengths for conversational DPO datasets.") parser.add_argument("--model-name", required=True) parser.add_argument("--train-file", required=True) parser.add_argument("--dev-file", required=True) args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token analyze_file(tokenizer, Path(args.train_file), "train") analyze_file(tokenizer, Path(args.dev_file), "dev") if __name__ == "__main__": main()