72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
def pct(values, q: float) -> int:
|
|
idx = min(len(values) - 1, int(len(values) * q))
|
|
return values[idx]
|
|
|
|
|
|
def analyze_file(tokenizer, path: Path, label: str) -> None:
|
|
prompt_lens = []
|
|
full_lens = []
|
|
completion_lens = []
|
|
|
|
with path.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
obj = json.loads(line)
|
|
messages = obj["messages"]
|
|
prompt_ids = tokenizer.apply_chat_template(
|
|
messages[:-1],
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
)
|
|
full_ids = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=False,
|
|
)
|
|
completion_len = max(0, len(full_ids) - len(prompt_ids))
|
|
prompt_lens.append(len(prompt_ids))
|
|
full_lens.append(len(full_ids))
|
|
completion_lens.append(completion_len)
|
|
|
|
prompt_lens.sort()
|
|
full_lens.sort()
|
|
completion_lens.sort()
|
|
|
|
print(f"=== {label} ===")
|
|
print(f"count={len(full_lens)}")
|
|
print(f"prompt_tokens: min={prompt_lens[0]} mean={round(statistics.mean(prompt_lens),1)} p50={pct(prompt_lens,0.5)} p95={pct(prompt_lens,0.95)} max={prompt_lens[-1]}")
|
|
print(f"completion_tokens: min={completion_lens[0]} mean={round(statistics.mean(completion_lens),1)} p50={pct(completion_lens,0.5)} p95={pct(completion_lens,0.95)} max={completion_lens[-1]}")
|
|
print(f"full_tokens: min={full_lens[0]} mean={round(statistics.mean(full_lens),1)} p50={pct(full_lens,0.5)} p95={pct(full_lens,0.95)} max={full_lens[-1]}")
|
|
print()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Analyze token lengths for conversational SFT datasets.")
|
|
parser.add_argument("--model-name", required=True)
|
|
parser.add_argument("--train-file", required=True)
|
|
parser.add_argument("--dev-file", required=True)
|
|
args = parser.parse_args()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, use_fast=False)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
analyze_file(tokenizer, Path(args.train_file), "train")
|
|
analyze_file(tokenizer, Path(args.dev_file), "dev")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|