Files
llmiotsafe/eval_api.py
2026-05-12 17:01:39 +08:00

351 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
eval_api.py — 通用 API 评测脚本(纯 requests 实现,不依赖 openai/anthropic 库)。
支持任何 OpenAI 兼容接口OpenAI、Claude兼容模式、OpenRouter、vLLM、Ollama 等。
使用方式:
# Claude API兼容 OpenAI 格式)
python eval_api.py --model claude-sonnet-4-6 \
--api-base https://你的api地址/v1 \
--api-key sk-xxx \
--max-episodes 10
# OpenRouter
python eval_api.py --model anthropic/claude-sonnet-4-6 \
--api-base https://openrouter.ai/api/v1 \
--api-key sk-or-xxx \
--max-episodes 10
# 本地 vLLM
python eval_api.py --model Qwen/Qwen2.5-7B-Instruct \
--api-base http://localhost:8000/v1 \
--api-key not-needed \
--max-episodes 10
# 全量 1200
python eval_api.py --model Qwen/Qwen2.5-7B-Instruct \
--api-base http://localhost:8000/v1 \
--api-key not-needed
"""
import json
import os
import sys
import time
import argparse
import requests
import concurrent.futures
from pathlib import Path
from datetime import datetime
from collections import Counter
PROJECT_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))
from src.evaluation.prompt_builder import build_prompt
from src.evaluation.scorer import parse_model_response, score_episode, score_batch
from src.evaluation.metrics import format_results_table, compute_error_taxonomy
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
RESULTS_DIR = PROJECT_ROOT / "results"
def call_chat_api(
system: str,
user: str,
model: str,
api_base: str,
api_key: str,
temperature: float = 0.0,
max_tokens: int = 2048,
timeout: int = 300,
extra_body: dict = None,
) -> str:
"""用 requests 调用 OpenAI 兼容的 chat/completions 接口"""
url = api_base.rstrip("/")
if not url.endswith("/chat/completions"):
url = url.rstrip("/") + "/chat/completions"
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": user})
headers = {
"Content-Type": "application/json",
}
if api_key and api_key != "not-needed":
headers["Authorization"] = f"Bearer {api_key}"
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if extra_body:
payload.update(extra_body)
resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
resp.raise_for_status()
data = resp.json()
# OpenAI 格式
if "choices" in data and len(data["choices"]) > 0:
return data["choices"][0]["message"]["content"]
# 兜底
raise ValueError(f"Unexpected response format: {json.dumps(data)[:500]}")
def evaluate_episode(ep_path, model, api_base, api_key, temperature, extra_body=None, prompt_mode="baseline"):
"""评测单个 episode"""
with open(ep_path, "r", encoding="utf-8") as f:
episode = json.load(f)
meta = episode["metadata"]
gt = episode["ground_truth"]
prompt = build_prompt(episode, mode=prompt_mode)
start = time.time()
try:
raw = call_chat_api(
system=prompt.get("system", ""),
user=prompt["user"],
model=model,
api_base=api_base,
api_key=api_key,
temperature=temperature,
extra_body=extra_body,
)
latency = time.time() - start
error = None
except Exception as e:
raw = ""
latency = time.time() - start
error = str(e)
parsed = parse_model_response(raw, mode=prompt_mode) if raw else {"is_anomaly": None, "_parse_failed": True}
scores = score_episode(gt, parsed, meta.get("variant", ""))
return {
"episode_id": episode.get("episode_id", ""),
"episode_path": ep_path,
"metadata": meta,
"ground_truth": gt,
"raw_response": raw,
"model_response": parsed,
"scores": scores,
"latency": latency,
"api_error": error,
}
def main():
parser = argparse.ArgumentParser(
description="SafeHome 通用 API 评测(纯 requests不依赖 openai 库)",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model", required=True, help="模型名")
parser.add_argument("--api-base", required=True, help="API 地址(如 https://xxx/v1")
parser.add_argument("--api-key", required=True, help="API Key")
parser.add_argument("--sq", nargs="*", default=None, help="评测哪些 SQ默认全部")
parser.add_argument("--variant", nargs="*", default=None, help="评测哪些 variant默认全部")
parser.add_argument("--max-episodes", type=int, default=None, help="最大评测数")
parser.add_argument("--workers", type=int, default=1, help="并行数")
parser.add_argument("--temperature", type=float, default=0.0, help="采样温度")
parser.add_argument("--output-dir", default=None, help="输出目录")
parser.add_argument("--no-resume", action="store_true", help="不续跑,重新开始")
parser.add_argument("--mode", default="baseline", choices=["baseline", "edrc", "cot"],
help="Prompt 模式: baseline(直接判断), edrc(6步结构化推理), cot(思维链)")
parser.add_argument("--thinking", action="store_true", help="启用深度思考模式DeepSeek 等支持的模型)")
parser.add_argument("--no_thinking", action="store_true", help="关闭 thinking并使用 no-thinking 的推荐采样参数")
parser.add_argument("--extra-json", default=None, help="额外的 JSON 参数,会合并到请求体中(如 '{\"top_p\": 0.9}'")
parser.add_argument("--max-tokens", type=int, default=2048, help="最大输出 token 数(思考模式建议调大到 8192")
disable_thinking_only = False
if "--disable-thinking-only" in sys.argv:
sys.argv.remove("--disable-thinking-only")
disable_thinking_only = True
args = parser.parse_args()
# 构建 extra_body
extra_body = {}
if args.thinking:
extra_body["chat_template_kwargs"] = {"thinking": True}
if disable_thinking_only:
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
if hasattr(args, 'no_thinking') and args.no_thinking:
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
extra_body["top_p"] = 0.95
extra_body["top_k"] = 20
extra_body["presence_penalty"] = 1.5
if args.temperature == 0.0:
args.temperature = 0.7
if args.extra_json:
extra_body.update(json.loads(args.extra_json))
if args.max_tokens != 2048:
extra_body["max_tokens"] = args.max_tokens
if not extra_body:
extra_body = None
# 收集 episodes
sq_types = [s.upper() for s in args.sq] if args.sq else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
variants = [v.upper() for v in args.variant] if args.variant else ["TP", "FP", "TN"]
episode_paths = []
for sq in sq_types:
sq_dir = BENCHMARK_DIR / sq.lower()
if not sq_dir.exists():
continue
for f in sorted(sq_dir.glob("*.json")):
for v in variants:
if f"_{v}_" in f.stem:
episode_paths.append(str(f))
break
if args.max_episodes:
episode_paths = episode_paths[:args.max_episodes]
# 输出目录
model_safe = args.model.replace("/", "_").replace(":", "_")
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / model_safe
output_dir.mkdir(parents=True, exist_ok=True)
# 断点续跑
results = []
completed = set()
results_file = output_dir / "results.jsonl"
if not args.no_resume and results_file.exists():
with open(results_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
r = json.loads(line)
completed.add(r["episode_id"])
results.append(r)
print(f"已完成 {len(completed)} 个,继续剩余部分")
remaining = [p for p in episode_paths if Path(p).stem not in completed]
print(f"\n{'='*60}")
print(f" SafeHome API 评测")
print(f"{'='*60}")
print(f" 模型: {args.model}")
print(f" API: {args.api_base}")
print(f" 待评: {len(remaining)} / {len(episode_paths)} episodes")
print(f" 并行: {args.workers}")
print(f" 输出: {output_dir}")
print(f"{'='*60}\n")
# 先测一个确认连通
if remaining:
print("测试连通性...")
try:
test_resp = call_chat_api(
system="You are a test assistant.",
user="Say OK.",
model=args.model,
api_base=args.api_base,
api_key=args.api_key,
max_tokens=100,
timeout=60,
extra_body=extra_body,
)
print(f"连通成功: {test_resp[:50]}\n")
except Exception as e:
print(f"连通失败: {e}")
print("请检查 --api-base 和 --api-key 是否正确")
return
if not remaining:
print("所有 episode 已评测完毕。\n")
else:
done = 0
errors_count = 0
with open(results_file, "a", encoding="utf-8") as out_f:
if args.workers <= 1:
for ep_path in remaining:
result = evaluate_episode(ep_path, args.model, args.api_base, args.api_key, args.temperature, extra_body, args.mode)
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
done += 1
label = result["scores"].get("detection_label", "?")
err_str = ""
if result.get("api_error"):
errors_count += 1
err_str = f" ERR: {result['api_error'][:40]}"
print(f" [{done}/{len(remaining)}] {result['episode_id']}: {label} ({result['latency']:.1f}s){err_str}")
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(evaluate_episode, ep, args.model, args.api_base, args.api_key, args.temperature, extra_body, args.mode): ep
for ep in remaining
}
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
except Exception as e:
ep_path = futures[future]
result = {
"episode_id": Path(ep_path).stem,
"episode_path": ep_path,
"metadata": {}, "ground_truth": {},
"raw_response": "", "model_response": {},
"scores": {"detection_correct": False, "parse_success": False},
"latency": 0, "api_error": str(e),
}
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
done += 1
label = result["scores"].get("detection_label", "?")
err_str = ""
if result.get("api_error"):
errors_count += 1
err_str = f" ERR"
print(f" [{done}/{len(remaining)}] {result.get('episode_id','?')}: {label} ({result.get('latency',0):.1f}s){err_str}")
# 汇总
summary = score_batch(results)
errors = compute_error_taxonomy(results)
latencies = [r["latency"] for r in results if r.get("latency") and not r.get("api_error")]
avg_latency = sum(latencies) / len(latencies) if latencies else 0
api_errors = sum(1 for r in results if r.get("api_error"))
summary_data = {
"model": args.model,
"api_base": args.api_base,
"timestamp": datetime.now().isoformat(),
"total_evaluated": len(results),
"avg_latency_seconds": round(avg_latency, 2),
"api_errors": api_errors,
"summary": summary,
"errors": errors,
}
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
json.dump(summary_data, f, ensure_ascii=False, indent=2)
print()
print(format_results_table(summary, args.model))
print(f" Average Latency: {avg_latency:.1f}s per episode")
print(f" API Errors: {api_errors}")
if errors["total_errors"] > 0:
print(f"\nError Distribution:")
for etype, count in errors["error_distribution"].items():
print(f" {etype}: {count}")
print(f"\n结果已保存到: {output_dir}")
if __name__ == "__main__":
main()