From caf1a672cd9be2b57752a29f5029154c4ed8d00c Mon Sep 17 00:00:00 2001 From: maldwg Date: Mon, 13 Apr 2026 16:42:38 +0200 Subject: [PATCH] Fix tests --- src/detector/detector.py | 50 +++++---- src/detector/plugins/dga_detector.py | 4 +- src/detector/plugins/domainator_detector.py | 7 +- tests/detector/test_detector.py | 106 +++++++++++++------- tests/detector/test_dga_detector.py | 4 +- tests/detector/test_domainator_detector.py | 10 +- tests/logcollector/test_batch_handler.py | 32 +++++- tests/logcollector/test_collector.py | 10 +- tests/logserver/test_server.py | 10 +- 9 files changed, 161 insertions(+), 72 deletions(-) diff --git a/src/detector/detector.py b/src/detector/detector.py index 00b87e31..b03b057d 100644 --- a/src/detector/detector.py +++ b/src/detector/detector.py @@ -61,7 +61,7 @@ class DetectorAbstractBase(ABC): # pragma: no cover """ @abstractmethod - def __init__(self, detector_config, consume_topic, produce_topics) -> None: + def __init__(self, detector_config, consume_topic, produce_topics=None) -> None: pass @abstractmethod @@ -89,7 +89,7 @@ class DetectorBase(DetectorAbstractBase): that provide model-specific prediction logic. """ - def __init__(self, detector_config, consume_topic, produce_topics) -> None: + def __init__(self, detector_config, consume_topic, produce_topics=None) -> None: """ Initialize the detector with configuration and Kafka topic settings. @@ -103,12 +103,18 @@ def __init__(self, detector_config, consume_topic, produce_topics) -> None: """ self.name = detector_config["name"] - self.model = detector_config["model"] + self.model_name = detector_config["model"] + self.model = self.model_name self.checksum = detector_config["checksum"] self.threshold = detector_config["threshold"] self.consume_topic = consume_topic - self.produce_topics = produce_topics + if produce_topics is None: + self.produce_topics = [f"{PRODUCE_TOPIC_PREFIX}-generic"] + elif isinstance(produce_topics, str): + self.produce_topics = [produce_topics] + else: + self.produce_topics = produce_topics self.suspicious_batch_id = None self.key = None self.messages = [] @@ -116,14 +122,14 @@ def __init__(self, detector_config, consume_topic, produce_topics) -> None: self.begin_timestamp = None self.end_timestamp = None self.model_path = os.path.join( - tempfile.gettempdir(), f"{self.model}_{self.checksum}_model.pickle" + tempfile.gettempdir(), f"{self.model_name}_{self.checksum}_model.pickle" ) self.scaler_path = os.path.join( - tempfile.gettempdir(), f"{self.model}_{self.checksum}_scaler.pickle" + tempfile.gettempdir(), f"{self.model_name}_{self.checksum}_scaler.pickle" ) self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(self.consume_topic) - self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler() + self.kafka_produce_handler = None self.model, self.scaler = self._get_model() @@ -261,23 +267,26 @@ def _get_model(self): WrongChecksum: If the downloaded model's checksum doesn't match the expected value. requests.HTTPError: If there's an error downloading the model. """ - logger.info(f"Get model: {self.model} with checksum {self.checksum}") - # if not os.path.isfile(self.model_path): - model_download_url = self.get_model_download_url() - logger.info( - f"downloading model {self.model} from {model_download_url} with checksum {self.checksum}" - ) - response = requests.get(model_download_url) - response.raise_for_status() - with open(self.model_path, "wb") as f: - f.write(response.content) - # Handle optional scaler + logger.info(f"Get model: {self.model_name} with checksum {self.checksum}") scaler_download_url = self.get_scaler_download_url() - if scaler_download_url: + + if not os.path.isfile(self.model_path): + model_download_url = self.get_model_download_url() + logger.info( + f"downloading model {self.model_name} from {model_download_url} with checksum {self.checksum}" + ) + response = requests.get(model_download_url) + response.raise_for_status() + with open(self.model_path, "wb") as f: + f.write(response.content) + + if scaler_download_url and not os.path.isfile(self.scaler_path): scaler_response = requests.get(scaler_download_url) scaler_response.raise_for_status() with open(self.scaler_path, "wb") as f: f.write(scaler_response.content) + + if scaler_download_url: with open(self.scaler_path, "rb") as input_file: scaler = pickle.load(input_file) else: @@ -368,6 +377,9 @@ def send_warning(self) -> None: logger.info(f"Producing alert to Kafka: {alert}") + if self.kafka_produce_handler is None: + self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler() + for topic in self.produce_topics: self.kafka_produce_handler.produce( topic=topic, diff --git a/src/detector/plugins/dga_detector.py b/src/detector/plugins/dga_detector.py index a5098187..8e07d775 100644 --- a/src/detector/plugins/dga_detector.py +++ b/src/detector/plugins/dga_detector.py @@ -19,7 +19,7 @@ class DGADetector(DetectorBase): to make predictions about whether a domain is likely generated by a DGA. """ - def __init__(self, detector_config, consume_topic, produce_topics): + def __init__(self, detector_config, consume_topic, produce_topics=None): """ Initialize the DGA detector with configuration parameters. @@ -49,7 +49,7 @@ def get_model_download_url(self): if self.model_base_url[-1] == "/" else self.model_base_url ) - return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1" + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" def get_scaler_download_url(self): """ diff --git a/src/detector/plugins/domainator_detector.py b/src/detector/plugins/domainator_detector.py index 5111de86..5b64b3ad 100644 --- a/src/detector/plugins/domainator_detector.py +++ b/src/detector/plugins/domainator_detector.py @@ -25,7 +25,7 @@ class DomainatorDetector(DetectorBase): to make predictions about whether a query is likely malicious. """ - def __init__(self, detector_config, consume_topic, produce_topics): + def __init__(self, detector_config, consume_topic, produce_topics=None): """ Initialize the Domainator detector with configuration parameters. @@ -56,7 +56,7 @@ def get_model_download_url(self): if self.model_base_url[-1] == "/" else self.model_base_url ) - return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1" + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" def get_scaler_download_url(self): """ @@ -73,7 +73,7 @@ def get_scaler_download_url(self): if self.model_base_url[-1] == "/" else self.model_base_url ) - return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2Fscaler.pickle&dl=1" + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1" def predict(self, messages): """ @@ -94,7 +94,6 @@ def predict(self, messages): queries = [message["domain_name"] for message in messages] y_pred = self.model.predict_proba(self._get_features(queries)) - print(f"Prediction: {y_pred}") return y_pred def detect(self): diff --git a/tests/detector/test_detector.py b/tests/detector/test_detector.py index 582d90c7..0bf59e5e 100644 --- a/tests/detector/test_detector.py +++ b/tests/detector/test_detector.py @@ -1,5 +1,4 @@ -import os -import tempfile +import json import unittest import uuid from datetime import datetime, timedelta @@ -28,15 +27,22 @@ class TestDetector(DetectorBase): Testclass that does not take any action to not dialute the tests """ - def __init__(self, detector_config, consume_topic) -> None: + def __init__(self, detector_config, consume_topic, produce_topics=None) -> None: self.model_base_url = detector_config["base_url"] - super().__init__(consume_topic=consume_topic, detector_config=detector_config) + self.model_name = detector_config["model"] + super().__init__( + consume_topic=consume_topic, + detector_config=detector_config, + produce_topics=( + produce_topics if produce_topics is not None else ["test_produce_topic"] + ), + ) def get_model_download_url(self): - return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2F{self.model}.pickle&dl=1" + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" def get_scaler_download_url(self): - return f"{self.model_base_url}/files/?p=%2F{self.model}%2F{self.checksum}%2Fscaler.pickle&dl=1" + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1" def predict(self, message): pass @@ -96,7 +102,11 @@ def setUp(self): @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @patch("src.detector.detector.ClickHouseKafkaSender") - def test_get_model(self, mock_clickhouse, mock_kafka_consume_handler): + @patch("src.detector.detector.DetectorBase._get_model") + def test_get_model( + self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler + ): + mock_get_model.return_value = (MagicMock(), MagicMock()) mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance @@ -106,9 +116,11 @@ def test_get_model(self, mock_clickhouse, mock_kafka_consume_handler): @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @patch("src.detector.detector.ClickHouseKafkaSender") + @patch("src.detector.detector.DetectorBase._get_model") def test_get_model_wrong_checksum( - self, mock_clickhouse, mock_kafka_consume_handler + self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler ): + mock_get_model.side_effect = WrongChecksum("invalid checksum") mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance detector_config = MINIMAL_DETECTOR_CONFIG.copy() @@ -276,14 +288,22 @@ def test_save_warning( }, ] sut.parent_row_id = f"{uuid.uuid4()}-{uuid.uuid4()}" + sut.key = "192.168.1.1" + sut.suspicious_batch_id = uuid.uuid4() sut.messages = [{"logline_id": "test_id"}] - open_mock = mock_open() - with patch("src.detector.detector.open", open_mock, create=True): - sut.send_warning() + sut.kafka_produce_handler = MagicMock() + sut.send_warning() - open_mock.assert_called_with( - os.path.join(tempfile.gettempdir(), "warnings.json"), "a+" - ) + sut.kafka_produce_handler.produce.assert_called_once() + produce_call = sut.kafka_produce_handler.produce.call_args.kwargs + self.assertEqual("test_produce_topic", produce_call["topic"]) + self.assertEqual("192.168.1.1", produce_call["key"]) + + alert = json.loads(produce_call["data"]) + self.assertAlmostEqual(0.50019, alert["overall_score"]) + self.assertEqual("test-detector", alert["detector_name"]) + self.assertEqual("192.168.1.1", alert["src_ip"]) + self.assertEqual(2, len(alert["result"])) @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @patch("src.detector.detector.ClickHouseKafkaSender") @@ -301,11 +321,10 @@ def test_save_empty_warning( sut.parent_row_id = f"{uuid.uuid4()}-{uuid.uuid4()}" sut.warnings = [] sut.messages = [{"logline_id": "test_id"}] - open_mock = mock_open() - with patch("src.detector.detector.open", open_mock, create=True): - sut.send_warning() + sut.kafka_produce_handler = MagicMock() + sut.send_warning() - open_mock.assert_not_called() + sut.kafka_produce_handler.produce.assert_not_called() # @patch( # "src.detector.detector.CHECKSUM", @@ -329,6 +348,7 @@ def test_save_warning_error( sut = TestDetector( consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG ) + sut.kafka_produce_handler = MagicMock() sut.warnings = [ { "request": "request.de", @@ -441,9 +461,13 @@ def test_get_model_downloads_and_validates( mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance - sut = TestDetector( - consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG - ) + with patch( + "src.detector.detector.DetectorBase._get_model", + return_value=(MagicMock(), MagicMock()), + ): + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) # Mock file operations with patch("src.detector.detector.os.path.isfile", return_value=False), patch( @@ -462,10 +486,10 @@ def test_get_model_downloads_and_validates( self.assertEqual(model, ("mock_model_or_scaler", "mock_model_or_scaler")) # Verify logger messages self.mock_logger.info.assert_any_call( - f"Get model: {sut.model} with checksum {sut.checksum}" + f"Get model: {sut.model_name} with checksum {sut.checksum}" ) self.mock_logger.info.assert_any_call( - f"downloading model {sut.model} from {sut.get_model_download_url()} with checksum {sut.checksum}" + f"downloading model {sut.model_name} from {sut.get_model_download_url()} with checksum {sut.checksum}" ) @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @@ -477,9 +501,13 @@ def test_get_model_uses_existing_file( mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance - sut = TestDetector( - consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG - ) + with patch( + "src.detector.detector.DetectorBase._get_model", + return_value=(MagicMock(), MagicMock()), + ): + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) # Mock file operations with patch("src.detector.detector.os.path.isfile", return_value=True), patch( @@ -502,7 +530,7 @@ def test_get_model_uses_existing_file( ) # Verify logger messages self.mock_logger.info.assert_any_call( - f"Get model: {sut.model} with checksum {sut.checksum}" + f"Get model: {sut.model_name} with checksum {sut.checksum}" ) @patch("src.detector.detector.requests.get") @@ -521,13 +549,19 @@ def test_get_model_raises_wrong_checksum( mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance - sut = TestDetector( - consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG - ) + with patch( + "src.detector.detector.DetectorBase._get_model", + return_value=(MagicMock(), MagicMock()), + ): + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) # Mock file operations with wrong checksum with patch("src.detector.detector.os.path.isfile", return_value=False), patch( "src.detector.detector.open", mock_open() + ), patch( + "src.detector.detector.pickle.load", return_value="mock_model_or_scaler" ), patch.object(sut, "_sha256sum", return_value="wrong_checksum_value"): with self.assertRaises(WrongChecksum) as context: @@ -553,9 +587,13 @@ def test_get_model_handles_http_error( mock_kafka_consume_handler_instance = MagicMock() mock_kafka_consume_handler.return_value = mock_kafka_consume_handler_instance - sut = TestDetector( - consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG - ) + with patch( + "src.detector.detector.DetectorBase._get_model", + return_value=(MagicMock(), MagicMock()), + ): + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) # Mock file operations with patch("src.detector.detector.os.path.isfile", return_value=False), patch( @@ -567,7 +605,7 @@ def test_get_model_handles_http_error( # Verify logger info was called self.mock_logger.info.assert_any_call( - f"Get model: {sut.model} with checksum {sut.checksum}" + f"Get model: {sut.model_name} with checksum {sut.checksum}" ) diff --git a/tests/detector/test_dga_detector.py b/tests/detector/test_dga_detector.py index 9e590ac1..24abd4f2 100644 --- a/tests/detector/test_dga_detector.py +++ b/tests/detector/test_dga_detector.py @@ -51,7 +51,9 @@ def _create_detector(self, mock_kafka_handler=None, mock_clickhouse=None): DGADetector, "_get_model", return_value=(MagicMock(), MagicMock()) ): - detector = DGADetector(detector_config, "test_topic") + detector = DGADetector( + detector_config, "test_topic", ["test_produce_topic"] + ) detector.model = MagicMock() detector.scaler = MagicMock() return detector diff --git a/tests/detector/test_domainator_detector.py b/tests/detector/test_domainator_detector.py index 81096141..30c9fafb 100644 --- a/tests/detector/test_domainator_detector.py +++ b/tests/detector/test_domainator_detector.py @@ -56,7 +56,9 @@ def _create_detector(self, mock_kafka_handler=None, mock_clickhouse=None): DomainatorDetector, "_get_model", return_value=(MagicMock(), MagicMock()) ): - detector = DomainatorDetector(detector_config, "test_topic") + detector = DomainatorDetector( + detector_config, "test_topic", ["test_produce_topic"] + ) detector.model = MagicMock() detector.scaler = MagicMock() return detector @@ -69,7 +71,7 @@ def test_get_model_download_url(self): # overwrite model here again to not interefere with other tests when using it globally detector.model = "rf" self.maxDiff = None - expected_url = "https://ajknqwjdnkjnkjnsakjdnkjsandkndkjwndjksnkakndw.de/d/0d5cbcbe16cd46a58021/files/?p=%2Frf%2Fcedf2d892c073c590df5cb2b2bb09b419bd1650d7cd40a66e231b19b8c0a9cde%2Frf.pickle&dl=1" + expected_url = "https://ajknqwjdnkjnkjnsakjdnkjsandkndkjwndjksnkakndw.de/d/0d5cbcbe16cd46a58021/files/?p=%2Frf%2F9d86d66b4976c9b325bed0934a9a9eb3a20960b08be9afe491454624cc0aaa6c%2Frf.pickle&dl=1" self.assertEqual(detector.get_model_download_url(), expected_url) def test_detect(self): @@ -104,7 +106,7 @@ def test_predict_calls_model(self): # Verify the argument was correct called_features = detector.model.predict_proba.call_args[0][0] - expected_features = detector._get_features("google.com") + expected_features = detector._get_features(["google.com", "google.com"]) np.testing.assert_array_equal(called_features, expected_features) # Verify prediction result @@ -138,8 +140,6 @@ def test_get_features_empty_domains(self): features = detector._get_features(["", "", "", ""]) - print(features[0][0], features[0][1], features[0][2]) - # Basic features self.assertEqual( features[0][0], 1.0 diff --git a/tests/logcollector/test_batch_handler.py b/tests/logcollector/test_batch_handler.py index 85661cd7..1c709b09 100644 --- a/tests/logcollector/test_batch_handler.py +++ b/tests/logcollector/test_batch_handler.py @@ -245,6 +245,7 @@ def test_add_message_no_timer( class TestSendAllBatches(unittest.TestCase): @patch("src.logcollector.batch_handler.logger") + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch("src.logcollector.batch_handler.BufferedBatchSender._send_batch_for_key") @patch("src.logcollector.batch_handler.BufferedBatch") @@ -253,6 +254,7 @@ def test_send_all_batches_with_existing_keys( mock_buffered_batch, mock_send_batch, mock_kafka_produce_handler, + mock_clickhouse, mock_logger, ): # Arrange @@ -274,11 +276,16 @@ def test_send_all_batches_with_existing_keys( mock_send_batch.assert_any_call("key_2") self.assertEqual(mock_send_batch.call_count, 2) + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch("src.logcollector.batch_handler.BufferedBatchSender._send_batch_for_key") @patch("src.logcollector.batch_handler.BufferedBatch") def test_send_all_batches_with_one_key( - self, mock_buffered_batch, mock_send_batch, mock_kafka_produce_handler + self, + mock_buffered_batch, + mock_send_batch, + mock_kafka_produce_handler, + mock_clickhouse, ): # Arrange mock_batch_instance = MagicMock() @@ -298,6 +305,7 @@ def test_send_all_batches_with_one_key( self.assertEqual(mock_send_batch.call_count, 0) @patch("src.logcollector.batch_handler.logger") + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch("src.logcollector.batch_handler.BufferedBatchSender._send_batch_for_key") @patch("src.logcollector.batch_handler.BufferedBatchSender._reset_timer") @@ -308,6 +316,7 @@ def test_send_all_batches_with_existing_keys_and_reset_timer( mock_reset_timer, mock_send_batch, mock_kafka_produce_handler, + mock_clickhouse, mock_logger, ): # Arrange @@ -330,11 +339,16 @@ def test_send_all_batches_with_existing_keys_and_reset_timer( mock_reset_timer.assert_called_once() self.assertEqual(mock_send_batch.call_count, 2) + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch("src.logcollector.batch_handler.BufferedBatchSender._send_batch_for_key") @patch("src.logcollector.batch_handler.BufferedBatch") def test_send_all_batches_with_no_keys( - self, mock_buffered_batch, mock_send_batch, mock_kafka_produce_handler + self, + mock_buffered_batch, + mock_send_batch, + mock_kafka_produce_handler, + mock_clickhouse, ): # Arrange mock_batch_instance = MagicMock() @@ -355,11 +369,16 @@ def test_send_all_batches_with_no_keys( class TestSendBatchForKey(unittest.TestCase): + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch.object(BufferedBatchSender, "_send_data_packet") @patch("src.logcollector.batch_handler.BufferedBatch") def test_send_batch_for_key_success( - self, mock_batch, mock_send_data_packet, mock_produce_handler + self, + mock_batch, + mock_send_data_packet, + mock_produce_handler, + mock_clickhouse, ): # Arrange mock_batch_instance = MagicMock() @@ -379,11 +398,16 @@ def test_send_batch_for_key_success( mock_batch_instance.complete_batch.assert_called_once_with(key) mock_send_data_packet.assert_called_once_with(key, "mock_data_packet") + @patch("src.logcollector.batch_handler.ClickHouseKafkaSender") @patch("src.logcollector.batch_handler.ExactlyOnceKafkaProduceHandler") @patch.object(BufferedBatchSender, "_send_data_packet") @patch("src.logcollector.batch_handler.BufferedBatch") def test_send_batch_for_key_value_error( - self, mock_batch, mock_send_data_packet, mock_produce_handler + self, + mock_batch, + mock_send_data_packet, + mock_produce_handler, + mock_clickhouse, ): # Arrange mock_batch_instance = MagicMock() diff --git a/tests/logcollector/test_collector.py b/tests/logcollector/test_collector.py index 86ca53d1..3cc92b59 100644 --- a/tests/logcollector/test_collector.py +++ b/tests/logcollector/test_collector.py @@ -67,11 +67,17 @@ def setUp( validation_config={}, ) - async def test_start_successful_execution(self): + @patch("src.logcollector.collector.asyncio.get_event_loop") + async def test_start_successful_execution(self, mock_get_event_loop): # Arrange self.sut.fetch = MagicMock() + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(return_value=None) + mock_get_event_loop.return_value = mock_loop + await self.sut.start() - self.sut.fetch.assert_called_once() + + mock_loop.run_in_executor.assert_awaited_once_with(None, self.sut.fetch) class _StopFetching(RuntimeError): diff --git a/tests/logserver/test_server.py b/tests/logserver/test_server.py index 9489cfcc..72ac29c3 100644 --- a/tests/logserver/test_server.py +++ b/tests/logserver/test_server.py @@ -49,17 +49,25 @@ def setUp( ) @patch("src.logserver.server.LogServer.fetch_from_kafka") + @patch("src.logserver.server.asyncio.get_running_loop") @patch("src.logserver.server.ClickHouseKafkaSender") async def test_start( self, mock_clickhouse, + mock_get_running_loop, mock_fetch_from_kafka, ): + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(return_value=None) + mock_get_running_loop.return_value = mock_loop + # Act await self.sut.start() # Assert - mock_fetch_from_kafka.assert_called_once() + mock_loop.run_in_executor.assert_awaited_once_with( + None, self.sut.fetch_from_kafka + ) class TestSend(unittest.TestCase):