317 lines
12 KiB
Python
317 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
eval_local.py — 本地模型评测脚本。
|
||
|
||
专门用于评测通过 vLLM / Ollama / llama.cpp 等部署的本地模型。
|
||
这些框架都提供 OpenAI 兼容的 API 接口。
|
||
|
||
使用方式:
|
||
|
||
# vLLM 部署后(默认端口 8000)
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1
|
||
|
||
# Ollama 部署后(默认端口 11434)
|
||
python eval_local.py --model qwen2.5:7b --api-base http://localhost:11434/v1
|
||
|
||
# 指定端口和 GPU 服务器
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://192.168.1.100:8000/v1
|
||
|
||
# 快速测试(只跑 5 个 episode)
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --max-episodes 5
|
||
|
||
# 只跑 SQ1 的 TP 看看效果
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --sq SQ1 --variant TP --max-episodes 10
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import concurrent.futures
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from collections import Counter
|
||
|
||
# 添加项目根目录到 path
|
||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from src.evaluation.prompt_builder import build_prompt
|
||
from src.evaluation.scorer import parse_model_response, score_episode, score_batch
|
||
from src.evaluation.metrics import format_results_table, compute_error_taxonomy
|
||
|
||
DATA_DIR = PROJECT_ROOT / "data"
|
||
BENCHMARK_DIR = DATA_DIR / "benchmark"
|
||
RESULTS_DIR = PROJECT_ROOT / "results"
|
||
|
||
|
||
def check_server(api_base: str) -> bool:
|
||
"""检查本地模型服务是否在线"""
|
||
import urllib.request
|
||
import urllib.error
|
||
|
||
# 尝试访问 /v1/models 端点
|
||
url = api_base.rstrip("/")
|
||
if not url.endswith("/v1"):
|
||
url = url.rstrip("/") + "/v1"
|
||
models_url = url + "/models"
|
||
|
||
try:
|
||
req = urllib.request.urlopen(models_url, timeout=5)
|
||
data = json.loads(req.read().decode())
|
||
models = [m.get("id", "") for m in data.get("data", [])]
|
||
print(f"Server online. Available models: {models}")
|
||
return True
|
||
except urllib.error.URLError as e:
|
||
print(f"Cannot connect to {models_url}: {e}")
|
||
return False
|
||
except Exception as e:
|
||
# 有些服务器没有 /models 端点但能正常使用
|
||
print(f"Server check inconclusive ({e}), proceeding anyway...")
|
||
return True
|
||
|
||
|
||
def call_local_model(
|
||
system: str,
|
||
user: str,
|
||
model: str,
|
||
api_base: str,
|
||
temperature: float = 0.0,
|
||
max_tokens: int = 2048,
|
||
timeout: int = 300,
|
||
) -> str:
|
||
"""调用本地模型的 OpenAI 兼容接口"""
|
||
try:
|
||
import openai
|
||
except ImportError:
|
||
raise ImportError("pip install openai")
|
||
|
||
client = openai.OpenAI(
|
||
api_key="not-needed", # 本地部署不需要 key
|
||
base_url=api_base,
|
||
)
|
||
|
||
messages = []
|
||
if system:
|
||
messages.append({"role": "system", "content": system})
|
||
messages.append({"role": "user", "content": user})
|
||
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
timeout=timeout,
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
|
||
def evaluate_episode(episode_path: str, model: str, api_base: str, temperature: float) -> dict:
|
||
"""评测单个 episode"""
|
||
with open(episode_path, "r", encoding="utf-8") as f:
|
||
episode = json.load(f)
|
||
|
||
meta = episode["metadata"]
|
||
gt = episode["ground_truth"]
|
||
prompt = build_prompt(episode)
|
||
|
||
start = time.time()
|
||
try:
|
||
raw = call_local_model(
|
||
system=prompt.get("system", ""),
|
||
user=prompt["user"],
|
||
model=model,
|
||
api_base=api_base,
|
||
temperature=temperature,
|
||
)
|
||
latency = time.time() - start
|
||
error = None
|
||
except Exception as e:
|
||
raw = ""
|
||
latency = time.time() - start
|
||
error = str(e)
|
||
|
||
parsed = parse_model_response(raw) if raw else {"is_anomaly": None, "_parse_failed": True}
|
||
scores = score_episode(gt, parsed, meta.get("variant", ""))
|
||
|
||
return {
|
||
"episode_id": episode.get("episode_id", ""),
|
||
"episode_path": episode_path,
|
||
"metadata": meta,
|
||
"ground_truth": gt,
|
||
"raw_response": raw,
|
||
"model_response": parsed,
|
||
"scores": scores,
|
||
"latency": latency,
|
||
"api_error": error,
|
||
}
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="SafeHome 本地模型评测",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
# vLLM
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1
|
||
|
||
# Ollama
|
||
python eval_local.py --model qwen2.5:7b --api-base http://localhost:11434/v1
|
||
|
||
# 快速测试
|
||
python eval_local.py --model Qwen/Qwen2.5-7B-Instruct --api-base http://localhost:8000/v1 --max-episodes 5
|
||
""",
|
||
)
|
||
parser.add_argument("--model", required=True, help="模型名(要跟部署时的名字一致)")
|
||
parser.add_argument("--api-base", required=True, help="本地 API 地址(如 http://localhost:8000/v1)")
|
||
parser.add_argument("--sq", nargs="*", default=None, help="评测哪些 SQ 类型(默认全部)")
|
||
parser.add_argument("--variant", nargs="*", default=None, help="评测哪些变体(默认全部)")
|
||
parser.add_argument("--max-episodes", type=int, default=None, help="最大评测数量")
|
||
parser.add_argument("--workers", type=int, default=1, help="并行数(本地模型建议 1,除非 GPU 够多)")
|
||
parser.add_argument("--temperature", type=float, default=0.0, help="采样温度")
|
||
parser.add_argument("--output-dir", default=None, help="结果保存目录")
|
||
parser.add_argument("--no-check", action="store_true", help="跳过服务器连通性检查")
|
||
parser.add_argument("--no-resume", action="store_true", help="不断点续跑,重新开始")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查服务器
|
||
if not args.no_check:
|
||
if not check_server(args.api_base):
|
||
print("\n服务器不可达。请确认:")
|
||
print(f" 1. 模型已部署并在运行(vLLM/Ollama)")
|
||
print(f" 2. API 地址正确: {args.api_base}")
|
||
print(f" 3. 防火墙/端口未被阻止")
|
||
print(f"\n如果确定服务器正常,加 --no-check 跳过检查")
|
||
return
|
||
|
||
# 收集 episodes
|
||
sq_types = [s.upper() for s in args.sq] if args.sq else ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5"]
|
||
variants = [v.upper() for v in args.variant] if args.variant else ["TP", "FP", "TN"]
|
||
|
||
episode_paths = []
|
||
for sq in sq_types:
|
||
sq_dir = BENCHMARK_DIR / sq.lower()
|
||
if not sq_dir.exists():
|
||
continue
|
||
for f in sorted(sq_dir.glob("*.json")):
|
||
for v in variants:
|
||
if f"_{v}_" in f.stem:
|
||
episode_paths.append(str(f))
|
||
break
|
||
|
||
if args.max_episodes:
|
||
episode_paths = episode_paths[:args.max_episodes]
|
||
|
||
# 输出目录
|
||
model_safe = args.model.replace("/", "_").replace(":", "_")
|
||
output_dir = Path(args.output_dir) if args.output_dir else RESULTS_DIR / model_safe
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 断点续跑
|
||
results = []
|
||
completed = set()
|
||
results_file = output_dir / "results.jsonl"
|
||
if not args.no_resume and results_file.exists():
|
||
with open(results_file, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
r = json.loads(line)
|
||
completed.add(r["episode_id"])
|
||
results.append(r)
|
||
print(f"已完成 {len(completed)} 个,继续剩余部分")
|
||
|
||
remaining = [p for p in episode_paths if Path(p).stem not in completed]
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f" SafeHome 本地模型评测")
|
||
print(f"{'='*60}")
|
||
print(f" 模型: {args.model}")
|
||
print(f" API: {args.api_base}")
|
||
print(f" 待评: {len(remaining)} / {len(episode_paths)} episodes")
|
||
print(f" 并行: {args.workers}")
|
||
print(f" 输出: {output_dir}")
|
||
print(f"{'='*60}\n")
|
||
|
||
if not remaining:
|
||
print("所有 episode 已评测完毕,直接输出结果。\n")
|
||
else:
|
||
done = 0
|
||
with open(results_file, "a", encoding="utf-8") as out_f:
|
||
if args.workers <= 1:
|
||
for ep_path in remaining:
|
||
result = evaluate_episode(ep_path, args.model, args.api_base, args.temperature)
|
||
results.append(result)
|
||
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
out_f.flush()
|
||
done += 1
|
||
|
||
label = result["scores"].get("detection_label", "?")
|
||
err = f" ERROR: {result['api_error'][:50]}" if result.get("api_error") else ""
|
||
print(f" [{done}/{len(remaining)}] {result['episode_id']}: {label} ({result['latency']:.1f}s){err}")
|
||
else:
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor:
|
||
futures = {
|
||
executor.submit(evaluate_episode, ep, args.model, args.api_base, args.temperature): ep
|
||
for ep in remaining
|
||
}
|
||
for future in concurrent.futures.as_completed(futures):
|
||
try:
|
||
result = future.result()
|
||
except Exception as e:
|
||
ep_path = futures[future]
|
||
result = {
|
||
"episode_id": Path(ep_path).stem,
|
||
"episode_path": ep_path,
|
||
"metadata": {}, "ground_truth": {},
|
||
"raw_response": "", "model_response": {},
|
||
"scores": {"detection_correct": False, "parse_success": False},
|
||
"latency": 0, "api_error": str(e),
|
||
}
|
||
results.append(result)
|
||
out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
out_f.flush()
|
||
done += 1
|
||
label = result["scores"].get("detection_label", "?")
|
||
print(f" [{done}/{len(remaining)}] {result.get('episode_id','?')}: {label} ({result.get('latency',0):.1f}s)")
|
||
|
||
# 汇总
|
||
summary = score_batch(results)
|
||
errors = compute_error_taxonomy(results)
|
||
|
||
# 额外统计
|
||
latencies = [r["latency"] for r in results if r.get("latency")]
|
||
avg_latency = sum(latencies) / len(latencies) if latencies else 0
|
||
api_errors = sum(1 for r in results if r.get("api_error"))
|
||
|
||
summary_data = {
|
||
"model": args.model,
|
||
"api_base": args.api_base,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"total_evaluated": len(results),
|
||
"avg_latency_seconds": round(avg_latency, 2),
|
||
"api_errors": api_errors,
|
||
"summary": summary,
|
||
"errors": errors,
|
||
}
|
||
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
|
||
json.dump(summary_data, f, ensure_ascii=False, indent=2)
|
||
|
||
# 打印结果
|
||
print()
|
||
print(format_results_table(summary, args.model))
|
||
print(f" Average Latency: {avg_latency:.1f}s per episode")
|
||
print(f" API Errors: {api_errors}")
|
||
|
||
if errors["total_errors"] > 0:
|
||
print(f"\nError Distribution:")
|
||
for etype, count in errors["error_distribution"].items():
|
||
print(f" {etype}: {count}")
|
||
|
||
print(f"\n结果已保存到: {output_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|