79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
def render_messages(tokenizer, messages):
|
|
return tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=False,
|
|
)
|
|
|
|
|
|
def analyze_file(tokenizer, path: Path, label: str) -> None:
|
|
prompt_lens = []
|
|
chosen_lens = []
|
|
rejected_lens = []
|
|
total_lens = []
|
|
|
|
with path.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
obj = json.loads(line)
|
|
prompt_text = render_messages(tokenizer, obj["prompt"])
|
|
chosen_text = render_messages(tokenizer, obj["prompt"] + obj["chosen"])
|
|
rejected_text = render_messages(tokenizer, obj["prompt"] + obj["rejected"])
|
|
|
|
prompt_len = len(tokenizer(prompt_text, add_special_tokens=False)["input_ids"])
|
|
chosen_len = len(tokenizer(chosen_text, add_special_tokens=False)["input_ids"])
|
|
rejected_len = len(tokenizer(rejected_text, add_special_tokens=False)["input_ids"])
|
|
|
|
prompt_lens.append(prompt_len)
|
|
chosen_lens.append(chosen_len)
|
|
rejected_lens.append(rejected_len)
|
|
total_lens.append(max(chosen_len, rejected_len))
|
|
|
|
def pct(values, q):
|
|
idx = min(len(values) - 1, int(len(values) * q))
|
|
return values[idx]
|
|
|
|
prompt_lens.sort()
|
|
chosen_lens.sort()
|
|
rejected_lens.sort()
|
|
total_lens.sort()
|
|
|
|
print(f"=== {label} ===")
|
|
print(f"count={len(prompt_lens)}")
|
|
print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}")
|
|
print(f"chosen_tokens: min={chosen_lens[0]} mean={round(statistics.mean(chosen_lens),1)} p50={pct(chosen_lens,0.5)} p95={pct(chosen_lens,0.95)} max={chosen_lens[-1]}")
|
|
print(f"rejected_tokens: min={rejected_lens[0]} mean={round(statistics.mean(rejected_lens),1)} p50={pct(rejected_lens,0.5)} p95={pct(rejected_lens,0.95)} max={rejected_lens[-1]}")
|
|
print(f"max_pair_tokens: min={total_lens[0]} mean={round(statistics.mean(total_lens),1)} p50={pct(total_lens,0.5)} p95={pct(total_lens,0.95)} max={total_lens[-1]}")
|
|
print()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Analyze token lengths for conversational DPO datasets.")
|
|
parser.add_argument("--model-name", required=True)
|
|
parser.add_argument("--train-file", required=True)
|
|
parser.add_argument("--dev-file", required=True)
|
|
args = parser.parse_args()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
analyze_file(tokenizer, Path(args.train_file), "train")
|
|
analyze_file(tokenizer, Path(args.dev_file), "dev")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|