351 lines
13 KiB
Python
351 lines
13 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
eval_api.py — 通用 API 评测脚本(纯 requests 实现,不依赖 openai/anthropic 库)。
|
||
|
||
支持任何 OpenAI 兼容接口:OpenAI、Claude(兼容模式)、OpenRouter、vLLM、Ollama 等。
|
||
|
||
使用方式:
|
||
|
||
# Claude API(兼容 OpenAI 格式)
|
||
python eval_api.py --model claude-sonnet-4-6 \
|
||
--api-base https://你的api地址/v1 \
|
||
--api-key sk-xxx \
|
||
--max-episodes 10
|
||
|
||
# OpenRouter
|
||
python eval_api.py --model anthropic/claude-sonnet-4-6 \
|
||
--api-base https://openrouter.ai/api/v1 \
|
||
--api-key sk-or-xxx \
|
||
--max-episodes 10
|
||
|
||
# 本地 vLLM
|
||
python eval_api.py --model Qwen/Qwen2.5-7B-Instruct \
|
||
--api-base http://localhost:8000/v1 \
|
||
--api-key not-needed \
|
||
--max-episodes 10
|
||
|
||
# 全量 1200
|
||
python eval_api.py --model Qwen/Qwen2.5-7B-Instruct \
|
||
--api-base http://localhost:8000/v1 \
|
||
--api-key not-needed
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import requests
|
||
import concurrent.futures
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from collections import Counter
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from src.evaluation.prompt_builder import build_prompt
|
||
from src.evaluation.scorer import parse_model_response, score_episode, score_batch
|
||
from src.evaluation.metrics import format_results_table, compute_error_taxonomy
|
||
|
||
BENCHMARK_DIR = PROJECT_ROOT / "data" / "benchmark"
|
||
RESULTS_DIR = PROJECT_ROOT / "results"
|
||
|
||
|
||
def call_chat_api(
|
||
system: str,
|
||
user: str,
|
||
model: str,
|
||
api_base: str,
|
||
api_key: str,
|
||
temperature: float = 0.0,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 300,
|
||
extra_body: dict = None,
|
||
) -> str:
|
||
"""用 requests 调用 OpenAI 兼容的 chat/completions 接口"""
|
||
|
||
url = api_base.rstrip("/")
|
||
if not url.endswith("/chat/completions"):
|
||
url = url.rstrip("/") + "/chat/completions"
|
||
|
||
messages = []
|
||
if system:
|
||
messages.append({"role": "system", "content": system})
|
||
messages.append({"role": "user", "content": user})
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
}
|
||
if api_key and api_key != "not-needed":
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
if extra_body:
|
||
payload.update(extra_body)
|
||
|
||
resp = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||
resp.raise_for_status()
|
||
|
||
data = resp.json()
|
||
|
||
# OpenAI 格式
|
||
if "choices" in data and len(data["choices"]) > 0:
|
||
return data["choices"][0]["message"]["content"]
|
||
|
||
# 兜底
|
||
raise ValueError(f"Unexpected response format: {json.dumps(data)[:500]}")
|
||
|
||
|
||
def evaluate_episode(ep_path, model, api_base, api_key, temperature, extra_body=None, prompt_mode="baseline"):
|
||
"""评测单个 episode"""
|
||
with open(ep_path, "r", encoding="utf-8") as f:
|
||
episode = json.load(f)
|
||
|
||
meta = episode["metadata"]
|
||
gt = episode["ground_truth"]
|
||
prompt = build_prompt(episode, mode=prompt_mode)
|
||
|
||
start = time.time()
|
||
try:
|
||
raw = call_chat_api(
|
||
system=prompt.get("system", ""),
|
||
user=prompt["user"],
|
||
model=model,
|
||
api_base=api_base,
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
extra_body=extra_body,
|
||
)
|
||
latency = time.time() - start
|
||
error = None
|
||
except Exception as e:
|
||
raw = ""
|
||
latency = time.time() - start
|
||
error = str(e)
|
||
|
||
parsed = parse_model_response(raw, mode=prompt_mode) if raw else {"is_anomaly": None, "_parse_failed": True}
|
||
scores = score_episode(gt, parsed, meta.get("variant", ""))
|
||
|
||
return {
|
||
"episode_id": episode.get("episode_id", ""),
|
||
"episode_path": ep_path,
|
||
"metadata": meta,
|
||
"ground_truth": gt,
|
||
"raw_response": raw,
|
||
"model_response": parsed,
|
||
"scores": scores,
|
||
"latency": latency,
|
||
"api_error": error,
|
||
}
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="SafeHome 通用 API 评测(纯 requests,不依赖 openai 库)",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
)
|
||
parser.add_argument("--model", required=True, help="模型名")
|
||
parser.add_argument("--api-base", required=True, help="API 地址(如 https://xxx/v1)")
|
||
parser.add_argument("--api-key", required=True, help="API Key")
|
||
parser.add_argument("--sq", nargs="*", default=None, help="评测哪些 SQ(默认全部)")
|
||
parser.add_argument("--variant", nargs="*", default=None, help="评测哪些 variant(默认全部)")
|
||
parser.add_argument("--max-episodes", type=int, default=None, help="最大评测数")
|
||
parser.add_argument("--workers", type=int, default=1, help="并行数")
|
||
parser.add_argument("--temperature", type=float, default=0.0, help="采样温度")
|
||
parser.add_argument("--output-dir", default=None, help="输出目录")
|
||
parser.add_argument("--no-resume", action="store_true", help="不续跑,重新开始")
|
||
parser.add_argument("--mode", default="baseline", choices=["baseline", "edrc", "cot"],
|
||
help="Prompt 模式: baseline(直接判断), edrc(6步结构化推理), cot(思维链)")
|
||
parser.add_argument("--thinking", action="store_true", help="启用深度思考模式(DeepSeek 等支持的模型)")
|
||
parser.add_argument("--no_thinking", action="store_true", help="关闭 thinking,并使用 no-thinking 的推荐采样参数")
|
||
parser.add_argument("--extra-json", default=None, help="额外的 JSON 参数,会合并到请求体中(如 '{\"top_p\": 0.9}')")
|
||
parser.add_argument("--max-tokens", type=int, default=2048, help="最大输出 token 数(思考模式建议调大到 8192)")
|
||
|
||
disable_thinking_only = False
|
||
if "--disable-thinking-only" in sys.argv:
|
||
sys.argv.remove("--disable-thinking-only")
|
||
disable_thinking_only = True
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 构建 extra_body
|
||
extra_body = {}
|
||
if args.thinking:
|
||
extra_body["chat_template_kwargs"] = {"thinking": True}
|
||
if disable_thinking_only:
|
||
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
|
||
if hasattr(args, 'no_thinking') and args.no_thinking:
|
||
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
|
||
extra_body["top_p"] = 0.95
|
||
extra_body["top_k"] = 20
|
||
extra_body["presence_penalty"] = 1.5
|
||
if args.temperature == 0.0:
|
||
args.temperature = 0.7
|
||
if args.extra_json:
|
||
extra_body.update(json.loads(args.extra_json))
|
||
if args.max_tokens != 2048:
|
||
extra_body["max_tokens"] = args.max_tokens
|
||
if not extra_body:
|
||
extra_body = None
|
||
|
||
# 收集 episodes
|
||
sq_types = [s.upper() for s in args.sq] if args.sq else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
|
||
variants = [v.upper() for v in args.variant] if args.variant else ["TP", "FP", "TN"]
|
||
|
||
episode_paths = []
|
||
for sq in sq_types:
|
||
sq_dir = BENCHMARK_DIR / sq.lower()
|
||
if not sq_dir.exists():
|
||
continue
|
||
for f in sorted(sq_dir.glob("*.json")):
|
||
for v in variants:
|
||
if f"_{v}_" in f.stem:
|
||
episode_paths.append(str(f))
|
||
break
|
||
|
||
if args.max_episodes:
|
||
episode_paths = episode_paths[:args.max_episodes]
|
||
|
||
# 输出目录
|
||
model_safe = args.model.replace("/", "_").replace(":", "_")
|
||
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / model_safe
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 断点续跑
|
||
results = []
|
||
completed = set()
|
||
results_file = output_dir / "results.jsonl"
|
||
if not args.no_resume and results_file.exists():
|
||
with open(results_file, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if line.strip():
|
||
r = json.loads(line)
|
||
completed.add(r["episode_id"])
|
||
results.append(r)
|
||
print(f"已完成 {len(completed)} 个,继续剩余部分")
|
||
|
||
remaining = [p for p in episode_paths if Path(p).stem not in completed]
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f" SafeHome API 评测")
|
||
print(f"{'='*60}")
|
||
print(f" 模型: {args.model}")
|
||
print(f" API: {args.api_base}")
|
||
print(f" 待评: {len(remaining)} / {len(episode_paths)} episodes")
|
||
print(f" 并行: {args.workers}")
|
||
print(f" 输出: {output_dir}")
|
||
print(f"{'='*60}\n")
|
||
|
||
# 先测一个确认连通
|
||
if remaining:
|
||
print("测试连通性...")
|
||
try:
|
||
test_resp = call_chat_api(
|
||
system="You are a test assistant.",
|
||
user="Say OK.",
|
||
model=args.model,
|
||
api_base=args.api_base,
|
||
api_key=args.api_key,
|
||
max_tokens=100,
|
||
timeout=60,
|
||
extra_body=extra_body,
|
||
)
|
||
print(f"连通成功: {test_resp[:50]}\n")
|
||
except Exception as e:
|
||
print(f"连通失败: {e}")
|
||
print("请检查 --api-base 和 --api-key 是否正确")
|
||
return
|
||
|
||
if not remaining:
|
||
print("所有 episode 已评测完毕。\n")
|
||
else:
|
||
done = 0
|
||
errors_count = 0
|
||
with open(results_file, "a", encoding="utf-8") as out_f:
|
||
if args.workers <= 1:
|
||
for ep_path in remaining:
|
||
result = evaluate_episode(ep_path, args.model, args.api_base, args.api_key, args.temperature, extra_body, args.mode)
|
||
results.append(result)
|
||
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
out_f.flush()
|
||
done += 1
|
||
|
||
label = result["scores"].get("detection_label", "?")
|
||
err_str = ""
|
||
if result.get("api_error"):
|
||
errors_count += 1
|
||
err_str = f" ERR: {result['api_error'][:40]}"
|
||
print(f" [{done}/{len(remaining)}] {result['episode_id']}: {label} ({result['latency']:.1f}s){err_str}")
|
||
else:
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
|
||
futures = {
|
||
executor.submit(evaluate_episode, ep, args.model, args.api_base, args.api_key, args.temperature, extra_body, args.mode): ep
|
||
for ep in remaining
|
||
}
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
result = future.result()
|
||
except Exception as e:
|
||
ep_path = futures[future]
|
||
result = {
|
||
"episode_id": Path(ep_path).stem,
|
||
"episode_path": ep_path,
|
||
"metadata": {}, "ground_truth": {},
|
||
"raw_response": "", "model_response": {},
|
||
"scores": {"detection_correct": False, "parse_success": False},
|
||
"latency": 0, "api_error": str(e),
|
||
}
|
||
results.append(result)
|
||
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
out_f.flush()
|
||
done += 1
|
||
label = result["scores"].get("detection_label", "?")
|
||
err_str = ""
|
||
if result.get("api_error"):
|
||
errors_count += 1
|
||
err_str = f" ERR"
|
||
print(f" [{done}/{len(remaining)}] {result.get('episode_id','?')}: {label} ({result.get('latency',0):.1f}s){err_str}")
|
||
|
||
# 汇总
|
||
summary = score_batch(results)
|
||
errors = compute_error_taxonomy(results)
|
||
|
||
latencies = [r["latency"] for r in results if r.get("latency") and not r.get("api_error")]
|
||
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
||
api_errors = sum(1 for r in results if r.get("api_error"))
|
||
|
||
summary_data = {
|
||
"model": args.model,
|
||
"api_base": args.api_base,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"total_evaluated": len(results),
|
||
"avg_latency_seconds": round(avg_latency, 2),
|
||
"api_errors": api_errors,
|
||
"summary": summary,
|
||
"errors": errors,
|
||
}
|
||
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
|
||
json.dump(summary_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print()
|
||
print(format_results_table(summary, args.model))
|
||
print(f" Average Latency: {avg_latency:.1f}s per episode")
|
||
print(f" API Errors: {api_errors}")
|
||
|
||
if errors["total_errors"] > 0:
|
||
print(f"\nError Distribution:")
|
||
for etype, count in errors["error_distribution"].items():
|
||
print(f" {etype}: {count}")
|
||
|
||
print(f"\n结果已保存到: {output_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|