import unittest import json from dataclasses import asdict from pathlib import Path from tempfile import TemporaryDirectory from er_tp_dgp.adapters import ExplicitIRAdapter from er_tp_dgp.candidate_universe import ( build_theia_candidate_universe, select_anchor_for_candidate, select_anchors_for_lifecycle, stratified_sample, ) from er_tp_dgp.candidates import WeakSignalCandidateGenerator, evaluate_candidates from er_tp_dgp.experiments import default_method_registry, validate_method_registry from er_tp_dgp.evaluation_batch import ( build_end_to_end_evaluation_batch, build_evaluation_batch, ) from er_tp_dgp.ground_truth import extract_e3_ground_truth_atoms from er_tp_dgp.ground_truth_mapping import ( evaluate_candidate_recall, match_theia_ground_truth_atoms, ) from er_tp_dgp.labels import LabelMapper, LabelRecord from er_tp_dgp.llm import ( LLMRequestConfig, extract_openai_compatible_text, parse_classification_output, ) from er_tp_dgp.llm_config import load_llm_config from er_tp_dgp.metrics import LayeredEvaluationReport, PredictionRecord, evaluate_classification from examples.synthetic_fixture import build_synthetic_graph from er_tp_dgp.metapaths import APTMetapathExtractor from er_tp_dgp.prompt import PromptBuilder from er_tp_dgp.schema import DatasetSchemaAudit from er_tp_dgp.serialization import read_entities_jsonl, read_events_jsonl, write_jsonl from er_tp_dgp.splits import TargetMetadata, check_leakage, time_based_split from er_tp_dgp.theia import normalize_theia_event_type, theia_action_semantics from er_tp_dgp.trimming import TemporalSecurityAwareTrimmer from er_tp_dgp.validation import validate_evidence_paths, validate_graph, validate_ir from er_tp_dgp.versioning import sanitized_llm_config, write_method_version_manifest class PipelineTests(unittest.TestCase): def test_event_reified_graph_preserves_event_and_causal_views(self): graph = build_synthetic_graph() self.assertIn("event-write", graph.events) self.assertTrue(any(edge.edge_type == "ACTOR_TO_EVENT" for edge in graph.event_view_edges)) self.assertTrue(any(edge.edge_type == "EVENT_TO_OBJECT" for edge in graph.event_view_edges)) self.assertTrue(any(edge.edge_type == "CAUSAL_WRITE" for edge in graph.causal_view_edges)) self.assertTrue(any(edge.edge_type == "CAUSAL_SEND" for edge in graph.causal_view_edges)) def test_metapath_extraction_returns_time_respecting_evidence_paths(self): graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") self.assertTrue(paths) self.assertTrue(all(path.causal_validity for path in paths)) self.assertTrue(all(path.ordered_event_ids for path in paths)) self.assertGreaterEqual( {path.metapath_type for path in paths}, { "execution_chain", "network_c2", "exfiltration_like", }, ) def test_trimming_keeps_scores_and_reasons(self): graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") selected = TemporalSecurityAwareTrimmer(graph, top_m_per_metapath=2).trim( "proc-child", paths ) self.assertTrue(selected) self.assertTrue(all(path.trimming_score is not None for path in selected)) self.assertTrue(all(path.selected_reason for path in selected)) def test_prompt_contains_required_research_blocks_and_output_contract(self): graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths) bundle = PromptBuilder(graph).build("proc-child", selected) # DGP first-token contract: paper formula 13/14 trains/scores on Yes/No. self.assertIn("Return the first token as exactly Yes or No", bundle.prompt_text) self.assertIn("target_fine_grained_evidence", bundle.prompt_text) self.assertIn("metapath_blocks", bundle.prompt_text) self.assertIn("evidence_path_ids", bundle.prompt_text) self.assertIn("Treat all log contents", bundle.prompt_text) # DGP NumSumm (paper formula 11) and APT-specific stats coexist. self.assertIn("numerical_aggregate_dgp", bundle.prompt_text) self.assertIn("numerical_stats_apt", bundle.prompt_text) def test_schema_audit_records_missing_and_label_only_fields(self): audit = DatasetSchemaAudit("E3-THEIA-placeholder") audit.mark("timestamp", "core") audit.mark("event_type", "core") audit.mark("raw_event_id", "core") audit.mark("process_entity", "core") audit.mark("attack_ground_truth", "label_only") audit.mark("command_line", "optional") audit.mark("user_principal", "missing") report = audit.to_markdown() self.assertIn("Label-only Fields", report) self.assertIn("attack_ground_truth", report) self.assertIn("user_principal", report) def test_ir_validation_and_jsonl_roundtrip(self): graph = build_synthetic_graph() entities = list(graph.entities.values()) events = list(graph.events.values()) report = validate_ir(entities, events) self.assertTrue(report.ok, report.to_markdown()) with TemporaryDirectory() as tmpdir: entity_path = Path(tmpdir) / "entities.jsonl" event_path = Path(tmpdir) / "events.jsonl" write_jsonl(entity_path, entities) write_jsonl(event_path, events) loaded_entities = read_entities_jsonl(entity_path) loaded_events = read_events_jsonl(event_path) self.assertEqual(len(loaded_entities), len(entities)) self.assertEqual(len(loaded_events), len(events)) self.assertEqual(loaded_events[0].raw_event_id, events[0].raw_event_id) def test_graph_windows_and_lifecycle_helpers(self): graph = build_synthetic_graph() window = graph.subgraph_by_time_window(host="h1", start_time=2.0, end_time=4.0) self.assertEqual(set(window.events), {"event-create", "event-exec-file", "event-read"}) context = graph.target_context_window("proc-child", lookback=0.5, lookahead=2.5) self.assertIn("event-create", context.events) self.assertIn("event-read", context.events) self.assertEqual(graph.process_parent("proc-child"), "proc-parent") self.assertIn("proc-child", graph.process_children("proc-parent")) self.assertEqual(graph.entity_lifecycle("proc-child")["num_events"], 4) self.assertTrue(validate_graph(graph).ok) def test_adapter_reports_schema_and_mismatches(self): graph = build_synthetic_graph() records = [] for entity in graph.entities.values(): records.append({"record_type": "entity", "payload": asdict(entity)}) for event in graph.events.values(): records.append({"record_type": "event", "payload": asdict(event)}) adapter = ExplicitIRAdapter("synthetic-ir", known_missing_fields={"user_principal"}) result = adapter.adapt(records) self.assertTrue(result.validation_report.ok) self.assertIn("attack_ground_truth", result.mismatch_report.to_markdown()) self.assertIn("user_principal", result.schema_audit.missing_fields) def test_label_mapping_is_label_only_and_candidate_eval_is_separate(self): graph = build_synthetic_graph() mapper = LabelMapper(graph) labels = mapper.from_malicious_event_ids( {"event-read", "event-send"}, label_source="synthetic_event_ids_only", ) mapper.add_high_confidence_benign_outside_windows( labels, attack_windows=[(3.5, 5.5)], label_source="synthetic_outside_attack_window", target_type="EVENT", ) self.assertFalse(labels.get("event-send").prompt_allowed) self.assertEqual(labels.get("proc-child").label, "malicious") self.assertRaises(ValueError, LabelRecord, "x", "EVENT", "malicious", 1.0, "gt", True) candidates = WeakSignalCandidateGenerator(graph).generate_process_candidates() evaluation = evaluate_candidates(candidates, labels, target_type="PROCESS") self.assertEqual(evaluation.num_labeled_positive, 1) self.assertEqual(evaluation.recall, 1.0) def test_evidence_validation_after_trimming(self): graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths) report = validate_evidence_paths(graph, selected) self.assertTrue(report.ok, report.to_markdown()) def test_method_registry_marks_only_graph_dgp_as_main(self): registry = default_method_registry() self.assertFalse(validate_method_registry(registry)) self.assertTrue(registry["graph_dgp"].allowed_as_main) self.assertFalse(registry["target_only_llm"].allowed_as_main) self.assertFalse(registry["flat_log_llm"].uses_event_reified_graph) self.assertTrue(registry["without_evidence_ids"].uses_llm_classifier) self.assertFalse(registry["without_evidence_ids"].uses_evidence_ids) def test_classification_metrics_and_layered_report(self): predictions = [ PredictionRecord( target_id="event-send", target_type="EVENT", score=0.95, predicted_label="malicious", true_label="malicious", timestamp=5.0, campaign_id="camp-a", evidence_path_ids=("ep-1",), prompt_tokens=100, inference_cost=0.01, prompt_construction_time=0.2, ), PredictionRecord( target_id="event-read", target_type="EVENT", score=0.75, predicted_label="malicious", true_label="malicious", timestamp=4.0, campaign_id="camp-a", evidence_path_ids=("ep-2",), prompt_tokens=120, inference_cost=0.01, prompt_construction_time=0.3, ), PredictionRecord( target_id="proc-parent", target_type="PROCESS", score=0.30, predicted_label="benign", true_label="benign", timestamp=2.0, prompt_tokens=80, ), PredictionRecord( target_id="event-write", target_type="EVENT", score=0.10, predicted_label="benign", true_label="benign", timestamp=1.0, prompt_tokens=70, ), ] metrics = evaluate_classification( predictions, k_values=(1, 2), recall_levels=(1.0,), attack_start_by_campaign={"camp-a": 3.0}, ) self.assertEqual(metrics.num_positive, 2) self.assertEqual(metrics.auprc, 1.0) self.assertEqual(metrics.auroc, 1.0) self.assertEqual(metrics.recall_at_k[2], 1.0) self.assertEqual(metrics.attack_case_recall, 1.0) self.assertEqual(metrics.detection_delay, 1.0) self.assertEqual(metrics.evidence_path_hit_rate, 1.0) report = LayeredEvaluationReport(candidate_generation=None, final_classification=metrics) self.assertIn("final_classification", report.to_dict()) def test_time_split_and_leakage_checks(self): targets = [ TargetMetadata( target_id="t1", target_type="EVENT", timestamp=1.0, host="h1", campaign_id="camp-a", raw_event_ids=("raw-1",), process_ids=("pid-1",), file_paths=("/tmp/payload",), prompt_text="same prompt", summary_ids=("sum-1",), ), TargetMetadata( target_id="t2", target_type="EVENT", timestamp=2.0, host="h1", campaign_id="camp-b", raw_event_ids=("raw-2",), process_ids=("pid-2",), prompt_text="validation prompt", ), TargetMetadata( target_id="t3", target_type="EVENT", timestamp=5.0, host="h1", campaign_id="camp-a", raw_event_ids=("raw-3",), process_ids=("pid-1",), file_paths=("/tmp/payload",), prompt_text="same prompt", summary_ids=("sum-1",), ), ] assignment = time_based_split(targets, train_until=1.0, validation_until=3.0) self.assertEqual(assignment.split_by_target["t1"].value, "train") self.assertEqual(assignment.split_by_target["t3"].value, "test") report = check_leakage( targets, assignment, ioc_file_paths={"/tmp/payload"}, host_time_window=10.0, ) markdown = report.to_markdown() self.assertFalse(report.ok) self.assertIn("process_id_leakage", markdown) self.assertIn("duplicated_prompt_leakage", markdown) self.assertIn("campaign_leakage", markdown) self.assertIn("file_path_ioc_leakage", markdown) def test_openai_compatible_response_parser(self): response = { "choices": [ { "message": { "content": ( "MALICIOUS\n" "{\"score\": 0.82, \"predicted_label\": \"MALICIOUS\", " "\"involved_techniques\": [\"execution_chain\"], " "\"evidence_path_ids\": [\"ep-1\"], " "\"concise_explanation\": \"Evidence path supports the label.\", " "\"uncertainty\": \"low\", " "\"missing_fields\": [\"ground_truth_not_in_prompt\"], " "\"recommended_analyst_checks\": [\"review ep-1\"]}" ) } } ], "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } text = extract_openai_compatible_text(response) output = parse_classification_output(text) self.assertEqual(output.first_token_label, "MALICIOUS") self.assertEqual(output.predicted_label, "MALICIOUS") self.assertEqual(output.score, 0.82) self.assertEqual(output.evidence_path_ids, ("ep-1",)) self.assertEqual(output.involved_techniques, ("execution_chain",)) def test_llm_request_config_url_handling(self): self.assertEqual( LLMRequestConfig( provider_type="api", base_url="https://example.test", model="model-a", ).completions_url(), "https://example.test/v1/chat/completions", ) self.assertEqual( LLMRequestConfig( provider_type="local", base_url="http://localhost:8000/v1", model="model-b", ).completions_url(), "http://localhost:8000/v1/chat/completions", ) def test_llm_yaml_config_loading(self): with TemporaryDirectory() as tmpdir: config_path = Path(tmpdir) / "llm.yaml" config_path.write_text( "\n".join( [ "provider: api", "base_url: https://example.test/v1", "model: model-a", "api_key_env: TEST_API_KEY", "timeout_seconds: 12", "temperature: 0.1", "max_tokens: 128", "user_agent: test-agent/1.0", "extra_headers:", " X-Test: 'yes'", "extra_body:", " seed: 7", ] ), encoding="utf-8", ) config = load_llm_config(config_path) self.assertEqual(config.provider_type, "api") self.assertEqual(config.completions_url(), "https://example.test/v1/chat/completions") self.assertEqual(config.model, "model-a") self.assertEqual(config.api_key_env, "TEST_API_KEY") self.assertEqual(config.timeout_seconds, 12) self.assertEqual(config.temperature, 0.1) self.assertEqual(config.max_tokens, 128) self.assertEqual(config.user_agent, "test-agent/1.0") self.assertEqual(config.extra_headers["X-Test"], "yes") self.assertEqual(config.extra_body["seed"], 7) def test_theia_action_semantics_are_auditable(self): execute = theia_action_semantics("EVENT_EXECUTE") self.assertEqual(execute.normalized_action, "EXEC") self.assertEqual(execute.canonical_action, "PROC_EXEC_FILE") self.assertIn("execution_chain", execute.metapath_hints) mprotect = theia_action_semantics("EVENT_MPROTECT") self.assertEqual(mprotect.canonical_action, "PROC_CHANGE_MEMORY_PROTECTION") self.assertIn("memory_context", mprotect.metapath_hints) self.assertEqual(mprotect.default_risk, "low") self.assertEqual(normalize_theia_event_type("EVENT_RECVFROM"), "RECEIVE") def test_method_version_manifest_redacts_llm_config(self): with TemporaryDirectory() as tmpdir: root = Path(tmpdir) config_dir = root / "configs" config_dir.mkdir() config_path = config_dir / "llm.yaml" config_path.write_text( "\n".join( [ "provider: api", "base_url: https://example.test/v1", "model: model-a", "api_key: secret-value", ] ), encoding="utf-8", ) sanitized = sanitized_llm_config(config_path) self.assertEqual(sanitized["api_key"], "") self.assertIn("sanitized_config_sha256", sanitized) output = root / "reports" / "manifest.json" manifest = write_method_version_manifest( output, repo_root=".", llm_config_path=config_path, ) self.assertTrue(output.exists()) self.assertEqual(manifest.version, "ER-TP-DGP-v0.1") self.assertEqual(manifest.llm_config["api_key"], "") def test_theia_candidate_universe_is_label_free_and_deterministic(self): prefix = "com.bbn.tc.schema.avro.cdm18." def wrapped(record_type, payload): return {"datum": {prefix + record_type: payload}, "source": "test"} with TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "theia.json" records = [ wrapped( "Subject", { "uuid": "subject-1", "hostId": "host-1", "type": "SUBJECT_PROCESS", "cmdLine": {"string": "python /tmp/dropper.py AAAABBBBCCCCDDDDEEEEFFFFGGGGHHHH"}, "properties": {"map": {"path": "/home/admin/bin/python"}}, }, ), wrapped( "NetFlowObject", { "uuid": "flow-1", "localAddress": "10.0.0.2", "localPort": 4444, "remoteAddress": "8.8.8.8", "remotePort": 443, }, ), wrapped( "FileObject", { "uuid": "file-1", "baseObject": {"properties": {"map": {"path": "/tmp/payload"}}}, }, ), wrapped( "Event", { "uuid": "event-recv", "type": "EVENT_RECVFROM", "timestampNanos": 100, "subject": {prefix + "UUID": "subject-1"}, "predicateObject": {prefix + "UUID": "flow-1"}, }, ), wrapped( "Event", { "uuid": "event-write", "type": "EVENT_WRITE", "timestampNanos": 200, "subject": {prefix + "UUID": "subject-1"}, "predicateObject": {prefix + "UUID": "file-1"}, "predicateObjectPath": "/tmp/payload", }, ), wrapped( "Event", { "uuid": "event-exec", "type": "EVENT_EXECUTE", "timestampNanos": 300, "subject": {prefix + "UUID": "subject-1"}, "predicateObject": {prefix + "UUID": "file-1"}, "predicateObjectPath": "/tmp/payload", }, ), ] path.write_text( "\n".join(json.dumps(record, sort_keys=True) for record in records) + "\n", encoding="utf-8", ) universe = build_theia_candidate_universe([path]) candidates = universe.candidate_profiles(min_score=1.0, min_events=1) sample_a = stratified_sample(candidates, per_stratum=1, seed=11) sample_b = stratified_sample(candidates, per_stratum=1, seed=11) self.assertEqual(universe.events_seen, 3) self.assertEqual(len(candidates), 1) profile = candidates[0] self.assertEqual(profile.candidate_id, "subject-1") self.assertEqual(profile.write_then_execute_count, 1) self.assertEqual(profile.recv_then_write_count, 1) self.assertGreater(profile.external_flow_count, 0) self.assertIn("event-recv", profile.sample_raw_event_ids) self.assertIn("This is a label-free candidate universe", universe.to_markdown()) self.assertEqual([item.candidate_id for item in sample_a], [item.candidate_id for item in sample_b]) # End-to-end anchor must come from a weak-signal-triggering event, # never from ground truth. The first event ("event-recv") triggers # external_flow + network_activity, so it should be picked. anchor = select_anchor_for_candidate(profile, strategy="first_weak_signal") self.assertEqual(anchor.anchor_event_id, "event-recv") self.assertEqual(anchor.anchor_timestamp_nanos, 100) self.assertFalse(anchor.fallback_used) self.assertIn("network_activity", anchor.triggering_signals) lifecycle = select_anchors_for_lifecycle(profile, max_anchors=4) self.assertGreaterEqual(len(lifecycle), 1) self.assertEqual(lifecycle[0].anchor_event_id, "event-recv") # All anchors must come from weak-signal-triggering events or the # padding fallback — not from any external label source. self.assertTrue(all(a.anchor_event_id is not None for a in lifecycle)) # Round-trip via JSONL to make sure the anchor selector also works # against the serialized form (this is what the e2e batch builder # consumes in production). serialized = profile.to_json_dict() anchor_from_dict = select_anchor_for_candidate(serialized, strategy="first_weak_signal") self.assertEqual(anchor_from_dict.anchor_event_id, "event-recv") # Backward compatibility: a row missing weak_signal_events must not # crash; it should fall back to the first observed event. legacy_row = { "candidate_id": "legacy-1", "sample_raw_event_ids": ["legacy-event-1", "legacy-event-2"], } legacy_anchor = select_anchor_for_candidate(legacy_row, strategy="first_weak_signal") self.assertEqual(legacy_anchor.anchor_event_id, "legacy-event-1") self.assertTrue(legacy_anchor.fallback_used) def test_end_to_end_evaluation_batch_uses_no_ground_truth(self): with TemporaryDirectory() as tmp: universe_path = Path(tmp) / "universe.jsonl" universe_path.write_text( "\n".join( json.dumps(row, sort_keys=True) for row in [ { "candidate_id": "subject-positive", "process_path": "/tmp/payload", "command_line": "./payload", "weak_signal_score": 5.0, "total_events": 10, "estimated_prompt_tokens": 800, "stratum": "execution_heavy", "sample_raw_event_ids": ["evt-pos-a"], "weak_signal_events": [ { "event_id": "evt-pos-a", "timestamp_nanos": 1_000_000_000, "signals": ["external_flow"], } ], "first_event_id": "evt-pos-a", "first_event_timestamp_nanos": 1_000_000_000, }, { "candidate_id": "subject-benign", "process_path": "/usr/bin/firefox", "command_line": "/usr/bin/firefox", "weak_signal_score": 2.0, "total_events": 200, "estimated_prompt_tokens": 1500, "stratum": "browser_like", "sample_raw_event_ids": ["evt-ben-a"], "weak_signal_events": [ { "event_id": "evt-ben-a", "timestamp_nanos": 2_000_000_000, "signals": ["network_activity"], } ], "first_event_id": "evt-ben-a", "first_event_timestamp_nanos": 2_000_000_000, }, ] ) + "\n", encoding="utf-8", ) label_path = Path(tmp) / "labels.jsonl" label_path.write_text( "\n".join( json.dumps(row, sort_keys=True) for row in [ { "target_id": "subject-positive", "label": "malicious", "label_confidence": "high", "label_source": "fixture_oracle", "atom_id": "atom-1", } ] ) + "\n", encoding="utf-8", ) batch_no_labels = build_end_to_end_evaluation_batch( candidate_universe_path=universe_path, label_lookup_path=None, anchor_strategy="first_weak_signal", ) # No labels supplied → all targets are unlabeled. Anchor selection # is independent of label state. self.assertEqual(len(batch_no_labels.targets), 2) for target in batch_no_labels.targets: self.assertEqual(target.label, "unlabeled") self.assertEqual(target.label_source, "no_ground_truth_join") self.assertIsNotNone(target.anchor_event_id) self.assertIsNotNone(target.anchor_timestamp_nanos) self.assertIn( "End-to-end batch: anchor selected from raw-log weak signals", "\n".join(target.notes), ) batch_with_labels = build_end_to_end_evaluation_batch( candidate_universe_path=universe_path, label_lookup_path=label_path, anchor_strategy="first_weak_signal", ) labels_by_id = {t.target_id: t for t in batch_with_labels.targets} self.assertEqual(labels_by_id["subject-positive"].label, "malicious") self.assertEqual(labels_by_id["subject-positive"].atom_id, "atom-1") self.assertEqual(labels_by_id["subject-benign"].label, "unlabeled") # Critical invariant: anchor for the labeled positive must be # identical between the labeled and unlabeled runs — labels must # not leak into anchor selection. self.assertEqual( labels_by_id["subject-positive"].anchor_event_id, next(t for t in batch_no_labels.targets if t.target_id == "subject-positive").anchor_event_id, ) self.assertEqual( labels_by_id["subject-positive"].anchor_timestamp_nanos, next(t for t in batch_no_labels.targets if t.target_id == "subject-positive").anchor_timestamp_nanos, ) def test_ground_truth_atom_extraction_is_label_only(self): text = """ 3 Nation State 3.3 20180410 1400 THEIA - Firefox Backdoor w/ Drakon In-Memory 3.3.2 Event Log 14:11 Received connection from firefox putfile ./deploy/archive/drakon.x64_128.55.12.1 /tmp/payload connect to 1.2.3.4:80 and www.allstate.com 3.4 20180411 1000 FiveDirections - Firefox Backdoor w/ Drakon In-Memory not the target filter 4 Common Threat 4.6 20180410 1300 THEIA - Phishing E-mail w/ Link Connect to www.nasa.ng (208.75.117.3) """ report = extract_e3_ground_truth_atoms(text, target_filter="THEIA") self.assertEqual(len(report.atoms), 2) self.assertTrue(all(not atom.prompt_allowed for atom in report.atoms)) self.assertEqual(report.atoms[0].attack_group, "nation_state") self.assertIn("1.2.3.4:80", report.atoms[0].ips) self.assertIn("/tmp/payload", report.atoms[0].file_paths) self.assertIn("www.nasa.ng", report.atoms[1].domains) self.assertIn("must not enter LLM prompts", report.to_markdown()) def test_theia_ground_truth_mapping_and_candidate_recall_are_label_only(self): prefix = "com.bbn.tc.schema.avro.cdm18." atoms = extract_e3_ground_truth_atoms( """ 3 Nation State 3.3 20180410 1400 THEIA - Firefox Backdoor w/ Drakon In-Memory connect to 8.8.8.8:443 """, target_filter="THEIA", ).atoms def wrapped(record_type, payload): return {"datum": {prefix + record_type: payload}, "source": "test"} with TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) theia_path = tmp / "theia.json" records = [ wrapped( "Subject", { "uuid": "subject-gt", "hostId": "host-1", "type": "SUBJECT_PROCESS", "cmdLine": {"string": "/usr/bin/firefox"}, "properties": {"map": {"path": "/home/admin/Downloads/firefox/firefox"}}, }, ), wrapped( "NetFlowObject", { "uuid": "flow-gt", "localAddress": "10.0.0.2", "localPort": 4444, "remoteAddress": "8.8.8.8", "remotePort": 443, }, ), wrapped( "Event", { "uuid": "event-gt", "type": "EVENT_CONNECT", "timestampNanos": 1_523_379_600_000_000_000, "subject": {prefix + "UUID": "subject-gt"}, "predicateObject": {prefix + "UUID": "flow-gt"}, }, ), ] theia_path.write_text( "\n".join(json.dumps(record, sort_keys=True) for record in records) + "\n", encoding="utf-8", ) candidate_path = tmp / "candidates.jsonl" candidate_path.write_text( json.dumps({"candidate_id": "subject-gt"}) + "\n", encoding="utf-8", ) mapping = match_theia_ground_truth_atoms([theia_path], atoms, min_score=3.0) recall = evaluate_candidate_recall( candidate_path, mapping.process_labels, mapping.event_matches, ) self.assertEqual(len(mapping.event_matches), 1) self.assertEqual(mapping.event_matches[0].raw_event_id, "event-gt") self.assertFalse(mapping.event_matches[0].prompt_allowed) self.assertEqual(mapping.process_labels[0].subject_uuid, "subject-gt") self.assertEqual(recall.process_recall, 1.0) self.assertIn("label/evaluation-only", mapping.to_markdown()) def test_evaluation_batch_keeps_labels_out_of_prompt_inputs(self): with TemporaryDirectory() as tmpdir: tmp = Path(tmpdir) process_labels = tmp / "process_labels.jsonl" event_matches = tmp / "event_matches.jsonl" candidates = tmp / "candidates.jsonl" all_labels = tmp / "all_labels.jsonl" process_labels.write_text( json.dumps( { "subject_uuid": "proc-positive", "confidence": "high", "atom_id": "atom-1", "matched_event_ids": ["event-positive"], } ) + "\n", encoding="utf-8", ) all_labels.write_text(process_labels.read_text(encoding="utf-8"), encoding="utf-8") event_matches.write_text( json.dumps( { "raw_event_id": "event-positive", "score": 9.0, "subject_path": "/tmp/payload", "command_line": "/tmp/payload", } ) + "\n", encoding="utf-8", ) candidates.write_text( "\n".join( [ json.dumps( { "candidate_id": "proc-positive", "sample_raw_event_ids": ["event-positive"], "weak_signal_score": 10.0, } ), json.dumps( { "candidate_id": "proc-hard-negative", "sample_raw_event_ids": ["event-negative"], "weak_signal_score": 8.0, "process_path": "/usr/bin/ssh", } ), ] ) + "\n", encoding="utf-8", ) batch = build_evaluation_batch( positive_process_labels_path=process_labels, positive_event_matches_path=event_matches, candidate_universe_path=candidates, all_mapped_process_labels_path=all_labels, num_positives=1, num_hard_negative_proxies=1, max_hard_negative_events=1000, seed=1, ) self.assertEqual(len(batch.targets), 2) positive = [target for target in batch.targets if target.label == "malicious"][0] negative = [target for target in batch.targets if target.cohort == "hard_negative_proxy"][0] self.assertFalse(positive.prompt_allowed_label_fields) self.assertEqual(positive.anchor_event_id, "event-positive") self.assertEqual(negative.label, "benign_proxy") self.assertIn("not a high-confidence benign label", " ".join(negative.notes)) class DGPComponentTests(unittest.TestCase): """Coverage for DGP-aligned algorithmic components added in Phase 2/3.""" def test_node_text_summarizer_caches_by_sha256(self): from er_tp_dgp.text_summarizer import NodeTextSummarizer, SummarizerConfig class _RecordingLLM: def __init__(self) -> None: self.calls: list[tuple[str, int]] = [] def complete(self, prompt: str, *, max_tokens: int) -> str: self.calls.append((prompt, max_tokens)) return "summary-of-" + prompt[-12:] with TemporaryDirectory() as tmp: cfg = SummarizerConfig( b_node=10, model_name="testmodel", cache_dir=Path(tmp) ) llm = _RecordingLLM() summ = NodeTextSummarizer(llm=llm, config=cfg) first = summ.summarize("command_line=/usr/bin/python /tmp/payload --sync") second = summ.summarize("command_line=/usr/bin/python /tmp/payload --sync") third = summ.summarize("command_line=/bin/bash -i") self.assertEqual(first, second) self.assertEqual(len(llm.calls), 2, "duplicate text must be served from cache") self.assertNotEqual(first, third) cache_files = list(Path(tmp).iterdir()) self.assertEqual(len(cache_files), 2) def test_node_summarize_batch_concurrency_and_cache(self): """Concurrent batch must (a) call LLM exactly once per unique key, (b) preserve input order, (c) be safe against same-key duplicates. """ import time from er_tp_dgp.text_summarizer import NodeTextSummarizer, SummarizerConfig class _SlowLLM: def __init__(self) -> None: self.calls: list[str] = [] self._lock = __import__("threading").Lock() def complete(self, prompt: str, *, max_tokens: int) -> str: # noqa: ARG002 with self._lock: self.calls.append(prompt) # Sleep simulates network latency. With 4 workers and 5 unique # keys, sequential would take 5 * 0.05 = 0.25s; concurrent # should be roughly 2 * 0.05 ~= 0.10s. time.sleep(0.05) return f"sum:{prompt[-8:]}" with TemporaryDirectory() as tmp: cfg = SummarizerConfig( b_node=10, model_name="t", cache_dir=Path(tmp), max_workers=4 ) llm = _SlowLLM() summ = NodeTextSummarizer(llm=llm, config=cfg) # 5 unique + 2 duplicates of unique[0] = 7 inputs but only 5 LLM calls. inputs = [ "alpha-text-aaaa", "beta-text-bbbb", "gamma-text-gggg", "alpha-text-aaaa", # duplicate "delta-text-dddd", "epsilon-text-eeee", "alpha-text-aaaa", # duplicate ] t0 = time.monotonic() results = summ.summarize_batch(inputs) elapsed = time.monotonic() - t0 self.assertEqual(len(results), len(inputs)) self.assertEqual(results[0], results[3]) self.assertEqual(results[0], results[6]) self.assertEqual(len(llm.calls), 5, f"expected 5 unique LLM calls, got {len(llm.calls)}") # Concurrency with 4 workers on 5 calls should be ~0.10s, not ~0.25s. self.assertLess(elapsed, 0.20, f"batch was not concurrent (elapsed={elapsed:.3f}s)") # Re-running the same batch must hit the disk cache and skip the LLM. results2 = summ.summarize_batch(inputs) self.assertEqual(results, results2) self.assertEqual(len(llm.calls), 5) def test_metapath_text_summarizer_combines_neighbor_summaries(self): from er_tp_dgp.text_summarizer import MetapathTextSummarizer, SummarizerConfig class _Joiner: def complete(self, prompt: str, *, max_tokens: int) -> str: # noqa: ARG002 return prompt.rsplit("\n", 1)[-2][:60] with TemporaryDirectory() as tmp: cfg = SummarizerConfig(b_meta=8, model_name="t", cache_dir=Path(tmp)) summ = MetapathTextSummarizer(llm=_Joiner(), config=cfg) result = summ.summarize_metapath("execution_chain", ["proc=python", "proc=bash"]) self.assertNotEqual(result, "") self.assertEqual( summ.summarize_metapath("execution_chain", []), "" ) def test_markov_diffusion_trimmer_selects_paths_through_top_m_neighbors(self): try: import numpy # noqa: F401 except ImportError: self.skipTest("numpy not installed; skipping MDK trimmer test") from er_tp_dgp.diffusion_trimmer import ( HashingEmbedder, MarkovDiffusionTrimmer, MDKConfig, ) graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") trimmer = MarkovDiffusionTrimmer( graph, embedder=HashingEmbedder(dim=32), config=MDKConfig(k_hops=3, top_m=2), ) kept = trimmer.trim("proc-child", paths) self.assertGreater(len(kept), 0) for path in kept: self.assertTrue(path.causal_validity) self.assertIsNotNone(path.selected_reason) self.assertIn("mdk(", path.selected_reason or "") def test_numerical_aggregator_outputs_means_and_histograms(self): from er_tp_dgp.numerical_aggregator import NumericalAggregator graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") # Use the executions chain block specifically. exec_paths = [p for p in paths if p.metapath_type == "execution_chain"] agg = NumericalAggregator(graph).aggregate("execution_chain", exec_paths) self.assertGreater(agg.neighbor_count, 0) d = agg.to_prompt_dict() self.assertIn("node_type_hist", d) self.assertIn("action_hist", d) self.assertAlmostEqual(sum(d["node_type_hist"].values()), 1.0, places=5) self.assertAlmostEqual(sum(d["action_hist"].values()), 1.0, places=5) def test_score_from_top_logprobs_softmax(self): from er_tp_dgp.scoring import score_from_top_logprobs result = score_from_top_logprobs( [ {"token": "Yes", "logprob": -0.10}, {"token": " No", "logprob": -2.30}, {"token": "the", "logprob": -5.0}, ] ) self.assertIsNotNone(result.score) self.assertGreater(result.score, 0.85) self.assertEqual(result.matched_yes_token, "Yes") self.assertEqual(result.matched_no_token, " No") self.assertFalse(result.fallback_used) empty = score_from_top_logprobs([{"token": "the", "logprob": -1.0}]) self.assertIsNone(empty.score) self.assertTrue(empty.fallback_used) def test_resolve_auto_model_class_routes_by_architecture(self): """Routing must pick image_text_to_text for Qwen3.5-style multimodal.""" try: from transformers import ( # noqa: F401 AutoModelForCausalLM as ACL, AutoModelForImageTextToText as AITT, AutoModelForSeq2SeqLM as AS2S, ) except ImportError: self.skipTest("transformers not installed; skipping resolver test") from er_tp_dgp.llm import _resolve_auto_model_class class _FakeCfg: def __init__(self, architectures, has_vision): self.architectures = architectures self.vision_config = {"hidden_size": 1} if has_vision else None self.assertIs( _resolve_auto_model_class("auto", _FakeCfg(["Qwen3ForCausalLM"], False)), ACL, ) # Multimodal conditional generation with vision_config (Qwen3.5-27B). self.assertIs( _resolve_auto_model_class( "auto", _FakeCfg(["Qwen3_5ForConditionalGeneration"], True) ), AITT, ) # Conditional generation without vision (T5-style). self.assertIs( _resolve_auto_model_class("auto", _FakeCfg(["T5ForConditionalGeneration"], False)), AS2S, ) # Explicit override wins regardless of config. self.assertIs( _resolve_auto_model_class("causal_lm", _FakeCfg(["Qwen3_5ForConditionalGeneration"], True)), ACL, ) def test_yes_no_protocol_parses_to_canonical_labels(self): output = parse_classification_output( "Yes\n{\"predicted_label\": \"Yes\", \"evidence_path_ids\": [\"ep-1\"]}" ) self.assertEqual(output.first_token_label, "MALICIOUS") self.assertEqual(output.predicted_label, "MALICIOUS") output_no = parse_classification_output("No\n{}") self.assertEqual(output_no.first_token_label, "BENIGN") def test_dgp_prompt_block_layout(self): from er_tp_dgp.prompt import PromptComponentSwitches graph = build_synthetic_graph() paths = APTMetapathExtractor(graph).extract_for_target("proc-child") selected = TemporalSecurityAwareTrimmer(graph).trim("proc-child", paths) # without_path_summ ablation: drop LLM PathSumm, keep concat. ablate = PromptBuilder( graph, switches=PromptComponentSwitches( use_text_summarization=False, use_path_summarization_llm=False, ), ).build("proc-child", selected) self.assertIn("path_summary_concat", ablate.prompt_text) self.assertNotIn("\"text_summary\"", ablate.prompt_text) class AblationRegistryTests(unittest.TestCase): def test_registry_contains_dgp_paper_aligned_ablations(self): registry = default_method_registry() for name in ( "graph_dgp", "without_dgp_text_summ", "without_dgp_mdk", "without_dgp_path_summ", "without_dgp_num_summ", ): self.assertIn(name, registry, f"missing variant {name}") def test_main_variant_requires_all_four_dgp_components(self): # Build a deliberately broken main variant to confirm validation fires. from er_tp_dgp.experiments import MethodFamily, MethodVariant broken = MethodVariant( name="graph_dgp_broken", family=MethodFamily.MAIN, description="missing text summ", uses_event_reified_graph=True, uses_target_fine_grained=True, uses_local_context=True, uses_time_respecting_metapaths=True, uses_temporal_trimming=True, uses_security_aware_trimming=True, uses_metapath_summary=True, uses_node_level_summary=True, uses_numerical_summary=True, uses_evidence_ids=True, uses_llm_classifier=True, uses_dgp_text_summarization=False, # the deliberate gap uses_dgp_diffusion_trimming=True, uses_dgp_path_summarization_llm=True, uses_dgp_numerical_aggregation=True, allowed_as_main=True, ) issues = broken.validate_role() self.assertTrue(issues) self.assertIn("uses_dgp_text_summarization", issues[0]) class MultiAnchorWindowIRTests(unittest.TestCase): """Phase 1.5: build_multi_target_window_irs single-scan multi-anchor builder.""" def _write_cdm18(self, dir_path: Path, lines: list[dict]) -> Path: """Write CDM18-shaped JSONL the same way iter_theia_records expects to read it.""" path = dir_path / "ta1-theia-fake.json" with path.open("w", encoding="utf-8") as f: for line in lines: f.write(json.dumps(line) + "\n") return path def _record(self, full_type: str, payload: dict) -> dict: return {"datum": {f"com.bbn.tc.schema.avro.cdm18.{full_type}": payload}} def test_two_anchors_one_pass_demuxes_events_per_window(self): from er_tp_dgp.theia import build_multi_target_window_irs with TemporaryDirectory() as tmp: tmp_p = Path(tmp) # Two anchor events at very different times. # A0 at t=1_000_000_000_000 (ns) → window [t-1s, t+1s] # A1 at t=5_000_000_000_000 (ns) → window [t-1s, t+1s] # Plus events near A0 (in window), near A1 (in window), and one far # from both (out of all windows). host = self._record("Host", {"uuid": "H", "hostName": "h1"}) principal = self._record("Principal", {"uuid": "P"}) subj_a = self._record("Subject", {"uuid": "Sa", "type": "SUBJECT_PROCESS", "cmdLine": {"string": "/bin/a"}}) subj_b = self._record("Subject", {"uuid": "Sb", "type": "SUBJECT_PROCESS", "cmdLine": {"string": "/bin/b"}}) file_obj = self._record("FileObject", {"uuid": "F", "baseObject": {}}) anchor0 = self._record("Event", { "uuid": "A0", "type": "EVENT_OPEN", "timestampNanos": 1_000_000_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"}, "predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"}, }) near_a0 = self._record("Event", { "uuid": "E_NEAR_A0", "type": "EVENT_READ", "timestampNanos": 1_000_500_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"}, "predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"}, }) anchor1 = self._record("Event", { "uuid": "A1", "type": "EVENT_EXECUTE", "timestampNanos": 5_000_000_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sb"}, }) near_a1 = self._record("Event", { "uuid": "E_NEAR_A1", "type": "EVENT_READ", "timestampNanos": 5_000_500_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sb"}, "predicateObject": {"com.bbn.tc.schema.avro.cdm18.UUID": "F"}, }) far = self._record("Event", { "uuid": "E_FAR", "type": "EVENT_READ", "timestampNanos": 9_000_000_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "Sa"}, }) json_path = self._write_cdm18( tmp_p, [host, principal, subj_a, subj_b, file_obj, anchor0, near_a0, anchor1, near_a1, far], ) cache_dir = tmp_p / "cache" results = build_multi_target_window_irs( [json_path], anchors=[ {"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, {"anchor_event_uuid": "A1", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, ], cache_dir=cache_dir, ) self.assertIn("A0", results) self.assertIn("A1", results) self.assertEqual(results["A0"].target_subject_id, "Sa") self.assertEqual(results["A1"].target_subject_id, "Sb") # A0 window must contain only A0 + E_NEAR_A0. a0_event_ids = {e.event_id for e in results["A0"].events} self.assertIn("A0", a0_event_ids) self.assertIn("E_NEAR_A0", a0_event_ids) self.assertNotIn("A1", a0_event_ids) self.assertNotIn("E_NEAR_A1", a0_event_ids) self.assertNotIn("E_FAR", a0_event_ids) # A1 window must contain only A1 + E_NEAR_A1. a1_event_ids = {e.event_id for e in results["A1"].events} self.assertIn("A1", a1_event_ids) self.assertIn("E_NEAR_A1", a1_event_ids) self.assertNotIn("A0", a1_event_ids) self.assertNotIn("E_FAR", a1_event_ids) # Cache files must exist for both anchors. cache_files = list(cache_dir.glob("*.json.gz")) self.assertEqual(len(cache_files), 2) # Re-running with same args must hit the cache (no scan needed). cached_results = build_multi_target_window_irs( [json_path], anchors=[ {"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, {"anchor_event_uuid": "A1", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, ], cache_dir=cache_dir, ) self.assertEqual( {e.event_id for e in cached_results["A0"].events}, a0_event_ids, ) def test_missing_anchor_is_silently_omitted(self): from er_tp_dgp.theia import build_multi_target_window_irs with TemporaryDirectory() as tmp: tmp_p = Path(tmp) # Corpus contains exactly one anchor; the second is not present. host = self._record("Host", {"uuid": "H"}) subj = self._record("Subject", {"uuid": "S", "type": "SUBJECT_PROCESS"}) anchor0 = self._record("Event", { "uuid": "A0", "type": "EVENT_OPEN", "timestampNanos": 1_000_000_000_000, "subject": {"com.bbn.tc.schema.avro.cdm18.UUID": "S"}, }) json_path = self._write_cdm18(tmp_p, [host, subj, anchor0]) results = build_multi_target_window_irs( [json_path], anchors=[ {"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, {"anchor_event_uuid": "NEVER_SEEN", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}, ], cache_dir=None, ) self.assertIn("A0", results) self.assertNotIn("NEVER_SEEN", results) def test_warm_cache_serves_without_scanning(self): """If every anchor has a cache hit, no scan is done. Use a path that does not exist on disk to prove no scan occurs (would otherwise crash). """ from er_tp_dgp.theia import ( TheiaWindowIR, _save_window_ir_to_cache, _multi_window_cache_key, build_multi_target_window_irs, ) from er_tp_dgp.ir import EntityNode, EventNode with TemporaryDirectory() as tmp: tmp_p = Path(tmp) cache_dir = tmp_p / "cache" cache_dir.mkdir() # Pre-populate the cache for anchor 'A0' with a fake window. ent = EntityNode(node_id="S", node_type="PROCESS", stable_name="x", dataset="syn") ev = EventNode( event_id="A0", raw_event_id="A0", timestamp=1.0, action="OPEN", actor_entity_id="S", object_entity_id=None, host=None, raw_event_type="EVENT_OPEN", normalized_action="OPEN", ) window = TheiaWindowIR( target_event_id="A0", target_subject_id="S", start_timestamp_nanos=0, end_timestamp_nanos=2_000_000_000, entities=(ent,), events=(ev,), schema_gaps=(), ) fake_json_path = tmp_p / "imaginary.json" # does not exist cache_key = _multi_window_cache_key( "A0", 1.0, 1.0, "DARPA_TC_E3_THEIA", [fake_json_path] ) _save_window_ir_to_cache(cache_dir / f"{cache_key}.json.gz", window) # If the function tries to open the imaginary file, we'd get # FileNotFoundError. Cache hit must mean no scan. results = build_multi_target_window_irs( [fake_json_path], anchors=[{"anchor_event_uuid": "A0", "lookback_seconds": 1.0, "lookahead_seconds": 1.0}], cache_dir=cache_dir, ) self.assertIn("A0", results) self.assertEqual(results["A0"].target_subject_id, "S") class WindowIRCacheTests(unittest.TestCase): """Phase 1 IR cache round-trip.""" def test_window_ir_cache_round_trip(self): from dataclasses import asdict as _asdict from er_tp_dgp.ir import EntityNode, EventNode from er_tp_dgp.theia import ( TheiaWindowIR, _load_window_ir_from_cache, _save_window_ir_to_cache, ) entity = EntityNode( node_id="e1", node_type="PROCESS", stable_name="/bin/bash", dataset="syn", host="h", text_fields={"command_line": "bash -i"}, ) event = EventNode( event_id="ev1", raw_event_id="ev1", timestamp=1.0, action="EXEC", actor_entity_id="e1", object_entity_id=None, host="h", raw_event_type="EVENT_EXEC", normalized_action="EXEC", ) window = TheiaWindowIR( target_event_id="ev1", target_subject_id="e1", start_timestamp_nanos=0, end_timestamp_nanos=2 * 1_000_000_000, entities=(entity,), events=(event,), schema_gaps=("foo",), ) with TemporaryDirectory() as tmp: path = Path(tmp) / "snap.json.gz" _save_window_ir_to_cache(path, window) self.assertTrue(path.exists()) roundtrip = _load_window_ir_from_cache(path) self.assertEqual(_asdict(roundtrip.entities[0]), _asdict(entity)) self.assertEqual(_asdict(roundtrip.events[0]), _asdict(event)) self.assertEqual(roundtrip.target_event_id, "ev1") self.assertEqual(roundtrip.schema_gaps, ("foo",)) if __name__ == "__main__": unittest.main()