Files
llmiotsafe/DPODataGenFullLog/audit_dataset.py
2026-05-12 17:01:39 +08:00

59 lines
1.8 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
def summarize(values):
if not values:
return "n/a"
ordered = sorted(values)
p50 = ordered[len(values) // 2]
p95 = ordered[min(len(values) - 1, int(len(values) * 0.95))]
return f"min={min(values)} mean={statistics.mean(values):.1f} p50={p50} p95={p95} max={max(values)}"
def audit_split(split_dir: Path) -> None:
pair_file = split_dir / "pairs.jsonl"
prompt_chars = []
chosen_chars = []
rejected_chars = []
variants = {}
with pair_file.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
obj = json.loads(line)
prompt = obj["prompt"]
prompt_text = prompt.get("system", "") + "\n" + prompt["user"]
prompt_chars.append(len(prompt_text))
chosen_chars.append(len(obj["chosen"]))
rejected_chars.append(len(obj["rejected"]))
variants[obj["variant"]] = variants.get(obj["variant"], 0) + 1
print(f"=== {split_dir.name} ===")
print(f"pairs={len(prompt_chars)} variants={variants}")
print(f"prompt_chars: {summarize(prompt_chars)}")
print(f"chosen_chars: {summarize(chosen_chars)}")
print(f"rejected_chars: {summarize(rejected_chars)}")
def main() -> None:
parser = argparse.ArgumentParser(description="Audit full-log DPO prompt lengths.")
parser.add_argument("--root", default="data_dpo_full_log_v1")
args = parser.parse_args()
root = Path(args.root)
for split_name in ("train_pref_v1", "dev_pref_v1"):
split_dir = root / split_name
if split_dir.exists():
audit_split(split_dir)
if __name__ == "__main__":
main()