Compare commits
3 Commits
097d2ce472
...
31812a72ee
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31812a72ee | ||
|
|
74e6bde13a | ||
|
|
fde96c7d9f |
231
README.md
231
README.md
@@ -2,43 +2,114 @@
|
||||
|
||||
Multi-Agent System for Digital Forensics — 基于大语言模型的多智能体电子取证系统。
|
||||
|
||||
系统通过 6 个专业化 Agent 协同工作,对磁盘镜像进行自动化取证分析,最终生成结构化的取证报告。
|
||||
系统通过 7 个专业化 Agent 协同工作,对磁盘镜像进行自动化取证分析,最终生成结构化的取证报告。Agent 之间不直接通信,通过共享的 **EvidenceGraph**(证据知识图)协作。
|
||||
|
||||
## 架构
|
||||
|
||||
```
|
||||
main.py 入口:配置加载、恢复检测、运行管理
|
||||
main.py 入口:配置加载、镜像选择、断连恢复
|
||||
│
|
||||
├── Orchestrator 四阶段流水线调度
|
||||
├── Orchestrator 五阶段流水线调度
|
||||
│ │
|
||||
│ ├── FileSystemAgent 磁盘结构、文件系统、删除文件、Prefetch
|
||||
│ ├── RegistryAgent 注册表分析(系统/用户/网络/软件)
|
||||
│ ├── CommunicationAgent 邮件、IRC 聊天记录
|
||||
│ ├── FileSystemAgent 分区/文件系统、目录、删除文件、Prefetch
|
||||
│ ├── HypothesisAgent 生成假设,链接已有证据
|
||||
│ ├── RegistryAgent 注册表分析(SYSTEM/SOFTWARE/SAM/NTUSER.DAT)
|
||||
│ ├── CommunicationAgent 邮件、IRC/mIRC 聊天记录
|
||||
│ ├── NetworkAgent 浏览器历史、PCAP 抓包
|
||||
│ ├── TimelineAgent 跨类别时间线关联
|
||||
│ └── ReportAgent 综合报告生成
|
||||
│
|
||||
├── Blackboard 共享知识库(Evidence + Lead)
|
||||
└── LLMClient Claude API 调用(ReAct 模式)
|
||||
├── EvidenceGraph 带类型边的证据知识图(自动持久化)
|
||||
├── AgentFactory 角色模板 + 动态 Agent 组合
|
||||
├── ToolRegistry 工具目录 + 结果缓存
|
||||
└── LLMClient Claude API 客户端(异步、tool-use)
|
||||
```
|
||||
|
||||
Agent 之间不直接通信,通过 **Blackboard(黑板)** 共享发现(Evidence)和线索(Lead)。
|
||||
## EvidenceGraph:证据知识图
|
||||
|
||||
## 调查流程
|
||||
三类节点 + 类型化加权边:
|
||||
|
||||
| 节点 | 前缀 | 含义 |
|
||||
|---|---|---|
|
||||
| `Phenomenon` | `ph-*` | 可观测的取证产物(一条具体发现) |
|
||||
| `Hypothesis` | `hyp-*` | 解释性假设(待验证的论断) |
|
||||
| `Entity` | `ent-*` | 人、程序、主机、IP 等可复现的实体 |
|
||||
|
||||
Phenomenon → Hypothesis 的边类型与权重写死在 `HYPOTHESIS_EDGE_WEIGHTS`:
|
||||
|
||||
| 边类型 | 权重 | 语义 |
|
||||
|---|---:|---|
|
||||
| `direct_evidence` | +0.25 | 现象就是假设所述行为本身 |
|
||||
| `supports` | +0.15 | 与假设一致但非决定性 |
|
||||
| `consequence_observed` | +0.15 | 观察到假设预期的结果 |
|
||||
| `prerequisite_met` | +0.10 | 满足假设的前置条件 |
|
||||
| `weakens` | −0.10 | 降低假设可能性 |
|
||||
| `contradicts` | −0.20 | 直接反驳假设 |
|
||||
|
||||
置信度更新公式(收敛于 [0, 1]):
|
||||
|
||||
- 正向边:`delta = weight * (1 - old_conf)`
|
||||
- 负向边:`delta = weight * old_conf`
|
||||
|
||||
跨阈值自动转状态:≥ 0.8 → `supported`,≤ 0.2 → `refuted`,跑完仍 active → `inconclusive`。LLM 只负责挑边类型(分类任务),权重表与状态转移由代码裁决,避免数值幻觉。
|
||||
|
||||
新增 Phenomenon 时通过 Jaccard 相似度合并(title > 0.6 且 description > 0.4 即视为重复,合并后提升置信度并追加 `corroborating_agents`),避免同一发现被重复入图。
|
||||
|
||||
## 五阶段流水线
|
||||
|
||||
| 阶段 | 说明 |
|
||||
|------|------|
|
||||
| **Phase 1** | FileSystemAgent 勘查磁盘镜像,识别分区、目录结构、关键文件,产出初始 Lead |
|
||||
| **Phase 2** | 多轮线索追踪 — Lead 按 Agent 类型分组并行派发,最多 10 轮迭代 |
|
||||
| **Phase 2.5** | 覆盖率缺口分析 — 对照 config.yaml 中的 10 个调查领域,自动补漏 |
|
||||
| **Phase 3** | TimelineAgent 综合所有 evidence 建立事件时间线 |
|
||||
| **Phase 4** | ReportAgent 生成 Markdown 格式取证报告 |
|
||||
| **Phase 1** | FileSystemAgent 初勘镜像,识别分区/文件系统/关键路径,产出首批 Phenomenon |
|
||||
| **Phase 2** | 假设生成 — 优先读 `config.yaml:hypotheses`;未配置则由 HypothesisAgent 从 Phase 1 现象自动生成 3-7 个 |
|
||||
| **Phase 3** | 假设驱动调查(默认 5 轮迭代)。每轮:一次性为所有 active 假设产出 leads → 按 agent 类型并发派发(信号量 = 3)→ 一次性判定新现象与各假设的关系。所有假设收敛即提前退出。末尾:失败 lead 重试一次 + Gap Analysis |
|
||||
| **Phase 4** | TimelineAgent 用 `build_filesystem_timeline` 生成 MAC 时间线,与 Phenomenon 时间戳关联 |
|
||||
| **Phase 5** | ReportAgent 综合假设、证据、实体,生成 Markdown 报告 |
|
||||
|
||||
## 取证工具链
|
||||
### Gap Analysis(Phase 3 末)
|
||||
|
||||
### Sleuth Kit(磁盘取证)
|
||||
`config.yaml:investigation_areas` 列出必须覆盖的调查领域(系统信息、用户账户、网络配置、邮件配置、IRC 日志、PCAP、删除文件、Prefetch 等)。Orchestrator 两层判定覆盖情况:
|
||||
|
||||
通过异步子进程调用 TSK 命令行工具:
|
||||
1. **关键词匹配**(`_AREA_KEYWORDS`)— 扫现有 Phenomenon 标题/描述
|
||||
2. **工具命中**(`_AREA_TOOLS`)— 检查是否调用过该领域的关键工具(如 `enumerate_users`、`parse_pcap_strings`)
|
||||
|
||||
未覆盖的领域自动派发 lead,最多 3 轮补漏。
|
||||
|
||||
## Agent 体系
|
||||
|
||||
`AgentFactory` 维护 7 个角色模板(`ROLE_TEMPLATES`),每个模板指定默认工具集。`HypothesisAgent` 和 `ReportAgent` 是 `BaseAgent` 的子类(额外注册专用工具),其余 5 个 Agent 直接由 `BaseAgent` + 工具列表生成。
|
||||
|
||||
### Agent 工作流
|
||||
|
||||
`BaseAgent.run` 在 system prompt 中强制四阶段:
|
||||
|
||||
```
|
||||
A. INVESTIGATE 先查图状态 / Asset Library,再调取证工具
|
||||
B. RECORD 每条发现写 add_phenomenon
|
||||
C. LINK 按需 link_to_entity,但禁止凭记忆引用 ph-id,必须先 list_phenomena
|
||||
D. ANSWER 以上完成后再给最终答复
|
||||
```
|
||||
|
||||
prompt 内置**反幻觉规则**:只允许记录工具输出中逐字出现的内容;时间戳/路径/inode 必须来自工具返回;输出被截断须标 `[truncated]`。
|
||||
|
||||
### 动态 Agent 组合
|
||||
|
||||
`AgentFactory.create_specialized_agent()` 应对能力缺口:将工具目录与假设描述喂给 LLM,由其挑 3-8 个工具并写角色描述,工厂据此实例化新 Agent 并缓存。
|
||||
|
||||
## 工具系统
|
||||
|
||||
`tool_registry.py` 启动时调用 `register_all_tools(image_path, partition_offset, graph)`,将所有工具一次性注册到全局 `TOOL_CATALOG`。
|
||||
|
||||
### 工具结果缓存
|
||||
|
||||
`CACHEABLE_TOOLS` 集合标记纯读取/确定性工具(partition_info、list_directory、parse_registry_key …)。镜像只读,同 args 调用产出固定,命中缓存直接复用,错误结果不入缓存。
|
||||
|
||||
### Asset Library
|
||||
|
||||
`EvidenceGraph.asset_library` 按 inode 索引所有已提取文件,避免重复 extract。Agent 通过 `list_assets` / `find_extracted_file` 工具查询。新文件按文件名自动归类到 `registry_hive` / `chat_log` / `prefetch` / `network_capture` / `recycle_bin` 等十类之一。
|
||||
|
||||
### 取证工具链
|
||||
|
||||
**Sleuth Kit(磁盘取证)** — 异步子进程调用 TSK:
|
||||
|
||||
| 工具 | 用途 |
|
||||
|------|------|
|
||||
@@ -49,47 +120,43 @@ Agent 之间不直接通信,通过 **Blackboard(黑板)** 共享发现(E
|
||||
| `srch_strings` | 磁盘字符串搜索 |
|
||||
| `fls -m` | MAC 时间线生成 |
|
||||
|
||||
### regipy(注册表解析)
|
||||
**regipy(注册表解析)** — 直接读 SYSTEM / SOFTWARE / SAM / NTUSER.DAT 二进制,提取系统信息、用户账户、网络配置、已安装软件、邮件账户、关机时间等。
|
||||
|
||||
直接解析 Windows 注册表 hive 二进制文件(SYSTEM、SOFTWARE、SAM、NTUSER.DAT),提取系统信息、用户账户、网络配置、已安装软件、邮件账户、关机时间等。
|
||||
**文件解析器** — Prefetch 二进制(`.pf`)、PCAP 字符串提取(HTTP 请求 / Host / Cookie / UA)、通用文本与二进制读取、正则搜索、Hex dump。
|
||||
|
||||
### 文件解析器
|
||||
## 断连恢复与运行归档
|
||||
|
||||
- **Prefetch** — 二进制解析 Windows XP .pf 文件(运行次数、最后执行时间)
|
||||
- **PCAP** — 从抓包文件提取 HTTP 请求、Host、Cookie、User-Agent
|
||||
- **通用文本/二进制** — 按偏移读取、正则搜索、Hex dump
|
||||
三层防护:
|
||||
|
||||
## 断连恢复与数据归档
|
||||
1. **EvidenceGraph 自动持久化** — 每次 `add_phenomenon` / `add_hypothesis` / `add_edge` / `add_lead` 等写操作均自动落盘(原子写 `.tmp` 后 rename)
|
||||
2. **Agent 级容错** — 单 Agent 失败 → 该 lead 标 `failed`,连续 3 次失败触发 `AnalysisAborted` 优雅退出;Phase 3 末尾对失败 lead 重试一次(`retry=True` 防无限循环)
|
||||
3. **续跑** — `main.py` 启动时扫 `runs/*/graph_state.json`,发现存在但缺 `run_metadata.json` 的目录即提示恢复,并按 graph 当前状态决定从哪一阶段续起
|
||||
|
||||
系统设计了三层防护,应对长时间运行中的网络中断:
|
||||
|
||||
1. **Blackboard 自动持久化** — 每次 add_evidence / add_lead 自动写盘(原子写入)
|
||||
2. **Agent 级容错** — 单个 Agent 失败标记 Lead 为 failed,不影响其他 Agent,自动重试一次
|
||||
3. **优雅退出** — 连续 3 次 Agent 失败后保存现有成果并干净退出
|
||||
|
||||
每次运行自动创建带时间戳的归档目录:
|
||||
### 运行归档目录
|
||||
|
||||
```
|
||||
runs/
|
||||
2026-04-02T14-30-00/
|
||||
config.yaml 配置快照
|
||||
blackboard_state.json 实时状态(用于恢复)
|
||||
evidence.json 结构化证据导出
|
||||
leads.json 线索及最终状态
|
||||
report.md 取证报告
|
||||
run_metadata.json 运行元数据(时长、统计、错误)
|
||||
masforensics.log 运行日志
|
||||
config.yaml 配置快照
|
||||
graph_state.json 实时图状态(续跑用)
|
||||
phenomena.json 现象导出
|
||||
hypotheses.json 假设 + 置信度日志
|
||||
entities.json 实体
|
||||
edges.json 边
|
||||
leads.json 线索及最终状态
|
||||
extracted/ 从镜像提取的文件
|
||||
<image>_forensic_report.md 取证报告
|
||||
run_metadata.json 运行元数据(时长、统计、错误)
|
||||
masforensics.log 运行日志
|
||||
```
|
||||
|
||||
中断后再次运行 `python main.py`,系统自动检测未完成的运行并提示恢复。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Python >= 3.14
|
||||
- The Sleuth Kit(系统安装,提供 `mmls`、`fls`、`icat` 等命令)
|
||||
- 磁盘镜像文件置于 `image/` 目录
|
||||
- 磁盘镜像文件
|
||||
|
||||
### 安装
|
||||
|
||||
@@ -99,50 +166,76 @@ uv sync
|
||||
|
||||
### 配置
|
||||
|
||||
编辑 `config.yaml`,填入 LLM API 地址和密钥:
|
||||
编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
base_url: "https://your-api-proxy.com"
|
||||
api_key: "sk-your-key"
|
||||
model: "claude-sonnet-4-6"
|
||||
max_tokens: 4096
|
||||
max_tokens: 16384
|
||||
|
||||
max_investigation_rounds: 5 # Phase 3 最大迭代轮数
|
||||
|
||||
# hypotheses: # 可选:手动指定初始假设
|
||||
# - title: "嫌疑人主动实施网络嗅探"
|
||||
# description: "..."
|
||||
|
||||
investigation_areas: # Gap Analysis 必须覆盖的领域
|
||||
- area: system_info
|
||||
agent: registry
|
||||
task: "..."
|
||||
# ...
|
||||
```
|
||||
|
||||
`investigation_areas` 部分定义了必须覆盖的调查领域,可按需增减。
|
||||
未配置 `hypotheses` 时由 HypothesisAgent 自动生成。
|
||||
|
||||
### 运行
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
python main.py # 交互式选镜像与分区
|
||||
python main.py /path/to/image/dir # 指定镜像目录
|
||||
```
|
||||
|
||||
报告和所有结构化数据将保存在 `runs/<timestamp>/` 目录下。
|
||||
中断后再次运行会自动检测未完成的 run 并提示是否续跑。
|
||||
|
||||
### 仅重生成报告
|
||||
|
||||
跑完一次后若只想换提示词或修复报告:
|
||||
|
||||
```bash
|
||||
python regenerate_report.py runs/<timestamp>
|
||||
```
|
||||
|
||||
跳过 Phase 1-4,直接从已有 `graph_state.json` 重跑 ReportAgent。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
MASForensics/
|
||||
├── main.py 入口
|
||||
├── orchestrator.py 流水线调度
|
||||
├── blackboard.py 共享知识库
|
||||
├── llm_client.py LLM API 客户端
|
||||
├── base_agent.py Agent 基类
|
||||
├── config.yaml 配置文件
|
||||
├── main.py 入口、镜像选择、断连恢复
|
||||
├── orchestrator.py 五阶段流水线调度
|
||||
├── evidence_graph.py 证据知识图 + 边权重表 + 持久化
|
||||
├── base_agent.py Agent 基类 + 内建 graph 工具
|
||||
├── agent_factory.py 角色模板 + 动态 Agent 组合
|
||||
├── tool_registry.py 工具目录 + 结果缓存 + 自动归类
|
||||
├── llm_client.py LLM API 客户端
|
||||
├── log_config.py 彩色终端日志 + 文件日志
|
||||
├── regenerate_report.py 从已有 graph_state 重生成报告
|
||||
├── config.yaml 配置 + 调查领域 + 可选假设
|
||||
├── agents/
|
||||
│ ├── filesystem.py 文件系统 Agent
|
||||
│ ├── registry.py 注册表 Agent
|
||||
│ ├── communication.py 通信 Agent
|
||||
│ ├── network.py 网络 Agent
|
||||
│ ├── timeline.py 时间线 Agent
|
||||
│ └── report.py 报告 Agent
|
||||
│ ├── hypothesis.py HypothesisAgent(add_hypothesis、link)
|
||||
│ ├── report.py ReportAgent(综合报告,自带读取工具)
|
||||
│ ├── timeline.py TimelineAgent(保留以备扩展)
|
||||
│ └── ... filesystem/registry/communication/network(同上)
|
||||
├── tools/
|
||||
│ ├── sleuthkit.py Sleuth Kit 封装
|
||||
│ ├── registry.py 注册表解析(regipy)
|
||||
│ └── parsers.py 文件格式解析器
|
||||
├── image/ 磁盘镜像
|
||||
├── extracted/ 提取的文件(运行时生成)
|
||||
└── runs/ 运行归档
|
||||
│ ├── sleuthkit.py TSK 异步封装
|
||||
│ ├── registry.py regipy 解析
|
||||
│ └── parsers.py Prefetch / PCAP / 通用文件解析
|
||||
├── image/ 磁盘镜像(用户放)
|
||||
├── runs/ 运行归档
|
||||
└── tests/
|
||||
└── test_optimizations.py
|
||||
```
|
||||
|
||||
## 依赖
|
||||
@@ -152,14 +245,16 @@ MASForensics/
|
||||
| `httpx[socks]` | 异步 HTTP 客户端(支持 SOCKS 代理) |
|
||||
| `pyyaml` | 配置文件解析 |
|
||||
| `regipy` | Windows 注册表 hive 解析 |
|
||||
| `pytest` / `pytest-asyncio` | 测试 |
|
||||
|
||||
## 当前案例
|
||||
## 默认案例
|
||||
|
||||
默认配置分析 **CFReDS Hacking Case**(NIST 标准取证教学镜像):
|
||||
**CFReDS Hacking Case**(NIST 标准取证教学镜像):
|
||||
|
||||
- 镜像:SCHARDT.001(~4.6GB,IBM 硬盘,8 个分段)
|
||||
- 镜像:SCHARDT.001(~4.6 GB,IBM 硬盘,8 个分段)
|
||||
- 系统:Windows XP
|
||||
- 场景:涉嫌黑客入侵的计算机取证分析
|
||||
- 完整镜像 MD5:`AEE4FCD9301C03B3B054623CA261959A`(`config.yaml` 含各分段 MD5 用于校验)
|
||||
|
||||
## 测试
|
||||
|
||||
|
||||
179
agent_factory.py
179
agent_factory.py
@@ -1,150 +1,50 @@
|
||||
"""Agent Factory — composes agents from tool registry and role templates.
|
||||
"""Agent Factory — instantiates agents from registered classes.
|
||||
|
||||
Provides both pre-defined agent templates (filesystem, registry, etc.)
|
||||
and LLM-driven dynamic agent composition for capability gaps.
|
||||
Each agent type has a dedicated subclass under agents/ that owns its name,
|
||||
role description, and tool list (single source of truth). The factory just
|
||||
maps agent_type → class. Also supports LLM-driven dynamic composition for
|
||||
capability gaps via create_specialized_agent().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from base_agent import BaseAgent
|
||||
from evidence_graph import EvidenceGraph
|
||||
from llm_client import LLMClient
|
||||
from tool_registry import TOOL_CATALOG, ToolDefinition
|
||||
from tool_registry import TOOL_CATALOG
|
||||
|
||||
# Agent classes with custom tools — keyed by template name
|
||||
_AGENT_CLASSES: dict[str, type] = {}
|
||||
# Agent classes keyed by name. Populated lazily to avoid circular imports.
|
||||
_AGENT_CLASSES: dict[str, type[BaseAgent]] = {}
|
||||
|
||||
|
||||
def _load_agent_classes() -> None:
|
||||
"""Lazy-import custom agent classes to avoid circular imports."""
|
||||
"""Lazy-import agent classes to avoid circular imports."""
|
||||
if _AGENT_CLASSES:
|
||||
return
|
||||
from agents.communication import CommunicationAgent
|
||||
from agents.filesystem import FileSystemAgent
|
||||
from agents.hypothesis import HypothesisAgent
|
||||
from agents.network import NetworkAgent
|
||||
from agents.registry import RegistryAgent
|
||||
from agents.report import ReportAgent
|
||||
from agents.timeline import TimelineAgent
|
||||
_AGENT_CLASSES["filesystem"] = FileSystemAgent
|
||||
_AGENT_CLASSES["registry"] = RegistryAgent
|
||||
_AGENT_CLASSES["communication"] = CommunicationAgent
|
||||
_AGENT_CLASSES["network"] = NetworkAgent
|
||||
_AGENT_CLASSES["timeline"] = TimelineAgent
|
||||
_AGENT_CLASSES["hypothesis"] = HypothesisAgent
|
||||
_AGENT_CLASSES["report"] = ReportAgent
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoleTemplate:
|
||||
"""Pre-defined agent archetype."""
|
||||
|
||||
name: str
|
||||
role: str
|
||||
default_tools: list[str] # tool names from TOOL_CATALOG
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# Pre-defined templates matching the original 6 agents + hypothesis agent.
|
||||
ROLE_TEMPLATES: dict[str, RoleTemplate] = {
|
||||
"filesystem": RoleTemplate(
|
||||
name="filesystem",
|
||||
role=(
|
||||
"File system forensic analyst. You examine disk image partition layouts, "
|
||||
"directory structures, file metadata, and recover deleted files. "
|
||||
"You identify suspicious files, installed programs, and user data locations. "
|
||||
"You also handle Recycle Bin forensics and Prefetch execution evidence."
|
||||
),
|
||||
default_tools=[
|
||||
"partition_info", "filesystem_info", "list_directory",
|
||||
"extract_file", "find_file", "search_strings",
|
||||
"parse_prefetch", "count_deleted_files",
|
||||
"read_text_file", "search_text_file", "read_binary_preview",
|
||||
],
|
||||
tags=["filesystem", "disk", "files", "deleted", "prefetch"],
|
||||
),
|
||||
"registry": RoleTemplate(
|
||||
name="registry",
|
||||
role=(
|
||||
"Windows registry forensic analyst. You parse registry hive files "
|
||||
"(SYSTEM, SOFTWARE, SAM, NTUSER.DAT) to extract system configuration, "
|
||||
"user accounts, installed software, network settings, email accounts, "
|
||||
"and other Windows artifacts."
|
||||
),
|
||||
default_tools=[
|
||||
"extract_file", "list_directory",
|
||||
"parse_registry_key", "list_installed_software",
|
||||
"get_user_activity", "search_registry",
|
||||
"get_system_info", "get_timezone_info", "get_computer_name",
|
||||
"get_shutdown_time", "enumerate_users",
|
||||
"get_network_interfaces", "get_email_config",
|
||||
],
|
||||
tags=["registry", "windows", "system", "user", "software"],
|
||||
),
|
||||
"communication": RoleTemplate(
|
||||
name="communication",
|
||||
role=(
|
||||
"Communication forensic analyst. You analyze email files (.dbx, .pst), "
|
||||
"IRC/mIRC chat logs, newsgroup data, and other messaging artifacts "
|
||||
"to identify communication patterns and contacts."
|
||||
),
|
||||
default_tools=[
|
||||
"list_directory", "extract_file",
|
||||
"read_text_file", "read_binary_preview",
|
||||
"list_extracted_dir", "search_strings",
|
||||
"search_text_file", "read_text_file_section",
|
||||
],
|
||||
tags=["email", "chat", "irc", "messaging", "communication"],
|
||||
),
|
||||
"network": RoleTemplate(
|
||||
name="network",
|
||||
role=(
|
||||
"Network forensic analyst. You analyze browser history, cookies, "
|
||||
"network captures (PCAP), wireless artifacts, and other network-related "
|
||||
"evidence to reconstruct online activities."
|
||||
),
|
||||
default_tools=[
|
||||
"list_directory", "extract_file",
|
||||
"read_text_file", "read_binary_preview",
|
||||
"list_extracted_dir", "search_strings",
|
||||
"search_text_file", "read_text_file_section",
|
||||
"parse_pcap_strings",
|
||||
],
|
||||
tags=["network", "browser", "pcap", "http", "internet"],
|
||||
),
|
||||
"timeline": RoleTemplate(
|
||||
name="timeline",
|
||||
role=(
|
||||
"Timeline correlation analyst. You build chronological timelines "
|
||||
"by combining filesystem MAC times with evidence from other agents. "
|
||||
"You identify temporal patterns and correlate events across categories."
|
||||
),
|
||||
default_tools=[
|
||||
"build_filesystem_timeline",
|
||||
],
|
||||
tags=["timeline", "correlation", "temporal"],
|
||||
),
|
||||
"report": RoleTemplate(
|
||||
name="report",
|
||||
role=(
|
||||
"Forensic report writer. You synthesize all evidence and hypotheses "
|
||||
"into a comprehensive forensic analysis report with executive summary, "
|
||||
"detailed findings organized by hypothesis, timeline of events, and conclusions."
|
||||
),
|
||||
default_tools=[], # Report agent uses only graph query tools
|
||||
tags=["report", "summary", "writing"],
|
||||
),
|
||||
"hypothesis": RoleTemplate(
|
||||
name="hypothesis",
|
||||
role=(
|
||||
"Hypothesis analyst. You review all phenomena discovered so far "
|
||||
"and formulate investigative hypotheses about what happened on the system. "
|
||||
"For each hypothesis, identify which existing phenomena support or contradict it."
|
||||
),
|
||||
default_tools=[], # Uses only graph query + hypothesis tools
|
||||
tags=["hypothesis", "analysis", "reasoning"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
"""Creates agents from templates or dynamically via LLM composition."""
|
||||
"""Creates agents from registered classes or dynamically via LLM composition."""
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
self.llm = llm
|
||||
@@ -152,40 +52,20 @@ class AgentFactory:
|
||||
self._cache: dict[str, BaseAgent] = {}
|
||||
|
||||
def get_or_create_agent(self, agent_type: str) -> BaseAgent | None:
|
||||
"""Get a cached agent or create one from a template."""
|
||||
"""Get a cached agent or instantiate one from its registered class."""
|
||||
if agent_type in self._cache:
|
||||
return self._cache[agent_type]
|
||||
|
||||
template = ROLE_TEMPLATES.get(agent_type)
|
||||
if template is None:
|
||||
logger.warning("No template for agent type: %s", agent_type)
|
||||
return None
|
||||
|
||||
# Use custom agent class if one exists, otherwise BaseAgent
|
||||
_load_agent_classes()
|
||||
agent_cls = _AGENT_CLASSES.get(agent_type)
|
||||
if agent_cls is not None:
|
||||
agent = agent_cls(self.llm, self.graph)
|
||||
else:
|
||||
agent = self._instantiate_from_template(template)
|
||||
if agent_cls is None:
|
||||
logger.warning("No agent class for type: %s", agent_type)
|
||||
return None
|
||||
|
||||
agent = agent_cls(self.llm, self.graph)
|
||||
self._cache[agent_type] = agent
|
||||
return agent
|
||||
|
||||
def _instantiate_from_template(self, template: RoleTemplate) -> BaseAgent:
|
||||
"""Create a BaseAgent from a role template, registering tools from the catalog."""
|
||||
agent = BaseAgent(self.llm, self.graph)
|
||||
agent.name = template.name
|
||||
agent.role = template.role
|
||||
|
||||
for tool_name in template.default_tools:
|
||||
td = TOOL_CATALOG.get(tool_name)
|
||||
if td is None:
|
||||
logger.warning("Tool '%s' not in catalog (template: %s)", tool_name, template.name)
|
||||
continue
|
||||
agent.register_tool(td.name, td.description, td.input_schema, td.executor)
|
||||
|
||||
return agent
|
||||
|
||||
async def create_specialized_agent(
|
||||
self,
|
||||
hypothesis_title: str,
|
||||
@@ -220,18 +100,15 @@ class AgentFactory:
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
|
||||
# Parse response — try to extract JSON
|
||||
try:
|
||||
config = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON in the response
|
||||
import re
|
||||
match = re.search(r'\{.*\}', response, re.DOTALL)
|
||||
if match:
|
||||
config = json.loads(match.group())
|
||||
else:
|
||||
logger.error("Failed to parse agent composition response: %s", response[:300])
|
||||
# Fallback: create a generic agent with all tools
|
||||
return self._create_fallback_agent(capability_gap)
|
||||
|
||||
agent_name = config.get("agent_name", "specialized")
|
||||
@@ -239,13 +116,11 @@ class AgentFactory:
|
||||
strategy = config.get("strategy", "")
|
||||
tool_names = config.get("tools", [])
|
||||
|
||||
# Validate tool names against catalog
|
||||
valid_tools = [t for t in tool_names if t in TOOL_CATALOG]
|
||||
if not valid_tools:
|
||||
logger.warning("No valid tools selected by LLM, using fallback")
|
||||
return self._create_fallback_agent(capability_gap)
|
||||
|
||||
# Build agent
|
||||
agent = BaseAgent(self.llm, self.graph)
|
||||
agent.name = agent_name
|
||||
agent.role = f"{role_text}\n\nInvestigation Strategy:\n{strategy}"
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Hypothesis Agent — analyzes phenomena and generates investigative hypotheses."""
|
||||
"""Hypothesis Agent — generates investigative hypotheses from phenomena.
|
||||
|
||||
Generates hypotheses only. Phenomenon→Hypothesis linking is handled centrally
|
||||
by Orchestrator._judge_new_phenomena, so all link logic lives in one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from base_agent import BaseAgent
|
||||
from evidence_graph import EvidenceGraph, HYPOTHESIS_EDGE_WEIGHTS
|
||||
from evidence_graph import EvidenceGraph
|
||||
from llm_client import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,8 +20,7 @@ class HypothesisAgent(BaseAgent):
|
||||
role = (
|
||||
"Hypothesis analyst. You review all phenomena discovered so far "
|
||||
"and formulate investigative hypotheses about what happened on this system. "
|
||||
"Your ultimate goal: build the most complete picture of events that occurred. "
|
||||
"For each hypothesis, identify which existing phenomena support or contradict it."
|
||||
"Your ultimate goal: build the most complete picture of events that occurred."
|
||||
)
|
||||
|
||||
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
|
||||
@@ -26,10 +28,6 @@ class HypothesisAgent(BaseAgent):
|
||||
self._register_hypothesis_tools()
|
||||
|
||||
def _register_hypothesis_tools(self) -> None:
|
||||
"""Register hypothesis-specific tools."""
|
||||
|
||||
valid_edge_types = list(HYPOTHESIS_EDGE_WEIGHTS.keys())
|
||||
|
||||
self.register_tool(
|
||||
name="add_hypothesis",
|
||||
description=(
|
||||
@@ -53,44 +51,6 @@ class HypothesisAgent(BaseAgent):
|
||||
executor=self._add_hypothesis,
|
||||
)
|
||||
|
||||
self.register_tool(
|
||||
name="link_phenomenon_to_hypothesis",
|
||||
description=(
|
||||
"Link an existing phenomenon to a hypothesis with a relationship type. "
|
||||
f"Valid relationship types: {', '.join(valid_edge_types)}. "
|
||||
"direct_evidence = the phenomenon IS the hypothesis. "
|
||||
"supports = consistent with the hypothesis. "
|
||||
"prerequisite_met = a necessary condition is satisfied. "
|
||||
"consequence_observed = an expected result of the hypothesis is found. "
|
||||
"contradicts = directly contradicts the hypothesis. "
|
||||
"weakens = makes the hypothesis less likely."
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"phenomenon_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the phenomenon (e.g. 'ph-a1b2c3d4').",
|
||||
},
|
||||
"hypothesis_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the hypothesis (e.g. 'hyp-e5f6g7h8').",
|
||||
},
|
||||
"edge_type": {
|
||||
"type": "string",
|
||||
"enum": valid_edge_types,
|
||||
"description": "The edge_type of the relationship.",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "The reason this relationship holds (1-2 sentences).",
|
||||
},
|
||||
},
|
||||
"required": ["phenomenon_id", "hypothesis_id", "edge_type", "reason"],
|
||||
},
|
||||
executor=self._link_phenomenon_to_hypothesis,
|
||||
)
|
||||
|
||||
async def _add_hypothesis(self, title: str, description: str) -> str:
|
||||
hid = await self.graph.add_hypothesis(
|
||||
title=title,
|
||||
@@ -98,33 +58,3 @@ class HypothesisAgent(BaseAgent):
|
||||
created_by=self.name,
|
||||
)
|
||||
return f"Hypothesis created: {hid} — {title} (confidence: 0.50)"
|
||||
|
||||
async def _link_phenomenon_to_hypothesis(
|
||||
self,
|
||||
phenomenon_id: str,
|
||||
hypothesis_id: str,
|
||||
edge_type: str = "",
|
||||
reason: str = "",
|
||||
# Common LLM misnaming — accept as fallbacks
|
||||
relationship: str = "",
|
||||
note: str = "",
|
||||
) -> str:
|
||||
edge_type = edge_type or relationship
|
||||
reason = reason or note
|
||||
if not edge_type:
|
||||
return "Error: edge_type is required."
|
||||
try:
|
||||
new_conf = await self.graph.update_hypothesis_confidence(
|
||||
hyp_id=hypothesis_id,
|
||||
phenomenon_id=phenomenon_id,
|
||||
edge_type=edge_type,
|
||||
reason=reason,
|
||||
)
|
||||
weight = HYPOTHESIS_EDGE_WEIGHTS[edge_type]
|
||||
direction = "+" if weight > 0 else ""
|
||||
return (
|
||||
f"Linked: {phenomenon_id} —[{edge_type}]→ {hypothesis_id} "
|
||||
f"(weight: {direction}{weight}, new confidence: {new_conf:.3f})"
|
||||
)
|
||||
except ValueError as e:
|
||||
return f"Error linking: {e}"
|
||||
|
||||
@@ -37,6 +37,7 @@ class BaseAgent:
|
||||
self._tools: dict[str, dict] = {} # name -> schema
|
||||
self._executors: dict[str, Any] = {} # name -> async callable
|
||||
self._work_log: list[str] = []
|
||||
self._current_lead_id: str | None = None
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
@@ -107,11 +108,12 @@ class BaseAgent:
|
||||
f"- Do NOT fabricate execution timestamps — only report timestamps returned by tools"
|
||||
)
|
||||
|
||||
async def run(self, task: str) -> str:
|
||||
async def run(self, task: str, lead_id: str | None = None) -> str:
|
||||
"""Run this agent with a specific task."""
|
||||
_log(task, event="agent_start", agent=self.name)
|
||||
self.graph.agent_status[self.name] = "running"
|
||||
self.graph._current_agent = self.name
|
||||
self._current_lead_id = lead_id
|
||||
|
||||
self._register_graph_tools()
|
||||
|
||||
@@ -375,6 +377,7 @@ class BaseAgent:
|
||||
raw_data=raw_data,
|
||||
timestamp=timestamp,
|
||||
source_tool=source_tool,
|
||||
from_lead_id=self._current_lead_id,
|
||||
)
|
||||
if merged:
|
||||
return f"Phenomenon merged into existing: {pid} — {title} (corroboration boost)"
|
||||
|
||||
@@ -18,10 +18,12 @@ from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Predefined edge weights for Phenomenon → Hypothesis relationships.
|
||||
# Default edge weights for Phenomenon → Hypothesis relationships.
|
||||
# LLM only picks the edge type (categorical); the weight is looked up here.
|
||||
# Override per-graph via EvidenceGraph(edge_weights=...) or config.yaml's
|
||||
# `hypothesis_edge_weights` section.
|
||||
# ---------------------------------------------------------------------------
|
||||
HYPOTHESIS_EDGE_WEIGHTS: dict[str, float] = {
|
||||
_DEFAULT_EDGE_WEIGHTS: dict[str, float] = {
|
||||
"direct_evidence": +0.25,
|
||||
"supports": +0.15,
|
||||
"prerequisite_met": +0.10,
|
||||
@@ -94,6 +96,7 @@ class Phenomenon:
|
||||
confidence: float = 1.0
|
||||
source_tool: str = ""
|
||||
corroborating_agents: list[str] = field(default_factory=list)
|
||||
from_lead_id: str | None = None
|
||||
created_at: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@@ -239,8 +242,12 @@ class EvidenceGraph:
|
||||
self,
|
||||
case_info: dict | None = None,
|
||||
persist_path: Path | None = None,
|
||||
edge_weights: dict[str, float] | None = None,
|
||||
) -> None:
|
||||
self.case_info: dict = case_info or {}
|
||||
self.edge_weights: dict[str, float] = (
|
||||
dict(edge_weights) if edge_weights else dict(_DEFAULT_EDGE_WEIGHTS)
|
||||
)
|
||||
self.image_path: str = ""
|
||||
self.partition_offset: int = 0
|
||||
self.extracted_dir: str = "extracted"
|
||||
@@ -304,12 +311,17 @@ class EvidenceGraph:
|
||||
self._persist_path = old
|
||||
|
||||
@classmethod
|
||||
def load_state(cls, path: Path) -> EvidenceGraph:
|
||||
def load_state(
|
||||
cls,
|
||||
path: Path,
|
||||
edge_weights: dict[str, float] | None = None,
|
||||
) -> EvidenceGraph:
|
||||
"""Restore an EvidenceGraph from a saved JSON state file."""
|
||||
data = json.loads(path.read_text())
|
||||
graph = cls(
|
||||
case_info=data.get("case_info", {}),
|
||||
persist_path=path,
|
||||
edge_weights=edge_weights,
|
||||
)
|
||||
graph.image_path = data.get("image_path", "")
|
||||
graph.partition_offset = data.get("partition_offset", 0)
|
||||
@@ -403,6 +415,7 @@ class EvidenceGraph:
|
||||
raw_data: dict | None = None,
|
||||
timestamp: str | None = None,
|
||||
source_tool: str = "",
|
||||
from_lead_id: str | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Add a phenomenon. Returns (id, was_merged).
|
||||
|
||||
@@ -419,6 +432,8 @@ class EvidenceGraph:
|
||||
for k, v in raw_data.items():
|
||||
if k not in similar.raw_data:
|
||||
similar.raw_data[k] = v
|
||||
if from_lead_id and similar.from_lead_id is None:
|
||||
similar.from_lead_id = from_lead_id
|
||||
self._auto_save()
|
||||
return similar.id, True
|
||||
|
||||
@@ -437,6 +452,7 @@ class EvidenceGraph:
|
||||
timestamp=timestamp,
|
||||
confidence=confidence,
|
||||
source_tool=source_tool,
|
||||
from_lead_id=from_lead_id,
|
||||
created_at=datetime.now().isoformat(),
|
||||
)
|
||||
self.phenomena[pid] = ph
|
||||
@@ -532,14 +548,14 @@ class EvidenceGraph:
|
||||
) -> float:
|
||||
"""Update hypothesis confidence based on a phenomenon linkage.
|
||||
|
||||
The edge_type must be one of HYPOTHESIS_EDGE_WEIGHTS keys.
|
||||
Weight is looked up from the predefined table, NOT judged by LLM.
|
||||
The edge_type must be one of self.edge_weights keys.
|
||||
Weight is looked up from the configured table, NOT judged by LLM.
|
||||
Returns the new confidence value.
|
||||
"""
|
||||
if edge_type not in HYPOTHESIS_EDGE_WEIGHTS:
|
||||
if edge_type not in self.edge_weights:
|
||||
raise ValueError(
|
||||
f"Invalid hypothesis edge type: {edge_type}. "
|
||||
f"Must be one of: {list(HYPOTHESIS_EDGE_WEIGHTS.keys())}"
|
||||
f"Must be one of: {list(self.edge_weights.keys())}"
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
@@ -549,7 +565,7 @@ class EvidenceGraph:
|
||||
if hyp is None:
|
||||
raise ValueError(f"Hypothesis not found: {hyp_id}")
|
||||
|
||||
weight = HYPOTHESIS_EDGE_WEIGHTS[edge_type]
|
||||
weight = self.edge_weights[edge_type]
|
||||
old_conf = hyp.confidence
|
||||
|
||||
if weight > 0:
|
||||
|
||||
1
main.py
1
main.py
@@ -229,6 +229,7 @@ async def async_main() -> None:
|
||||
graph = EvidenceGraph(
|
||||
case_info=config.get("cfreds_hacking_case", {}),
|
||||
persist_path=run_dir / "graph_state.json",
|
||||
edge_weights=config.get("hypothesis_edge_weights"),
|
||||
)
|
||||
graph.image_path = image_path
|
||||
graph.partition_offset = partition_offset
|
||||
|
||||
@@ -11,7 +11,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from agent_factory import AgentFactory
|
||||
from evidence_graph import EvidenceGraph, HYPOTHESIS_EDGE_WEIGHTS
|
||||
from evidence_graph import EvidenceGraph
|
||||
from llm_client import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -149,7 +149,8 @@ class Orchestrator:
|
||||
await agent.run(
|
||||
f"Investigate this lead: {lead.description}\n"
|
||||
f"{hyp_line}"
|
||||
f"Focus area: {lead.target_agent}"
|
||||
f"Focus area: {lead.target_agent}",
|
||||
lead_id=lead.id,
|
||||
)
|
||||
await self.graph.mark_lead_completed(lead.id)
|
||||
self._failure_count = 0
|
||||
@@ -209,11 +210,9 @@ class Orchestrator:
|
||||
"1. Specific and testable\n"
|
||||
"2. About a distinct aspect of activity (e.g., hacking tools, communication, "
|
||||
"network attacks, data theft)\n\n"
|
||||
"For each hypothesis:\n"
|
||||
"- Call add_hypothesis to create it\n"
|
||||
"- Then call link_phenomenon_to_hypothesis to link relevant existing phenomena\n"
|
||||
"- Choose the relationship type carefully: direct_evidence, supports, "
|
||||
"prerequisite_met, consequence_observed, contradicts, or weakens\n\n"
|
||||
"Call add_hypothesis for each. The orchestrator will automatically link "
|
||||
"relevant existing phenomena to each hypothesis after you finish — you do "
|
||||
"not need to (and cannot) create those links yourself.\n\n"
|
||||
"The ultimate goal is to reconstruct a detailed timeline of what happened on this host."
|
||||
)
|
||||
|
||||
@@ -333,7 +332,7 @@ class Orchestrator:
|
||||
if not unlinked:
|
||||
return
|
||||
|
||||
valid_types = list(HYPOTHESIS_EDGE_WEIGHTS.keys())
|
||||
valid_types = list(self.graph.edge_weights.keys())
|
||||
|
||||
hyp_section = "\n".join(
|
||||
f" [{h.id}] {h.title}: {h.description}" for h in active
|
||||
@@ -370,7 +369,7 @@ class Orchestrator:
|
||||
if (
|
||||
hyp_id in self.graph.hypotheses
|
||||
and ph_id in self.graph.phenomena
|
||||
and edge_type in HYPOTHESIS_EDGE_WEIGHTS
|
||||
and edge_type in self.graph.edge_weights
|
||||
):
|
||||
await self.graph.update_hypothesis_confidence(
|
||||
hyp_id=hyp_id,
|
||||
@@ -413,7 +412,7 @@ class Orchestrator:
|
||||
ph_id = j.get("phenomenon_id", "")
|
||||
edge_type = j.get("edge_type", "")
|
||||
reason = j.get("reason", "")
|
||||
if ph_id in self.graph.phenomena and edge_type in HYPOTHESIS_EDGE_WEIGHTS:
|
||||
if ph_id in self.graph.phenomena and edge_type in self.graph.edge_weights:
|
||||
await self.graph.update_hypothesis_confidence(
|
||||
hyp_id=hyp.id,
|
||||
phenomenon_id=ph_id,
|
||||
@@ -505,6 +504,7 @@ class Orchestrator:
|
||||
break
|
||||
_log(f"Gap fill round {round_num}: {len(pending)} leads", event="dispatch")
|
||||
await self._dispatch_leads_parallel(pending)
|
||||
await self._judge_new_phenomena()
|
||||
|
||||
# ---- Run archiving -------------------------------------------------------
|
||||
|
||||
@@ -604,11 +604,13 @@ class Orchestrator:
|
||||
manual_hypotheses = self.config.get("hypotheses", [])
|
||||
if manual_hypotheses:
|
||||
await self._generate_hypotheses_manual(manual_hypotheses)
|
||||
if self.graph.phenomena:
|
||||
await self._judge_new_phenomena()
|
||||
else:
|
||||
await self._generate_hypotheses_auto()
|
||||
|
||||
# Unified judge step — link Phase 1 phenomena to newly-created hypotheses
|
||||
if self.graph.phenomena and self.graph.hypotheses:
|
||||
await self._judge_new_phenomena()
|
||||
|
||||
for h in self.graph.hypotheses.values():
|
||||
_log(f" {h.summary()}", event="hypothesis")
|
||||
_log(
|
||||
|
||||
@@ -13,8 +13,16 @@ from tool_registry import register_all_tools
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Find the run to regenerate from
|
||||
run_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("runs/2026-04-02T15-11-25")
|
||||
# Find the run: CLI arg, or latest run with a graph_state.json
|
||||
if len(sys.argv) > 1:
|
||||
run_dir = Path(sys.argv[1])
|
||||
else:
|
||||
states = sorted(Path("runs").glob("*/graph_state.json"), reverse=True)
|
||||
if not states:
|
||||
print("No runs found in runs/")
|
||||
return
|
||||
run_dir = states[0].parent
|
||||
print(f"Using latest run: {run_dir.name}")
|
||||
state_path = run_dir / "graph_state.json"
|
||||
|
||||
if not state_path.exists():
|
||||
@@ -24,8 +32,11 @@ async def main() -> None:
|
||||
config = yaml.safe_load(open("config.yaml"))
|
||||
agent_cfg = config["agent"]
|
||||
|
||||
# Load graph
|
||||
graph = EvidenceGraph.load_state(state_path)
|
||||
# Load graph (edge_weights from config — applied to the loaded graph)
|
||||
graph = EvidenceGraph.load_state(
|
||||
state_path,
|
||||
edge_weights=config.get("hypothesis_edge_weights"),
|
||||
)
|
||||
print(f"Loaded: {graph.stats_summary()}")
|
||||
|
||||
# LLM client with larger max_tokens for report
|
||||
|
||||
957
tests/test_optimizations.py
Normal file
957
tests/test_optimizations.py
Normal file
@@ -0,0 +1,957 @@
|
||||
"""Tests for dedup/quality scoring, context management, parallel dispatch, and evidence graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from evidence_graph import (
|
||||
EvidenceGraph, Phenomenon, Lead,
|
||||
_compute_quality_score, _jaccard_similarity,
|
||||
)
|
||||
from llm_client import (
|
||||
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
|
||||
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
|
||||
TOOL_RESULT_TAG, TOOL_RESULT_END,
|
||||
)
|
||||
from tool_registry import (
|
||||
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality scoring & similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQualityScore:
|
||||
def test_full_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="list_dir",
|
||||
timestamp="2024-01-01",
|
||||
raw_data={"key": "val"},
|
||||
description="A" * 50,
|
||||
related_ids=["ph-0"],
|
||||
)
|
||||
assert score == pytest.approx(1.0)
|
||||
|
||||
def test_minimal_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="",
|
||||
timestamp=None,
|
||||
raw_data={},
|
||||
description="short",
|
||||
related_ids=[],
|
||||
)
|
||||
assert score == pytest.approx(0.0)
|
||||
|
||||
def test_partial_score(self):
|
||||
score = _compute_quality_score(
|
||||
source_tool="parse_registry_key",
|
||||
timestamp=None,
|
||||
raw_data={},
|
||||
description="A" * 50,
|
||||
related_ids=[],
|
||||
)
|
||||
assert score == pytest.approx(0.40)
|
||||
|
||||
|
||||
class TestJaccardSimilarity:
|
||||
def test_identical(self):
|
||||
assert _jaccard_similarity("hello world", "hello world") == pytest.approx(1.0)
|
||||
|
||||
def test_no_overlap(self):
|
||||
assert _jaccard_similarity("hello world", "foo bar") == pytest.approx(0.0)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
sim = _jaccard_similarity("User account Mr. Evil", "User account Mr. Evil enumeration")
|
||||
assert 0.5 < sim < 1.0
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _jaccard_similarity("", "hello") == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Evidence graph: phenomenon dedup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPhenomenonDedup:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_phenomenon_gets_quality_score(self, graph):
|
||||
pid, merged = await graph.add_phenomenon(
|
||||
source_agent="fs", category="filesystem",
|
||||
title="Directory listing",
|
||||
description="Found important files in the root directory of the disk image",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
assert not merged
|
||||
ph = graph.phenomena[pid]
|
||||
assert ph.confidence == pytest.approx(0.40)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identical_phenomenon_merges(self, graph):
|
||||
pid1, m1 = await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003 in the SAM hive",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="User account Mr. Evil discovered with RID 1003 in SAM registry hive",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
assert not m1
|
||||
assert m2
|
||||
assert pid1 == pid2
|
||||
ph = graph.phenomena[pid1]
|
||||
assert "reg" in ph.corroborating_agents
|
||||
assert ph.confidence > 0.40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_category_no_merge(self, graph):
|
||||
pid1, _ = await graph.add_phenomenon(
|
||||
source_agent="fs", category="filesystem",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003",
|
||||
source_tool="list_directory",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="User account Mr. Evil",
|
||||
description="Found user account Mr. Evil with RID 1003",
|
||||
source_tool="enumerate_users",
|
||||
)
|
||||
assert not m2
|
||||
assert pid1 != pid2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dissimilar_titles_no_merge(self, graph):
|
||||
await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="Installed software list",
|
||||
description="Found 25 installed programs",
|
||||
source_tool="list_installed_software",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="Last shutdown time",
|
||||
description="System last shut down at 2004-08-27",
|
||||
source_tool="get_shutdown_time",
|
||||
)
|
||||
assert not m2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_adds_raw_data(self, graph):
|
||||
pid1, _ = await graph.add_phenomenon(
|
||||
source_agent="fs", category="registry",
|
||||
title="Network config",
|
||||
description="Network interface configuration with IP address and MAC address details",
|
||||
raw_data={"ip": "192.168.1.1"},
|
||||
source_tool="get_network_interfaces",
|
||||
)
|
||||
pid2, m2 = await graph.add_phenomenon(
|
||||
source_agent="reg", category="registry",
|
||||
title="Network config",
|
||||
description="Network interface configuration including DHCP and MAC address info",
|
||||
raw_data={"mac": "00:11:22:33:44:55"},
|
||||
source_tool="get_network_interfaces",
|
||||
)
|
||||
assert m2
|
||||
ph = graph.phenomena[pid1]
|
||||
assert "ip" in ph.raw_data
|
||||
assert "mac" in ph.raw_data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis confidence updates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHypothesisConfidence:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supports_increases_confidence(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
old = graph.hypotheses[hid].confidence
|
||||
new = await graph.update_hypothesis_confidence(hid, pid, "supports", "reason")
|
||||
assert new > old
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contradicts_decreases_confidence(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
old = graph.hypotheses[hid].confidence
|
||||
new = await graph.update_hypothesis_confidence(hid, pid, "contradicts", "reason")
|
||||
assert new < old
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_evidence_has_largest_effect(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
conf = await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "reason")
|
||||
# direct_evidence weight is +0.25 * (1-0.5) = +0.125
|
||||
assert conf == pytest.approx(0.625)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_log_tracked(self, graph):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
|
||||
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
|
||||
await graph.update_hypothesis_confidence(hid, pid, "supports", "because reasons")
|
||||
log = graph.hypotheses[hid].confidence_log
|
||||
assert len(log) == 1
|
||||
assert log[0]["edge_type"] == "supports"
|
||||
assert log[0]["reason"] == "because reasons"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supported_status_at_threshold(self, graph):
|
||||
hid = await graph.add_hypothesis("test", "desc", created_by="test")
|
||||
# Push confidence up with multiple direct_evidence links
|
||||
for i in range(5):
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", f"test {i}", f"desc {i}", source_tool="t")
|
||||
await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "proof")
|
||||
assert graph.hypotheses[hid].status == "supported"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context window management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTruncateToolResult:
|
||||
def test_short_text_unchanged(self):
|
||||
text = "short result"
|
||||
assert _truncate_tool_result(text) == text
|
||||
|
||||
def test_long_text_truncated(self):
|
||||
text = "x" * 5000
|
||||
result = _truncate_tool_result(text, max_chars=3000)
|
||||
assert len(result) < 3100
|
||||
assert "truncated" in result
|
||||
assert "5000" in result
|
||||
|
||||
def test_exact_boundary(self):
|
||||
text = "x" * 3000
|
||||
assert _truncate_tool_result(text, max_chars=3000) == text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel dispatch & async safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_add_phenomenon(self):
|
||||
graph = EvidenceGraph()
|
||||
|
||||
async def add_batch(agent_name: str, count: int):
|
||||
for i in range(count):
|
||||
await graph.add_phenomenon(
|
||||
source_agent=agent_name,
|
||||
category=f"cat-{agent_name}",
|
||||
title=f"Phenomenon {agent_name}-{i}",
|
||||
description=f"Description from {agent_name} number {i} with enough detail to be meaningful",
|
||||
source_tool="test_tool",
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
add_batch("agent_a", 10),
|
||||
add_batch("agent_b", 10),
|
||||
add_batch("agent_c", 10),
|
||||
)
|
||||
assert len(graph.phenomena) == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_add_lead(self):
|
||||
graph = EvidenceGraph()
|
||||
|
||||
async def add_leads(count: int):
|
||||
for i in range(count):
|
||||
await graph.add_lead(target_agent="network", description=f"Lead {i}")
|
||||
|
||||
await asyncio.gather(add_leads(10), add_leads(10))
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_lead_completed(self):
|
||||
graph = EvidenceGraph()
|
||||
lid = await graph.add_lead(target_agent="fs", description="test lead")
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 1
|
||||
|
||||
await graph.mark_lead_completed(lid)
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 0
|
||||
|
||||
|
||||
class TestParallelDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_agent_types_run_concurrently(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
execution_log: list[tuple[str, float, float]] = []
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
async def run(self, task, lead_id=None):
|
||||
start = time.monotonic()
|
||||
await asyncio.sleep(0.05)
|
||||
end = time.monotonic()
|
||||
execution_log.append((self.name, start, end))
|
||||
|
||||
factory._cache = {
|
||||
"filesystem": FakeAgent("filesystem"),
|
||||
"registry": FakeAgent("registry"),
|
||||
"network": FakeAgent("network"),
|
||||
}
|
||||
|
||||
leads = [
|
||||
Lead(id="l1", target_agent="filesystem", description="fs task"),
|
||||
Lead(id="l2", target_agent="registry", description="reg task"),
|
||||
Lead(id="l3", target_agent="network", description="net task"),
|
||||
]
|
||||
graph.leads = leads
|
||||
|
||||
start_time = time.monotonic()
|
||||
await orch._dispatch_leads_parallel(leads)
|
||||
total_time = time.monotonic() - start_time
|
||||
|
||||
assert total_time < 0.12, f"Expected parallel execution but took {total_time:.3f}s"
|
||||
assert len(execution_log) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_agent_type_runs_serially(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
execution_order: list[str] = []
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
async def run(self, task, lead_id=None):
|
||||
execution_order.append(task)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
factory._cache = {"filesystem": FakeAgent("filesystem")}
|
||||
|
||||
leads = [
|
||||
Lead(id="l1", target_agent="filesystem", description="task 1"),
|
||||
Lead(id="l2", target_agent="filesystem", description="task 2"),
|
||||
]
|
||||
graph.leads = leads
|
||||
|
||||
await orch._dispatch_leads_parallel(leads)
|
||||
|
||||
assert len(execution_order) == 2
|
||||
assert "task 1" in execution_order[0]
|
||||
assert "task 2" in execution_order[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Asset library
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAssetLibrary:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
return EvidenceGraph()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_asset(self, graph):
|
||||
aid, existed = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
assert not existed
|
||||
assert aid.startswith("asset-")
|
||||
assert len(graph.asset_library) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_by_inode(self, graph):
|
||||
aid1, _ = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
aid2, existed = await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="registry",
|
||||
)
|
||||
assert existed
|
||||
assert aid1 == aid2
|
||||
assert len(graph.asset_library) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_category(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
await graph.register_asset(
|
||||
inode="10090-128-1", original_path="mIRC/logs/hackers.log",
|
||||
local_path="extracted/hackers.log", category="chat_log",
|
||||
filename="hackers.log", size_bytes=200, extracted_by="comm",
|
||||
)
|
||||
results = graph.query_assets(category="registry_hive")
|
||||
assert len(results) == 1
|
||||
assert results[0].filename == "system"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_by_inode(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
asset = graph.lookup_asset_by_inode("334-128-4")
|
||||
assert asset is not None
|
||||
assert asset.local_path == "extracted/system"
|
||||
|
||||
assert graph.lookup_asset_by_inode("999-128-4") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_round_trip(self, graph, tmp_path):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4",
|
||||
original_path="WINDOWS/system32/config/SYSTEM",
|
||||
local_path="extracted/system",
|
||||
category="registry_hive",
|
||||
filename="system",
|
||||
size_bytes=262144,
|
||||
extracted_by="filesystem",
|
||||
)
|
||||
path = tmp_path / "test_state.json"
|
||||
graph.save_state(path)
|
||||
|
||||
loaded = EvidenceGraph.load_state(path)
|
||||
assert len(loaded.asset_library) == 1
|
||||
assert loaded.lookup_asset_by_inode("334-128-4") is not None
|
||||
assert loaded.lookup_asset_by_inode("334-128-4").category == "registry_hive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=8192, extracted_by="fs",
|
||||
)
|
||||
summaries = graph.list_assets()
|
||||
assert len(summaries) == 1
|
||||
assert "system" in summaries[0]
|
||||
assert "registry_hive" in summaries[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_by_filename(self, graph):
|
||||
await graph.register_asset(
|
||||
inode="10080-128-3", original_path="mIRC/mirc.ini",
|
||||
local_path="extracted/mirc.ini", category="config_file",
|
||||
filename="mirc.ini", size_bytes=500, extracted_by="comm",
|
||||
)
|
||||
await graph.register_asset(
|
||||
inode="334-128-4", original_path="config/SYSTEM",
|
||||
local_path="extracted/system", category="registry_hive",
|
||||
filename="system", size_bytes=100, extracted_by="fs",
|
||||
)
|
||||
results = graph.query_assets(filename_pattern="mirc")
|
||||
assert len(results) == 1
|
||||
assert results[0].filename == "mirc.ini"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool call partitioning (P0: parallel tool execution)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPartitionToolCalls:
|
||||
def test_consecutive_read_only_grouped(self):
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {"inode": "33"}},
|
||||
{"name": "list_directory", "arguments": {"inode": "46"}},
|
||||
{"name": "list_directory", "arguments": {"inode": "50"}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
assert len(batches[0].calls) == 3
|
||||
|
||||
def test_write_tools_isolated(self):
|
||||
calls = [
|
||||
{"name": "extract_file", "arguments": {"inode": "40"}},
|
||||
{"name": "extract_file", "arguments": {"inode": "41"}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 2
|
||||
assert not batches[0].is_read_only
|
||||
assert not batches[1].is_read_only
|
||||
|
||||
def test_mixed_read_write_partitioned(self):
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "test"}},
|
||||
{"name": "add_phenomenon", "arguments": {"title": "x"}},
|
||||
{"name": "list_phenomena", "arguments": {}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 3
|
||||
assert batches[0].is_read_only and len(batches[0].calls) == 2
|
||||
assert not batches[1].is_read_only and len(batches[1].calls) == 1
|
||||
assert batches[2].is_read_only and len(batches[2].calls) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert _partition_tool_calls([]) == []
|
||||
|
||||
def test_single_read_only(self):
|
||||
calls = [{"name": "partition_info", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
|
||||
def test_single_write(self):
|
||||
calls = [{"name": "add_lead", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert not batches[0].is_read_only
|
||||
|
||||
def test_unknown_tool_treated_as_write(self):
|
||||
calls = [{"name": "unknown_tool", "arguments": {}}]
|
||||
batches = _partition_tool_calls(calls)
|
||||
assert len(batches) == 1
|
||||
assert not batches[0].is_read_only
|
||||
|
||||
def test_custom_read_only_set(self):
|
||||
calls = [
|
||||
{"name": "foo", "arguments": {}},
|
||||
{"name": "bar", "arguments": {}},
|
||||
]
|
||||
batches = _partition_tool_calls(calls, read_only={"foo", "bar"})
|
||||
assert len(batches) == 1
|
||||
assert batches[0].is_read_only
|
||||
|
||||
|
||||
class TestParallelToolExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_faster_than_serial(self):
|
||||
"""Read-only tools should execute concurrently."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
async def slow_tool(**kwargs):
|
||||
await asyncio.sleep(0.05)
|
||||
return "ok"
|
||||
|
||||
tool_executor = {
|
||||
"list_directory": slow_tool,
|
||||
"search_graph": slow_tool,
|
||||
"list_phenomena": slow_tool,
|
||||
}
|
||||
|
||||
client = LLMClient.__new__(LLMClient) # skip __init__
|
||||
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "x"}},
|
||||
{"name": "list_phenomena", "arguments": {}},
|
||||
]
|
||||
|
||||
start = time.monotonic()
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert len(results) == 3
|
||||
assert all("ok" in r for r in results)
|
||||
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
|
||||
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_in_parallel_batch_does_not_crash(self):
|
||||
"""A failing tool in a parallel batch should produce an error result, not crash."""
|
||||
from llm_client import LLMClient
|
||||
|
||||
async def ok_tool(**kwargs):
|
||||
return "success"
|
||||
|
||||
async def bad_tool(**kwargs):
|
||||
raise ValueError("broken")
|
||||
|
||||
tool_executor = {
|
||||
"list_directory": ok_tool,
|
||||
"search_graph": bad_tool,
|
||||
}
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
calls = [
|
||||
{"name": "list_directory", "arguments": {}},
|
||||
{"name": "search_graph", "arguments": {"keyword": "x"}},
|
||||
]
|
||||
|
||||
results = await client._execute_tool_batch_parallel(calls, tool_executor)
|
||||
assert len(results) == 2
|
||||
assert "success" in results[0]
|
||||
assert "Error" in results[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batched orchestrator LLM calls (P0: batch lead gen + judging)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBatchLeadGeneration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_leads_created_with_hypothesis_id(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
# Set up 2 hypotheses
|
||||
hid1 = await graph.add_hypothesis("Wardriving", "Wireless scanning", created_by="test")
|
||||
hid2 = await graph.add_hypothesis("Data theft", "Stolen files", created_by="test")
|
||||
|
||||
# Mock LLM to return batched JSON
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "{hid1}", "agent": "network", "task": "analyze pcap", "priority": 3}},'
|
||||
f' {{"hypothesis_id": "{hid2}", "agent": "filesystem", "task": "check deleted files", "priority": 5}}]'
|
||||
)
|
||||
|
||||
await orch._generate_hypothesis_leads()
|
||||
|
||||
# Should have made exactly 1 LLM call
|
||||
assert llm.chat.call_count == 1
|
||||
# Should have created 2 leads
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 2
|
||||
agents = {l.target_agent for l in pending}
|
||||
assert "network" in agents
|
||||
assert "filesystem" in agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_parse_error(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid = await graph.add_hypothesis("Test", "Test desc", created_by="test")
|
||||
|
||||
# First call (batched) returns invalid JSON; fallback calls succeed
|
||||
llm.chat.side_effect = [
|
||||
"not valid json at all", # batched call fails
|
||||
'[{"agent": "registry", "task": "check SAM", "priority": 5}]', # fallback
|
||||
]
|
||||
|
||||
await orch._generate_hypothesis_leads()
|
||||
|
||||
# Should have called LLM twice: 1 batched (failed) + 1 fallback
|
||||
assert llm.chat.call_count == 2
|
||||
pending = await graph.get_pending_leads()
|
||||
assert len(pending) == 1
|
||||
|
||||
|
||||
class TestBatchJudging:
|
||||
@pytest.mark.asyncio
|
||||
async def test_batched_judgments_applied(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid1 = await graph.add_hypothesis("Hyp A", "Desc A", created_by="test")
|
||||
hid2 = await graph.add_hypothesis("Hyp B", "Desc B", created_by="test")
|
||||
pid1, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 1", "details 1", source_tool="t")
|
||||
pid2, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 2", "details 2", source_tool="t")
|
||||
|
||||
old_conf1 = graph.hypotheses[hid1].confidence
|
||||
old_conf2 = graph.hypotheses[hid2].confidence
|
||||
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "{hid1}", "phenomenon_id": "{pid1}", "edge_type": "supports", "reason": "evidence for A"}},'
|
||||
f' {{"hypothesis_id": "{hid2}", "phenomenon_id": "{pid2}", "edge_type": "contradicts", "reason": "disproves B"}}]'
|
||||
)
|
||||
|
||||
await orch._judge_new_phenomena()
|
||||
|
||||
assert llm.chat.call_count == 1
|
||||
assert graph.hypotheses[hid1].confidence > old_conf1
|
||||
assert graph.hypotheses[hid2].confidence < old_conf2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_ids_skipped(self):
|
||||
from unittest.mock import AsyncMock
|
||||
from orchestrator import Orchestrator
|
||||
from agent_factory import AgentFactory
|
||||
|
||||
graph = EvidenceGraph()
|
||||
llm = AsyncMock()
|
||||
factory = AgentFactory(llm, graph)
|
||||
orch = Orchestrator(llm, graph, factory)
|
||||
|
||||
hid = await graph.add_hypothesis("Hyp", "Desc", created_by="test")
|
||||
pid, _ = await graph.add_phenomenon("fs", "filesystem", "Finding", "details", source_tool="t")
|
||||
|
||||
old_conf = graph.hypotheses[hid].confidence
|
||||
|
||||
llm.chat.return_value = (
|
||||
f'[{{"hypothesis_id": "hyp-nonexistent", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "bad ref"}},'
|
||||
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "ph-nonexistent", "edge_type": "supports", "reason": "bad ref"}},'
|
||||
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "good"}}]'
|
||||
)
|
||||
|
||||
await orch._judge_new_phenomena()
|
||||
|
||||
# Only the valid judgment should have been applied
|
||||
assert graph.hypotheses[hid].confidence > old_conf
|
||||
# Only 1 edge created (the valid one)
|
||||
hyp_edges = [e for e in graph.edges if e.target_id == hid and e.source_id == pid]
|
||||
assert len(hyp_edges) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool result caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolResultCache:
|
||||
def setup_method(self):
|
||||
_tool_result_cache.clear()
|
||||
|
||||
def test_cache_key_deterministic(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("list_directory", {"inode": "33"})
|
||||
assert k1 == k2
|
||||
|
||||
def test_cache_key_differs_by_args(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("list_directory", {"inode": "46"})
|
||||
assert k1 != k2
|
||||
|
||||
def test_cache_key_differs_by_tool(self):
|
||||
k1 = _cache_key("list_directory", {"inode": "33"})
|
||||
k2 = _cache_key("find_file", {"inode": "33"})
|
||||
assert k1 != k2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_executor_returns_cached_result(self):
|
||||
call_count = 0
|
||||
|
||||
async def real_executor(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result for {kwargs}"
|
||||
|
||||
cached = _make_cached("test_tool", real_executor)
|
||||
|
||||
r1 = await cached(inode="33")
|
||||
r2 = await cached(inode="33")
|
||||
assert r1 == r2
|
||||
assert call_count == 1 # only called once
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_args_not_cached(self):
|
||||
call_count = 0
|
||||
|
||||
async def real_executor(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{call_count}"
|
||||
|
||||
cached = _make_cached("test_tool", real_executor)
|
||||
|
||||
r1 = await cached(inode="33")
|
||||
r2 = await cached(inode="46")
|
||||
assert r1 != r2
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_results_not_cached(self):
|
||||
call_count = 0
|
||||
|
||||
async def flaky(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "Error: something went wrong"
|
||||
return "success"
|
||||
|
||||
cached = _make_cached("test_tool", flaky)
|
||||
|
||||
r1 = await cached(x=1)
|
||||
assert "Error" in r1
|
||||
r2 = await cached(x=1)
|
||||
assert r2 == "success"
|
||||
assert call_count == 2 # called twice because error wasn't cached
|
||||
|
||||
def test_cacheable_tools_are_read_only(self):
|
||||
"""Every cacheable tool should also be in READ_ONLY_TOOLS."""
|
||||
for t in CACHEABLE_TOOLS:
|
||||
assert t in READ_ONLY_TOOLS, f"{t} is cacheable but not in READ_ONLY_TOOLS"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progressive context decay (Stage A)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProgressiveDecay:
|
||||
def _make_messages(self, n_rounds: int) -> list[dict]:
|
||||
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(n_rounds):
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
|
||||
})
|
||||
# Tool result message with substantial content
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"{TOOL_RESULT_TAG}\n"
|
||||
f"[tool_{i}] {'x' * 2500}\n"
|
||||
f"{TOOL_RESULT_END}"
|
||||
),
|
||||
})
|
||||
return messages
|
||||
|
||||
def test_short_conversation_unchanged(self):
|
||||
msgs = self._make_messages(3)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
# Content should be identical for short conversations
|
||||
for orig, decayed in zip(msgs, result):
|
||||
assert orig["content"] == decayed["content"]
|
||||
|
||||
def test_old_messages_truncated(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
|
||||
# Recent tool results should be full length
|
||||
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
|
||||
assert len(last_tool_result["content"]) > 2000
|
||||
|
||||
# Oldest tool results should be truncated
|
||||
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
|
||||
assert len(first_tool_result["content"]) < 500
|
||||
|
||||
def test_message_count_preserved(self):
|
||||
msgs = self._make_messages(20)
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert len(result) == len(msgs)
|
||||
|
||||
def test_non_tool_messages_unchanged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _apply_progressive_decay(msgs)
|
||||
assert result[0]["content"] == "Hello"
|
||||
assert result[1]["content"] == "Hi there"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM message folding (Stage B)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageFolding:
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_replaces_old_messages_with_summary(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(return_value="Summary: found important files")
|
||||
|
||||
# Build messages exceeding fold threshold
|
||||
messages = [{"role": "user", "content": "Start task"}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "assistant", "content": f"thinking step {i}"})
|
||||
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
|
||||
|
||||
result = await client._fold_old_messages(messages, "system prompt")
|
||||
|
||||
# Should be significantly shorter
|
||||
assert len(result) < len(messages)
|
||||
# First message should be the summary
|
||||
assert "Context summary" in result[0]["content"]
|
||||
# Recent messages preserved
|
||||
assert len(result) == _FOLD_KEEP_RECENT + 1 # +1 for summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_not_triggered_below_threshold(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
|
||||
assert result == messages
|
||||
client.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fold_graceful_on_llm_failure(self):
|
||||
from llm_client import LLMClient
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
client.chat = AsyncMock(side_effect=Exception("API error"))
|
||||
|
||||
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
|
||||
result = await client._fold_old_messages(messages, "system")
|
||||
|
||||
# On failure, should return original messages unchanged
|
||||
assert len(result) == 40
|
||||
Reference in New Issue
Block a user