259 lines
9.2 KiB
Python
259 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
prepare_annotation.py — 准备人工标注数据集。
|
||
|
||
生成:
|
||
1. annotation_tasks/ — 每个 episode 一个 txt 文件,只包含标注员应该看到的信息
|
||
2. annotation_sheet.csv — 标注表格,标注员填写判断结果
|
||
3. annotation_key.json — 内部用,包含 ground truth(不给标注员看)
|
||
|
||
使用: python prepare_annotation.py --sample-size 100
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import csv
|
||
import random
|
||
from pathlib import Path
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from src.evaluation.prompt_builder import summarize_home_layout, format_event_log
|
||
|
||
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
|
||
OUTPUT_DIR = PROJECT_ROOT / "annotation"
|
||
|
||
|
||
def stratified_sample(sample_size: int, seed: int = 42) -> list:
|
||
"""分层抽样"""
|
||
random.seed(seed)
|
||
all_files = {}
|
||
for sq in ["sq1", "sq2", "sq3", "sq4", "sq5"]:
|
||
sq_dir = BENCHMARK_DIR / sq
|
||
if not sq_dir.exists():
|
||
continue
|
||
for f in sq_dir.glob("*.json"):
|
||
for v in ["TP", "FP", "TN"]:
|
||
if f"_{v}_" in f.stem:
|
||
key = f"{sq.upper()}_{v}"
|
||
if key not in all_files:
|
||
all_files[key] = []
|
||
all_files[key].append(str(f))
|
||
break
|
||
|
||
sampled = []
|
||
per_strata = max(2, sample_size // len(all_files))
|
||
for key, files in sorted(all_files.items()):
|
||
n = min(per_strata, len(files))
|
||
sampled.extend(random.sample(files, n))
|
||
|
||
remaining = [f for files in all_files.values() for f in files if f not in sampled]
|
||
random.shuffle(remaining)
|
||
while len(sampled) < sample_size and remaining:
|
||
sampled.append(remaining.pop())
|
||
|
||
random.shuffle(sampled)
|
||
return sampled[:sample_size]
|
||
|
||
|
||
def build_annotation_text(episode: dict, task_id: int) -> str:
|
||
"""构建标注员看到的文本(不含答案)"""
|
||
home_state = episode.get("home_state", {})
|
||
events = episode.get("event_sequence", [])
|
||
query = episode.get("query", "")
|
||
meta = episode.get("metadata", {})
|
||
|
||
# 过滤内部字段
|
||
clean_events = []
|
||
for evt in events:
|
||
clean = {
|
||
"timestamp": evt.get("timestamp"),
|
||
"device_id": evt.get("device_id"),
|
||
"event_type": evt.get("event_type"),
|
||
}
|
||
if evt.get("event_type") == "attribute_change":
|
||
clean["cluster"] = evt.get("cluster", "")
|
||
clean["attribute"] = evt.get("attribute", "")
|
||
clean["value"] = evt.get("value")
|
||
elif evt.get("event_type") == "device_event":
|
||
clean["event_name"] = evt.get("event_name", "")
|
||
clean["fields"] = evt.get("fields", {})
|
||
elif evt.get("event_type") == "command":
|
||
clean["command"] = evt.get("command", "")
|
||
clean_events.append(clean)
|
||
|
||
# SQ3 时间窗过滤
|
||
import re
|
||
hours_match = re.search(r"过去\s*(\d+)\s*小时", query)
|
||
if hours_match and clean_events:
|
||
hours = int(hours_match.group(1))
|
||
from datetime import datetime, timedelta
|
||
last_ts = clean_events[-1].get("timestamp", "")
|
||
if last_ts:
|
||
try:
|
||
last_dt = datetime.fromisoformat(last_ts)
|
||
cutoff_ts = (last_dt - timedelta(hours=hours)).isoformat()
|
||
clean_events = [e for e in clean_events if e.get("timestamp", "") >= cutoff_ts]
|
||
except:
|
||
pass
|
||
|
||
lines = []
|
||
lines.append(f"{'='*70}")
|
||
lines.append(f" 标注任务 #{task_id:03d}")
|
||
lines.append(f"{'='*70}")
|
||
lines.append("")
|
||
lines.append("## 家庭环境")
|
||
lines.append("")
|
||
lines.append(summarize_home_layout(home_state))
|
||
lines.append("")
|
||
lines.append(f"## 设备事件日志(共 {len(clean_events)} 条)")
|
||
lines.append("")
|
||
lines.append(format_event_log(clean_events))
|
||
lines.append("")
|
||
lines.append("## 任务")
|
||
lines.append("")
|
||
lines.append(query)
|
||
lines.append("")
|
||
lines.append("## Matter 协议说明")
|
||
lines.append("- 温度单位: 0.01°C(如 2350 = 23.50°C)")
|
||
lines.append("- Occupancy: 1=有人, 0=无人")
|
||
lines.append("- BooleanState(门窗): True=关闭, False=打开")
|
||
lines.append("- DoorLock.LockState: 1=已锁, 2=已解锁")
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="准备人工标注数据集")
|
||
parser.add_argument("--sample-size", type=int, default=100)
|
||
parser.add_argument("--seed", type=int, default=42)
|
||
args = parser.parse_args()
|
||
|
||
output_dir = OUTPUT_DIR
|
||
tasks_dir = output_dir / "annotation_tasks"
|
||
tasks_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"Sampling {args.sample_size} episodes...")
|
||
sampled = stratified_sample(args.sample_size, args.seed)
|
||
print(f" Sampled {len(sampled)} episodes")
|
||
|
||
# 生成标注文件
|
||
csv_rows = []
|
||
key_data = []
|
||
|
||
for i, ep_path in enumerate(sampled):
|
||
with open(ep_path, "r", encoding="utf-8") as f:
|
||
episode = json.load(f)
|
||
|
||
task_id = i + 1
|
||
episode_id = episode.get("episode_id", "")
|
||
meta = episode.get("metadata", {})
|
||
gt = episode.get("ground_truth", {})
|
||
|
||
# 生成标注文本
|
||
text = build_annotation_text(episode, task_id)
|
||
task_file = tasks_dir / f"task_{task_id:03d}.txt"
|
||
with open(task_file, "w", encoding="utf-8") as f:
|
||
f.write(text)
|
||
|
||
# CSV 行(标注员填写的部分)
|
||
csv_rows.append({
|
||
"task_id": task_id,
|
||
"is_anomaly": "", # 标注员填: yes / no
|
||
"threat_type": "", # 标注员填: 具体类型或 none
|
||
"confidence": "", # 标注员填: high / medium / low
|
||
"brief_reason": "", # 标注员填: 一句话理由
|
||
})
|
||
|
||
# 答案(不给标注员)
|
||
key_data.append({
|
||
"task_id": task_id,
|
||
"episode_id": episode_id,
|
||
"episode_path": ep_path,
|
||
"sq_type": meta.get("sq_type", ""),
|
||
"variant": meta.get("variant", ""),
|
||
"layout_id": meta.get("layout_id", ""),
|
||
"gt_is_anomaly": gt.get("is_anomaly"),
|
||
"gt_threat_type": gt.get("threat_type", "none"),
|
||
"gt_category": gt.get("category", ""),
|
||
"gt_scenario_id": gt.get("scenario_id", ""),
|
||
"gt_difficulty_label": gt.get("difficulty_label", ""),
|
||
"gt_difficulty_score": gt.get("difficulty_score", 0),
|
||
})
|
||
|
||
# 写 CSV
|
||
csv_path = output_dir / "annotation_sheet.csv"
|
||
with open(csv_path, "w", encoding="utf-8-sig", newline="") as f:
|
||
writer = csv.DictWriter(f, fieldnames=["task_id", "is_anomaly", "threat_type", "confidence", "brief_reason"])
|
||
writer.writeheader()
|
||
writer.writerows(csv_rows)
|
||
|
||
# 写答案
|
||
key_path = output_dir / "annotation_key.json"
|
||
with open(key_path, "w", encoding="utf-8") as f:
|
||
json.dump(key_data, f, ensure_ascii=False, indent=2)
|
||
|
||
# 写标注说明
|
||
guide_path = output_dir / "标注说明.txt"
|
||
with open(guide_path, "w", encoding="utf-8") as f:
|
||
f.write("""SafeHome Benchmark 人工标注说明
|
||
================================
|
||
|
||
一、你的任务
|
||
阅读 annotation_tasks/ 目录下的每个 txt 文件,根据智能家居设备日志判断是否存在安全异常。
|
||
|
||
二、标注方式
|
||
打开 annotation_sheet.csv,逐行填写以下字段:
|
||
|
||
1. is_anomaly: 填 yes 或 no
|
||
- yes = 日志中存在安全异常/威胁
|
||
- no = 一切正常,无异常
|
||
|
||
2. threat_type: 如果 is_anomaly=yes,填写具体类型,从以下选择:
|
||
intrusion / tailgating / credential_theft / fire_risk / unattended_cooking /
|
||
carbon_monoxide / sensor_stuck / sensor_drift / sensor_malfunction /
|
||
actuator_stuck / lock_malfunction / safety_device_failure / water_leak /
|
||
possible_fall / abnormal_inactivity / health_concern / child_safety /
|
||
behavioral_anomaly / none
|
||
如果 is_anomaly=no,填 none
|
||
|
||
3. confidence: 填 high / medium / low
|
||
- high = 非常确定
|
||
- medium = 比较确定但有一点犹豫
|
||
- low = 不太确定,猜的
|
||
|
||
4. brief_reason: 一句话解释你的判断理由
|
||
|
||
三、注意事项
|
||
- 每个 task 独立判断,不要参考其他 task 的结果
|
||
- Matter 协议温度单位是 0.01°C,比如 2350 = 23.50°C
|
||
- 正常做饭时厨房温度上升 3-5°C 是正常的
|
||
- 一天内温度波动 ±2°C 是正常的日间变化
|
||
- Occupancy 传感器有人时每5分钟心跳一次、无人时每30分钟心跳一次,这是正常的
|
||
- 不要过度解读正常的日常活动
|
||
|
||
四、参考时间
|
||
预计每个 task 需要 3-5 分钟,100 个 task 约 5-8 小时。
|
||
建议分 2-3 次完成,避免疲劳影响判断质量。
|
||
""")
|
||
|
||
print(f"\n生成完成:")
|
||
print(f" 标注任务: {tasks_dir} ({len(sampled)} 个 txt 文件)")
|
||
print(f" 标注表格: {csv_path}")
|
||
print(f" 标注说明: {guide_path}")
|
||
print(f" 答案 (不给标注员): {key_path}")
|
||
print(f"\n分发给标注员的文件:")
|
||
print(f" 1. annotation_tasks/ 目录(全部 txt 文件)")
|
||
print(f" 2. annotation_sheet.csv")
|
||
print(f" 3. 标注说明.txt")
|
||
print(f"\n不要给标注员的文件:")
|
||
print(f" - annotation_key.json(包含答案)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|