596 lines
24 KiB
Python
596 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import math
|
|
import random
|
|
from collections import Counter, defaultdict
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, List, Optional, Tuple
|
|
|
|
from src.episode_gen.behavior_engine import generate_daily_events
|
|
from src.episode_gen.home_state import HomeState
|
|
|
|
from .prompt_builder import build_dpo_prompt
|
|
from .scenario_bank import SCENARIOS, get_scenarios_for_split
|
|
|
|
|
|
PROFILE_BY_LAYOUT = {
|
|
"A": ["young_professional"],
|
|
"B": ["young_professional", "family_with_children"],
|
|
"C": ["family_with_children"],
|
|
"D": ["elderly_living_alone"],
|
|
}
|
|
|
|
LAYOUT_WEIGHTS = {"A": 0.22, "B": 0.28, "C": 0.28, "D": 0.22}
|
|
|
|
QUERY_TEMPLATES = {
|
|
"sq1": [
|
|
"根据以下传感器与设备记录,评估目标设备当前工作状态是否正常。如异常,请给出故障判断依据。",
|
|
"请结合下面的 Matter 设备日志,判断是否存在设备健康问题,并说明可能的故障类型。",
|
|
],
|
|
"sq2": [
|
|
"系统刚记录到以下一段设备状态变化。请评估这些变化是否已经构成安全风险,并说明原因。",
|
|
"请分析下面的实时设备事件,判断是否需要将其视为异常安全事件。",
|
|
],
|
|
"sq3": [
|
|
"以下是过去一段时间的设备活动记录。请识别其中是否存在可疑行为模式,并解释你的判断。",
|
|
"请从下面的时序日志中分析住户或环境是否出现异常模式,不要只看单个事件。",
|
|
],
|
|
"sq4": [
|
|
"请综合以下多个设备与环境状态,评估家庭整体安全状况,指出任何潜在风险。",
|
|
"你将看到一组跨设备日志。请做综合安全分析,而不是单点判断。",
|
|
],
|
|
"sq5": [
|
|
"监控系统标记了潜在异常。请分析下面的日志,判断是否需要采取响应措施,并给出建议。",
|
|
"请根据以下设备记录评估当前是否需要安全响应;若需要,请说明原因与处置建议。",
|
|
],
|
|
}
|
|
|
|
TOTAL_PLAN = {
|
|
"intrusion": {"TP": 400, "FP": 300},
|
|
"fire_gas": {"TP": 250, "FP": 200},
|
|
"water_damage": {"TP": 150, "FP": 100},
|
|
"device_fault": {"TP": 350, "FP": 250},
|
|
"elderly_specific": {"TP": 150, "FP": 100},
|
|
"child_specific": {"TP": 100, "FP": 80},
|
|
"behavioral_anomaly": {"TP": 100, "FP": 70},
|
|
"none": {"TN": 200},
|
|
}
|
|
|
|
DEV_PLAN = {
|
|
"intrusion": {"TP": 45, "FP": 30},
|
|
"fire_gas": {"TP": 28, "FP": 22},
|
|
"water_damage": {"TP": 15, "FP": 10},
|
|
"device_fault": {"TP": 36, "FP": 29},
|
|
"elderly_specific": {"TP": 15, "FP": 10},
|
|
"child_specific": {"TP": 11, "FP": 9},
|
|
"behavioral_anomaly": {"TP": 10, "FP": 10},
|
|
"none": {"TN": 20},
|
|
}
|
|
|
|
THREAT_FAMILY = {
|
|
"intrusion": ["intrusion", "credential_theft", "tailgating"],
|
|
"fire_gas": ["fire_risk", "carbon_monoxide", "unattended_cooking", "safety_device_failure"],
|
|
"water_damage": ["water_leak"],
|
|
"device_fault": ["sensor_stuck", "sensor_drift", "sensor_malfunction", "actuator_stuck", "lock_malfunction"],
|
|
"elderly_specific": ["health_concern", "abnormal_inactivity", "possible_fall", "behavioral_anomaly"],
|
|
"child_specific": ["child_safety"],
|
|
"behavioral_anomaly": ["behavioral_anomaly"],
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class SplitOutput:
|
|
split: str
|
|
episode_count: int
|
|
pair_count: int
|
|
stats: Dict[str, int]
|
|
output_dir: Path
|
|
|
|
|
|
def parse_event_time(time_str: str, base_date: datetime) -> datetime:
|
|
if "Day" in time_str:
|
|
prefix, hhmm = time_str.split()
|
|
day_index = int(prefix.replace("Day", "")) - 1
|
|
else:
|
|
hhmm = time_str
|
|
day_index = 0
|
|
hour, minute = [int(x) for x in hhmm.split(":")]
|
|
return (base_date + timedelta(days=day_index)).replace(hour=hour, minute=minute, second=0, microsecond=0)
|
|
|
|
|
|
def scenario_num_days(scenario: Optional[Dict]) -> int:
|
|
if not scenario:
|
|
return 1
|
|
max_day = 1
|
|
for pool_name in ("anomaly_events",):
|
|
for item in scenario.get(pool_name, []):
|
|
time_str = item.get("time", "")
|
|
if "Day" in time_str:
|
|
max_day = max(max_day, int(time_str.split()[0].replace("Day", "")))
|
|
return max_day
|
|
|
|
|
|
def choose_weighted_layout(layouts: Iterable[str], rng: random.Random) -> str:
|
|
layouts = list(layouts)
|
|
weights = [LAYOUT_WEIGHTS.get(layout, 0.25) for layout in layouts]
|
|
return rng.choices(layouts, weights=weights, k=1)[0]
|
|
|
|
|
|
def choose_profile(layout_id: str, scenario: Optional[Dict], rng: random.Random) -> str:
|
|
allowed = scenario.get("applicable_profiles") if scenario else None
|
|
candidates = PROFILE_BY_LAYOUT[layout_id]
|
|
if allowed:
|
|
candidates = [c for c in candidates if c in allowed]
|
|
return rng.choice(candidates)
|
|
|
|
|
|
def pick_room(home: HomeState, selector: str, rng: random.Random) -> Optional[str]:
|
|
rooms = home.get_room_ids()
|
|
if selector in {"entrance", "kitchen", "balcony"}:
|
|
return selector if selector in rooms else None
|
|
if selector == "bathroom":
|
|
for room in rooms:
|
|
if "bathroom" in room and any(d.device_type == "water_leak_detector" for d in home.get_devices_in_room(room)):
|
|
return room
|
|
return "bathroom" if "bathroom" in rooms else None
|
|
if selector == "bathroom_pair":
|
|
return "master_bathroom" if "master_bathroom" in rooms else None
|
|
if selector == "living_area":
|
|
for room in ("living_room", "dining_room"):
|
|
if room in rooms:
|
|
return room
|
|
if selector in {"behavioral", "light_room", "window_room", "occupancy_room", "temp_room", "ac_room"}:
|
|
for room in ("living_room", "bedroom", "master_bedroom", "second_bedroom"):
|
|
if room in rooms:
|
|
return room
|
|
if selector == "elderly":
|
|
return "living_room"
|
|
if selector == "child":
|
|
return "kids_room" if "kids_room" in rooms else None
|
|
if selector == "paired_temp_rooms":
|
|
return "living_room"
|
|
if selector == "contact_room_pair":
|
|
return "living_room"
|
|
return rooms[0] if rooms else None
|
|
|
|
|
|
def choose_device_by_type(home: HomeState, room_id: Optional[str], device_type: str) -> Optional[str]:
|
|
if room_id:
|
|
for dev in home.get_devices_in_room(room_id):
|
|
if dev.device_type == device_type:
|
|
return dev.device_id
|
|
for dev in home.devices.values():
|
|
if dev.device_type == device_type:
|
|
return dev.device_id
|
|
return None
|
|
|
|
|
|
def resolve_device_id(home: HomeState, device_template: str, room_id: Optional[str]) -> Optional[str]:
|
|
device_id = device_template
|
|
if room_id:
|
|
device_id = device_id.replace("{room}", room_id)
|
|
if home.get_device(device_id):
|
|
return device_id
|
|
|
|
special_alias = {
|
|
"living_room_window_contact": "contact_sensor",
|
|
"bedroom_window_contact": "contact_sensor",
|
|
"living_room_window": "window_covering",
|
|
"bedroom_light": "dimmable_light",
|
|
"bedroom_temp_sensor": "temperature_sensor",
|
|
"living_room_light_1": "dimmable_light",
|
|
}
|
|
if device_template in special_alias:
|
|
return choose_device_by_type(home, room_id, special_alias[device_template])
|
|
|
|
for dev in home.devices.values():
|
|
if dev.device_id.endswith(device_template) or device_template in dev.device_id:
|
|
return dev.device_id
|
|
return None
|
|
|
|
|
|
def materialize_events(
|
|
home: HomeState,
|
|
templates: List[Dict],
|
|
base_date: datetime,
|
|
room_id: Optional[str],
|
|
) -> List[Dict]:
|
|
events = []
|
|
for raw in templates:
|
|
device_id = resolve_device_id(home, raw["device"], room_id)
|
|
if not device_id:
|
|
continue
|
|
event = {
|
|
"timestamp": parse_event_time(raw["time"], base_date).isoformat(),
|
|
"device_id": device_id,
|
|
"_is_injected": True,
|
|
}
|
|
if "cluster" in raw and "attribute" in raw:
|
|
event["event_type"] = "attribute_change"
|
|
event["cluster"] = raw["cluster"]
|
|
event["attribute"] = raw["attribute"]
|
|
event["value"] = raw["value"]
|
|
elif "event_name" in raw:
|
|
event["event_type"] = "device_event"
|
|
event["event_name"] = raw["event_name"]
|
|
event["fields"] = deepcopy(raw.get("fields", {}))
|
|
else:
|
|
continue
|
|
if "note" in raw:
|
|
event["_note"] = raw["note"]
|
|
events.append(event)
|
|
return events
|
|
|
|
|
|
def format_event_for_display(event: Dict) -> str:
|
|
ts = event.get("timestamp", "")
|
|
device = event.get("device_id", "")
|
|
if event.get("event_type") == "attribute_change":
|
|
return f"[{ts}] {device} — {event.get('cluster','')}.{event.get('attribute','')} = {event.get('value')}"
|
|
if event.get("event_type") == "device_event":
|
|
fields = ", ".join(f"{k}={v}" for k, v in event.get("fields", {}).items())
|
|
return f"[{ts}] {device} — Event: {event.get('event_name')}({fields})"
|
|
return f"[{ts}] {device} — {event.get('event_type','unknown')}"
|
|
|
|
|
|
def build_normal_events(home: HomeState, profile_id: str, base_date: datetime, num_days: int, seed: int) -> List[Dict]:
|
|
events: List[Dict] = []
|
|
for day_idx in range(num_days):
|
|
day = base_date + timedelta(days=day_idx)
|
|
day_type = "weekday" if day.weekday() < 5 else "weekend"
|
|
events.extend(generate_daily_events(home, profile_id, day_type, day, seed=seed + day_idx))
|
|
return events
|
|
|
|
|
|
def default_tn_query_family(rng: random.Random) -> str:
|
|
return rng.choice(["sq1", "sq2", "sq3", "sq4", "sq5"])
|
|
|
|
|
|
def build_episode(
|
|
episode_id: str,
|
|
variant: str,
|
|
scenario: Optional[Dict],
|
|
layout_id: str,
|
|
profile_id: str,
|
|
base_date: datetime,
|
|
query_family: str,
|
|
seed: int,
|
|
) -> Dict:
|
|
home = HomeState(layout_id)
|
|
home.randomize_temperatures()
|
|
|
|
num_days = max(1, scenario_num_days(scenario))
|
|
normal_events = build_normal_events(home, profile_id, base_date, num_days, seed)
|
|
room_id = pick_room(home, scenario["room_selector"], random.Random(seed)) if scenario else None
|
|
|
|
if variant == "TP":
|
|
anomaly_events = materialize_events(home, scenario["anomaly_events"], base_date, room_id)
|
|
ground_truth = {
|
|
"scenario_id": scenario["scenario_id"],
|
|
"scenario_name": scenario["name"],
|
|
"category": scenario["category"],
|
|
"is_anomaly": True,
|
|
"threat_type": scenario["threat_type"],
|
|
"key_evidence": scenario["ground_truth"]["key_evidence"],
|
|
"expected_response": scenario["ground_truth"]["expected_response"],
|
|
"difficulty_score": scenario["ground_truth"]["difficulty_score"],
|
|
"difficulty_level": scenario["ground_truth"]["difficulty_level"],
|
|
"difficulty_label": scenario["ground_truth"]["difficulty_label"],
|
|
"difficulty_dimensions": scenario["ground_truth"]["difficulty_dimensions"],
|
|
"variant": "TP",
|
|
"target_room": room_id,
|
|
}
|
|
elif variant == "FP":
|
|
anomaly_events = materialize_events(home, scenario["false_positive_variant"]["events"], base_date, room_id)
|
|
ground_truth = {
|
|
"scenario_id": scenario["scenario_id"],
|
|
"scenario_name": scenario["name"],
|
|
"category": scenario["category"],
|
|
"is_anomaly": False,
|
|
"threat_type": "none",
|
|
"key_evidence": [scenario["false_positive_variant"]["key_difference"]],
|
|
"expected_response": "判定为正常模式,不需要报警。",
|
|
"difficulty_score": scenario["ground_truth"]["difficulty_score"],
|
|
"difficulty_level": scenario["ground_truth"]["difficulty_level"],
|
|
"difficulty_label": scenario["ground_truth"]["difficulty_label"],
|
|
"difficulty_dimensions": scenario["ground_truth"]["difficulty_dimensions"],
|
|
"variant": "FP",
|
|
"target_room": room_id,
|
|
"key_difference": scenario["false_positive_variant"]["key_difference"],
|
|
}
|
|
else:
|
|
anomaly_events = []
|
|
ground_truth = {
|
|
"scenario_id": "TN",
|
|
"scenario_name": "Normal baseline",
|
|
"category": "none",
|
|
"is_anomaly": False,
|
|
"threat_type": "none",
|
|
"key_evidence": ["日志中未出现持续矛盾、直接告警或高风险时序链路。"],
|
|
"expected_response": "判定正常,不需要安全响应。",
|
|
"difficulty_score": 5,
|
|
"difficulty_level": 1,
|
|
"difficulty_label": "L1",
|
|
"difficulty_dimensions": {
|
|
"D1_evidence_count": 1,
|
|
"D2_signal_directness": 1,
|
|
"D3_cross_device": 1,
|
|
"D4_temporal_span": 1,
|
|
"D5_fp_similarity": 1,
|
|
},
|
|
"variant": "TN",
|
|
"target_room": room_id,
|
|
}
|
|
|
|
all_events = sorted(normal_events + anomaly_events, key=lambda x: x["timestamp"])
|
|
query = random.Random(seed + 7).choice(QUERY_TEMPLATES[query_family])
|
|
if query_family == "sq3" and num_days > 1:
|
|
query = query.replace("过去一段时间", f"过去 {num_days} 天").replace("下面的时序日志", f"下面 {num_days} 天的时序日志")
|
|
|
|
episode = {
|
|
"episode_id": episode_id,
|
|
"metadata": {
|
|
"sq_type": query_family.upper(),
|
|
"sq_name": f"DPO_{query_family}",
|
|
"variant": variant,
|
|
"layout_id": layout_id,
|
|
"layout_name": home.layout_data.get("_meta", {}).get("layout_name", ""),
|
|
"profile_id": profile_id,
|
|
"base_date": base_date.isoformat(),
|
|
"seed": seed,
|
|
"scenario_id": ground_truth["scenario_id"],
|
|
"query_family": query_family,
|
|
"total_events": len(all_events),
|
|
"anomaly_events_count": len(anomaly_events),
|
|
"split_family": "train_pref_v1",
|
|
"focus_room": room_id,
|
|
"focus_devices": sorted({event["device_id"] for event in anomaly_events}),
|
|
},
|
|
"home_state": home.snapshot(),
|
|
"event_sequence": all_events,
|
|
"event_log": "\n".join(format_event_for_display(e) for e in all_events),
|
|
"query": query,
|
|
"ground_truth": ground_truth,
|
|
}
|
|
return episode
|
|
|
|
|
|
def build_rule_chosen_answer(episode: Dict) -> str:
|
|
gt = episode["ground_truth"]
|
|
is_anomaly = gt["is_anomaly"]
|
|
reasoning = []
|
|
for idx, ev in enumerate(gt.get("key_evidence", [])[:3], start=1):
|
|
reasoning.append(f"推理步骤{idx}: {ev}")
|
|
if not reasoning:
|
|
reasoning.append("推理步骤1: 未发现足以支持异常的直接证据。")
|
|
|
|
if is_anomaly:
|
|
answer = {
|
|
"is_anomaly": True,
|
|
"confidence": "medium" if gt.get("difficulty_level", 2) >= 2 else "high",
|
|
"threat_type": gt["threat_type"],
|
|
"threat_description": gt["scenario_name"],
|
|
"reasoning": reasoning,
|
|
"key_evidence": gt.get("key_evidence", [])[:3],
|
|
"recommended_actions": [gt.get("expected_response", "建议人工复核。")],
|
|
}
|
|
else:
|
|
normal_reason = gt.get("key_difference") or gt.get("key_evidence", ["日志可由正常行为解释。"])[0]
|
|
answer = {
|
|
"is_anomaly": False,
|
|
"confidence": "medium",
|
|
"threat_type": "none",
|
|
"threat_description": "当前日志更符合正常模式,不建议报警。",
|
|
"reasoning": [f"推理步骤1: {normal_reason}"],
|
|
"key_evidence": gt.get("key_evidence", [])[:2],
|
|
"recommended_actions": ["继续监测,无需立即处置。"],
|
|
}
|
|
return json.dumps(answer, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def choose_wrong_threat(category: str, correct: str) -> str:
|
|
pool = []
|
|
for cat, threats in THREAT_FAMILY.items():
|
|
if cat != category:
|
|
pool.extend(threats)
|
|
for item in pool:
|
|
if item != correct and item != "none":
|
|
return item
|
|
return "behavioral_anomaly"
|
|
|
|
|
|
def build_rejected_answer(episode: Dict, rng: random.Random) -> Tuple[str, str]:
|
|
gt = episode["ground_truth"]
|
|
category = gt["category"]
|
|
variant = gt["variant"]
|
|
if variant == "TP":
|
|
error_type = rng.choice(["miss", "type_confusion", "fabricated_evidence"])
|
|
if error_type == "miss":
|
|
answer = {
|
|
"is_anomaly": False,
|
|
"confidence": "medium",
|
|
"threat_type": "none",
|
|
"threat_description": "这些变化仍可视为正常波动。",
|
|
"reasoning": ["推理步骤1: 没有看到必须报警的直接证据。"],
|
|
"key_evidence": ["设备状态变化幅度看起来仍在可接受范围内。"],
|
|
"recommended_actions": ["保持观察即可。"],
|
|
}
|
|
elif error_type == "type_confusion":
|
|
answer = {
|
|
"is_anomaly": True,
|
|
"confidence": "low",
|
|
"threat_type": choose_wrong_threat(category, gt["threat_type"]),
|
|
"threat_description": "存在异常,但原因更可能属于其他类别。",
|
|
"reasoning": ["推理步骤1: 日志里有异常,但我把它归因为另一类威胁。"],
|
|
"key_evidence": gt.get("key_evidence", [])[:1],
|
|
"recommended_actions": ["建议按该错误类别处理。"],
|
|
}
|
|
else:
|
|
answer = {
|
|
"is_anomaly": True,
|
|
"confidence": "high",
|
|
"threat_type": gt["threat_type"],
|
|
"threat_description": gt["scenario_name"],
|
|
"reasoning": ["推理步骤1: 日志中出现了不存在的额外报警事件。"],
|
|
"key_evidence": ["我假设看到了一个未在日志中出现的 critical 告警。"],
|
|
"recommended_actions": [gt.get("expected_response", "立即处理。")],
|
|
}
|
|
else:
|
|
error_type = rng.choice(["false_alarm", "false_alarm_hard"])
|
|
answer = {
|
|
"is_anomaly": True,
|
|
"confidence": "medium",
|
|
"threat_type": choose_wrong_threat(category if category != "none" else "behavioral_anomaly", "none"),
|
|
"threat_description": "这些变化表明存在安全异常。",
|
|
"reasoning": ["推理步骤1: 我把本可由正常模式解释的事件误判成了异常。"],
|
|
"key_evidence": gt.get("key_evidence", [])[:1] or ["一些可疑线索。"],
|
|
"recommended_actions": ["建议立即检查并触发安全响应。"],
|
|
}
|
|
return json.dumps(answer, ensure_ascii=False, indent=2), error_type
|
|
|
|
|
|
def category_plan_for_split(split: str) -> Dict[str, Dict[str, int]]:
|
|
if split == "dev":
|
|
return deepcopy(DEV_PLAN)
|
|
plan = deepcopy(TOTAL_PLAN)
|
|
for category, variants in DEV_PLAN.items():
|
|
for variant, count in variants.items():
|
|
plan[category][variant] -= count
|
|
return plan
|
|
|
|
|
|
def scenario_ids_by_category(split: str) -> Dict[str, List[str]]:
|
|
scenarios = get_scenarios_for_split(split)
|
|
grouped: Dict[str, List[str]] = defaultdict(list)
|
|
for sid, scenario in scenarios.items():
|
|
grouped[scenario["category"]].append(sid)
|
|
return grouped
|
|
|
|
|
|
def round_robin_ids(ids: List[str], total: int) -> List[str]:
|
|
if not ids:
|
|
return []
|
|
ordered = []
|
|
for idx in range(total):
|
|
ordered.append(ids[idx % len(ids)])
|
|
return ordered
|
|
|
|
|
|
def build_prompt_record(episode: Dict) -> Dict:
|
|
prompt = build_dpo_prompt(episode)
|
|
return {
|
|
"system": prompt.get("system", ""),
|
|
"user": prompt["user"],
|
|
}
|
|
|
|
|
|
def generate_split(
|
|
output_root: Path,
|
|
split: str,
|
|
seed: int = 20260506,
|
|
max_episodes: Optional[int] = None,
|
|
) -> SplitOutput:
|
|
rng = random.Random(seed + (0 if split == "train" else 999))
|
|
split_dir = output_root / f"{split}_pref_v1"
|
|
episodes_dir = split_dir / "episodes"
|
|
episodes_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
plan = category_plan_for_split(split)
|
|
scenarios = get_scenarios_for_split(split)
|
|
grouped = scenario_ids_by_category(split)
|
|
counters = Counter()
|
|
pair_records = []
|
|
|
|
episode_index = 0
|
|
for category, variants in plan.items():
|
|
for variant, count in variants.items():
|
|
if max_episodes is not None and episode_index >= max_episodes:
|
|
break
|
|
if category == "none":
|
|
scenario_schedule = [None] * count
|
|
else:
|
|
scenario_schedule = round_robin_ids(grouped[category], count)
|
|
rng.shuffle(scenario_schedule)
|
|
for scenario_id in scenario_schedule:
|
|
if max_episodes is not None and episode_index >= max_episodes:
|
|
break
|
|
scenario = scenarios.get(scenario_id) if scenario_id else None
|
|
allowed_layouts = scenario["applicable_layouts"] if scenario else ["A", "B", "C", "D"]
|
|
layout_id = choose_weighted_layout(allowed_layouts, rng)
|
|
profile_id = choose_profile(layout_id, scenario or {}, rng)
|
|
query_family = rng.choice(scenario["query_families"]) if scenario else default_tn_query_family(rng)
|
|
base_date = datetime(2026, 7, 1) + timedelta(days=rng.randint(0, 59))
|
|
episode_id = f"DPO_{split.upper()}_{variant}_{episode_index:05d}"
|
|
episode = build_episode(
|
|
episode_id=episode_id,
|
|
variant=variant,
|
|
scenario=scenario,
|
|
layout_id=layout_id,
|
|
profile_id=profile_id,
|
|
base_date=base_date,
|
|
query_family=query_family,
|
|
seed=seed + episode_index,
|
|
)
|
|
episode_path = episodes_dir / f"{episode_id}.json"
|
|
with open(episode_path, "w", encoding="utf-8") as f:
|
|
json.dump(episode, f, ensure_ascii=False, indent=2)
|
|
|
|
prompt_record = build_prompt_record(episode)
|
|
chosen = build_rule_chosen_answer(episode)
|
|
rejected, rejected_error = build_rejected_answer(episode, rng)
|
|
pair_records.append({
|
|
"pair_id": f"{episode_id}_pair",
|
|
"split": split,
|
|
"episode_id": episode_id,
|
|
"episode_path": str(episode_path),
|
|
"scenario_id": episode["ground_truth"]["scenario_id"],
|
|
"category": episode["ground_truth"]["category"],
|
|
"variant": variant,
|
|
"prompt": prompt_record,
|
|
"chosen": chosen,
|
|
"rejected": rejected,
|
|
"chosen_source": "rule",
|
|
"rejected_source": "constructed",
|
|
"rejected_error_type": rejected_error,
|
|
"metadata": {
|
|
"layout_id": layout_id,
|
|
"profile_id": profile_id,
|
|
"query_family": query_family,
|
|
"difficulty_level": episode["ground_truth"]["difficulty_level"],
|
|
},
|
|
})
|
|
|
|
counters[f"{category}_{variant}"] += 1
|
|
counters[f"layout_{layout_id}"] += 1
|
|
counters[f"query_{query_family}"] += 1
|
|
counters["episodes_total"] += 1
|
|
episode_index += 1
|
|
|
|
if max_episodes is not None and episode_index >= max_episodes:
|
|
break
|
|
|
|
with open(split_dir / "pairs.jsonl", "w", encoding="utf-8") as f:
|
|
for record in pair_records:
|
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
|
|
summary = {
|
|
"split": split,
|
|
"episode_count": counters["episodes_total"],
|
|
"pair_count": len(pair_records),
|
|
"stats": dict(counters),
|
|
}
|
|
with open(split_dir / "summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
return SplitOutput(
|
|
split=split,
|
|
episode_count=counters["episodes_total"],
|
|
pair_count=len(pair_records),
|
|
stats=dict(counters),
|
|
output_dir=split_dir,
|
|
)
|