Initial commit: ER-TP-DGP research prototype
Event-Reified Temporal Provenance Dual-Granularity Prompting for LLM-based APT detection on DARPA provenance datasets. Includes phase 0-14 method spec, IR/graph/metapath/trimming/prompt modules, scripts for THEIA candidate universe, landmark CSG construction, hybrid prompting, and LLM inference. Excludes data/, reports/, and local LLM config from version control.
This commit is contained in:
310
scripts/anchor_coverage_audit.py
Normal file
310
scripts/anchor_coverage_audit.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Quantify the gap between oracle (GT-derived) and end-to-end anchor selection.
|
||||
|
||||
For each ground-truth-malicious process target, compare:
|
||||
- oracle anchor: the anchor recorded in the oracle labeled-targets JSONL
|
||||
(from `import_orthrus_ground_truth.py` or the GT-event-match builder).
|
||||
The oracle path picks anchors using ground-truth event matches and so
|
||||
cannot be deployed in production.
|
||||
- end-to-end anchor: produced from raw-log weak signals only, using
|
||||
``select_anchor_for_candidate`` over the candidate universe. This is
|
||||
deployable.
|
||||
|
||||
Outputs per-subject rows and an aggregate report:
|
||||
- end-to-end anchor recall under a fixed lookback/lookahead window:
|
||||
fraction of GT-malicious subjects for which the end-to-end window
|
||||
[t_e2e - L, t_e2e + L] contains the oracle anchor's timestamp (a
|
||||
proxy for "would the LLM see at least one ground-truth attack
|
||||
event in its window?")
|
||||
- delta_seconds distribution: |t_oracle - t_e2e|
|
||||
- reasons for fallback / failure (no weak signal recorded, no events
|
||||
indexed, etc.)
|
||||
|
||||
The number this audit publishes is the ceiling on end-to-end LLM recall
|
||||
for the current weak-signal anchor strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import statistics
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Allow running as a standalone script without `pip install -e .`.
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from er_tp_dgp.candidate_universe import select_anchor_for_candidate # noqa: E402
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Compare oracle (GT-derived) vs end-to-end (weak-signal) anchor "
|
||||
"selection. Reports the recall ceiling for the deployable anchor."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oracle-targets",
|
||||
required=True,
|
||||
help=(
|
||||
"Path to oracle labeled_targets JSONL (e.g., the orthrus output). "
|
||||
"Must contain target_id, anchor_event_id, anchor_timestamp_nanos, label."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--candidate-universe",
|
||||
required=True,
|
||||
help="Path to candidate-universe JSONL (with weak_signal_events field).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--anchor-strategy",
|
||||
default="first_weak_signal",
|
||||
choices=("first_weak_signal", "first_event"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lookback-seconds",
|
||||
type=float,
|
||||
default=300.0,
|
||||
help="Window half-width used to score 'oracle event inside e2e window'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lookahead-seconds",
|
||||
type=float,
|
||||
default=300.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-jsonl",
|
||||
required=True,
|
||||
help="Per-subject comparison rows.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-markdown",
|
||||
required=True,
|
||||
help="Aggregate audit report.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
oracle_rows = _read_jsonl(args.oracle_targets)
|
||||
universe_index = _index_universe(args.candidate_universe)
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for oracle in oracle_rows:
|
||||
if oracle.get("label") != "malicious":
|
||||
continue
|
||||
target_id = oracle.get("target_id")
|
||||
if not target_id:
|
||||
continue
|
||||
oracle_ts = oracle.get("anchor_timestamp_nanos")
|
||||
oracle_event_id = oracle.get("anchor_event_id")
|
||||
if not isinstance(oracle_ts, int) or not oracle_event_id:
|
||||
continue
|
||||
|
||||
profile_row = universe_index.get(target_id)
|
||||
if profile_row is None:
|
||||
rows.append(
|
||||
{
|
||||
"target_id": target_id,
|
||||
"in_candidate_universe": False,
|
||||
"oracle_anchor_event_id": oracle_event_id,
|
||||
"oracle_anchor_timestamp_nanos": oracle_ts,
|
||||
"e2e_anchor_event_id": None,
|
||||
"e2e_anchor_timestamp_nanos": None,
|
||||
"delta_seconds": None,
|
||||
"oracle_inside_e2e_window": False,
|
||||
"fallback_used": None,
|
||||
"anchor_strategy": args.anchor_strategy,
|
||||
"reason": "candidate_not_in_universe",
|
||||
"atom_id": oracle.get("atom_id"),
|
||||
"process_path": oracle.get("process_path"),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
anchor = select_anchor_for_candidate(profile_row, strategy=args.anchor_strategy)
|
||||
e2e_event_id = anchor.anchor_event_id
|
||||
e2e_ts = anchor.anchor_timestamp_nanos
|
||||
delta_seconds: float | None = None
|
||||
inside = False
|
||||
if isinstance(e2e_ts, int):
|
||||
delta_ns = oracle_ts - e2e_ts
|
||||
delta_seconds = delta_ns / 1_000_000_000
|
||||
window_start = e2e_ts - int(args.lookback_seconds * 1_000_000_000)
|
||||
window_end = e2e_ts + int(args.lookahead_seconds * 1_000_000_000)
|
||||
inside = window_start <= oracle_ts <= window_end
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"target_id": target_id,
|
||||
"in_candidate_universe": True,
|
||||
"oracle_anchor_event_id": oracle_event_id,
|
||||
"oracle_anchor_timestamp_nanos": oracle_ts,
|
||||
"e2e_anchor_event_id": e2e_event_id,
|
||||
"e2e_anchor_timestamp_nanos": e2e_ts,
|
||||
"delta_seconds": delta_seconds,
|
||||
"oracle_inside_e2e_window": inside,
|
||||
"fallback_used": anchor.fallback_used,
|
||||
"anchor_strategy": anchor.strategy,
|
||||
"triggering_signals": list(anchor.triggering_signals),
|
||||
"weak_signal_events_count": len(profile_row.get("weak_signal_events") or []),
|
||||
"weak_signal_events_truncated": bool(profile_row.get("weak_signal_events_truncated")),
|
||||
"reason": anchor.reason,
|
||||
"atom_id": oracle.get("atom_id"),
|
||||
"process_path": oracle.get("process_path"),
|
||||
"weak_signal_score": profile_row.get("weak_signal_score"),
|
||||
}
|
||||
)
|
||||
|
||||
Path(args.out_jsonl).parent.mkdir(parents=True, exist_ok=True)
|
||||
with Path(args.out_jsonl).open("w", encoding="utf-8") as out:
|
||||
for row in rows:
|
||||
out.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
|
||||
Path(args.out_markdown).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.out_markdown).write_text(
|
||||
_render_markdown(rows, args), encoding="utf-8"
|
||||
)
|
||||
|
||||
summary = _summarize(rows)
|
||||
print(
|
||||
"[anchor-coverage] subjects={total} in_universe={in_u} "
|
||||
"anchor_inside_window={inside} fallback={fb} no_anchor={no_anchor}".format(
|
||||
total=summary["total"],
|
||||
in_u=summary["in_universe"],
|
||||
inside=summary["inside_window"],
|
||||
fb=summary["fallback_used"],
|
||||
no_anchor=summary["no_e2e_anchor"],
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _index_universe(path: str) -> dict[str, dict[str, Any]]:
|
||||
index: dict[str, dict[str, Any]] = {}
|
||||
with Path(path).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
cid = row.get("candidate_id") or row.get("target_id")
|
||||
if cid:
|
||||
index[str(cid)] = row
|
||||
return index
|
||||
|
||||
|
||||
def _read_jsonl(path: str) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
with Path(path).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _summarize(rows: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
total = len(rows)
|
||||
in_universe = sum(1 for r in rows if r["in_candidate_universe"])
|
||||
inside_window = sum(1 for r in rows if r.get("oracle_inside_e2e_window"))
|
||||
fallback_used = sum(1 for r in rows if r.get("fallback_used"))
|
||||
no_e2e_anchor = sum(1 for r in rows if r.get("e2e_anchor_event_id") is None)
|
||||
deltas = [r["delta_seconds"] for r in rows if isinstance(r.get("delta_seconds"), (int, float))]
|
||||
abs_deltas = [abs(d) for d in deltas]
|
||||
summary = {
|
||||
"total": total,
|
||||
"in_universe": in_universe,
|
||||
"inside_window": inside_window,
|
||||
"fallback_used": fallback_used,
|
||||
"no_e2e_anchor": no_e2e_anchor,
|
||||
"anchor_recall_at_window": (inside_window / total) if total else None,
|
||||
"abs_delta_seconds_median": statistics.median(abs_deltas) if abs_deltas else None,
|
||||
"abs_delta_seconds_p90": _percentile(abs_deltas, 0.9) if abs_deltas else None,
|
||||
"abs_delta_seconds_p99": _percentile(abs_deltas, 0.99) if abs_deltas else None,
|
||||
"abs_delta_seconds_max": max(abs_deltas) if abs_deltas else None,
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def _percentile(values: list[float], q: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round(q * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
def _render_markdown(rows: list[dict[str, Any]], args: argparse.Namespace) -> str:
|
||||
summary = _summarize(rows)
|
||||
lines = [
|
||||
"# Anchor Coverage Audit",
|
||||
"",
|
||||
"This audit measures the deployable anchor strategy against the GT-derived",
|
||||
"oracle anchor. The headline number is `anchor_recall_at_window` — the",
|
||||
"ceiling on end-to-end LLM recall under the chosen anchor strategy and",
|
||||
"lookback/lookahead window.",
|
||||
"",
|
||||
f"- oracle_targets: `{args.oracle_targets}`",
|
||||
f"- candidate_universe: `{args.candidate_universe}`",
|
||||
f"- anchor_strategy: `{args.anchor_strategy}`",
|
||||
f"- lookback_seconds: {args.lookback_seconds}",
|
||||
f"- lookahead_seconds: {args.lookahead_seconds}",
|
||||
"",
|
||||
"## Aggregate",
|
||||
"",
|
||||
f"- ground_truth_positive_subjects: {summary['total']}",
|
||||
f"- in_candidate_universe: {summary['in_universe']}",
|
||||
f"- end_to_end_anchor_resolved: {summary['total'] - summary['no_e2e_anchor']}",
|
||||
f"- end_to_end_anchor_used_fallback: {summary['fallback_used']}",
|
||||
f"- oracle_anchor_inside_e2e_window: {summary['inside_window']}",
|
||||
(
|
||||
"- **anchor_recall_at_window**: "
|
||||
f"{summary['anchor_recall_at_window']:.3f}"
|
||||
if summary["anchor_recall_at_window"] is not None
|
||||
else "- anchor_recall_at_window: n/a"
|
||||
),
|
||||
"",
|
||||
"## |delta_seconds| distribution (oracle_ts - e2e_ts)",
|
||||
"",
|
||||
f"- median: {summary['abs_delta_seconds_median']}",
|
||||
f"- p90: {summary['abs_delta_seconds_p90']}",
|
||||
f"- p99: {summary['abs_delta_seconds_p99']}",
|
||||
f"- max: {summary['abs_delta_seconds_max']}",
|
||||
"",
|
||||
"## Failure breakdown",
|
||||
"",
|
||||
]
|
||||
failures = [r for r in rows if not r.get("oracle_inside_e2e_window")]
|
||||
if not failures:
|
||||
lines.append("- (none)")
|
||||
else:
|
||||
reasons: dict[str, int] = {}
|
||||
for r in failures:
|
||||
key = r.get("reason") or "unknown"
|
||||
reasons[key] = reasons.get(key, 0) + 1
|
||||
for reason, count in sorted(reasons.items(), key=lambda kv: -kv[1]):
|
||||
lines.append(f"- {reason}: {count}")
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Interpretation",
|
||||
"",
|
||||
"- If `anchor_recall_at_window` is well below 1.0, the anchor strategy",
|
||||
" is the bottleneck — even a perfect LLM cannot exceed this number.",
|
||||
" Either widen the window, switch to multi-anchor lifecycle tiling",
|
||||
" (`select_anchors_for_lifecycle`), or expand the weak-signal set.",
|
||||
"- If `fallback_used` is high, many candidates have no weak-signal",
|
||||
" trigger at all; consider whether they should be filtered out of",
|
||||
" the candidate universe or treated as low-priority.",
|
||||
"- The oracle column shows what the GT-coupled pipeline was",
|
||||
" effectively assuming — any AUPRC delta between GT-anchored runs",
|
||||
" and end-to-end runs lower-bounds the oracle leakage.",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
348
scripts/build_hybrid_community_prompts.py
Normal file
348
scripts/build_hybrid_community_prompts.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Render one hybrid (community + v0.1 fine-grained) prompt per landmark community.
|
||||
|
||||
Hybrid pipeline:
|
||||
1) read landmark communities from Phase 14 output;
|
||||
2) re-stream the raw THEIA corpus once and demux each event into
|
||||
per-community fine-grained subgraphs (community_to_subgraph);
|
||||
3) on each subgraph, run v0.1 APT metapath extraction +
|
||||
temporal-security-aware trimming;
|
||||
4) compose a layered prompt: community overview + landmark skeleton
|
||||
+ landmark bridges + per-metapath blocks (DGP path summary +
|
||||
numerical aggregate + APT stats + evidence path ids).
|
||||
|
||||
Writes:
|
||||
- prompts/<community_id>.txt
|
||||
- prompt_metadata.jsonl — one row per prompt with label + community
|
||||
summary + subgraph stats (entity / event counts, truncation flag,
|
||||
metapath hits). Labels are attached to metadata only, never enter
|
||||
the prompt body.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from er_tp_dgp.community_to_subgraph import build_community_subgraphs # noqa: E402
|
||||
from er_tp_dgp.hybrid_prompt import ( # noqa: E402
|
||||
HybridCommunityPromptBuilder,
|
||||
HybridPromptSwitches,
|
||||
)
|
||||
from er_tp_dgp.landmark import ( # noqa: E402
|
||||
LandmarkEdge,
|
||||
LandmarkEvent,
|
||||
read_communities_jsonl,
|
||||
)
|
||||
from er_tp_dgp.theia import discover_theia_json_files # noqa: E402
|
||||
|
||||
|
||||
def _stream_filter_landmarks(
|
||||
path: Path, allowed_ids: set[str]
|
||||
) -> dict[str, LandmarkEvent]:
|
||||
"""Stream-read landmarks.jsonl and keep only rows whose event_id is in allowed_ids.
|
||||
|
||||
The landmarks file is multi-GB on real datasets — a full ``read_landmarks_jsonl``
|
||||
eats hundreds of GB of RAM and minutes of wall time. We need only the
|
||||
landmarks referenced by the selected communities.
|
||||
"""
|
||||
out: dict[str, LandmarkEvent] = {}
|
||||
if not allowed_ids:
|
||||
return out
|
||||
needed = set(allowed_ids)
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
event_id = r.get("event_id")
|
||||
if event_id not in needed:
|
||||
continue
|
||||
out[event_id] = LandmarkEvent(
|
||||
event_id=event_id,
|
||||
timestamp_nanos=r["timestamp_nanos"],
|
||||
host_id=r.get("host_id"),
|
||||
actor_subject_id=r["actor_subject_id"],
|
||||
actor_path=r.get("actor_path"),
|
||||
object_id=r.get("object_id"),
|
||||
object_type=r.get("object_type"),
|
||||
object_summary=r.get("object_summary"),
|
||||
canonical_action=r["canonical_action"],
|
||||
raw_event_type=r["raw_event_type"],
|
||||
signals=tuple(r.get("signals") or ()),
|
||||
metapath_hints=tuple(r.get("metapath_hints") or ()),
|
||||
landmark_classes=tuple(r.get("landmark_classes") or ()),
|
||||
)
|
||||
if len(out) == len(needed):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _stream_filter_edges(
|
||||
path: Path, allowed_ids: set[str]
|
||||
) -> dict[str, LandmarkEdge]:
|
||||
"""Stream-read landmark_edges.jsonl with allowed_ids filter."""
|
||||
out: dict[str, LandmarkEdge] = {}
|
||||
if not allowed_ids:
|
||||
return out
|
||||
needed = set(allowed_ids)
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
edge_id = r.get("edge_id")
|
||||
if edge_id not in needed:
|
||||
continue
|
||||
out[edge_id] = LandmarkEdge(
|
||||
edge_id=edge_id,
|
||||
src_event_id=r["src_event_id"],
|
||||
dst_event_id=r["dst_event_id"],
|
||||
host_id=r.get("host_id"),
|
||||
delta_nanos=r["delta_nanos"],
|
||||
bridge_hops=r["bridge_hops"],
|
||||
bridge_summary=r["bridge_summary"],
|
||||
)
|
||||
if len(out) == len(needed):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--communities", required=True)
|
||||
parser.add_argument("--landmarks", required=True)
|
||||
parser.add_argument("--landmark-edges", required=True)
|
||||
parser.add_argument(
|
||||
"--labeled-communities",
|
||||
default=None,
|
||||
help="Optional. Adds label/atom_id to prompt_metadata.jsonl, never the prompt body.",
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument("--input-file", action="append", default=None)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument("--margin-seconds", type=float, default=60.0)
|
||||
parser.add_argument("--max-events-per-community", type=int, default=5000)
|
||||
parser.add_argument("--max-landmarks-in-prompt", type=int, default=60)
|
||||
parser.add_argument("--max-edges-in-prompt", type=int, default=80)
|
||||
parser.add_argument("--top-m-per-metapath", type=int, default=5)
|
||||
parser.add_argument("--max-prompts", type=int, default=None)
|
||||
parser.add_argument("--progress-every", type=int, default=2_000_000)
|
||||
parser.add_argument("--max-lines", type=int, default=None)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--include-only",
|
||||
choices=("all", "malicious", "balanced"),
|
||||
default="balanced",
|
||||
help=(
|
||||
"Which communities to render. 'malicious' = only GT-malicious, "
|
||||
"'balanced' = all malicious + a random benign sample (--benign-per-malicious)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benign-per-malicious",
|
||||
type=int,
|
||||
default=24,
|
||||
help="When --include-only=balanced, sample this many benign per malicious.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = (
|
||||
[Path(p) for p in args.input_file]
|
||||
if args.input_file
|
||||
else discover_theia_json_files(args.data_dir)
|
||||
)
|
||||
if not paths:
|
||||
raise SystemExit(f"No THEIA JSON files found at {args.data_dir}")
|
||||
|
||||
print("[hybrid] reading communities...", flush=True)
|
||||
communities = read_communities_jsonl(args.communities)
|
||||
print(f"[hybrid] communities loaded: {len(communities)}", flush=True)
|
||||
|
||||
label_index: dict[str, dict] = {}
|
||||
if args.labeled_communities:
|
||||
with Path(args.labeled_communities).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
label_index[row["community_id"]] = row
|
||||
|
||||
# --- selection ---------------------------------------------------- #
|
||||
if args.include_only != "all":
|
||||
if not label_index:
|
||||
raise SystemExit("--include-only != all requires --labeled-communities")
|
||||
if args.include_only == "malicious":
|
||||
communities = [
|
||||
c for c in communities
|
||||
if label_index.get(c.community_id, {}).get("label") == "malicious"
|
||||
]
|
||||
elif args.include_only == "balanced":
|
||||
import random
|
||||
|
||||
rng = random.Random(args.seed)
|
||||
mal = [
|
||||
c for c in communities
|
||||
if label_index.get(c.community_id, {}).get("label") == "malicious"
|
||||
]
|
||||
ben = [
|
||||
c for c in communities
|
||||
if label_index.get(c.community_id, {}).get("label") == "benign"
|
||||
]
|
||||
rng.shuffle(ben)
|
||||
target_ben = max(1, args.benign_per_malicious * max(1, len(mal)))
|
||||
communities = mal + ben[:target_ben]
|
||||
communities.sort(
|
||||
key=lambda c: (-len(c.landmark_event_ids), c.start_timestamp_nanos, c.community_id)
|
||||
)
|
||||
|
||||
if args.max_prompts is not None:
|
||||
communities = communities[: args.max_prompts]
|
||||
|
||||
print(
|
||||
f"[hybrid] selected {len(communities)} communities "
|
||||
f"({sum(1 for c in communities if label_index.get(c.community_id, {}).get('label') == 'malicious')} malicious)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# --- stream-filtered loads of landmarks + edges ------------------- #
|
||||
needed_landmark_ids: set[str] = set()
|
||||
needed_edge_ids: set[str] = set()
|
||||
for c in communities:
|
||||
needed_landmark_ids.update(c.landmark_event_ids)
|
||||
needed_edge_ids.update(c.edge_ids)
|
||||
print(
|
||||
f"[hybrid] need {len(needed_landmark_ids)} landmark rows / "
|
||||
f"{len(needed_edge_ids)} edge rows from disk",
|
||||
flush=True,
|
||||
)
|
||||
print("[hybrid] stream-loading landmarks...", flush=True)
|
||||
landmarks_by_id = _stream_filter_landmarks(Path(args.landmarks), needed_landmark_ids)
|
||||
print(f"[hybrid] landmarks loaded: {len(landmarks_by_id)}", flush=True)
|
||||
print("[hybrid] stream-loading edges...", flush=True)
|
||||
edges_by_id = _stream_filter_edges(Path(args.landmark_edges), needed_edge_ids)
|
||||
print(f"[hybrid] edges loaded: {len(edges_by_id)}", flush=True)
|
||||
|
||||
# --- materialize fine-grained subgraphs (single THEIA pass) ------- #
|
||||
print(f"[hybrid] streaming THEIA from {len(paths)} files to build subgraphs...", flush=True)
|
||||
subgraphs = build_community_subgraphs(
|
||||
communities,
|
||||
paths,
|
||||
margin_seconds=args.margin_seconds,
|
||||
max_events_per_community=args.max_events_per_community,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
progress_every=args.progress_every,
|
||||
)
|
||||
truncated = sum(1 for s in subgraphs.values() if s.truncated)
|
||||
total_events = sum(len(s.events) for s in subgraphs.values())
|
||||
total_entities = sum(len(s.entities) for s in subgraphs.values())
|
||||
print(
|
||||
f"[hybrid] subgraphs ready: communities={len(subgraphs)} "
|
||||
f"truncated={truncated} total_events={total_events} total_entities={total_entities}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# --- build hybrid prompts ----------------------------------------- #
|
||||
output_dir = Path(args.output_dir)
|
||||
prompts_dir = output_dir / "prompts"
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "prompt_metadata.jsonl"
|
||||
|
||||
builder = HybridCommunityPromptBuilder(
|
||||
landmarks_by_id=landmarks_by_id,
|
||||
edges_by_id=edges_by_id,
|
||||
# No NodeText / PathSumm summarizers — keeps the experiment
|
||||
# cost-bounded and removes a confounder. Switches below disable them.
|
||||
node_summarizer=None,
|
||||
path_summarizer=None,
|
||||
switches=HybridPromptSwitches(
|
||||
use_text_summarization=False,
|
||||
use_path_summarization_llm=False,
|
||||
use_numerical_aggregation_dgp=True,
|
||||
use_apt_numerical_stats=True,
|
||||
include_evidence_ids=True,
|
||||
include_landmark_skeleton=True,
|
||||
include_landmark_bridges=True,
|
||||
max_landmarks_in_prompt=args.max_landmarks_in_prompt,
|
||||
max_edges_in_prompt=args.max_edges_in_prompt,
|
||||
top_m_per_metapath=args.top_m_per_metapath,
|
||||
),
|
||||
)
|
||||
|
||||
written = 0
|
||||
with metadata_path.open("w", encoding="utf-8") as meta_out:
|
||||
for community in communities:
|
||||
sub = subgraphs.get(community.community_id)
|
||||
if sub is None:
|
||||
# Stream filter produced nothing for this community — emit
|
||||
# a stub prompt with empty metapath blocks rather than
|
||||
# silently dropping it (we want to count this in metrics).
|
||||
continue
|
||||
bundle = builder.build(community, sub)
|
||||
(prompts_dir / f"{community.community_id}.txt").write_text(
|
||||
bundle.prompt_text, encoding="utf-8"
|
||||
)
|
||||
label_row = label_index.get(community.community_id) or {}
|
||||
meta_out.write(
|
||||
json.dumps(
|
||||
{
|
||||
"community_id": community.community_id,
|
||||
"host_id": community.host_id,
|
||||
"label": label_row.get("label", "unlabeled"),
|
||||
"label_source": label_row.get(
|
||||
"label_source", "no_ground_truth_join"
|
||||
),
|
||||
"gt_atoms_hit": label_row.get("gt_atoms_hit") or [],
|
||||
"gt_subjects_hit": label_row.get("gt_subjects_hit") or [],
|
||||
"span_seconds": community.span_seconds,
|
||||
"subjects_in_community": len(community.subjects),
|
||||
"num_landmarks_total": len(community.landmark_event_ids),
|
||||
"num_landmarks_in_prompt": bundle.metadata[
|
||||
"num_landmarks_in_prompt"
|
||||
],
|
||||
"num_edges_total": len(community.edge_ids),
|
||||
"num_edges_in_prompt": bundle.metadata["num_edges_in_prompt"],
|
||||
"subgraph_entities_count": bundle.metadata[
|
||||
"subgraph_entities_count"
|
||||
],
|
||||
"subgraph_events_count": bundle.metadata[
|
||||
"subgraph_events_count"
|
||||
],
|
||||
"subgraph_truncated": bundle.metadata["subgraph_truncated"],
|
||||
"metapath_paths_extracted": bundle.metadata[
|
||||
"metapath_paths_extracted"
|
||||
],
|
||||
"metapath_paths_after_trim": bundle.metadata[
|
||||
"metapath_paths_after_trim"
|
||||
],
|
||||
"selected_landmark_ids": list(bundle.selected_landmark_ids),
|
||||
"evidence_path_ids": list(bundle.evidence_path_ids),
|
||||
"prompt_path": str(
|
||||
(prompts_dir / f"{community.community_id}.txt").resolve()
|
||||
),
|
||||
"prompt_char_length": len(bundle.prompt_text),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
written += 1
|
||||
print(
|
||||
f"[hybrid] wrote {written} prompts to {prompts_dir} "
|
||||
f"and metadata to {metadata_path}",
|
||||
flush=True,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
59
scripts/build_hybrid_labeled_targets.py
Normal file
59
scripts/build_hybrid_labeled_targets.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convert hybrid prompt_metadata.jsonl → labeled_targets.jsonl for run_evaluation.py.
|
||||
|
||||
Hybrid prompts use ``community_id`` as the prompt id; the v0.1 evaluator
|
||||
expects ``target_id``. This script does the rename and emits a minimal
|
||||
labeled_targets.jsonl with the fields the evaluator needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--prompt-metadata", required=True)
|
||||
parser.add_argument("--output", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
out = Path(args.output)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
written = 0
|
||||
with Path(args.prompt_metadata).open("r", encoding="utf-8") as inp, out.open(
|
||||
"w", encoding="utf-8"
|
||||
) as outf:
|
||||
for line in inp:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
label = row.get("label", "unlabeled")
|
||||
if label not in {"malicious", "benign"}:
|
||||
continue # skip unlabeled communities
|
||||
payload = {
|
||||
"target_id": row["community_id"],
|
||||
"target_type": "COMMUNITY_SUBGRAPH",
|
||||
"label": label,
|
||||
"label_confidence": "high" if label == "malicious" else "default",
|
||||
"label_source": row.get("label_source", "no_ground_truth_join"),
|
||||
"anchor_event_id": row.get("selected_landmark_ids", [""])[0]
|
||||
if row.get("selected_landmark_ids")
|
||||
else "",
|
||||
"host_id": row.get("host_id"),
|
||||
"span_seconds": row.get("span_seconds"),
|
||||
"subjects_in_community": row.get("subjects_in_community"),
|
||||
"num_landmarks_total": row.get("num_landmarks_total"),
|
||||
"subgraph_events_count": row.get("subgraph_events_count"),
|
||||
"gt_atoms_hit": row.get("gt_atoms_hit") or [],
|
||||
}
|
||||
outf.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
written += 1
|
||||
print(f"wrote {written} labeled targets to {out}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
62
scripts/build_labeled_eval_batch.py
Normal file
62
scripts/build_labeled_eval_batch.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a protocol-based labeled target batch for prompt generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.evaluation_batch import build_evaluation_batch
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build labeled target metadata. Labels remain evaluation-only and are not prompt input."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--positive-process-labels",
|
||||
default="reports/ground_truth/e3_mapping_ioc_files_time/process_labels_high_plus.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--positive-event-matches",
|
||||
default="reports/ground_truth/e3_mapping_ioc_files_time/event_matches_high_plus.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all-mapped-process-labels",
|
||||
default="reports/ground_truth/e3_mapping_ioc_files_time/process_labels.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--candidate-universe",
|
||||
default="reports/theia_candidate_universe_ioc_files/candidate_universe.jsonl",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="reports/evaluation/e3_theia_v0_1")
|
||||
parser.add_argument("--num-positives", type=int, default=8)
|
||||
parser.add_argument("--num-hard-negative-proxies", type=int, default=8)
|
||||
parser.add_argument("--max-hard-negative-events", type=int, default=1000)
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
args = parser.parse_args()
|
||||
|
||||
batch = build_evaluation_batch(
|
||||
positive_process_labels_path=args.positive_process_labels,
|
||||
positive_event_matches_path=args.positive_event_matches,
|
||||
candidate_universe_path=args.candidate_universe,
|
||||
all_mapped_process_labels_path=args.all_mapped_process_labels,
|
||||
num_positives=args.num_positives,
|
||||
num_hard_negative_proxies=args.num_hard_negative_proxies,
|
||||
max_hard_negative_events=args.max_hard_negative_events,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
targets_path = output_dir / "labeled_targets.jsonl"
|
||||
report_path = output_dir / "labeled_targets.md"
|
||||
batch.write_jsonl(targets_path)
|
||||
report_path.write_text(batch.to_markdown() + "\n", encoding="utf-8")
|
||||
print(f"targets={len(batch.targets)}")
|
||||
print(f"wrote {targets_path}")
|
||||
print(f"wrote {report_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
scripts/build_landmark_graph.py
Normal file
184
scripts/build_landmark_graph.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Stream the THEIA corpus once and emit the Landmark-Bridged Causal Story Graph.
|
||||
|
||||
Outputs:
|
||||
- landmarks.jsonl — one row per landmark event
|
||||
- landmark_edges.jsonl — one row per landmark→landmark causal bridge
|
||||
- landmark_communities.jsonl — one row per detection unit (subgraph)
|
||||
- landmark_stats.json — corpus-level counts and class histogram
|
||||
|
||||
This script is the construction phase of Phase 14. Detection (per-community
|
||||
LLM prompting) is a separate step (`build_landmark_prompts.py`).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from er_tp_dgp.landmark import ( # noqa: E402
|
||||
StreamingLandmarkGraphBuilder,
|
||||
compute_landmark_communities,
|
||||
write_communities_jsonl,
|
||||
write_edges_jsonl,
|
||||
write_landmarks_jsonl,
|
||||
)
|
||||
from er_tp_dgp.theia import discover_theia_json_files, iter_theia_records # noqa: E402
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument("--input-file", action="append", default=None)
|
||||
parser.add_argument("--output-dir", default="reports/landmark_csg")
|
||||
parser.add_argument("--progress-every", type=int, default=1_000_000)
|
||||
parser.add_argument(
|
||||
"--k-ancestors",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Per-entity ancestor cache size. Bigger = denser landmark edges.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-bridge-seconds",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="Drop ancestor→landmark edges whose time delta exceeds this.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-edges-per-landmark-in",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Cap inbound edges per landmark to keep the graph sparse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--silence-split-seconds",
|
||||
type=float,
|
||||
default=300.0,
|
||||
help="Inside a connected component, split on landmark gaps wider than this.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-community-landmarks",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Drop communities smaller than this (singletons are not stories).",
|
||||
)
|
||||
parser.add_argument("--max-lines", type=int, default=None)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = (
|
||||
[Path(p) for p in args.input_file]
|
||||
if args.input_file
|
||||
else discover_theia_json_files(args.data_dir)
|
||||
)
|
||||
if not paths:
|
||||
raise SystemExit(f"no THEIA JSON files found under {args.data_dir}")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(
|
||||
f"[start] paths={len(paths)} k_ancestors={args.k_ancestors} "
|
||||
f"max_bridge_seconds={args.max_bridge_seconds} "
|
||||
f"max_edges_per_landmark_in={args.max_edges_per_landmark_in}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
builder = StreamingLandmarkGraphBuilder(
|
||||
k_ancestors_per_entity=args.k_ancestors,
|
||||
max_bridge_nanos=int(args.max_bridge_seconds * 1_000_000_000),
|
||||
max_edges_per_landmark_in=args.max_edges_per_landmark_in,
|
||||
)
|
||||
builder.feed_iterable(
|
||||
iter_theia_records(
|
||||
paths,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
),
|
||||
progress_every=args.progress_every,
|
||||
)
|
||||
landmarks, edges, stats = builder.finalize()
|
||||
|
||||
print(
|
||||
f"[built] records={stats.records_seen} events={stats.events_seen} "
|
||||
f"landmarks={stats.landmarks} edges={stats.edges} "
|
||||
f"edges_skipped_time={stats.edges_skipped_time} "
|
||||
f"edges_skipped_self={stats.edges_skipped_self}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print("[community] computing weakly connected components + temporal split", flush=True)
|
||||
communities = compute_landmark_communities(
|
||||
landmarks,
|
||||
edges,
|
||||
min_landmarks=args.min_community_landmarks,
|
||||
silence_split_seconds=args.silence_split_seconds,
|
||||
)
|
||||
print(f"[community] {len(communities)} communities produced", flush=True)
|
||||
|
||||
landmarks_path = output_dir / "landmarks.jsonl"
|
||||
edges_path = output_dir / "landmark_edges.jsonl"
|
||||
communities_path = output_dir / "landmark_communities.jsonl"
|
||||
stats_path = output_dir / "landmark_stats.json"
|
||||
|
||||
write_landmarks_jsonl(landmarks, landmarks_path)
|
||||
write_edges_jsonl(edges, edges_path)
|
||||
write_communities_jsonl(communities, communities_path)
|
||||
|
||||
summary = {
|
||||
"records_seen": stats.records_seen,
|
||||
"events_seen": stats.events_seen,
|
||||
"landmarks": stats.landmarks,
|
||||
"edges": stats.edges,
|
||||
"edges_skipped_time": stats.edges_skipped_time,
|
||||
"edges_skipped_self": stats.edges_skipped_self,
|
||||
"landmarks_by_class": dict(stats.landmarks_by_class),
|
||||
"communities": len(communities),
|
||||
"community_size_min": min((len(c.landmark_event_ids) for c in communities), default=0),
|
||||
"community_size_max": max((len(c.landmark_event_ids) for c in communities), default=0),
|
||||
"community_size_p50": _percentile(
|
||||
[len(c.landmark_event_ids) for c in communities], 0.5
|
||||
),
|
||||
"community_size_p90": _percentile(
|
||||
[len(c.landmark_event_ids) for c in communities], 0.9
|
||||
),
|
||||
"community_size_p99": _percentile(
|
||||
[len(c.landmark_event_ids) for c in communities], 0.99
|
||||
),
|
||||
"config": {
|
||||
"k_ancestors": args.k_ancestors,
|
||||
"max_bridge_seconds": args.max_bridge_seconds,
|
||||
"max_edges_per_landmark_in": args.max_edges_per_landmark_in,
|
||||
"silence_split_seconds": args.silence_split_seconds,
|
||||
"min_community_landmarks": args.min_community_landmarks,
|
||||
},
|
||||
"files": {
|
||||
"landmarks": str(landmarks_path),
|
||||
"landmark_edges": str(edges_path),
|
||||
"landmark_communities": str(communities_path),
|
||||
},
|
||||
}
|
||||
stats_path.write_text(json.dumps(summary, indent=2, sort_keys=True), encoding="utf-8")
|
||||
print(f"[write] {landmarks_path}", flush=True)
|
||||
print(f"[write] {edges_path}", flush=True)
|
||||
print(f"[write] {communities_path}", flush=True)
|
||||
print(f"[write] {stats_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
def _percentile(values: list[int], q: float) -> int | None:
|
||||
if not values:
|
||||
return None
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round(q * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
156
scripts/build_landmark_prompts.py
Normal file
156
scripts/build_landmark_prompts.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Render one LLM prompt per landmark community.
|
||||
|
||||
Reads:
|
||||
- communities (output of build_landmark_graph.py)
|
||||
- landmarks (output of build_landmark_graph.py)
|
||||
- landmark_edges (output of build_landmark_graph.py)
|
||||
- labeled_communities (output of evaluate_landmark_detection.py) — labels
|
||||
are *only* attached to per-prompt metadata for downstream evaluation;
|
||||
they never enter the prompt body.
|
||||
|
||||
Writes:
|
||||
- prompts/<community_id>.txt
|
||||
- prompt_metadata.jsonl — one row per prompt with label + community
|
||||
summary, suitable for downstream LLM-runner + AUPRC computation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from er_tp_dgp.landmark import ( # noqa: E402
|
||||
read_communities_jsonl,
|
||||
read_edges_jsonl,
|
||||
read_landmarks_jsonl,
|
||||
)
|
||||
from er_tp_dgp.landmark_prompt import ( # noqa: E402
|
||||
CommunityPromptSwitches,
|
||||
LandmarkCommunityPromptBuilder,
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--communities", required=True)
|
||||
parser.add_argument("--landmarks", required=True)
|
||||
parser.add_argument("--landmark-edges", required=True)
|
||||
parser.add_argument(
|
||||
"--labeled-communities",
|
||||
default=None,
|
||||
help="Optional. Adds label/atom_id to prompt_metadata.jsonl, never the prompt body.",
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument("--max-landmarks-in-prompt", type=int, default=60)
|
||||
parser.add_argument("--max-edges-in-prompt", type=int, default=80)
|
||||
parser.add_argument("--max-prompts", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--include-only",
|
||||
choices=("all", "malicious", "balanced"),
|
||||
default="all",
|
||||
help=(
|
||||
"Which communities to render. 'malicious' = only GT-malicious, "
|
||||
"'balanced' = all malicious + an equal-sized random benign sample."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
args = parser.parse_args()
|
||||
|
||||
communities = read_communities_jsonl(args.communities)
|
||||
landmarks = read_landmarks_jsonl(args.landmarks)
|
||||
edges = read_edges_jsonl(args.landmark_edges)
|
||||
landmarks_by_id = {lm.event_id: lm for lm in landmarks}
|
||||
edges_by_id = {edge.edge_id: edge for edge in edges}
|
||||
|
||||
label_index: dict[str, dict] = {}
|
||||
if args.labeled_communities:
|
||||
with Path(args.labeled_communities).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
label_index[row["community_id"]] = row
|
||||
|
||||
if args.include_only != "all":
|
||||
if not label_index:
|
||||
raise SystemExit(
|
||||
"--include-only != all requires --labeled-communities"
|
||||
)
|
||||
if args.include_only == "malicious":
|
||||
communities = [c for c in communities if label_index.get(c.community_id, {}).get("label") == "malicious"]
|
||||
elif args.include_only == "balanced":
|
||||
import random
|
||||
|
||||
rng = random.Random(args.seed)
|
||||
mal = [c for c in communities if label_index.get(c.community_id, {}).get("label") == "malicious"]
|
||||
ben = [c for c in communities if label_index.get(c.community_id, {}).get("label") == "benign"]
|
||||
rng.shuffle(ben)
|
||||
communities = mal + ben[: len(mal)]
|
||||
communities.sort(
|
||||
key=lambda c: (-len(c.landmark_event_ids), c.start_timestamp_nanos, c.community_id)
|
||||
)
|
||||
|
||||
if args.max_prompts is not None:
|
||||
communities = communities[: args.max_prompts]
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
prompts_dir = output_dir / "prompts"
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "prompt_metadata.jsonl"
|
||||
|
||||
builder = LandmarkCommunityPromptBuilder(
|
||||
landmarks_by_id=landmarks_by_id,
|
||||
edges_by_id=edges_by_id,
|
||||
switches=CommunityPromptSwitches(
|
||||
max_landmarks_in_prompt=args.max_landmarks_in_prompt,
|
||||
max_edges_in_prompt=args.max_edges_in_prompt,
|
||||
),
|
||||
)
|
||||
|
||||
with metadata_path.open("w", encoding="utf-8") as meta_out:
|
||||
for community in communities:
|
||||
bundle = builder.build(community)
|
||||
(prompts_dir / f"{community.community_id}.txt").write_text(
|
||||
bundle.prompt_text, encoding="utf-8"
|
||||
)
|
||||
label_row = label_index.get(community.community_id) or {}
|
||||
meta_out.write(
|
||||
json.dumps(
|
||||
{
|
||||
"community_id": community.community_id,
|
||||
"host_id": community.host_id,
|
||||
"label": label_row.get("label", "unlabeled"),
|
||||
"label_source": label_row.get("label_source", "no_ground_truth_join"),
|
||||
"gt_atoms_hit": label_row.get("gt_atoms_hit") or [],
|
||||
"gt_subjects_hit": label_row.get("gt_subjects_hit") or [],
|
||||
"num_landmarks_total": len(community.landmark_event_ids),
|
||||
"num_landmarks_in_prompt": bundle.metadata["num_landmarks_in_prompt"],
|
||||
"num_edges_total": len(community.edge_ids),
|
||||
"num_edges_in_prompt": bundle.metadata["num_edges_in_prompt"],
|
||||
"span_seconds": community.span_seconds,
|
||||
"subjects_in_community": len(community.subjects),
|
||||
"selected_landmark_ids": list(bundle.selected_landmark_ids),
|
||||
"prompt_path": str((prompts_dir / f"{community.community_id}.txt").resolve()),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
print(
|
||||
f"[prompts] wrote {len(communities)} prompts to {prompts_dir} "
|
||||
f"and metadata to {metadata_path}",
|
||||
flush=True,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
204
scripts/build_landmark_prompts_for_ids.py
Normal file
204
scripts/build_landmark_prompts_for_ids.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Render Phase 14 raw landmark prompts for a specific list of community IDs.
|
||||
|
||||
For head-to-head comparison with the hybrid pipeline: feed in the same
|
||||
community_ids the hybrid pipeline rendered, get a parallel set of raw
|
||||
landmark-only prompts on the same set.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
from er_tp_dgp.landmark import ( # noqa: E402
|
||||
LandmarkEdge,
|
||||
LandmarkEvent,
|
||||
read_communities_jsonl,
|
||||
)
|
||||
from er_tp_dgp.landmark_prompt import ( # noqa: E402
|
||||
CommunityPromptSwitches,
|
||||
LandmarkCommunityPromptBuilder,
|
||||
)
|
||||
|
||||
|
||||
def _stream_filter_landmarks(path: Path, allowed_ids: set[str]) -> dict[str, LandmarkEvent]:
|
||||
out: dict[str, LandmarkEvent] = {}
|
||||
if not allowed_ids:
|
||||
return out
|
||||
needed = set(allowed_ids)
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
event_id = r.get("event_id")
|
||||
if event_id not in needed:
|
||||
continue
|
||||
out[event_id] = LandmarkEvent(
|
||||
event_id=event_id,
|
||||
timestamp_nanos=r["timestamp_nanos"],
|
||||
host_id=r.get("host_id"),
|
||||
actor_subject_id=r["actor_subject_id"],
|
||||
actor_path=r.get("actor_path"),
|
||||
object_id=r.get("object_id"),
|
||||
object_type=r.get("object_type"),
|
||||
object_summary=r.get("object_summary"),
|
||||
canonical_action=r["canonical_action"],
|
||||
raw_event_type=r["raw_event_type"],
|
||||
signals=tuple(r.get("signals") or ()),
|
||||
metapath_hints=tuple(r.get("metapath_hints") or ()),
|
||||
landmark_classes=tuple(r.get("landmark_classes") or ()),
|
||||
)
|
||||
if len(out) == len(needed):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _stream_filter_edges(path: Path, allowed_ids: set[str]) -> dict[str, LandmarkEdge]:
|
||||
out: dict[str, LandmarkEdge] = {}
|
||||
if not allowed_ids:
|
||||
return out
|
||||
needed = set(allowed_ids)
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
edge_id = r.get("edge_id")
|
||||
if edge_id not in needed:
|
||||
continue
|
||||
out[edge_id] = LandmarkEdge(
|
||||
edge_id=edge_id,
|
||||
src_event_id=r["src_event_id"],
|
||||
dst_event_id=r["dst_event_id"],
|
||||
host_id=r.get("host_id"),
|
||||
delta_nanos=r["delta_nanos"],
|
||||
bridge_hops=r["bridge_hops"],
|
||||
bridge_summary=r["bridge_summary"],
|
||||
)
|
||||
if len(out) == len(needed):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--communities", required=True)
|
||||
parser.add_argument("--landmarks", required=True)
|
||||
parser.add_argument("--landmark-edges", required=True)
|
||||
parser.add_argument(
|
||||
"--ids-from-metadata",
|
||||
required=True,
|
||||
help="prompt_metadata.jsonl from the hybrid run; we replicate its community_id set.",
|
||||
)
|
||||
parser.add_argument("--labeled-communities", default=None)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument("--max-landmarks-in-prompt", type=int, default=60)
|
||||
parser.add_argument("--max-edges-in-prompt", type=int, default=80)
|
||||
args = parser.parse_args()
|
||||
|
||||
target_ids: set[str] = set()
|
||||
with Path(args.ids_from_metadata).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
row = json.loads(line)
|
||||
target_ids.add(row["community_id"])
|
||||
print(f"[raw] target community ids: {len(target_ids)}", flush=True)
|
||||
|
||||
print("[raw] reading communities...", flush=True)
|
||||
communities = read_communities_jsonl(args.communities)
|
||||
communities = [c for c in communities if c.community_id in target_ids]
|
||||
print(f"[raw] communities matched: {len(communities)}", flush=True)
|
||||
|
||||
label_index: dict[str, dict] = {}
|
||||
if args.labeled_communities:
|
||||
with Path(args.labeled_communities).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
label_index[r["community_id"]] = r
|
||||
|
||||
needed_lm_ids: set[str] = set()
|
||||
needed_edge_ids: set[str] = set()
|
||||
for c in communities:
|
||||
needed_lm_ids.update(c.landmark_event_ids)
|
||||
needed_edge_ids.update(c.edge_ids)
|
||||
print(
|
||||
f"[raw] need {len(needed_lm_ids)} landmark rows / {len(needed_edge_ids)} edge rows",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print("[raw] stream-loading landmarks...", flush=True)
|
||||
landmarks_by_id = _stream_filter_landmarks(Path(args.landmarks), needed_lm_ids)
|
||||
print(f"[raw] landmarks loaded: {len(landmarks_by_id)}", flush=True)
|
||||
print("[raw] stream-loading edges...", flush=True)
|
||||
edges_by_id = _stream_filter_edges(Path(args.landmark_edges), needed_edge_ids)
|
||||
print(f"[raw] edges loaded: {len(edges_by_id)}", flush=True)
|
||||
|
||||
out_dir = Path(args.output_dir)
|
||||
prompts_dir = out_dir / "prompts"
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = out_dir / "prompt_metadata.jsonl"
|
||||
|
||||
builder = LandmarkCommunityPromptBuilder(
|
||||
landmarks_by_id=landmarks_by_id,
|
||||
edges_by_id=edges_by_id,
|
||||
switches=CommunityPromptSwitches(
|
||||
max_landmarks_in_prompt=args.max_landmarks_in_prompt,
|
||||
max_edges_in_prompt=args.max_edges_in_prompt,
|
||||
),
|
||||
)
|
||||
|
||||
written = 0
|
||||
with metadata_path.open("w", encoding="utf-8") as meta_out:
|
||||
for community in communities:
|
||||
bundle = builder.build(community)
|
||||
(prompts_dir / f"{community.community_id}.txt").write_text(
|
||||
bundle.prompt_text, encoding="utf-8"
|
||||
)
|
||||
label_row = label_index.get(community.community_id) or {}
|
||||
meta_out.write(
|
||||
json.dumps(
|
||||
{
|
||||
"community_id": community.community_id,
|
||||
"host_id": community.host_id,
|
||||
"label": label_row.get("label", "unlabeled"),
|
||||
"label_source": label_row.get(
|
||||
"label_source", "no_ground_truth_join"
|
||||
),
|
||||
"gt_atoms_hit": label_row.get("gt_atoms_hit") or [],
|
||||
"num_landmarks_total": len(community.landmark_event_ids),
|
||||
"num_landmarks_in_prompt": bundle.metadata[
|
||||
"num_landmarks_in_prompt"
|
||||
],
|
||||
"num_edges_total": len(community.edge_ids),
|
||||
"num_edges_in_prompt": bundle.metadata["num_edges_in_prompt"],
|
||||
"span_seconds": community.span_seconds,
|
||||
"subjects_in_community": len(community.subjects),
|
||||
"selected_landmark_ids": list(bundle.selected_landmark_ids),
|
||||
"prompt_path": str(
|
||||
(prompts_dir / f"{community.community_id}.txt").resolve()
|
||||
),
|
||||
"prompt_char_length": len(bundle.prompt_text),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
written += 1
|
||||
print(f"[raw] wrote {written} prompts to {prompts_dir}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
486
scripts/build_theia_prompt_batch.py
Normal file
486
scripts/build_theia_prompt_batch.py
Normal file
@@ -0,0 +1,486 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ER-TP-DGP prompts for a labeled THEIA evaluation batch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.experiments import default_method_registry
|
||||
from er_tp_dgp.metapaths import APTMetapathExtractor
|
||||
from er_tp_dgp.numerical_aggregator import NumericalAggregator
|
||||
from er_tp_dgp.prompt import PromptBuilder, PromptComponentSwitches
|
||||
from er_tp_dgp.theia import (
|
||||
build_cached_theia_window_ir,
|
||||
build_multi_target_window_irs,
|
||||
discover_theia_json_files,
|
||||
)
|
||||
from er_tp_dgp.trimming import TemporalSecurityAwareTrimmer
|
||||
from er_tp_dgp.validation import validate_evidence_paths, validate_graph, validate_ir
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Build graph-enhanced ER-TP-DGP prompts from labeled target metadata. "
|
||||
"Labels are written only to metadata, never into prompt text."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--targets", default="reports/evaluation/e3_theia_v0_1/labeled_targets.jsonl")
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Specific THEIA JSON file to scan. Can be repeated. Overrides --data-dir discovery.",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="reports/evaluation/e3_theia_v0_1/prompts_graph_dgp_full")
|
||||
parser.add_argument("--lookback-seconds", type=float, default=300.0)
|
||||
parser.add_argument("--lookahead-seconds", type=float, default=300.0)
|
||||
parser.add_argument("--top-m-per-metapath", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--max-window-events",
|
||||
type=int,
|
||||
default=50000,
|
||||
help=(
|
||||
"Soft audit threshold: windows above this size are recorded in "
|
||||
"prompt_size_audit.jsonl but still proceed to prompt construction "
|
||||
"(trimming controls actual prompt size, not raw window event count)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-skip-window-events",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"If set, hard-skip targets whose window exceeds this size. Default "
|
||||
"is no hard skip; only soft audit applies."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
default="reports/cache/theia_window_ir",
|
||||
help="Directory for compressed window-IR snapshots. Pass empty to disable.",
|
||||
)
|
||||
parser.add_argument("--max-targets", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--include-cohort",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Only include this cohort. Can be repeated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-per-cohort",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum targets to keep from each cohort after filtering.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method-variant",
|
||||
default="graph_dgp",
|
||||
help=(
|
||||
"Method variant from experiments.default_method_registry(). "
|
||||
"Drives prompt component switches (TextSumm / MDK / PathSumm / "
|
||||
"NumSumm / TempTrim / SecAware / EvidenceIDs)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summarizer-config",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to summarizer LLM config (YAML). Required if method variant "
|
||||
"enables DGP TextSumm or PathSumm."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summarizer-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help=(
|
||||
"Concurrency for batched LLM summarization (ThreadPoolExecutor). "
|
||||
"Higher values shorten first-cold-cache batches; bound by your "
|
||||
"endpoint's per-key rate limit."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-multi-anchor-prewarm",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Skip the one-time multi-anchor IR prewarm and let the per-target "
|
||||
"loop scan the corpus once per target. Only useful for debugging."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = [Path(path) for path in args.input_file] if args.input_file else discover_theia_json_files(args.data_dir)
|
||||
if not paths:
|
||||
raise SystemExit("no THEIA JSON files found")
|
||||
targets = _read_jsonl(args.targets)
|
||||
targets = _filter_targets(targets, args.include_cohort, args.max_per_cohort)
|
||||
if args.max_targets is not None:
|
||||
targets = targets[: args.max_targets]
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
prompts_dir = output_dir / "prompt_text"
|
||||
validations_dir = output_dir / "validations"
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
validations_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "prompt_metadata.jsonl"
|
||||
failures_path = output_dir / "prompt_failures.jsonl"
|
||||
audit_path = output_dir / "prompt_size_audit.jsonl"
|
||||
cache_dir = args.cache_dir or None
|
||||
|
||||
registry = default_method_registry()
|
||||
if args.method_variant not in registry:
|
||||
raise SystemExit(
|
||||
f"unknown method variant: {args.method_variant}; "
|
||||
f"choose from {sorted(registry)}"
|
||||
)
|
||||
method = registry[args.method_variant]
|
||||
switches = PromptComponentSwitches(
|
||||
use_text_summarization=method.uses_dgp_text_summarization,
|
||||
use_path_summarization_llm=method.uses_dgp_path_summarization_llm,
|
||||
use_numerical_aggregation_dgp=method.uses_dgp_numerical_aggregation,
|
||||
use_apt_numerical_stats=method.uses_numerical_summary,
|
||||
include_evidence_ids=method.uses_evidence_ids,
|
||||
include_local_one_hop_context=method.uses_local_context,
|
||||
)
|
||||
|
||||
summarizer_pair = _maybe_build_summarizers(
|
||||
switches=switches,
|
||||
summarizer_config_path=args.summarizer_config,
|
||||
max_workers=args.summarizer_workers,
|
||||
)
|
||||
|
||||
# Pre-warm the THEIA window-IR cache for *all* targets in one two-pass scan,
|
||||
# so the per-target loop below hits cache instead of scanning the 80 GB
|
||||
# corpus once per target. For 16 targets this is 16x less disk IO.
|
||||
if cache_dir and not args.skip_multi_anchor_prewarm and len(targets) > 1:
|
||||
anchors = [
|
||||
{
|
||||
"anchor_event_uuid": t["anchor_event_id"],
|
||||
"lookback_seconds": args.lookback_seconds,
|
||||
"lookahead_seconds": args.lookahead_seconds,
|
||||
}
|
||||
for t in targets
|
||||
]
|
||||
from time import time as _now
|
||||
prewarm_started = _now()
|
||||
print(
|
||||
f"[multi-anchor prewarm] {len(anchors)} anchors, lookback={args.lookback_seconds}s, "
|
||||
f"lookahead={args.lookahead_seconds}s, cache={cache_dir}"
|
||||
)
|
||||
prewarm_results = build_multi_target_window_irs(
|
||||
paths,
|
||||
anchors=anchors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
prewarm_elapsed = _now() - prewarm_started
|
||||
print(
|
||||
f"[multi-anchor prewarm] populated {len(prewarm_results)}/{len(anchors)} anchors "
|
||||
f"in {prewarm_elapsed:.1f}s"
|
||||
)
|
||||
|
||||
metadata_rows: list[dict[str, object]] = []
|
||||
failure_rows: list[dict[str, object]] = []
|
||||
audit_rows: list[dict[str, object]] = []
|
||||
for index, target in enumerate(targets, start=1):
|
||||
target_id = target["target_id"]
|
||||
anchor_event_id = target["anchor_event_id"]
|
||||
try:
|
||||
window = build_cached_theia_window_ir(
|
||||
paths,
|
||||
target_event_uuid=anchor_event_id,
|
||||
lookback_seconds=args.lookback_seconds,
|
||||
lookahead_seconds=args.lookahead_seconds,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
graph = window.to_graph()
|
||||
graph_target_id = window.target_subject_id or window.target_event_id
|
||||
if graph_target_id != target_id:
|
||||
raise ValueError(
|
||||
f"anchor event subject mismatch: expected {target_id}, got {graph_target_id}"
|
||||
)
|
||||
if (
|
||||
args.hard_skip_window_events is not None
|
||||
and len(window.events) > args.hard_skip_window_events
|
||||
):
|
||||
raise ValueError(
|
||||
f"window too large for direct prompt construction: "
|
||||
f"{len(window.events)} events > {args.hard_skip_window_events}; "
|
||||
"consider narrower lookback/lookahead or remove --hard-skip-window-events."
|
||||
)
|
||||
window_oversize = len(window.events) > args.max_window_events
|
||||
if window_oversize:
|
||||
audit_rows.append(
|
||||
{
|
||||
"target_id": target_id,
|
||||
"anchor_event_id": anchor_event_id,
|
||||
"cohort": target.get("cohort"),
|
||||
"events": len(window.events),
|
||||
"audit_threshold": args.max_window_events,
|
||||
"note": (
|
||||
"Window exceeded soft threshold; prompt was still "
|
||||
"constructed because trimming controls prompt size."
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
ir_report = validate_ir(list(window.entities), list(window.events))
|
||||
graph_report = validate_graph(graph)
|
||||
paths_all = APTMetapathExtractor(graph).extract_for_target(graph_target_id)
|
||||
selected = _select_paths(
|
||||
graph=graph,
|
||||
graph_target_id=graph_target_id,
|
||||
paths_all=paths_all,
|
||||
method_variant=method,
|
||||
top_m_per_metapath=args.top_m_per_metapath,
|
||||
)
|
||||
evidence_report = validate_evidence_paths(graph, selected)
|
||||
node_summarizer, path_summarizer = summarizer_pair
|
||||
prompt = PromptBuilder(
|
||||
graph,
|
||||
node_summarizer=node_summarizer,
|
||||
path_summarizer=path_summarizer,
|
||||
numerical_aggregator=NumericalAggregator(graph),
|
||||
switches=switches,
|
||||
).build(graph_target_id, selected)
|
||||
|
||||
safe_name = f"{index:04d}_{_safe_id(target_id)}"
|
||||
prompt_path = prompts_dir / f"{safe_name}.txt"
|
||||
prompt_path.write_text(prompt.prompt_text, encoding="utf-8")
|
||||
(validations_dir / f"{safe_name}_ir.md").write_text(ir_report.to_markdown(), encoding="utf-8")
|
||||
(validations_dir / f"{safe_name}_graph.md").write_text(graph_report.to_markdown(), encoding="utf-8")
|
||||
(validations_dir / f"{safe_name}_evidence.md").write_text(evidence_report.to_markdown(), encoding="utf-8")
|
||||
|
||||
metadata_rows.append(
|
||||
{
|
||||
"target_id": target_id,
|
||||
"target_type": target["target_type"],
|
||||
"label": target["label"],
|
||||
"label_confidence": target["label_confidence"],
|
||||
"cohort": target["cohort"],
|
||||
"anchor_event_id": anchor_event_id,
|
||||
"prompt_path": str(prompt_path),
|
||||
"prompt_chars": len(prompt.prompt_text),
|
||||
"prompt_estimated_tokens": int(len(prompt.prompt_text) / 4),
|
||||
"entities": len(window.entities),
|
||||
"events": len(window.events),
|
||||
"extracted_evidence_paths": len(paths_all),
|
||||
"selected_evidence_paths": len(selected),
|
||||
"evidence_path_ids": list(prompt.evidence_path_ids),
|
||||
"ir_ok": ir_report.ok,
|
||||
"graph_ok": graph_report.ok,
|
||||
"evidence_ok": evidence_report.ok,
|
||||
"schema_gaps": list(window.schema_gaps),
|
||||
"label_fields_excluded_from_prompt": True,
|
||||
"method_variant": method.name,
|
||||
"window_exceeded_soft_threshold": window_oversize,
|
||||
}
|
||||
)
|
||||
tag = " (oversize-window)" if window_oversize else ""
|
||||
print(
|
||||
f"[{index}/{len(targets)}] built {target_id} "
|
||||
f"events={len(window.events)} selected={len(selected)}{tag}"
|
||||
)
|
||||
except Exception as exc:
|
||||
failure_rows.append(
|
||||
{
|
||||
"target_id": target_id,
|
||||
"anchor_event_id": anchor_event_id,
|
||||
"cohort": target.get("cohort"),
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
print(f"[{index}/{len(targets)}] failed {target_id}: {exc}")
|
||||
|
||||
_write_jsonl(metadata_path, metadata_rows)
|
||||
_write_jsonl(failures_path, failure_rows)
|
||||
_write_jsonl(audit_path, audit_rows)
|
||||
summary = _summary_markdown(metadata_rows, failure_rows, audit_rows, args)
|
||||
(output_dir / "prompt_batch.md").write_text(summary + "\n", encoding="utf-8")
|
||||
|
||||
print(f"built={len(metadata_rows)} failed={len(failure_rows)} oversize_audited={len(audit_rows)}")
|
||||
print(f"wrote {metadata_path}")
|
||||
print(f"wrote {failures_path}")
|
||||
print(f"wrote {audit_path}")
|
||||
|
||||
|
||||
def _read_jsonl(path: str | Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
with Path(path).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _filter_targets(
|
||||
targets: list[dict[str, object]],
|
||||
include_cohorts: list[str] | None,
|
||||
max_per_cohort: int | None,
|
||||
) -> list[dict[str, object]]:
|
||||
if include_cohorts:
|
||||
allowed = set(include_cohorts)
|
||||
targets = [target for target in targets if target.get("cohort") in allowed]
|
||||
if max_per_cohort is None:
|
||||
return targets
|
||||
counts: dict[str, int] = {}
|
||||
selected: list[dict[str, object]] = []
|
||||
for target in targets:
|
||||
cohort = str(target.get("cohort"))
|
||||
if counts.get(cohort, 0) >= max_per_cohort:
|
||||
continue
|
||||
selected.append(target)
|
||||
counts[cohort] = counts.get(cohort, 0) + 1
|
||||
return selected
|
||||
|
||||
|
||||
def _write_jsonl(path: str | Path, rows: list[dict[str, object]]) -> None:
|
||||
destination = Path(path)
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
|
||||
|
||||
def _summary_markdown(
|
||||
metadata_rows: list[dict[str, object]],
|
||||
failure_rows: list[dict[str, object]],
|
||||
audit_rows: list[dict[str, object]],
|
||||
args: argparse.Namespace,
|
||||
) -> str:
|
||||
cohorts: dict[str, int] = {}
|
||||
for row in metadata_rows:
|
||||
cohort = str(row.get("cohort"))
|
||||
cohorts[cohort] = cohorts.get(cohort, 0) + 1
|
||||
lines = [
|
||||
"# ER-TP-DGP Prompt Batch",
|
||||
"",
|
||||
"Labels are metadata only and are excluded from prompt text.",
|
||||
"",
|
||||
f"- method_variant: {args.method_variant}",
|
||||
f"- built: {len(metadata_rows)}",
|
||||
f"- failed: {len(failure_rows)}",
|
||||
f"- oversize_audited: {len(audit_rows)}",
|
||||
f"- lookback_seconds: {args.lookback_seconds}",
|
||||
f"- lookahead_seconds: {args.lookahead_seconds}",
|
||||
f"- top_m_per_metapath: {args.top_m_per_metapath}",
|
||||
f"- max_window_events_soft: {args.max_window_events}",
|
||||
f"- hard_skip_window_events: {args.hard_skip_window_events}",
|
||||
f"- cache_dir: {args.cache_dir}",
|
||||
"",
|
||||
"## Cohorts",
|
||||
"",
|
||||
]
|
||||
lines.extend([f"- {key}: {value}" for key, value in sorted(cohorts.items())] or ["- none"])
|
||||
lines.extend(["", "## Prompt Size", ""])
|
||||
if metadata_rows:
|
||||
token_values = [int(row["prompt_estimated_tokens"]) for row in metadata_rows]
|
||||
lines.extend(
|
||||
[
|
||||
f"- min_estimated_tokens: {min(token_values)}",
|
||||
f"- max_estimated_tokens: {max(token_values)}",
|
||||
f"- avg_estimated_tokens: {sum(token_values) / len(token_values):.1f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.append("- none")
|
||||
if failure_rows:
|
||||
lines.extend(["", "## Failures", ""])
|
||||
for row in failure_rows:
|
||||
lines.append(f"- target={row['target_id']} error={row['error']}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _safe_id(value: str) -> str:
|
||||
return re.sub(r"[^A-Za-z0-9_.-]+", "_", value)[:120]
|
||||
|
||||
|
||||
def _select_paths(*, graph, graph_target_id, paths_all, method_variant, top_m_per_metapath):
|
||||
"""Pick the trimmer that matches the method variant's switches.
|
||||
|
||||
Four regimes (in order of preference):
|
||||
0. No graph at all: ``uses_event_reified_graph=False`` → return [].
|
||||
Used by target_only_llm and similar non-graph baselines so the prompt
|
||||
really contains zero metapath context.
|
||||
1. DGP MDK: ``uses_dgp_diffusion_trimming=True`` → MDK trimmer.
|
||||
2. APT rule trimmer: MDK off but TempTrim/SecAware still on.
|
||||
3. No trimming at all: return paths_all (w/o TempTrim ablation when
|
||||
MDK is also off, but graph still present).
|
||||
"""
|
||||
if not method_variant.uses_event_reified_graph:
|
||||
return []
|
||||
if method_variant.uses_dgp_diffusion_trimming:
|
||||
try:
|
||||
from er_tp_dgp.diffusion_trimmer import (
|
||||
HashingEmbedder,
|
||||
MarkovDiffusionTrimmer,
|
||||
MDKConfig,
|
||||
)
|
||||
except RuntimeError:
|
||||
print("WARNING: numpy unavailable; falling back to rule trimmer for MDK request.")
|
||||
else:
|
||||
embedder = HashingEmbedder(dim=64)
|
||||
return MarkovDiffusionTrimmer(
|
||||
graph,
|
||||
embedder=embedder,
|
||||
config=MDKConfig(k_hops=3, top_m=top_m_per_metapath),
|
||||
).trim(graph_target_id, paths_all)
|
||||
|
||||
if method_variant.uses_temporal_trimming or method_variant.uses_security_aware_trimming:
|
||||
return TemporalSecurityAwareTrimmer(
|
||||
graph,
|
||||
top_m_per_metapath=top_m_per_metapath,
|
||||
).trim(graph_target_id, paths_all)
|
||||
|
||||
# No trimming.
|
||||
return paths_all
|
||||
|
||||
|
||||
def _maybe_build_summarizers(*, switches, summarizer_config_path, max_workers):
|
||||
"""Build NodeTextSummarizer / MetapathTextSummarizer iff DGP TextSumm/PathSumm enabled.
|
||||
|
||||
Returns ``(None, None)`` when summarization is disabled.
|
||||
"""
|
||||
needs_node = switches.use_text_summarization
|
||||
needs_path = switches.use_path_summarization_llm
|
||||
if not (needs_node or needs_path):
|
||||
return None, None
|
||||
if not summarizer_config_path:
|
||||
print(
|
||||
"WARNING: method variant requests TextSumm/PathSumm but "
|
||||
"--summarizer-config was not provided; falling back to truncation-only summaries."
|
||||
)
|
||||
from er_tp_dgp.text_summarizer import (
|
||||
MetapathTextSummarizer,
|
||||
NodeTextSummarizer,
|
||||
SummarizerConfig,
|
||||
_NullLLM,
|
||||
)
|
||||
|
||||
cfg = SummarizerConfig(model_name="null-fallback", max_workers=max_workers)
|
||||
node = NodeTextSummarizer(llm=_NullLLM(), config=cfg) if needs_node else None
|
||||
path = MetapathTextSummarizer(llm=_NullLLM(), config=cfg) if needs_path else None
|
||||
return node, path
|
||||
|
||||
from er_tp_dgp.llm import OpenAICompatibleHTTPProvider
|
||||
from er_tp_dgp.llm_config import load_llm_config
|
||||
from er_tp_dgp.text_summarizer import (
|
||||
MetapathTextSummarizer,
|
||||
NodeTextSummarizer,
|
||||
SummarizerConfig,
|
||||
)
|
||||
|
||||
llm_config = load_llm_config(summarizer_config_path)
|
||||
provider = OpenAICompatibleHTTPProvider(llm_config)
|
||||
cfg = SummarizerConfig(model_name=llm_config.model, max_workers=max_workers)
|
||||
node = NodeTextSummarizer(llm=provider, config=cfg) if needs_node else None
|
||||
path = MetapathTextSummarizer(llm=provider, config=cfg) if needs_path else None
|
||||
return node, path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
284
scripts/evaluate_landmark_detection.py
Normal file
284
scripts/evaluate_landmark_detection.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Join ORTHRUS ground truth onto landmark communities and report coverage.
|
||||
|
||||
The CSG construction is GT-free. This script is the evaluation phase: it
|
||||
reads the constructed communities and asks two questions:
|
||||
|
||||
1. **Subject coverage** — for each GT-malicious subject, is it touched by
|
||||
at least one community? Lower bounds detection recall.
|
||||
2. **Community-level GT join** — for each community, is any of its landmark
|
||||
events a GT-malicious-subject event? Communities flagged this way are
|
||||
the positive class for downstream LLM evaluation.
|
||||
|
||||
The output of this script is the labeled-community manifest fed to LLM
|
||||
prompting + AUPRC computation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--communities", required=True)
|
||||
parser.add_argument("--landmarks", required=True)
|
||||
parser.add_argument(
|
||||
"--oracle-targets",
|
||||
required=True,
|
||||
help="ORTHRUS labeled targets (target_id is the malicious subject UUID).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-labeled-communities",
|
||||
required=True,
|
||||
help="Per-community manifest with label/atom_id joined for evaluation only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-markdown",
|
||||
required=True,
|
||||
help="Aggregate evaluation report.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
communities = _read_jsonl(args.communities)
|
||||
landmarks_by_id = {row["event_id"]: row for row in _read_jsonl(args.landmarks)}
|
||||
oracle_rows = _read_jsonl(args.oracle_targets)
|
||||
|
||||
# Build subject → atom_id lookup over ground-truth-malicious subjects.
|
||||
gt_subject_to_atom: dict[str, str | None] = {}
|
||||
for row in oracle_rows:
|
||||
if row.get("label") == "malicious":
|
||||
gt_subject_to_atom[row["target_id"]] = row.get("atom_id")
|
||||
print(f"[gt] malicious subjects in oracle: {len(gt_subject_to_atom)}", flush=True)
|
||||
|
||||
# Index communities by subject for "did we cover this GT subject?".
|
||||
communities_by_subject: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for c in communities:
|
||||
for sid in c.get("subjects") or ():
|
||||
communities_by_subject[sid].append(c)
|
||||
|
||||
# Per-GT-subject coverage report.
|
||||
covered = 0
|
||||
coverage_rows: list[dict[str, Any]] = []
|
||||
for sid, atom in gt_subject_to_atom.items():
|
||||
cs = communities_by_subject.get(sid, [])
|
||||
if cs:
|
||||
covered += 1
|
||||
coverage_rows.append(
|
||||
{
|
||||
"subject_uuid": sid,
|
||||
"atom_id": atom,
|
||||
"covered": True,
|
||||
"communities": [c["community_id"] for c in cs],
|
||||
"max_community_landmarks": max(len(c["landmark_event_ids"]) for c in cs),
|
||||
}
|
||||
)
|
||||
else:
|
||||
coverage_rows.append(
|
||||
{
|
||||
"subject_uuid": sid,
|
||||
"atom_id": atom,
|
||||
"covered": False,
|
||||
"communities": [],
|
||||
"max_community_landmarks": 0,
|
||||
}
|
||||
)
|
||||
|
||||
# Per-community label join: a community is malicious if any of its
|
||||
# landmarks' actor_subject_id is in the GT-malicious set.
|
||||
labeled: list[dict[str, Any]] = []
|
||||
malicious_community_count = 0
|
||||
benign_community_count = 0
|
||||
for c in communities:
|
||||
gt_subjects_hit: list[str] = []
|
||||
gt_atoms_hit: set[str] = set()
|
||||
for eid in c["landmark_event_ids"]:
|
||||
lm = landmarks_by_id.get(eid)
|
||||
if not lm:
|
||||
continue
|
||||
sid = lm.get("actor_subject_id")
|
||||
if sid in gt_subject_to_atom:
|
||||
gt_subjects_hit.append(sid)
|
||||
atom = gt_subject_to_atom[sid]
|
||||
if atom:
|
||||
gt_atoms_hit.add(atom)
|
||||
is_malicious = bool(gt_subjects_hit)
|
||||
if is_malicious:
|
||||
malicious_community_count += 1
|
||||
else:
|
||||
benign_community_count += 1
|
||||
labeled.append(
|
||||
{
|
||||
"community_id": c["community_id"],
|
||||
"host_id": c.get("host_id"),
|
||||
"label": "malicious" if is_malicious else "benign",
|
||||
"label_source": (
|
||||
"orthrus_subject_membership" if is_malicious else "no_gt_subject_overlap"
|
||||
),
|
||||
"gt_subjects_hit": sorted(set(gt_subjects_hit)),
|
||||
"gt_atoms_hit": sorted(gt_atoms_hit),
|
||||
"num_landmarks": len(c["landmark_event_ids"]),
|
||||
"num_edges": len(c.get("edge_ids") or ()),
|
||||
"subjects_in_community": len(c.get("subjects") or ()),
|
||||
"span_seconds": c["span_seconds"],
|
||||
"start_timestamp_nanos": c["start_timestamp_nanos"],
|
||||
"landmark_class_counts": c.get("landmark_class_counts") or {},
|
||||
}
|
||||
)
|
||||
|
||||
Path(args.out_labeled_communities).parent.mkdir(parents=True, exist_ok=True)
|
||||
with Path(args.out_labeled_communities).open("w", encoding="utf-8") as out:
|
||||
for row in labeled:
|
||||
out.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
|
||||
md = _render_markdown(
|
||||
coverage_rows=coverage_rows,
|
||||
labeled=labeled,
|
||||
gt_subjects=gt_subject_to_atom,
|
||||
communities=communities,
|
||||
config={
|
||||
"communities_path": args.communities,
|
||||
"landmarks_path": args.landmarks,
|
||||
"oracle_targets_path": args.oracle_targets,
|
||||
"out_labeled_communities": args.out_labeled_communities,
|
||||
},
|
||||
coverage=covered,
|
||||
)
|
||||
Path(args.out_markdown).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.out_markdown).write_text(md, encoding="utf-8")
|
||||
|
||||
print(
|
||||
f"[eval] gt_subjects={len(gt_subject_to_atom)} covered={covered} "
|
||||
f"communities={len(communities)} malicious={malicious_community_count} "
|
||||
f"benign={benign_community_count}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[eval] wrote {args.out_labeled_communities}", flush=True)
|
||||
print(f"[eval] wrote {args.out_markdown}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
def _read_jsonl(path: str) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
with Path(path).open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _render_markdown(
|
||||
*,
|
||||
coverage_rows: list[dict[str, Any]],
|
||||
labeled: list[dict[str, Any]],
|
||||
gt_subjects: dict[str, str | None],
|
||||
communities: list[dict[str, Any]],
|
||||
config: dict[str, Any],
|
||||
coverage: int,
|
||||
) -> str:
|
||||
total_subjects = len(gt_subjects)
|
||||
malicious_communities = sum(1 for r in labeled if r["label"] == "malicious")
|
||||
benign_communities = sum(1 for r in labeled if r["label"] == "benign")
|
||||
atoms_with_at_least_one_community: set[str] = set()
|
||||
for row in labeled:
|
||||
if row["label"] == "malicious":
|
||||
atoms_with_at_least_one_community.update(row.get("gt_atoms_hit") or [])
|
||||
total_atoms = {atom for atom in gt_subjects.values() if atom}
|
||||
|
||||
sizes = [len(c["landmark_event_ids"]) for c in communities]
|
||||
sizes_malicious = [r["num_landmarks"] for r in labeled if r["label"] == "malicious"]
|
||||
sizes_benign = [r["num_landmarks"] for r in labeled if r["label"] == "benign"]
|
||||
|
||||
failures = [
|
||||
(row["atom_id"] or "(no atom)", row["subject_uuid"])
|
||||
for row in coverage_rows
|
||||
if not row["covered"]
|
||||
]
|
||||
failure_atoms = Counter(atom for atom, _ in failures)
|
||||
|
||||
lines = [
|
||||
"# Landmark CSG Detection Coverage",
|
||||
"",
|
||||
"Construction is GT-free. This report joins GT only for evaluation.",
|
||||
"",
|
||||
"## Inputs",
|
||||
"",
|
||||
f"- communities: `{config['communities_path']}`",
|
||||
f"- landmarks: `{config['landmarks_path']}`",
|
||||
f"- oracle: `{config['oracle_targets_path']}`",
|
||||
f"- output (labeled communities): `{config['out_labeled_communities']}`",
|
||||
"",
|
||||
"## Subject coverage",
|
||||
"",
|
||||
f"- GT-malicious subjects: {total_subjects}",
|
||||
f"- subjects touched by at least one community: {coverage}",
|
||||
(
|
||||
f"- **subject_coverage_recall**: {coverage / total_subjects:.3f}"
|
||||
if total_subjects
|
||||
else "- subject_coverage_recall: n/a"
|
||||
),
|
||||
"",
|
||||
"## Community-level join",
|
||||
"",
|
||||
f"- communities total: {len(communities)}",
|
||||
f"- malicious communities: {malicious_communities}",
|
||||
f"- benign communities: {benign_communities}",
|
||||
(
|
||||
f"- malicious_share: {malicious_communities / len(communities):.4f}"
|
||||
if communities
|
||||
else "- malicious_share: n/a"
|
||||
),
|
||||
f"- distinct GT atoms with ≥1 community: {len(atoms_with_at_least_one_community)} / {len(total_atoms)}",
|
||||
"",
|
||||
"## Community size",
|
||||
"",
|
||||
f"- all (n={len(sizes)}): min={min(sizes, default=0)} median={_pct(sizes, 0.5)} p90={_pct(sizes, 0.9)} max={max(sizes, default=0)}",
|
||||
f"- malicious (n={len(sizes_malicious)}): median={_pct(sizes_malicious, 0.5)} p90={_pct(sizes_malicious, 0.9)} max={max(sizes_malicious, default=0)}",
|
||||
f"- benign (n={len(sizes_benign)}): median={_pct(sizes_benign, 0.5)} p90={_pct(sizes_benign, 0.9)} max={max(sizes_benign, default=0)}",
|
||||
"",
|
||||
"## Failure breakdown (uncovered GT subjects)",
|
||||
"",
|
||||
]
|
||||
if not failures:
|
||||
lines.append("- (none)")
|
||||
else:
|
||||
for atom, n in failure_atoms.most_common(20):
|
||||
lines.append(f"- {atom}: {n}")
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Interpretation",
|
||||
"",
|
||||
"- `subject_coverage_recall` is the upper bound on detection recall:",
|
||||
" any subject NOT touched by a community cannot be flagged.",
|
||||
"- `malicious_share` is the inverse of the LLM's class imbalance — too low",
|
||||
" means LLM faces an extreme imbalance; too high means the construction is",
|
||||
" over-clustering benign and malicious into shared communities.",
|
||||
"- Median malicious community size vs benign indicates whether attack",
|
||||
" stories naturally form longer chains than benign noise.",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _pct(values: list[int], q: float) -> int | None:
|
||||
if not values:
|
||||
return None
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round(q * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
49
scripts/extract_e3_ground_truth_atoms.py
Normal file
49
scripts/extract_e3_ground_truth_atoms.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Extract label-only structured atoms from the E3 ground-truth PDF."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.ground_truth import write_ground_truth_atom_report
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract label-only E3 ground-truth atoms. Output must not be used in prompts."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdf",
|
||||
default="data/ground_truth/e3/TC_Ground_Truth_Report_E3_Update.pdf",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="reports/ground_truth/e3")
|
||||
parser.add_argument("--target-filter", default="THEIA")
|
||||
args = parser.parse_args()
|
||||
|
||||
pdf_path = Path(args.pdf)
|
||||
if not pdf_path.exists():
|
||||
raise SystemExit(f"missing PDF: {pdf_path}")
|
||||
|
||||
result = subprocess.run(
|
||||
["pdftotext", "-layout", str(pdf_path), "-"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
report = write_ground_truth_atom_report(
|
||||
result.stdout,
|
||||
jsonl_path=output_dir / "ground_truth_atoms.jsonl",
|
||||
markdown_path=output_dir / "ground_truth_atoms.md",
|
||||
target_filter=None if args.target_filter.lower() == "all" else args.target_filter,
|
||||
)
|
||||
print(f"atoms={len(report.atoms)} lines_seen={report.lines_seen}")
|
||||
print(f"wrote {output_dir / 'ground_truth_atoms.jsonl'}")
|
||||
print(f"wrote {output_dir / 'ground_truth_atoms.md'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
scripts/freeze_method_version.py
Normal file
43
scripts/freeze_method_version.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Freeze an auditable ER-TP-DGP method-version manifest."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.versioning import write_method_version_manifest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Write a sanitized, hash-based ER-TP-DGP method manifest."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="reports/method_versions/ER-TP-DGP-v0.1.json",
|
||||
help="Destination JSON path.",
|
||||
)
|
||||
parser.add_argument("--version", default="ER-TP-DGP-v0.1")
|
||||
parser.add_argument("--repo-root", default=".")
|
||||
parser.add_argument(
|
||||
"--llm-config",
|
||||
default="configs/llm.yaml",
|
||||
help="LLM YAML to sanitize and include. Use 'none' to skip.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
llm_config_path = None if args.llm_config.lower() == "none" else args.llm_config
|
||||
manifest = write_method_version_manifest(
|
||||
args.output,
|
||||
repo_root=args.repo_root,
|
||||
version=args.version,
|
||||
llm_config_path=llm_config_path,
|
||||
)
|
||||
print(f"wrote {Path(args.output)}")
|
||||
print(f"method={manifest.method_name} version={manifest.version}")
|
||||
print(f"components={len(manifest.components)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
331
scripts/import_orthrus_ground_truth.py
Normal file
331
scripts/import_orthrus_ground_truth.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Import ORTHRUS (USENIX Sec 2025) ground truth into ER-TP-DGP labeled_targets.jsonl format.
|
||||
|
||||
ORTHRUS publishes manually curated, attack-graph-aligned ground truth for
|
||||
DARPA TC E3 + E5 (12 attack scenarios across 6 sub-datasets) on Zenodo
|
||||
(record 14641608, https://github.com/ubc-provenance/ground-truth). It is the
|
||||
most conservative and reproducible labeling currently available — see
|
||||
ORTHRUS Appendix C "Ground Truth Construction" for methodology.
|
||||
|
||||
This script:
|
||||
1. Reads ORTHRUS CSV files (UUID, attrs, index_id) per attack scenario
|
||||
2. Filters to subject (process) entities — those are the ones our pipeline
|
||||
can score as targets. Files / netflows are kept as evidence of attack
|
||||
scope but excluded from the target list.
|
||||
3. Produces labeled_targets.jsonl rows with:
|
||||
label = "malicious"
|
||||
atom_id = ORTHRUS scenario name (e.g. e3-theia-Browser_Extension_Drakon_Dropper)
|
||||
process_path = parsed from ORTHRUS attributes['subject']
|
||||
cohort = "positive_high_confidence_orthrus"
|
||||
4. Optionally augments with hard_negative_proxy from candidate_universe.
|
||||
|
||||
NOTE: each malicious target needs an ``anchor_event_id``. ORTHRUS labels
|
||||
entities, not events — so for each subject UUID we pick the FIRST event in
|
||||
the THEIA log where that subject appears as actor (i.e. its earliest action).
|
||||
This requires scanning the corpus once to map subject_uuid → first
|
||||
event_uuid, which is built lazily and cached.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import csv
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.theia import discover_theia_json_files, iter_theia_records
|
||||
from er_tp_dgp.theia import _unwrap_uuid
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
||||
parser.add_argument(
|
||||
"--orthrus-dir",
|
||||
default="data/ground_truth/orthrus/ubc-provenance-ground-truth-ff65bc7/darpa",
|
||||
help="Root of the unpacked ORTHRUS ground-truth zip.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sub-dataset",
|
||||
default="E3-THEIA",
|
||||
choices=["E3-CADETS", "E3-CLEARSCOPE", "E3-THEIA", "E5-CADETS", "E5-CLEARSCOPE", "E5-THEIA"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default="data/raw/e3_theia_json",
|
||||
help="Raw THEIA JSON corpus to scan for first-event-per-subject anchors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-jsonl",
|
||||
required=True,
|
||||
help="Output labeled_targets.jsonl path (will be written; parent dirs created).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--anchor-cache",
|
||||
default="reports/cache/orthrus_subject_first_event_e3_theia.jsonl",
|
||||
help="Cache mapping subject UUID -> first event UUID (built once).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-non-subject",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Include ORTHRUS-labeled file/netflow entities as separate targets. "
|
||||
"Default: only subject (process) entities."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--candidate-universe",
|
||||
default="reports/theia_candidate_universe/candidate_universe.jsonl",
|
||||
help="Used to draw diverse benign cohort.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-benign",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of hard_negative_proxy candidates to draw from candidate_universe.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benign-process-paths",
|
||||
nargs="*",
|
||||
default=None,
|
||||
help=(
|
||||
"Optional list of process paths to include in benign cohort. If unset, "
|
||||
"stratify by unique process_path to maximize cohort diversity."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
orthrus_root = Path(args.orthrus_dir) / args.sub_dataset
|
||||
if not orthrus_root.exists():
|
||||
raise SystemExit(f"missing ORTHRUS dir: {orthrus_root}")
|
||||
csv_files = sorted(orthrus_root.glob("node_*.csv"))
|
||||
if not csv_files:
|
||||
raise SystemExit(f"no node_*.csv under {orthrus_root}")
|
||||
|
||||
# Step 1: load all ORTHRUS-labeled UUIDs + their attributes per scenario.
|
||||
malicious_records: list[dict] = []
|
||||
for f in csv_files:
|
||||
scenario = f.stem.replace("node_", "")
|
||||
with f.open() as handle:
|
||||
for row in csv.reader(handle):
|
||||
if not row or len(row) < 3:
|
||||
continue
|
||||
uuid, attrs_str, _idx = row[0], row[1], row[2]
|
||||
try:
|
||||
attrs = ast.literal_eval(attrs_str)
|
||||
except (SyntaxError, ValueError):
|
||||
attrs = {}
|
||||
if not isinstance(attrs, dict) or not attrs:
|
||||
continue
|
||||
attr_type = next(iter(attrs.keys()))
|
||||
attr_value = attrs[attr_type]
|
||||
if not args.include_non_subject and attr_type != "subject":
|
||||
continue
|
||||
process_path, command_line = _parse_subject_attr(attr_value) if attr_type == "subject" else (None, None)
|
||||
malicious_records.append({
|
||||
"target_id": uuid,
|
||||
"target_type": "PROCESS" if attr_type == "subject" else attr_type.upper(),
|
||||
"atom_id": f"{args.sub_dataset.lower()}-{scenario}",
|
||||
"label": "malicious",
|
||||
"label_confidence": "high",
|
||||
"label_source": "orthrus_manual_curated",
|
||||
"cohort": "positive_high_confidence_orthrus",
|
||||
"process_path": process_path,
|
||||
"command_line": command_line,
|
||||
"attrs_raw": attrs,
|
||||
})
|
||||
print(f"ORTHRUS {args.sub_dataset}: scenarios={len(csv_files)} records={len(malicious_records)} (subjects only={not args.include_non_subject})")
|
||||
|
||||
# Step 2: anchor mapping. We need a target_event_uuid for each subject so
|
||||
# build_theia_window_ir can pick a time window. Build subject_uuid →
|
||||
# first_event_uuid via one corpus scan, with on-disk cache.
|
||||
anchor_cache_path = Path(args.anchor_cache)
|
||||
subject_to_anchor: dict[str, dict] = {}
|
||||
if anchor_cache_path.exists():
|
||||
with anchor_cache_path.open() as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
row = json.loads(line)
|
||||
subject_to_anchor[row["subject_uuid"]] = row
|
||||
print(f"loaded {len(subject_to_anchor)} subject→event anchors from {anchor_cache_path}")
|
||||
else:
|
||||
wanted = {r["target_id"] for r in malicious_records if r["target_type"] == "PROCESS"}
|
||||
print(f"scanning corpus for first event per subject (n={len(wanted)} subjects)... this may take ~5 min for 80 GB E3-THEIA")
|
||||
paths = discover_theia_json_files(args.data_dir)
|
||||
for record in iter_theia_records(paths):
|
||||
if record.record_type != "Event":
|
||||
continue
|
||||
payload = record.payload
|
||||
sid = _unwrap_uuid(payload.get("subject"))
|
||||
if sid not in wanted or sid in subject_to_anchor:
|
||||
continue
|
||||
ts = payload.get("timestampNanos")
|
||||
if not isinstance(ts, int):
|
||||
continue
|
||||
evid = payload.get("uuid")
|
||||
if not evid:
|
||||
continue
|
||||
subject_to_anchor[sid] = {
|
||||
"subject_uuid": sid,
|
||||
"anchor_event_id": evid,
|
||||
"anchor_event_type": payload.get("type"),
|
||||
"anchor_timestamp_nanos": ts,
|
||||
}
|
||||
if len(subject_to_anchor) >= len(wanted):
|
||||
break
|
||||
anchor_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with anchor_cache_path.open("w") as out:
|
||||
for row in subject_to_anchor.values():
|
||||
out.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
print(f"cached {len(subject_to_anchor)} subject→event anchors → {anchor_cache_path}")
|
||||
|
||||
# Step 3: emit labeled_targets.jsonl rows.
|
||||
out_rows: list[dict] = []
|
||||
skipped = 0
|
||||
for r in malicious_records:
|
||||
anchor = subject_to_anchor.get(r["target_id"])
|
||||
if r["target_type"] == "PROCESS" and not anchor:
|
||||
skipped += 1
|
||||
continue
|
||||
out_rows.append({
|
||||
"target_id": r["target_id"],
|
||||
"target_type": r["target_type"],
|
||||
"label": r["label"],
|
||||
"label_confidence": r["label_confidence"],
|
||||
"cohort": r["cohort"],
|
||||
"anchor_event_id": (anchor or {}).get("anchor_event_id"),
|
||||
"anchor_timestamp_nanos": (anchor or {}).get("anchor_timestamp_nanos"),
|
||||
"atom_id": r["atom_id"],
|
||||
"label_source": r["label_source"],
|
||||
"matched_event_count": 0,
|
||||
"weak_signal_score": None,
|
||||
"candidate_total_events": None,
|
||||
"candidate_estimated_prompt_tokens": None,
|
||||
"process_path": r["process_path"],
|
||||
"command_line": r["command_line"],
|
||||
"prompt_allowed_label_fields": False,
|
||||
"notes": [
|
||||
"Ground truth from ORTHRUS USENIX Sec 2025 (Zenodo 14641608).",
|
||||
"Manually curated, conservative attack-graph-aligned labels.",
|
||||
f"Attack scenario: {r['atom_id']}.",
|
||||
],
|
||||
})
|
||||
print(f"emitted {len(out_rows)} malicious targets ({skipped} skipped due to missing anchor)")
|
||||
|
||||
# Step 4: optional benign cohort.
|
||||
if args.num_benign > 0:
|
||||
cu_path = Path(args.candidate_universe)
|
||||
if not cu_path.exists():
|
||||
print(f"WARNING: candidate_universe missing at {cu_path}; skipping benign cohort.")
|
||||
else:
|
||||
benign_rows = _select_diverse_benign(
|
||||
cu_path,
|
||||
num=args.num_benign,
|
||||
exclude_uuids={r["target_id"] for r in out_rows},
|
||||
allowed_paths=set(args.benign_process_paths) if args.benign_process_paths else None,
|
||||
)
|
||||
out_rows.extend(benign_rows)
|
||||
print(f"appended {len(benign_rows)} hard_negative_proxy targets")
|
||||
|
||||
out_path = Path(args.out_jsonl)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with out_path.open("w", encoding="utf-8") as handle:
|
||||
for row in out_rows:
|
||||
handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
print(f"wrote {len(out_rows)} targets → {out_path}")
|
||||
return 0
|
||||
|
||||
|
||||
def _parse_subject_attr(value: str) -> tuple[str | None, str | None]:
|
||||
"""ORTHRUS subject attrs look like: '/usr/bin/firefox firefox-bin -P default -e'.
|
||||
|
||||
Heuristic: the FIRST whitespace-separated token that starts with `/` or
|
||||
contains a slash is the path; everything else is the command line.
|
||||
"""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None, None
|
||||
tokens = value.strip().split()
|
||||
path = None
|
||||
for tok in tokens:
|
||||
if tok.startswith("/") or "/" in tok:
|
||||
path = tok
|
||||
break
|
||||
return path, value.strip()
|
||||
|
||||
|
||||
def _select_diverse_benign(
|
||||
candidate_universe_path: Path,
|
||||
*,
|
||||
num: int,
|
||||
exclude_uuids: set[str],
|
||||
allowed_paths: set[str] | None,
|
||||
) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
by_path: dict[str, list[dict]] = defaultdict(list)
|
||||
with candidate_universe_path.open() as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
r = json.loads(line)
|
||||
cid = r.get("candidate_id")
|
||||
if not cid or cid in exclude_uuids:
|
||||
continue
|
||||
sample_events = r.get("sample_raw_event_ids") or []
|
||||
if not sample_events:
|
||||
continue
|
||||
path = r.get("process_path") or "unknown"
|
||||
if allowed_paths is not None and path not in allowed_paths:
|
||||
continue
|
||||
by_path[path].append(r)
|
||||
|
||||
# Stratify: round-robin over distinct process_paths to maximize diversity.
|
||||
paths_sorted = sorted(by_path.keys(), key=lambda p: (-len(by_path[p]), p))
|
||||
picked: list[dict] = []
|
||||
while len(picked) < num and paths_sorted:
|
||||
for p in list(paths_sorted):
|
||||
if not by_path[p]:
|
||||
paths_sorted.remove(p)
|
||||
continue
|
||||
picked.append(by_path[p].pop(0))
|
||||
if len(picked) >= num:
|
||||
break
|
||||
|
||||
for r in picked:
|
||||
rows.append({
|
||||
"target_id": r["candidate_id"],
|
||||
"target_type": "PROCESS",
|
||||
"label": "benign_proxy",
|
||||
"label_confidence": "unverified",
|
||||
"cohort": "hard_negative_proxy",
|
||||
"anchor_event_id": str((r.get("sample_raw_event_ids") or [None])[0]),
|
||||
"atom_id": None,
|
||||
"label_source": "candidate_not_in_orthrus_ground_truth",
|
||||
"matched_event_count": 0,
|
||||
"weak_signal_score": _safe_float(r.get("weak_signal_score")),
|
||||
"candidate_total_events": _safe_int(r.get("total_events")),
|
||||
"candidate_estimated_prompt_tokens": _safe_int(r.get("estimated_prompt_tokens")),
|
||||
"process_path": r.get("process_path"),
|
||||
"command_line": r.get("command_line"),
|
||||
"prompt_allowed_label_fields": False,
|
||||
"notes": [
|
||||
"Hard negative proxy: process not in ORTHRUS ground truth and not matching any attack atom.",
|
||||
"Diversity-stratified across process paths from candidate_universe.",
|
||||
],
|
||||
})
|
||||
return rows
|
||||
|
||||
|
||||
def _safe_float(value):
|
||||
try: return float(value)
|
||||
except (TypeError, ValueError): return None
|
||||
|
||||
|
||||
def _safe_int(value):
|
||||
try: return int(value)
|
||||
except (TypeError, ValueError): return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
144
scripts/map_theia_ground_truth.py
Normal file
144
scripts/map_theia_ground_truth.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Map E3 THEIA ground-truth atoms to THEIA events/processes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.ground_truth_mapping import (
|
||||
evaluate_candidate_recall,
|
||||
match_theia_ground_truth_atoms,
|
||||
read_ground_truth_atoms_jsonl,
|
||||
)
|
||||
from er_tp_dgp.theia import discover_theia_json_files
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Map label-only E3 ground-truth atoms to THEIA events/processes. "
|
||||
"Outputs are forbidden from prompt construction."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Specific THEIA JSON file to scan. Can be repeated. Overrides --data-dir discovery.",
|
||||
)
|
||||
parser.add_argument("--atoms", default="reports/ground_truth/e3/ground_truth_atoms.jsonl")
|
||||
parser.add_argument("--candidate-jsonl", default="reports/theia_candidate_universe/candidate_universe.jsonl")
|
||||
parser.add_argument("--output-dir", default="reports/ground_truth/e3_mapping")
|
||||
parser.add_argument("--max-lines", type=int, default=None)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=None)
|
||||
parser.add_argument("--min-score", type=float, default=3.0)
|
||||
parser.add_argument("--include-term-only", action="store_true")
|
||||
parser.add_argument("--require-time-window", action="store_true")
|
||||
parser.add_argument("--time-window-hours", type=float, default=6.0)
|
||||
parser.add_argument(
|
||||
"--recall-min-confidence",
|
||||
choices=("low", "medium", "high"),
|
||||
default="high",
|
||||
help="Minimum mapped label confidence used for candidate recall.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timezone-offsets-hours",
|
||||
default="0",
|
||||
help="Comma-separated local offsets to try when interpreting ground-truth times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-target-network-ips",
|
||||
action="store_true",
|
||||
help="Allow 128.55.12.* target network addresses to act as hard match indicators.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = [Path(path) for path in args.input_file] if args.input_file else discover_theia_json_files(args.data_dir)
|
||||
if not paths:
|
||||
raise SystemExit(f"no THEIA JSON files found under {args.data_dir}")
|
||||
atoms = read_ground_truth_atoms_jsonl(args.atoms)
|
||||
offsets = tuple(int(value) for value in args.timezone_offsets_hours.split(",") if value.strip())
|
||||
|
||||
report = match_theia_ground_truth_atoms(
|
||||
paths,
|
||||
atoms,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
min_score=args.min_score,
|
||||
include_term_only=args.include_term_only,
|
||||
require_time_window=args.require_time_window,
|
||||
time_window_hours=args.time_window_hours,
|
||||
timezone_offsets_hours=offsets or (0,),
|
||||
ignore_target_network_prefixes=() if args.include_target_network_ips else ("128.55.12.",),
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
event_path = output_dir / "event_matches.jsonl"
|
||||
process_path = output_dir / "process_labels.jsonl"
|
||||
report_path = output_dir / "mapping_report.md"
|
||||
report.write_event_jsonl(event_path)
|
||||
report.write_process_jsonl(process_path)
|
||||
report_path.write_text(report.to_markdown() + "\n", encoding="utf-8")
|
||||
|
||||
filtered_event_path = output_dir / f"event_matches_{args.recall_min_confidence}_plus.jsonl"
|
||||
filtered_process_path = output_dir / f"process_labels_{args.recall_min_confidence}_plus.jsonl"
|
||||
_write_filtered_event_matches(filtered_event_path, report.event_matches, args.recall_min_confidence)
|
||||
_write_filtered_process_labels(filtered_process_path, report.process_labels, args.recall_min_confidence)
|
||||
|
||||
recall = evaluate_candidate_recall(
|
||||
args.candidate_jsonl,
|
||||
report.process_labels,
|
||||
report.event_matches,
|
||||
min_confidence=args.recall_min_confidence,
|
||||
)
|
||||
recall_json = output_dir / "candidate_recall.json"
|
||||
recall_md = output_dir / "candidate_recall.md"
|
||||
recall_json.write_text(
|
||||
json.dumps(recall.to_json_dict(), indent=2, sort_keys=True, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
recall_md.write_text(recall.to_markdown() + "\n", encoding="utf-8")
|
||||
|
||||
print(
|
||||
f"atoms={report.atoms_seen} lines_seen={report.lines_seen} "
|
||||
f"events_seen={report.events_seen} event_matches={len(report.event_matches)} "
|
||||
f"process_labels={len(report.process_labels)}"
|
||||
)
|
||||
print(f"candidate_process_recall={recall.process_recall}")
|
||||
print(f"event_subject_recall={recall.event_subject_recall}")
|
||||
print(f"wrote {event_path}")
|
||||
print(f"wrote {process_path}")
|
||||
print(f"wrote {filtered_event_path}")
|
||||
print(f"wrote {filtered_process_path}")
|
||||
print(f"wrote {report_path}")
|
||||
print(f"wrote {recall_json}")
|
||||
|
||||
|
||||
def _write_filtered_event_matches(path, matches, min_confidence):
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for match in matches:
|
||||
if _confidence_rank(match.confidence) >= _confidence_rank(min_confidence):
|
||||
handle.write(json.dumps(match.to_json_dict(), ensure_ascii=False, sort_keys=True) + "\n")
|
||||
|
||||
|
||||
def _write_filtered_process_labels(path, labels, min_confidence):
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
for label in labels:
|
||||
if _confidence_rank(label.confidence) >= _confidence_rank(min_confidence):
|
||||
handle.write(json.dumps(label.to_json_dict(), ensure_ascii=False, sort_keys=True) + "\n")
|
||||
|
||||
|
||||
def _confidence_rank(value):
|
||||
return {"low": 0, "medium": 1, "high": 2}.get(value, -1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
scripts/retry_skipped_llm.py
Normal file
127
scripts/retry_skipped_llm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Retry LLM inference for prompts that were skipped due to transient errors.
|
||||
|
||||
Reads predictions_jsonl, identifies rows with ``skipped: true``, looks up
|
||||
the corresponding prompt files, and re-runs LLM inference with retries.
|
||||
Successful retries replace the skipped row; persistent failures keep
|
||||
the original skip row.
|
||||
|
||||
Adds in-process exponential-backoff retry on the API ``no choices``
|
||||
response — the proxy hche3637.com returns transient empty bodies that
|
||||
look like HTTP-200 but lack ``choices``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from er_tp_dgp.llm import OpenAICompatibleHTTPProvider
|
||||
from er_tp_dgp.llm_config import load_llm_config
|
||||
|
||||
|
||||
def _read_predictions(path: Path) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _retry_inference(
|
||||
provider: OpenAICompatibleHTTPProvider,
|
||||
target_id: str,
|
||||
prompt_text: str,
|
||||
*,
|
||||
max_attempts: int,
|
||||
backoff_seconds: float,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Try up to ``max_attempts`` times with exponential backoff. Returns
|
||||
(payload, error_str). On success, payload is the to_json_dict() result."""
|
||||
last_error: str | None = None
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
result = provider.classify(target_id=target_id, prompt_text=prompt_text)
|
||||
return result.to_json_dict(), None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = f"{type(exc).__name__}: {str(exc)[:200]}"
|
||||
if attempt < max_attempts:
|
||||
time.sleep(backoff_seconds * (2 ** (attempt - 1)))
|
||||
return None, last_error
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--predictions-jsonl", required=True)
|
||||
parser.add_argument("--prompt-dir", required=True)
|
||||
parser.add_argument("--config", default="configs/llm.yaml")
|
||||
parser.add_argument("--max-attempts", type=int, default=4)
|
||||
parser.add_argument("--backoff-seconds", type=float, default=2.0)
|
||||
parser.add_argument("--output-jsonl", default=None,
|
||||
help="Defaults to predictions-jsonl (in-place).")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = load_llm_config(args.config)
|
||||
# Honor request_logprobs from upstream config — does NOT enable here
|
||||
# by default since the proxy seems to ignore it anyway.
|
||||
provider = OpenAICompatibleHTTPProvider(config)
|
||||
|
||||
predictions_path = Path(args.predictions_jsonl)
|
||||
prompt_dir = Path(args.prompt_dir)
|
||||
rows = _read_predictions(predictions_path)
|
||||
skipped = [r for r in rows if r.get("skipped")]
|
||||
print(f"[retry] total rows: {len(rows)}, skipped: {len(skipped)}", flush=True)
|
||||
|
||||
successes = 0
|
||||
persistent_failures = 0
|
||||
for row in rows:
|
||||
if not row.get("skipped"):
|
||||
continue
|
||||
target_id = row.get("target_id")
|
||||
prompt_file = prompt_dir / f"{target_id}.txt"
|
||||
if not prompt_file.exists():
|
||||
print(f"[retry] {target_id}: prompt file missing, keeping skip", flush=True)
|
||||
persistent_failures += 1
|
||||
continue
|
||||
prompt_text = prompt_file.read_text(encoding="utf-8")
|
||||
payload, error = _retry_inference(
|
||||
provider,
|
||||
target_id=target_id,
|
||||
prompt_text=prompt_text,
|
||||
max_attempts=args.max_attempts,
|
||||
backoff_seconds=args.backoff_seconds,
|
||||
)
|
||||
if payload is None:
|
||||
print(f"[retry] {target_id}: persistent failure: {error}", flush=True)
|
||||
row["skip_reason"] = f"after {args.max_attempts} retries: {error}"
|
||||
persistent_failures += 1
|
||||
continue
|
||||
# Replace the skipped row with the successful payload.
|
||||
payload["prompt_file"] = str(prompt_file)
|
||||
row.clear()
|
||||
row.update(payload)
|
||||
successes += 1
|
||||
print(f"[retry] {target_id}: SUCCESS {payload.get('output', {}).get('first_token_label')}",
|
||||
flush=True)
|
||||
|
||||
output_path = Path(args.output_jsonl) if args.output_jsonl else predictions_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
print(
|
||||
f"[retry] DONE successes={successes} persistent_failures={persistent_failures} "
|
||||
f"wrote={output_path}",
|
||||
flush=True,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
226
scripts/run_evaluation.py
Normal file
226
scripts/run_evaluation.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
"""End-to-end evaluation: join LLM predictions with labels and aggregate metrics.
|
||||
|
||||
Inputs:
|
||||
--predictions-jsonl One file per method variant, produced by
|
||||
run_llm_inference.py. The file's basename is used as
|
||||
the method name in the metrics table.
|
||||
--labeled-targets evaluation_batch jsonl (target_id, label, ...)
|
||||
|
||||
Output:
|
||||
--output-dir/metrics.md Paper-Table-2-style table:
|
||||
method | AUPRC | AUROC | Macro-F1 |
|
||||
Recall@10 | FPR@0.9 | avg_tokens |
|
||||
avg_latency | evidence_path_hit_rate
|
||||
--output-dir/metrics.json Machine-readable equivalent.
|
||||
|
||||
Each row uses the calibrated first-token softmax score from
|
||||
``LLMInferenceResult.first_token_score`` (DGP paper formula 14). If a row's
|
||||
score is missing, it is excluded from the metrics with a warning.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.metrics import PredictionRecord, evaluate_classification
|
||||
|
||||
|
||||
_log = logging.getLogger("run_evaluation")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
||||
parser.add_argument(
|
||||
"--predictions-jsonl",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Repeat once per method variant. Filename stem is used as method name.",
|
||||
)
|
||||
parser.add_argument("--labeled-targets", required=True)
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 5, 10],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recall-levels",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[0.8, 0.9],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
labels = _index_labels(Path(args.labeled_targets))
|
||||
|
||||
method_metrics: dict[str, dict] = {}
|
||||
for path in args.predictions_jsonl:
|
||||
prediction_path = Path(path)
|
||||
method_name = prediction_path.stem
|
||||
records = _build_prediction_records(prediction_path, labels)
|
||||
if not records:
|
||||
_log.warning("No usable predictions in %s; skipping.", prediction_path)
|
||||
continue
|
||||
metrics = evaluate_classification(
|
||||
records, k_values=args.k_values, recall_levels=args.recall_levels
|
||||
)
|
||||
method_metrics[method_name] = {
|
||||
"metrics": metrics.to_dict(),
|
||||
"num_records_used": len(records),
|
||||
"predictions_path": str(prediction_path),
|
||||
}
|
||||
|
||||
(output_dir / "metrics.json").write_text(
|
||||
json.dumps(method_metrics, ensure_ascii=False, sort_keys=True, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(output_dir / "metrics.md").write_text(_render_markdown_table(method_metrics), encoding="utf-8")
|
||||
print(f"wrote {output_dir/'metrics.md'}")
|
||||
print(f"wrote {output_dir/'metrics.json'}")
|
||||
return 0
|
||||
|
||||
|
||||
def _index_labels(path: Path) -> dict[str, dict]:
|
||||
labels: dict[str, dict] = {}
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
target_id = row.get("target_id")
|
||||
if target_id:
|
||||
labels[target_id] = row
|
||||
return labels
|
||||
|
||||
|
||||
def _build_prediction_records(
|
||||
predictions_path: Path, labels: dict[str, dict]
|
||||
) -> list[PredictionRecord]:
|
||||
records: list[PredictionRecord] = []
|
||||
with predictions_path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
target_id = payload.get("target_id")
|
||||
output = payload.get("output") or {}
|
||||
score = (
|
||||
payload.get("first_token_score")
|
||||
if payload.get("first_token_score") is not None
|
||||
else output.get("score")
|
||||
)
|
||||
if score is None:
|
||||
# Fallback: many OpenAI-compatible endpoints don't honor logprobs.
|
||||
# Derive a degraded binary score from the first-token label so the
|
||||
# row is still usable (Macro-F1 / Precision@K stay valid; AUROC
|
||||
# collapses but AUPRC still works on rank order).
|
||||
first_label = (output.get("first_token_label") or "").upper()
|
||||
predicted_upper = str(output.get("predicted_label") or "").upper()
|
||||
if first_label == "MALICIOUS" or predicted_upper == "MALICIOUS":
|
||||
score = 1.0
|
||||
elif first_label == "BENIGN" or predicted_upper == "BENIGN":
|
||||
score = 0.0
|
||||
else:
|
||||
_log.warning(
|
||||
"missing first-token score AND no usable label for %s; skipping",
|
||||
target_id,
|
||||
)
|
||||
continue
|
||||
# Prompt-batch filenames carry an "NNNN_<uuid>" prefix (see
|
||||
# build_theia_prompt_batch.py:_safe_id). Recover the bare UUID
|
||||
# so that labeled_targets.jsonl lookups succeed.
|
||||
label_row = labels.get(target_id)
|
||||
if not label_row and isinstance(target_id, str) and "_" in target_id:
|
||||
bare = target_id.split("_", 1)[1]
|
||||
label_row = labels.get(bare)
|
||||
if label_row:
|
||||
target_id = bare
|
||||
if not label_row:
|
||||
_log.warning("no label for %s; skipping", target_id)
|
||||
continue
|
||||
true_label = "malicious" if label_row.get("label") == "malicious" else "benign"
|
||||
predicted = output.get("predicted_label", "BENIGN")
|
||||
predicted_label = "malicious" if str(predicted).upper() == "MALICIOUS" else "benign"
|
||||
records.append(
|
||||
PredictionRecord(
|
||||
target_id=target_id,
|
||||
target_type=label_row.get("target_type", "PROCESS"),
|
||||
score=float(max(0.0, min(1.0, score))),
|
||||
predicted_label=predicted_label,
|
||||
true_label=true_label,
|
||||
timestamp=label_row.get("anchor_timestamp"),
|
||||
evidence_path_ids=tuple(output.get("evidence_path_ids") or ()),
|
||||
prompt_tokens=payload.get("prompt_tokens"),
|
||||
inference_cost=None,
|
||||
prompt_construction_time=None,
|
||||
)
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
def _render_markdown_table(method_metrics: dict[str, dict]) -> str:
|
||||
if not method_metrics:
|
||||
return "# ER-TP-DGP Evaluation\n\nNo method metrics produced.\n"
|
||||
headers = [
|
||||
"method",
|
||||
"n",
|
||||
"n+",
|
||||
"AUPRC",
|
||||
"AUROC",
|
||||
"Macro-F1",
|
||||
"Recall@10",
|
||||
"FPR@0.9",
|
||||
"avg_tokens",
|
||||
"evidence_hit",
|
||||
]
|
||||
rows: list[list[str]] = []
|
||||
for method_name, payload in sorted(method_metrics.items()):
|
||||
m = payload["metrics"]
|
||||
rows.append(
|
||||
[
|
||||
method_name,
|
||||
str(m["num_examples"]),
|
||||
str(m["num_positive"]),
|
||||
_fmt(m["auprc"]),
|
||||
_fmt(m["auroc"]),
|
||||
_fmt(m["macro_f1"]),
|
||||
_fmt(m["recall_at_k"].get(10)),
|
||||
_fmt(m["fpr_at_recall"].get(0.9)),
|
||||
_fmt(m["avg_prompt_tokens"]),
|
||||
_fmt(m["evidence_path_hit_rate"]),
|
||||
]
|
||||
)
|
||||
lines = [
|
||||
"# ER-TP-DGP Evaluation",
|
||||
"",
|
||||
"Per-method metrics. Score column is calibrated first-token softmax over (Yes, No)",
|
||||
"(DGP paper formula 14). Records missing logprobs are excluded with a warning.",
|
||||
"",
|
||||
"| " + " | ".join(headers) + " |",
|
||||
"|" + "|".join(["---"] * len(headers)) + "|",
|
||||
]
|
||||
for row in rows:
|
||||
lines.append("| " + " | ".join(row) + " |")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _fmt(value) -> str:
|
||||
if isinstance(value, float):
|
||||
return f"{value:.4f}"
|
||||
if value is None:
|
||||
return "n/a"
|
||||
return str(value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
103
scripts/run_hybrid_experiment.sh
Executable file
103
scripts/run_hybrid_experiment.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env bash
|
||||
# End-to-end hybrid (community + v0.1 fine-grained) experiment driver.
|
||||
#
|
||||
# Steps:
|
||||
# 1) Build hybrid prompts on a balanced set of communities (Phase 14 +
|
||||
# v0.1 fine-grained re-injection).
|
||||
# 2) Build a parallel set of Phase 14 raw landmark-only prompts on the
|
||||
# SAME communities (head-to-head ablation).
|
||||
# 3) Convert prompt metadata → labeled_targets.jsonl.
|
||||
# 4) Run LLM inference on both prompt sets.
|
||||
# 5) Run evaluation, write metrics.md.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_hybrid_experiment.sh [BENIGN_PER_MALICIOUS]
|
||||
#
|
||||
# Defaults to BENIGN_PER_MALICIOUS=24 → 6 mal + 144 ben = 150 communities,
|
||||
# matching the v0.1 evaluation scale of n=146.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BENIGN_PER_MAL=${1:-24}
|
||||
OUT_ROOT="reports/hybrid_v0_3"
|
||||
PROMPTS_HYBRID="${OUT_ROOT}/prompts_hybrid"
|
||||
PROMPTS_RAW="${OUT_ROOT}/prompts_landmark_raw"
|
||||
LABELED_TARGETS="${OUT_ROOT}/labeled_targets.jsonl"
|
||||
PRED_HYBRID="${OUT_ROOT}/predictions_hybrid.jsonl"
|
||||
PRED_RAW="${OUT_ROOT}/predictions_landmark_raw.jsonl"
|
||||
METRICS_DIR="${OUT_ROOT}/metrics"
|
||||
|
||||
mkdir -p "${OUT_ROOT}" "${METRICS_DIR}"
|
||||
|
||||
LANDMARK_DIR="reports/landmark_csg"
|
||||
COMMUNITIES="${LANDMARK_DIR}/landmark_communities.jsonl"
|
||||
LANDMARKS="${LANDMARK_DIR}/landmarks.jsonl"
|
||||
EDGES="${LANDMARK_DIR}/landmark_edges.jsonl"
|
||||
LABELED_COMMUNITIES="${LANDMARK_DIR}/labeled_communities.jsonl"
|
||||
|
||||
echo "=== STEP 1: build hybrid prompts (community + v0.1 fine-grained) ==="
|
||||
.venv/bin/python -u scripts/build_hybrid_community_prompts.py \
|
||||
--communities "${COMMUNITIES}" \
|
||||
--landmarks "${LANDMARKS}" \
|
||||
--landmark-edges "${EDGES}" \
|
||||
--labeled-communities "${LABELED_COMMUNITIES}" \
|
||||
--output-dir "${PROMPTS_HYBRID}" \
|
||||
--include-only balanced \
|
||||
--benign-per-malicious "${BENIGN_PER_MAL}" \
|
||||
--margin-seconds 60 \
|
||||
--max-events-per-community 5000 \
|
||||
--max-landmarks-in-prompt 60 \
|
||||
--max-edges-in-prompt 80 \
|
||||
--top-m-per-metapath 5 \
|
||||
--progress-every 2000000
|
||||
|
||||
echo "=== STEP 2: build Phase 14 raw landmark prompts on the SAME communities ==="
|
||||
.venv/bin/python -u scripts/build_landmark_prompts_for_ids.py \
|
||||
--communities "${COMMUNITIES}" \
|
||||
--landmarks "${LANDMARKS}" \
|
||||
--landmark-edges "${EDGES}" \
|
||||
--labeled-communities "${LABELED_COMMUNITIES}" \
|
||||
--ids-from-metadata "${PROMPTS_HYBRID}/prompt_metadata.jsonl" \
|
||||
--output-dir "${PROMPTS_RAW}" \
|
||||
--max-landmarks-in-prompt 60 \
|
||||
--max-edges-in-prompt 80
|
||||
|
||||
echo "=== STEP 3: build labeled_targets.jsonl from hybrid metadata ==="
|
||||
.venv/bin/python -u scripts/build_hybrid_labeled_targets.py \
|
||||
--prompt-metadata "${PROMPTS_HYBRID}/prompt_metadata.jsonl" \
|
||||
--output "${LABELED_TARGETS}"
|
||||
|
||||
echo "=== STEP 4a: LLM inference on hybrid prompts ==="
|
||||
.venv/bin/python -u scripts/run_llm_inference.py \
|
||||
--config configs/llm.yaml \
|
||||
--prompt-dir "${PROMPTS_HYBRID}/prompts" \
|
||||
--output-jsonl "${PRED_HYBRID}" \
|
||||
--request-logprobs \
|
||||
--max-prompt-chars 200000
|
||||
|
||||
echo "=== STEP 4b: LLM inference on Phase 14 raw landmark prompts (same set) ==="
|
||||
.venv/bin/python -u scripts/run_llm_inference.py \
|
||||
--config configs/llm.yaml \
|
||||
--prompt-dir "${PROMPTS_RAW}/prompts" \
|
||||
--output-jsonl "${PRED_RAW}" \
|
||||
--request-logprobs \
|
||||
--max-prompt-chars 200000
|
||||
|
||||
echo "=== STEP 5: aggregate metrics ==="
|
||||
.venv/bin/python -u scripts/run_evaluation.py \
|
||||
--predictions-jsonl "${PRED_HYBRID}" \
|
||||
--predictions-jsonl "${PRED_RAW}" \
|
||||
--labeled-targets "${LABELED_TARGETS}" \
|
||||
--output-dir "${METRICS_DIR}"
|
||||
|
||||
echo "=== STEP 6: cross-compare with v0.1/v0.2 baselines ==="
|
||||
.venv/bin/python -u scripts/summarize_hybrid_experiment.py \
|
||||
--hybrid-metrics "${METRICS_DIR}/metrics.json" \
|
||||
--output "${OUT_ROOT}/summary.md"
|
||||
|
||||
echo "=== ALL STAGES COMPLETE ==="
|
||||
echo "Metrics:"
|
||||
cat "${METRICS_DIR}/metrics.md"
|
||||
echo
|
||||
echo "Summary:"
|
||||
cat "${OUT_ROOT}/summary.md"
|
||||
81
scripts/run_hybrid_inference_local.sh
Executable file
81
scripts/run_hybrid_inference_local.sh
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env bash
|
||||
# Continuation of run_hybrid_experiment.sh — resumes from STEP 4 onward
|
||||
# using local_hf (HuggingFace transformers) provider instead of the API.
|
||||
# Steps 1-3 (prompt build + labeled targets) already completed; reuse their outputs.
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_hybrid_inference_local.sh [MODEL]
|
||||
# Default MODEL = Qwen/Qwen3.5-27B (matches v0.1/v0.2 baselines).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
MODEL=${1:-Qwen/Qwen3.5-27B}
|
||||
MAX_GIB=${HYBRID_MAX_GIB:-30}
|
||||
|
||||
OUT_ROOT="reports/hybrid_v0_3"
|
||||
PROMPTS_HYBRID="${OUT_ROOT}/prompts_hybrid"
|
||||
PROMPTS_RAW="${OUT_ROOT}/prompts_landmark_raw"
|
||||
LABELED_TARGETS="${OUT_ROOT}/labeled_targets.jsonl"
|
||||
PRED_HYBRID="${OUT_ROOT}/predictions_hybrid_local.jsonl"
|
||||
PRED_RAW="${OUT_ROOT}/predictions_landmark_raw_local.jsonl"
|
||||
METRICS_DIR="${OUT_ROOT}/metrics_local"
|
||||
|
||||
mkdir -p "${OUT_ROOT}" "${METRICS_DIR}"
|
||||
|
||||
# Step 2 redo: ensure raw landmark prompts cover the SAME 150 community ids.
|
||||
# (The first run hit a bash-mid-execution cache miss and used the legacy
|
||||
# build_landmark_prompts.py which only produced 12 prompts.)
|
||||
RAW_COUNT=$(ls "${PROMPTS_RAW}/prompts" 2>/dev/null | wc -l | tr -d ' ')
|
||||
if [[ "${RAW_COUNT}" != "150" ]]; then
|
||||
echo "=== STEP 2 (redo): build raw landmark prompts for the same 150 ids ==="
|
||||
.venv/bin/python -u scripts/build_landmark_prompts_for_ids.py \
|
||||
--communities reports/landmark_csg/landmark_communities.jsonl \
|
||||
--landmarks reports/landmark_csg/landmarks.jsonl \
|
||||
--landmark-edges reports/landmark_csg/landmark_edges.jsonl \
|
||||
--labeled-communities reports/landmark_csg/labeled_communities.jsonl \
|
||||
--ids-from-metadata "${PROMPTS_HYBRID}/prompt_metadata.jsonl" \
|
||||
--output-dir "${PROMPTS_RAW}" \
|
||||
--max-landmarks-in-prompt 60 \
|
||||
--max-edges-in-prompt 80
|
||||
fi
|
||||
|
||||
echo "=== STEP 4a: LLM inference on hybrid prompts (local_hf, ${MODEL}) ==="
|
||||
.venv/bin/python -u scripts/run_llm_inference.py \
|
||||
--provider local_hf \
|
||||
--model "${MODEL}" \
|
||||
--dtype bf16 \
|
||||
--device-map auto \
|
||||
--max-memory-per-gpu-gib "${MAX_GIB}" \
|
||||
--prompt-dir "${PROMPTS_HYBRID}/prompts" \
|
||||
--output-jsonl "${PRED_HYBRID}" \
|
||||
--max-prompt-chars 200000
|
||||
|
||||
echo "=== STEP 4b: LLM inference on Phase 14 raw landmark prompts (same set) ==="
|
||||
.venv/bin/python -u scripts/run_llm_inference.py \
|
||||
--provider local_hf \
|
||||
--model "${MODEL}" \
|
||||
--dtype bf16 \
|
||||
--device-map auto \
|
||||
--max-memory-per-gpu-gib "${MAX_GIB}" \
|
||||
--prompt-dir "${PROMPTS_RAW}/prompts" \
|
||||
--output-jsonl "${PRED_RAW}" \
|
||||
--max-prompt-chars 200000
|
||||
|
||||
echo "=== STEP 5: aggregate metrics ==="
|
||||
.venv/bin/python -u scripts/run_evaluation.py \
|
||||
--predictions-jsonl "${PRED_HYBRID}" \
|
||||
--predictions-jsonl "${PRED_RAW}" \
|
||||
--labeled-targets "${LABELED_TARGETS}" \
|
||||
--output-dir "${METRICS_DIR}"
|
||||
|
||||
echo "=== STEP 6: cross-compare with v0.1/v0.2 baselines ==="
|
||||
.venv/bin/python -u scripts/summarize_hybrid_experiment.py \
|
||||
--hybrid-metrics "${METRICS_DIR}/metrics.json" \
|
||||
--output "${OUT_ROOT}/summary_local.md"
|
||||
|
||||
echo "=== ALL STAGES COMPLETE ==="
|
||||
echo "Metrics:"
|
||||
cat "${METRICS_DIR}/metrics.md"
|
||||
echo
|
||||
echo "Summary:"
|
||||
cat "${OUT_ROOT}/summary_local.md"
|
||||
207
scripts/run_llm_inference.py
Normal file
207
scripts/run_llm_inference.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Run OpenAI-compatible LLM inference for saved ER-TP-DGP prompts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.llm import LocalHFLogitsProvider, OpenAICompatibleHTTPProvider
|
||||
from er_tp_dgp.llm_config import load_llm_config, merge_llm_config
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", help="YAML LLM config file, e.g. configs/llm.yaml")
|
||||
parser.add_argument("--provider", choices=["api", "local", "local_hf"])
|
||||
parser.add_argument("--base-url")
|
||||
parser.add_argument("--model")
|
||||
parser.add_argument("--prompt-file", action="append", default=[])
|
||||
parser.add_argument("--prompt-dir")
|
||||
parser.add_argument("--output-jsonl", default="reports/llm_predictions.jsonl")
|
||||
parser.add_argument("--api-key-env", default=None)
|
||||
parser.add_argument("--timeout-seconds", type=float)
|
||||
parser.add_argument("--temperature", type=float)
|
||||
parser.add_argument("--max-tokens", type=int)
|
||||
parser.add_argument(
|
||||
"--request-logprobs",
|
||||
action="store_true",
|
||||
help="(API/local-OpenAI) Ask server for first-token top_logprobs and "
|
||||
"compute calibrated softmax score (DGP formula 14).",
|
||||
)
|
||||
parser.add_argument("--lora-adapter", default=None, help="(local_hf) path to LoRA adapter.")
|
||||
parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--device-map", default="auto")
|
||||
parser.add_argument(
|
||||
"--model-class",
|
||||
default="auto",
|
||||
choices=["auto", "causal_lm", "image_text_to_text", "seq2seq"],
|
||||
help=(
|
||||
"(local_hf) HF AutoModelFor* class. 'auto' inspects "
|
||||
"config.architectures (multimodal Qwen3.5-27B → image_text_to_text)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-memory-per-gpu-gib",
|
||||
type=float,
|
||||
default=None,
|
||||
help=(
|
||||
"(local_hf) Cap per-GPU memory so accelerate balances across cards "
|
||||
"instead of filling GPU 0. Use ~30 for 2x A100 40GB on Qwen3.5-27B."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-prompt-chars",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Skip any prompt larger than this (chars). Outliers (e.g. firefox "
|
||||
"30s windows producing 1M+ tokens) trigger attention OOM even with "
|
||||
"SDPA. The skipped target gets first_token_score=None and is "
|
||||
"excluded by the metrics aggregator."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt_files = [Path(path) for path in args.prompt_file]
|
||||
if args.prompt_dir:
|
||||
prompt_files.extend(sorted(Path(args.prompt_dir).glob("*.txt")))
|
||||
if not prompt_files:
|
||||
raise SystemExit("No prompt files provided. Use --prompt-file or --prompt-dir.")
|
||||
|
||||
if args.provider == "local_hf":
|
||||
if not args.model:
|
||||
raise SystemExit("local_hf requires --model (HF model id, e.g. Qwen/Qwen3-8B).")
|
||||
provider = LocalHFLogitsProvider(
|
||||
base_model=args.model,
|
||||
lora_adapter=args.lora_adapter,
|
||||
dtype=args.dtype,
|
||||
device_map=args.device_map,
|
||||
model_class=args.model_class,
|
||||
max_memory_per_gpu_gib=args.max_memory_per_gpu_gib,
|
||||
)
|
||||
else:
|
||||
if args.config:
|
||||
config = load_llm_config(args.config)
|
||||
config = merge_llm_config(
|
||||
config,
|
||||
provider=args.provider,
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
api_key_env=args.api_key_env,
|
||||
timeout_seconds=args.timeout_seconds,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
)
|
||||
else:
|
||||
missing = [
|
||||
name
|
||||
for name, value in (
|
||||
("--provider", args.provider),
|
||||
("--base-url", args.base_url),
|
||||
("--model", args.model),
|
||||
)
|
||||
if not value
|
||||
]
|
||||
if missing:
|
||||
raise SystemExit(
|
||||
f"Missing required arguments without --config: {', '.join(missing)}"
|
||||
)
|
||||
config = merge_llm_config(
|
||||
load_default_inline_config(args.provider, args.base_url, args.model),
|
||||
api_key_env=args.api_key_env,
|
||||
timeout_seconds=args.timeout_seconds,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
)
|
||||
if args.request_logprobs:
|
||||
config = LLMRequestConfig_with_logprobs(config)
|
||||
provider = OpenAICompatibleHTTPProvider(config)
|
||||
|
||||
output_path = Path(args.output_jsonl)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
skipped = 0
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
for idx, prompt_file in enumerate(prompt_files, start=1):
|
||||
prompt_text = prompt_file.read_text(encoding="utf-8")
|
||||
target_id = prompt_file.stem
|
||||
if args.max_prompt_chars is not None and len(prompt_text) > args.max_prompt_chars:
|
||||
payload = {
|
||||
"target_id": target_id,
|
||||
"prompt_file": str(prompt_file),
|
||||
"skipped": True,
|
||||
"skip_reason": f"prompt size {len(prompt_text)} > --max-prompt-chars {args.max_prompt_chars}",
|
||||
"first_token_score": None,
|
||||
"first_token_yes_logprob": None,
|
||||
"first_token_no_logprob": None,
|
||||
"output": {"first_token_label": None, "score": None, "predicted_label": None,
|
||||
"evidence_path_ids": []},
|
||||
}
|
||||
handle.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
handle.flush()
|
||||
skipped += 1
|
||||
print(f"[{idx}/{len(prompt_files)}] {prompt_file}: SKIP ({len(prompt_text)} chars > cap)")
|
||||
continue
|
||||
try:
|
||||
result = provider.classify(target_id=target_id, prompt_text=prompt_text)
|
||||
except Exception as exc: # noqa: BLE001 - any GPU/inference error → skip, keep batch alive
|
||||
payload = {
|
||||
"target_id": target_id,
|
||||
"prompt_file": str(prompt_file),
|
||||
"skipped": True,
|
||||
"skip_reason": f"inference error: {type(exc).__name__}: {str(exc)[:200]}",
|
||||
"first_token_score": None,
|
||||
"first_token_yes_logprob": None,
|
||||
"first_token_no_logprob": None,
|
||||
"output": {"first_token_label": None, "score": None, "predicted_label": None,
|
||||
"evidence_path_ids": []},
|
||||
}
|
||||
handle.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
handle.flush()
|
||||
skipped += 1
|
||||
print(f"[{idx}/{len(prompt_files)}] {prompt_file}: ERROR {type(exc).__name__} (continuing)")
|
||||
# Free CUDA cache before next prompt to avoid cascading OOM.
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
payload = result.to_json_dict()
|
||||
payload["prompt_file"] = str(prompt_file)
|
||||
handle.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
score = (
|
||||
result.first_token_score
|
||||
if result.first_token_score is not None
|
||||
else result.output.score
|
||||
)
|
||||
print(
|
||||
f"{prompt_file}: {result.output.first_token_label} "
|
||||
f"score={score} latency={result.latency_seconds:.2f}s"
|
||||
)
|
||||
print(f"wrote={output_path}")
|
||||
return 0
|
||||
|
||||
|
||||
def load_default_inline_config(provider: str, base_url: str, model: str):
|
||||
from er_tp_dgp.llm import LLMRequestConfig
|
||||
|
||||
return LLMRequestConfig(
|
||||
provider_type=provider,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
api_key_env="OPENAI_COMPAT_API_KEY" if provider == "api" else None,
|
||||
)
|
||||
|
||||
|
||||
def LLMRequestConfig_with_logprobs(config):
|
||||
"""Return a copy of `config` with logprobs/top_logprobs requested."""
|
||||
from dataclasses import replace as _replace
|
||||
|
||||
return _replace(config, request_logprobs=True, top_logprobs=20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
255
scripts/run_multiround_inference.py
Normal file
255
scripts/run_multiround_inference.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Causal Graph-of-Thought (CGoT) multi-round inference.
|
||||
|
||||
Loads the same prompt batch produced by build_theia_prompt_batch.py BUT also
|
||||
needs the underlying provenance graph (from the cached THEIA window IR) and
|
||||
the labeled_targets to know each target's anchor + UUID. Round prompts are
|
||||
constructed live from the graph; the per-target prompt_text/*.txt files are
|
||||
NOT used here.
|
||||
|
||||
Output format: one JSONL line per target with:
|
||||
target_id, score (final round softmax), yes_logprob, no_logprob,
|
||||
intermediate_findings (list of {round_id, metapath_type, observation}),
|
||||
rounds_run, total_latency_seconds.
|
||||
|
||||
The output is shaped so that scripts/run_evaluation.py can ingest it like any
|
||||
other predictions file (first_token_score / first_token_yes_logprob /
|
||||
first_token_no_logprob fields are populated identically).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.llm import LocalHFLogitsProvider
|
||||
from er_tp_dgp.metapaths import APTMetapathExtractor
|
||||
from er_tp_dgp.multiround import MultiRoundPromptBuilder
|
||||
from er_tp_dgp.numerical_aggregator import NumericalAggregator
|
||||
from er_tp_dgp.prompt import PromptComponentSwitches
|
||||
from er_tp_dgp.scoring import score_from_hf_logits
|
||||
from er_tp_dgp.text_summarizer import (
|
||||
MetapathTextSummarizer,
|
||||
NodeTextSummarizer,
|
||||
SummarizerConfig,
|
||||
_NullLLM,
|
||||
)
|
||||
from er_tp_dgp.theia import build_cached_theia_window_ir, discover_theia_json_files
|
||||
from er_tp_dgp.trimming import TemporalSecurityAwareTrimmer
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
||||
parser.add_argument("--labeled-targets", required=True)
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
default="reports/cache/theia_window_ir",
|
||||
help="Where pre-warmed window-IR snapshots live.",
|
||||
)
|
||||
parser.add_argument("--lookback-seconds", type=float, default=30.0)
|
||||
parser.add_argument("--lookahead-seconds", type=float, default=30.0)
|
||||
parser.add_argument("--top-m-per-metapath", type=int, default=5)
|
||||
parser.add_argument("--model", required=True, help="HF model id, e.g. Qwen/Qwen3-1.7B")
|
||||
parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--device-map", default="auto")
|
||||
parser.add_argument("--max-memory-per-gpu-gib", type=float, default=None)
|
||||
parser.add_argument("--lora-adapter", default=None)
|
||||
parser.add_argument("--output-jsonl", required=True)
|
||||
parser.add_argument(
|
||||
"--intermediate-max-tokens",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Max new tokens for non-final rounds (short observations).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-llm-summarizer",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Use a remote OpenAI-compat config for TextSumm/PathSumm "
|
||||
"(via --summarizer-config). Default: NullSummarizer (truncation only)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--summarizer-config", default=None)
|
||||
parser.add_argument("--summarizer-workers", type=int, default=8)
|
||||
parser.add_argument("--max-targets", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = discover_theia_json_files(args.data_dir)
|
||||
if not paths:
|
||||
raise SystemExit(f"no THEIA JSON files in {args.data_dir}")
|
||||
|
||||
targets = _read_jsonl(Path(args.labeled_targets))
|
||||
if args.max_targets is not None:
|
||||
targets = targets[: args.max_targets]
|
||||
|
||||
provider = LocalHFLogitsProvider(
|
||||
base_model=args.model,
|
||||
lora_adapter=args.lora_adapter,
|
||||
dtype=args.dtype,
|
||||
device_map=args.device_map,
|
||||
max_memory_per_gpu_gib=args.max_memory_per_gpu_gib,
|
||||
)
|
||||
|
||||
node_summ, path_summ = _build_summarizers(
|
||||
use_llm=args.use_llm_summarizer,
|
||||
config_path=args.summarizer_config,
|
||||
workers=args.summarizer_workers,
|
||||
)
|
||||
|
||||
output_path = Path(args.output_jsonl)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as out:
|
||||
for index, target in enumerate(targets, start=1):
|
||||
target_id = target["target_id"]
|
||||
anchor_event_id = target["anchor_event_id"]
|
||||
print(f"[{index}/{len(targets)}] target={target_id} anchor={anchor_event_id}", flush=True)
|
||||
|
||||
window = build_cached_theia_window_ir(
|
||||
paths,
|
||||
target_event_uuid=anchor_event_id,
|
||||
lookback_seconds=args.lookback_seconds,
|
||||
lookahead_seconds=args.lookahead_seconds,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
graph = window.to_graph()
|
||||
graph_target_id = window.target_subject_id or window.target_event_id
|
||||
evidence_paths = APTMetapathExtractor(graph).extract_for_target(graph_target_id)
|
||||
selected = TemporalSecurityAwareTrimmer(
|
||||
graph, top_m_per_metapath=args.top_m_per_metapath
|
||||
).trim(graph_target_id, evidence_paths)
|
||||
|
||||
switches = PromptComponentSwitches(
|
||||
use_text_summarization=(node_summ is not None),
|
||||
use_path_summarization_llm=(path_summ is not None),
|
||||
)
|
||||
builder = MultiRoundPromptBuilder(
|
||||
graph,
|
||||
node_summarizer=node_summ,
|
||||
path_summarizer=path_summ,
|
||||
numerical_aggregator=NumericalAggregator(graph),
|
||||
switches=switches,
|
||||
)
|
||||
plan = builder.build(graph_target_id, selected)
|
||||
|
||||
result = _run_plan(
|
||||
provider=provider,
|
||||
plan=plan,
|
||||
intermediate_max_tokens=args.intermediate_max_tokens,
|
||||
)
|
||||
payload = {
|
||||
"target_id": graph_target_id,
|
||||
"anchor_event_id": anchor_event_id,
|
||||
"rounds_run": len(plan.rounds),
|
||||
"intermediate_findings": result["intermediate_findings"],
|
||||
"raw_text": result["final_text"],
|
||||
"first_token_score": result["score"],
|
||||
"first_token_yes_logprob": result["yes_logprob"],
|
||||
"first_token_no_logprob": result["no_logprob"],
|
||||
"output": {
|
||||
"first_token_label": "MALICIOUS" if (result["score"] or 0.0) >= 0.5 else "BENIGN",
|
||||
"score": result["score"],
|
||||
"predicted_label": "MALICIOUS" if (result["score"] or 0.0) >= 0.5 else "BENIGN",
|
||||
"evidence_path_ids": list(plan.evidence_path_ids),
|
||||
},
|
||||
"latency_seconds": result["total_latency"],
|
||||
}
|
||||
out.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
|
||||
out.flush()
|
||||
print(
|
||||
f" rounds={len(plan.rounds)} score={result['score']} "
|
||||
f"latency={result['total_latency']:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"wrote {output_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
def _run_plan(*, provider: LocalHFLogitsProvider, plan, intermediate_max_tokens: int) -> dict:
|
||||
intermediate: list[dict] = []
|
||||
started = time.time()
|
||||
|
||||
def _format(prompt_template: str) -> str:
|
||||
prior_block = "\n".join(
|
||||
f"- {entry['round_id']} ({entry.get('metapath_type') or '-'}): {entry['observation']}"
|
||||
for entry in intermediate
|
||||
)
|
||||
if "{prior_findings}" in prompt_template:
|
||||
return prompt_template.replace(
|
||||
"{prior_findings}",
|
||||
f"Prior reasoning:\n{prior_block}" if prior_block else "Prior reasoning: (none yet)",
|
||||
)
|
||||
return prompt_template
|
||||
|
||||
score = None
|
||||
yes_lp = None
|
||||
no_lp = None
|
||||
final_text = ""
|
||||
|
||||
for round_prompt in plan.rounds:
|
||||
prompt = _format(round_prompt.prompt_text)
|
||||
if round_prompt.is_final:
|
||||
# Final round: classify, read first-token Yes/No softmax.
|
||||
r = provider.classify(target_id=plan.target_id, prompt_text=prompt)
|
||||
score = r.first_token_score
|
||||
yes_lp = r.first_token_yes_logprob
|
||||
no_lp = r.first_token_no_logprob
|
||||
final_text = r.raw_text
|
||||
else:
|
||||
# Intermediate round: short text generation.
|
||||
obs = provider.complete(prompt, max_tokens=intermediate_max_tokens)
|
||||
# Trim the observation aggressively: take up to first newline.
|
||||
short = obs.split("\n", 1)[0].strip()[:280]
|
||||
intermediate.append(
|
||||
{
|
||||
"round_id": round_prompt.round_id,
|
||||
"metapath_type": round_prompt.metapath_type,
|
||||
"observation": short,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"intermediate_findings": intermediate,
|
||||
"final_text": final_text,
|
||||
"score": score,
|
||||
"yes_logprob": yes_lp,
|
||||
"no_logprob": no_lp,
|
||||
"total_latency": time.time() - started,
|
||||
}
|
||||
|
||||
|
||||
def _build_summarizers(
|
||||
*, use_llm: bool, config_path: str | None, workers: int
|
||||
) -> tuple[NodeTextSummarizer | None, MetapathTextSummarizer | None]:
|
||||
if not use_llm:
|
||||
return None, None
|
||||
if not config_path:
|
||||
cfg = SummarizerConfig(model_name="null-fallback", max_workers=workers)
|
||||
return NodeTextSummarizer(llm=_NullLLM(), config=cfg), MetapathTextSummarizer(
|
||||
llm=_NullLLM(), config=cfg
|
||||
)
|
||||
from er_tp_dgp.llm import OpenAICompatibleHTTPProvider
|
||||
from er_tp_dgp.llm_config import load_llm_config
|
||||
|
||||
llm_cfg = load_llm_config(config_path)
|
||||
provider = OpenAICompatibleHTTPProvider(llm_cfg)
|
||||
cfg = SummarizerConfig(model_name=llm_cfg.model, max_workers=workers)
|
||||
return NodeTextSummarizer(llm=provider, config=cfg), MetapathTextSummarizer(llm=provider, config=cfg)
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
if line.strip():
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
124
scripts/summarize_hybrid_experiment.py
Normal file
124
scripts/summarize_hybrid_experiment.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Summarize the hybrid v0.3 experiment + cross-compare with v0.1 v0.2 baselines.
|
||||
|
||||
Reads:
|
||||
- reports/hybrid_v0_3/metrics/metrics.json (this experiment)
|
||||
- reports/evaluation/e3_theia_v0_2/metrics_n146_4methods/metrics.json (v0.1 baseline)
|
||||
|
||||
Writes:
|
||||
- reports/hybrid_v0_3/summary.md — head-to-head comparison table
|
||||
|
||||
The two experiments use different target populations (v0.1 = per-process
|
||||
n=146, hybrid = per-community n=150) so this is NOT a direct AUPRC
|
||||
comparison — it's a "how does the new method compare in absolute
|
||||
detection capability" snapshot. The within-experiment row comparison
|
||||
(hybrid vs Phase 14 raw landmarks on the SAME 150 communities) IS direct.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load(path: Path) -> dict:
|
||||
if not path.exists():
|
||||
return {}
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _row(name: str, m: dict) -> str:
|
||||
metrics = m.get("metrics") if "metrics" in m else m
|
||||
n = metrics.get("num_examples", "?")
|
||||
n_pos = metrics.get("num_positive", "?")
|
||||
return (
|
||||
f"| {name} | {n} | {n_pos} | "
|
||||
f"{_fmt(metrics.get('auprc'))} | {_fmt(metrics.get('auroc'))} | "
|
||||
f"{_fmt(metrics.get('macro_f1'))} | "
|
||||
f"{_fmt((metrics.get('recall_at_k') or {}).get('10') or (metrics.get('recall_at_k') or {}).get(10))} | "
|
||||
f"{_fmt((metrics.get('fpr_at_recall') or {}).get('0.9') or (metrics.get('fpr_at_recall') or {}).get(0.9))} | "
|
||||
f"{_fmt(metrics.get('avg_prompt_tokens'))} | "
|
||||
f"{_fmt(metrics.get('evidence_path_hit_rate'))} |"
|
||||
)
|
||||
|
||||
|
||||
def _fmt(value) -> str:
|
||||
if isinstance(value, (int, float)):
|
||||
return f"{value:.4f}" if isinstance(value, float) else str(value)
|
||||
if value is None:
|
||||
return "n/a"
|
||||
return str(value)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--hybrid-metrics",
|
||||
default="reports/hybrid_v0_3/metrics/metrics.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baseline-metrics",
|
||||
default="reports/evaluation/e3_theia_v0_2/metrics_n146_4methods/metrics.json",
|
||||
)
|
||||
parser.add_argument("--output", default="reports/hybrid_v0_3/summary.md")
|
||||
args = parser.parse_args()
|
||||
|
||||
hybrid_data = _load(Path(args.hybrid_metrics))
|
||||
baseline_data = _load(Path(args.baseline_metrics))
|
||||
|
||||
rows: list[tuple[str, dict]] = []
|
||||
for name, payload in sorted(hybrid_data.items()):
|
||||
rows.append((f"hybrid_v0_3 / {name}", payload))
|
||||
for name, payload in sorted(baseline_data.items()):
|
||||
rows.append((f"baseline_v0_2 / {name}", payload))
|
||||
|
||||
headers = [
|
||||
"method",
|
||||
"n",
|
||||
"n+",
|
||||
"AUPRC",
|
||||
"AUROC",
|
||||
"Macro-F1",
|
||||
"Recall@10",
|
||||
"FPR@0.9",
|
||||
"avg_tokens",
|
||||
"evidence_hit",
|
||||
]
|
||||
lines = [
|
||||
"# ER-TP-DGP Hybrid v0.3 — Head-to-Head Summary",
|
||||
"",
|
||||
"## Comparison axes",
|
||||
"",
|
||||
"- **hybrid_v0_3 / predictions_hybrid** — Phase 14 community detection unit + ",
|
||||
" v0.1 fine-grained subgraph re-injection + DGP-12 layered prompt.",
|
||||
"- **hybrid_v0_3 / predictions_landmark_raw** — Phase 14 raw landmark-only ",
|
||||
" prompts on the SAME 150 communities (head-to-head ablation).",
|
||||
"- **baseline_v0_2 / predictions_graph_dgp_*** — v0.1 graph_dgp pipeline ",
|
||||
" on n=146 per-process targets (different population, included as scale reference).",
|
||||
"- **baseline_v0_2 / predictions_target_only_*** — v0.1 target-only baseline ",
|
||||
" on n=146 per-process targets.",
|
||||
"",
|
||||
"## Metrics",
|
||||
"",
|
||||
"| " + " | ".join(headers) + " |",
|
||||
"|" + "|".join(["---"] * len(headers)) + "|",
|
||||
]
|
||||
for name, payload in rows:
|
||||
lines.append(_row(name, payload))
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"Score column is calibrated first-token softmax over (Yes, No) "
|
||||
"(DGP paper formula 14). Rows missing logprobs are excluded with a warning."
|
||||
)
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
print(f"wrote {out_path}")
|
||||
print()
|
||||
print("\n".join(lines))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
99
scripts/theia_candidate_universe.py
Normal file
99
scripts/theia_candidate_universe.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a label-free THEIA candidate universe and QA sampling frame."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.candidate_universe import (
|
||||
build_theia_candidate_universe,
|
||||
write_stratified_sample_jsonl,
|
||||
)
|
||||
from er_tp_dgp.theia import discover_theia_json_files
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Build protocol-based process candidates from THEIA JSON. "
|
||||
"This is label-free candidate generation, not detection evaluation."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
action="append",
|
||||
default=None,
|
||||
help="Specific THEIA JSON file to scan. Can be repeated. Overrides --data-dir discovery.",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="reports/theia_candidate_universe")
|
||||
parser.add_argument("--dataset-name", default="DARPA_TC_E3_THEIA")
|
||||
parser.add_argument("--max-lines", type=int, default=None)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--progress-every",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Emit '[progress] lines=...' every N records. Useful for long full-corpus scans.",
|
||||
)
|
||||
parser.add_argument("--min-score", type=float, default=1.0)
|
||||
parser.add_argument("--min-events", type=int, default=1)
|
||||
parser.add_argument("--per-stratum", type=int, default=5)
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
parser.add_argument("--report-limit", type=int, default=40)
|
||||
args = parser.parse_args()
|
||||
|
||||
paths = [Path(path) for path in args.input_file] if args.input_file else discover_theia_json_files(args.data_dir)
|
||||
if not paths:
|
||||
raise SystemExit(f"no THEIA JSON files found under {args.data_dir}")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
universe = build_theia_candidate_universe(
|
||||
paths,
|
||||
dataset_name=args.dataset_name,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
progress_every=args.progress_every,
|
||||
)
|
||||
candidates = universe.candidate_profiles(
|
||||
min_score=args.min_score,
|
||||
min_events=args.min_events,
|
||||
)
|
||||
|
||||
universe_path = output_dir / "candidate_universe.jsonl"
|
||||
report_path = output_dir / "candidate_universe.md"
|
||||
sample_path = output_dir / "qa_stratified_sample.jsonl"
|
||||
|
||||
universe.write_jsonl(
|
||||
universe_path,
|
||||
min_score=args.min_score,
|
||||
min_events=args.min_events,
|
||||
)
|
||||
report_path.write_text(
|
||||
universe.to_markdown(
|
||||
min_score=args.min_score,
|
||||
min_events=args.min_events,
|
||||
limit=args.report_limit,
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
sample = write_stratified_sample_jsonl(
|
||||
candidates,
|
||||
sample_path,
|
||||
per_stratum=args.per_stratum,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
print(f"files={len(paths)} lines_seen={universe.lines_seen} events_seen={universe.events_seen}")
|
||||
print(f"profiles={len(universe.profiles)} candidates={len(candidates)} qa_sample={len(sample)}")
|
||||
print(f"wrote {universe_path}")
|
||||
print(f"wrote {report_path}")
|
||||
print(f"wrote {sample_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
108
scripts/theia_idea_validate.py
Normal file
108
scripts/theia_idea_validate.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build one real THEIA E3 ER-TP-DGP prompt for idea validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.metapaths import APTMetapathExtractor
|
||||
from er_tp_dgp.prompt import PromptBuilder
|
||||
from er_tp_dgp.theia import build_theia_window_ir, discover_theia_json_files
|
||||
from er_tp_dgp.trimming import TemporalSecurityAwareTrimmer
|
||||
from er_tp_dgp.validation import validate_evidence_paths, validate_graph, validate_ir
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument("--output-dir", default="reports/theia_e3_idea")
|
||||
parser.add_argument(
|
||||
"--target-event",
|
||||
default="86E0FB61-B300-2215-3C6D-8F0000000010",
|
||||
help="Raw THEIA event UUID to use as target anchor.",
|
||||
)
|
||||
parser.add_argument("--lookback-seconds", type=float, default=120.0)
|
||||
parser.add_argument("--lookahead-seconds", type=float, default=120.0)
|
||||
parser.add_argument("--max-lines", type=int, default=1_250_000)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=50_000)
|
||||
parser.add_argument("--top-m-per-metapath", type=int, default=5)
|
||||
args = parser.parse_args()
|
||||
|
||||
files = discover_theia_json_files(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
window = build_theia_window_ir(
|
||||
files,
|
||||
target_event_uuid=args.target_event,
|
||||
lookback_seconds=args.lookback_seconds,
|
||||
lookahead_seconds=args.lookahead_seconds,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
)
|
||||
graph = window.to_graph()
|
||||
target_id = window.target_subject_id or window.target_event_id
|
||||
|
||||
ir_report = validate_ir(list(window.entities), list(window.events))
|
||||
graph_report = validate_graph(graph)
|
||||
paths = APTMetapathExtractor(graph).extract_for_target(target_id)
|
||||
selected = TemporalSecurityAwareTrimmer(
|
||||
graph,
|
||||
top_m_per_metapath=args.top_m_per_metapath,
|
||||
).trim(target_id, paths)
|
||||
evidence_report = validate_evidence_paths(graph, selected)
|
||||
prompt = PromptBuilder(graph).build(target_id, selected)
|
||||
|
||||
summary = [
|
||||
"# THEIA E3 ER-TP-DGP Idea Validation",
|
||||
"",
|
||||
"This is a method plumbing validation on a real THEIA E3 window. It is not a detection-performance result.",
|
||||
"",
|
||||
f"- target_event_id: {window.target_event_id}",
|
||||
f"- target_subject_id: {window.target_subject_id}",
|
||||
f"- window_start_nanos: {window.start_timestamp_nanos}",
|
||||
f"- window_end_nanos: {window.end_timestamp_nanos}",
|
||||
f"- entities: {len(window.entities)}",
|
||||
f"- events: {len(window.events)}",
|
||||
f"- extracted_evidence_paths: {len(paths)}",
|
||||
f"- selected_evidence_paths: {len(selected)}",
|
||||
f"- schema_gaps: {list(window.schema_gaps)}",
|
||||
"",
|
||||
"## Validation",
|
||||
"",
|
||||
f"- ir_ok: {ir_report.ok}",
|
||||
f"- graph_ok: {graph_report.ok}",
|
||||
f"- evidence_ok: {evidence_report.ok}",
|
||||
"",
|
||||
"## Selected Evidence Paths",
|
||||
"",
|
||||
]
|
||||
for path in selected:
|
||||
summary.append(
|
||||
"- "
|
||||
f"{path.path_id} metapath={path.metapath_type} score={path.trimming_score:.3f} "
|
||||
f"events={list(path.ordered_event_ids)} reason={path.selected_reason}"
|
||||
)
|
||||
if not selected:
|
||||
summary.append("- none")
|
||||
|
||||
(output_dir / "idea_validation.md").write_text("\n".join(summary), encoding="utf-8")
|
||||
(output_dir / "prompt.txt").write_text(prompt.prompt_text, encoding="utf-8")
|
||||
(output_dir / "ir_validation.md").write_text(ir_report.to_markdown(), encoding="utf-8")
|
||||
(output_dir / "graph_validation.md").write_text(graph_report.to_markdown(), encoding="utf-8")
|
||||
(output_dir / "evidence_validation.md").write_text(evidence_report.to_markdown(), encoding="utf-8")
|
||||
|
||||
print(f"target_subject_id={window.target_subject_id}")
|
||||
print(f"entities={len(window.entities)}")
|
||||
print(f"events={len(window.events)}")
|
||||
print(f"paths={len(paths)}")
|
||||
print(f"selected={len(selected)}")
|
||||
print(f"schema_gaps={list(window.schema_gaps)}")
|
||||
print(f"wrote={output_dir / 'idea_validation.md'}")
|
||||
print(f"wrote={output_dir / 'prompt.txt'}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
54
scripts/theia_preliminary.py
Normal file
54
scripts/theia_preliminary.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Run THEIA E3 schema audit and debugging-only preliminary scan."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.theia import audit_theia_files, discover_theia_json_files, preliminary_scan_theia_files
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-dir", default="data/raw/e3_theia_json")
|
||||
parser.add_argument("--output-dir", default="reports/theia_e3")
|
||||
parser.add_argument("--max-lines", type=int, default=250_000)
|
||||
parser.add_argument("--max-lines-per-file", type=int, default=None)
|
||||
parser.add_argument("--max-candidates", type=int, default=200)
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = Path(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
files = discover_theia_json_files(data_dir)
|
||||
if not files:
|
||||
raise SystemExit(f"No THEIA JSON files found in {data_dir}")
|
||||
|
||||
profile = audit_theia_files(
|
||||
files,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
)
|
||||
scan = preliminary_scan_theia_files(
|
||||
files,
|
||||
max_lines=args.max_lines,
|
||||
max_lines_per_file=args.max_lines_per_file,
|
||||
max_candidates=args.max_candidates,
|
||||
)
|
||||
|
||||
(output_dir / "schema_profile.md").write_text(profile.to_markdown(), encoding="utf-8")
|
||||
(output_dir / "preliminary_candidates.md").write_text(scan.to_markdown(), encoding="utf-8")
|
||||
|
||||
print(f"files={len(files)}")
|
||||
print(f"schema_lines={profile.lines_seen}")
|
||||
print(f"scan_lines={scan.lines_seen}")
|
||||
print(f"candidates={len(scan.candidates)}")
|
||||
print(f"wrote={output_dir / 'schema_profile.md'}")
|
||||
print(f"wrote={output_dir / 'preliminary_candidates.md'}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
196
scripts/train_lora.py
Normal file
196
scripts/train_lora.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
"""LoRA fine-tune Qwen3-8B (or compatible) on ER-TP-DGP prompt batches.
|
||||
|
||||
Inputs:
|
||||
--prompt-batch-dir Directory produced by build_theia_prompt_batch.py.
|
||||
Expected files inside:
|
||||
- prompt_metadata.jsonl
|
||||
- prompt_text/<NNNN_targetid>.txt
|
||||
--labeled-targets Path to evaluation_batch.jsonl with `label` field.
|
||||
--train-until / --val-until Time-based split timestamps (paper-aligned
|
||||
anti-leakage protocol; see splits.time_based_split).
|
||||
|
||||
Outputs:
|
||||
--output-dir/lora_final PEFT adapter directory + tokenizer
|
||||
--output-dir/splits.json Train/val/test target ID lists
|
||||
--output-dir/leakage_audit.md splits.check_leakage report
|
||||
|
||||
Implements paper formula 13: CE on first generated token Yes/No, computed
|
||||
under the standard transformers Trainer with label_ids = -100 except at the
|
||||
target token position. Adapter loadable later via LocalHFLogitsProvider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
from er_tp_dgp.splits import TargetMetadata, check_leakage, time_based_split
|
||||
from er_tp_dgp.training import LoRAConfig, TrainConfig, TrainExample, train_lora
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
|
||||
parser.add_argument("--prompt-batch-dir", required=True)
|
||||
parser.add_argument("--labeled-targets", required=True)
|
||||
parser.add_argument("--output-dir", default="reports/training/v1")
|
||||
parser.add_argument("--base-model", default="Qwen/Qwen3-8B")
|
||||
parser.add_argument("--epochs", type=int, default=3)
|
||||
parser.add_argument("--learning-rate", type=float, default=2e-4)
|
||||
parser.add_argument("--per-device-batch-size", type=int, default=2)
|
||||
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
|
||||
parser.add_argument("--max-seq-length", type=int, default=8192)
|
||||
parser.add_argument("--lora-r", type=int, default=16)
|
||||
parser.add_argument("--lora-alpha", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--train-until",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Targets with timestamp <= train_until go to train split.",
|
||||
)
|
||||
parser.add_argument("--val-until", type=float, required=True)
|
||||
parser.add_argument("--seed", type=int, default=7)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
examples, target_meta = _load_prompt_batch_examples(
|
||||
prompt_batch_dir=Path(args.prompt_batch_dir),
|
||||
labeled_targets=Path(args.labeled_targets),
|
||||
)
|
||||
if not examples:
|
||||
raise SystemExit("No prompt examples found; check --prompt-batch-dir/--labeled-targets.")
|
||||
|
||||
assignment = time_based_split(
|
||||
target_meta,
|
||||
train_until=args.train_until,
|
||||
validation_until=args.val_until,
|
||||
)
|
||||
leakage = check_leakage(target_meta, assignment)
|
||||
(output_dir / "leakage_audit.md").write_text(leakage.to_markdown(), encoding="utf-8")
|
||||
if not leakage.ok:
|
||||
# Don't abort; the audit file is the artifact. Operator decides.
|
||||
print(
|
||||
f"WARNING: leakage audit reported {len(leakage.findings)} findings; "
|
||||
f"see {output_dir/'leakage_audit.md'}"
|
||||
)
|
||||
|
||||
splits_payload: dict[str, list[str]] = {"train": [], "val": [], "test": []}
|
||||
train_examples: list[TrainExample] = []
|
||||
val_examples: list[TrainExample] = []
|
||||
for example, meta in zip(examples, target_meta, strict=True):
|
||||
split = assignment.split_by_target[meta.target_id].value
|
||||
if split == "train":
|
||||
train_examples.append(example)
|
||||
splits_payload["train"].append(meta.target_id)
|
||||
elif split == "validation":
|
||||
val_examples.append(example)
|
||||
splits_payload["val"].append(meta.target_id)
|
||||
else:
|
||||
splits_payload["test"].append(meta.target_id)
|
||||
(output_dir / "splits.json").write_text(
|
||||
json.dumps(splits_payload, ensure_ascii=False, sort_keys=True, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
print(
|
||||
f"train={len(train_examples)} val={len(val_examples)} "
|
||||
f"test={len(splits_payload['test'])}"
|
||||
)
|
||||
|
||||
train_cfg = TrainConfig(
|
||||
base_model=args.base_model,
|
||||
output_dir=output_dir,
|
||||
epochs=args.epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
per_device_batch_size=args.per_device_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
max_seq_length=args.max_seq_length,
|
||||
seed=args.seed,
|
||||
)
|
||||
lora_cfg = LoRAConfig(r=args.lora_r, alpha=args.lora_alpha)
|
||||
final_dir = train_lora(train_examples, val_examples, train_config=train_cfg, lora_config=lora_cfg)
|
||||
|
||||
manifest = {
|
||||
"base_model": args.base_model,
|
||||
"lora_r": args.lora_r,
|
||||
"lora_alpha": args.lora_alpha,
|
||||
"epochs": args.epochs,
|
||||
"learning_rate": args.learning_rate,
|
||||
"train_until": args.train_until,
|
||||
"val_until": args.val_until,
|
||||
"train_size": len(train_examples),
|
||||
"val_size": len(val_examples),
|
||||
"test_size": len(splits_payload["test"]),
|
||||
"adapter_path": str(final_dir),
|
||||
"splits_path": str(output_dir / "splits.json"),
|
||||
"leakage_audit_path": str(output_dir / "leakage_audit.md"),
|
||||
}
|
||||
(output_dir / "train_manifest.json").write_text(
|
||||
json.dumps(manifest, ensure_ascii=False, sort_keys=True, indent=2), encoding="utf-8"
|
||||
)
|
||||
print(f"adapter saved to: {final_dir}")
|
||||
print(f"manifest: {output_dir/'train_manifest.json'}")
|
||||
return 0
|
||||
|
||||
|
||||
def _load_prompt_batch_examples(
|
||||
*, prompt_batch_dir: Path, labeled_targets: Path
|
||||
) -> tuple[list[TrainExample], list[TargetMetadata]]:
|
||||
"""Cross-reference prompt files with labeled_targets for supervised pairs."""
|
||||
metadata_path = prompt_batch_dir / "prompt_metadata.jsonl"
|
||||
if not metadata_path.exists():
|
||||
raise SystemExit(f"missing {metadata_path}")
|
||||
label_by_id: dict[str, dict] = {}
|
||||
for row in _read_jsonl(labeled_targets):
|
||||
label_by_id[row["target_id"]] = row
|
||||
|
||||
examples: list[TrainExample] = []
|
||||
metas: list[TargetMetadata] = []
|
||||
for row in _read_jsonl(metadata_path):
|
||||
target_id = row["target_id"]
|
||||
prompt_path = Path(row["prompt_path"])
|
||||
label_row = label_by_id.get(target_id)
|
||||
if not prompt_path.exists() or not label_row:
|
||||
continue
|
||||
label_value = label_row.get("label")
|
||||
if label_value not in {"malicious", "benign", "benign_proxy"}:
|
||||
continue
|
||||
prompt_text = prompt_path.read_text(encoding="utf-8")
|
||||
examples.append(
|
||||
TrainExample(
|
||||
prompt_text=prompt_text,
|
||||
label="Yes" if label_value == "malicious" else "No",
|
||||
)
|
||||
)
|
||||
metas.append(
|
||||
TargetMetadata(
|
||||
target_id=target_id,
|
||||
target_type=str(label_row.get("target_type", "PROCESS")),
|
||||
timestamp=float(row.get("anchor_timestamp") or label_row.get("anchor_timestamp") or 0.0),
|
||||
host=label_row.get("host"),
|
||||
campaign_id=label_row.get("atom_id"),
|
||||
prompt_text=prompt_text,
|
||||
raw_event_ids=tuple(row.get("evidence_path_ids") or ()),
|
||||
process_ids=(target_id,) if label_row.get("target_type") == "PROCESS" else (),
|
||||
file_paths=tuple([label_row["process_path"]] if label_row.get("process_path") else ()),
|
||||
)
|
||||
)
|
||||
return examples, metas
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if line:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user