274 lines
10 KiB
Python
274 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import Counter, defaultdict
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Sequence, Set, Tuple
|
|
|
|
from src.evaluation.prompt_builder import RESPONSE_FORMAT_BASELINE, SYSTEM_PROMPT_BASELINE
|
|
|
|
|
|
PROMPT_EVENT_LIMITS = {
|
|
"sq1": 140,
|
|
"sq2": 80,
|
|
"sq3": 140,
|
|
"sq4": 150,
|
|
"sq5": 100,
|
|
}
|
|
|
|
WINDOW_MINUTES = {
|
|
"sq1": 180,
|
|
"sq2": 35,
|
|
"sq3": 50,
|
|
"sq4": 60,
|
|
"sq5": 45,
|
|
}
|
|
|
|
|
|
def _parse_ts(ts: str) -> datetime:
|
|
return datetime.fromisoformat(ts)
|
|
|
|
|
|
def _compact_home_summary(home_state: Dict, focus_rooms: Sequence[str]) -> str:
|
|
focus_rooms = [room for room in focus_rooms if room]
|
|
lines = [f"家庭布局: {home_state.get('layout_name', '未知布局')}"]
|
|
devices_by_room = defaultdict(list)
|
|
for dev_id, info in home_state.get("devices", {}).items():
|
|
room = info.get("room_id", "unknown")
|
|
devices_by_room[room].append((info.get("display_name", dev_id), info.get("device_type", "")))
|
|
|
|
ordered_rooms = focus_rooms + [room for room in sorted(devices_by_room) if room not in focus_rooms]
|
|
for room in ordered_rooms:
|
|
if room not in devices_by_room:
|
|
continue
|
|
names = [f"{name}[{dtype}]" for name, dtype in devices_by_room[room][:8]]
|
|
if len(devices_by_room[room]) > 8:
|
|
names.append(f"... 共{len(devices_by_room[room])}个设备")
|
|
prefix = "重点房间" if room in focus_rooms else "房间"
|
|
lines.append(f"- {prefix} {room}: " + ", ".join(names))
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _event_room(event: Dict, home_state: Dict) -> str:
|
|
device = home_state.get("devices", {}).get(event.get("device_id", ""), {})
|
|
return device.get("room_id", "unknown")
|
|
|
|
|
|
def _format_event(event: Dict) -> str:
|
|
ts = event.get("timestamp", "")
|
|
dev = event.get("device_id", "")
|
|
if event.get("event_type") == "attribute_change":
|
|
return f"[{ts}] {dev} | {event.get('cluster','')}.{event.get('attribute','')} = {event.get('value')}"
|
|
if event.get("event_type") == "device_event":
|
|
fields = ", ".join(f"{k}={v}" for k, v in event.get("fields", {}).items())
|
|
return f"[{ts}] {dev} | Event: {event.get('event_name','')}({fields})"
|
|
if event.get("event_type") == "command":
|
|
return f"[{ts}] {dev} | Command: {event.get('command','')}"
|
|
return f"[{ts}] {dev}"
|
|
|
|
|
|
def _focus_devices(episode: Dict) -> Set[str]:
|
|
focus = set(episode.get("metadata", {}).get("focus_devices", []))
|
|
for event in episode.get("event_sequence", []):
|
|
if event.get("_is_injected"):
|
|
focus.add(event.get("device_id", ""))
|
|
focus.discard("")
|
|
return focus
|
|
|
|
|
|
def _focus_rooms(episode: Dict) -> List[str]:
|
|
meta = episode.get("metadata", {})
|
|
gt = episode.get("ground_truth", {})
|
|
rooms = []
|
|
for room in [meta.get("focus_room"), gt.get("target_room")]:
|
|
if room and room not in rooms:
|
|
rooms.append(room)
|
|
return rooms
|
|
|
|
|
|
def _select_events(episode: Dict) -> List[Dict]:
|
|
events = episode.get("event_sequence", [])
|
|
query_family = episode.get("metadata", {}).get("query_family", "sq3")
|
|
focus_devices = _focus_devices(episode)
|
|
home_state = episode.get("home_state", {})
|
|
focus_rooms = _focus_rooms(episode)
|
|
|
|
injected_indices = [idx for idx, event in enumerate(events) if event.get("_is_injected")]
|
|
selected_indices: Set[int] = set(injected_indices)
|
|
|
|
if injected_indices:
|
|
radius = timedelta(minutes=WINDOW_MINUTES.get(query_family, 60))
|
|
for center_idx in injected_indices:
|
|
center_dt = _parse_ts(events[center_idx]["timestamp"])
|
|
for idx, event in enumerate(events):
|
|
event_dt = _parse_ts(event["timestamp"])
|
|
same_focus = event.get("device_id") in focus_devices or _event_room(event, home_state) in focus_rooms
|
|
if abs(event_dt - center_dt) <= radius:
|
|
selected_indices.add(idx)
|
|
elif same_focus and abs(event_dt - center_dt) <= radius * 2:
|
|
selected_indices.add(idx)
|
|
else:
|
|
if events:
|
|
last_dt = _parse_ts(events[-1]["timestamp"])
|
|
window = timedelta(minutes=WINDOW_MINUTES.get(query_family, 60) * 2)
|
|
for idx, event in enumerate(events):
|
|
event_dt = _parse_ts(event["timestamp"])
|
|
if last_dt - event_dt <= window:
|
|
selected_indices.add(idx)
|
|
|
|
# Bring in a sparse history for focus devices to preserve trends.
|
|
if focus_devices:
|
|
device_buckets: Dict[str, List[int]] = defaultdict(list)
|
|
for idx, event in enumerate(events):
|
|
if event.get("device_id") in focus_devices:
|
|
device_buckets[event["device_id"]].append(idx)
|
|
for indices in device_buckets.values():
|
|
stride = max(1, len(indices) // 6)
|
|
selected_indices.update(indices[::stride])
|
|
selected_indices.add(indices[-1])
|
|
|
|
ordered = sorted(selected_indices)
|
|
limit = PROMPT_EVENT_LIMITS.get(query_family, 120)
|
|
if len(ordered) <= limit:
|
|
return _compress_events([events[idx] for idx in ordered], episode)
|
|
|
|
kept = []
|
|
injected_set = set(injected_indices)
|
|
injected_keep = [idx for idx in ordered if idx in injected_set]
|
|
kept.extend(injected_keep)
|
|
|
|
non_injected = [idx for idx in ordered if idx not in injected_set]
|
|
remaining = max(0, limit - len(kept))
|
|
if remaining <= 0:
|
|
final_indices = sorted(kept[:limit])
|
|
return _compress_events([events[idx] for idx in final_indices], episode)
|
|
|
|
stride = max(1, len(non_injected) // remaining)
|
|
sampled = non_injected[::stride][:remaining]
|
|
final_indices = sorted(set(kept + sampled))
|
|
if len(final_indices) > limit:
|
|
final_indices = final_indices[:limit]
|
|
return _compress_events([events[idx] for idx in final_indices], episode)
|
|
|
|
|
|
def _compress_events(events: List[Dict], episode: Dict) -> List[Dict]:
|
|
query_family = episode.get("metadata", {}).get("query_family", "sq3")
|
|
focus_devices = _focus_devices(episode)
|
|
home_state = episode.get("home_state", {})
|
|
focus_rooms = set(_focus_rooms(episode))
|
|
|
|
compressed: List[Dict] = []
|
|
last_seen: Dict[Tuple[str, str, str, str, str], datetime] = {}
|
|
|
|
for event in events:
|
|
if event.get("_is_injected"):
|
|
compressed.append(event)
|
|
continue
|
|
|
|
if event.get("event_type") == "attribute_change":
|
|
cluster = event.get("cluster", "")
|
|
attribute = event.get("attribute", "")
|
|
value = str(event.get("value"))
|
|
room = _event_room(event, home_state)
|
|
is_focus = event.get("device_id") in focus_devices or room in focus_rooms
|
|
|
|
# For local decision tasks, drop unrelated temperature telemetry.
|
|
if query_family in {"sq2", "sq5"} and cluster == "TemperatureMeasurement" and not is_focus:
|
|
continue
|
|
|
|
key = (event.get("device_id", ""), cluster, attribute, value, room)
|
|
current_dt = _parse_ts(event["timestamp"])
|
|
min_gap = 15
|
|
if cluster == "OccupancySensing":
|
|
min_gap = 20
|
|
elif cluster == "TemperatureMeasurement":
|
|
min_gap = 30
|
|
|
|
previous_dt = last_seen.get(key)
|
|
if previous_dt and (current_dt - previous_dt) < timedelta(minutes=min_gap):
|
|
continue
|
|
last_seen[key] = current_dt
|
|
|
|
compressed.append(event)
|
|
|
|
return compressed
|
|
|
|
|
|
def _daily_activity_summary(episode: Dict) -> str:
|
|
query_family = episode.get("metadata", {}).get("query_family", "sq3")
|
|
if query_family not in {"sq3", "sq4", "sq5"}:
|
|
return ""
|
|
|
|
home_state = episode.get("home_state", {})
|
|
room_day_counter: Dict[Tuple[str, str], Counter] = defaultdict(Counter)
|
|
for event in episode.get("event_sequence", []):
|
|
room = _event_room(event, home_state)
|
|
day = event.get("timestamp", "")[:10]
|
|
if event.get("event_type") == "attribute_change" and event.get("cluster") == "TemperatureMeasurement":
|
|
continue
|
|
room_day_counter[day][room] += 1
|
|
|
|
focus_rooms = _focus_rooms(episode)
|
|
lines = ["## 跨天活动摘要"]
|
|
for day in sorted(room_day_counter):
|
|
room_counts = room_day_counter[day]
|
|
ordered_rooms = focus_rooms + [room for room, _ in room_counts.most_common() if room not in focus_rooms]
|
|
parts = []
|
|
seen = set()
|
|
for room in ordered_rooms:
|
|
if room in seen:
|
|
continue
|
|
seen.add(room)
|
|
parts.append(f"{room}={room_counts.get(room, 0)}")
|
|
if len(parts) >= 6:
|
|
break
|
|
lines.append(f"- {day}: " + ", ".join(parts))
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _build_focus_note(episode: Dict, selected_events: Sequence[Dict]) -> str:
|
|
query_family = episode.get("metadata", {}).get("query_family", "sq3")
|
|
focus_devices = sorted(_focus_devices(episode))
|
|
focus_rooms = _focus_rooms(episode)
|
|
lines = ["## 已选日志说明"]
|
|
lines.append(
|
|
"以下日志不是整天原始全量心跳,而是围绕关键设备、关键时间窗和跨天模式筛出的训练片段。"
|
|
)
|
|
lines.append(f"- 任务类型: {query_family.upper()}")
|
|
if focus_rooms:
|
|
lines.append(f"- 重点房间: {', '.join(focus_rooms)}")
|
|
if focus_devices:
|
|
lines.append(f"- 重点设备: {', '.join(focus_devices[:8])}")
|
|
lines.append(f"- 已选事件数: {len(selected_events)}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def build_dpo_prompt(episode: Dict) -> Dict[str, str]:
|
|
selected_events = _select_events(episode)
|
|
home_state = episode.get("home_state", {})
|
|
focus_rooms = _focus_rooms(episode)
|
|
|
|
parts = [
|
|
"## 家庭环境信息",
|
|
_compact_home_summary(home_state, focus_rooms),
|
|
"",
|
|
_build_focus_note(episode, selected_events),
|
|
"",
|
|
]
|
|
|
|
daily_summary = _daily_activity_summary(episode)
|
|
if daily_summary:
|
|
parts.extend([daily_summary, ""])
|
|
|
|
parts.append("## 设备事件日志")
|
|
parts.extend(_format_event(event) for event in selected_events)
|
|
parts.append("")
|
|
parts.append("## 任务")
|
|
parts.append(episode.get("query", ""))
|
|
parts.append("")
|
|
parts.append(RESPONSE_FORMAT_BASELINE)
|
|
|
|
return {
|
|
"system": SYSTEM_PROMPT_BASELINE,
|
|
"user": "\n".join(parts),
|
|
}
|