From abf9e7bb9de072f1845dcb8438a1186d2ecba5a9 Mon Sep 17 00:00:00 2001 From: deacon Date: Mon, 16 Mar 2026 09:37:36 -0400 Subject: [PATCH 1/2] test: add exhaustive pytest coverage for atomic plugin --- pytest.ini | 6 + tests/conftest.py | 281 +++++++++++ tests/test_atomic_gui.py | 60 +++ tests/test_atomic_powershell.py | 160 ++++++ tests/test_atomic_svc.py | 844 ++++++++++++++++++++++++++++++-- tests/test_hook.py | 124 +++++ 6 files changed, 1423 insertions(+), 52 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/conftest.py create mode 100644 tests/test_atomic_gui.py create mode 100644 tests/test_atomic_powershell.py create mode 100644 tests/test_hook.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..7fca8d4 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests +asyncio_mode = auto +markers = + unit: unit tests + integration: integration tests diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a7a7b73 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,281 @@ +import hashlib +import os +import sys +import types +import logging +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from collections import defaultdict + +# --------------------------------------------------------------------------- +# Determine paths +# --------------------------------------------------------------------------- +_repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, _repo_root) + +# --------------------------------------------------------------------------- +# Stub heavy Caldera imports BEFORE importing any plugin code. +# We create real module objects (not MagicMock) for 'app' so that +# sub-module imports like `from app.atomic_svc import ...` work. +# --------------------------------------------------------------------------- + +# Create the 'app' package as a real namespace package +_app_pkg = types.ModuleType('app') +_app_pkg.__path__ = [os.path.join(_repo_root, 'app')] +_app_pkg.__package__ = 'app' +sys.modules['app'] = _app_pkg + +# app.utility +_app_utility = types.ModuleType('app.utility') +_app_utility.__path__ = [os.path.join(_repo_root, 'app', 'utility')] +_app_utility.__package__ = 'app.utility' +sys.modules['app.utility'] = _app_utility + +# app.objects +_app_objects = types.ModuleType('app.objects') +_app_objects.__path__ = [os.path.join(_repo_root, 'app', 'objects')] +_app_objects.__package__ = 'app.objects' +sys.modules['app.objects'] = _app_objects + +# app.service +_app_service = types.ModuleType('app.service') +_app_service.__path__ = [os.path.join(_repo_root, 'app', 'service')] +_app_service.__package__ = 'app.service' +sys.modules['app.service'] = _app_service + +# app.parsers +_app_parsers = types.ModuleType('app.parsers') +_app_parsers.__path__ = [os.path.join(_repo_root, 'app', 'parsers')] +_app_parsers.__package__ = 'app.parsers' +sys.modules['app.parsers'] = _app_parsers + +# -- app.utility.base_world -- +_base_world_mod = types.ModuleType('app.utility.base_world') + + +class BaseWorld: + class Access: + RED = 'red' + + @staticmethod + def strip_yml(path): + return [] + + +_base_world_mod.BaseWorld = BaseWorld +sys.modules['app.utility.base_world'] = _base_world_mod + +# -- app.utility.base_service -- +_base_service_mod = types.ModuleType('app.utility.base_service') + + +class BaseService: + @staticmethod + def add_service(name, svc): + return logging.getLogger(name) + + +_base_service_mod.BaseService = BaseService +sys.modules['app.utility.base_service'] = _base_service_mod + +# -- app.utility.base_parser -- +PARSER_SIGNALS_FAILURE = 'failure' +_base_parser_mod = types.ModuleType('app.utility.base_parser') + + +class BaseParser: + def line(self, blob): + return blob.strip().splitlines() + + +_base_parser_mod.BaseParser = BaseParser +_base_parser_mod.PARSER_SIGNALS_FAILURE = PARSER_SIGNALS_FAILURE +sys.modules['app.utility.base_parser'] = _base_parser_mod + +# -- app.objects.c_agent -- +_agent_mod = types.ModuleType('app.objects.c_agent') + + +class Agent: + RESERVED = ['#{server}', '#{group}', '#{paw}', '#{location}'] + + +_agent_mod.Agent = Agent +sys.modules['app.objects.c_agent'] = _agent_mod + +# -- app.service.auth_svc -- +_auth_svc_mod = types.ModuleType('app.service.auth_svc') +_auth_svc_mod.for_all_public_methods = lambda fn: lambda cls: cls +_auth_svc_mod.check_authorization = MagicMock() +sys.modules['app.service.auth_svc'] = _auth_svc_mod + +# -- plugin namespace stubs -- +_plugins = types.ModuleType('plugins') +_plugins.__path__ = [] +sys.modules['plugins'] = _plugins + +_plugins_atomic = types.ModuleType('plugins.atomic') +_plugins_atomic.__path__ = [_repo_root] +sys.modules['plugins.atomic'] = _plugins_atomic + +_plugins_atomic_app = types.ModuleType('plugins.atomic.app') +_plugins_atomic_app.__path__ = [os.path.join(_repo_root, 'app')] +sys.modules['plugins.atomic.app'] = _plugins_atomic_app + +_plugins_atomic_app_parsers = types.ModuleType('plugins.atomic.app.parsers') +_plugins_atomic_app_parsers.__path__ = [os.path.join(_repo_root, 'app', 'parsers')] +sys.modules['plugins.atomic.app.parsers'] = _plugins_atomic_app_parsers + +# --------------------------------------------------------------------------- +# Now import the real plugin modules +# --------------------------------------------------------------------------- +from app.atomic_svc import AtomicService # noqa: E402 +from app.atomic_gui import AtomicGUI # noqa: E402 +from app.parsers.atomic_powershell import Parser as AtomicPowershellParser # noqa: E402 + +# Register under plugins.atomic namespace too +import app.atomic_svc as _real_atomic_svc +import app.atomic_gui as _real_atomic_gui +import app.parsers.atomic_powershell as _real_atomic_parser + +sys.modules['plugins.atomic.app.atomic_svc'] = _real_atomic_svc +sys.modules['plugins.atomic.app.atomic_gui'] = _real_atomic_gui +sys.modules['plugins.atomic.app.parsers.atomic_powershell'] = _real_atomic_parser + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +DUMMY_PAYLOAD_PATH = '/tmp/dummyatomicpayload' +DUMMY_PAYLOAD_CONTENT = 'Dummy payload content.' +PREFIX_HASH_LENGTH = 6 + + +@pytest.fixture +def atomic_svc(): + return AtomicService() + + +@pytest.fixture +def generate_dummy_payload(): + with open(DUMMY_PAYLOAD_PATH, 'w') as f: + f.write(DUMMY_PAYLOAD_CONTENT) + yield DUMMY_PAYLOAD_PATH + if os.path.exists(DUMMY_PAYLOAD_PATH): + os.remove(DUMMY_PAYLOAD_PATH) + + +@pytest.fixture +def multiline_command(): + return '\n'.join([ + 'command1', + 'command2', + 'command3', + ]) + + +@pytest.fixture +def atomic_test(): + return { + 'name': 'Qakbot Recon', + 'auto_generated_guid': '121de5c6-5818-4868-b8a7-8fd07c455c1b', + 'description': 'A list of commands known to be performed by Qakbot', + 'supported_platforms': ['windows'], + 'input_arguments': { + 'recon_commands': { + 'description': 'File that houses commands to be executed', + 'type': 'Path', + 'default': 'PathToAtomicsFolder\\T1016\\src\\qakbot.bat' + } + }, + 'executor': { + 'command': '#{recon_commands}\n', + 'name': 'command_prompt' + } + } + + +@pytest.fixture +def atomic_test_linux(): + return { + 'name': 'Linux Recon', + 'auto_generated_guid': 'aabbccdd-1111-2222-3333-444455556666', + 'description': 'Linux reconnaissance commands', + 'supported_platforms': ['linux'], + 'input_arguments': { + 'output_file': { + 'description': 'Output file path', + 'type': 'Path', + 'default': '/tmp/output.txt' + } + }, + 'executor': { + 'command': 'whoami > #{output_file}\nhostname >> #{output_file}\n', + 'name': 'sh' + } + } + + +@pytest.fixture +def atomic_test_manual(): + return { + 'name': 'Manual Test', + 'auto_generated_guid': 'deadbeef-0000-1111-2222-333344445555', + 'description': 'Manual test that should be skipped', + 'supported_platforms': ['windows'], + 'input_arguments': {}, + 'executor': { + 'command': 'Do this manually', + 'name': 'manual' + } + } + + +@pytest.fixture +def atomic_entries(): + return { + 'attack_technique': 'T1016', + 'display_name': 'System Network Configuration Discovery' + } + + +@pytest.fixture +def mitre_json_data(): + return { + 'objects': [ + { + 'type': 'attack-pattern', + 'external_references': [ + {'source_name': 'mitre-attack', 'external_id': 'T1016'} + ], + 'kill_chain_phases': [ + {'kill_chain_name': 'mitre-attack', 'phase_name': 'discovery'} + ] + }, + { + 'type': 'attack-pattern', + 'external_references': [ + {'source_name': 'mitre-attack', 'external_id': 'T1059'} + ], + 'kill_chain_phases': [ + {'kill_chain_name': 'mitre-attack', 'phase_name': 'execution'}, + {'kill_chain_name': 'mitre-attack', 'phase_name': 'persistence'} + ] + }, + { + 'type': 'malware', + 'external_references': [ + {'source_name': 'mitre-attack', 'external_id': 'S0001'} + ] + }, + { + 'type': 'attack-pattern', + 'external_references': [ + {'source_name': 'other-source', 'external_id': 'X9999'} + ], + 'kill_chain_phases': [ + {'kill_chain_name': 'other-chain', 'phase_name': 'other'} + ] + } + ] + } diff --git a/tests/test_atomic_gui.py b/tests/test_atomic_gui.py new file mode 100644 index 0000000..8d248f2 --- /dev/null +++ b/tests/test_atomic_gui.py @@ -0,0 +1,60 @@ +import logging +import pytest +from unittest.mock import MagicMock + +from app.atomic_gui import AtomicGUI + + +class TestAtomicGUIInit: + """Tests for AtomicGUI initialization and configuration.""" + + def test_init_stores_auth_svc(self): + services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()} + gui = AtomicGUI(services, 'TestAtomic', 'Test description') + assert gui.auth_svc is services['auth_svc'] + + def test_init_stores_data_svc(self): + services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()} + gui = AtomicGUI(services, 'TestAtomic', 'Test description') + assert gui.data_svc is services['data_svc'] + + def test_init_creates_logger(self): + services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()} + gui = AtomicGUI(services, 'TestAtomic', 'Test description') + assert isinstance(gui.log, logging.Logger) + assert gui.log.name == 'atomic_gui' + + def test_init_with_missing_services(self): + """If services dict doesn't have keys, attributes should be None.""" + services = {} + gui = AtomicGUI(services, 'Atomic', 'desc') + assert gui.auth_svc is None + assert gui.data_svc is None + + def test_init_name_description_not_stored(self): + """AtomicGUI receives name/description but does not store them as attributes.""" + services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()} + gui = AtomicGUI(services, 'MyName', 'MyDesc') + # name and description are passed but not stored on the instance + assert not hasattr(gui, 'name') or gui.name != 'MyName' + assert not hasattr(gui, 'description') or gui.description != 'MyDesc' + + def test_multiple_instances_independent(self): + """Each instance should have its own services.""" + svc1 = {'auth_svc': MagicMock(name='auth1'), 'data_svc': MagicMock(name='data1')} + svc2 = {'auth_svc': MagicMock(name='auth2'), 'data_svc': MagicMock(name='data2')} + gui1 = AtomicGUI(svc1, 'A', 'a') + gui2 = AtomicGUI(svc2, 'B', 'b') + assert gui1.auth_svc is not gui2.auth_svc + assert gui1.data_svc is not gui2.data_svc + + +class TestAtomicGUIIsBaseWorld: + """Verify AtomicGUI inherits from BaseWorld stub.""" + + def test_is_instance_of_base_class(self): + services = {'auth_svc': MagicMock(), 'data_svc': MagicMock()} + gui = AtomicGUI(services, 'Atomic', 'desc') + # AtomicGUI should be an instance of the BaseWorld stub + from app.utility.base_world import BaseWorld + assert isinstance(gui, BaseWorld) diff --git a/tests/test_atomic_powershell.py b/tests/test_atomic_powershell.py new file mode 100644 index 0000000..eda169e --- /dev/null +++ b/tests/test_atomic_powershell.py @@ -0,0 +1,160 @@ +import pytest + +from app.parsers.atomic_powershell import Parser + + +class TestParserCheckedFlags: + """ + Tests for the Parser.checked_flags class attribute. + + KNOWN BUG: `list('FullyQualifiedErrorId')` produces a list of individual + characters ['F', 'u', 'l', 'l', 'y', ...] instead of the intended + ['FullyQualifiedErrorId']. This means the parser checks for single + characters rather than the full error string. + """ + + def test_checked_flags_is_list(self): + assert isinstance(Parser.checked_flags, list) + + def test_checked_flags_known_bug_individual_characters(self): + """ + Demonstrates the known bug: list('FullyQualifiedErrorId') splits the + string into individual characters instead of wrapping it in a list. + """ + expected_buggy = list('FullyQualifiedErrorId') + assert Parser.checked_flags == expected_buggy + # This is what it SHOULD be: + expected_correct = ['FullyQualifiedErrorId'] + assert Parser.checked_flags != expected_correct + + def test_checked_flags_length_is_wrong(self): + """The list has 21 entries (one per char) instead of 1.""" + assert len(Parser.checked_flags) == len('FullyQualifiedErrorId') + assert len(Parser.checked_flags) != 1 + + def test_checked_flags_contains_individual_chars(self): + """Each element is a single character.""" + for flag in Parser.checked_flags: + assert len(flag) == 1 + + def test_checked_flags_first_char_is_F(self): + assert Parser.checked_flags[0] == 'F' + + def test_checked_flags_last_char_is_d(self): + assert Parser.checked_flags[-1] == 'd' + + +class TestParserParse: + """Tests for the Parser.parse() method.""" + + def test_parse_empty_blob(self): + parser = Parser() + result = parser.parse('') + assert result == [] + + def test_parse_no_error_indicators(self): + parser = Parser() + result = parser.parse('All good, no issues here.\nAnother clean line.') + # Because of the bug, any line containing common letters like 'l', 'e', + # 'i', etc. will be flagged. Let's check: + # 'All good, no issues here.' contains 'l' which is in checked_flags + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + assert result == [PARSER_SIGNALS_FAILURE] + + def test_parse_with_fully_qualified_error_id(self): + """A line containing 'FullyQualifiedErrorId' should trigger failure.""" + parser = Parser() + blob = 'Error: FullyQualifiedErrorId : SomeError' + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + result = parser.parse(blob) + assert result == [PARSER_SIGNALS_FAILURE] + + def test_parse_bug_false_positive_on_common_letters(self): + """ + Due to the bug, even innocent text containing letters like 'e', 'l', + 'i', 'r', 'd', etc. will trigger the parser as failed. + """ + parser = Parser() + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + # 'hello' contains 'l' and 'e' which are in list('FullyQualifiedErrorId') + result = parser.parse('hello') + assert result == [PARSER_SIGNALS_FAILURE] + + def test_parse_bug_no_false_positive_on_safe_chars(self): + """ + Text using ONLY characters NOT in 'FullyQualifiedErrorId' should pass. + Characters in the string: F, u, l, y, Q, a, i, f, e, d, E, r, o, I + So characters NOT in the set include: b, c, g, h, j, k, m, n, p, s, t, v, w, x, z + and digits, punctuation etc. + """ + parser = Parser() + # Using only characters NOT present in 'FullyQualifiedErrorId' + safe_text = '0123456789 -+*/' + result = parser.parse(safe_text) + assert result == [] + + def test_parse_returns_failure_signal_list(self): + """When failure is detected, return value is [PARSER_SIGNALS_FAILURE].""" + parser = Parser() + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + result = parser.parse('Some error output with letter e') + assert len(result) == 1 + assert result[0] == PARSER_SIGNALS_FAILURE + + def test_parse_multiline_first_line_triggers(self): + """Only the first matching line triggers failure, method returns early.""" + parser = Parser() + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + blob = 'line with F\nclean_0123' + result = parser.parse(blob) + assert result == [PARSER_SIGNALS_FAILURE] + + def test_parse_all_clean_lines_no_flagged_chars(self): + """Multiple lines all free of flagged characters should pass.""" + parser = Parser() + blob = '0123\n4567\n890' + result = parser.parse(blob) + assert result == [] + + def test_parse_correct_behavior_if_bug_fixed(self): + """ + Demonstrates what SHOULD happen if the bug were fixed: + Only 'FullyQualifiedErrorId' as a substring should trigger failure. + + Currently, because checked_flags is individual characters, this test + verifies the BUGGY behavior. + """ + parser = Parser() + from app.utility.base_parser import PARSER_SIGNALS_FAILURE + # Text without 'FullyQualifiedErrorId' but with common letters + result = parser.parse('Process completed successfully') + # BUG: triggers because 'e', 'l', etc. are in checked_flags + assert result == [PARSER_SIGNALS_FAILURE] + # If fixed, this would return [] instead + + +class TestParserLineMethod: + """Test the inherited line() method from BaseParser.""" + + def test_line_splits_blob(self): + parser = Parser() + lines = list(parser.line('line1\nline2\nline3')) + assert lines == ['line1', 'line2', 'line3'] + + def test_line_strips_outer_whitespace(self): + parser = Parser() + lines = list(parser.line(' line1 \n line2 ')) + # strip() removes leading/trailing whitespace from the whole blob, + # then splitlines preserves internal whitespace per line + assert lines == ['line1 ', ' line2'] + + def test_line_empty_blob(self): + parser = Parser() + lines = list(parser.line('')) + # ''.strip().splitlines() returns [] + assert lines == [] + + def test_line_single_line(self): + parser = Parser() + lines = list(parser.line('single')) + assert lines == ['single'] diff --git a/tests/test_atomic_svc.py b/tests/test_atomic_svc.py index b8d4a53..1fa7281 100644 --- a/tests/test_atomic_svc.py +++ b/tests/test_atomic_svc.py @@ -1,81 +1,153 @@ import hashlib +import json import os import re import pytest +from collections import defaultdict +from unittest.mock import patch, MagicMock, AsyncMock, mock_open + +from app.atomic_svc import AtomicService, ExtractionError, PLATFORMS, EXECUTORS, RE_VARIABLE, PREFIX_HASH_LEN -from plugins.atomic.app.atomic_svc import AtomicService DUMMY_PAYLOAD_PATH = '/tmp/dummyatomicpayload' DUMMY_PAYLOAD_CONTENT = 'Dummy payload content.' PREFIX_HASH_LENGTH = 6 -@pytest.fixture -def atomic_svc(): - return AtomicService() - - -@pytest.fixture -def generate_dummy_payload(): - with open(DUMMY_PAYLOAD_PATH, 'w') as f: - f.write(DUMMY_PAYLOAD_CONTENT) - yield DUMMY_PAYLOAD_PATH - os.remove(DUMMY_PAYLOAD_PATH) - - -@pytest.fixture -def multiline_command(): - return '\n'.join([ - 'command1', - 'command2', - 'command3', - ]) - - -@pytest.fixture -def atomic_test(): - return { - 'name': 'Qakbot Recon', - 'auto_generated_guid': '121de5c6-5818-4868-b8a7-8fd07c455c1b', - 'description': 'A list of commands known to be performed by Qakbot', - 'supported_platforms': ['windows'], - 'input_arguments': { - 'recon_commands': { - 'description': 'File that houses commands to be executed', - 'type': 'Path', - 'default': 'PathToAtomicsFolder\\T1016\\src\\qakbot.bat' - } - }, - 'executor': { - 'command': '#{recon_commands}\n', - 'name': - 'command_prompt' - } - } +# ============================================================================ +# Module-level constants +# ============================================================================ + +class TestModuleConstants: + """Verify module-level constants are set correctly.""" + + def test_platforms_mapping(self): + assert PLATFORMS == {'windows': 'windows', 'macos': 'darwin', 'linux': 'linux'} + + def test_executors_mapping(self): + assert EXECUTORS == {'command_prompt': 'cmd', 'sh': 'sh', 'powershell': 'psh', 'bash': 'sh'} + + def test_re_variable_pattern(self): + m = RE_VARIABLE.search('#{my_var}') + assert m is not None + assert m.group(2) == 'my_var' + + def test_re_variable_no_match(self): + assert RE_VARIABLE.search('no variables here') is None + + def test_re_variable_multiline(self): + m = RE_VARIABLE.search('#{multi\nline}') + assert m is not None + assert m.group(2) == 'multi\nline' + def test_prefix_hash_len(self): + assert PREFIX_HASH_LEN == 6 -class TestAtomicSvc: + +class TestExtractionError: + def test_is_exception(self): + with pytest.raises(ExtractionError): + raise ExtractionError('test') + + def test_inherits_from_exception(self): + assert issubclass(ExtractionError, Exception) + + +# ============================================================================ +# AtomicService init / config +# ============================================================================ + +class TestAtomicSvcConfig: def test_svc_config(self, atomic_svc): assert atomic_svc.repo_dir == 'plugins/atomic/data/atomic-red-team' assert atomic_svc.data_dir == 'plugins/atomic/data' assert atomic_svc.payloads_dir == 'plugins/atomic/payloads' + def test_atomic_dir(self, atomic_svc): + assert atomic_svc.atomic_dir == os.path.join('plugins', 'atomic') + + def test_technique_to_tactics_default_empty(self, atomic_svc): + assert isinstance(atomic_svc.technique_to_tactics, defaultdict) + assert len(atomic_svc.technique_to_tactics) == 0 + + def test_processing_debug_default_false(self, atomic_svc): + assert atomic_svc.processing_debug is False + + +# ============================================================================ +# Path normalization +# ============================================================================ + +class TestNormalizePath: def test_normalize_windows_path(self): assert AtomicService._normalize_path('windows\\test\\path', 'windows') == 'windows/test/path' def test_normalize_posix_path(self): assert AtomicService._normalize_path('linux/test/path', 'linux') == 'linux/test/path' - def test_handle_attachment(self, atomic_svc, generate_dummy_payload): + def test_normalize_darwin_path(self): + assert AtomicService._normalize_path('macos/test/path', 'darwin') == 'macos/test/path' + + def test_normalize_windows_no_backslash(self): + assert AtomicService._normalize_path('already/forward', 'windows') == 'already/forward' + + def test_normalize_empty_string(self): + assert AtomicService._normalize_path('', 'windows') == '' + assert AtomicService._normalize_path('', 'linux') == '' + + def test_normalize_nested_backslashes(self): + assert AtomicService._normalize_path('a\\b\\c\\d', 'windows') == 'a/b/c/d' + + def test_normalize_linux_preserves_backslashes(self): + # On linux, backslashes are NOT replaced + assert AtomicService._normalize_path('path\\with\\backslash', 'linux') == 'path\\with\\backslash' + + +# ============================================================================ +# Attachment handling +# ============================================================================ + +class TestHandleAttachment: + def test_handle_attachment(self, atomic_svc, generate_dummy_payload, tmp_path): + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) target_hash = hashlib.md5(DUMMY_PAYLOAD_CONTENT.encode()).hexdigest() target_name = target_hash[:PREFIX_HASH_LENGTH] + '_dummyatomicpayload' - target_path = atomic_svc.payloads_dir + '/' + target_name + target_path = os.path.join(atomic_svc.payloads_dir, target_name) assert atomic_svc._handle_attachment(DUMMY_PAYLOAD_PATH) == target_name assert os.path.isfile(target_path) with open(target_path, 'r') as f: file_data = f.read() assert file_data == DUMMY_PAYLOAD_CONTENT + def test_handle_attachment_name_format(self, atomic_svc, generate_dummy_payload, tmp_path): + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + result = atomic_svc._handle_attachment(DUMMY_PAYLOAD_PATH) + parts = result.split('_', 1) + assert len(parts) == 2 + assert len(parts[0]) == PREFIX_HASH_LENGTH + assert parts[1] == 'dummyatomicpayload' + + def test_handle_attachment_different_content_different_hash(self, atomic_svc, tmp_path): + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + path1 = str(tmp_path / 'payload_1') + path2 = str(tmp_path / 'payload_2') + with open(path1, 'w') as f: + f.write('content_A') + with open(path2, 'w') as f: + f.write('content_B') + name1 = atomic_svc._handle_attachment(path1) + name2 = atomic_svc._handle_attachment(path2) + assert name1 != name2 + + +# ============================================================================ +# Multiline command handling +# ============================================================================ + +class TestHandleMultilineCommands: def test_handle_multiline_command_sh(self, multiline_command): target = 'command1; command2; command3' assert AtomicService._handle_multiline_commands(multiline_command, 'sh') == target @@ -84,6 +156,20 @@ def test_handle_multiline_command_cmd(self, multiline_command): target = 'command1 && command2 && command3' assert AtomicService._handle_multiline_commands(multiline_command, 'cmd') == target + def test_handle_multiline_command_psh(self, multiline_command): + target = 'command1; command2; command3' + assert AtomicService._handle_multiline_commands(multiline_command, 'psh') == target + + def test_single_line_sh(self): + assert AtomicService._handle_multiline_commands('single', 'sh') == 'single' + + def test_single_line_cmd(self): + assert AtomicService._handle_multiline_commands('single', 'cmd') == 'single' + + def test_empty_command(self): + assert AtomicService._handle_multiline_commands('', 'sh') == '' + assert AtomicService._handle_multiline_commands('', 'cmd') == '' + def test_handle_multiline_command_cmd_comments(self): commands = '\n'.join([ 'command1', @@ -192,6 +278,124 @@ def test_handle_multiline_command_no_extra_semicolon_after_fi(self): f"Unexpected consecutive semicolons with only whitespace between them in: {result!r}" assert 'ip neighbour show' in result + def test_whitespace_only_lines(self): + commands = '\n'.join(['cmd1', ' ', 'cmd2']) + result = AtomicService._handle_multiline_commands(commands, 'sh') + assert 'cmd1' in result + assert 'cmd2' in result + + +# ============================================================================ +# Concatenate shell commands +# ============================================================================ + +class TestConcatenateShellCommands: + def test_empty_list(self): + assert AtomicService._concatenate_shell_commands([]) == '' + + def test_single_command(self): + assert AtomicService._concatenate_shell_commands(['echo hello']) == 'echo hello' + + def test_multiple_commands(self): + result = AtomicService._concatenate_shell_commands(['cmd1', 'cmd2', 'cmd3']) + assert result == 'cmd1; cmd2; cmd3' + + def test_line_ending_with_do(self): + result = AtomicService._concatenate_shell_commands(['for i in x; do', 'echo $i', 'done']) + assert result == 'for i in x; do echo $i; done' + + def test_line_ending_with_then(self): + result = AtomicService._concatenate_shell_commands(['if true; then', 'echo yes', 'fi']) + assert result == 'if true; then echo yes; fi' + + def test_line_ending_with_semicolon(self): + result = AtomicService._concatenate_shell_commands(['cmd1;', 'cmd2']) + assert result == 'cmd1; cmd2' + + def test_last_line_no_trailing_semicolon(self): + result = AtomicService._concatenate_shell_commands(['cmd1', 'cmd2']) + assert not result.endswith('; ') + + +# ============================================================================ +# Remove DOS comment lines +# ============================================================================ + +class TestRemoveDosCommentLines: + def test_remove_rem_comment(self): + lines = ['command1', 'REM comment', 'command2'] + result = AtomicService._remove_dos_comment_lines(lines) + assert result == ['command1', 'command2'] + + def test_remove_lowercase_rem(self): + lines = ['rem comment', 'command'] + result = AtomicService._remove_dos_comment_lines(lines) + assert result == ['command'] + + def test_remove_double_colon_comment(self): + lines = [':: comment', 'command'] + result = AtomicService._remove_dos_comment_lines(lines) + assert result == ['command'] + + def test_remove_at_rem(self): + lines = ['@rem comment', 'command'] + result = AtomicService._remove_dos_comment_lines(lines) + assert result == ['command'] + + def test_keep_non_comments(self): + lines = ['echo hello', 'dir'] + result = AtomicService._remove_dos_comment_lines(lines) + assert result == lines + + def test_empty_list(self): + assert AtomicService._remove_dos_comment_lines([]) == [] + + +# ============================================================================ +# Remove shell comments +# ============================================================================ + +class TestRemoveShellComments: + def test_remove_line_comment(self): + lines = ['# this is a comment', 'echo hello'] + result = AtomicService._remove_shell_comments(lines, 'sh') + assert result == ['echo hello'] + + def test_remove_trailing_comment(self): + lines = ['echo hello # comment'] + result = AtomicService._remove_shell_comments(lines, 'sh') + assert result == ['echo hello'] + + def test_preserve_hash_in_quotes(self): + lines = ['echo "this # is not a comment"'] + result = AtomicService._remove_shell_comments(lines, 'sh') + assert result == ['echo "this # is not a comment"'] + + def test_psh_escaped_quotes(self): + lines = ['echo `"not a real quote # comment `"'] + result = AtomicService._remove_shell_comments(lines, 'psh') + assert result == ['echo `"not a real quote'] + + def test_sh_escaped_quotes(self): + lines = ["echo \\'not a real quote # comment \\'"] + result = AtomicService._remove_shell_comments(lines, 'sh') + # The escaped quotes should be removed during processing, leaving the # as a comment + assert len(result) == 1 + + def test_semicolon_comment(self): + lines = ['echo hello;# comment'] + result = AtomicService._remove_shell_comments(lines, 'sh') + assert result == ['echo hello'] + + def test_empty_lines(self): + assert AtomicService._remove_shell_comments([], 'sh') == [] + + +# ============================================================================ +# Default inputs +# ============================================================================ + +class TestUseDefaultInputs: def test_use_default_inputs(self, atomic_svc, atomic_test): platform = 'windows' string_to_analyze = '#{recon_commands} -a' @@ -199,8 +403,8 @@ def test_use_default_inputs(self, atomic_svc, atomic_test): test['input_arguments']['recon_commands']['default'] = \ 'PathToAtomicsFolder\\T1016\\src\\nonexistent-qakbot.bat' got = atomic_svc._use_default_inputs(test=test, - platform=platform, - string_to_analyse=string_to_analyze) + platform=platform, + string_to_analyse=string_to_analyze) assert got[0] == 'PathToAtomicsFolder\\T1016\\src\\nonexistent-qakbot.bat -a' assert got[1] == [] @@ -208,8 +412,8 @@ def test_use_default_inputs_empty_string(self, atomic_svc, atomic_test): platform = 'windows' string_to_analyze = '' got = atomic_svc._use_default_inputs(test=atomic_test, - platform=platform, - string_to_analyse=string_to_analyze) + platform=platform, + string_to_analyse=string_to_analyze) assert got[0] == '' assert got[1] == [] @@ -219,7 +423,543 @@ def test_use_default_inputs_nil_valued(self, atomic_svc, atomic_test): test = atomic_test test['input_arguments']['recon_commands']['default'] = '' got = atomic_svc._use_default_inputs(test=test, - platform=platform, - string_to_analyse=string_to_analyze) + platform=platform, + string_to_analyse=string_to_analyze) assert got[0] == '' assert got[1] == [] + + def test_use_default_inputs_multiple_variables(self, atomic_svc): + test = { + 'input_arguments': { + 'var_a': {'default': 'ALPHA'}, + 'var_b': {'default': 'BETA'}, + }, + 'executor': {'command': '#{var_a} #{var_b}', 'name': 'sh'} + } + got = atomic_svc._use_default_inputs(test=test, platform='linux', + string_to_analyse='#{var_a} and #{var_b}') + assert got[0] == 'ALPHA and BETA' + + def test_use_default_inputs_no_variables(self, atomic_svc, atomic_test): + got = atomic_svc._use_default_inputs(test=atomic_test, platform='linux', + string_to_analyse='plain command') + assert got[0] == 'plain command' + assert got[1] == [] + + def test_use_default_inputs_reserved_parameter(self, atomic_svc, atomic_test): + """Commands with reserved parameters (#{server}, #{paw}, etc.) should be left untouched.""" + got = atomic_svc._use_default_inputs(test=atomic_test, platform='linux', + string_to_analyse='curl #{server}/file') + assert got[0] == 'curl #{server}/file' + + def test_use_default_inputs_integer_default(self, atomic_svc): + """Default value that is an integer should be converted to string.""" + test = { + 'input_arguments': { + 'port': {'default': 8080}, + }, + 'executor': {'command': 'nc -l #{port}', 'name': 'sh'} + } + got = atomic_svc._use_default_inputs(test=test, platform='linux', + string_to_analyse='nc -l #{port}') + assert got[0] == 'nc -l 8080' + + +# ============================================================================ +# has_reserved_parameter +# ============================================================================ + +class TestHasReservedParameter: + def test_has_server(self, atomic_svc): + assert atomic_svc._has_reserved_parameter('#{server}/api') + + def test_has_paw(self, atomic_svc): + assert atomic_svc._has_reserved_parameter('agent #{paw}') + + def test_has_group(self, atomic_svc): + assert atomic_svc._has_reserved_parameter('#{group}') + + def test_has_location(self, atomic_svc): + assert atomic_svc._has_reserved_parameter('#{location}/file') + + def test_no_reserved(self, atomic_svc): + assert not atomic_svc._has_reserved_parameter('echo hello') + + def test_custom_variable_not_reserved(self, atomic_svc): + assert not atomic_svc._has_reserved_parameter('#{custom_var}') + + +# ============================================================================ +# catch_path_to_atomics_folder +# ============================================================================ + +class TestCatchPathToAtomicsFolder: + def test_no_path_in_string(self, atomic_svc): + result, payloads = atomic_svc._catch_path_to_atomics_folder('no path here', 'linux') + assert result == 'no path here' + assert payloads == [] + + def test_path_with_nonexistent_file(self, atomic_svc): + cmd = '$PathToAtomicsFolder/T1234/src/nonexistent.sh' + result, payloads = atomic_svc._catch_path_to_atomics_folder(cmd, 'linux') + # Since the file doesn't exist, it should remain unchanged + assert result == cmd + assert payloads == [] + + def test_path_with_backslash_windows(self, atomic_svc): + cmd = 'PathToAtomicsFolder\\T1234\\src\\nonexistent.bat' + result, payloads = atomic_svc._catch_path_to_atomics_folder(cmd, 'windows') + # File doesn't exist, so no replacement + assert payloads == [] + + +# ============================================================================ +# gen_single_match_tactic_technique (generator) +# ============================================================================ + +class TestGenSingleMatchTacticTechnique: + def test_basic_match(self, mitre_json_data): + results = list(AtomicService._gen_single_match_tactic_technique(mitre_json_data)) + # T1016 -> discovery, T1059 -> execution + persistence + assert ('discovery', 'T1016') in results + assert ('execution', 'T1059') in results + assert ('persistence', 'T1059') in results + + def test_skips_non_attack_patterns(self, mitre_json_data): + results = list(AtomicService._gen_single_match_tactic_technique(mitre_json_data)) + # S0001 is malware type, should not appear + assert all(ext_id != 'S0001' for _, ext_id in results) + + def test_skips_non_mitre_sources(self, mitre_json_data): + results = list(AtomicService._gen_single_match_tactic_technique(mitre_json_data)) + assert all(ext_id != 'X9999' for _, ext_id in results) + + def test_empty_json(self): + results = list(AtomicService._gen_single_match_tactic_technique({})) + assert results == [] + + def test_empty_objects(self): + results = list(AtomicService._gen_single_match_tactic_technique({'objects': []})) + assert results == [] + + def test_object_without_external_references(self): + data = {'objects': [{'type': 'attack-pattern'}]} + results = list(AtomicService._gen_single_match_tactic_technique(data)) + assert results == [] + + def test_object_without_kill_chain_phases(self): + data = {'objects': [{ + 'type': 'attack-pattern', + 'external_references': [{'source_name': 'mitre-attack', 'external_id': 'T0001'}] + }]} + results = list(AtomicService._gen_single_match_tactic_technique(data)) + assert results == [] + + +# ============================================================================ +# populate_dict_techniques_tactics +# ============================================================================ + +class TestPopulateDictTechniquesTactics: + @pytest.mark.asyncio + async def test_populates_mapping(self, atomic_svc, mitre_json_data): + mock_file = mock_open(read_data=json.dumps(mitre_json_data)) + with patch('builtins.open', mock_file): + await atomic_svc._populate_dict_techniques_tactics() + + assert 'T1016' in atomic_svc.technique_to_tactics + assert 'discovery' in atomic_svc.technique_to_tactics['T1016'] + assert 'T1059' in atomic_svc.technique_to_tactics + assert 'execution' in atomic_svc.technique_to_tactics['T1059'] + assert 'persistence' in atomic_svc.technique_to_tactics['T1059'] + + +# ============================================================================ +# clone_atomic_red_team_repo +# ============================================================================ + +class TestCloneAtomicRedTeamRepo: + @pytest.mark.asyncio + async def test_clone_default_url(self, atomic_svc): + with patch('os.path.exists', return_value=False), \ + patch('app.atomic_svc.check_call') as mock_call: + await atomic_svc.clone_atomic_red_team_repo() + mock_call.assert_called_once() + args = mock_call.call_args[0][0] + assert 'https://github.com/redcanaryco/atomic-red-team.git' in args + + @pytest.mark.asyncio + async def test_clone_custom_url(self, atomic_svc): + with patch('os.path.exists', return_value=False), \ + patch('app.atomic_svc.check_call') as mock_call: + await atomic_svc.clone_atomic_red_team_repo(repo_url='https://example.com/fork.git') + args = mock_call.call_args[0][0] + assert 'https://example.com/fork.git' in args + + @pytest.mark.asyncio + async def test_clone_skips_when_exists(self, atomic_svc): + with patch('os.path.exists', return_value=True), \ + patch('os.listdir', return_value=['some_file']), \ + patch('app.atomic_svc.check_call') as mock_call: + await atomic_svc.clone_atomic_red_team_repo() + mock_call.assert_not_called() + + @pytest.mark.asyncio + async def test_clone_runs_when_dir_empty(self, atomic_svc): + with patch('os.path.exists', return_value=True), \ + patch('os.listdir', return_value=[]), \ + patch('app.atomic_svc.check_call') as mock_call: + await atomic_svc.clone_atomic_red_team_repo() + mock_call.assert_called_once() + + +# ============================================================================ +# prepare_cmd +# ============================================================================ + +class TestPrepareCmd: + @pytest.mark.asyncio + async def test_basic_prepare(self, atomic_svc, atomic_test_linux): + cmd, payloads = await atomic_svc._prepare_cmd( + atomic_test_linux, 'linux', 'sh', + 'whoami > #{output_file}' + ) + assert '/tmp/output.txt' in cmd + assert payloads == [] + + @pytest.mark.asyncio + async def test_multiline_prepare(self, atomic_svc, atomic_test_linux): + cmd, payloads = await atomic_svc._prepare_cmd( + atomic_test_linux, 'linux', 'sh', + 'line1\nline2\nline3' + ) + assert '; ' in cmd or cmd == 'line1; line2; line3' + + @pytest.mark.asyncio + async def test_empty_command(self, atomic_svc, atomic_test_linux): + cmd, payloads = await atomic_svc._prepare_cmd( + atomic_test_linux, 'linux', 'sh', '' + ) + assert cmd == '' + assert payloads == [] + + +# ============================================================================ +# prepare_executor +# ============================================================================ + +class TestPrepareExecutor: + @pytest.mark.asyncio + async def test_basic_executor(self, atomic_svc, atomic_test_linux): + command, cleanup, payloads = await atomic_svc._prepare_executor( + atomic_test_linux, 'linux', 'sh' + ) + assert 'whoami' in command + assert payloads == [] + + @pytest.mark.asyncio + async def test_executor_with_cleanup(self, atomic_svc): + test = { + 'name': 'test', + 'input_arguments': {}, + 'executor': { + 'command': 'mkdir /tmp/test_dir', + 'cleanup_command': 'rm -rf /tmp/test_dir', + 'name': 'sh' + } + } + command, cleanup, payloads = await atomic_svc._prepare_executor(test, 'linux', 'sh') + assert 'mkdir' in command + assert 'rm -rf' in cleanup + + @pytest.mark.asyncio + async def test_executor_no_cleanup(self, atomic_svc, atomic_test_linux): + command, cleanup, payloads = await atomic_svc._prepare_executor( + atomic_test_linux, 'linux', 'sh' + ) + assert cleanup == '' + + @pytest.mark.asyncio + async def test_executor_with_dependencies_extraction_error(self, atomic_svc): + """Dependencies that can't be automated should be skipped gracefully.""" + test = { + 'name': 'test_with_dep', + 'input_arguments': {}, + 'dependencies': [ + { + 'prereq_command': 'echo "Run this manually"; exit 1', + 'get_prereq_command': 'echo "Sorry, cannot automate"', + } + ], + 'executor': { + 'command': 'echo hello', + 'name': 'sh' + } + } + command, cleanup, payloads = await atomic_svc._prepare_executor(test, 'linux', 'sh') + # Even though the prereq fails, we still get the command + assert 'echo hello' in command + + +# ============================================================================ +# save_ability +# ============================================================================ + +class TestSaveAbility: + @pytest.mark.asyncio + async def test_save_ability_creates_file(self, atomic_svc, atomic_entries, tmp_path): + atomic_svc.data_dir = str(tmp_path / 'data') + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + atomic_svc.technique_to_tactics = defaultdict(list, {'T1016': ['discovery']}) + atomic_svc.repo_dir = str(tmp_path / 'repo') + + test = { + 'name': 'Test Ability', + 'description': 'A test', + 'supported_platforms': ['linux'], + 'input_arguments': {}, + 'executor': { + 'command': 'whoami', + 'name': 'sh' + } + } + result = await atomic_svc._save_ability(atomic_entries, test) + assert result is True + ability_dir = os.path.join(atomic_svc.data_dir, 'abilities', 'discovery') + assert os.path.isdir(ability_dir) + files = os.listdir(ability_dir) + assert len(files) == 1 + assert files[0].endswith('.yml') + + @pytest.mark.asyncio + async def test_save_ability_manual_skipped(self, atomic_svc, atomic_entries, atomic_test_manual): + atomic_svc.technique_to_tactics = defaultdict(list, {'T1016': ['discovery']}) + result = await atomic_svc._save_ability(atomic_entries, atomic_test_manual) + assert result is False + + @pytest.mark.asyncio + async def test_save_ability_multiple_tactics(self, atomic_svc, atomic_entries, tmp_path): + atomic_svc.data_dir = str(tmp_path / 'data') + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + atomic_svc.technique_to_tactics = defaultdict(list, { + 'T1016': ['discovery', 'collection'] + }) + atomic_svc.repo_dir = str(tmp_path / 'repo') + + test = { + 'name': 'Multi-tactic Test', + 'description': 'A test with multiple tactics', + 'supported_platforms': ['linux'], + 'input_arguments': {}, + 'executor': { + 'command': 'whoami', + 'name': 'sh' + } + } + result = await atomic_svc._save_ability(atomic_entries, test) + assert result is True + ability_dir = os.path.join(atomic_svc.data_dir, 'abilities', 'multiple') + assert os.path.isdir(ability_dir) + + @pytest.mark.asyncio + async def test_save_ability_unknown_technique(self, atomic_svc, tmp_path): + atomic_svc.data_dir = str(tmp_path / 'data') + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + atomic_svc.technique_to_tactics = defaultdict(list) + atomic_svc.repo_dir = str(tmp_path / 'repo') + + entries = {'attack_technique': 'T9999', 'display_name': 'Unknown'} + test = { + 'name': 'Unknown Tech', + 'description': 'Test for unknown technique', + 'supported_platforms': ['linux'], + 'input_arguments': {}, + 'executor': { + 'command': 'echo test', + 'name': 'sh' + } + } + result = await atomic_svc._save_ability(entries, test) + assert result is True + ability_dir = os.path.join(atomic_svc.data_dir, 'abilities', 'redcanary-unknown') + assert os.path.isdir(ability_dir) + + @pytest.mark.asyncio + async def test_save_ability_psh_has_parsers(self, atomic_svc, atomic_entries, tmp_path): + atomic_svc.data_dir = str(tmp_path / 'data') + atomic_svc.payloads_dir = str(tmp_path / 'payloads') + os.makedirs(atomic_svc.payloads_dir, exist_ok=True) + atomic_svc.technique_to_tactics = defaultdict(list, {'T1016': ['discovery']}) + atomic_svc.repo_dir = str(tmp_path / 'repo') + + test = { + 'name': 'PSH Test', + 'description': 'PowerShell test', + 'supported_platforms': ['windows'], + 'input_arguments': {}, + 'executor': { + 'command': 'Get-Process', + 'name': 'powershell' + } + } + result = await atomic_svc._save_ability(atomic_entries, test) + assert result is True + + import yaml + ability_dir = os.path.join(atomic_svc.data_dir, 'abilities', 'discovery') + files = os.listdir(ability_dir) + with open(os.path.join(ability_dir, files[0]), 'r') as f: + data = yaml.safe_load(f) + assert 'parsers' in data[0]['platforms']['windows']['psh'] + + +# ============================================================================ +# populate_data_directory +# ============================================================================ + +class TestPopulateDataDirectory: + @pytest.mark.asyncio + async def test_populate_calls_techniques_if_empty(self, atomic_svc): + with patch.object(atomic_svc, '_populate_dict_techniques_tactics', new_callable=AsyncMock) as mock_pop, \ + patch('glob.iglob', return_value=[]): + await atomic_svc.populate_data_directory() + mock_pop.assert_called_once() + + @pytest.mark.asyncio + async def test_populate_skips_techniques_if_filled(self, atomic_svc): + atomic_svc.technique_to_tactics = {'T1016': ['discovery']} + with patch.object(atomic_svc, '_populate_dict_techniques_tactics', new_callable=AsyncMock) as mock_pop, \ + patch('glob.iglob', return_value=[]): + await atomic_svc.populate_data_directory() + mock_pop.assert_not_called() + + +# ============================================================================ +# prereq_formater +# ============================================================================ + +class TestPrereqFormater: + @pytest.mark.asyncio + async def test_sh_falsy_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='if test -f /file; exit 1', + prereq='wget http://example.com/file', + prereq_type='sh', + exec_type='sh', + ability_command='echo done' + ) + assert 'wget' in result + assert 'echo done' in result + + @pytest.mark.asyncio + async def test_sh_truthy_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='if test -f /file; exit 0', + prereq='wget http://example.com/file', + prereq_type='sh', + exec_type='sh', + ability_command='echo done' + ) + assert 'else' in result + + @pytest.mark.asyncio + async def test_psh_try_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='Try { Get-Item file } Catch { exit 1 }', + prereq='Install-Module thing', + prereq_type='psh', + exec_type='psh', + ability_command='Use-Module thing' + ) + assert 'Install-Module' in result + + @pytest.mark.asyncio + async def test_psh_falsy_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='if (Test-Path file) { exit 1 }', + prereq='Download-File file', + prereq_type='psh', + exec_type='psh', + ability_command='Use-File file' + ) + assert 'Download-File' in result + + @pytest.mark.asyncio + async def test_cmd_falsy_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='IF EXIST file (exit 1) ELSE (exit 0)', + prereq='curl http://example.com/file', + prereq_type='cmd', + exec_type='cmd', + ability_command='run_file' + ) + assert 'curl' in result + + @pytest.mark.asyncio + async def test_cmd_truthy_prereq(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='IF EXIST file (exit 0) ELSE (exit 1)', + prereq='curl http://example.com/file', + prereq_type='cmd', + exec_type='cmd', + ability_command='run_file' + ) + assert 'call' in result # truthy uses 'call' + + @pytest.mark.asyncio + async def test_raises_on_echo_prereq(self, atomic_svc): + with pytest.raises(ExtractionError): + await atomic_svc._prereq_formater( + prereq_test='echo "check manually"', + prereq='echo "Run this manually"', + prereq_type='sh', + exec_type='sh', + ability_command='cmd' + ) + + @pytest.mark.asyncio + async def test_raises_on_sorry_prereq(self, atomic_svc): + with pytest.raises(ExtractionError): + await atomic_svc._prereq_formater( + prereq_test='echo check', + prereq='echo Sorry, cannot automate', + prereq_type='sh', + exec_type='sh', + ability_command='cmd' + ) + + @pytest.mark.asyncio + async def test_unknown_prereq_type_returns_ability(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='if test; exit 1', + prereq='install_thing', + prereq_type='unknown', + exec_type='sh', + ability_command='my_command' + ) + assert result == 'my_command' + + @pytest.mark.asyncio + async def test_cross_type_cmd_psh(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='IF EXIST file (exit 1) ELSE (exit 0)', + prereq='curl file', + prereq_type='cmd', + exec_type='psh', + ability_command='Use-File' + ) + assert 'Use-File' in result + + @pytest.mark.asyncio + async def test_cross_type_psh_cmd(self, atomic_svc): + result = await atomic_svc._prereq_formater( + prereq_test='if (Test-Path file) { exit 1 }', + prereq='Download file', + prereq_type='psh', + exec_type='cmd', + ability_command='run_file' + ) + assert 'powershell -command' in result diff --git a/tests/test_hook.py b/tests/test_hook.py new file mode 100644 index 0000000..eca14db --- /dev/null +++ b/tests/test_hook.py @@ -0,0 +1,124 @@ +import os +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + + +class TestHookModuleAttributes: + """Test module-level attributes in hook.py.""" + + def test_name(self): + import hook + assert hook.name == 'Atomic' + + def test_description(self): + import hook + assert hook.description == 'The collection of abilities in the Red Canary Atomic test project' + + def test_address(self): + import hook + assert hook.address == '/plugin/atomic/gui' + + def test_access(self): + import hook + from app.utility.base_world import BaseWorld + assert hook.access == BaseWorld.Access.RED + + def test_data_dir(self): + import hook + assert hook.data_dir == os.path.join('plugins', 'atomic', 'data') + + +class TestHookEnable: + """Test the enable() async function.""" + + @pytest.mark.asyncio + async def test_enable_creates_gui(self): + import hook + + mock_app = MagicMock() + mock_app_svc = MagicMock() + mock_app_svc.application = mock_app + + services = { + 'auth_svc': MagicMock(), + 'data_svc': MagicMock(), + 'app_svc': mock_app_svc, + } + + with patch.object(hook, 'data_dir', '/tmp/atomic_test_hook_data'), \ + patch('os.listdir', return_value=['abilities', 'other']): + await hook.enable(services) + # If abilities already exists, no cloning or populating should happen + + @pytest.mark.asyncio + async def test_enable_ingests_when_no_abilities(self): + import hook + + mock_app = MagicMock() + mock_app_svc = MagicMock() + mock_app_svc.application = mock_app + + services = { + 'auth_svc': MagicMock(), + 'data_svc': MagicMock(), + 'app_svc': mock_app_svc, + } + + mock_atomic_svc = MagicMock() + mock_atomic_svc.clone_atomic_red_team_repo = AsyncMock() + mock_atomic_svc.populate_data_directory = AsyncMock() + + with patch.object(hook, 'data_dir', '/tmp/atomic_test_hook_data'), \ + patch('os.listdir', return_value=['some_file']), \ + patch('hook.AtomicService', return_value=mock_atomic_svc), \ + patch('hook.AtomicGUI'): + await hook.enable(services) + mock_atomic_svc.clone_atomic_red_team_repo.assert_called_once() + mock_atomic_svc.populate_data_directory.assert_called_once() + + @pytest.mark.asyncio + async def test_enable_skips_ingest_when_abilities_exist(self): + import hook + + mock_app = MagicMock() + mock_app_svc = MagicMock() + mock_app_svc.application = mock_app + + services = { + 'auth_svc': MagicMock(), + 'data_svc': MagicMock(), + 'app_svc': mock_app_svc, + } + + mock_atomic_svc = MagicMock() + mock_atomic_svc.clone_atomic_red_team_repo = AsyncMock() + mock_atomic_svc.populate_data_directory = AsyncMock() + + with patch.object(hook, 'data_dir', '/tmp/atomic_test_hook_data'), \ + patch('os.listdir', return_value=['abilities', 'other_stuff']), \ + patch('hook.AtomicService', return_value=mock_atomic_svc) as mock_svc_cls, \ + patch('hook.AtomicGUI'): + await hook.enable(services) + # AtomicService should NOT be instantiated when abilities dir exists + mock_svc_cls.assert_not_called() + + @pytest.mark.asyncio + async def test_enable_accesses_app(self): + """enable() should access services['app_svc'].application.""" + import hook + + mock_app = MagicMock() + mock_app_svc = MagicMock() + mock_app_svc.application = mock_app + + services = { + 'auth_svc': MagicMock(), + 'data_svc': MagicMock(), + 'app_svc': mock_app_svc, + } + + with patch.object(hook, 'data_dir', '/tmp/atomic_test_hook_data'), \ + patch('os.listdir', return_value=['abilities']), \ + patch('hook.AtomicGUI'): + await hook.enable(services) + _ = mock_app_svc.application # verify it was accessed From 9a2085c8aad95f9205b8d29a2f5b2a22eb8c1f2f Mon Sep 17 00:00:00 2001 From: deacon Date: Wed, 18 Mar 2026 09:08:13 -0400 Subject: [PATCH 2/2] fix: address Copilot review feedback on pytest coverage tests - conftest.py: generate_dummy_payload uses tmp_path fixture instead of hard-coded /tmp path; cleanup is now handled automatically by pytest - test_atomic_svc.py: attachment tests use fixture-returned path instead of the module-level DUMMY_PAYLOAD_PATH constant - test_atomic_powershell.py: mark checked_flags tests as xfail since they assert current buggy behavior (list() splitting string into chars) and will break once the underlying bug is fixed - test_hook.py: test_enable_creates_gui now patches hook.AtomicGUI and asserts it was called with expected args; test_enable_accesses_app uses PropertyMock for mock_app_svc.application and asserts it was accessed --- tests/conftest.py | 10 ++++------ tests/test_atomic_powershell.py | 27 +++++++++++++++++++++++++++ tests/test_atomic_svc.py | 4 ++-- tests/test_hook.py | 11 ++++++----- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a7a7b73..27cb536 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,12 +157,10 @@ def atomic_svc(): @pytest.fixture -def generate_dummy_payload(): - with open(DUMMY_PAYLOAD_PATH, 'w') as f: - f.write(DUMMY_PAYLOAD_CONTENT) - yield DUMMY_PAYLOAD_PATH - if os.path.exists(DUMMY_PAYLOAD_PATH): - os.remove(DUMMY_PAYLOAD_PATH) +def generate_dummy_payload(tmp_path): + payload_path = tmp_path / 'dummyatomicpayload' + payload_path.write_text(DUMMY_PAYLOAD_CONTENT) + yield str(payload_path) @pytest.fixture diff --git a/tests/test_atomic_powershell.py b/tests/test_atomic_powershell.py index eda169e..4ae1b43 100644 --- a/tests/test_atomic_powershell.py +++ b/tests/test_atomic_powershell.py @@ -11,15 +11,26 @@ class TestParserCheckedFlags: characters ['F', 'u', 'l', 'l', 'y', ...] instead of the intended ['FullyQualifiedErrorId']. This means the parser checks for single characters rather than the full error string. + + Tests marked xfail below document current buggy behavior and are expected + to start passing once the bug is fixed (at which point they should be + updated to assert the correct behavior instead). """ def test_checked_flags_is_list(self): assert isinstance(Parser.checked_flags, list) + @pytest.mark.xfail( + reason="Bug: list('FullyQualifiedErrorId') splits into chars; " + "once fixed, checked_flags will equal ['FullyQualifiedErrorId'] " + "and this test will need to be updated to assert the correct value." + ) def test_checked_flags_known_bug_individual_characters(self): """ Demonstrates the known bug: list('FullyQualifiedErrorId') splits the string into individual characters instead of wrapping it in a list. + Once the bug is fixed, Parser.checked_flags will equal + ['FullyQualifiedErrorId'] and this assertion will fail. """ expected_buggy = list('FullyQualifiedErrorId') assert Parser.checked_flags == expected_buggy @@ -27,19 +38,35 @@ def test_checked_flags_known_bug_individual_characters(self): expected_correct = ['FullyQualifiedErrorId'] assert Parser.checked_flags != expected_correct + @pytest.mark.xfail( + reason="Bug: checked_flags contains one entry per character (21 total) " + "instead of a single string entry; will break when bug is fixed." + ) def test_checked_flags_length_is_wrong(self): """The list has 21 entries (one per char) instead of 1.""" assert len(Parser.checked_flags) == len('FullyQualifiedErrorId') assert len(Parser.checked_flags) != 1 + @pytest.mark.xfail( + reason="Bug: checked_flags contains individual characters; " + "will break when bug is fixed and flags become full strings." + ) def test_checked_flags_contains_individual_chars(self): """Each element is a single character.""" for flag in Parser.checked_flags: assert len(flag) == 1 + @pytest.mark.xfail( + reason="Bug: first element is 'F' (first char of 'FullyQualifiedErrorId'); " + "will break when bug is fixed." + ) def test_checked_flags_first_char_is_F(self): assert Parser.checked_flags[0] == 'F' + @pytest.mark.xfail( + reason="Bug: last element is 'd' (last char of 'FullyQualifiedErrorId'); " + "will break when bug is fixed." + ) def test_checked_flags_last_char_is_d(self): assert Parser.checked_flags[-1] == 'd' diff --git a/tests/test_atomic_svc.py b/tests/test_atomic_svc.py index 1fa7281..8232872 100644 --- a/tests/test_atomic_svc.py +++ b/tests/test_atomic_svc.py @@ -114,7 +114,7 @@ def test_handle_attachment(self, atomic_svc, generate_dummy_payload, tmp_path): target_hash = hashlib.md5(DUMMY_PAYLOAD_CONTENT.encode()).hexdigest() target_name = target_hash[:PREFIX_HASH_LENGTH] + '_dummyatomicpayload' target_path = os.path.join(atomic_svc.payloads_dir, target_name) - assert atomic_svc._handle_attachment(DUMMY_PAYLOAD_PATH) == target_name + assert atomic_svc._handle_attachment(generate_dummy_payload) == target_name assert os.path.isfile(target_path) with open(target_path, 'r') as f: file_data = f.read() @@ -123,7 +123,7 @@ def test_handle_attachment(self, atomic_svc, generate_dummy_payload, tmp_path): def test_handle_attachment_name_format(self, atomic_svc, generate_dummy_payload, tmp_path): atomic_svc.payloads_dir = str(tmp_path / 'payloads') os.makedirs(atomic_svc.payloads_dir, exist_ok=True) - result = atomic_svc._handle_attachment(DUMMY_PAYLOAD_PATH) + result = atomic_svc._handle_attachment(generate_dummy_payload) parts = result.split('_', 1) assert len(parts) == 2 assert len(parts[0]) == PREFIX_HASH_LENGTH diff --git a/tests/test_hook.py b/tests/test_hook.py index eca14db..f23c090 100644 --- a/tests/test_hook.py +++ b/tests/test_hook.py @@ -1,6 +1,6 @@ import os import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock, patch, PropertyMock class TestHookModuleAttributes: @@ -46,9 +46,10 @@ async def test_enable_creates_gui(self): } with patch.object(hook, 'data_dir', '/tmp/atomic_test_hook_data'), \ - patch('os.listdir', return_value=['abilities', 'other']): + patch('os.listdir', return_value=['abilities', 'other']), \ + patch('hook.AtomicGUI') as mock_gui_cls: await hook.enable(services) - # If abilities already exists, no cloning or populating should happen + mock_gui_cls.assert_called_once_with(services, hook.name, hook.description) @pytest.mark.asyncio async def test_enable_ingests_when_no_abilities(self): @@ -109,7 +110,7 @@ async def test_enable_accesses_app(self): mock_app = MagicMock() mock_app_svc = MagicMock() - mock_app_svc.application = mock_app + type(mock_app_svc).application = PropertyMock(return_value=mock_app) services = { 'auth_svc': MagicMock(), @@ -121,4 +122,4 @@ async def test_enable_accesses_app(self): patch('os.listdir', return_value=['abilities']), \ patch('hook.AtomicGUI'): await hook.enable(services) - _ = mock_app_svc.application # verify it was accessed + type(mock_app_svc).application.assert_called()