Files
llmiotsafe/eval_local.py
2026-05-12 17:01:39 +08:00

317 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
eval_local.py — 本地模型评测脚本。
专门用于评测通过 vLLM / Ollama / llama.cpp 等部署的本地模型。
这些框架都提供 OpenAI 兼容的 API 接口。
使用方式:
# vLLM 部署后(默认端口 8000
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1
# Ollama 部署后(默认端口 11434
python eval_local.py --model qwen2.5:7b --api-base http://localhost:11434/v1
# 指定端口和 GPU 服务器
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://192.168.1.100:8000/v1
# 快速测试(只跑 5 个 episode
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --max-episodes 5
# 只跑 SQ1 的 TP 看看效果
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --sq SQ1 --variant TP --max-episodes 10
"""
import json
import os
import sys
import time
import argparse
import concurrent.futures
from pathlib import Path
from datetime import datetime
from collections import Counter
# 添加项目根目录到 path
PROJECT_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))
from src.evaluation.prompt_builder import build_prompt
from src.evaluation.scorer import parse_model_response, score_episode, score_batch
from src.evaluation.metrics import format_results_table, compute_error_taxonomy
DATA_DIR = PROJECT_ROOT / "data"
BENCHMARK_DIR = DATA_DIR / "benchmark"
RESULTS_DIR = PROJECT_ROOT / "results"
def check_server(api_base: str) -> bool:
"""检查本地模型服务是否在线"""
import urllib.request
import urllib.error
# 尝试访问 /v1/models 端点
url = api_base.rstrip("/")
if not url.endswith("/v1"):
url = url.rstrip("/") + "/v1"
models_url = url + "/models"
try:
req = urllib.request.urlopen(models_url, timeout=5)
data = json.loads(req.read().decode())
models = [m.get("id", "") for m in data.get("data", [])]
print(f"Server online. Available models: {models}")
return True
except urllib.error.URLError as e:
print(f"Cannot connect to {models_url}: {e}")
return False
except Exception as e:
# 有些服务器没有 /models 端点但能正常使用
print(f"Server check inconclusive ({e}), proceeding anyway...")
return True
def call_local_model(
system: str,
user: str,
model: str,
api_base: str,
temperature: float = 0.0,
max_tokens: int = 2048,
timeout: int = 300,
) -> str:
"""调用本地模型的 OpenAI 兼容接口"""
try:
import openai
except ImportError:
raise ImportError("pip install openai")
client = openai.OpenAI(
api_key="not-needed", # 本地部署不需要 key
base_url=api_base,
)
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": user})
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
)
return response.choices[0].message.content
def evaluate_episode(episode_path: str, model: str, api_base: str, temperature: float) -> dict:
"""评测单个 episode"""
with open(episode_path, "r", encoding="utf-8") as f:
episode = json.load(f)
meta = episode["metadata"]
gt = episode["ground_truth"]
prompt = build_prompt(episode)
start = time.time()
try:
raw = call_local_model(
system=prompt.get("system", ""),
user=prompt["user"],
model=model,
api_base=api_base,
temperature=temperature,
)
latency = time.time() - start
error = None
except Exception as e:
raw = ""
latency = time.time() - start
error = str(e)
parsed = parse_model_response(raw) if raw else {"is_anomaly": None, "_parse_failed": True}
scores = score_episode(gt, parsed, meta.get("variant", ""))
return {
"episode_id": episode.get("episode_id", ""),
"episode_path": episode_path,
"metadata": meta,
"ground_truth": gt,
"raw_response": raw,
"model_response": parsed,
"scores": scores,
"latency": latency,
"api_error": error,
}
def main():
parser = argparse.ArgumentParser(
description="SafeHome 本地模型评测",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# vLLM
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1
# Ollama
python eval_local.py --model qwen2.5:7b --api-base http://localhost:11434/v1
# 快速测试
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --max-episodes 5
""",
)
parser.add_argument("--model", required=True, help="模型名(要跟部署时的名字一致)")
parser.add_argument("--api-base", required=True, help="本地 API 地址(如 http://localhost:8000/v1")
parser.add_argument("--sq", nargs="*", default=None, help="评测哪些 SQ 类型(默认全部)")
parser.add_argument("--variant", nargs="*", default=None, help="评测哪些变体(默认全部)")
parser.add_argument("--max-episodes", type=int, default=None, help="最大评测数量")
parser.add_argument("--workers", type=int, default=1, help="并行数(本地模型建议 1除非 GPU 够多)")
parser.add_argument("--temperature", type=float, default=0.0, help="采样温度")
parser.add_argument("--output-dir", default=None, help="结果保存目录")
parser.add_argument("--no-check", action="store_true", help="跳过服务器连通性检查")
parser.add_argument("--no-resume", action="store_true", help="不断点续跑,重新开始")
args = parser.parse_args()
# 检查服务器
if not args.no_check:
if not check_server(args.api_base):
print("\n服务器不可达。请确认:")
print(f" 1. 模型已部署并在运行vLLM/Ollama")
print(f" 2. API 地址正确: {args.api_base}")
print(f" 3. 防火墙/端口未被阻止")
print(f"\n如果确定服务器正常,加 --no-check 跳过检查")
return
# 收集 episodes
sq_types = [s.upper() for s in args.sq] if args.sq else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
variants = [v.upper() for v in args.variant] if args.variant else ["TP", "FP", "TN"]
episode_paths = []
for sq in sq_types:
sq_dir = BENCHMARK_DIR / sq.lower()
if not sq_dir.exists():
continue
for f in sorted(sq_dir.glob("*.json")):
for v in variants:
if f"_{v}_" in f.stem:
episode_paths.append(str(f))
break
if args.max_episodes:
episode_paths = episode_paths[:args.max_episodes]
# 输出目录
model_safe = args.model.replace("/", "_").replace(":", "_")
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / model_safe
output_dir.mkdir(parents=True, exist_ok=True)
# 断点续跑
results = []
completed = set()
results_file = output_dir / "results.jsonl"
if not args.no_resume and results_file.exists():
with open(results_file, "r", encoding="utf-8") as f:
for line in f:
r = json.loads(line)
completed.add(r["episode_id"])
results.append(r)
print(f"已完成 {len(completed)} 个,继续剩余部分")
remaining = [p for p in episode_paths if Path(p).stem not in completed]
print(f"\n{'='*60}")
print(f" SafeHome 本地模型评测")
print(f"{'='*60}")
print(f" 模型: {args.model}")
print(f" API: {args.api_base}")
print(f" 待评: {len(remaining)} / {len(episode_paths)} episodes")
print(f" 并行: {args.workers}")
print(f" 输出: {output_dir}")
print(f"{'='*60}\n")
if not remaining:
print("所有 episode 已评测完毕,直接输出结果。\n")
else:
done = 0
with open(results_file, "a", encoding="utf-8") as out_f:
if args.workers <= 1:
for ep_path in remaining:
result = evaluate_episode(ep_path, args.model, args.api_base, args.temperature)
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
done += 1
label = result["scores"].get("detection_label", "?")
err = f" ERROR: {result['api_error'][:50]}" if result.get("api_error") else ""
print(f" [{done}/{len(remaining)}] {result['episode_id']}: {label} ({result['latency']:.1f}s){err}")
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(evaluate_episode, ep, args.model, args.api_base, args.temperature): ep
for ep in remaining
}
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
except Exception as e:
ep_path = futures[future]
result = {
"episode_id": Path(ep_path).stem,
"episode_path": ep_path,
"metadata": {}, "ground_truth": {},
"raw_response": "", "model_response": {},
"scores": {"detection_correct": False, "parse_success": False},
"latency": 0, "api_error": str(e),
}
results.append(result)
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
out_f.flush()
done += 1
label = result["scores"].get("detection_label", "?")
print(f" [{done}/{len(remaining)}] {result.get('episode_id','?')}: {label} ({result.get('latency',0):.1f}s)")
# 汇总
summary = score_batch(results)
errors = compute_error_taxonomy(results)
# 额外统计
latencies = [r["latency"] for r in results if r.get("latency")]
avg_latency = sum(latencies) / len(latencies) if latencies else 0
api_errors = sum(1 for r in results if r.get("api_error"))
summary_data = {
"model": args.model,
"api_base": args.api_base,
"timestamp": datetime.now().isoformat(),
"total_evaluated": len(results),
"avg_latency_seconds": round(avg_latency, 2),
"api_errors": api_errors,
"summary": summary,
"errors": errors,
}
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
json.dump(summary_data, f, ensure_ascii=False, indent=2)
# 打印结果
print()
print(format_results_table(summary, args.model))
print(f" Average Latency: {avg_latency:.1f}s per episode")
print(f" API Errors: {api_errors}")
if errors["total_errors"] > 0:
print(f"\nError Distribution:")
for etype, count in errors["error_distribution"].items():
print(f" {etype}: {count}")
print(f"\n结果已保存到: {output_dir}")
if __name__ == "__main__":
main()