70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
|
|
def convert_row(row: dict) -> dict:
|
|
messages = list(row["prompt"]) + list(row["chosen"])
|
|
return {
|
|
"example_id": row["pair_id"],
|
|
"episode_id": row["episode_id"],
|
|
"scenario_id": row.get("scenario_id"),
|
|
"variant": row.get("variant"),
|
|
"assistant_source": row.get("chosen_source"),
|
|
"messages": messages,
|
|
}
|
|
|
|
|
|
def convert_split(input_path: Path, output_path: Path) -> dict:
|
|
counters = Counter()
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with input_path.open("r", encoding="utf-8") as src, output_path.open("w", encoding="utf-8") as dst:
|
|
for line in src:
|
|
if not line.strip():
|
|
continue
|
|
row = json.loads(line)
|
|
out = convert_row(row)
|
|
dst.write(json.dumps(out, ensure_ascii=False) + "\n")
|
|
counters["total"] += 1
|
|
counters[f"variant_{out.get('variant', '?')}"] += 1
|
|
counters[f"assistant_source_{out.get('assistant_source', '?')}"] += 1
|
|
|
|
return dict(counters)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Build chosen-only conversational SFT datasets from full-log DPO pairs.")
|
|
parser.add_argument("--train-input", required=True)
|
|
parser.add_argument("--dev-input", required=True)
|
|
parser.add_argument("--output-root", required=True)
|
|
args = parser.parse_args()
|
|
|
|
output_root = Path(args.output_root)
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
manifest = {
|
|
"splits": {
|
|
"train": convert_split(Path(args.train_input), output_root / "train_sft.jsonl"),
|
|
"dev": convert_split(Path(args.dev_input), output_root / "dev_sft.jsonl"),
|
|
},
|
|
"notes": {
|
|
"format": "conversational SFT jsonl with a single `messages` field",
|
|
"source": "chosen responses only from full-log preference pairs",
|
|
},
|
|
}
|
|
|
|
with (output_root.parent / "manifest.json").open("w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"wrote SFT datasets to: {output_root}")
|
|
print(f"manifest: {output_root.parent / 'manifest.json'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|