59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
|
|
def summarize(values):
|
|
if not values:
|
|
return "n/a"
|
|
ordered = sorted(values)
|
|
p50 = ordered[len(values) // 2]
|
|
p95 = ordered[min(len(values) - 1, int(len(values) * 0.95))]
|
|
return f"min={min(values)} mean={statistics.mean(values):.1f} p50={p50} p95={p95} max={max(values)}"
|
|
|
|
|
|
def audit_split(split_dir: Path) -> None:
|
|
pair_file = split_dir / "pairs.jsonl"
|
|
prompt_chars = []
|
|
chosen_chars = []
|
|
rejected_chars = []
|
|
variants = {}
|
|
|
|
with pair_file.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
obj = json.loads(line)
|
|
prompt = obj["prompt"]
|
|
prompt_text = prompt.get("system", "") + "\n" + prompt["user"]
|
|
prompt_chars.append(len(prompt_text))
|
|
chosen_chars.append(len(obj["chosen"]))
|
|
rejected_chars.append(len(obj["rejected"]))
|
|
variants[obj["variant"]] = variants.get(obj["variant"], 0) + 1
|
|
|
|
print(f"=== {split_dir.name} ===")
|
|
print(f"pairs={len(prompt_chars)} variants={variants}")
|
|
print(f"prompt_chars: {summarize(prompt_chars)}")
|
|
print(f"chosen_chars: {summarize(chosen_chars)}")
|
|
print(f"rejected_chars: {summarize(rejected_chars)}")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Audit full-log DPO prompt lengths.")
|
|
parser.add_argument("--root", default="data_dpo_full_log_v1")
|
|
args = parser.parse_args()
|
|
|
|
root = Path(args.root)
|
|
for split_name in ("train_pref_v1", "dev_pref_v1"):
|
|
split_dir = root / split_name
|
|
if split_dir.exists():
|
|
audit_split(split_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|