48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import statistics
|
|
from pathlib import Path
|
|
|
|
|
|
def audit_split(split_dir: Path) -> None:
|
|
pair_file = split_dir / "pairs.jsonl"
|
|
lens = []
|
|
variants = {}
|
|
with pair_file.open("r", encoding="utf-8") as f:
|
|
for line in f:
|
|
obj = json.loads(line)
|
|
text = obj["prompt"]["system"] + "\n" + obj["prompt"]["user"]
|
|
lens.append(len(text))
|
|
variants[obj["variant"]] = variants.get(obj["variant"], 0) + 1
|
|
|
|
lens.sort()
|
|
print(f"=== {split_dir.name} ===")
|
|
print(f"pairs={len(lens)} variants={variants}")
|
|
print(
|
|
"prompt_chars:",
|
|
f"min={lens[0]}",
|
|
f"mean={round(statistics.mean(lens), 1)}",
|
|
f"p50={lens[len(lens)//2]}",
|
|
f"p95={lens[int(len(lens)*0.95)]}",
|
|
f"max={lens[-1]}",
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Audit generated DPO prompt lengths.")
|
|
parser.add_argument("--root", default="data", help="root containing train_pref_v1/dev_pref_v1")
|
|
args = parser.parse_args()
|
|
|
|
root = Path(args.root)
|
|
for name in ["train_pref_v1", "dev_pref_v1"]:
|
|
split_dir = root / name
|
|
if split_dir.exists():
|
|
audit_split(split_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|