Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions src/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class DetectorAbstractBase(ABC): # pragma: no cover
"""

@abstractmethod
def __init__(self, detector_config, consume_topic, produce_topics) -> None:
def __init__(self, detector_config, consume_topic, produce_topics=None) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -89,7 +89,7 @@ class DetectorBase(DetectorAbstractBase):
that provide model-specific prediction logic.
"""

def __init__(self, detector_config, consume_topic, produce_topics) -> None:
def __init__(self, detector_config, consume_topic, produce_topics=None) -> None:
"""
Initialize the detector with configuration and Kafka topic settings.

Expand All @@ -103,27 +103,33 @@ def __init__(self, detector_config, consume_topic, produce_topics) -> None:
"""

self.name = detector_config["name"]
self.model = detector_config["model"]
self.model_name = detector_config["model"]
self.model = self.model_name
self.checksum = detector_config["checksum"]
self.threshold = detector_config["threshold"]

self.consume_topic = consume_topic
self.produce_topics = produce_topics
if produce_topics is None:
self.produce_topics = [f"{PRODUCE_TOPIC_PREFIX}-generic"]
elif isinstance(produce_topics, str):
self.produce_topics = [produce_topics]
else:
self.produce_topics = produce_topics
self.suspicious_batch_id = None
self.key = None
self.messages = []
self.warnings = []
self.begin_timestamp = None
self.end_timestamp = None
self.model_path = os.path.join(
tempfile.gettempdir(), f"{self.model}_{self.checksum}_model.pickle"
tempfile.gettempdir(), f"{self.model_name}_{self.checksum}_model.pickle"
)
self.scaler_path = os.path.join(
tempfile.gettempdir(), f"{self.model}_{self.checksum}_scaler.pickle"
tempfile.gettempdir(), f"{self.model_name}_{self.checksum}_scaler.pickle"
)

self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(self.consume_topic)
self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler()
self.kafka_produce_handler = None

self.model, self.scaler = self._get_model()

Expand Down Expand Up @@ -261,23 +267,26 @@ def _get_model(self):
WrongChecksum: If the downloaded model's checksum doesn't match the expected value.
requests.HTTPError: If there's an error downloading the model.
"""
logger.info(f"Get model: {self.model} with checksum {self.checksum}")
# if not os.path.isfile(self.model_path):
model_download_url = self.get_model_download_url()
logger.info(
f"downloading model {self.model} from {model_download_url} with checksum {self.checksum}"
)
response = requests.get(model_download_url)
response.raise_for_status()
with open(self.model_path, "wb") as f:
f.write(response.content)
# Handle optional scaler
logger.info(f"Get model: {self.model_name} with checksum {self.checksum}")
scaler_download_url = self.get_scaler_download_url()
if scaler_download_url:

if not os.path.isfile(self.model_path):
model_download_url = self.get_model_download_url()
logger.info(
f"downloading model {self.model_name} from {model_download_url} with checksum {self.checksum}"
)
response = requests.get(model_download_url)
response.raise_for_status()
with open(self.model_path, "wb") as f:
f.write(response.content)

if scaler_download_url and not os.path.isfile(self.scaler_path):
scaler_response = requests.get(scaler_download_url)
scaler_response.raise_for_status()
with open(self.scaler_path, "wb") as f:
f.write(scaler_response.content)

if scaler_download_url:
with open(self.scaler_path, "rb") as input_file:
scaler = pickle.load(input_file)
else:
Expand Down Expand Up @@ -368,6 +377,9 @@ def send_warning(self) -> None:

logger.info(f"Producing alert to Kafka: {alert}")

if self.kafka_produce_handler is None:
self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler()

for topic in self.produce_topics:
self.kafka_produce_handler.produce(
topic=topic,
Expand Down
4 changes: 2 additions & 2 deletions src/detector/plugins/dga_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DGADetector(DetectorBase):
to make predictions about whether a domain is likely generated by a DGA.
"""

def __init__(self, detector_config, consume_topic, produce_topics):
def __init__(self, detector_config, consume_topic, produce_topics=None):
"""
Initialize the DGA detector with configuration parameters.

Expand Down Expand Up @@ -49,7 +49,7 @@ def get_model_download_url(self):
if self.model_base_url[-1] == "/"
else self.model_base_url
)
return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1"
return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1"

def get_scaler_download_url(self):
"""
Expand Down
7 changes: 3 additions & 4 deletions src/detector/plugins/domainator_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DomainatorDetector(DetectorBase):
to make predictions about whether a query is likely malicious.
"""

def __init__(self, detector_config, consume_topic, produce_topics):
def __init__(self, detector_config, consume_topic, produce_topics=None):
"""
Initialize the Domainator detector with configuration parameters.

Expand Down Expand Up @@ -56,7 +56,7 @@ def get_model_download_url(self):
if self.model_base_url[-1] == "/"
else self.model_base_url
)
return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1"
return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1"

def get_scaler_download_url(self):
"""
Expand All @@ -73,7 +73,7 @@ def get_scaler_download_url(self):
if self.model_base_url[-1] == "/"
else self.model_base_url
)
return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2Fscaler.pickle&dl=1"
return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1"

def predict(self, messages):
"""
Expand All @@ -94,7 +94,6 @@ def predict(self, messages):
queries = [message["domain_name"] for message in messages]

y_pred = self.model.predict_proba(self._get_features(queries))
print(f"Prediction: {y_pred}")
return y_pred

def detect(self):
Expand Down
106 changes: 72 additions & 34 deletions tests/detector/test_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import tempfile
import json
import unittest
import uuid
from datetime import datetime, timedelta
Expand Down Expand Up @@ -28,15 +27,22 @@ class TestDetector(DetectorBase):
Testclass that does not take any action to not dialute the tests
"""

def __init__(self, detector_config, consume_topic) -> None:
def __init__(self, detector_config, consume_topic, produce_topics=None) -> None:
self.model_base_url = detector_config["base_url"]
super().__init__(consume_topic=consume_topic, detector_config=detector_config)
self.model_name = detector_config["model"]
super().__init__(
consume_topic=consume_topic,
detector_config=detector_config,
produce_topics=(
produce_topics if produce_topics is not None else ["test_produce_topic"]
),
)

def get_model_download_url(self):
return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1"
return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1"

def get_scaler_download_url(self):
return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2Fscaler.pickle&dl=1"
return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1"

def predict(self, message):
pass
Expand Down Expand Up @@ -96,7 +102,11 @@ def setUp(self):

@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
@patch("src.detector.detector.ClickHouseKafkaSender")
def test_get_model(self, mock_clickhouse, mock_kafka_consume_handler):
@patch("src.detector.detector.DetectorBase._get_model")
def test_get_model(
self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler
):
mock_get_model.return_value = (MagicMock(), MagicMock())
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

Expand All @@ -106,9 +116,11 @@ def test_get_model(self, mock_clickhouse, mock_kafka_consume_handler):

@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
@patch("src.detector.detector.ClickHouseKafkaSender")
@patch("src.detector.detector.DetectorBase._get_model")
def test_get_model_wrong_checksum(
self, mock_clickhouse, mock_kafka_consume_handler
self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler
):
mock_get_model.side_effect = WrongChecksum("invalid checksum")
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance
detector_config = MINIMAL_DETECTOR_CONFIG.copy()
Expand Down Expand Up @@ -276,14 +288,22 @@ def test_save_warning(
},
]
sut.parent_row_id = f"{uuid.uuid4()}-{uuid.uuid4()}"
sut.key = "192.168.1.1"
sut.suspicious_batch_id = uuid.uuid4()
sut.messages = [{"logline_id": "test_id"}]
open_mock = mock_open()
with patch("src.detector.detector.open", open_mock, create=True):
sut.send_warning()
sut.kafka_produce_handler = MagicMock()
sut.send_warning()

open_mock.assert_called_with(
os.path.join(tempfile.gettempdir(), "warnings.json"), "a+"
)
sut.kafka_produce_handler.produce.assert_called_once()
produce_call = sut.kafka_produce_handler.produce.call_args.kwargs
self.assertEqual("test_produce_topic", produce_call["topic"])
self.assertEqual("192.168.1.1", produce_call["key"])

alert = json.loads(produce_call["data"])
self.assertAlmostEqual(0.50019, alert["overall_score"])
self.assertEqual("test-detector", alert["detector_name"])
self.assertEqual("192.168.1.1", alert["src_ip"])
self.assertEqual(2, len(alert["result"]))

@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
@patch("src.detector.detector.ClickHouseKafkaSender")
Expand All @@ -301,11 +321,10 @@ def test_save_empty_warning(
sut.parent_row_id = f"{uuid.uuid4()}-{uuid.uuid4()}"
sut.warnings = []
sut.messages = [{"logline_id": "test_id"}]
open_mock = mock_open()
with patch("src.detector.detector.open", open_mock, create=True):
sut.send_warning()
sut.kafka_produce_handler = MagicMock()
sut.send_warning()

open_mock.assert_not_called()
sut.kafka_produce_handler.produce.assert_not_called()

# @patch(
# "src.detector.detector.CHECKSUM",
Expand All @@ -329,6 +348,7 @@ def test_save_warning_error(
sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)
sut.kafka_produce_handler = MagicMock()
sut.warnings = [
{
"request": "request.de",
Expand Down Expand Up @@ -441,9 +461,13 @@ def test_get_model_downloads_and_validates(
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)
with patch(
"src.detector.detector.DetectorBase._get_model",
return_value=(MagicMock(), MagicMock()),
):
sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)

# Mock file operations
with patch("src.detector.detector.os.path.isfile", return_value=False), patch(
Expand All @@ -462,10 +486,10 @@ def test_get_model_downloads_and_validates(
self.assertEqual(model, ("mock_model_or_scaler", "mock_model_or_scaler"))
# Verify logger messages
self.mock_logger.info.assert_any_call(
f"Get model: {sut.model} with checksum {sut.checksum}"
f"Get model: {sut.model_name} with checksum {sut.checksum}"
)
self.mock_logger.info.assert_any_call(
f"downloading model {sut.model} from {sut.get_model_download_url()} with checksum {sut.checksum}"
f"downloading model {sut.model_name} from {sut.get_model_download_url()} with checksum {sut.checksum}"
)

@patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler")
Expand All @@ -477,9 +501,13 @@ def test_get_model_uses_existing_file(
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)
with patch(
"src.detector.detector.DetectorBase._get_model",
return_value=(MagicMock(), MagicMock()),
):
sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)

# Mock file operations
with patch("src.detector.detector.os.path.isfile", return_value=True), patch(
Expand All @@ -502,7 +530,7 @@ def test_get_model_uses_existing_file(
)
# Verify logger messages
self.mock_logger.info.assert_any_call(
f"Get model: {sut.model} with checksum {sut.checksum}"
f"Get model: {sut.model_name} with checksum {sut.checksum}"
)

@patch("src.detector.detector.requests.get")
Expand All @@ -521,13 +549,19 @@ def test_get_model_raises_wrong_checksum(
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)
with patch(
"src.detector.detector.DetectorBase._get_model",
return_value=(MagicMock(), MagicMock()),
):
sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)

# Mock file operations with wrong checksum
with patch("src.detector.detector.os.path.isfile", return_value=False), patch(
"src.detector.detector.open", mock_open()
), patch(
"src.detector.detector.pickle.load", return_value="mock_model_or_scaler"
), patch.object(sut, "_sha256sum", return_value="wrong_checksum_value"):

with self.assertRaises(WrongChecksum) as context:
Expand All @@ -553,9 +587,13 @@ def test_get_model_handles_http_error(
mock_kafka_consume_handler_instance = MagicMock()
mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance

sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)
with patch(
"src.detector.detector.DetectorBase._get_model",
return_value=(MagicMock(), MagicMock()),
):
sut = TestDetector(
consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG
)

# Mock file operations
with patch("src.detector.detector.os.path.isfile", return_value=False), patch(
Expand All @@ -567,7 +605,7 @@ def test_get_model_handles_http_error(

# Verify logger info was called
self.mock_logger.info.assert_any_call(
f"Get model: {sut.model} with checksum {sut.checksum}"
f"Get model: {sut.model_name} with checksum {sut.checksum}"
)


Expand Down
4 changes: 3 additions & 1 deletion tests/detector/test_dga_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _create_detector(self, mock_kafka_handler=None, mock_clickhouse=None):
DGADetector, "_get_model", return_value=(MagicMock(), MagicMock())
):

detector = DGADetector(detector_config, "test_topic")
detector = DGADetector(
detector_config, "test_topic", ["test_produce_topic"]
)
detector.model = MagicMock()
detector.scaler = MagicMock()
return detector
Expand Down
Loading