diff --git a/dojo/celery.py b/dojo/celery.py index 93f2a1fd150..336fd420aca 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -76,13 +76,44 @@ def apply_async(self, args=None, kwargs=None, **options): return super().apply_async(args=args, kwargs=kwargs, **options) -class PgHistoryTask(DojoAsyncTask): +class PluggableContextTask(DojoAsyncTask): + + """ + Extends DojoAsyncTask with pluggable context managers loaded from settings. + + CELERY_TASK_CONTEXT_MANAGERS is a list of dotted paths to callables that + return context managers. Each task execution is wrapped in all of them. + This replaces the celery signal-based approach (task_prerun/task_postrun) + which does not work reliably with prefork worker pools. + """ + + def __call__(self, *args, **kwargs): + from contextlib import ExitStack # noqa: PLC0415 + + from django.utils.module_loading import import_string # noqa: PLC0415 + + cm_paths = getattr(settings, "CELERY_TASK_CONTEXT_MANAGERS", []) + if not cm_paths: + return super().__call__(*args, **kwargs) + + # ExitStack ensures all entered context managers are properly exited + # (via __exit__) even if the task raises an exception, so cleanup + # and batch dispatch always happen. + with ExitStack() as stack: + for path in cm_paths: + cm_factory = import_string(path) + stack.enter_context(cm_factory()) + return super().__call__(*args, **kwargs) + + +class PgHistoryTask(PluggableContextTask): """ Custom Celery base task that automatically applies pghistory context. - This class inherits from DojoAsyncTask to provide: + This class inherits from PluggableContextTask to provide: - User context injection and task tracking (from DojoAsyncTask) + - Pluggable context managers from settings (from PluggableContextTask) - Automatic pghistory context application (from this class) When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 8dd2aa4a4f9..71536d6af17 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -273,6 +273,7 @@ def process_findings( product_grading_option=True, issue_updater_option=True, push_to_jira=push_to_jira, + sync=kwargs.get("sync", False), ) # No chord: tasks are dispatched immediately above per batch diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 5075eb6409b..9c45ab46de9 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -441,6 +441,7 @@ def process_findings( issue_updater_option=True, push_to_jira=push_to_jira, jira_instance_id=getattr(self.jira_instance, "id", None), + sync=kwargs.get("sync", False), ) # No chord: tasks are dispatched immediately above per batch