diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md
index fb9703ff..c21a55ae 100644
--- a/.github/CODE_OF_CONDUCT.md
+++ b/.github/CODE_OF_CONDUCT.md
@@ -89,4 +89,3 @@ This Code of Conduct is adapted from the Contributor Covenant, version 3.0, perm
Contributor Covenant is stewarded by the Organization for Ethical Source and licensed under CC BY-SA 4.0. To view a copy of this license, visit [https://creativecommons.org/licenses/by-sa/4.0/](https://creativecommons.org/licenses/by-sa/4.0/)
For answers to common questions about Contributor Covenant, see the FAQ at [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are provided at [https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations). Additional enforcement and community guideline resources can be found at [https://www.contributor-covenant.org/resources](https://www.contributor-covenant.org/resources). The enforcement ladder was inspired by the work of [Mozilla’s code of conduct team](https://github.com/mozilla/inclusion).
-
diff --git a/.github/workflows/build_publish_docker.yml b/.github/workflows/build_publish_docker.yml
index 32197ce1..f9851b33 100644
--- a/.github/workflows/build_publish_docker.yml
+++ b/.github/workflows/build_publish_docker.yml
@@ -72,4 +72,4 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
file: ./docker/dockerfiles/Dockerfile.${{ matrix.container }}
cache-from: type=gha
- cache-to: type=gha,mode=max
\ No newline at end of file
+ cache-to: type=gha,mode=max
diff --git a/.gitignore b/.gitignore
index 0cc2a700..69174f9d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -335,4 +335,4 @@ docs/api/
!vcpkg.json
!**/vcpkg.json
!compile_commands.json
-vcpkg_installed
\ No newline at end of file
+vcpkg_installed
diff --git a/.readthedocs.yml b/.readthedocs.yml
index 336e933a..9b23a01e 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -20,4 +20,4 @@ python:
- requirements: requirements/requirements.logserver.txt
sphinx:
- configuration: docs/conf.py
\ No newline at end of file
+ configuration: docs/conf.py
diff --git a/README.md b/README.md
index 28f35375..c246509c 100644
--- a/README.md
+++ b/README.md
@@ -109,7 +109,7 @@ values in detail.
### Testing Your Own Data
-If you want to ingest data to the pipeline, you can do so via the zeek container. Either select the interface in the `config.yaml` zeek should be listening on and set `static_analysis: false` or provide PCAPs to Zeek by adding them in the `data/test_pcaps` directory, which is mounted per default for Zeek to ingest static data.
+If you want to ingest data to the pipeline, you can do so via the zeek container. Either select the interface in the `config.yaml` zeek should be listening on and set `static_analysis: false` or provide PCAPs to Zeek by adding them in the `data/test_pcaps` directory, which is mounted per default for Zeek to ingest static data.
### Monitoring
To monitor the system and observe its real-time behavior, multiple Grafana dashboards have been set up.
diff --git a/assets/heidgaf_architecture.svg b/assets/heidgaf_architecture.svg
index 4530acf2..4c26e527 100644
--- a/assets/heidgaf_architecture.svg
+++ b/assets/heidgaf_architecture.svg
@@ -1,4 +1,4 @@
-
\ No newline at end of file
+
diff --git a/config.yaml b/config.yaml
index 17a874a5..617d4a6a 100644
--- a/config.yaml
+++ b/config.yaml
@@ -19,7 +19,7 @@ pipeline:
log_storage:
logserver:
input_file: "/opt/file.txt"
-
+
diff --git a/docs/pipeline.rst b/docs/pipeline.rst
index 5ab59c20..37c11dcc 100644
--- a/docs/pipeline.rst
+++ b/docs/pipeline.rst
@@ -840,4 +840,4 @@ Domainator Detector
The :class:`DomainatorDetector` consumes anomalous batches of requests.
It identifies potential data exfiltration and command & control on the subdomain level by analyzing characteristics of the subdomains.
Messages are grouped by domain into fixed-size windows to allow for sequential anomaly detection. The detector leverages machine learning based on statistical and linguistic features from the domain name
-including label lengths, character frequencies, entropy measures, and counts of different character types across domain name levels.
\ No newline at end of file
+including label lengths, character frequencies, entropy measures, and counts of different character types across domain name levels.
diff --git a/requirements/requirements.detector.txt b/requirements/requirements.detector.txt
index 6be038c1..5ca06bee 100644
--- a/requirements/requirements.detector.txt
+++ b/requirements/requirements.detector.txt
@@ -7,4 +7,4 @@ confluent-kafka~=2.4.0
marshmallow_dataclass~=8.7.1
clickhouse_connect~=0.8.3
pylcs
-Levenshtein
\ No newline at end of file
+Levenshtein
diff --git a/requirements/requirements.train.txt b/requirements/requirements.train.txt
index 3ca51a33..82e15a1f 100644
--- a/requirements/requirements.train.txt
+++ b/requirements/requirements.train.txt
@@ -13,4 +13,4 @@ seaborn
lightgbm
imblearn
pylcs
-Levenshtein
\ No newline at end of file
+Levenshtein
diff --git a/src/alerter/alerter.py b/src/alerter/alerter.py
index 0eba4a64..5934e3f1 100644
--- a/src/alerter/alerter.py
+++ b/src/alerter/alerter.py
@@ -31,6 +31,7 @@ class AlerterAbstractBase(ABC):
"""
Abstract base class for all alerter implementations.
"""
+
@abstractmethod
def __init__(self, alerter_config, consume_topic) -> None:
pass
@@ -51,6 +52,7 @@ class AlerterBase(AlerterAbstractBase):
executing custom processing via plugins, and performing base actions
like logging to a file or forwarding to an external Kafka topic.
"""
+
def __init__(self, alerter_config, consume_topic) -> None:
self.name = alerter_config.get("name", "generic")
self.consume_topic = consume_topic
@@ -59,12 +61,16 @@ def __init__(self, alerter_config, consume_topic) -> None:
self.key = None
self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(self.consume_topic)
-
+
# Base actions config
self.log_to_file = ALERTING_CONFIG.get("log_to_file", False)
- self.log_file_path = ALERTING_CONFIG.get("log_file_path", "/opt/logs/alerts.txt")
+ self.log_file_path = ALERTING_CONFIG.get(
+ "log_file_path", "/opt/logs/alerts.txt"
+ )
self.log_to_kafka = ALERTING_CONFIG.get("log_to_kafka", False)
- self.external_kafka_topic = ALERTING_CONFIG.get("external_kafka_topic", "external_alerts_topic")
+ self.external_kafka_topic = ALERTING_CONFIG.get(
+ "external_kafka_topic", "external_alerts_topic"
+ )
if self.log_to_file:
ensure_directory(self.log_file_path)
@@ -72,12 +78,11 @@ def __init__(self, alerter_config, consume_topic) -> None:
if self.log_to_kafka:
self._setup_kafka_output_topics()
-
def _setup_kafka_output_topics(self):
"""
- Ensure that the external Kafka topic exists.
-
- Since no internal consumer subscribes to this topic, auto-creation
+ Ensure that the external Kafka topic exists.
+
+ Since no internal consumer subscribes to this topic, auto-creation
via consumer polling won't happen. We use AdminClient to ensure
the topic exists before producing to it.
"""
@@ -92,10 +97,12 @@ def _setup_kafka_output_topics(self):
try:
admin_client.create_topics([NewTopic(self.external_kafka_topic, 1, 1)])
except Exception as e:
- logger.warning(f"Could not auto-create topic {self.external_kafka_topic}: {e}")
-
+ logger.warning(
+ f"Could not auto-create topic {self.external_kafka_topic}: {e}"
+ )
+
self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler()
-
+
def get_and_fill_data(self) -> None:
if self.alert_data:
logger.warning(
@@ -121,7 +128,7 @@ def _log_to_file_action(self):
"""
if not self.log_to_file:
return
-
+
logger.info(f"{self.name}: Logging alert to file {self.log_file_path}")
try:
with open(self.log_file_path, "a+") as f:
@@ -138,7 +145,9 @@ def _log_to_kafka_action(self):
if not self.log_to_kafka:
return
- logger.info(f"{self.name}: Forwarding alert to topic {self.external_kafka_topic}")
+ logger.info(
+ f"{self.name}: Forwarding alert to topic {self.external_kafka_topic}"
+ )
try:
self.kafka_produce_handler.produce(
topic=self.external_kafka_topic,
@@ -151,7 +160,7 @@ def _log_to_kafka_action(self):
def bootstrap_alerter_instance(self):
"""
- Main loop for the alerter instance.
+ Main loop for the alerter instance.
Consumes alerts, processes them, and executes base actions.
"""
logger.info(f"Starting {self.name} Alerter")
@@ -185,18 +194,17 @@ async def start(self):
await loop.run_in_executor(None, self.bootstrap_alerter_instance)
-
async def main():
tasks = []
-
+
# Setup Generic Alerter Task
generic_topic = f"{CONSUME_TOPIC_PREFIX}-generic"
- logger.info("Initializing Generic Alerter")
+ logger.info("Initializing Generic Alerter")
class_name = "GenericAlerter"
mod_name = f"{PLUGIN_PATH}.generic_alerter"
module = importlib.import_module(mod_name)
AlerterClass = getattr(module, class_name)
-
+
generic_alerter = AlerterClass(
alerter_config={"name": "generic"}, consume_topic=generic_topic
)
@@ -211,12 +219,12 @@ async def main():
mod_name = f"{PLUGIN_PATH}.{alerter_config['alerter_module_name']}"
module = importlib.import_module(mod_name)
AlerterClass = getattr(module, class_name)
-
+
alerter_instance = AlerterClass(
alerter_config=alerter_config, consume_topic=consume_topic
)
tasks.append(asyncio.create_task(alerter_instance.start()))
-
+
await asyncio.gather(*tasks)
diff --git a/src/alerter/plugins/generic_alerter.py b/src/alerter/plugins/generic_alerter.py
index 6816d106..6ba6a62e 100644
--- a/src/alerter/plugins/generic_alerter.py
+++ b/src/alerter/plugins/generic_alerter.py
@@ -11,15 +11,15 @@
class GenericAlerter(AlerterBase):
"""
- Specific implementation for an Alerter that processes alerts
- from a generic topic.
-
+ Specific implementation for an Alerter that processes alerts
+ from a generic topic.
+
It performs no additional processing or transformation by itself,
instead relying solely on the base actions (logging to file/Kafka).
"""
+
def process_alert(self):
"""
Generic implementation: no special processing needed.
"""
pass
-
diff --git a/src/base/utils.py b/src/base/utils.py
index 3328a06e..76115d30 100644
--- a/src/base/utils.py
+++ b/src/base/utils.py
@@ -233,4 +233,4 @@ def generate_collisions_resistant_uuid():
def ensure_directory(file_path):
directory = os.path.dirname(file_path)
if directory:
- os.makedirs(directory, exist_ok=True)
\ No newline at end of file
+ os.makedirs(directory, exist_ok=True)
diff --git a/src/detector/detector.py b/src/detector/detector.py
index deb21f33..00b87e31 100644
--- a/src/detector/detector.py
+++ b/src/detector/detector.py
@@ -317,7 +317,9 @@ def detect(self) -> None:
y_pred = self.predict(message)
logger.info(f"Prediction: {y_pred}")
# TODO: DO NOT USE if TRUE for prod!!!
- if True: # np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > self.threshold:
+ if (
+ True
+ ): # np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > self.threshold:
logger.info("Append malicious request to warning.")
warning = {
"request": message,
@@ -361,11 +363,11 @@ def send_warning(self) -> None:
"src_ip": self.key,
"alert_timestamp": datetime.datetime.now().isoformat(),
"suspicious_batch_id": str(self.suspicious_batch_id),
- "detector_name": self.name
+ "detector_name": self.name,
}
logger.info(f"Producing alert to Kafka: {alert}")
-
+
for topic in self.produce_topics:
self.kafka_produce_handler.produce(
topic=topic,
@@ -526,13 +528,16 @@ async def main(): # pragma: no cover
"""
# ensure all detectors configure what to do
# instead of doing ensure alert directly we now use alerter topics
-
+
tasks = []
for detector_config in DETECTORS:
consume_topic = f"{CONSUME_TOPIC_PREFIX}-{detector_config['name']}"
produce_topics_str = detector_config.get("produce_topics", "")
if produce_topics_str:
- produce_topics = [f"{PRODUCE_TOPIC_PREFIX}-{t.strip()}" for t in produce_topics_str.split(",")]
+ produce_topics = [
+ f"{PRODUCE_TOPIC_PREFIX}-{t.strip()}"
+ for t in produce_topics_str.split(",")
+ ]
else:
produce_topics = [f"{PRODUCE_TOPIC_PREFIX}-generic"]
@@ -541,7 +546,9 @@ async def main(): # pragma: no cover
module = importlib.import_module(module_name)
DetectorClass = getattr(module, class_name)
detector = DetectorClass(
- detector_config=detector_config, consume_topic=consume_topic, produce_topics=produce_topics
+ detector_config=detector_config,
+ consume_topic=consume_topic,
+ produce_topics=produce_topics,
)
tasks.append(asyncio.create_task(detector.start()))
await asyncio.gather(*tasks)
diff --git a/src/detector/plugins/domainator_detector.py b/src/detector/plugins/domainator_detector.py
index 2fec43d7..5111de86 100644
--- a/src/detector/plugins/domainator_detector.py
+++ b/src/detector/plugins/domainator_detector.py
@@ -13,7 +13,7 @@
class DomainatorDetector(DetectorBase):
"""
- Detector implementation for identifying data exfiltration and command and control on the
+ Detector implementation for identifying data exfiltration and command and control on the
subdomain level.
This class extends the DetectorBase to provide specific functionality for detecting
@@ -91,12 +91,12 @@ def predict(self, messages):
np.ndarray: Prediction probabilities for each class. Typically a 2D array
where the shape is (1, 2) for binary classification (benign/malicious).
"""
- queries = [message['domain_name'] for message in messages]
+ queries = [message["domain_name"] for message in messages]
y_pred = self.model.predict_proba(self._get_features(queries))
print(f"Prediction: {y_pred}")
return y_pred
-
+
def detect(self):
logger.info("Start detecting malicious requests.")
for message in self.messages:
@@ -119,7 +119,6 @@ def detect(self):
if len(self.message_queues[message_domain]) >= 10:
del self.message_queues[message_domain][0]
-
def _strip_domain(self, query: str):
"""Extract the domain name from the message for the window grouping
@@ -140,7 +139,6 @@ def _strip_domain(self, query: str):
return domain
-
def _get_features(self, queries: list) -> np.ndarray:
"""Extracts feature vector from domain name for ML model inference.
@@ -154,32 +152,73 @@ def _get_features(self, queries: list) -> np.ndarray:
Returns:
numpy.ndarray: Feature vector ready for ML model prediction.
"""
-
+
queries = [query.strip(".") for query in queries]
- subdomains = ['.'.join(domain.split(".")[:-2]) for domain in queries]
+ subdomains = [".".join(domain.split(".")[:-2]) for domain in queries]
# Values can be put directly into an array, as the return converts them anyway,
# but this slightly improves readability
metrics = {
- 'levenshtein': [],
- 'jaro': [],
- 'rev_jaro': [],
- 'jaro_winkler': [],
- 'rev_jaro_wink': [],
- 'lcs_seq': [],
- 'lcs_str': [],
+ "levenshtein": [],
+ "jaro": [],
+ "rev_jaro": [],
+ "jaro_winkler": [],
+ "rev_jaro_wink": [],
+ "lcs_seq": [],
+ "lcs_str": [],
}
# if subdomains:
cartesian = list(itertools.combinations(subdomains, 2))
- metrics['levenshtein'] = np.mean([Levenshtein.ratio(product[0], product[1]) for product in cartesian])
- metrics['jaro'] = np.mean([Levenshtein.jaro(product[0], product[1]) for product in cartesian])
- metrics['jaro_winkler'] = np.mean([Levenshtein.jaro_winkler(product[0], product[1], prefix_weight=0.2) for product in cartesian])
- metrics['rev_jaro'] = np.mean([Levenshtein.jaro(product[0][::-1], product[1][::-1]) for product in cartesian])
- metrics['rev_jaro_wink'] = np.mean([Levenshtein.jaro_winkler(product[0][::-1], product[1][::-1], prefix_weight=0.2) for product in cartesian])
+ metrics["levenshtein"] = np.mean(
+ [Levenshtein.ratio(product[0], product[1]) for product in cartesian]
+ )
+ metrics["jaro"] = np.mean(
+ [Levenshtein.jaro(product[0], product[1]) for product in cartesian]
+ )
+ metrics["jaro_winkler"] = np.mean(
+ [
+ Levenshtein.jaro_winkler(product[0], product[1], prefix_weight=0.2)
+ for product in cartesian
+ ]
+ )
+ metrics["rev_jaro"] = np.mean(
+ [
+ Levenshtein.jaro(product[0][::-1], product[1][::-1])
+ for product in cartesian
+ ]
+ )
+ metrics["rev_jaro_wink"] = np.mean(
+ [
+ Levenshtein.jaro_winkler(
+ product[0][::-1], product[1][::-1], prefix_weight=0.2
+ )
+ for product in cartesian
+ ]
+ )
- metrics['lcs_seq'] = np.mean([pylcs.lcs_sequence_length(product[0], product[1])/((len(product[0]) + len(product[1]))/2) if len(product[0]) and len(product[1]) else 0.0 for product in cartesian ])
- metrics['lcs_str'] = np.mean([pylcs.lcs_string_length(product[0], product[1])/((len(product[0]) + len(product[1]))/2) if len(product[0]) and len(product[1]) else 0.0 for product in cartesian])
+ metrics["lcs_seq"] = np.mean(
+ [
+ (
+ pylcs.lcs_sequence_length(product[0], product[1])
+ / ((len(product[0]) + len(product[1])) / 2)
+ if len(product[0]) and len(product[1])
+ else 0.0
+ )
+ for product in cartesian
+ ]
+ )
+ metrics["lcs_str"] = np.mean(
+ [
+ (
+ pylcs.lcs_string_length(product[0], product[1])
+ / ((len(product[0]) + len(product[1])) / 2)
+ if len(product[0]) and len(product[1])
+ else 0.0
+ )
+ for product in cartesian
+ ]
+ )
- return np.fromiter(metrics.values(), dtype=float).reshape(1, -1)
\ No newline at end of file
+ return np.fromiter(metrics.values(), dtype=float).reshape(1, -1)
diff --git a/src/train/dataset.py b/src/train/dataset.py
index d3da13cf..cd1606fe 100644
--- a/src/train/dataset.py
+++ b/src/train/dataset.py
@@ -262,7 +262,7 @@ def cast_heicloud(data_path: str, max_rows: int) -> pl.DataFrame:
def cast_domainator(data_path: List[str], max_rows: int) -> pl.DataFrame:
"""Loads and processes Domainator dataset from multiple CSV files.
- Reads Domainator datasets (benign, malicious), appends a user source if not present,
+ Reads Domainator datasets (benign, malicious), appends a user source if not present,
then processes the queries and combines the datasets into one for training
Args:
@@ -280,10 +280,10 @@ def cast_domainator(data_path: List[str], max_rows: int) -> pl.DataFrame:
path,
separator=",",
has_header=True,
- n_rows=max_rows if max_rows > 0 else None
+ n_rows=max_rows if max_rows > 0 else None,
)
- if 'user' not in df.columns:
- df.insert_column(0, pl.Series('user', ['testbed']*len(df)))
+ if "user" not in df.columns:
+ df.insert_column(0, pl.Series("user", ["testbed"] * len(df)))
df = preprocess(df, keep_all=True)
logger.info(f"Data loaded with shape {df.shape}")
dataframes.append(df)
@@ -349,6 +349,7 @@ def heicloud_dataset(self) -> Dataset:
max_rows=self.max_rows,
)
return self.heicloud_data
+
@property
def dgarchive_dataset(self) -> list[Dataset]:
dgarchive_files = [
@@ -365,17 +366,17 @@ def dgarchive_dataset(self) -> list[Dataset]:
)
)
return self.dgarchive_data
-
+
@property
def domainator_dataset(self) -> Dataset:
self.domainator_data = Dataset(
name="domainator",
data_path={
f"{self.base_path}/domainator/domainator_combined.csv",
- f"{self.base_path}/domainator/domainator_ziza.csv"
+ f"{self.base_path}/domainator/domainator_ziza.csv",
},
cast_dataset=cast_domainator,
- max_rows=self.max_rows
+ max_rows=self.max_rows,
)
logger.debug("Domainator Loader")
diff --git a/src/train/feature.py b/src/train/feature.py
index 900a9890..bc419175 100644
--- a/src/train/feature.py
+++ b/src/train/feature.py
@@ -168,58 +168,115 @@ def transform_domainator(self, x: pl.DataFrame) -> pl.DataFrame:
window_size = 10
min_window_size = 3
-
x = x.with_columns(
- pl.concat_str([pl.col('secondleveldomain'), pl.col('tld')], separator='.').alias('domain')
+ pl.concat_str(
+ [pl.col("secondleveldomain"), pl.col("tld")], separator="."
+ ).alias("domain")
)
- for user in x['user'].unique():
+ for user in x["user"].unique():
# logger.debug(x.filter(pl.col('user') == user))
- for domain in x.filter(pl.col('user') == user)['domain'].unique():
- sub_list = x.filter((pl.col('user') == user) & (pl.col('domain') == domain))['thirdleveldomain']
- true_class = x.filter(pl.col('domain') == domain)['class'].unique() # currently assumes domain is not both malicious and legitimate
-
- windows = [sub_list[i:i+window_size] for i in range(0, len(sub_list), window_size)]
+ for domain in x.filter(pl.col("user") == user)["domain"].unique():
+ sub_list = x.filter(
+ (pl.col("user") == user) & (pl.col("domain") == domain)
+ )["thirdleveldomain"]
+ true_class = x.filter(pl.col("domain") == domain)[
+ "class"
+ ].unique() # currently assumes domain is not both malicious and legitimate
+
+ windows = [
+ sub_list[i : i + window_size]
+ for i in range(0, len(sub_list), window_size)
+ ]
if not windows:
windows = sub_list
-
for item in windows:
if len(item) > min_window_size:
cartesian = list(itertools.combinations(item, 2))
metrics = {
- 'user': user,
- 'class': true_class[0],
- 'query': domain,
- 'levenshtein': [],
- 'jaro': [],
- 'rev_jaro': [],
- 'jaro_winkler': [],
- 'rev_jaro_wink': [],
- 'lcs_seq': [],
- 'lcs_str': [],
+ "user": user,
+ "class": true_class[0],
+ "query": domain,
+ "levenshtein": [],
+ "jaro": [],
+ "rev_jaro": [],
+ "jaro_winkler": [],
+ "rev_jaro_wink": [],
+ "lcs_seq": [],
+ "lcs_str": [],
}
- metrics['levenshtein'] = np.mean([Levenshtein.ratio(product[0], product[1]) for product in cartesian])
- metrics['jaro'] = np.mean([Levenshtein.jaro(product[0], product[1]) for product in cartesian])
- metrics['jaro_winkler'] = np.mean([Levenshtein.jaro_winkler(product[0], product[1], prefix_weight=0.2) for product in cartesian])
- metrics['rev_jaro'] = np.mean([Levenshtein.jaro(product[0][::-1], product[1][::-1]) for product in cartesian])
- metrics['rev_jaro_wink'] = np.mean([Levenshtein.jaro_winkler(product[0][::-1], product[1][::-1], prefix_weight=0.2) for product in cartesian])
+ metrics["levenshtein"] = np.mean(
+ [
+ Levenshtein.ratio(product[0], product[1])
+ for product in cartesian
+ ]
+ )
+ metrics["jaro"] = np.mean(
+ [
+ Levenshtein.jaro(product[0], product[1])
+ for product in cartesian
+ ]
+ )
+ metrics["jaro_winkler"] = np.mean(
+ [
+ Levenshtein.jaro_winkler(
+ product[0], product[1], prefix_weight=0.2
+ )
+ for product in cartesian
+ ]
+ )
+ metrics["rev_jaro"] = np.mean(
+ [
+ Levenshtein.jaro(product[0][::-1], product[1][::-1])
+ for product in cartesian
+ ]
+ )
+ metrics["rev_jaro_wink"] = np.mean(
+ [
+ Levenshtein.jaro_winkler(
+ product[0][::-1],
+ product[1][::-1],
+ prefix_weight=0.2,
+ )
+ for product in cartesian
+ ]
+ )
- metrics['lcs_seq'] = np.mean([pylcs.lcs_sequence_length(product[0], product[1])/((len(product[0]) + len(product[1]))/2) if len(product[0]) and len(product[1]) else 0.0 for product in cartesian ])
- metrics['lcs_str'] = np.mean([pylcs.lcs_string_length(product[0], product[1])/((len(product[0]) + len(product[1]))/2) if len(product[0]) and len(product[1]) else 0.0 for product in cartesian])
+ metrics["lcs_seq"] = np.mean(
+ [
+ (
+ pylcs.lcs_sequence_length(product[0], product[1])
+ / ((len(product[0]) + len(product[1])) / 2)
+ if len(product[0]) and len(product[1])
+ else 0.0
+ )
+ for product in cartesian
+ ]
+ )
+ metrics["lcs_str"] = np.mean(
+ [
+ (
+ pylcs.lcs_string_length(product[0], product[1])
+ / ((len(product[0]) + len(product[1])) / 2)
+ if len(product[0]) and len(product[1])
+ else 0.0
+ )
+ for product in cartesian
+ ]
+ )
metrics_list.append(metrics)
df = pl.from_dicts(metrics_list)
logger.debug(df)
- logger.debug(df['class'].unique())
-
+ logger.debug(df["class"].unique())
df = df.drop(["user"])
logger.debug("Transform done")
- return df
\ No newline at end of file
+ return df
diff --git a/src/train/model.py b/src/train/model.py
index d2b817f4..ad91752d 100644
--- a/src/train/model.py
+++ b/src/train/model.py
@@ -84,9 +84,9 @@ def __init__(
try:
X, y = self._load_npy(ds.name)
except FileNotFoundError:
- if ds.name == 'domainator':
+ if ds.name == "domainator":
ds.data = self.processor.transform_domainator(x=ds.data)
- data = ds.data.drop('query')
+ data = ds.data.drop("query")
else:
data = self.processor.transform(x=ds.data)
X = data.drop("class").to_numpy()
diff --git a/src/train/train.py b/src/train/train.py
index dbab9e2d..2a1d929d 100644
--- a/src/train/train.py
+++ b/src/train/train.py
@@ -41,7 +41,7 @@ class DatasetEnum(str, Enum):
COMBINE = "combine"
DGTA = "dgta"
DGARCHIVE = "dgarchive"
- DOMAINATOR = 'domainator'
+ DOMAINATOR = "domainator"
@unique
diff --git a/tests/detector/test_domainator_detector.py b/tests/detector/test_domainator_detector.py
index ae25f5ae..81096141 100644
--- a/tests/detector/test_domainator_detector.py
+++ b/tests/detector/test_domainator_detector.py
@@ -5,6 +5,7 @@
import os
import sys
+
sys.path.append(os.getcwd())
from src.detector.plugins.domainator_detector import DomainatorDetector
@@ -116,7 +117,9 @@ def test_get_features_basic_attributes(self):
detector = self._create_detector(mock_kafka, mock_ch)
# Test with various 'google.com' subdomains
- features = detector._get_features(["sub1.google.com", "sub2.google.com", "sub3.google.com"])
+ features = detector._get_features(
+ ["sub1.google.com", "sub2.google.com", "sub3.google.com"]
+ )
# Basic features: label_length, label_max, label_average
leven_dist = features[0][0] # Levenshtein distance
@@ -124,7 +127,7 @@ def test_get_features_basic_attributes(self):
lcs = features[0][6] # Longest common string
self.assertEqual(leven_dist, 0.75)
- self.assertAlmostEqual(jaro_dist, 0.833, 3) # Rounded to 3 decimal places
+ self.assertAlmostEqual(jaro_dist, 0.833, 3) # Rounded to 3 decimal places
self.assertEqual(lcs, 0.75)
def test_get_features_empty_domains(self):
@@ -138,13 +141,25 @@ def test_get_features_empty_domains(self):
print(features[0][0], features[0][1], features[0][2])
# Basic features
- self.assertEqual(features[0][0], 1.) # Levenshtein distance of empty strings is 1
- self.assertEqual(features[0][1], 1.) # Jaro distance of empty strings is 1
- self.assertEqual(features[0][2], 1.) # Jaro distance on the reverse empty strings is 1
- self.assertEqual(features[0][3], 1.) # Jaro-Winkler distance of empty strings is 1
- self.assertEqual(features[0][4], 1.) # Jaro-Winkler distance on the reverse empty strings is 1
- self.assertEqual(features[0][5], 0.) # Longest common sequence of empty strings is 0
- self.assertEqual(features[0][6], 0.) # Longest common string of empty strings is 0
+ self.assertEqual(
+ features[0][0], 1.0
+ ) # Levenshtein distance of empty strings is 1
+ self.assertEqual(features[0][1], 1.0) # Jaro distance of empty strings is 1
+ self.assertEqual(
+ features[0][2], 1.0
+ ) # Jaro distance on the reverse empty strings is 1
+ self.assertEqual(
+ features[0][3], 1.0
+ ) # Jaro-Winkler distance of empty strings is 1
+ self.assertEqual(
+ features[0][4], 1.0
+ ) # Jaro-Winkler distance on the reverse empty strings is 1
+ self.assertEqual(
+ features[0][5], 0.0
+ ) # Longest common sequence of empty strings is 0
+ self.assertEqual(
+ features[0][6], 0.0
+ ) # Longest common string of empty strings is 0
def test_get_features_single_same_character(self):
"""Test handling of single character domain."""
@@ -155,13 +170,25 @@ def test_get_features_single_same_character(self):
features = detector._get_features(["a", "a", "a"])
# Basic features
- self.assertEqual(features[0][0], 1.) # Levenshtein distance of same strings is 1
- self.assertEqual(features[0][1], 1.) # Jaro distance of same strings is 1
- self.assertEqual(features[0][2], 1.) # Jaro distance on the reverse same strings is 1
- self.assertEqual(features[0][3], 1.) # Jaro-Winkler distance of same strings is 1
- self.assertEqual(features[0][4], 1.) # Jaro-Winkler distance on the reverse same strings is 1
- self.assertEqual(features[0][5], 0.) # Longest common sequence of same strings is 0
- self.assertEqual(features[0][6], 0.) # Longest common string of same strings is 0
+ self.assertEqual(
+ features[0][0], 1.0
+ ) # Levenshtein distance of same strings is 1
+ self.assertEqual(features[0][1], 1.0) # Jaro distance of same strings is 1
+ self.assertEqual(
+ features[0][2], 1.0
+ ) # Jaro distance on the reverse same strings is 1
+ self.assertEqual(
+ features[0][3], 1.0
+ ) # Jaro-Winkler distance of same strings is 1
+ self.assertEqual(
+ features[0][4], 1.0
+ ) # Jaro-Winkler distance on the reverse same strings is 1
+ self.assertEqual(
+ features[0][5], 0.0
+ ) # Longest common sequence of same strings is 0
+ self.assertEqual(
+ features[0][6], 0.0
+ ) # Longest common string of same strings is 0
def test_get_features_feature_vector_shape(self):
"""Test that the feature vector has the expected shape."""
@@ -169,7 +196,9 @@ def test_get_features_feature_vector_shape(self):
mock_ch = MagicMock()
detector = self._create_detector(mock_kafka, mock_ch)
- features = detector._get_features(["test.domain.com", "test.domain.com", "test.domain.com"])
+ features = detector._get_features(
+ ["test.domain.com", "test.domain.com", "test.domain.com"]
+ )
expected_entropy = 7
@@ -181,12 +210,16 @@ def test_get_features_case_insensitivity(self):
mock_ch = MagicMock()
detector = self._create_detector(mock_kafka, mock_ch)
- features_upper = detector._get_features(["DRIVE.GOOGLE.COM", "WORKSPACE.GOOGLE.COM"])
- features_lower = detector._get_features(["drive.google.com", "workspace.google.com"])
+ features_upper = detector._get_features(
+ ["DRIVE.GOOGLE.COM", "WORKSPACE.GOOGLE.COM"]
+ )
+ features_lower = detector._get_features(
+ ["drive.google.com", "workspace.google.com"]
+ )
# The comparison features should be identical regardless of case
np.testing.assert_array_almost_equal(
features_upper[0][0:],
features_lower[0][0:],
decimal=5,
- )
\ No newline at end of file
+ )