Files
llmiotsafe/DPO_FULLLOG_QWEN2B_SERVER_BUNDLE/scripts/analyze_token_lengths.py
2026-05-12 17:01:39 +08:00

79 lines
3.0 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
from transformers import AutoTokenizer
def render_messages(tokenizer, messages):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
def analyze_file(tokenizer, path: Path, label: str) -> None:
prompt_lens = []
chosen_lens = []
rejected_lens = []
total_lens = []
with path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
prompt_text = render_messages(tokenizer, obj["prompt"])
chosen_text = render_messages(tokenizer, obj["prompt"] + obj["chosen"])
rejected_text = render_messages(tokenizer, obj["prompt"] + obj["rejected"])
prompt_len = len(tokenizer(prompt_text, add_special_tokens=False)["input_ids"])
chosen_len = len(tokenizer(chosen_text, add_special_tokens=False)["input_ids"])
rejected_len = len(tokenizer(rejected_text, add_special_tokens=False)["input_ids"])
prompt_lens.append(prompt_len)
chosen_lens.append(chosen_len)
rejected_lens.append(rejected_len)
total_lens.append(max(chosen_len, rejected_len))
def pct(values, q):
idx = min(len(values) - 1, int(len(values) * q))
return values[idx]
prompt_lens.sort()
chosen_lens.sort()
rejected_lens.sort()
total_lens.sort()
print(f"=== {label} ===")
print(f"count={len(prompt_lens)}")
print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}")
print(f"chosen_tokens: min={chosen_lens[0]} mean={round(statistics.mean(chosen_lens),1)} p50={pct(chosen_lens,0.5)} p95={pct(chosen_lens,0.95)} max={chosen_lens[-1]}")
print(f"rejected_tokens: min={rejected_lens[0]} mean={round(statistics.mean(rejected_lens),1)} p50={pct(rejected_lens,0.5)} p95={pct(rejected_lens,0.95)} max={rejected_lens[-1]}")
print(f"max_pair_tokens: min={total_lens[0]} mean={round(statistics.mean(total_lens),1)} p50={pct(total_lens,0.5)} p95={pct(total_lens,0.95)} max={total_lens[-1]}")
print()
def main() -> None:
parser = argparse.ArgumentParser(description="Analyze token lengths for conversational DPO datasets.")
parser.add_argument("--model-name", required=True)
parser.add_argument("--train-file", required=True)
parser.add_argument("--dev-file", required=True)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
analyze_file(tokenizer, Path(args.train_file), "train")
analyze_file(tokenizer, Path(args.dev_file), "dev")
if __name__ == "__main__":
main()