Files
llmiotsafe/DPODataGen/audit_dataset.py
2026-05-12 17:01:39 +08:00

48 lines
1.3 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
def audit_split(split_dir: Path) -> None:
pair_file = split_dir / "pairs.jsonl"
lens = []
variants = {}
with pair_file.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
text = obj["prompt"]["system"] + "\n" + obj["prompt"]["user"]
lens.append(len(text))
variants[obj["variant"]] = variants.get(obj["variant"], 0) + 1
lens.sort()
print(f"=== {split_dir.name} ===")
print(f"pairs={len(lens)} variants={variants}")
print(
"prompt_chars:",
f"min={lens[0]}",
f"mean={round(statistics.mean(lens), 1)}",
f"p50={lens[len(lens)//2]}",
f"p95={lens[int(len(lens)*0.95)]}",
f"max={lens[-1]}",
)
def main() -> None:
parser = argparse.ArgumentParser(description="Audit generated DPO prompt lengths.")
parser.add_argument("--root", default="data", help="root containing train_pref_v1/dev_pref_v1")
args = parser.parse_args()
root = Path(args.root)
for name in ["train_pref_v1", "dev_pref_v1"]:
split_dir = root / name
if split_dir.exists():
audit_split(split_dir)
if __name__ == "__main__":
main()