from __future__ import annotations import csv import sys from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Callable import numpy as np sys.path.insert(0, str(Path(__file__).resolve().parent)) from extract_lib import _canonical_key @dataclass(frozen=True) class CsvFlowAdapter: join_cols: dict[str, str] label_col: str timestamp_formats: tuple[str, ...] benign_aliases: frozenset[str] benign_token: str = 'normal' drop_label_patterns: tuple[str, ...] = () label_aliases: dict[str, str] = field(default_factory=dict) label_normalizer: Callable[[str], str] | None = None def normalize_label(self, raw: str) -> str: if self.label_normalizer is not None: return self.label_normalizer(raw) s = raw.strip() if s in self.benign_aliases: return self.benign_token return self.label_aliases.get(s, s) def parse_timestamp(self, raw: str) -> float | None: s = raw.strip() if not s: return None for fmt in self.timestamp_formats: try: return datetime.strptime(s, fmt).timestamp() except ValueError: continue return None def parse_csv_rows(*, csv_path: Path, row_idx_start: int, time_offset_seconds: float, adapter: CsvFlowAdapter, max_per_class: int | None=None, max_benign: int | None=None, rng: np.random.Generator | None=None) -> tuple[dict[tuple, list[tuple[int, float]]], list[str], int, int, dict[str, int]]: if (max_per_class is not None or max_benign is not None) and rng is None: rng = np.random.default_rng(42) parsed: list[tuple[tuple, float, str]] = [] n_skip = 0 with open(csv_path, 'r', newline='') as f: reader = csv.reader(f) header = [h.strip() for h in next(reader)] h2i = {h: i for (i, h) in enumerate(header)} needed = list(adapter.join_cols.values()) + [adapter.label_col] for col in needed: if col not in h2i: raise KeyError(f'{csv_path.name}: missing column {col!r}') i_src_ip = h2i[adapter.join_cols['src_ip']] i_src_port = h2i[adapter.join_cols['src_port']] i_dst_ip = h2i[adapter.join_cols['dst_ip']] i_dst_port = h2i[adapter.join_cols['dst_port']] i_proto = h2i[adapter.join_cols['protocol']] i_ts = h2i[adapter.join_cols['timestamp']] i_label = h2i[adapter.label_col] for row in reader: if not row: continue try: raw_label = row[i_label] except IndexError: n_skip += 1 continue if any((pat in raw_label for pat in adapter.drop_label_patterns)): n_skip += 1 continue try: sp = int(float(row[i_src_port])) if row[i_src_port].strip() else 0 dp = int(float(row[i_dst_port])) if row[i_dst_port].strip() else 0 proto = int(float(row[i_proto])) if row[i_proto].strip() else 0 except (ValueError, IndexError): n_skip += 1 continue sip = row[i_src_ip].strip() dip = row[i_dst_ip].strip() ck = _canonical_key(sip, dip, sp, dp, proto) ts_parsed = adapter.parse_timestamp(row[i_ts]) ts_epoch = float('nan') if ts_parsed is None else ts_parsed + time_offset_seconds parsed.append((ck, ts_epoch, adapter.normalize_label(raw_label))) keep_idx = _select_indices(labels=[p[2] for p in parsed], benign_token=adapter.benign_token, max_per_class=max_per_class, max_benign=max_benign, rng=rng) rows_by_key: dict[tuple, list[tuple[int, float]]] = {} labels_out: list[str] = [] class_counts: dict[str, int] = {} row_idx = row_idx_start for i in keep_idx: (ck, ts_epoch, label) = parsed[i] rows_by_key.setdefault(ck, []).append((row_idx, ts_epoch)) labels_out.append(label) class_counts[label] = class_counts.get(label, 0) + 1 row_idx += 1 return (rows_by_key, labels_out, row_idx - row_idx_start, n_skip, class_counts) def _select_indices(*, labels: list[str], benign_token: str, max_per_class: int | None, max_benign: int | None, rng: np.random.Generator | None) -> list[int]: if max_per_class is None and max_benign is None: return list(range(len(labels))) assert rng is not None buckets: dict[str, list[int]] = {} for (i, label) in enumerate(labels): buckets.setdefault(label, []).append(i) keep: list[int] = [] for (label, idxs) in buckets.items(): cap = max_benign if label == benign_token else max_per_class if cap is not None and len(idxs) > cap: pick = rng.choice(len(idxs), size=cap, replace=False) idxs = [idxs[j] for j in sorted(pick)] keep.extend(idxs) keep.sort() return keep