Compare commits

...

3 Commits

Author SHA1 Message Date
BattleTag
31812a72ee test: track tests/ directory in version control
tests/test_optimizations.py — 60 pytest cases covering:
- EvidenceGraph: quality scoring, Jaccard merge, async safety,
  hypothesis confidence updates, asset library
- llm_client: tool-result truncation, parallel batch execution,
  progressive context decay, message folding
- orchestrator: parallel dispatch, batched lead generation,
  batched judging
- tool_registry: result cache key derivation

FakeAgent.run signatures updated to BaseAgent.run(task, lead_id=None).

Previously listed in .gitignore (which is itself untracked, so the
ignore rule lives only locally).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:10:31 +08:00
BattleTag
74e6bde13a refactor: lead provenance, unified link path, SSOT cleanup, configurable weights
Five interrelated cleanups:

1. Lead -> Phenomenon provenance
   - Phenomenon.from_lead_id field on the dataclass
   - BaseAgent.run(lead_id=...) writes self._current_lead_id
   - _add_phenomenon auto-injects from agent state (LLM unaware)
   - Orchestrator dispatch passes lead.id; Phase 1/2-auto/4/5 stay None
   - Merge path preserves the first non-None lead_id on collision

2. Unified Phenomenon <-> Hypothesis link path
   - HypothesisAgent only adds hypotheses, never links
   - link_phenomenon_to_hypothesis tool + executor removed
   - All links go through Orchestrator._judge_new_phenomena
   - Phase 2 unconditionally judges after hypothesis generation
   - Gap Analysis judges after each dispatch round
   (Three previously-missing judge calls now in place.)

3. SSOT in agent subclasses
   - Remove RoleTemplate dataclass, ROLE_TEMPLATES dict,
     _instantiate_from_template method
   - Each agent subclass owns name, role, and tool list
   - agent_factory.py shrinks from 299 to 153 lines
   - All 7 agents now route through _AGENT_CLASSES (filesystem,
     registry, communication, network, timeline were previously dead
     subclasses overridden by templates)

4. Configurable edge weights
   - HYPOTHESIS_EDGE_WEIGHTS -> _DEFAULT_EDGE_WEIGHTS (private default)
   - EvidenceGraph(edge_weights=...) override via config.yaml
   - hypothesis_edge_weights section in config.yaml (commented example)
   - main.py and regenerate_report.py read and pass through

5. regenerate_report.py auto-picks the latest run/*/graph_state.json
   when no CLI arg is given (was a hardcoded date path)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:10:15 +08:00
BattleTag
fde96c7d9f docs: rewrite README for EvidenceGraph + 5-phase + 7-agent architecture
Previous README described a Blackboard-based 4-phase, 6-agent system.
The actual code uses:
- EvidenceGraph with typed weighted edges (Phenomenon/Hypothesis/Entity)
- 5 phases (explicit Hypothesis Generation between survey and investigation)
- 7 agents (added HypothesisAgent)

Documents the confidence update formula, Phenomenon Jaccard merging,
Asset Library inode dedup, tool-result caching, Gap Analysis coverage
check, auto-persistence, and the resume mechanism.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:09:59 +08:00
9 changed files with 1212 additions and 322 deletions

231
README.md
View File

@@ -2,43 +2,114 @@
Multi-Agent System for Digital Forensics — 基于大语言模型的多智能体电子取证系统。
系统通过 6 个专业化 Agent 协同工作,对磁盘镜像进行自动化取证分析,最终生成结构化的取证报告。
系统通过 7 个专业化 Agent 协同工作,对磁盘镜像进行自动化取证分析,最终生成结构化的取证报告。Agent 之间不直接通信,通过共享的 **EvidenceGraph**(证据知识图)协作。
## 架构
```
main.py 入口:配置加载、恢复检测、运行管理
main.py 入口:配置加载、镜像选择、断连恢复
├── Orchestrator 阶段流水线调度
├── Orchestrator 阶段流水线调度
│ │
│ ├── FileSystemAgent 磁盘结构、文件系统、删除文件、Prefetch
│ ├── RegistryAgent 注册表分析(系统/用户/网络/软件)
│ ├── CommunicationAgent 邮件、IRC 聊天记录
│ ├── FileSystemAgent 分区/文件系统、目录、删除文件、Prefetch
│ ├── HypothesisAgent 生成假设,链接已有证据
│ ├── RegistryAgent 注册表分析SYSTEM/SOFTWARE/SAM/NTUSER.DAT
│ ├── CommunicationAgent 邮件、IRC/mIRC 聊天记录
│ ├── NetworkAgent 浏览器历史、PCAP 抓包
│ ├── TimelineAgent 跨类别时间线关联
│ └── ReportAgent 综合报告生成
├── Blackboard 共享知识库Evidence + Lead
── LLMClient Claude API 调用ReAct 模式)
├── EvidenceGraph 带类型边的证据知识图(自动持久化
── AgentFactory 角色模板 + 动态 Agent 组合
├── ToolRegistry 工具目录 + 结果缓存
└── LLMClient Claude API 客户端异步、tool-use
```
Agent 之间不直接通信,通过 **Blackboard黑板** 共享发现Evidence和线索Lead
## EvidenceGraph证据知识图
## 调查流程
三类节点 + 类型化加权边:
| 节点 | 前缀 | 含义 |
|---|---|---|
| `Phenomenon` | `ph-*` | 可观测的取证产物(一条具体发现) |
| `Hypothesis` | `hyp-*` | 解释性假设(待验证的论断) |
| `Entity` | `ent-*` | 人、程序、主机、IP 等可复现的实体 |
Phenomenon → Hypothesis 的边类型与权重写死在 `HYPOTHESIS_EDGE_WEIGHTS`
| 边类型 | 权重 | 语义 |
|---|---:|---|
| `direct_evidence` | +0.25 | 现象就是假设所述行为本身 |
| `supports` | +0.15 | 与假设一致但非决定性 |
| `consequence_observed` | +0.15 | 观察到假设预期的结果 |
| `prerequisite_met` | +0.10 | 满足假设的前置条件 |
| `weakens` | 0.10 | 降低假设可能性 |
| `contradicts` | 0.20 | 直接反驳假设 |
置信度更新公式(收敛于 [0, 1]
- 正向边:`delta = weight * (1 - old_conf)`
- 负向边:`delta = weight * old_conf`
跨阈值自动转状态:≥ 0.8 → `supported`,≤ 0.2 → `refuted`,跑完仍 active → `inconclusive`。LLM 只负责挑边类型(分类任务),权重表与状态转移由代码裁决,避免数值幻觉。
新增 Phenomenon 时通过 Jaccard 相似度合并title > 0.6 且 description > 0.4 即视为重复,合并后提升置信度并追加 `corroborating_agents`),避免同一发现被重复入图。
## 五阶段流水线
| 阶段 | 说明 |
|------|------|
| **Phase 1** | FileSystemAgent 勘查磁盘镜像,识别分区、目录结构、关键文件,产出初始 Lead |
| **Phase 2** | 多轮线索追踪 — Lead 按 Agent 类型分组并行派发,最多 10 轮迭代 |
| **Phase 2.5** | 覆盖率缺口分析 — 对照 config.yaml 中的 10 个调查领域,自动补漏 |
| **Phase 3** | TimelineAgent 综合所有 evidence 建立事件时间线 |
| **Phase 4** | ReportAgent 生成 Markdown 格式取证报告 |
| **Phase 1** | FileSystemAgent 勘镜像,识别分区/文件系统/关键路径,产出首批 Phenomenon |
| **Phase 2** | 假设生成 — 优先读 `config.yaml:hypotheses`;未配置则由 HypothesisAgent 从 Phase 1 现象自动生成 3-7 个 |
| **Phase 3** | 假设驱动调查(默认 5 轮迭代)。每轮:一次性为所有 active 假设产出 leads → 按 agent 类型并发派发(信号量 = 3→ 一次性判定新现象与各假设的关系。所有假设收敛即提前退出。末尾:失败 lead 重试一次 + Gap Analysis |
| **Phase 4** | TimelineAgent `build_filesystem_timeline` 生成 MAC 时间线,与 Phenomenon 时间戳关联 |
| **Phase 5** | ReportAgent 综合假设、证据、实体,生成 Markdown 报告 |
## 取证工具链
### Gap AnalysisPhase 3 末)
### Sleuth Kit磁盘取证
`config.yaml:investigation_areas` 列出必须覆盖的调查领域系统信息、用户账户、网络配置、邮件配置、IRC 日志、PCAP、删除文件、Prefetch 等。Orchestrator 两层判定覆盖情况:
通过异步子进程调用 TSK 命令行工具:
1. **关键词匹配**`_AREA_KEYWORDS`)— 扫现有 Phenomenon 标题/描述
2. **工具命中**`_AREA_TOOLS`)— 检查是否调用过该领域的关键工具(如 `enumerate_users``parse_pcap_strings`
未覆盖的领域自动派发 lead最多 3 轮补漏。
## Agent 体系
`AgentFactory` 维护 7 个角色模板(`ROLE_TEMPLATES`),每个模板指定默认工具集。`HypothesisAgent``ReportAgent``BaseAgent` 的子类(额外注册专用工具),其余 5 个 Agent 直接由 `BaseAgent` + 工具列表生成。
### Agent 工作流
`BaseAgent.run` 在 system prompt 中强制四阶段:
```
A. INVESTIGATE 先查图状态 / Asset Library再调取证工具
B. RECORD 每条发现写 add_phenomenon
C. LINK 按需 link_to_entity但禁止凭记忆引用 ph-id必须先 list_phenomena
D. ANSWER 以上完成后再给最终答复
```
prompt 内置**反幻觉规则**:只允许记录工具输出中逐字出现的内容;时间戳/路径/inode 必须来自工具返回;输出被截断须标 `[truncated]`
### 动态 Agent 组合
`AgentFactory.create_specialized_agent()` 应对能力缺口:将工具目录与假设描述喂给 LLM由其挑 3-8 个工具并写角色描述,工厂据此实例化新 Agent 并缓存。
## 工具系统
`tool_registry.py` 启动时调用 `register_all_tools(image_path, partition_offset, graph)`,将所有工具一次性注册到全局 `TOOL_CATALOG`
### 工具结果缓存
`CACHEABLE_TOOLS` 集合标记纯读取/确定性工具partition_info、list_directory、parse_registry_key …)。镜像只读,同 args 调用产出固定,命中缓存直接复用,错误结果不入缓存。
### Asset Library
`EvidenceGraph.asset_library` 按 inode 索引所有已提取文件,避免重复 extract。Agent 通过 `list_assets` / `find_extracted_file` 工具查询。新文件按文件名自动归类到 `registry_hive` / `chat_log` / `prefetch` / `network_capture` / `recycle_bin` 等十类之一。
### 取证工具链
**Sleuth Kit磁盘取证** — 异步子进程调用 TSK
| 工具 | 用途 |
|------|------|
@@ -49,47 +120,43 @@ Agent 之间不直接通信,通过 **Blackboard黑板** 共享发现E
| `srch_strings` | 磁盘字符串搜索 |
| `fls -m` | MAC 时间线生成 |
### regipy注册表解析
**regipy注册表解析** — 直接读 SYSTEM / SOFTWARE / SAM / NTUSER.DAT 二进制,提取系统信息、用户账户、网络配置、已安装软件、邮件账户、关机时间等。
直接解析 Windows 注册表 hive 二进制文件SYSTEM、SOFTWARE、SAM、NTUSER.DAT提取系统信息、用户账户、网络配置、已安装软件、邮件账户、关机时间等
**文件解析器** — Prefetch 二进制(`.pf`、PCAP 字符串提取HTTP 请求 / Host / Cookie / UA、通用文本与二进制读取、正则搜索、Hex dump
### 文件解析器
## 断连恢复与运行归档
- **Prefetch** — 二进制解析 Windows XP .pf 文件(运行次数、最后执行时间)
- **PCAP** — 从抓包文件提取 HTTP 请求、Host、Cookie、User-Agent
- **通用文本/二进制** — 按偏移读取、正则搜索、Hex dump
三层防护:
## 断连恢复与数据归档
1. **EvidenceGraph 自动持久化** — 每次 `add_phenomenon` / `add_hypothesis` / `add_edge` / `add_lead` 等写操作均自动落盘(原子写 `.tmp` 后 rename
2. **Agent 级容错** — 单 Agent 失败 → 该 lead 标 `failed`,连续 3 次失败触发 `AnalysisAborted` 优雅退出Phase 3 末尾对失败 lead 重试一次(`retry=True` 防无限循环)
3. **续跑**`main.py` 启动时扫 `runs/*/graph_state.json`,发现存在但缺 `run_metadata.json` 的目录即提示恢复,并按 graph 当前状态决定从哪一阶段续起
系统设计了三层防护,应对长时间运行中的网络中断:
1. **Blackboard 自动持久化** — 每次 add_evidence / add_lead 自动写盘(原子写入)
2. **Agent 级容错** — 单个 Agent 失败标记 Lead 为 failed不影响其他 Agent自动重试一次
3. **优雅退出** — 连续 3 次 Agent 失败后保存现有成果并干净退出
每次运行自动创建带时间戳的归档目录:
### 运行归档目录
```
runs/
2026-04-02T14-30-00/
config.yaml 配置快照
blackboard_state.json 实时状态(用于恢复
evidence.json 结构化证据导出
leads.json 线索及最终状态
report.md 取证报告
run_metadata.json 运行元数据(时长、统计、错误)
masforensics.log 运行日志
config.yaml 配置快照
graph_state.json 实时状态(续跑用)
phenomena.json 现象导出
hypotheses.json 假设 + 置信度日志
entities.json 实体
edges.json
leads.json 线索及最终状态
extracted/ 从镜像提取的文件
<image>_forensic_report.md 取证报告
run_metadata.json 运行元数据(时长、统计、错误)
masforensics.log 运行日志
```
中断后再次运行 `python main.py`,系统自动检测未完成的运行并提示恢复。
## 快速开始
### 环境要求
- Python >= 3.14
- The Sleuth Kit系统安装提供 `mmls``fls``icat` 等命令)
- 磁盘镜像文件置于 `image/` 目录
- 磁盘镜像文件
### 安装
@@ -99,50 +166,76 @@ uv sync
### 配置
编辑 `config.yaml`,填入 LLM API 地址和密钥
编辑 `config.yaml`
```yaml
agent:
base_url: "https://your-api-proxy.com"
api_key: "sk-your-key"
model: "claude-sonnet-4-6"
max_tokens: 4096
max_tokens: 16384
max_investigation_rounds: 5 # Phase 3 最大迭代轮数
# hypotheses: # 可选:手动指定初始假设
# - title: "嫌疑人主动实施网络嗅探"
# description: "..."
investigation_areas: # Gap Analysis 必须覆盖的领域
- area: system_info
agent: registry
task: "..."
# ...
```
`investigation_areas` 部分定义了必须覆盖的调查领域,可按需增减
未配置 `hypotheses` 时由 HypothesisAgent 自动生成
### 运行
```bash
python main.py
python main.py # 交互式选镜像与分区
python main.py /path/to/image/dir # 指定镜像目录
```
报告和所有结构化数据将保存在 `runs/<timestamp>/` 目录下
中断后再次运行会自动检测未完成的 run 并提示是否续跑
### 仅重生成报告
跑完一次后若只想换提示词或修复报告:
```bash
python regenerate_report.py runs/<timestamp>
```
跳过 Phase 1-4直接从已有 `graph_state.json` 重跑 ReportAgent。
## 项目结构
```
MASForensics/
├── main.py 入口
├── orchestrator.py 流水线调度
├── blackboard.py 共享知识库
├── llm_client.py LLM API 客户端
├── base_agent.py Agent 基类
├── config.yaml 配置文件
├── main.py 入口、镜像选择、断连恢复
├── orchestrator.py 五阶段流水线调度
├── evidence_graph.py 证据知识图 + 边权重表 + 持久化
├── base_agent.py Agent 基类 + 内建 graph 工具
├── agent_factory.py 角色模板 + 动态 Agent 组合
├── tool_registry.py 工具目录 + 结果缓存 + 自动归类
├── llm_client.py LLM API 客户端
├── log_config.py 彩色终端日志 + 文件日志
├── regenerate_report.py 从已有 graph_state 重生成报告
├── config.yaml 配置 + 调查领域 + 可选假设
├── agents/
│ ├── filesystem.py 文件系统 Agent
│ ├── registry.py 注册表 Agent
│ ├── communication.py 通信 Agent
── network.py 网络 Agent
│ ├── timeline.py 时间线 Agent
│ └── report.py 报告 Agent
│ ├── hypothesis.py HypothesisAgentadd_hypothesis、link
│ ├── report.py ReportAgent综合报告自带读取工具
│ ├── timeline.py TimelineAgent保留以备扩展
── ... filesystem/registry/communication/network同上
├── tools/
│ ├── sleuthkit.py Sleuth Kit 封装
│ ├── registry.py 注册表解析(regipy
│ └── parsers.py 文件格式解析
├── image/ 磁盘镜像
├── extracted/ 提取的文件(运行时生成)
└── runs/ 运行归档
│ ├── sleuthkit.py TSK 异步封装
│ ├── registry.py regipy 解析
│ └── parsers.py Prefetch / PCAP / 通用文件解析
├── image/ 磁盘镜像(用户放)
├── runs/ 运行归档
└── tests/
└── test_optimizations.py
```
## 依赖
@@ -152,14 +245,16 @@ MASForensics/
| `httpx[socks]` | 异步 HTTP 客户端(支持 SOCKS 代理) |
| `pyyaml` | 配置文件解析 |
| `regipy` | Windows 注册表 hive 解析 |
| `pytest` / `pytest-asyncio` | 测试 |
## 当前案例
## 默认案例
默认配置分析 **CFReDS Hacking Case**NIST 标准取证教学镜像):
**CFReDS Hacking Case**NIST 标准取证教学镜像):
- 镜像SCHARDT.001~4.6GBIBM 硬盘8 个分段)
- 镜像SCHARDT.001~4.6 GBIBM 硬盘8 个分段)
- 系统Windows XP
- 场景:涉嫌黑客入侵的计算机取证分析
- 完整镜像 MD5`AEE4FCD9301C03B3B054623CA261959A``config.yaml` 含各分段 MD5 用于校验)
## 测试

View File

@@ -1,150 +1,50 @@
"""Agent Factory — composes agents from tool registry and role templates.
"""Agent Factory — instantiates agents from registered classes.
Provides both pre-defined agent templates (filesystem, registry, etc.)
and LLM-driven dynamic agent composition for capability gaps.
Each agent type has a dedicated subclass under agents/ that owns its name,
role description, and tool list (single source of truth). The factory just
maps agent_type → class. Also supports LLM-driven dynamic composition for
capability gaps via create_specialized_agent().
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from base_agent import BaseAgent
from evidence_graph import EvidenceGraph
from llm_client import LLMClient
from tool_registry import TOOL_CATALOG, ToolDefinition
from tool_registry import TOOL_CATALOG
# Agent classes with custom tools — keyed by template name
_AGENT_CLASSES: dict[str, type] = {}
# Agent classes keyed by name. Populated lazily to avoid circular imports.
_AGENT_CLASSES: dict[str, type[BaseAgent]] = {}
def _load_agent_classes() -> None:
"""Lazy-import custom agent classes to avoid circular imports."""
"""Lazy-import agent classes to avoid circular imports."""
if _AGENT_CLASSES:
return
from agents.communication import CommunicationAgent
from agents.filesystem import FileSystemAgent
from agents.hypothesis import HypothesisAgent
from agents.network import NetworkAgent
from agents.registry import RegistryAgent
from agents.report import ReportAgent
from agents.timeline import TimelineAgent
_AGENT_CLASSES["filesystem"] = FileSystemAgent
_AGENT_CLASSES["registry"] = RegistryAgent
_AGENT_CLASSES["communication"] = CommunicationAgent
_AGENT_CLASSES["network"] = NetworkAgent
_AGENT_CLASSES["timeline"] = TimelineAgent
_AGENT_CLASSES["hypothesis"] = HypothesisAgent
_AGENT_CLASSES["report"] = ReportAgent
logger = logging.getLogger(__name__)
@dataclass
class RoleTemplate:
"""Pre-defined agent archetype."""
name: str
role: str
default_tools: list[str] # tool names from TOOL_CATALOG
tags: list[str] = field(default_factory=list)
# Pre-defined templates matching the original 6 agents + hypothesis agent.
ROLE_TEMPLATES: dict[str, RoleTemplate] = {
"filesystem": RoleTemplate(
name="filesystem",
role=(
"File system forensic analyst. You examine disk image partition layouts, "
"directory structures, file metadata, and recover deleted files. "
"You identify suspicious files, installed programs, and user data locations. "
"You also handle Recycle Bin forensics and Prefetch execution evidence."
),
default_tools=[
"partition_info", "filesystem_info", "list_directory",
"extract_file", "find_file", "search_strings",
"parse_prefetch", "count_deleted_files",
"read_text_file", "search_text_file", "read_binary_preview",
],
tags=["filesystem", "disk", "files", "deleted", "prefetch"],
),
"registry": RoleTemplate(
name="registry",
role=(
"Windows registry forensic analyst. You parse registry hive files "
"(SYSTEM, SOFTWARE, SAM, NTUSER.DAT) to extract system configuration, "
"user accounts, installed software, network settings, email accounts, "
"and other Windows artifacts."
),
default_tools=[
"extract_file", "list_directory",
"parse_registry_key", "list_installed_software",
"get_user_activity", "search_registry",
"get_system_info", "get_timezone_info", "get_computer_name",
"get_shutdown_time", "enumerate_users",
"get_network_interfaces", "get_email_config",
],
tags=["registry", "windows", "system", "user", "software"],
),
"communication": RoleTemplate(
name="communication",
role=(
"Communication forensic analyst. You analyze email files (.dbx, .pst), "
"IRC/mIRC chat logs, newsgroup data, and other messaging artifacts "
"to identify communication patterns and contacts."
),
default_tools=[
"list_directory", "extract_file",
"read_text_file", "read_binary_preview",
"list_extracted_dir", "search_strings",
"search_text_file", "read_text_file_section",
],
tags=["email", "chat", "irc", "messaging", "communication"],
),
"network": RoleTemplate(
name="network",
role=(
"Network forensic analyst. You analyze browser history, cookies, "
"network captures (PCAP), wireless artifacts, and other network-related "
"evidence to reconstruct online activities."
),
default_tools=[
"list_directory", "extract_file",
"read_text_file", "read_binary_preview",
"list_extracted_dir", "search_strings",
"search_text_file", "read_text_file_section",
"parse_pcap_strings",
],
tags=["network", "browser", "pcap", "http", "internet"],
),
"timeline": RoleTemplate(
name="timeline",
role=(
"Timeline correlation analyst. You build chronological timelines "
"by combining filesystem MAC times with evidence from other agents. "
"You identify temporal patterns and correlate events across categories."
),
default_tools=[
"build_filesystem_timeline",
],
tags=["timeline", "correlation", "temporal"],
),
"report": RoleTemplate(
name="report",
role=(
"Forensic report writer. You synthesize all evidence and hypotheses "
"into a comprehensive forensic analysis report with executive summary, "
"detailed findings organized by hypothesis, timeline of events, and conclusions."
),
default_tools=[], # Report agent uses only graph query tools
tags=["report", "summary", "writing"],
),
"hypothesis": RoleTemplate(
name="hypothesis",
role=(
"Hypothesis analyst. You review all phenomena discovered so far "
"and formulate investigative hypotheses about what happened on the system. "
"For each hypothesis, identify which existing phenomena support or contradict it."
),
default_tools=[], # Uses only graph query + hypothesis tools
tags=["hypothesis", "analysis", "reasoning"],
),
}
class AgentFactory:
"""Creates agents from templates or dynamically via LLM composition."""
"""Creates agents from registered classes or dynamically via LLM composition."""
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
self.llm = llm
@@ -152,40 +52,20 @@ class AgentFactory:
self._cache: dict[str, BaseAgent] = {}
def get_or_create_agent(self, agent_type: str) -> BaseAgent | None:
"""Get a cached agent or create one from a template."""
"""Get a cached agent or instantiate one from its registered class."""
if agent_type in self._cache:
return self._cache[agent_type]
template = ROLE_TEMPLATES.get(agent_type)
if template is None:
logger.warning("No template for agent type: %s", agent_type)
return None
# Use custom agent class if one exists, otherwise BaseAgent
_load_agent_classes()
agent_cls = _AGENT_CLASSES.get(agent_type)
if agent_cls is not None:
agent = agent_cls(self.llm, self.graph)
else:
agent = self._instantiate_from_template(template)
if agent_cls is None:
logger.warning("No agent class for type: %s", agent_type)
return None
agent = agent_cls(self.llm, self.graph)
self._cache[agent_type] = agent
return agent
def _instantiate_from_template(self, template: RoleTemplate) -> BaseAgent:
"""Create a BaseAgent from a role template, registering tools from the catalog."""
agent = BaseAgent(self.llm, self.graph)
agent.name = template.name
agent.role = template.role
for tool_name in template.default_tools:
td = TOOL_CATALOG.get(tool_name)
if td is None:
logger.warning("Tool '%s' not in catalog (template: %s)", tool_name, template.name)
continue
agent.register_tool(td.name, td.description, td.input_schema, td.executor)
return agent
async def create_specialized_agent(
self,
hypothesis_title: str,
@@ -220,18 +100,15 @@ class AgentFactory:
messages=[{"role": "user", "content": prompt}],
)
# Parse response — try to extract JSON
try:
config = json.loads(response)
except json.JSONDecodeError:
# Try to find JSON in the response
import re
match = re.search(r'\{.*\}', response, re.DOTALL)
if match:
config = json.loads(match.group())
else:
logger.error("Failed to parse agent composition response: %s", response[:300])
# Fallback: create a generic agent with all tools
return self._create_fallback_agent(capability_gap)
agent_name = config.get("agent_name", "specialized")
@@ -239,13 +116,11 @@ class AgentFactory:
strategy = config.get("strategy", "")
tool_names = config.get("tools", [])
# Validate tool names against catalog
valid_tools = [t for t in tool_names if t in TOOL_CATALOG]
if not valid_tools:
logger.warning("No valid tools selected by LLM, using fallback")
return self._create_fallback_agent(capability_gap)
# Build agent
agent = BaseAgent(self.llm, self.graph)
agent.name = agent_name
agent.role = f"{role_text}\n\nInvestigation Strategy:\n{strategy}"

View File

@@ -1,12 +1,15 @@
"""Hypothesis Agent — analyzes phenomena and generates investigative hypotheses."""
"""Hypothesis Agent — generates investigative hypotheses from phenomena.
Generates hypotheses only. Phenomenon→Hypothesis linking is handled centrally
by Orchestrator._judge_new_phenomena, so all link logic lives in one place.
"""
from __future__ import annotations
import json
import logging
from base_agent import BaseAgent
from evidence_graph import EvidenceGraph, HYPOTHESIS_EDGE_WEIGHTS
from evidence_graph import EvidenceGraph
from llm_client import LLMClient
logger = logging.getLogger(__name__)
@@ -17,8 +20,7 @@ class HypothesisAgent(BaseAgent):
role = (
"Hypothesis analyst. You review all phenomena discovered so far "
"and formulate investigative hypotheses about what happened on this system. "
"Your ultimate goal: build the most complete picture of events that occurred. "
"For each hypothesis, identify which existing phenomena support or contradict it."
"Your ultimate goal: build the most complete picture of events that occurred."
)
def __init__(self, llm: LLMClient, graph: EvidenceGraph) -> None:
@@ -26,10 +28,6 @@ class HypothesisAgent(BaseAgent):
self._register_hypothesis_tools()
def _register_hypothesis_tools(self) -> None:
"""Register hypothesis-specific tools."""
valid_edge_types = list(HYPOTHESIS_EDGE_WEIGHTS.keys())
self.register_tool(
name="add_hypothesis",
description=(
@@ -53,44 +51,6 @@ class HypothesisAgent(BaseAgent):
executor=self._add_hypothesis,
)
self.register_tool(
name="link_phenomenon_to_hypothesis",
description=(
"Link an existing phenomenon to a hypothesis with a relationship type. "
f"Valid relationship types: {', '.join(valid_edge_types)}. "
"direct_evidence = the phenomenon IS the hypothesis. "
"supports = consistent with the hypothesis. "
"prerequisite_met = a necessary condition is satisfied. "
"consequence_observed = an expected result of the hypothesis is found. "
"contradicts = directly contradicts the hypothesis. "
"weakens = makes the hypothesis less likely."
),
input_schema={
"type": "object",
"properties": {
"phenomenon_id": {
"type": "string",
"description": "ID of the phenomenon (e.g. 'ph-a1b2c3d4').",
},
"hypothesis_id": {
"type": "string",
"description": "ID of the hypothesis (e.g. 'hyp-e5f6g7h8').",
},
"edge_type": {
"type": "string",
"enum": valid_edge_types,
"description": "The edge_type of the relationship.",
},
"reason": {
"type": "string",
"description": "The reason this relationship holds (1-2 sentences).",
},
},
"required": ["phenomenon_id", "hypothesis_id", "edge_type", "reason"],
},
executor=self._link_phenomenon_to_hypothesis,
)
async def _add_hypothesis(self, title: str, description: str) -> str:
hid = await self.graph.add_hypothesis(
title=title,
@@ -98,33 +58,3 @@ class HypothesisAgent(BaseAgent):
created_by=self.name,
)
return f"Hypothesis created: {hid}{title} (confidence: 0.50)"
async def _link_phenomenon_to_hypothesis(
self,
phenomenon_id: str,
hypothesis_id: str,
edge_type: str = "",
reason: str = "",
# Common LLM misnaming — accept as fallbacks
relationship: str = "",
note: str = "",
) -> str:
edge_type = edge_type or relationship
reason = reason or note
if not edge_type:
return "Error: edge_type is required."
try:
new_conf = await self.graph.update_hypothesis_confidence(
hyp_id=hypothesis_id,
phenomenon_id=phenomenon_id,
edge_type=edge_type,
reason=reason,
)
weight = HYPOTHESIS_EDGE_WEIGHTS[edge_type]
direction = "+" if weight > 0 else ""
return (
f"Linked: {phenomenon_id} —[{edge_type}]→ {hypothesis_id} "
f"(weight: {direction}{weight}, new confidence: {new_conf:.3f})"
)
except ValueError as e:
return f"Error linking: {e}"

View File

@@ -37,6 +37,7 @@ class BaseAgent:
self._tools: dict[str, dict] = {} # name -> schema
self._executors: dict[str, Any] = {} # name -> async callable
self._work_log: list[str] = []
self._current_lead_id: str | None = None
def register_tool(
self,
@@ -107,11 +108,12 @@ class BaseAgent:
f"- Do NOT fabricate execution timestamps — only report timestamps returned by tools"
)
async def run(self, task: str) -> str:
async def run(self, task: str, lead_id: str | None = None) -> str:
"""Run this agent with a specific task."""
_log(task, event="agent_start", agent=self.name)
self.graph.agent_status[self.name] = "running"
self.graph._current_agent = self.name
self._current_lead_id = lead_id
self._register_graph_tools()
@@ -375,6 +377,7 @@ class BaseAgent:
raw_data=raw_data,
timestamp=timestamp,
source_tool=source_tool,
from_lead_id=self._current_lead_id,
)
if merged:
return f"Phenomenon merged into existing: {pid}{title} (corroboration boost)"

View File

@@ -18,10 +18,12 @@ from pathlib import Path
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Predefined edge weights for Phenomenon → Hypothesis relationships.
# Default edge weights for Phenomenon → Hypothesis relationships.
# LLM only picks the edge type (categorical); the weight is looked up here.
# Override per-graph via EvidenceGraph(edge_weights=...) or config.yaml's
# `hypothesis_edge_weights` section.
# ---------------------------------------------------------------------------
HYPOTHESIS_EDGE_WEIGHTS: dict[str, float] = {
_DEFAULT_EDGE_WEIGHTS: dict[str, float] = {
"direct_evidence": +0.25,
"supports": +0.15,
"prerequisite_met": +0.10,
@@ -94,6 +96,7 @@ class Phenomenon:
confidence: float = 1.0
source_tool: str = ""
corroborating_agents: list[str] = field(default_factory=list)
from_lead_id: str | None = None
created_at: str = ""
def to_dict(self) -> dict:
@@ -239,8 +242,12 @@ class EvidenceGraph:
self,
case_info: dict | None = None,
persist_path: Path | None = None,
edge_weights: dict[str, float] | None = None,
) -> None:
self.case_info: dict = case_info or {}
self.edge_weights: dict[str, float] = (
dict(edge_weights) if edge_weights else dict(_DEFAULT_EDGE_WEIGHTS)
)
self.image_path: str = ""
self.partition_offset: int = 0
self.extracted_dir: str = "extracted"
@@ -304,12 +311,17 @@ class EvidenceGraph:
self._persist_path = old
@classmethod
def load_state(cls, path: Path) -> EvidenceGraph:
def load_state(
cls,
path: Path,
edge_weights: dict[str, float] | None = None,
) -> EvidenceGraph:
"""Restore an EvidenceGraph from a saved JSON state file."""
data = json.loads(path.read_text())
graph = cls(
case_info=data.get("case_info", {}),
persist_path=path,
edge_weights=edge_weights,
)
graph.image_path = data.get("image_path", "")
graph.partition_offset = data.get("partition_offset", 0)
@@ -403,6 +415,7 @@ class EvidenceGraph:
raw_data: dict | None = None,
timestamp: str | None = None,
source_tool: str = "",
from_lead_id: str | None = None,
) -> tuple[str, bool]:
"""Add a phenomenon. Returns (id, was_merged).
@@ -419,6 +432,8 @@ class EvidenceGraph:
for k, v in raw_data.items():
if k not in similar.raw_data:
similar.raw_data[k] = v
if from_lead_id and similar.from_lead_id is None:
similar.from_lead_id = from_lead_id
self._auto_save()
return similar.id, True
@@ -437,6 +452,7 @@ class EvidenceGraph:
timestamp=timestamp,
confidence=confidence,
source_tool=source_tool,
from_lead_id=from_lead_id,
created_at=datetime.now().isoformat(),
)
self.phenomena[pid] = ph
@@ -532,14 +548,14 @@ class EvidenceGraph:
) -> float:
"""Update hypothesis confidence based on a phenomenon linkage.
The edge_type must be one of HYPOTHESIS_EDGE_WEIGHTS keys.
Weight is looked up from the predefined table, NOT judged by LLM.
The edge_type must be one of self.edge_weights keys.
Weight is looked up from the configured table, NOT judged by LLM.
Returns the new confidence value.
"""
if edge_type not in HYPOTHESIS_EDGE_WEIGHTS:
if edge_type not in self.edge_weights:
raise ValueError(
f"Invalid hypothesis edge type: {edge_type}. "
f"Must be one of: {list(HYPOTHESIS_EDGE_WEIGHTS.keys())}"
f"Must be one of: {list(self.edge_weights.keys())}"
)
async with self._lock:
@@ -549,7 +565,7 @@ class EvidenceGraph:
if hyp is None:
raise ValueError(f"Hypothesis not found: {hyp_id}")
weight = HYPOTHESIS_EDGE_WEIGHTS[edge_type]
weight = self.edge_weights[edge_type]
old_conf = hyp.confidence
if weight > 0:

View File

@@ -229,6 +229,7 @@ async def async_main() -> None:
graph = EvidenceGraph(
case_info=config.get("cfreds_hacking_case", {}),
persist_path=run_dir / "graph_state.json",
edge_weights=config.get("hypothesis_edge_weights"),
)
graph.image_path = image_path
graph.partition_offset = partition_offset

View File

@@ -11,7 +11,7 @@ from datetime import datetime
from pathlib import Path
from agent_factory import AgentFactory
from evidence_graph import EvidenceGraph, HYPOTHESIS_EDGE_WEIGHTS
from evidence_graph import EvidenceGraph
from llm_client import LLMClient
logger = logging.getLogger(__name__)
@@ -149,7 +149,8 @@ class Orchestrator:
await agent.run(
f"Investigate this lead: {lead.description}\n"
f"{hyp_line}"
f"Focus area: {lead.target_agent}"
f"Focus area: {lead.target_agent}",
lead_id=lead.id,
)
await self.graph.mark_lead_completed(lead.id)
self._failure_count = 0
@@ -209,11 +210,9 @@ class Orchestrator:
"1. Specific and testable\n"
"2. About a distinct aspect of activity (e.g., hacking tools, communication, "
"network attacks, data theft)\n\n"
"For each hypothesis:\n"
"- Call add_hypothesis to create it\n"
"- Then call link_phenomenon_to_hypothesis to link relevant existing phenomena\n"
"- Choose the relationship type carefully: direct_evidence, supports, "
"prerequisite_met, consequence_observed, contradicts, or weakens\n\n"
"Call add_hypothesis for each. The orchestrator will automatically link "
"relevant existing phenomena to each hypothesis after you finish — you do "
"not need to (and cannot) create those links yourself.\n\n"
"The ultimate goal is to reconstruct a detailed timeline of what happened on this host."
)
@@ -333,7 +332,7 @@ class Orchestrator:
if not unlinked:
return
valid_types = list(HYPOTHESIS_EDGE_WEIGHTS.keys())
valid_types = list(self.graph.edge_weights.keys())
hyp_section = "\n".join(
f" [{h.id}] {h.title}: {h.description}" for h in active
@@ -370,7 +369,7 @@ class Orchestrator:
if (
hyp_id in self.graph.hypotheses
and ph_id in self.graph.phenomena
and edge_type in HYPOTHESIS_EDGE_WEIGHTS
and edge_type in self.graph.edge_weights
):
await self.graph.update_hypothesis_confidence(
hyp_id=hyp_id,
@@ -413,7 +412,7 @@ class Orchestrator:
ph_id = j.get("phenomenon_id", "")
edge_type = j.get("edge_type", "")
reason = j.get("reason", "")
if ph_id in self.graph.phenomena and edge_type in HYPOTHESIS_EDGE_WEIGHTS:
if ph_id in self.graph.phenomena and edge_type in self.graph.edge_weights:
await self.graph.update_hypothesis_confidence(
hyp_id=hyp.id,
phenomenon_id=ph_id,
@@ -505,6 +504,7 @@ class Orchestrator:
break
_log(f"Gap fill round {round_num}: {len(pending)} leads", event="dispatch")
await self._dispatch_leads_parallel(pending)
await self._judge_new_phenomena()
# ---- Run archiving -------------------------------------------------------
@@ -604,11 +604,13 @@ class Orchestrator:
manual_hypotheses = self.config.get("hypotheses", [])
if manual_hypotheses:
await self._generate_hypotheses_manual(manual_hypotheses)
if self.graph.phenomena:
await self._judge_new_phenomena()
else:
await self._generate_hypotheses_auto()
# Unified judge step — link Phase 1 phenomena to newly-created hypotheses
if self.graph.phenomena and self.graph.hypotheses:
await self._judge_new_phenomena()
for h in self.graph.hypotheses.values():
_log(f" {h.summary()}", event="hypothesis")
_log(

View File

@@ -13,8 +13,16 @@ from tool_registry import register_all_tools
async def main() -> None:
# Find the run to regenerate from
run_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("runs/2026-04-02T15-11-25")
# Find the run: CLI arg, or latest run with a graph_state.json
if len(sys.argv) > 1:
run_dir = Path(sys.argv[1])
else:
states = sorted(Path("runs").glob("*/graph_state.json"), reverse=True)
if not states:
print("No runs found in runs/")
return
run_dir = states[0].parent
print(f"Using latest run: {run_dir.name}")
state_path = run_dir / "graph_state.json"
if not state_path.exists():
@@ -24,8 +32,11 @@ async def main() -> None:
config = yaml.safe_load(open("config.yaml"))
agent_cfg = config["agent"]
# Load graph
graph = EvidenceGraph.load_state(state_path)
# Load graph (edge_weights from config — applied to the loaded graph)
graph = EvidenceGraph.load_state(
state_path,
edge_weights=config.get("hypothesis_edge_weights"),
)
print(f"Loaded: {graph.stats_summary()}")
# LLM client with larger max_tokens for report

957
tests/test_optimizations.py Normal file
View File

@@ -0,0 +1,957 @@
"""Tests for dedup/quality scoring, context management, parallel dispatch, and evidence graph."""
from __future__ import annotations
import asyncio
import time
import pytest
from evidence_graph import (
EvidenceGraph, Phenomenon, Lead,
_compute_quality_score, _jaccard_similarity,
)
from llm_client import (
_truncate_tool_result, _partition_tool_calls, _ToolBatch, READ_ONLY_TOOLS,
_apply_progressive_decay, _FOLD_THRESHOLD, _FOLD_KEEP_RECENT,
TOOL_RESULT_TAG, TOOL_RESULT_END,
)
from tool_registry import (
_tool_result_cache, _cache_key, _make_cached, CACHEABLE_TOOLS,
)
# ---------------------------------------------------------------------------
# Quality scoring & similarity
# ---------------------------------------------------------------------------
class TestQualityScore:
def test_full_score(self):
score = _compute_quality_score(
source_tool="list_dir",
timestamp="2024-01-01",
raw_data={"key": "val"},
description="A" * 50,
related_ids=["ph-0"],
)
assert score == pytest.approx(1.0)
def test_minimal_score(self):
score = _compute_quality_score(
source_tool="",
timestamp=None,
raw_data={},
description="short",
related_ids=[],
)
assert score == pytest.approx(0.0)
def test_partial_score(self):
score = _compute_quality_score(
source_tool="parse_registry_key",
timestamp=None,
raw_data={},
description="A" * 50,
related_ids=[],
)
assert score == pytest.approx(0.40)
class TestJaccardSimilarity:
def test_identical(self):
assert _jaccard_similarity("hello world", "hello world") == pytest.approx(1.0)
def test_no_overlap(self):
assert _jaccard_similarity("hello world", "foo bar") == pytest.approx(0.0)
def test_partial_overlap(self):
sim = _jaccard_similarity("User account Mr. Evil", "User account Mr. Evil enumeration")
assert 0.5 < sim < 1.0
def test_empty_string(self):
assert _jaccard_similarity("", "hello") == pytest.approx(0.0)
# ---------------------------------------------------------------------------
# Evidence graph: phenomenon dedup
# ---------------------------------------------------------------------------
class TestPhenomenonDedup:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_new_phenomenon_gets_quality_score(self, graph):
pid, merged = await graph.add_phenomenon(
source_agent="fs", category="filesystem",
title="Directory listing",
description="Found important files in the root directory of the disk image",
source_tool="list_directory",
)
assert not merged
ph = graph.phenomena[pid]
assert ph.confidence == pytest.approx(0.40)
@pytest.mark.asyncio
async def test_identical_phenomenon_merges(self, graph):
pid1, m1 = await graph.add_phenomenon(
source_agent="fs", category="registry",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003 in the SAM hive",
source_tool="enumerate_users",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="User account Mr. Evil",
description="User account Mr. Evil discovered with RID 1003 in SAM registry hive",
source_tool="enumerate_users",
)
assert not m1
assert m2
assert pid1 == pid2
ph = graph.phenomena[pid1]
assert "reg" in ph.corroborating_agents
assert ph.confidence > 0.40
@pytest.mark.asyncio
async def test_different_category_no_merge(self, graph):
pid1, _ = await graph.add_phenomenon(
source_agent="fs", category="filesystem",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003",
source_tool="list_directory",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="User account Mr. Evil",
description="Found user account Mr. Evil with RID 1003",
source_tool="enumerate_users",
)
assert not m2
assert pid1 != pid2
@pytest.mark.asyncio
async def test_dissimilar_titles_no_merge(self, graph):
await graph.add_phenomenon(
source_agent="fs", category="registry",
title="Installed software list",
description="Found 25 installed programs",
source_tool="list_installed_software",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="Last shutdown time",
description="System last shut down at 2004-08-27",
source_tool="get_shutdown_time",
)
assert not m2
@pytest.mark.asyncio
async def test_merge_adds_raw_data(self, graph):
pid1, _ = await graph.add_phenomenon(
source_agent="fs", category="registry",
title="Network config",
description="Network interface configuration with IP address and MAC address details",
raw_data={"ip": "192.168.1.1"},
source_tool="get_network_interfaces",
)
pid2, m2 = await graph.add_phenomenon(
source_agent="reg", category="registry",
title="Network config",
description="Network interface configuration including DHCP and MAC address info",
raw_data={"mac": "00:11:22:33:44:55"},
source_tool="get_network_interfaces",
)
assert m2
ph = graph.phenomena[pid1]
assert "ip" in ph.raw_data
assert "mac" in ph.raw_data
# ---------------------------------------------------------------------------
# Hypothesis confidence updates
# ---------------------------------------------------------------------------
class TestHypothesisConfidence:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_supports_increases_confidence(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
old = graph.hypotheses[hid].confidence
new = await graph.update_hypothesis_confidence(hid, pid, "supports", "reason")
assert new > old
@pytest.mark.asyncio
async def test_contradicts_decreases_confidence(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
old = graph.hypotheses[hid].confidence
new = await graph.update_hypothesis_confidence(hid, pid, "contradicts", "reason")
assert new < old
@pytest.mark.asyncio
async def test_direct_evidence_has_largest_effect(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
conf = await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "reason")
# direct_evidence weight is +0.25 * (1-0.5) = +0.125
assert conf == pytest.approx(0.625)
@pytest.mark.asyncio
async def test_confidence_log_tracked(self, graph):
pid, _ = await graph.add_phenomenon("fs", "filesystem", "test", "test desc", source_tool="t")
hid = await graph.add_hypothesis("test hyp", "desc", created_by="test")
await graph.update_hypothesis_confidence(hid, pid, "supports", "because reasons")
log = graph.hypotheses[hid].confidence_log
assert len(log) == 1
assert log[0]["edge_type"] == "supports"
assert log[0]["reason"] == "because reasons"
@pytest.mark.asyncio
async def test_supported_status_at_threshold(self, graph):
hid = await graph.add_hypothesis("test", "desc", created_by="test")
# Push confidence up with multiple direct_evidence links
for i in range(5):
pid, _ = await graph.add_phenomenon("fs", "filesystem", f"test {i}", f"desc {i}", source_tool="t")
await graph.update_hypothesis_confidence(hid, pid, "direct_evidence", "proof")
assert graph.hypotheses[hid].status == "supported"
# ---------------------------------------------------------------------------
# Context window management
# ---------------------------------------------------------------------------
class TestTruncateToolResult:
def test_short_text_unchanged(self):
text = "short result"
assert _truncate_tool_result(text) == text
def test_long_text_truncated(self):
text = "x" * 5000
result = _truncate_tool_result(text, max_chars=3000)
assert len(result) < 3100
assert "truncated" in result
assert "5000" in result
def test_exact_boundary(self):
text = "x" * 3000
assert _truncate_tool_result(text, max_chars=3000) == text
# ---------------------------------------------------------------------------
# Parallel dispatch & async safety
# ---------------------------------------------------------------------------
class TestGraphAsync:
@pytest.mark.asyncio
async def test_concurrent_add_phenomenon(self):
graph = EvidenceGraph()
async def add_batch(agent_name: str, count: int):
for i in range(count):
await graph.add_phenomenon(
source_agent=agent_name,
category=f"cat-{agent_name}",
title=f"Phenomenon {agent_name}-{i}",
description=f"Description from {agent_name} number {i} with enough detail to be meaningful",
source_tool="test_tool",
)
await asyncio.gather(
add_batch("agent_a", 10),
add_batch("agent_b", 10),
add_batch("agent_c", 10),
)
assert len(graph.phenomena) == 30
@pytest.mark.asyncio
async def test_concurrent_add_lead(self):
graph = EvidenceGraph()
async def add_leads(count: int):
for i in range(count):
await graph.add_lead(target_agent="network", description=f"Lead {i}")
await asyncio.gather(add_leads(10), add_leads(10))
pending = await graph.get_pending_leads()
assert len(pending) == 20
@pytest.mark.asyncio
async def test_mark_lead_completed(self):
graph = EvidenceGraph()
lid = await graph.add_lead(target_agent="fs", description="test lead")
pending = await graph.get_pending_leads()
assert len(pending) == 1
await graph.mark_lead_completed(lid)
pending = await graph.get_pending_leads()
assert len(pending) == 0
class TestParallelDispatch:
@pytest.mark.asyncio
async def test_different_agent_types_run_concurrently(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
execution_log: list[tuple[str, float, float]] = []
class FakeAgent:
def __init__(self, name):
self.name = name
async def run(self, task, lead_id=None):
start = time.monotonic()
await asyncio.sleep(0.05)
end = time.monotonic()
execution_log.append((self.name, start, end))
factory._cache = {
"filesystem": FakeAgent("filesystem"),
"registry": FakeAgent("registry"),
"network": FakeAgent("network"),
}
leads = [
Lead(id="l1", target_agent="filesystem", description="fs task"),
Lead(id="l2", target_agent="registry", description="reg task"),
Lead(id="l3", target_agent="network", description="net task"),
]
graph.leads = leads
start_time = time.monotonic()
await orch._dispatch_leads_parallel(leads)
total_time = time.monotonic() - start_time
assert total_time < 0.12, f"Expected parallel execution but took {total_time:.3f}s"
assert len(execution_log) == 3
@pytest.mark.asyncio
async def test_same_agent_type_runs_serially(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
execution_order: list[str] = []
class FakeAgent:
def __init__(self, name):
self.name = name
async def run(self, task, lead_id=None):
execution_order.append(task)
await asyncio.sleep(0.01)
factory._cache = {"filesystem": FakeAgent("filesystem")}
leads = [
Lead(id="l1", target_agent="filesystem", description="task 1"),
Lead(id="l2", target_agent="filesystem", description="task 2"),
]
graph.leads = leads
await orch._dispatch_leads_parallel(leads)
assert len(execution_order) == 2
assert "task 1" in execution_order[0]
assert "task 2" in execution_order[1]
# ---------------------------------------------------------------------------
# Asset library
# ---------------------------------------------------------------------------
class TestAssetLibrary:
@pytest.fixture
def graph(self):
return EvidenceGraph()
@pytest.mark.asyncio
async def test_register_asset(self, graph):
aid, existed = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
assert not existed
assert aid.startswith("asset-")
assert len(graph.asset_library) == 1
@pytest.mark.asyncio
async def test_dedup_by_inode(self, graph):
aid1, _ = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
aid2, existed = await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="registry",
)
assert existed
assert aid1 == aid2
assert len(graph.asset_library) == 1
@pytest.mark.asyncio
async def test_query_by_category(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
await graph.register_asset(
inode="10090-128-1", original_path="mIRC/logs/hackers.log",
local_path="extracted/hackers.log", category="chat_log",
filename="hackers.log", size_bytes=200, extracted_by="comm",
)
results = graph.query_assets(category="registry_hive")
assert len(results) == 1
assert results[0].filename == "system"
@pytest.mark.asyncio
async def test_lookup_by_inode(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
asset = graph.lookup_asset_by_inode("334-128-4")
assert asset is not None
assert asset.local_path == "extracted/system"
assert graph.lookup_asset_by_inode("999-128-4") is None
@pytest.mark.asyncio
async def test_persistence_round_trip(self, graph, tmp_path):
await graph.register_asset(
inode="334-128-4",
original_path="WINDOWS/system32/config/SYSTEM",
local_path="extracted/system",
category="registry_hive",
filename="system",
size_bytes=262144,
extracted_by="filesystem",
)
path = tmp_path / "test_state.json"
graph.save_state(path)
loaded = EvidenceGraph.load_state(path)
assert len(loaded.asset_library) == 1
assert loaded.lookup_asset_by_inode("334-128-4") is not None
assert loaded.lookup_asset_by_inode("334-128-4").category == "registry_hive"
@pytest.mark.asyncio
async def test_list_assets(self, graph):
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=8192, extracted_by="fs",
)
summaries = graph.list_assets()
assert len(summaries) == 1
assert "system" in summaries[0]
assert "registry_hive" in summaries[0]
@pytest.mark.asyncio
async def test_query_by_filename(self, graph):
await graph.register_asset(
inode="10080-128-3", original_path="mIRC/mirc.ini",
local_path="extracted/mirc.ini", category="config_file",
filename="mirc.ini", size_bytes=500, extracted_by="comm",
)
await graph.register_asset(
inode="334-128-4", original_path="config/SYSTEM",
local_path="extracted/system", category="registry_hive",
filename="system", size_bytes=100, extracted_by="fs",
)
results = graph.query_assets(filename_pattern="mirc")
assert len(results) == 1
assert results[0].filename == "mirc.ini"
# ---------------------------------------------------------------------------
# Tool call partitioning (P0: parallel tool execution)
# ---------------------------------------------------------------------------
class TestPartitionToolCalls:
def test_consecutive_read_only_grouped(self):
calls = [
{"name": "list_directory", "arguments": {"inode": "33"}},
{"name": "list_directory", "arguments": {"inode": "46"}},
{"name": "list_directory", "arguments": {"inode": "50"}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert batches[0].is_read_only
assert len(batches[0].calls) == 3
def test_write_tools_isolated(self):
calls = [
{"name": "extract_file", "arguments": {"inode": "40"}},
{"name": "extract_file", "arguments": {"inode": "41"}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 2
assert not batches[0].is_read_only
assert not batches[1].is_read_only
def test_mixed_read_write_partitioned(self):
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "test"}},
{"name": "add_phenomenon", "arguments": {"title": "x"}},
{"name": "list_phenomena", "arguments": {}},
]
batches = _partition_tool_calls(calls)
assert len(batches) == 3
assert batches[0].is_read_only and len(batches[0].calls) == 2
assert not batches[1].is_read_only and len(batches[1].calls) == 1
assert batches[2].is_read_only and len(batches[2].calls) == 1
def test_empty_list(self):
assert _partition_tool_calls([]) == []
def test_single_read_only(self):
calls = [{"name": "partition_info", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert batches[0].is_read_only
def test_single_write(self):
calls = [{"name": "add_lead", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert not batches[0].is_read_only
def test_unknown_tool_treated_as_write(self):
calls = [{"name": "unknown_tool", "arguments": {}}]
batches = _partition_tool_calls(calls)
assert len(batches) == 1
assert not batches[0].is_read_only
def test_custom_read_only_set(self):
calls = [
{"name": "foo", "arguments": {}},
{"name": "bar", "arguments": {}},
]
batches = _partition_tool_calls(calls, read_only={"foo", "bar"})
assert len(batches) == 1
assert batches[0].is_read_only
class TestParallelToolExecution:
@pytest.mark.asyncio
async def test_parallel_faster_than_serial(self):
"""Read-only tools should execute concurrently."""
from llm_client import LLMClient
async def slow_tool(**kwargs):
await asyncio.sleep(0.05)
return "ok"
tool_executor = {
"list_directory": slow_tool,
"search_graph": slow_tool,
"list_phenomena": slow_tool,
}
client = LLMClient.__new__(LLMClient) # skip __init__
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "x"}},
{"name": "list_phenomena", "arguments": {}},
]
start = time.monotonic()
results = await client._execute_tool_batch_parallel(calls, tool_executor)
elapsed = time.monotonic() - start
assert len(results) == 3
assert all("ok" in r for r in results)
# 3 tasks × 50ms each should take ~50ms parallel, not ~150ms serial
assert elapsed < 0.12, f"Expected parallel execution but took {elapsed:.3f}s"
@pytest.mark.asyncio
async def test_error_in_parallel_batch_does_not_crash(self):
"""A failing tool in a parallel batch should produce an error result, not crash."""
from llm_client import LLMClient
async def ok_tool(**kwargs):
return "success"
async def bad_tool(**kwargs):
raise ValueError("broken")
tool_executor = {
"list_directory": ok_tool,
"search_graph": bad_tool,
}
client = LLMClient.__new__(LLMClient)
calls = [
{"name": "list_directory", "arguments": {}},
{"name": "search_graph", "arguments": {"keyword": "x"}},
]
results = await client._execute_tool_batch_parallel(calls, tool_executor)
assert len(results) == 2
assert "success" in results[0]
assert "Error" in results[1]
# ---------------------------------------------------------------------------
# Batched orchestrator LLM calls (P0: batch lead gen + judging)
# ---------------------------------------------------------------------------
class TestBatchLeadGeneration:
@pytest.mark.asyncio
async def test_batched_leads_created_with_hypothesis_id(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
# Set up 2 hypotheses
hid1 = await graph.add_hypothesis("Wardriving", "Wireless scanning", created_by="test")
hid2 = await graph.add_hypothesis("Data theft", "Stolen files", created_by="test")
# Mock LLM to return batched JSON
llm.chat.return_value = (
f'[{{"hypothesis_id": "{hid1}", "agent": "network", "task": "analyze pcap", "priority": 3}},'
f' {{"hypothesis_id": "{hid2}", "agent": "filesystem", "task": "check deleted files", "priority": 5}}]'
)
await orch._generate_hypothesis_leads()
# Should have made exactly 1 LLM call
assert llm.chat.call_count == 1
# Should have created 2 leads
pending = await graph.get_pending_leads()
assert len(pending) == 2
agents = {l.target_agent for l in pending}
assert "network" in agents
assert "filesystem" in agents
@pytest.mark.asyncio
async def test_fallback_on_parse_error(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid = await graph.add_hypothesis("Test", "Test desc", created_by="test")
# First call (batched) returns invalid JSON; fallback calls succeed
llm.chat.side_effect = [
"not valid json at all", # batched call fails
'[{"agent": "registry", "task": "check SAM", "priority": 5}]', # fallback
]
await orch._generate_hypothesis_leads()
# Should have called LLM twice: 1 batched (failed) + 1 fallback
assert llm.chat.call_count == 2
pending = await graph.get_pending_leads()
assert len(pending) == 1
class TestBatchJudging:
@pytest.mark.asyncio
async def test_batched_judgments_applied(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid1 = await graph.add_hypothesis("Hyp A", "Desc A", created_by="test")
hid2 = await graph.add_hypothesis("Hyp B", "Desc B", created_by="test")
pid1, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 1", "details 1", source_tool="t")
pid2, _ = await graph.add_phenomenon("fs", "filesystem", "Finding 2", "details 2", source_tool="t")
old_conf1 = graph.hypotheses[hid1].confidence
old_conf2 = graph.hypotheses[hid2].confidence
llm.chat.return_value = (
f'[{{"hypothesis_id": "{hid1}", "phenomenon_id": "{pid1}", "edge_type": "supports", "reason": "evidence for A"}},'
f' {{"hypothesis_id": "{hid2}", "phenomenon_id": "{pid2}", "edge_type": "contradicts", "reason": "disproves B"}}]'
)
await orch._judge_new_phenomena()
assert llm.chat.call_count == 1
assert graph.hypotheses[hid1].confidence > old_conf1
assert graph.hypotheses[hid2].confidence < old_conf2
@pytest.mark.asyncio
async def test_invalid_ids_skipped(self):
from unittest.mock import AsyncMock
from orchestrator import Orchestrator
from agent_factory import AgentFactory
graph = EvidenceGraph()
llm = AsyncMock()
factory = AgentFactory(llm, graph)
orch = Orchestrator(llm, graph, factory)
hid = await graph.add_hypothesis("Hyp", "Desc", created_by="test")
pid, _ = await graph.add_phenomenon("fs", "filesystem", "Finding", "details", source_tool="t")
old_conf = graph.hypotheses[hid].confidence
llm.chat.return_value = (
f'[{{"hypothesis_id": "hyp-nonexistent", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "bad ref"}},'
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "ph-nonexistent", "edge_type": "supports", "reason": "bad ref"}},'
f' {{"hypothesis_id": "{hid}", "phenomenon_id": "{pid}", "edge_type": "supports", "reason": "good"}}]'
)
await orch._judge_new_phenomena()
# Only the valid judgment should have been applied
assert graph.hypotheses[hid].confidence > old_conf
# Only 1 edge created (the valid one)
hyp_edges = [e for e in graph.edges if e.target_id == hid and e.source_id == pid]
assert len(hyp_edges) == 1
# ---------------------------------------------------------------------------
# Tool result caching
# ---------------------------------------------------------------------------
class TestToolResultCache:
def setup_method(self):
_tool_result_cache.clear()
def test_cache_key_deterministic(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("list_directory", {"inode": "33"})
assert k1 == k2
def test_cache_key_differs_by_args(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("list_directory", {"inode": "46"})
assert k1 != k2
def test_cache_key_differs_by_tool(self):
k1 = _cache_key("list_directory", {"inode": "33"})
k2 = _cache_key("find_file", {"inode": "33"})
assert k1 != k2
@pytest.mark.asyncio
async def test_cached_executor_returns_cached_result(self):
call_count = 0
async def real_executor(**kwargs):
nonlocal call_count
call_count += 1
return f"result for {kwargs}"
cached = _make_cached("test_tool", real_executor)
r1 = await cached(inode="33")
r2 = await cached(inode="33")
assert r1 == r2
assert call_count == 1 # only called once
@pytest.mark.asyncio
async def test_different_args_not_cached(self):
call_count = 0
async def real_executor(**kwargs):
nonlocal call_count
call_count += 1
return f"result-{call_count}"
cached = _make_cached("test_tool", real_executor)
r1 = await cached(inode="33")
r2 = await cached(inode="46")
assert r1 != r2
assert call_count == 2
@pytest.mark.asyncio
async def test_error_results_not_cached(self):
call_count = 0
async def flaky(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return "Error: something went wrong"
return "success"
cached = _make_cached("test_tool", flaky)
r1 = await cached(x=1)
assert "Error" in r1
r2 = await cached(x=1)
assert r2 == "success"
assert call_count == 2 # called twice because error wasn't cached
def test_cacheable_tools_are_read_only(self):
"""Every cacheable tool should also be in READ_ONLY_TOOLS."""
for t in CACHEABLE_TOOLS:
assert t in READ_ONLY_TOOLS, f"{t} is cacheable but not in READ_ONLY_TOOLS"
# ---------------------------------------------------------------------------
# Progressive context decay (Stage A)
# ---------------------------------------------------------------------------
class TestProgressiveDecay:
def _make_messages(self, n_rounds: int) -> list[dict]:
"""Build a synthetic message list with n_rounds of (assistant, user) pairs."""
messages = [{"role": "user", "content": "Start task"}]
for i in range(n_rounds):
messages.append({
"role": "assistant",
"content": f"<tool_call>{{'name': 'tool_{i}'}}</tool_call>",
})
# Tool result message with substantial content
messages.append({
"role": "user",
"content": (
f"{TOOL_RESULT_TAG}\n"
f"[tool_{i}] {'x' * 2500}\n"
f"{TOOL_RESULT_END}"
),
})
return messages
def test_short_conversation_unchanged(self):
msgs = self._make_messages(3)
result = _apply_progressive_decay(msgs)
assert len(result) == len(msgs)
# Content should be identical for short conversations
for orig, decayed in zip(msgs, result):
assert orig["content"] == decayed["content"]
def test_old_messages_truncated(self):
msgs = self._make_messages(20)
result = _apply_progressive_decay(msgs)
# Recent tool results should be full length
last_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][-1]
assert len(last_tool_result["content"]) > 2000
# Oldest tool results should be truncated
first_tool_result = [m for m in result if m["role"] == "user" and TOOL_RESULT_TAG in m["content"]][0]
assert len(first_tool_result["content"]) < 500
def test_message_count_preserved(self):
msgs = self._make_messages(20)
result = _apply_progressive_decay(msgs)
assert len(result) == len(msgs)
def test_non_tool_messages_unchanged(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _apply_progressive_decay(msgs)
assert result[0]["content"] == "Hello"
assert result[1]["content"] == "Hi there"
# ---------------------------------------------------------------------------
# LLM message folding (Stage B)
# ---------------------------------------------------------------------------
class TestMessageFolding:
@pytest.mark.asyncio
async def test_fold_replaces_old_messages_with_summary(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock(return_value="Summary: found important files")
# Build messages exceeding fold threshold
messages = [{"role": "user", "content": "Start task"}]
for i in range(30):
messages.append({"role": "assistant", "content": f"thinking step {i}"})
messages.append({"role": "user", "content": f"tool result {i}: {'data ' * 50}"})
result = await client._fold_old_messages(messages, "system prompt")
# Should be significantly shorter
assert len(result) < len(messages)
# First message should be the summary
assert "Context summary" in result[0]["content"]
# Recent messages preserved
assert len(result) == _FOLD_KEEP_RECENT + 1 # +1 for summary
@pytest.mark.asyncio
async def test_fold_not_triggered_below_threshold(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock()
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
result = await client._fold_old_messages(messages, "system")
# Should return original (n_to_fold = 2 - 10 = negative, so no folding)
assert result == messages
client.chat.assert_not_called()
@pytest.mark.asyncio
async def test_fold_graceful_on_llm_failure(self):
from llm_client import LLMClient
from unittest.mock import AsyncMock
client = LLMClient.__new__(LLMClient)
client.chat = AsyncMock(side_effect=Exception("API error"))
messages = [{"role": "user", "content": f"msg {i}"} for i in range(40)]
result = await client._fold_old_messages(messages, "system")
# On failure, should return original messages unchanged
assert len(result) == 40