Files
ER-TP-DGP/tests/test_pipeline.py
BattleTag b86ae87b75 Initial commit: ER-TP-DGP research prototype
Event-Reified Temporal Provenance Dual-Granularity Prompting for
LLM-based APT detection on DARPA provenance datasets.

Includes phase 0-14 method spec, IR/graph/metapath/trimming/prompt
modules, scripts for THEIA candidate universe, landmark CSG construction,
hybrid prompting, and LLM inference. Excludes data/, reports/, and
local LLM config from version control.
2026-05-15 16:53:57 +08:00

1366 lines
58 KiB
Python

import unittest
import json
from dataclasses import asdict
from pathlib import Path
from tempfile import TemporaryDirectory
from er_tp_dgp.adapters import ExplicitIRAdapter
from er_tp_dgp.candidate_universe import (
build_theia_candidate_universe,
select_anchor_for_candidate,
select_anchors_for_lifecycle,
stratified_sample,
)
from er_tp_dgp.candidates import WeakSignalCandidateGenerator, evaluate_candidates
from er_tp_dgp.experiments import default_method_registry, validate_method_registry
from er_tp_dgp.evaluation_batch import (
build_end_to_end_evaluation_batch,
build_evaluation_batch,
)
from er_tp_dgp.ground_truth import extract_e3_ground_truth_atoms
from er_tp_dgp.ground_truth_mapping import (
evaluate_candidate_recall,
match_theia_ground_truth_atoms,
)
from er_tp_dgp.labels import LabelMapper, LabelRecord
from er_tp_dgp.llm import (
LLMRequestConfig,
extract_openai_compatible_text,
parse_classification_output,
)
from er_tp_dgp.llm_config import load_llm_config
from er_tp_dgp.metrics import LayeredEvaluationReport, PredictionRecord, evaluate_classification
from examples.synthetic_fixture import build_synthetic_graph
from er_tp_dgp.metapaths import APTMetapathExtractor
from er_tp_dgp.prompt import PromptBuilder
from er_tp_dgp.schema import DatasetSchemaAudit
from er_tp_dgp.serialization import read_entities_jsonl, read_events_jsonl, write_jsonl
from er_tp_dgp.splits import TargetMetadata, check_leakage, time_based_split
from er_tp_dgp.theia import normalize_theia_event_type, theia_action_semantics
from er_tp_dgp.trimming import TemporalSecurityAwareTrimmer
from er_tp_dgp.validation import validate_evidence_paths, validate_graph, validate_ir
from er_tp_dgp.versioning import sanitized_llm_config, write_method_version_manifest
class PipelineTests(unittest.TestCase):
def test_event_reified_graph_preserves_event_and_causal_views(self):
graph = build_synthetic_graph()
self.assertIn("event-write", graph.events)
self.assertTrue(any(edge.edge_type == "ACTOR_TO_EVENT" for edge in graph.event_view_edges))
self.assertTrue(any(edge.edge_type == "EVENT_TO_OBJECT" for edge in graph.event_view_edges))
self.assertTrue(any(edge.edge_type == "CAUSAL_WRITE" for edge in graph.causal_view_edges))
self.assertTrue(any(edge.edge_type == "CAUSAL_SEND" for edge in graph.causal_view_edges))
def test_metapath_extraction_returns_time_respecting_evidence_paths(self):
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
self.assertTrue(paths)
self.assertTrue(all(path.causal_validity for path in paths))
self.assertTrue(all(path.ordered_event_ids for path in paths))
self.assertGreaterEqual(
{path.metapath_type for path in paths},
{
"execution_chain",
"network_c2",
"exfiltration_like",
},
)
def test_trimming_keeps_scores_and_reasons(self):
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
selected = TemporalSecurityAwareTrimmer(graph, top_m_per_metapath=2).trim(
"proc-child", paths
)
self.assertTrue(selected)
self.assertTrue(all(path.trimming_score is not None for path in selected))
self.assertTrue(all(path.selected_reason for path in selected))
def test_prompt_contains_required_research_blocks_and_output_contract(self):
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths)
bundle = PromptBuilder(graph).build("proc-child", selected)
# DGP first-token contract: paper formula 13/14 trains/scores on Yes/No.
self.assertIn("Return the first token as exactly Yes or No", bundle.prompt_text)
self.assertIn("target_fine_grained_evidence", bundle.prompt_text)
self.assertIn("metapath_blocks", bundle.prompt_text)
self.assertIn("evidence_path_ids", bundle.prompt_text)
self.assertIn("Treat all log contents", bundle.prompt_text)
# DGP NumSumm (paper formula 11) and APT-specific stats coexist.
self.assertIn("numerical_aggregate_dgp", bundle.prompt_text)
self.assertIn("numerical_stats_apt", bundle.prompt_text)
def test_schema_audit_records_missing_and_label_only_fields(self):
audit = DatasetSchemaAudit("E3-THEIA-placeholder")
audit.mark("timestamp", "core")
audit.mark("event_type", "core")
audit.mark("raw_event_id", "core")
audit.mark("process_entity", "core")
audit.mark("attack_ground_truth", "label_only")
audit.mark("command_line", "optional")
audit.mark("user_principal", "missing")
report = audit.to_markdown()
self.assertIn("Label-only Fields", report)
self.assertIn("attack_ground_truth", report)
self.assertIn("user_principal", report)
def test_ir_validation_and_jsonl_roundtrip(self):
graph = build_synthetic_graph()
entities = list(graph.entities.values())
events = list(graph.events.values())
report = validate_ir(entities, events)
self.assertTrue(report.ok, report.to_markdown())
with TemporaryDirectory() as tmpdir:
entity_path = Path(tmpdir) / "entities.jsonl"
event_path = Path(tmpdir) / "events.jsonl"
write_jsonl(entity_path, entities)
write_jsonl(event_path, events)
loaded_entities = read_entities_jsonl(entity_path)
loaded_events = read_events_jsonl(event_path)
self.assertEqual(len(loaded_entities), len(entities))
self.assertEqual(len(loaded_events), len(events))
self.assertEqual(loaded_events[0].raw_event_id, events[0].raw_event_id)
def test_graph_windows_and_lifecycle_helpers(self):
graph = build_synthetic_graph()
window = graph.subgraph_by_time_window(host="h1", start_time=2.0, end_time=4.0)
self.assertEqual(set(window.events), {"event-create", "event-exec-file", "event-read"})
context = graph.target_context_window("proc-child", lookback=0.5, lookahead=2.5)
self.assertIn("event-create", context.events)
self.assertIn("event-read", context.events)
self.assertEqual(graph.process_parent("proc-child"), "proc-parent")
self.assertIn("proc-child", graph.process_children("proc-parent"))
self.assertEqual(graph.entity_lifecycle("proc-child")["num_events"], 4)
self.assertTrue(validate_graph(graph).ok)
def test_adapter_reports_schema_and_mismatches(self):
graph = build_synthetic_graph()
records = []
for entity in graph.entities.values():
records.append({"record_type": "entity", "payload": asdict(entity)})
for event in graph.events.values():
records.append({"record_type": "event", "payload": asdict(event)})
adapter = ExplicitIRAdapter("synthetic-ir", known_missing_fields={"user_principal"})
result = adapter.adapt(records)
self.assertTrue(result.validation_report.ok)
self.assertIn("attack_ground_truth", result.mismatch_report.to_markdown())
self.assertIn("user_principal", result.schema_audit.missing_fields)
def test_label_mapping_is_label_only_and_candidate_eval_is_separate(self):
graph = build_synthetic_graph()
mapper = LabelMapper(graph)
labels = mapper.from_malicious_event_ids(
{"event-read", "event-send"},
label_source="synthetic_event_ids_only",
)
mapper.add_high_confidence_benign_outside_windows(
labels,
attack_windows=[(3.5, 5.5)],
label_source="synthetic_outside_attack_window",
target_type="EVENT",
)
self.assertFalse(labels.get("event-send").prompt_allowed)
self.assertEqual(labels.get("proc-child").label, "malicious")
self.assertRaises(ValueError, LabelRecord, "x", "EVENT", "malicious", 1.0, "gt", True)
candidates = WeakSignalCandidateGenerator(graph).generate_process_candidates()
evaluation = evaluate_candidates(candidates, labels, target_type="PROCESS")
self.assertEqual(evaluation.num_labeled_positive, 1)
self.assertEqual(evaluation.recall, 1.0)
def test_evidence_validation_after_trimming(self):
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths)
report = validate_evidence_paths(graph, selected)
self.assertTrue(report.ok, report.to_markdown())
def test_method_registry_marks_only_graph_dgp_as_main(self):
registry = default_method_registry()
self.assertFalse(validate_method_registry(registry))
self.assertTrue(registry["graph_dgp"].allowed_as_main)
self.assertFalse(registry["target_only_llm"].allowed_as_main)
self.assertFalse(registry["flat_log_llm"].uses_event_reified_graph)
self.assertTrue(registry["without_evidence_ids"].uses_llm_classifier)
self.assertFalse(registry["without_evidence_ids"].uses_evidence_ids)
def test_classification_metrics_and_layered_report(self):
predictions = [
PredictionRecord(
target_id="event-send",
target_type="EVENT",
score=0.95,
predicted_label="malicious",
true_label="malicious",
timestamp=5.0,
campaign_id="camp-a",
evidence_path_ids=("ep-1",),
prompt_tokens=100,
inference_cost=0.01,
prompt_construction_time=0.2,
),
PredictionRecord(
target_id="event-read",
target_type="EVENT",
score=0.75,
predicted_label="malicious",
true_label="malicious",
timestamp=4.0,
campaign_id="camp-a",
evidence_path_ids=("ep-2",),
prompt_tokens=120,
inference_cost=0.01,
prompt_construction_time=0.3,
),
PredictionRecord(
target_id="proc-parent",
target_type="PROCESS",
score=0.30,
predicted_label="benign",
true_label="benign",
timestamp=2.0,
prompt_tokens=80,
),
PredictionRecord(
target_id="event-write",
target_type="EVENT",
score=0.10,
predicted_label="benign",
true_label="benign",
timestamp=1.0,
prompt_tokens=70,
),
]
metrics = evaluate_classification(
predictions,
k_values=(1, 2),
recall_levels=(1.0,),
attack_start_by_campaign={"camp-a": 3.0},
)
self.assertEqual(metrics.num_positive, 2)
self.assertEqual(metrics.auprc, 1.0)
self.assertEqual(metrics.auroc, 1.0)
self.assertEqual(metrics.recall_at_k[2], 1.0)
self.assertEqual(metrics.attack_case_recall, 1.0)
self.assertEqual(metrics.detection_delay, 1.0)
self.assertEqual(metrics.evidence_path_hit_rate, 1.0)
report = LayeredEvaluationReport(candidate_generation=None, final_classification=metrics)
self.assertIn("final_classification", report.to_dict())
def test_time_split_and_leakage_checks(self):
targets = [
TargetMetadata(
target_id="t1",
target_type="EVENT",
timestamp=1.0,
host="h1",
campaign_id="camp-a",
raw_event_ids=("raw-1",),
process_ids=("pid-1",),
file_paths=("/tmp/payload",),
prompt_text="same prompt",
summary_ids=("sum-1",),
),
TargetMetadata(
target_id="t2",
target_type="EVENT",
timestamp=2.0,
host="h1",
campaign_id="camp-b",
raw_event_ids=("raw-2",),
process_ids=("pid-2",),
prompt_text="validation prompt",
),
TargetMetadata(
target_id="t3",
target_type="EVENT",
timestamp=5.0,
host="h1",
campaign_id="camp-a",
raw_event_ids=("raw-3",),
process_ids=("pid-1",),
file_paths=("/tmp/payload",),
prompt_text="same prompt",
summary_ids=("sum-1",),
),
]
assignment = time_based_split(targets, train_until=1.0, validation_until=3.0)
self.assertEqual(assignment.split_by_target["t1"].value, "train")
self.assertEqual(assignment.split_by_target["t3"].value, "test")
report = check_leakage(
targets,
assignment,
ioc_file_paths={"/tmp/payload"},
host_time_window=10.0,
)
markdown = report.to_markdown()
self.assertFalse(report.ok)
self.assertIn("process_id_leakage", markdown)
self.assertIn("duplicated_prompt_leakage", markdown)
self.assertIn("campaign_leakage", markdown)
self.assertIn("file_path_ioc_leakage", markdown)
def test_openai_compatible_response_parser(self):
response = {
"choices": [
{
"message": {
"content": (
"MALICIOUS\n"
"{\"score\": 0.82, \"predicted_label\": \"MALICIOUS\", "
"\"involved_techniques\": [\"execution_chain\"], "
"\"evidence_path_ids\": [\"ep-1\"], "
"\"concise_explanation\": \"Evidence path supports the label.\", "
"\"uncertainty\": \"low\", "
"\"missing_fields\": [\"ground_truth_not_in_prompt\"], "
"\"recommended_analyst_checks\": [\"review ep-1\"]}"
)
}
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
text = extract_openai_compatible_text(response)
output = parse_classification_output(text)
self.assertEqual(output.first_token_label, "MALICIOUS")
self.assertEqual(output.predicted_label, "MALICIOUS")
self.assertEqual(output.score, 0.82)
self.assertEqual(output.evidence_path_ids, ("ep-1",))
self.assertEqual(output.involved_techniques, ("execution_chain",))
def test_llm_request_config_url_handling(self):
self.assertEqual(
LLMRequestConfig(
provider_type="api",
base_url="https://example.test",
model="model-a",
).completions_url(),
"https://example.test/v1/chat/completions",
)
self.assertEqual(
LLMRequestConfig(
provider_type="local",
base_url="http://localhost:8000/v1",
model="model-b",
).completions_url(),
"http://localhost:8000/v1/chat/completions",
)
def test_llm_yaml_config_loading(self):
with TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "llm.yaml"
config_path.write_text(
"\n".join(
[
"provider: api",
"base_url: https://example.test/v1",
"model: model-a",
"api_key_env: TEST_API_KEY",
"timeout_seconds: 12",
"temperature: 0.1",
"max_tokens: 128",
"user_agent: test-agent/1.0",
"extra_headers:",
" X-Test: 'yes'",
"extra_body:",
" seed: 7",
]
),
encoding="utf-8",
)
config = load_llm_config(config_path)
self.assertEqual(config.provider_type, "api")
self.assertEqual(config.completions_url(), "https://example.test/v1/chat/completions")
self.assertEqual(config.model, "model-a")
self.assertEqual(config.api_key_env, "TEST_API_KEY")
self.assertEqual(config.timeout_seconds, 12)
self.assertEqual(config.temperature, 0.1)
self.assertEqual(config.max_tokens, 128)
self.assertEqual(config.user_agent, "test-agent/1.0")
self.assertEqual(config.extra_headers["X-Test"], "yes")
self.assertEqual(config.extra_body["seed"], 7)
def test_theia_action_semantics_are_auditable(self):
execute = theia_action_semantics("EVENT_EXECUTE")
self.assertEqual(execute.normalized_action, "EXEC")
self.assertEqual(execute.canonical_action, "PROC_EXEC_FILE")
self.assertIn("execution_chain", execute.metapath_hints)
mprotect = theia_action_semantics("EVENT_MPROTECT")
self.assertEqual(mprotect.canonical_action, "PROC_CHANGE_MEMORY_PROTECTION")
self.assertIn("memory_context", mprotect.metapath_hints)
self.assertEqual(mprotect.default_risk, "low")
self.assertEqual(normalize_theia_event_type("EVENT_RECVFROM"), "RECEIVE")
def test_method_version_manifest_redacts_llm_config(self):
with TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
config_dir = root / "configs"
config_dir.mkdir()
config_path = config_dir / "llm.yaml"
config_path.write_text(
"\n".join(
[
"provider: api",
"base_url: https://example.test/v1",
"model: model-a",
"api_key: secret-value",
]
),
encoding="utf-8",
)
sanitized = sanitized_llm_config(config_path)
self.assertEqual(sanitized["api_key"], "<redacted>")
self.assertIn("sanitized_config_sha256", sanitized)
output = root / "reports" / "manifest.json"
manifest = write_method_version_manifest(
output,
repo_root=".",
llm_config_path=config_path,
)
self.assertTrue(output.exists())
self.assertEqual(manifest.version, "ER-TP-DGP-v0.1")
self.assertEqual(manifest.llm_config["api_key"], "<redacted>")
def test_theia_candidate_universe_is_label_free_and_deterministic(self):
prefix = "com.bbn.tc.schema.avro.cdm18."
def wrapped(record_type, payload):
return {"datum": {prefix + record_type: payload}, "source": "test"}
with TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "theia.json"
records = [
wrapped(
"Subject",
{
"uuid": "subject-1",
"hostId": "host-1",
"type": "SUBJECT_PROCESS",
"cmdLine": {"string": "python /tmp/dropper.py AAAABBBBCCCCDDDDEEEEFFFFGGGGHHHH"},
"properties": {"map": {"path": "/home/admin/bin/python"}},
},
),
wrapped(
"NetFlowObject",
{
"uuid": "flow-1",
"localAddress": "10.0.0.2",
"localPort": 4444,
"remoteAddress": "8.8.8.8",
"remotePort": 443,
},
),
wrapped(
"FileObject",
{
"uuid": "file-1",
"baseObject": {"properties": {"map": {"path": "/tmp/payload"}}},
},
),
wrapped(
"Event",
{
"uuid": "event-recv",
"type": "EVENT_RECVFROM",
"timestampNanos": 100,
"subject": {prefix + "UUID": "subject-1"},
"predicateObject": {prefix + "UUID": "flow-1"},
},
),
wrapped(
"Event",
{
"uuid": "event-write",
"type": "EVENT_WRITE",
"timestampNanos": 200,
"subject": {prefix + "UUID": "subject-1"},
"predicateObject": {prefix + "UUID": "file-1"},
"predicateObjectPath": "/tmp/payload",
},
),
wrapped(
"Event",
{
"uuid": "event-exec",
"type": "EVENT_EXECUTE",
"timestampNanos": 300,
"subject": {prefix + "UUID": "subject-1"},
"predicateObject": {prefix + "UUID": "file-1"},
"predicateObjectPath": "/tmp/payload",
},
),
]
path.write_text(
"\n".join(json.dumps(record, sort_keys=True) for record in records) + "\n",
encoding="utf-8",
)
universe = build_theia_candidate_universe([path])
candidates = universe.candidate_profiles(min_score=1.0, min_events=1)
sample_a = stratified_sample(candidates, per_stratum=1, seed=11)
sample_b = stratified_sample(candidates, per_stratum=1, seed=11)
self.assertEqual(universe.events_seen, 3)
self.assertEqual(len(candidates), 1)
profile = candidates[0]
self.assertEqual(profile.candidate_id, "subject-1")
self.assertEqual(profile.write_then_execute_count, 1)
self.assertEqual(profile.recv_then_write_count, 1)
self.assertGreater(profile.external_flow_count, 0)
self.assertIn("event-recv", profile.sample_raw_event_ids)
self.assertIn("This is a label-free candidate universe", universe.to_markdown())
self.assertEqual([item.candidate_id for item in sample_a], [item.candidate_id for item in sample_b])
# End-to-end anchor must come from a weak-signal-triggering event,
# never from ground truth. The first event ("event-recv") triggers
# external_flow + network_activity, so it should be picked.
anchor = select_anchor_for_candidate(profile, strategy="first_weak_signal")
self.assertEqual(anchor.anchor_event_id, "event-recv")
self.assertEqual(anchor.anchor_timestamp_nanos, 100)
self.assertFalse(anchor.fallback_used)
self.assertIn("network_activity", anchor.triggering_signals)
lifecycle = select_anchors_for_lifecycle(profile, max_anchors=4)
self.assertGreaterEqual(len(lifecycle), 1)
self.assertEqual(lifecycle[0].anchor_event_id, "event-recv")
# All anchors must come from weak-signal-triggering events or the
# padding fallback — not from any external label source.
self.assertTrue(all(a.anchor_event_id is not None for a in lifecycle))
# Round-trip via JSONL to make sure the anchor selector also works
# against the serialized form (this is what the e2e batch builder
# consumes in production).
serialized = profile.to_json_dict()
anchor_from_dict = select_anchor_for_candidate(serialized, strategy="first_weak_signal")
self.assertEqual(anchor_from_dict.anchor_event_id, "event-recv")
# Backward compatibility: a row missing weak_signal_events must not
# crash; it should fall back to the first observed event.
legacy_row = {
"candidate_id": "legacy-1",
"sample_raw_event_ids": ["legacy-event-1", "legacy-event-2"],
}
legacy_anchor = select_anchor_for_candidate(legacy_row, strategy="first_weak_signal")
self.assertEqual(legacy_anchor.anchor_event_id, "legacy-event-1")
self.assertTrue(legacy_anchor.fallback_used)
def test_end_to_end_evaluation_batch_uses_no_ground_truth(self):
with TemporaryDirectory() as tmp:
universe_path = Path(tmp) / "universe.jsonl"
universe_path.write_text(
"\n".join(
json.dumps(row, sort_keys=True)
for row in [
{
"candidate_id": "subject-positive",
"process_path": "/tmp/payload",
"command_line": "./payload",
"weak_signal_score": 5.0,
"total_events": 10,
"estimated_prompt_tokens": 800,
"stratum": "execution_heavy",
"sample_raw_event_ids": ["evt-pos-a"],
"weak_signal_events": [
{
"event_id": "evt-pos-a",
"timestamp_nanos": 1_000_000_000,
"signals": ["external_flow"],
}
],
"first_event_id": "evt-pos-a",
"first_event_timestamp_nanos": 1_000_000_000,
},
{
"candidate_id": "subject-benign",
"process_path": "/usr/bin/firefox",
"command_line": "/usr/bin/firefox",
"weak_signal_score": 2.0,
"total_events": 200,
"estimated_prompt_tokens": 1500,
"stratum": "browser_like",
"sample_raw_event_ids": ["evt-ben-a"],
"weak_signal_events": [
{
"event_id": "evt-ben-a",
"timestamp_nanos": 2_000_000_000,
"signals": ["network_activity"],
}
],
"first_event_id": "evt-ben-a",
"first_event_timestamp_nanos": 2_000_000_000,
},
]
)
+ "\n",
encoding="utf-8",
)
label_path = Path(tmp) / "labels.jsonl"
label_path.write_text(
"\n".join(
json.dumps(row, sort_keys=True)
for row in [
{
"target_id": "subject-positive",
"label": "malicious",
"label_confidence": "high",
"label_source": "fixture_oracle",
"atom_id": "atom-1",
}
]
)
+ "\n",
encoding="utf-8",
)
batch_no_labels = build_end_to_end_evaluation_batch(
candidate_universe_path=universe_path,
label_lookup_path=None,
anchor_strategy="first_weak_signal",
)
# No labels supplied → all targets are unlabeled. Anchor selection
# is independent of label state.
self.assertEqual(len(batch_no_labels.targets), 2)
for target in batch_no_labels.targets:
self.assertEqual(target.label, "unlabeled")
self.assertEqual(target.label_source, "no_ground_truth_join")
self.assertIsNotNone(target.anchor_event_id)
self.assertIsNotNone(target.anchor_timestamp_nanos)
self.assertIn(
"End-to-end batch: anchor selected from raw-log weak signals",
"\n".join(target.notes),
)
batch_with_labels = build_end_to_end_evaluation_batch(
candidate_universe_path=universe_path,
label_lookup_path=label_path,
anchor_strategy="first_weak_signal",
)
labels_by_id = {t.target_id: t for t in batch_with_labels.targets}
self.assertEqual(labels_by_id["subject-positive"].label, "malicious")
self.assertEqual(labels_by_id["subject-positive"].atom_id, "atom-1")
self.assertEqual(labels_by_id["subject-benign"].label, "unlabeled")
# Critical invariant: anchor for the labeled positive must be
# identical between the labeled and unlabeled runs — labels must
# not leak into anchor selection.
self.assertEqual(
labels_by_id["subject-positive"].anchor_event_id,
next(t for t in batch_no_labels.targets if t.target_id == "subject-positive").anchor_event_id,
)
self.assertEqual(
labels_by_id["subject-positive"].anchor_timestamp_nanos,
next(t for t in batch_no_labels.targets if t.target_id == "subject-positive").anchor_timestamp_nanos,
)
def test_ground_truth_atom_extraction_is_label_only(self):
text = """
3 Nation State
3.3 20180410 1400 THEIA - Firefox Backdoor w/ Drakon In-Memory
3.3.2 Event Log
14:11 Received connection from firefox
putfile ./deploy/archive/drakon.x64_128.55.12.1 /tmp/payload
connect to 1.2.3.4:80 and www.allstate.com
3.4 20180411 1000 FiveDirections - Firefox Backdoor w/ Drakon In-Memory
not the target filter
4 Common Threat
4.6 20180410 1300 THEIA - Phishing E-mail w/ Link
Connect to www.nasa.ng (208.75.117.3)
"""
report = extract_e3_ground_truth_atoms(text, target_filter="THEIA")
self.assertEqual(len(report.atoms), 2)
self.assertTrue(all(not atom.prompt_allowed for atom in report.atoms))
self.assertEqual(report.atoms[0].attack_group, "nation_state")
self.assertIn("1.2.3.4:80", report.atoms[0].ips)
self.assertIn("/tmp/payload", report.atoms[0].file_paths)
self.assertIn("www.nasa.ng", report.atoms[1].domains)
self.assertIn("must not enter LLM prompts", report.to_markdown())
def test_theia_ground_truth_mapping_and_candidate_recall_are_label_only(self):
prefix = "com.bbn.tc.schema.avro.cdm18."
atoms = extract_e3_ground_truth_atoms(
"""
3 Nation State
3.3 20180410 1400 THEIA - Firefox Backdoor w/ Drakon In-Memory
connect to 8.8.8.8:443
""",
target_filter="THEIA",
).atoms
def wrapped(record_type, payload):
return {"datum": {prefix + record_type: payload}, "source": "test"}
with TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
theia_path = tmp / "theia.json"
records = [
wrapped(
"Subject",
{
"uuid": "subject-gt",
"hostId": "host-1",
"type": "SUBJECT_PROCESS",
"cmdLine": {"string": "/usr/bin/firefox"},
"properties": {"map": {"path": "/home/admin/Downloads/firefox/firefox"}},
},
),
wrapped(
"NetFlowObject",
{
"uuid": "flow-gt",
"localAddress": "10.0.0.2",
"localPort": 4444,
"remoteAddress": "8.8.8.8",
"remotePort": 443,
},
),
wrapped(
"Event",
{
"uuid": "event-gt",
"type": "EVENT_CONNECT",
"timestampNanos": 1_523_379_600_000_000_000,
"subject": {prefix + "UUID": "subject-gt"},
"predicateObject": {prefix + "UUID": "flow-gt"},
},
),
]
theia_path.write_text(
"\n".join(json.dumps(record, sort_keys=True) for record in records) + "\n",
encoding="utf-8",
)
candidate_path = tmp / "candidates.jsonl"
candidate_path.write_text(
json.dumps({"candidate_id": "subject-gt"}) + "\n",
encoding="utf-8",
)
mapping = match_theia_ground_truth_atoms([theia_path], atoms, min_score=3.0)
recall = evaluate_candidate_recall(
candidate_path,
mapping.process_labels,
mapping.event_matches,
)
self.assertEqual(len(mapping.event_matches), 1)
self.assertEqual(mapping.event_matches[0].raw_event_id, "event-gt")
self.assertFalse(mapping.event_matches[0].prompt_allowed)
self.assertEqual(mapping.process_labels[0].subject_uuid, "subject-gt")
self.assertEqual(recall.process_recall, 1.0)
self.assertIn("label/evaluation-only", mapping.to_markdown())
def test_evaluation_batch_keeps_labels_out_of_prompt_inputs(self):
with TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
process_labels = tmp / "process_labels.jsonl"
event_matches = tmp / "event_matches.jsonl"
candidates = tmp / "candidates.jsonl"
all_labels = tmp / "all_labels.jsonl"
process_labels.write_text(
json.dumps(
{
"subject_uuid": "proc-positive",
"confidence": "high",
"atom_id": "atom-1",
"matched_event_ids": ["event-positive"],
}
)
+ "\n",
encoding="utf-8",
)
all_labels.write_text(process_labels.read_text(encoding="utf-8"), encoding="utf-8")
event_matches.write_text(
json.dumps(
{
"raw_event_id": "event-positive",
"score": 9.0,
"subject_path": "/tmp/payload",
"command_line": "/tmp/payload",
}
)
+ "\n",
encoding="utf-8",
)
candidates.write_text(
"\n".join(
[
json.dumps(
{
"candidate_id": "proc-positive",
"sample_raw_event_ids": ["event-positive"],
"weak_signal_score": 10.0,
}
),
json.dumps(
{
"candidate_id": "proc-hard-negative",
"sample_raw_event_ids": ["event-negative"],
"weak_signal_score": 8.0,
"process_path": "/usr/bin/ssh",
}
),
]
)
+ "\n",
encoding="utf-8",
)
batch = build_evaluation_batch(
positive_process_labels_path=process_labels,
positive_event_matches_path=event_matches,
candidate_universe_path=candidates,
all_mapped_process_labels_path=all_labels,
num_positives=1,
num_hard_negative_proxies=1,
max_hard_negative_events=1000,
seed=1,
)
self.assertEqual(len(batch.targets), 2)
positive = [target for target in batch.targets if target.label == "malicious"][0]
negative = [target for target in batch.targets if target.cohort == "hard_negative_proxy"][0]
self.assertFalse(positive.prompt_allowed_label_fields)
self.assertEqual(positive.anchor_event_id, "event-positive")
self.assertEqual(negative.label, "benign_proxy")
self.assertIn("not a high-confidence benign label", " ".join(negative.notes))
class DGPComponentTests(unittest.TestCase):
"""Coverage for DGP-aligned algorithmic components added in Phase 2/3."""
def test_node_text_summarizer_caches_by_sha256(self):
from er_tp_dgp.text_summarizer import NodeTextSummarizer, SummarizerConfig
class _RecordingLLM:
def __init__(self) -> None:
self.calls: list[tuple[str, int]] = []
def complete(self, prompt: str, *, max_tokens: int) -> str:
self.calls.append((prompt, max_tokens))
return "summary-of-" + prompt[-12:]
with TemporaryDirectory() as tmp:
cfg = SummarizerConfig(
b_node=10, model_name="testmodel", cache_dir=Path(tmp)
)
llm = _RecordingLLM()
summ = NodeTextSummarizer(llm=llm, config=cfg)
first = summ.summarize("command_line=/usr/bin/python /tmp/payload --sync")
second = summ.summarize("command_line=/usr/bin/python /tmp/payload --sync")
third = summ.summarize("command_line=/bin/bash -i")
self.assertEqual(first, second)
self.assertEqual(len(llm.calls), 2, "duplicate text must be served from cache")
self.assertNotEqual(first, third)
cache_files = list(Path(tmp).iterdir())
self.assertEqual(len(cache_files), 2)
def test_node_summarize_batch_concurrency_and_cache(self):
"""Concurrent batch must (a) call LLM exactly once per unique key,
(b) preserve input order, (c) be safe against same-key duplicates.
"""
import time
from er_tp_dgp.text_summarizer import NodeTextSummarizer, SummarizerConfig
class _SlowLLM:
def __init__(self) -> None:
self.calls: list[str] = []
self._lock = __import__("threading").Lock()
def complete(self, prompt: str, *, max_tokens: int) -> str: # noqa: ARG002
with self._lock:
self.calls.append(prompt)
# Sleep simulates network latency. With 4 workers and 5 unique
# keys, sequential would take 5 * 0.05 = 0.25s; concurrent
# should be roughly 2 * 0.05 ~= 0.10s.
time.sleep(0.05)
return f"sum:{prompt[-8:]}"
with TemporaryDirectory() as tmp:
cfg = SummarizerConfig(
b_node=10, model_name="t", cache_dir=Path(tmp), max_workers=4
)
llm = _SlowLLM()
summ = NodeTextSummarizer(llm=llm, config=cfg)
# 5 unique + 2 duplicates of unique[0] = 7 inputs but only 5 LLM calls.
inputs = [
"alpha-text-aaaa",
"beta-text-bbbb",
"gamma-text-gggg",
"alpha-text-aaaa", # duplicate
"delta-text-dddd",
"epsilon-text-eeee",
"alpha-text-aaaa", # duplicate
]
t0 = time.monotonic()
results = summ.summarize_batch(inputs)
elapsed = time.monotonic() - t0
self.assertEqual(len(results), len(inputs))
self.assertEqual(results[0], results[3])
self.assertEqual(results[0], results[6])
self.assertEqual(len(llm.calls), 5, f"expected 5 unique LLM calls, got {len(llm.calls)}")
# Concurrency with 4 workers on 5 calls should be ~0.10s, not ~0.25s.
self.assertLess(elapsed, 0.20, f"batch was not concurrent (elapsed={elapsed:.3f}s)")
# Re-running the same batch must hit the disk cache and skip the LLM.
results2 = summ.summarize_batch(inputs)
self.assertEqual(results, results2)
self.assertEqual(len(llm.calls), 5)
def test_metapath_text_summarizer_combines_neighbor_summaries(self):
from er_tp_dgp.text_summarizer import MetapathTextSummarizer, SummarizerConfig
class _Joiner:
def complete(self, prompt: str, *, max_tokens: int) -> str: # noqa: ARG002
return prompt.rsplit("\n", 1)[-2][:60]
with TemporaryDirectory() as tmp:
cfg = SummarizerConfig(b_meta=8, model_name="t", cache_dir=Path(tmp))
summ = MetapathTextSummarizer(llm=_Joiner(), config=cfg)
result = summ.summarize_metapath("execution_chain", ["proc=python", "proc=bash"])
self.assertNotEqual(result, "")
self.assertEqual(
summ.summarize_metapath("execution_chain", []), ""
)
def test_markov_diffusion_trimmer_selects_paths_through_top_m_neighbors(self):
try:
import numpy # noqa: F401
except ImportError:
self.skipTest("numpy not installed; skipping MDK trimmer test")
from er_tp_dgp.diffusion_trimmer import (
HashingEmbedder,
MarkovDiffusionTrimmer,
MDKConfig,
)
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
trimmer = MarkovDiffusionTrimmer(
graph,
embedder=HashingEmbedder(dim=32),
config=MDKConfig(k_hops=3, top_m=2),
)
kept = trimmer.trim("proc-child", paths)
self.assertGreater(len(kept), 0)
for path in kept:
self.assertTrue(path.causal_validity)
self.assertIsNotNone(path.selected_reason)
self.assertIn("mdk(", path.selected_reason or "")
def test_numerical_aggregator_outputs_means_and_histograms(self):
from er_tp_dgp.numerical_aggregator import NumericalAggregator
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
# Use the executions chain block specifically.
exec_paths = [p for p in paths if p.metapath_type == "execution_chain"]
agg = NumericalAggregator(graph).aggregate("execution_chain", exec_paths)
self.assertGreater(agg.neighbor_count, 0)
d = agg.to_prompt_dict()
self.assertIn("node_type_hist", d)
self.assertIn("action_hist", d)
self.assertAlmostEqual(sum(d["node_type_hist"].values()), 1.0, places=5)
self.assertAlmostEqual(sum(d["action_hist"].values()), 1.0, places=5)
def test_score_from_top_logprobs_softmax(self):
from er_tp_dgp.scoring import score_from_top_logprobs
result = score_from_top_logprobs(
[
{"token": "Yes", "logprob": -0.10},
{"token": " No", "logprob": -2.30},
{"token": "the", "logprob": -5.0},
]
)
self.assertIsNotNone(result.score)
self.assertGreater(result.score, 0.85)
self.assertEqual(result.matched_yes_token, "Yes")
self.assertEqual(result.matched_no_token, " No")
self.assertFalse(result.fallback_used)
empty = score_from_top_logprobs([{"token": "the", "logprob": -1.0}])
self.assertIsNone(empty.score)
self.assertTrue(empty.fallback_used)
def test_resolve_auto_model_class_routes_by_architecture(self):
"""Routing must pick image_text_to_text for Qwen3.5-style multimodal."""
try:
from transformers import ( # noqa: F401
AutoModelForCausalLM as ACL,
AutoModelForImageTextToText as AITT,
AutoModelForSeq2SeqLM as AS2S,
)
except ImportError:
self.skipTest("transformers not installed; skipping resolver test")
from er_tp_dgp.llm import _resolve_auto_model_class
class _FakeCfg:
def __init__(self, architectures, has_vision):
self.architectures = architectures
self.vision_config = {"hidden_size": 1} if has_vision else None
self.assertIs(
_resolve_auto_model_class("auto", _FakeCfg(["Qwen3ForCausalLM"], False)),
ACL,
)
# Multimodal conditional generation with vision_config (Qwen3.5-27B).
self.assertIs(
_resolve_auto_model_class(
"auto", _FakeCfg(["Qwen3_5ForConditionalGeneration"], True)
),
AITT,
)
# Conditional generation without vision (T5-style).
self.assertIs(
_resolve_auto_model_class("auto", _FakeCfg(["T5ForConditionalGeneration"], False)),
AS2S,
)
# Explicit override wins regardless of config.
self.assertIs(
_resolve_auto_model_class("causal_lm", _FakeCfg(["Qwen3_5ForConditionalGeneration"], True)),
ACL,
)
def test_yes_no_protocol_parses_to_canonical_labels(self):
output = parse_classification_output(
"Yes\n{\"predicted_label\": \"Yes\", \"evidence_path_ids\": [\"ep-1\"]}"
)
self.assertEqual(output.first_token_label, "MALICIOUS")
self.assertEqual(output.predicted_label, "MALICIOUS")
output_no = parse_classification_output("No\n{}")
self.assertEqual(output_no.first_token_label, "BENIGN")
def test_dgp_prompt_block_layout(self):
from er_tp_dgp.prompt import PromptComponentSwitches
graph = build_synthetic_graph()
paths = APTMetapathExtractor(graph).extract_for_target("proc-child")
selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths)
# without_path_summ ablation: drop LLM PathSumm, keep concat.
ablate = PromptBuilder(
graph,
switches=PromptComponentSwitches(
use_text_summarization=False,
use_path_summarization_llm=False,
),
).build("proc-child", selected)
self.assertIn("path_summary_concat", ablate.prompt_text)
self.assertNotIn("\"text_summary\"", ablate.prompt_text)
class AblationRegistryTests(unittest.TestCase):
def test_registry_contains_dgp_paper_aligned_ablations(self):
registry = default_method_registry()
for name in (
"graph_dgp",
"without_dgp_text_summ",
"without_dgp_mdk",
"without_dgp_path_summ",
"without_dgp_num_summ",
):
self.assertIn(name, registry, f"missing variant {name}")
def test_main_variant_requires_all_four_dgp_components(self):
# Build a deliberately broken main variant to confirm validation fires.
from er_tp_dgp.experiments import MethodFamily, MethodVariant
broken = MethodVariant(
name="graph_dgp_broken",
family=MethodFamily.MAIN,
description="missing text summ",
uses_event_reified_graph=True,
uses_target_fine_grained=True,
uses_local_context=True,
uses_time_respecting_metapaths=True,
uses_temporal_trimming=True,
uses_security_aware_trimming=True,
uses_metapath_summary=True,
uses_node_level_summary=True,
uses_numerical_summary=True,
uses_evidence_ids=True,
uses_llm_classifier=True,
uses_dgp_text_summarization=False, # the deliberate gap
uses_dgp_diffusion_trimming=True,
uses_dgp_path_summarization_llm=True,
uses_dgp_numerical_aggregation=True,
allowed_as_main=True,
)
issues = broken.validate_role()
self.assertTrue(issues)
self.assertIn("uses_dgp_text_summarization", issues[0])
class MultiAnchorWindowIRTests(unittest.TestCase):
"""Phase 1.5: build_multi_target_window_irs single-scan multi-anchor builder."""
def _write_cdm18(self, dir_path: Path, lines: list[dict]) -> Path:
"""Write CDM18-shaped JSONL the same way iter_theia_records expects to read it."""
path = dir_path / "ta1-theia-fake.json"
with path.open("w", encoding="utf-8") as f:
for line in lines:
f.write(json.dumps(line) + "\n")
return path
def _record(self, full_type: str, payload: dict) -> dict:
return {"datum": {f"com.bbn.tc.schema.avro.cdm18.{full_type}": payload}}
def test_two_anchors_one_pass_demuxes_events_per_window(self):
from er_tp_dgp.theia import build_multi_target_window_irs
with TemporaryDirectory() as tmp:
tmp_p = Path(tmp)
# Two anchor events at very different times.
# A0 at t=1_000_000_000_000 (ns) → window [t-1s, t+1s]
# A1 at t=5_000_000_000_000 (ns) → window [t-1s, t+1s]
# Plus events near A0 (in window), near A1 (in window), and one far
# from both (out of all windows).
host = self._record("Host", {"uuid": "H", "hostName": "h1"})
principal = self._record("Principal", {"uuid": "P"})
subj_a = self._record("Subject", {"uuid": "Sa", "type": "SUBJECT_PROCESS",
"cmdLine": {"string": "/bin/a"}})
subj_b = self._record("Subject", {"uuid": "Sb", "type": "SUBJECT_PROCESS",
"cmdLine": {"string": "/bin/b"}})
file_obj = self._record("FileObject", {"uuid": "F", "baseObject": {}})
anchor0 = self._record("Event", {
"uuid": "A0", "type": "EVENT_OPEN", "timestampNanos": 1_000_000_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"},
"predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"},
})
near_a0 = self._record("Event", {
"uuid": "E_NEAR_A0", "type": "EVENT_READ", "timestampNanos": 1_000_500_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"},
"predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"},
})
anchor1 = self._record("Event", {
"uuid": "A1", "type": "EVENT_EXECUTE", "timestampNanos": 5_000_000_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sb"},
})
near_a1 = self._record("Event", {
"uuid": "E_NEAR_A1", "type": "EVENT_READ", "timestampNanos": 5_000_500_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sb"},
"predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"},
})
far = self._record("Event", {
"uuid": "E_FAR", "type": "EVENT_READ", "timestampNanos": 9_000_000_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"},
})
json_path = self._write_cdm18(
tmp_p,
[host, principal, subj_a, subj_b, file_obj, anchor0, near_a0, anchor1, near_a1, far],
)
cache_dir = tmp_p / "cache"
results = build_multi_target_window_irs(
[json_path],
anchors=[
{"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
{"anchor_event_uuid": "A1", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
],
cache_dir=cache_dir,
)
self.assertIn("A0", results)
self.assertIn("A1", results)
self.assertEqual(results["A0"].target_subject_id, "Sa")
self.assertEqual(results["A1"].target_subject_id, "Sb")
# A0 window must contain only A0 + E_NEAR_A0.
a0_event_ids = {e.event_id for e in results["A0"].events}
self.assertIn("A0", a0_event_ids)
self.assertIn("E_NEAR_A0", a0_event_ids)
self.assertNotIn("A1", a0_event_ids)
self.assertNotIn("E_NEAR_A1", a0_event_ids)
self.assertNotIn("E_FAR", a0_event_ids)
# A1 window must contain only A1 + E_NEAR_A1.
a1_event_ids = {e.event_id for e in results["A1"].events}
self.assertIn("A1", a1_event_ids)
self.assertIn("E_NEAR_A1", a1_event_ids)
self.assertNotIn("A0", a1_event_ids)
self.assertNotIn("E_FAR", a1_event_ids)
# Cache files must exist for both anchors.
cache_files = list(cache_dir.glob("*.json.gz"))
self.assertEqual(len(cache_files), 2)
# Re-running with same args must hit the cache (no scan needed).
cached_results = build_multi_target_window_irs(
[json_path],
anchors=[
{"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
{"anchor_event_uuid": "A1", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
],
cache_dir=cache_dir,
)
self.assertEqual(
{e.event_id for e in cached_results["A0"].events},
a0_event_ids,
)
def test_missing_anchor_is_silently_omitted(self):
from er_tp_dgp.theia import build_multi_target_window_irs
with TemporaryDirectory() as tmp:
tmp_p = Path(tmp)
# Corpus contains exactly one anchor; the second is not present.
host = self._record("Host", {"uuid": "H"})
subj = self._record("Subject", {"uuid": "S", "type": "SUBJECT_PROCESS"})
anchor0 = self._record("Event", {
"uuid": "A0", "type": "EVENT_OPEN", "timestampNanos": 1_000_000_000_000,
"subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "S"},
})
json_path = self._write_cdm18(tmp_p, [host, subj, anchor0])
results = build_multi_target_window_irs(
[json_path],
anchors=[
{"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
{"anchor_event_uuid": "NEVER_SEEN", "lookback_seconds": 1.0, "lookahead_seconds": 1.0},
],
cache_dir=None,
)
self.assertIn("A0", results)
self.assertNotIn("NEVER_SEEN", results)
def test_warm_cache_serves_without_scanning(self):
"""If every anchor has a cache hit, no scan is done. Use a path that
does not exist on disk to prove no scan occurs (would otherwise crash).
"""
from er_tp_dgp.theia import (
TheiaWindowIR,
_save_window_ir_to_cache,
_multi_window_cache_key,
build_multi_target_window_irs,
)
from er_tp_dgp.ir import EntityNode, EventNode
with TemporaryDirectory() as tmp:
tmp_p = Path(tmp)
cache_dir = tmp_p / "cache"
cache_dir.mkdir()
# Pre-populate the cache for anchor 'A0' with a fake window.
ent = EntityNode(node_id="S", node_type="PROCESS", stable_name="x", dataset="syn")
ev = EventNode(
event_id="A0", raw_event_id="A0", timestamp=1.0, action="OPEN",
actor_entity_id="S", object_entity_id=None, host=None,
raw_event_type="EVENT_OPEN", normalized_action="OPEN",
)
window = TheiaWindowIR(
target_event_id="A0",
target_subject_id="S",
start_timestamp_nanos=0,
end_timestamp_nanos=2_000_000_000,
entities=(ent,),
events=(ev,),
schema_gaps=(),
)
fake_json_path = tmp_p / "imaginary.json" # does not exist
cache_key = _multi_window_cache_key(
"A0", 1.0, 1.0, "DARPA_TC_E3_THEIA", [fake_json_path]
)
_save_window_ir_to_cache(cache_dir / f"{cache_key}.json.gz", window)
# If the function tries to open the imaginary file, we'd get
# FileNotFoundError. Cache hit must mean no scan.
results = build_multi_target_window_irs(
[fake_json_path],
anchors=[{"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}],
cache_dir=cache_dir,
)
self.assertIn("A0", results)
self.assertEqual(results["A0"].target_subject_id, "S")
class WindowIRCacheTests(unittest.TestCase):
"""Phase 1 IR cache round-trip."""
def test_window_ir_cache_round_trip(self):
from dataclasses import asdict as _asdict
from er_tp_dgp.ir import EntityNode, EventNode
from er_tp_dgp.theia import (
TheiaWindowIR,
_load_window_ir_from_cache,
_save_window_ir_to_cache,
)
entity = EntityNode(
node_id="e1",
node_type="PROCESS",
stable_name="/bin/bash",
dataset="syn",
host="h",
text_fields={"command_line": "bash -i"},
)
event = EventNode(
event_id="ev1",
raw_event_id="ev1",
timestamp=1.0,
action="EXEC",
actor_entity_id="e1",
object_entity_id=None,
host="h",
raw_event_type="EVENT_EXEC",
normalized_action="EXEC",
)
window = TheiaWindowIR(
target_event_id="ev1",
target_subject_id="e1",
start_timestamp_nanos=0,
end_timestamp_nanos=2 * 1_000_000_000,
entities=(entity,),
events=(event,),
schema_gaps=("foo",),
)
with TemporaryDirectory() as tmp:
path = Path(tmp) / "snap.json.gz"
_save_window_ir_to_cache(path, window)
self.assertTrue(path.exists())
roundtrip = _load_window_ir_from_cache(path)
self.assertEqual(_asdict(roundtrip.entities[0]), _asdict(entity))
self.assertEqual(_asdict(roundtrip.events[0]), _asdict(event))
self.assertEqual(roundtrip.target_event_id, "ev1")
self.assertEqual(roundtrip.schema_gaps, ("foo",))
if __name__ == "__main__":
unittest.main()