Files
llmiotsafe/DPODataGen/prompt_builder.py
2026-05-12 17:01:39 +08:00

274 lines
10 KiB
Python

from __future__ import annotations
from collections import Counter, defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Sequence, Set, Tuple
from src.evaluation.prompt_builder import RESPONSE_FORMAT_BASELINE, SYSTEM_PROMPT_BASELINE
PROMPT_EVENT_LIMITS = {
"sq1": 140,
"sq2": 80,
"sq3": 140,
"sq4": 150,
"sq5": 100,
}
WINDOW_MINUTES = {
"sq1": 180,
"sq2": 35,
"sq3": 50,
"sq4": 60,
"sq5": 45,
}
def _parse_ts(ts: str) -> datetime:
return datetime.fromisoformat(ts)
def _compact_home_summary(home_state: Dict, focus_rooms: Sequence[str]) -> str:
focus_rooms = [room for room in focus_rooms if room]
lines = [f"家庭布局: {home_state.get('layout_name', '未知布局')}"]
devices_by_room = defaultdict(list)
for dev_id, info in home_state.get("devices", {}).items():
room = info.get("room_id", "unknown")
devices_by_room[room].append((info.get("display_name", dev_id), info.get("device_type", "")))
ordered_rooms = focus_rooms + [room for room in sorted(devices_by_room) if room not in focus_rooms]
for room in ordered_rooms:
if room not in devices_by_room:
continue
names = [f"{name}[{dtype}]" for name, dtype in devices_by_room[room][:8]]
if len(devices_by_room[room]) > 8:
names.append(f"... 共{len(devices_by_room[room])}个设备")
prefix = "重点房间" if room in focus_rooms else "房间"
lines.append(f"- {prefix} {room}: " + ", ".join(names))
return "\n".join(lines)
def _event_room(event: Dict, home_state: Dict) -> str:
device = home_state.get("devices", {}).get(event.get("device_id", ""), {})
return device.get("room_id", "unknown")
def _format_event(event: Dict) -> str:
ts = event.get("timestamp", "")
dev = event.get("device_id", "")
if event.get("event_type") == "attribute_change":
return f"[{ts}] {dev} | {event.get('cluster','')}.{event.get('attribute','')} = {event.get('value')}"
if event.get("event_type") == "device_event":
fields = ", ".join(f"{k}={v}" for k, v in event.get("fields", {}).items())
return f"[{ts}] {dev} | Event: {event.get('event_name','')}({fields})"
if event.get("event_type") == "command":
return f"[{ts}] {dev} | Command: {event.get('command','')}"
return f"[{ts}] {dev}"
def _focus_devices(episode: Dict) -> Set[str]:
focus = set(episode.get("metadata", {}).get("focus_devices", []))
for event in episode.get("event_sequence", []):
if event.get("_is_injected"):
focus.add(event.get("device_id", ""))
focus.discard("")
return focus
def _focus_rooms(episode: Dict) -> List[str]:
meta = episode.get("metadata", {})
gt = episode.get("ground_truth", {})
rooms = []
for room in [meta.get("focus_room"), gt.get("target_room")]:
if room and room not in rooms:
rooms.append(room)
return rooms
def _select_events(episode: Dict) -> List[Dict]:
events = episode.get("event_sequence", [])
query_family = episode.get("metadata", {}).get("query_family", "sq3")
focus_devices = _focus_devices(episode)
home_state = episode.get("home_state", {})
focus_rooms = _focus_rooms(episode)
injected_indices = [idx for idx, event in enumerate(events) if event.get("_is_injected")]
selected_indices: Set[int] = set(injected_indices)
if injected_indices:
radius = timedelta(minutes=WINDOW_MINUTES.get(query_family, 60))
for center_idx in injected_indices:
center_dt = _parse_ts(events[center_idx]["timestamp"])
for idx, event in enumerate(events):
event_dt = _parse_ts(event["timestamp"])
same_focus = event.get("device_id") in focus_devices or _event_room(event, home_state) in focus_rooms
if abs(event_dt - center_dt) <= radius:
selected_indices.add(idx)
elif same_focus and abs(event_dt - center_dt) <= radius * 2:
selected_indices.add(idx)
else:
if events:
last_dt = _parse_ts(events[-1]["timestamp"])
window = timedelta(minutes=WINDOW_MINUTES.get(query_family, 60) * 2)
for idx, event in enumerate(events):
event_dt = _parse_ts(event["timestamp"])
if last_dt - event_dt <= window:
selected_indices.add(idx)
# Bring in a sparse history for focus devices to preserve trends.
if focus_devices:
device_buckets: Dict[str, List[int]] = defaultdict(list)
for idx, event in enumerate(events):
if event.get("device_id") in focus_devices:
device_buckets[event["device_id"]].append(idx)
for indices in device_buckets.values():
stride = max(1, len(indices) // 6)
selected_indices.update(indices[::stride])
selected_indices.add(indices[-1])
ordered = sorted(selected_indices)
limit = PROMPT_EVENT_LIMITS.get(query_family, 120)
if len(ordered) <= limit:
return _compress_events([events[idx] for idx in ordered], episode)
kept = []
injected_set = set(injected_indices)
injected_keep = [idx for idx in ordered if idx in injected_set]
kept.extend(injected_keep)
non_injected = [idx for idx in ordered if idx not in injected_set]
remaining = max(0, limit - len(kept))
if remaining <= 0:
final_indices = sorted(kept[:limit])
return _compress_events([events[idx] for idx in final_indices], episode)
stride = max(1, len(non_injected) // remaining)
sampled = non_injected[::stride][:remaining]
final_indices = sorted(set(kept + sampled))
if len(final_indices) > limit:
final_indices = final_indices[:limit]
return _compress_events([events[idx] for idx in final_indices], episode)
def _compress_events(events: List[Dict], episode: Dict) -> List[Dict]:
query_family = episode.get("metadata", {}).get("query_family", "sq3")
focus_devices = _focus_devices(episode)
home_state = episode.get("home_state", {})
focus_rooms = set(_focus_rooms(episode))
compressed: List[Dict] = []
last_seen: Dict[Tuple[str, str, str, str, str], datetime] = {}
for event in events:
if event.get("_is_injected"):
compressed.append(event)
continue
if event.get("event_type") == "attribute_change":
cluster = event.get("cluster", "")
attribute = event.get("attribute", "")
value = str(event.get("value"))
room = _event_room(event, home_state)
is_focus = event.get("device_id") in focus_devices or room in focus_rooms
# For local decision tasks, drop unrelated temperature telemetry.
if query_family in {"sq2", "sq5"} and cluster == "TemperatureMeasurement" and not is_focus:
continue
key = (event.get("device_id", ""), cluster, attribute, value, room)
current_dt = _parse_ts(event["timestamp"])
min_gap = 15
if cluster == "OccupancySensing":
min_gap = 20
elif cluster == "TemperatureMeasurement":
min_gap = 30
previous_dt = last_seen.get(key)
if previous_dt and (current_dt - previous_dt) < timedelta(minutes=min_gap):
continue
last_seen[key] = current_dt
compressed.append(event)
return compressed
def _daily_activity_summary(episode: Dict) -> str:
query_family = episode.get("metadata", {}).get("query_family", "sq3")
if query_family not in {"sq3", "sq4", "sq5"}:
return ""
home_state = episode.get("home_state", {})
room_day_counter: Dict[Tuple[str, str], Counter] = defaultdict(Counter)
for event in episode.get("event_sequence", []):
room = _event_room(event, home_state)
day = event.get("timestamp", "")[:10]
if event.get("event_type") == "attribute_change" and event.get("cluster") == "TemperatureMeasurement":
continue
room_day_counter[day][room] += 1
focus_rooms = _focus_rooms(episode)
lines = ["## 跨天活动摘要"]
for day in sorted(room_day_counter):
room_counts = room_day_counter[day]
ordered_rooms = focus_rooms + [room for room, _ in room_counts.most_common() if room not in focus_rooms]
parts = []
seen = set()
for room in ordered_rooms:
if room in seen:
continue
seen.add(room)
parts.append(f"{room}={room_counts.get(room, 0)}")
if len(parts) >= 6:
break
lines.append(f"- {day}: " + ", ".join(parts))
return "\n".join(lines)
def _build_focus_note(episode: Dict, selected_events: Sequence[Dict]) -> str:
query_family = episode.get("metadata", {}).get("query_family", "sq3")
focus_devices = sorted(_focus_devices(episode))
focus_rooms = _focus_rooms(episode)
lines = ["## 已选日志说明"]
lines.append(
"以下日志不是整天原始全量心跳,而是围绕关键设备、关键时间窗和跨天模式筛出的训练片段。"
)
lines.append(f"- 任务类型: {query_family.upper()}")
if focus_rooms:
lines.append(f"- 重点房间: {', '.join(focus_rooms)}")
if focus_devices:
lines.append(f"- 重点设备: {', '.join(focus_devices[:8])}")
lines.append(f"- 已选事件数: {len(selected_events)}")
return "\n".join(lines)
def build_dpo_prompt(episode: Dict) -> Dict[str, str]:
selected_events = _select_events(episode)
home_state = episode.get("home_state", {})
focus_rooms = _focus_rooms(episode)
parts = [
"## 家庭环境信息",
_compact_home_summary(home_state, focus_rooms),
"",
_build_focus_note(episode, selected_events),
"",
]
daily_summary = _daily_activity_summary(episode)
if daily_summary:
parts.extend([daily_summary, ""])
parts.append("## 设备事件日志")
parts.extend(_format_event(event) for event in selected_events)
parts.append("")
parts.append("## 任务")
parts.append(episode.get("query", ""))
parts.append("")
parts.append(RESPONSE_FORMAT_BASELINE)
return {
"system": SYSTEM_PROMPT_BASELINE,
"user": "\n".join(parts),
}