Files
llmiotsafe/prepare_annotation.py
2026-05-12 17:01:39 +08:00

259 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
prepare_annotation.py — 准备人工标注数据集。
生成:
1. annotation_tasks/ — 每个 episode 一个 txt 文件,只包含标注员应该看到的信息
2. annotation_sheet.csv — 标注表格,标注员填写判断结果
3. annotation_key.json — 内部用,包含 ground truth不给标注员看
使用: python prepare_annotation.py --sample-size 100
"""
import json
import os
import sys
import csv
import random
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))
from src.evaluation.prompt_builder import summarize_home_layout, format_event_log
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
OUTPUT_DIR = PROJECT_ROOT / "annotation"
def stratified_sample(sample_size: int, seed: int = 42) -> list:
"""分层抽样"""
random.seed(seed)
all_files = {}
for sq in ["sq1", "sq2", "sq3", "sq4", "sq5"]:
sq_dir = BENCHMARK_DIR / sq
if not sq_dir.exists():
continue
for f in sq_dir.glob("*.json"):
for v in ["TP", "FP", "TN"]:
if f"_{v}_" in f.stem:
key = f"{sq.upper()}_{v}"
if key not in all_files:
all_files[key] = []
all_files[key].append(str(f))
break
sampled = []
per_strata = max(2, sample_size // len(all_files))
for key, files in sorted(all_files.items()):
n = min(per_strata, len(files))
sampled.extend(random.sample(files, n))
remaining = [f for files in all_files.values() for f in files if f not in sampled]
random.shuffle(remaining)
while len(sampled) < sample_size and remaining:
sampled.append(remaining.pop())
random.shuffle(sampled)
return sampled[:sample_size]
def build_annotation_text(episode: dict, task_id: int) -> str:
"""构建标注员看到的文本(不含答案)"""
home_state = episode.get("home_state", {})
events = episode.get("event_sequence", [])
query = episode.get("query", "")
meta = episode.get("metadata", {})
# 过滤内部字段
clean_events = []
for evt in events:
clean = {
"timestamp": evt.get("timestamp"),
"device_id": evt.get("device_id"),
"event_type": evt.get("event_type"),
}
if evt.get("event_type") == "attribute_change":
clean["cluster"] = evt.get("cluster", "")
clean["attribute"] = evt.get("attribute", "")
clean["value"] = evt.get("value")
elif evt.get("event_type") == "device_event":
clean["event_name"] = evt.get("event_name", "")
clean["fields"] = evt.get("fields", {})
elif evt.get("event_type") == "command":
clean["command"] = evt.get("command", "")
clean_events.append(clean)
# SQ3 时间窗过滤
import re
hours_match = re.search(r"过去\s*(\d+)\s*小时", query)
if hours_match and clean_events:
hours = int(hours_match.group(1))
from datetime import datetime, timedelta
last_ts = clean_events[-1].get("timestamp", "")
if last_ts:
try:
last_dt = datetime.fromisoformat(last_ts)
cutoff_ts = (last_dt - timedelta(hours=hours)).isoformat()
clean_events = [e for e in clean_events if e.get("timestamp", "") >= cutoff_ts]
except:
pass
lines = []
lines.append(f"{'='*70}")
lines.append(f" 标注任务 #{task_id:03d}")
lines.append(f"{'='*70}")
lines.append("")
lines.append("## 家庭环境")
lines.append("")
lines.append(summarize_home_layout(home_state))
lines.append("")
lines.append(f"## 设备事件日志(共 {len(clean_events)} 条)")
lines.append("")
lines.append(format_event_log(clean_events))
lines.append("")
lines.append("## 任务")
lines.append("")
lines.append(query)
lines.append("")
lines.append("## Matter 协议说明")
lines.append("- 温度单位: 0.01°C如 2350 = 23.50°C")
lines.append("- Occupancy: 1=有人, 0=无人")
lines.append("- BooleanState门窗: True=关闭, False=打开")
lines.append("- DoorLock.LockState: 1=已锁, 2=已解锁")
lines.append("")
return "\n".join(lines)
def main():
import argparse
parser = argparse.ArgumentParser(description="准备人工标注数据集")
parser.add_argument("--sample-size", type=int, default=100)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
output_dir = OUTPUT_DIR
tasks_dir = output_dir / "annotation_tasks"
tasks_dir.mkdir(parents=True, exist_ok=True)
print(f"Sampling {args.sample_size} episodes...")
sampled = stratified_sample(args.sample_size, args.seed)
print(f" Sampled {len(sampled)} episodes")
# 生成标注文件
csv_rows = []
key_data = []
for i, ep_path in enumerate(sampled):
with open(ep_path, "r", encoding="utf-8") as f:
episode = json.load(f)
task_id = i + 1
episode_id = episode.get("episode_id", "")
meta = episode.get("metadata", {})
gt = episode.get("ground_truth", {})
# 生成标注文本
text = build_annotation_text(episode, task_id)
task_file = tasks_dir / f"task_{task_id:03d}.txt"
with open(task_file, "w", encoding="utf-8") as f:
f.write(text)
# CSV 行(标注员填写的部分)
csv_rows.append({
"task_id": task_id,
"is_anomaly": "", # 标注员填: yes / no
"threat_type": "", # 标注员填: 具体类型或 none
"confidence": "", # 标注员填: high / medium / low
"brief_reason": "", # 标注员填: 一句话理由
})
# 答案(不给标注员)
key_data.append({
"task_id": task_id,
"episode_id": episode_id,
"episode_path": ep_path,
"sq_type": meta.get("sq_type", ""),
"variant": meta.get("variant", ""),
"layout_id": meta.get("layout_id", ""),
"gt_is_anomaly": gt.get("is_anomaly"),
"gt_threat_type": gt.get("threat_type", "none"),
"gt_category": gt.get("category", ""),
"gt_scenario_id": gt.get("scenario_id", ""),
"gt_difficulty_label": gt.get("difficulty_label", ""),
"gt_difficulty_score": gt.get("difficulty_score", 0),
})
# 写 CSV
csv_path = output_dir / "annotation_sheet.csv"
with open(csv_path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["task_id", "is_anomaly", "threat_type", "confidence", "brief_reason"])
writer.writeheader()
writer.writerows(csv_rows)
# 写答案
key_path = output_dir / "annotation_key.json"
with open(key_path, "w", encoding="utf-8") as f:
json.dump(key_data, f, ensure_ascii=False, indent=2)
# 写标注说明
guide_path = output_dir / "标注说明.txt"
with open(guide_path, "w", encoding="utf-8") as f:
f.write("""SafeHome Benchmark 人工标注说明
================================
一、你的任务
阅读 annotation_tasks/ 目录下的每个 txt 文件,根据智能家居设备日志判断是否存在安全异常。
二、标注方式
打开 annotation_sheet.csv逐行填写以下字段
1. is_anomaly: 填 yes 或 no
- yes = 日志中存在安全异常/威胁
- no = 一切正常,无异常
2. threat_type: 如果 is_anomaly=yes填写具体类型从以下选择
intrusion / tailgating / credential_theft / fire_risk / unattended_cooking /
carbon_monoxide / sensor_stuck / sensor_drift / sensor_malfunction /
actuator_stuck / lock_malfunction / safety_device_failure / water_leak /
possible_fall / abnormal_inactivity / health_concern / child_safety /
behavioral_anomaly / none
如果 is_anomaly=no填 none
3. confidence: 填 high / medium / low
- high = 非常确定
- medium = 比较确定但有一点犹豫
- low = 不太确定,猜的
4. brief_reason: 一句话解释你的判断理由
三、注意事项
- 每个 task 独立判断,不要参考其他 task 的结果
- Matter 协议温度单位是 0.01°C比如 2350 = 23.50°C
- 正常做饭时厨房温度上升 3-5°C 是正常的
- 一天内温度波动 ±2°C 是正常的日间变化
- Occupancy 传感器有人时每5分钟心跳一次、无人时每30分钟心跳一次这是正常的
- 不要过度解读正常的日常活动
四、参考时间
预计每个 task 需要 3-5 分钟100 个 task 约 5-8 小时。
建议分 2-3 次完成,避免疲劳影响判断质量。
""")
print(f"\n生成完成:")
print(f" 标注任务: {tasks_dir} ({len(sampled)} 个 txt 文件)")
print(f" 标注表格: {csv_path}")
print(f" 标注说明: {guide_path}")
print(f" 答案 (不给标注员): {key_path}")
print(f"\n分发给标注员的文件:")
print(f" 1. annotation_tasks/ 目录(全部 txt 文件)")
print(f" 2. annotation_sheet.csv")
print(f" 3. 标注说明.txt")
print(f"\n不要给标注员的文件:")
print(f" - annotation_key.json包含答案")
if __name__ == "__main__":
main()