-
Notifications
You must be signed in to change notification settings - Fork 0
Unit test for swarm #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
057095d
9f41728
21139b8
6b32073
71804c5
596548c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,15 @@ async def run_success(self, result: Any) -> bool: | |
| task_id = self.task_data.get(TASK_ID_PARAM_NAME, None) | ||
| if task_id: | ||
| current_task = await TaskSignature.get_safe(task_id) | ||
| if current_task is None: | ||
| # Task was deleted before success callback could be triggered | ||
| # This can happen if TTL expired or task was manually removed | ||
| import logging | ||
| logging.warning( | ||
|
Comment on lines
+44
to
+45
|
||
| f"run_success: TaskSignature {task_id} not found in Redis - " | ||
| f"success callbacks will not be triggered!" | ||
| ) | ||
| return False | ||
| task_success_workflows = current_task.activate_success(result) | ||
| success_publish_tasks.append(asyncio.create_task(task_success_workflows)) | ||
|
|
||
|
|
@@ -51,6 +60,14 @@ async def run_error(self) -> bool: | |
| task_id = self.task_data.get(TASK_ID_PARAM_NAME, None) | ||
| if task_id: | ||
| current_task = await TaskSignature.get_safe(task_id) | ||
| if current_task is None: | ||
| # Task was deleted before error callback could be triggered | ||
| import logging | ||
| logging.warning( | ||
|
Comment on lines
+65
to
+66
|
||
| f"run_error: TaskSignature {task_id} not found in Redis - " | ||
| f"error callbacks will not be triggered!" | ||
| ) | ||
| return False | ||
| task_error_workflows = current_task.activate_error(self.message) | ||
| error_publish_tasks.append(asyncio.create_task(task_error_workflows)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -218,30 +218,57 @@ async def add_to_running_tasks(self, task: TaskSignatureConvertible) -> bool: | |
| await self.tasks_left_to_run.aappend(task.key) | ||
| return False | ||
|
|
||
| async def fill_running_tasks(self) -> int: | ||
| async def fill_running_tasks(self, logger=None) -> int: | ||
| resource_to_run = self.config.max_concurrency - self.current_running_tasks | ||
| if resource_to_run <= 0: | ||
| return 0 | ||
| num_of_task_to_run = min(resource_to_run, len(self.tasks_left_to_run)) | ||
| task_ids = await asyncio.gather( | ||
| *[self.tasks_left_to_run.apop() for i in range(num_of_task_to_run)] | ||
| ) | ||
|
|
||
| # Log popped task IDs for debugging | ||
| valid_task_ids = [tid for tid in task_ids if tid] | ||
| if logger: | ||
| logger(f"fill_running_tasks: popped {len(valid_task_ids)} task IDs from queue: {valid_task_ids}") | ||
|
|
||
| tasks = await asyncio.gather( | ||
| *[ | ||
| BatchItemTaskSignature.get_safe(task_id) | ||
| for task_id in task_ids | ||
| if task_id # Check not None | ||
| for task_id in valid_task_ids | ||
| ] | ||
| ) | ||
|
|
||
| # Identify and log missing tasks | ||
| missing_task_ids = [] | ||
| valid_tasks = [] | ||
| for task_id, task in zip(valid_task_ids, tasks): | ||
| if task is None: | ||
| missing_task_ids.append(task_id) | ||
| if logger: | ||
| logger(f"WARN: BatchItemTaskSignature {task_id} not found in Redis - task lost!") | ||
| else: | ||
| valid_tasks.append(task) | ||
|
|
||
| if missing_task_ids and logger: | ||
| logger(f"MAJOR: {len(missing_task_ids)} tasks were popped but not found: {missing_task_ids}") | ||
|
|
||
| publish_coroutine = [ | ||
| next_task.aio_run_no_wait(EmptyModel()) | ||
| for next_task in tasks | ||
| if next_task is not None | ||
| for next_task in valid_tasks | ||
| ] | ||
|
|
||
| if logger: | ||
| logger(f"fill_running_tasks: publishing {len(publish_coroutine)} tasks") | ||
|
|
||
| await asyncio.gather(*publish_coroutine) | ||
|
|
||
| if len(tasks) != len(publish_coroutine): | ||
| raise MissingSwarmItemError(f"swarm item was deleted before swarm is done") | ||
| return len(tasks) | ||
| raise MissingSwarmItemError( | ||
| f"swarm item was deleted before swarm is done. " | ||
| f"Missing: {missing_task_ids}" | ||
| ) | ||
|
Comment on lines
266
to
+270
|
||
| return len(valid_tasks) | ||
|
|
||
| async def decrease_running_tasks_count(self): | ||
| await self.current_running_tasks.increase(-1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -12,19 +12,18 @@ | |||
| from threading import Thread | ||||
| from typing import Generator, Callable, AsyncGenerator | ||||
|
|
||||
| import mageflow | ||||
|
||||
| import mageflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
github.event.base_ref == 'refs/heads/main'will not work as expected for tag push events. When a tag is pushed,github.event.base_refis typically empty or undefined. This condition should likely check that the tag was created from the main branch using a different approach, such as checking out the repository and verifying the tag points to a commit on main, or removing this condition entirely if tags should only be created from main by policy.