diff --git a/source/app/__init__.py b/source/app/__init__.py index 9b7e756c4..61e3b5a7e 100644 --- a/source/app/__init__.py +++ b/source/app/__init__.py @@ -170,3 +170,11 @@ def after_request(response): lm.user_loader(load_user) lm.request_loader(load_user_from_request) + +from app.blueprints.socket_io_event_handlers.case_event_handlers import register_case_event_handlers +from app.blueprints.socket_io_event_handlers.case_notes_event_handlers import register_notes_event_handlers +from app.blueprints.socket_io_event_handlers.update_event_handlers import register_update_event_handlers + +register_case_event_handlers() +register_notes_event_handlers() +register_update_event_handlers() diff --git a/source/app/blueprints/rest/case/case_timeline_routes.py b/source/app/blueprints/rest/case/case_timeline_routes.py index 680e02176..af08122e2 100644 --- a/source/app/blueprints/rest/case/case_timeline_routes.py +++ b/source/app/blueprints/rest/case/case_timeline_routes.py @@ -74,6 +74,7 @@ from app.blueprints.responses import response_success from app.business.errors import BusinessProcessingError from app.business.events import events_create +from app.business.events import events_update case_timeline_rest_blueprint = Blueprint('case_timeline_rest', __name__) @@ -684,6 +685,7 @@ def event_view(cur_id, caseid): @case_timeline_rest_blueprint.route('/case/timeline/events/update/', methods=['POST']) +@endpoint_deprecated('PUT', '/api/v2/cases/{case_identifier}/events/{identifier}') @ac_requires_case_identifier(CaseAccessLevel.full_access) @ac_api_requires() def case_edit_event(cur_id, caseid): @@ -692,42 +694,10 @@ def case_edit_event(cur_id, caseid): if not event: return response_error("Invalid event ID for this case") - event_schema = EventSchema() - - request_data = call_modules_hook('on_preload_event_update', data=request.get_json(), caseid=caseid) + request_json = request.get_json() + event = events_update(event, request_json) - request_data['event_id'] = cur_id - event = event_schema.load(request_data, instance=event) - - event.event_date, event.event_date_wtz = event_schema.validate_date( - request_data.get(u'event_date'), - request_data.get(u'event_tz') - ) - - event.case_id = caseid - add_obj_history_entry(event, 'updated') - - update_timeline_state(caseid=caseid) - db.session.commit() - - save_event_category(event.event_id, request_data.get('event_category_id')) - - setattr(event, 'event_category_id', request_data.get('event_category_id')) - - success, log = update_event_assets(event.event_id, caseid, request_data.get('event_assets'), - request_data.get('event_iocs'), request_data.get('event_sync_iocs_assets')) - if not success: - return response_error('Error while saving linked assets', data=log) - - success, log = update_event_iocs(event_id=event.event_id, - caseid=caseid, - iocs_list=request_data.get('event_iocs')) - if not success: - return response_error('Error while saving linked iocs', data=log) - - event = call_modules_hook('on_postload_event_update', data=event, caseid=caseid) - - track_activity(f"updated event \"{event.event_title}\"", caseid=caseid) + event_schema = EventSchema() event_dump = event_schema.dump(event) collab_notify(case_id=caseid, object_type='events', @@ -737,8 +707,8 @@ def case_edit_event(cur_id, caseid): return response_success("Event updated", data=event_dump) - except marshmallow.exceptions.ValidationError as e: - return response_error(msg="Data error", data=e.normalized_messages()) + except BusinessProcessingError as e: + return response_error(e.get_message(), data=e.get_data()) @case_timeline_rest_blueprint.route('/case/timeline/events/add', methods=['POST']) diff --git a/source/app/blueprints/rest/v2/auth.py b/source/app/blueprints/rest/v2/auth.py index 9e99bff74..26b062edb 100644 --- a/source/app/blueprints/rest/v2/auth.py +++ b/source/app/blueprints/rest/v2/auth.py @@ -25,6 +25,7 @@ from app import app from app import db from app import oidc_client +from app.logger import logger from app.blueprints.access_controls import is_authentication_ldap from app.blueprints.access_controls import is_authentication_oidc from app.blueprints.access_controls import not_authenticated_redirection_url @@ -37,18 +38,16 @@ auth_blueprint = Blueprint('auth', __name__, url_prefix='/auth') -log = app.logger - @auth_blueprint.post('/login') def login(): """ Login endpoint. Handles taking user/pass combo and authenticating a local session or returning an error. """ - log.info('Authenticating user') + logger.info('Authenticating user') if current_user.is_authenticated: - log.info('User already authenticated - redirecting') - log.debug(f'User {current_user.user} already logged in') + logger.info('User already authenticated - redirecting') + logger.debug(f'User {current_user.user} already logged in') user = return_authed_user_info(user_id=current_user.id) return response_api_success(data=user) diff --git a/source/app/blueprints/rest/v2/case_objects/events.py b/source/app/blueprints/rest/v2/case_objects/events.py index b050ca7ce..c89c0a077 100644 --- a/source/app/blueprints/rest/v2/case_objects/events.py +++ b/source/app/blueprints/rest/v2/case_objects/events.py @@ -27,12 +27,14 @@ from app.blueprints.access_controls import ac_api_return_access_denied from app.business.events import events_create from app.business.events import events_get +from app.business.events import events_update from app.models.cases import CasesEvent from app.schema.marshables import EventSchema from app.business.errors import BusinessProcessingError from app.business.errors import ObjectNotFoundError from app.business.cases import cases_exists from app.iris_engine.access_control.utils import ac_fast_check_current_user_has_case_access +from app.iris_engine.utils.collab import notify from app.models.authorization import CaseAccessLevel @@ -41,7 +43,7 @@ @case_events_blueprint.post('') @ac_api_requires() -def create_evidence(case_identifier): +def create_event(case_identifier): if not cases_exists(case_identifier): return response_api_not_found() if not ac_fast_check_current_user_has_case_access(case_identifier, [CaseAccessLevel.full_access]): @@ -50,7 +52,10 @@ def create_evidence(case_identifier): try: event = events_create(case_identifier, request.get_json()) schema = EventSchema() - return response_api_created(schema.dump(event)) + result = schema.dump(event) + notify(case_identifier, 'events', 'updated', event.event_id, object_data=result) + + return response_api_created(result) except BusinessProcessingError as e: return response_api_error(e.get_message(), data=e.get_data()) @@ -77,6 +82,31 @@ def get_event(case_identifier, identifier): return response_api_error(e.get_message(), data=e.get_data()) +@case_events_blueprint.put('/') +@ac_api_requires() +def update_event(case_identifier, identifier): + if not cases_exists(case_identifier): + return response_api_not_found() + + try: + event = events_get(identifier) + if not ac_fast_check_current_user_has_case_access(event.case_id, [CaseAccessLevel.full_access]): + return ac_api_return_access_denied(caseid=event.case_id) + _check_event_and_case_identifier_match(event, case_identifier) + + event = events_update(event, request.get_json()) + + schema = EventSchema() + result = schema.dump(event) + notify(case_identifier, 'events', 'updated', identifier, object_data=result) + + return response_api_success(result) + except ObjectNotFoundError: + return response_api_not_found() + except BusinessProcessingError as e: + return response_api_error(e.get_message(), data=e.get_data()) + + def _check_event_and_case_identifier_match(event: CasesEvent, case_identifier): if event.case_id != case_identifier: raise BusinessProcessingError(f'Event {event.event_id} does not belong to case {case_identifier}') diff --git a/source/app/blueprints/socket_io_event_handlers/case_event_handlers.py b/source/app/blueprints/socket_io_event_handlers/case_event_handlers.py index e8061f334..6a82f4fdd 100644 --- a/source/app/blueprints/socket_io_event_handlers/case_event_handlers.py +++ b/source/app/blueprints/socket_io_event_handlers/case_event_handlers.py @@ -26,7 +26,6 @@ from app.models.authorization import CaseAccessLevel -@socket_io.on('change') @ac_socket_requires(CaseAccessLevel.full_access) def socket_summary_onchange(data): @@ -34,7 +33,6 @@ def socket_summary_onchange(data): emit('change', data, to=data['channel'], skip_sid=request.sid) -@socket_io.on('save') @ac_socket_requires(CaseAccessLevel.full_access) def socket_summary_onsave(data): @@ -42,17 +40,29 @@ def socket_summary_onsave(data): emit('save', data, to=data['channel'], skip_sid=request.sid) -@socket_io.on('clear_buffer') @ac_socket_requires(CaseAccessLevel.full_access) def socket_summary_on_clear_buffer(message): emit('clear_buffer', message) -@socket_io.on('join') @ac_socket_requires(CaseAccessLevel.full_access) def get_message(data): room = data['channel'] join_room(room=room) emit('join', {'message': f"{current_user.user} just joined"}, room=room) + + +@ac_socket_requires(CaseAccessLevel.full_access) +def socket_join_case_obj_notif(data): + room = data['channel'] + join_room(room=room) + + +def register_case_event_handlers(): + socket_io.on_event('change', socket_summary_onchange) + socket_io.on_event('save', socket_summary_onsave) + socket_io.on_event('clear_buffer', socket_summary_on_clear_buffer) + socket_io.on_event('join', get_message) + socket_io.on_event('join-case-obj-notif', socket_join_case_obj_notif) diff --git a/source/app/blueprints/socket_io_event_handlers/case_notes_event_handlers.py b/source/app/blueprints/socket_io_event_handlers/case_notes_event_handlers.py index fdbff4206..b4faa3adf 100644 --- a/source/app/blueprints/socket_io_event_handlers/case_notes_event_handlers.py +++ b/source/app/blueprints/socket_io_event_handlers/case_notes_event_handlers.py @@ -26,7 +26,6 @@ from app.models.authorization import CaseAccessLevel -@socket_io.on('change-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_change_note(data): @@ -34,7 +33,6 @@ def socket_change_note(data): emit('change-note', data, to=data['channel'], skip_sid=request.sid, room=data['channel']) -@socket_io.on('save-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_save_note(data): @@ -42,14 +40,12 @@ def socket_save_note(data): emit('save-note', data, to=data['channel'], skip_sid=request.sid, room=data['channel']) -@socket_io.on('clear_buffer-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_clear_buffer_note(message): emit('clear_buffer-note', message, room=message['channel']) -@socket_io.on('join-notes') @ac_socket_requires(CaseAccessLevel.full_access) def socket_join_note(data): @@ -62,28 +58,24 @@ def socket_join_note(data): }, room=room) -@socket_io.on('ping-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_ping_note(data): emit('ping-note', {"user": current_user.name, "note_id": data['note_id']}, room=data['channel']) -@socket_io.on('pong-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_pong_note(data): emit('pong-note', {"user": current_user.name, "note_id": data['note_id']}, room=data['channel']) -@socket_io.on('overview-map-note') @ac_socket_requires(CaseAccessLevel.full_access) def socket_overview_map_note(data): emit('overview-map-note', {"user": current_user.user, "note_id": data['note_id']}, room=data['channel']) -@socket_io.on('join-notes-overview') @ac_socket_requires(CaseAccessLevel.full_access) def socket_join_overview(data): @@ -96,7 +88,18 @@ def socket_join_overview(data): }, room=room) -@socket_io.on('disconnect') @ac_socket_requires(CaseAccessLevel.full_access) def socket_disconnect(data): emit('disconnect', current_user.user, broadcast=True) + + +def register_notes_event_handlers(): + socket_io.on_event('change-note', socket_change_note) + socket_io.on_event('save-note', socket_save_note) + socket_io.on_event('clear_buffer-note', socket_clear_buffer_note) + socket_io.on_event('join-notes', socket_join_note) + socket_io.on_event('ping-note', socket_ping_note) + socket_io.on_event('pong-note', socket_pong_note) + socket_io.on_event('overview-map-note', socket_overview_map_note) + socket_io.on_event('join-notes-overview', socket_join_overview) + socket_io.on_event('disconnect', socket_disconnect) diff --git a/source/app/blueprints/socket_io_event_handlers/update_event_handlers.py b/source/app/blueprints/socket_io_event_handlers/update_event_handlers.py new file mode 100644 index 000000000..ea727760c --- /dev/null +++ b/source/app/blueprints/socket_io_event_handlers/update_event_handlers.py @@ -0,0 +1,47 @@ +# IRIS Source Code +# Copyright (C) 2024 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from flask_login import current_user +from flask_socketio import emit +from flask_socketio import join_room + +from app import socket_io +from app import app + + +def get_message(data): + room = data['channel'] + join_room(room=room) + + emit('join', {'message': f"{current_user.user} just joined", 'is_error': False}, room=room, + namespace='/server-updates') + + +def socket_on_update_ping(msg): + emit('update_ping', {'message': "Server connected", 'is_error': False}, + namespace='/server-updates') + + +def socket_on_update_do_reboot(msg): + socket_io.emit('update_current_version', {"version": app.config.get('IRIS_VERSION')}, to='iris_update_status', + namespace='/server-updates') + + +def register_update_event_handlers(): + socket_io.on_event('join-update', get_message, namespace='/server-updates') + socket_io.on_event('update_ping', socket_on_update_ping, namespace='/server-updates') + socket_io.on_event('update_get_current_version', socket_on_update_do_reboot, namespace='/server-updates') diff --git a/source/app/business/cases.py b/source/app/business/cases.py index 7bc500611..4a5e9a5f9 100644 --- a/source/app/business/cases.py +++ b/source/app/business/cases.py @@ -17,27 +17,20 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import datetime -from app.logger import logger import traceback - from flask_login import current_user - from marshmallow.exceptions import ValidationError from app import db - +from app.logger import logger from app.util import add_obj_history_entry from app.schema.marshables import CaseSchema - from app.models.models import ReviewStatusList - from app.business.errors import BusinessProcessingError from app.business.iocs import iocs_exports_to_json - from app.iris_engine.module_handler.module_handler import call_modules_hook from app.iris_engine.utils.tracker import track_activity from app.iris_engine.access_control.utils import ac_set_new_case_access - from app.datamgmt.case.case_db import case_db_exists from app.datamgmt.case.case_db import save_case_tags from app.datamgmt.case.case_db import register_case_protagonists @@ -60,9 +53,10 @@ from app.datamgmt.reporter.report_db import export_case_tasks_json from app.datamgmt.reporter.report_db import export_case_comments_json from app.datamgmt.reporter.report_db import export_case_notes_json +from app.models.cases import Cases -def _load(request_data, **kwargs): +def _load(request_data, **kwargs) -> Cases: try: add_case_schema = CaseSchema() return add_case_schema.load(request_data, **kwargs) diff --git a/source/app/business/events.py b/source/app/business/events.py index d70236c5f..84700fbd2 100644 --- a/source/app/business/events.py +++ b/source/app/business/events.py @@ -18,6 +18,7 @@ from datetime import datetime +import marshmallow from flask_login import current_user from marshmallow.exceptions import ValidationError @@ -38,8 +39,11 @@ def _load(request_data, **kwargs): try: - evidence_schema = EventSchema() - return evidence_schema.load(request_data, **kwargs) + schema = EventSchema() + event = schema.load(request_data, **kwargs) + event.event_date, event.event_date_wtz = schema.validate_date(request_data.get(u'event_date'), + request_data.get(u'event_tz')) + return event except ValidationError as e: raise BusinessProcessingError('Data error', data=e.normalized_messages()) @@ -48,10 +52,6 @@ def events_create(case_identifier, request_json) -> CasesEvent: request_data = call_modules_hook('on_preload_event_create', data=request_json, caseid=case_identifier) event = _load(request_data) - # TODO this should probably rather be done in the API layer - event_schema = EventSchema() - event.event_date, event.event_date_wtz = event_schema.validate_date(request_data.get(u'event_date'), - request_data.get(u'event_tz')) event.case_id = case_identifier event.event_added = datetime.utcnow() @@ -92,3 +92,38 @@ def events_get(identifier) -> CasesEvent: if not event: raise ObjectNotFoundError() return event + + +def events_update(event: CasesEvent, request_json: dict) -> CasesEvent: + try: + request_data = call_modules_hook('on_preload_event_update', data=request_json, caseid=event.case_id) + + request_data['event_id'] = event.event_id + event = _load(request_data, instance=event) + + add_obj_history_entry(event, 'updated') + + update_timeline_state(caseid=event.case_id) + db.session.commit() + + save_event_category(event.event_id, request_data.get('event_category_id')) + + setattr(event, 'event_category_id', request_data.get('event_category_id')) + + success, log = update_event_assets(event.event_id, event.case_id, request_data.get('event_assets'), + request_data.get('event_iocs'), request_data.get('event_sync_iocs_assets')) + if not success: + raise BusinessProcessingError('Error while saving linked assets', data=log) + + success, log = update_event_iocs(event_id=event.event_id, + caseid=event.case_id, + iocs_list=request_data.get('event_iocs')) + if not success: + raise BusinessProcessingError('Error while saving linked iocs', data=log) + + event = call_modules_hook('on_postload_event_update', data=event, caseid=event.case_id) + + track_activity(f"updated event \"{event.event_title}\"", caseid=event.case_id) + return event + except marshmallow.exceptions.ValidationError as e: + raise BusinessProcessingError('Data error', data=e.normalized_messages()) diff --git a/source/app/iris_engine/updater/updater.py b/source/app/iris_engine/updater/updater.py index 680d83fc6..d3b068dfd 100644 --- a/source/app/iris_engine/updater/updater.py +++ b/source/app/iris_engine/updater/updater.py @@ -26,9 +26,6 @@ import time from celery.schedules import crontab from datetime import datetime -from flask_login import current_user -from flask_socketio import emit -from flask_socketio import join_room from packaging import version from pathlib import Path @@ -70,30 +67,6 @@ def update_log_error(status): update_log_to_socket(status, is_error=True) -@socket_io.on('join-update', namespace='/server-updates') -def get_message(data): - - room = data['channel'] - join_room(room=room) - - emit('join', {'message': f"{current_user.user} just joined", 'is_error': False}, room=room, - namespace='/server-updates') - - -@socket_io.on('update_ping', namespace='/server-updates') -def socket_on_update_ping(msg): - - emit('update_ping', {'message': "Server connected", 'is_error': False}, - namespace='/server-updates') - - -@socket_io.on('update_get_current_version', namespace='/server-updates') -def socket_on_update_do_reboot(msg): - - socket_io.emit('update_current_version', {"version": app.config.get('IRIS_VERSION')}, to='iris_update_status', - namespace='/server-updates') - - def notify_server_ready_to_reboot(): socket_io.emit('server_ready_to_reboot', {}, to='iris_update_status', namespace='/server-updates') diff --git a/source/app/iris_engine/utils/collab.py b/source/app/iris_engine/utils/collab.py index 15d02cb5c..1c9e863d5 100644 --- a/source/app/iris_engine/utils/collab.py +++ b/source/app/iris_engine/utils/collab.py @@ -1,23 +1,26 @@ import json -import app +from app import socket_io -def collab_notify(case_id: int, - object_type: str, - action_type: str, - object_id, - object_data: json = None, - request_sid: int = None - ): - room = f"case-{case_id}" - app.socket_io.emit('case-obj-notif', - json.dumps({ - 'object_id': object_id, - 'action_type': action_type, - 'object_type': object_type, - 'object_data': object_data - }), - room=room, - to=room, - skip_sid=request_sid) +def collab_notify(case_id: int, object_type: str, action_type: str, object_id, + object_data: json = None, request_sid: int = None): + room = f'case-{case_id}' + data = json.dumps({ + 'object_id': object_id, + 'action_type': action_type, + 'object_type': object_type, + 'object_data': object_data + }) + socket_io.emit('case-obj-notif', data, room=room, to=room, skip_sid=request_sid) + + +def notify(case_identifier: int, object_type: str, action_type: str, object_id, object_data: json = None): + room = f'case-{case_identifier}' + data = { + 'object_id': object_id, + 'action_type': action_type, + 'object_type': object_type, + 'object_data': object_data + } + socket_io.emit('case-obj-notif', data, room=room, to=room) diff --git a/source/app/models/models.py b/source/app/models/models.py index aececb1d5..faabf33c6 100644 --- a/source/app/models/models.py +++ b/source/app/models/models.py @@ -680,12 +680,11 @@ def __init__(self, tag_title, namespace=None): def save(self): existing_tag = self.get_by_title(self.tag_title) - if existing_tag is not None: + if existing_tag: return existing_tag - else: - db.session.add(self) - db.session.commit() - return self + db.session.add(self) + db.session.commit() + return self @classmethod def get_by_title(cls, tag_title): @@ -983,15 +982,13 @@ def create_safe_attr(session, attribute_display_name, attribute_description, att CustomAttribute.attribute_description == attribute_description, CustomAttribute.attribute_for == attribute_for ).first() - if cat: - return False - else: - instance = CustomAttribute() - instance.attribute_display_name = attribute_display_name - instance.attribute_description = attribute_description - instance.attribute_for = attribute_for - instance.attribute_content = attribute_content - session.add(instance) - session.commit() - return True + return + + instance = CustomAttribute() + instance.attribute_display_name = attribute_display_name + instance.attribute_description = attribute_description + instance.attribute_for = attribute_for + instance.attribute_content = attribute_content + session.add(instance) + session.commit() diff --git a/tests/iris.py b/tests/iris.py index fee320958..0aa46be9d 100644 --- a/tests/iris.py +++ b/tests/iris.py @@ -16,11 +16,12 @@ # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +from uuid import uuid4 from pathlib import Path from docker_compose import DockerCompose from rest_api import RestApi from user import User -from uuid import uuid4 +from socket_io_context_manager import SocketIOContextManager API_URL = 'http://127.0.0.1:8000' # TODO SSOT: this should be directly read from the .env file @@ -38,6 +39,10 @@ def __init__(self): # TODO remove this field and use _administrator instead self._api = RestApi(API_URL, _API_KEY) self._administrator = User(API_URL, _API_KEY, _ADMINISTRATOR_USER_IDENTIFIER) + self._socket_io_client = SocketIOContextManager(API_URL, _API_KEY) + + def get_socket_io_client(self) -> SocketIOContextManager: + return self._socket_io_client def create(self, path, body, query_parameters=None): return self._api.post(path, body, query_parameters) diff --git a/tests/requirements.txt b/tests/requirements.txt index 4d6d5ef1a..ece97a948 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1 +1,2 @@ requests >= 2.31.0, < 3.0.0 +python-socketio[client] \ No newline at end of file diff --git a/tests/socket_io_client.py b/tests/socket_io_client.py new file mode 100644 index 000000000..03bcec3bc --- /dev/null +++ b/tests/socket_io_client.py @@ -0,0 +1,42 @@ +# IRIS Source Code +# Copyright (C) 2023 - DFIR-IRIS +# contact@dfir-iris.org +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3 of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +from socketio import SimpleClient + + +class SocketIOClient: + + def __init__(self, url, api_key): + self._url = url + self._api_key = api_key + self._client = SimpleClient() + + def connect(self): + self._client.connect(self._url, headers={'Authorization': f'Bearer {self._api_key}'}) + + def emit(self, event, channel): + print(f'==> {event}/{channel}') + self._client.emit(event, {'channel': channel}) + + def receive(self): + message = self._client.receive(timeout=20) + print(f'<== {message[0]}/{message[1]}') + return message[1] + + def disconnect(self): + self._client.disconnect() diff --git a/source/app/blueprints/socket_io_event_handlers/collab.py b/tests/socket_io_context_manager.py similarity index 67% rename from source/app/blueprints/socket_io_event_handlers/collab.py rename to tests/socket_io_context_manager.py index 320811bc6..b6fa2f11f 100644 --- a/source/app/blueprints/socket_io_event_handlers/collab.py +++ b/tests/socket_io_context_manager.py @@ -1,5 +1,5 @@ # IRIS Source Code -# Copyright (C) 2024 - DFIR-IRIS +# Copyright (C) 2023 - DFIR-IRIS # contact@dfir-iris.org # # This program is free software; you can redistribute it and/or @@ -15,15 +15,18 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -from flask_socketio import join_room -from app import socket_io -from app.blueprints.access_controls import ac_socket_requires -from app.models.authorization import CaseAccessLevel +from socket_io_client import SocketIOClient -@socket_io.on('join-case-obj-notif') -@ac_socket_requires(CaseAccessLevel.full_access) -def socket_join_case_obj_notif(data): - room = data['channel'] - join_room(room=room) +class SocketIOContextManager: + + def __init__(self, url, api_key): + self._client = SocketIOClient(url, api_key) + + def __enter__(self) -> SocketIOClient: + self._client.connect() + return self._client + + def __exit__(self, type, value, traceback): + self._client.disconnect() diff --git a/tests/tests_rest_events.py b/tests/tests_rest_events.py index 316d24edf..c049b37b8 100644 --- a/tests/tests_rest_events.py +++ b/tests/tests_rest_events.py @@ -17,6 +17,7 @@ # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. from unittest import TestCase + from iris import Iris _IDENTIFIER_FOR_NONEXISTENT_OBJECT = 123456789 @@ -85,6 +86,21 @@ def test_create_event_should_set_event_parent_id_when_provided(self): response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() self.assertEqual(identifier, response['parent_event_id']) + def test_create_event_should_change_send_socket_io_message(self): + case_identifier = self._subject.create_dummy_case() + + with self._subject.get_socket_io_client() as socket_io_client: + socket_io_client.emit('join-case-obj-notif', f'case-{case_identifier}') + + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + + message = socket_io_client.receive() + self.assertEqual(identifier, message['object_id']) + def test_get_event_should_return_200(self): case_identifier = self._subject.create_dummy_case() body = {'event_title': 'title', 'event_category_id': 1, @@ -168,3 +184,178 @@ def test_get_event_should_return_children_when_event_is_parent_of_another_event( child_identifier = response['event_id'] response = self._subject.get(f'/api/v2/cases/{case_identifier}/events/{identifier}', body).json() self.assertEqual(child_identifier, response['children'][0]['event_id']) + + def test_update_event_should_return_200(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(200, response.status_code) + + def test_update_event_should_change_event_title(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body).json() + self.assertEqual('new title', response['event_title']) + + def test_update_event_should_change_send_socket_io_message(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + + with self._subject.get_socket_io_client() as socket_io_client: + socket_io_client.emit('join-case-obj-notif', f'case-{case_identifier}') + + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body).json() + + message = socket_io_client.receive() + + self.assertEqual(identifier, message['object_id']) + + def test_socket_io_join_should_not_fail(self): + case_identifier = self._subject.create_dummy_case() + + with self._subject.get_socket_io_client() as socket_io_client: + socket_io_client.emit('join', f'case-{case_identifier}') + message = socket_io_client.receive() + self.assertEqual('administrator just joined', message['message']) + + def test_update_event_should_return_403_when_user_has_no_permission_to_access_case(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + + user = self._subject.create_dummy_user() + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = user.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(403, response.status_code) + + def test_update_event_should_return_404_when_event_does_not_exist(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{_IDENTIFIER_FOR_NONEXISTENT_OBJECT}', body) + self.assertEqual(404, response.status_code) + + def test_update_event_should_return_404_when_case_does_not_exist(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{_IDENTIFIER_FOR_NONEXISTENT_OBJECT}/events/{identifier}', body) + self.assertEqual(404, response.status_code) + + def test_update_event_should_return_400_when_event_date_format_is_incorrect(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '1744181930.204785', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(400, response.status_code) + + def test_update_event_should_return_400_when_case_identifier_does_not_match_event_case(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + case_identifier2 = self._subject.create_dummy_case() + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier2}/events/{identifier}', body) + self.assertEqual(400, response.status_code) + + def test_update_event_should_return_400_when_field_event_category_id_is_missing(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(400, response.status_code) + + def test_update_event_should_return_400_when_field_event_assets_is_missing(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_iocs': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(400, response.status_code) + + def test_update_event_should_return_400_when_field_event_iocs_is_missing(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': []} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body) + self.assertEqual(400, response.status_code) + + def test_update_event_should_set_event_parent_id_when_provided(self): + case_identifier = self._subject.create_dummy_case() + body = {'event_title': 'title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + parent_event_identifier = response['event_id'] + body = {'event_title': 'title2', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': []} + response = self._subject.create(f'/api/v2/cases/{case_identifier}/events', body).json() + identifier = response['event_id'] + body = {'event_title': 'new title', 'event_category_id': 1, + 'event_date': '2025-03-26T00:00:00.000', 'event_tz': '+00:00', + 'event_assets': [], 'event_iocs': [], + 'parent_event_id': parent_event_identifier} + response = self._subject.update(f'/api/v2/cases/{case_identifier}/events/{identifier}', body).json() + self.assertEqual(parent_event_identifier, response['parent_event_id']) diff --git a/tests/tests_rest_evidences.py b/tests/tests_rest_evidences.py index 698453143..a4ee16134 100644 --- a/tests/tests_rest_evidences.py +++ b/tests/tests_rest_evidences.py @@ -262,8 +262,8 @@ def test_update_evidence_should_return_400_when_case_identifier_does_not_match_e body = {'filename': 'filename'} response = self._subject.create(f'/api/v2/cases/{case_identifier}/evidences', body).json() identifier = response['id'] - body = {'filename': 'filename2'} case_identifier2 = self._subject.create_dummy_case() + body = {'filename': 'filename2'} response = self._subject.update(f'/api/v2/cases/{case_identifier2}/evidences/{identifier}', body) self.assertEqual(400, response.status_code) diff --git a/tests/tests_rest_notes.py b/tests/tests_rest_notes.py index 5647cc750..089004487 100644 --- a/tests/tests_rest_notes.py +++ b/tests/tests_rest_notes.py @@ -256,3 +256,12 @@ def test_get_note_should_return_404_when_note_is_deleted(self): self._subject.delete(f'/api/v2/cases/{case_identifier}/notes/{identifier}') response = self._subject.get(f'/api/v2/cases/{case_identifier}/notes/{identifier}') self.assertEqual(404, response.status_code) + + def test_socket_io_join_notes_overview_should_not_fail(self): + case_identifier = self._subject.create_dummy_case() + + with self._subject.get_socket_io_client() as socket_io_client: + socket_io_client.emit('join-notes-overview', f'case-{case_identifier}-notes') + message = socket_io_client.receive() + self.assertEqual('administrator', message['user']) +