115 lines
4.9 KiB
Python
115 lines
4.9 KiB
Python
from __future__ import annotations
|
|
import csv
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
import numpy as np
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
|
from extract_lib import _canonical_key
|
|
|
|
@dataclass(frozen=True)
|
|
class CsvFlowAdapter:
|
|
join_cols: dict[str, str]
|
|
label_col: str
|
|
timestamp_formats: tuple[str, ...]
|
|
benign_aliases: frozenset[str]
|
|
benign_token: str = 'normal'
|
|
drop_label_patterns: tuple[str, ...] = ()
|
|
label_aliases: dict[str, str] = field(default_factory=dict)
|
|
label_normalizer: Callable[[str], str] | None = None
|
|
|
|
def normalize_label(self, raw: str) -> str:
|
|
if self.label_normalizer is not None:
|
|
return self.label_normalizer(raw)
|
|
s = raw.strip()
|
|
if s in self.benign_aliases:
|
|
return self.benign_token
|
|
return self.label_aliases.get(s, s)
|
|
|
|
def parse_timestamp(self, raw: str) -> float | None:
|
|
s = raw.strip()
|
|
if not s:
|
|
return None
|
|
for fmt in self.timestamp_formats:
|
|
try:
|
|
return datetime.strptime(s, fmt).timestamp()
|
|
except ValueError:
|
|
continue
|
|
return None
|
|
|
|
def parse_csv_rows(*, csv_path: Path, row_idx_start: int, time_offset_seconds: float, adapter: CsvFlowAdapter, max_per_class: int | None=None, max_benign: int | None=None, rng: np.random.Generator | None=None) -> tuple[dict[tuple, list[tuple[int, float]]], list[str], int, int, dict[str, int]]:
|
|
if (max_per_class is not None or max_benign is not None) and rng is None:
|
|
rng = np.random.default_rng(42)
|
|
parsed: list[tuple[tuple, float, str]] = []
|
|
n_skip = 0
|
|
with open(csv_path, 'r', newline='') as f:
|
|
reader = csv.reader(f)
|
|
header = [h.strip() for h in next(reader)]
|
|
h2i = {h: i for (i, h) in enumerate(header)}
|
|
needed = list(adapter.join_cols.values()) + [adapter.label_col]
|
|
for col in needed:
|
|
if col not in h2i:
|
|
raise KeyError(f'{csv_path.name}: missing column {col!r}')
|
|
i_src_ip = h2i[adapter.join_cols['src_ip']]
|
|
i_src_port = h2i[adapter.join_cols['src_port']]
|
|
i_dst_ip = h2i[adapter.join_cols['dst_ip']]
|
|
i_dst_port = h2i[adapter.join_cols['dst_port']]
|
|
i_proto = h2i[adapter.join_cols['protocol']]
|
|
i_ts = h2i[adapter.join_cols['timestamp']]
|
|
i_label = h2i[adapter.label_col]
|
|
for row in reader:
|
|
if not row:
|
|
continue
|
|
try:
|
|
raw_label = row[i_label]
|
|
except IndexError:
|
|
n_skip += 1
|
|
continue
|
|
if any((pat in raw_label for pat in adapter.drop_label_patterns)):
|
|
n_skip += 1
|
|
continue
|
|
try:
|
|
sp = int(float(row[i_src_port])) if row[i_src_port].strip() else 0
|
|
dp = int(float(row[i_dst_port])) if row[i_dst_port].strip() else 0
|
|
proto = int(float(row[i_proto])) if row[i_proto].strip() else 0
|
|
except (ValueError, IndexError):
|
|
n_skip += 1
|
|
continue
|
|
sip = row[i_src_ip].strip()
|
|
dip = row[i_dst_ip].strip()
|
|
ck = _canonical_key(sip, dip, sp, dp, proto)
|
|
ts_parsed = adapter.parse_timestamp(row[i_ts])
|
|
ts_epoch = float('nan') if ts_parsed is None else ts_parsed + time_offset_seconds
|
|
parsed.append((ck, ts_epoch, adapter.normalize_label(raw_label)))
|
|
keep_idx = _select_indices(labels=[p[2] for p in parsed], benign_token=adapter.benign_token, max_per_class=max_per_class, max_benign=max_benign, rng=rng)
|
|
rows_by_key: dict[tuple, list[tuple[int, float]]] = {}
|
|
labels_out: list[str] = []
|
|
class_counts: dict[str, int] = {}
|
|
row_idx = row_idx_start
|
|
for i in keep_idx:
|
|
(ck, ts_epoch, label) = parsed[i]
|
|
rows_by_key.setdefault(ck, []).append((row_idx, ts_epoch))
|
|
labels_out.append(label)
|
|
class_counts[label] = class_counts.get(label, 0) + 1
|
|
row_idx += 1
|
|
return (rows_by_key, labels_out, row_idx - row_idx_start, n_skip, class_counts)
|
|
|
|
def _select_indices(*, labels: list[str], benign_token: str, max_per_class: int | None, max_benign: int | None, rng: np.random.Generator | None) -> list[int]:
|
|
if max_per_class is None and max_benign is None:
|
|
return list(range(len(labels)))
|
|
assert rng is not None
|
|
buckets: dict[str, list[int]] = {}
|
|
for (i, label) in enumerate(labels):
|
|
buckets.setdefault(label, []).append(i)
|
|
keep: list[int] = []
|
|
for (label, idxs) in buckets.items():
|
|
cap = max_benign if label == benign_token else max_per_class
|
|
if cap is not None and len(idxs) > cap:
|
|
pick = rng.choice(len(idxs), size=cap, replace=False)
|
|
idxs = [idxs[j] for j in sorted(pick)]
|
|
keep.extend(idxs)
|
|
keep.sort()
|
|
return keep
|