diff --git a/demo/tests/test_process.py b/demo/tests/test_process.py index cf6ef9a..c2dc42e 100644 --- a/demo/tests/test_process.py +++ b/demo/tests/test_process.py @@ -37,6 +37,7 @@ def test_invoice_callbacks(self, debug_method): self.assertEqual(invoice.status, 'sent') self.assertEqual(debug_method.call_count, 5) expected_side_effects_kwargs = { + 'foo': 'bar', 'app_label': 'demo', 'model_name': 'invoice', 'instance_id': invoice.id, @@ -45,6 +46,7 @@ def test_invoice_callbacks(self, debug_method): 'transition': InvoiceProcess.transitions[3] } expected_callbacks_kwargs = { + 'foo': 'bar', 'app_label': 'demo', 'model_name': 'invoice', 'instance_id': invoice.id, @@ -79,6 +81,7 @@ def test_invoice_failure_callbacks(self, debug_method): self.assertEqual(invoice.status, 'failed') self.assertEqual(debug_method.call_count, 3) expected_side_effects_kwargs = { + 'foo': 'bar', 'app_label': 'demo', 'model_name': 'invoice', 'instance_id': invoice.id, @@ -87,6 +90,7 @@ def test_invoice_failure_callbacks(self, debug_method): 'transition': InvoiceProcess.transitions[5] } expected_callbacks_kwargs = { + 'foo': 'bar', 'app_label': 'demo', 'model_name': 'invoice', 'instance_id': invoice.id, diff --git a/django_logic_celery/commands.py b/django_logic_celery/commands.py index de25f6d..c9b966c 100644 --- a/django_logic_celery/commands.py +++ b/django_logic_celery/commands.py @@ -11,11 +11,11 @@ @shared_task(acks_late=True) def complete_transition(*args, **kwargs): """Completes transition """ - app_label = kwargs['app_label'] - model_name = kwargs['model_name'] - instance_id = kwargs['instance_id'] - transition = kwargs['transition'] - process_name = kwargs['process_name'] + app_label = kwargs.pop('app_label', None) + model_name = kwargs.pop('model_name', None) + instance_id = kwargs.pop('instance_id', None) + transition = kwargs.pop('transition', None) + process_name = kwargs.pop('process_name', None) app = apps.get_app_config(app_label) model = app.get_model(model_name) @@ -34,15 +34,15 @@ def fail_transition(task_id, *args, **kwargs): Make sure to catch all exceptions by this failure handler as otherwise it leads to the worker crash. """ - app_label = kwargs['app_label'] - model_name = kwargs['model_name'] - instance_id = kwargs['instance_id'] - transition = kwargs['transition'] + app_label = kwargs.pop('app_label', None) + model_name = kwargs.pop('model_name', None) + instance_id = kwargs.pop('instance_id', None) + transition = kwargs.pop('transition', None) try: app = apps.get_app_config(app_label) - model = app.get_model(kwargs['model_name']) - instance = model.objects.get(id=kwargs['instance_id']) + model = app.get_model(model_name) + instance = model.objects.get(id=instance_id) state = getattr(instance, kwargs['process_name']).state try: # If exception is raised in success callback, it will be passed through args @@ -106,7 +106,7 @@ class CeleryCommandMixin: def execute(self, state: State, **kwargs): if not self.commands: - return super().execute(state) + return super().execute(state, **kwargs) task_kwargs = self.get_task_kwargs(state, **kwargs) self.queue_task(task_kwargs) @@ -114,17 +114,15 @@ def execute(self, state: State, **kwargs): f'the following parameters {task_kwargs}') def get_task_kwargs(self, state: State, **kwargs): - task_kwargs = dict( + kwargs.update(dict( app_label=state.instance._meta.app_label, model_name=state.instance._meta.model_name, instance_id=state.instance.pk, process_name=state.process_name, - field_name=state.field_name - ) - if 'exception' in kwargs: - task_kwargs['exception'] = kwargs['exception'] + field_name=state.field_name, + )) - return task_kwargs + return kwargs def queue_task(self, task_kwargs): return NotImplementedError