Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name: Python tests
name: CI

on:
push:
branches: [main]
pull_request:
workflow_dispatch:

jobs:
validate-branch-target:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
create-github-release:
needs: check-tests
if: github.ref_type == 'tag' && github.event.base_ref == 'refs/heads/main'
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition github.event.base_ref == 'refs/heads/main' will not work as expected for tag push events. When a tag is pushed, github.event.base_ref is typically empty or undefined. This condition should likely check that the tag was created from the main branch using a different approach, such as checking out the repository and verifying the tag points to a commit on main, or removing this condition entirely if tags should only be created from main by policy.

Suggested change
if: github.ref_type == 'tag' && github.event.base_ref == 'refs/heads/main'
if: github.ref_type == 'tag'

Copilot uses AI. Check for mistakes.
runs-on: ubuntu-22.04
permissions:
contents: write # Required for creating releases
Expand Down
17 changes: 17 additions & 0 deletions mageflow/invokers/hatchet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ async def run_success(self, result: Any) -> bool:
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
if task_id:
current_task = await TaskSignature.get_safe(task_id)
if current_task is None:
# Task was deleted before success callback could be triggered
# This can happen if TTL expired or task was manually removed
import logging
logging.warning(
Comment on lines +44 to +45
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logging module is imported inside the function body rather than at the module level. This is inefficient as the import statement will be executed every time the function is called when a task is not found. Move the import statement to the top of the file with other imports.

Copilot uses AI. Check for mistakes.
f"run_success: TaskSignature {task_id} not found in Redis - "
f"success callbacks will not be triggered!"
)
return False
task_success_workflows = current_task.activate_success(result)
success_publish_tasks.append(asyncio.create_task(task_success_workflows))

Expand All @@ -51,6 +60,14 @@ async def run_error(self) -> bool:
task_id = self.task_data.get(TASK_ID_PARAM_NAME, None)
if task_id:
current_task = await TaskSignature.get_safe(task_id)
if current_task is None:
# Task was deleted before error callback could be triggered
import logging
logging.warning(
Comment on lines +65 to +66
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logging module is imported inside the function body rather than at the module level. This is inefficient as the import statement will be executed every time the function is called when a task is not found. Move the import statement to the top of the file with other imports.

Copilot uses AI. Check for mistakes.
f"run_error: TaskSignature {task_id} not found in Redis - "
f"error callbacks will not be triggered!"
)
return False
task_error_workflows = current_task.activate_error(self.message)
error_publish_tasks.append(asyncio.create_task(task_error_workflows))

Expand Down
41 changes: 34 additions & 7 deletions mageflow/swarm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,30 +218,57 @@ async def add_to_running_tasks(self, task: TaskSignatureConvertible) -> bool:
await self.tasks_left_to_run.aappend(task.key)
return False

async def fill_running_tasks(self) -> int:
async def fill_running_tasks(self, logger=None) -> int:
resource_to_run = self.config.max_concurrency - self.current_running_tasks
if resource_to_run <= 0:
return 0
num_of_task_to_run = min(resource_to_run, len(self.tasks_left_to_run))
task_ids = await asyncio.gather(
*[self.tasks_left_to_run.apop() for i in range(num_of_task_to_run)]
)

# Log popped task IDs for debugging
valid_task_ids = [tid for tid in task_ids if tid]
if logger:
logger(f"fill_running_tasks: popped {len(valid_task_ids)} task IDs from queue: {valid_task_ids}")

tasks = await asyncio.gather(
*[
BatchItemTaskSignature.get_safe(task_id)
for task_id in task_ids
if task_id # Check not None
for task_id in valid_task_ids
]
)

# Identify and log missing tasks
missing_task_ids = []
valid_tasks = []
for task_id, task in zip(valid_task_ids, tasks):
if task is None:
missing_task_ids.append(task_id)
if logger:
logger(f"WARN: BatchItemTaskSignature {task_id} not found in Redis - task lost!")
else:
valid_tasks.append(task)

if missing_task_ids and logger:
logger(f"MAJOR: {len(missing_task_ids)} tasks were popped but not found: {missing_task_ids}")

publish_coroutine = [
next_task.aio_run_no_wait(EmptyModel())
for next_task in tasks
if next_task is not None
for next_task in valid_tasks
]

if logger:
logger(f"fill_running_tasks: publishing {len(publish_coroutine)} tasks")

await asyncio.gather(*publish_coroutine)

if len(tasks) != len(publish_coroutine):
raise MissingSwarmItemError(f"swarm item was deleted before swarm is done")
return len(tasks)
raise MissingSwarmItemError(
f"swarm item was deleted before swarm is done. "
f"Missing: {missing_task_ids}"
)
Comment on lines 266 to +270
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error check is comparing the wrong variables. The condition checks if len(tasks) differs from len(publish_coroutine), but publish_coroutine is built from valid_tasks, not tasks. This check will always pass even when tasks are missing. The condition should compare len(valid_tasks) with len(publish_coroutine) (which should always be equal), or better yet, compare len(tasks) with len(valid_tasks) to detect missing tasks.

Copilot uses AI. Check for mistakes.
return len(valid_tasks)

async def decrease_running_tasks_count(self):
await self.current_running_tasks.increase(-1)
Expand Down
84 changes: 75 additions & 9 deletions mageflow/swarm/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,50 @@ async def swarm_start_tasks(msg: EmptyModel, ctx: Context):
task_data = HatchetInvoker(msg, ctx).task_ctx
swarm_task_id = task_data[SWARM_TASK_ID_PARAM_NAME]
swarm_task = await SwarmTaskSignature.get_safe(swarm_task_id)

if swarm_task is None:
ctx.log(f"MAJOR - Swarm {swarm_task_id} not found in Redis!")
raise MissingSwarmItemError(f"Swarm {swarm_task_id} not found")

ctx.log(
f"Swarm state: total_tasks={len(swarm_task.tasks)} "
f"max_concurrency={swarm_task.config.max_concurrency} "
f"is_closed={swarm_task.is_swarm_closed}"
)

if swarm_task.has_swarm_started:
ctx.log(f"Swarm task started but already running {msg}")
return

tasks_ids_to_run = swarm_task.tasks[: swarm_task.config.max_concurrency]
tasks_left_to_run = swarm_task.tasks[swarm_task.config.max_concurrency :]

ctx.log(
f"Initial batch: {len(tasks_ids_to_run)} tasks to run immediately, "
f"{len(tasks_left_to_run)} tasks queued"
)

async with swarm_task.pipeline() as swarm_task:
await swarm_task.tasks_left_to_run.aclear()
await swarm_task.tasks_left_to_run.aextend(tasks_left_to_run)

tasks_to_run = await asyncio.gather(
*[TaskSignature.get_safe(task_id) for task_id in tasks_ids_to_run]
)
await asyncio.gather(*[task.aio_run_no_wait(msg) for task in tasks_to_run])
ctx.log(f"Swarm task started with tasks {tasks_ids_to_run} {msg}")

# Check for missing tasks
missing_ids = [
task_id for task_id, task in zip(tasks_ids_to_run, tasks_to_run)
if task is None
]
if missing_ids:
ctx.log(f"WARN: {len(missing_ids)} initial tasks not found: {missing_ids}")

valid_tasks = [t for t in tasks_to_run if t is not None]
ctx.log(f"Publishing {len(valid_tasks)} initial tasks")

await asyncio.gather(*[task.aio_run_no_wait(msg) for task in valid_tasks])
ctx.log(f"Swarm task started with {len(valid_tasks)} tasks {msg}")
except Exception:
ctx.log(f"MAJOR - Error in swarm start tasks")
raise
Expand All @@ -47,17 +78,23 @@ async def swarm_item_done(msg: SwarmResultsMessage, ctx: Context):
try:
swarm_task_id = task_data[SWARM_TASK_ID_PARAM_NAME]
swarm_item_id = task_data[SWARM_ITEM_TASK_ID_PARAM_NAME]
ctx.log(f"Swarm item done {swarm_item_id}")
ctx.log(f"Swarm item done: item={swarm_item_id} swarm={swarm_task_id}")

# Update swarm tasks
swarm_task = await SwarmTaskSignature.get_safe(swarm_task_id)
if swarm_task is None:
ctx.log(f"MAJOR - Swarm {swarm_task_id} not found! Item {swarm_item_id} completion lost!")
raise MissingSwarmItemError(f"Swarm {swarm_task_id} not found for item {swarm_item_id}")

res = msg.results
async with swarm_task.lock(save_at_end=False) as swarm_task:
ctx.log(f"Swarm item done {swarm_item_id} - saving results")
ctx.log(f"Swarm item done {swarm_item_id} - saving results (lock acquired)")
await swarm_task.finished_tasks.aappend(swarm_item_id)
await swarm_task.tasks_results.aappend(res)
ctx.log(f"Swarm item {swarm_item_id} added to finished_tasks")
await handle_finish_tasks(swarm_task, ctx, msg)
except Exception as e:
ctx.log(f"MAJOR - Error in swarm start item done")
ctx.log(f"MAJOR - Error in swarm item done: {type(e).__name__}: {e}")
raise
finally:
await TaskSignature.try_remove(task_id)
Expand All @@ -69,11 +106,18 @@ async def swarm_item_failed(msg: EmptyModel, ctx: Context):
try:
swarm_task_key = task_data[SWARM_TASK_ID_PARAM_NAME]
swarm_item_key = task_data[SWARM_ITEM_TASK_ID_PARAM_NAME]
ctx.log(f"Swarm item failed {swarm_item_key}")
ctx.log(f"Swarm item failed: item={swarm_item_key} swarm={swarm_task_key}")

# Check if the swarm should end
swarm_task = await SwarmTaskSignature.get_safe(swarm_task_key)
if swarm_task is None:
ctx.log(f"MAJOR - Swarm {swarm_task_key} not found! Item {swarm_item_key} failure lost!")
raise MissingSwarmItemError(f"Swarm {swarm_task_key} not found for failed item {swarm_item_key}")

async with swarm_task.lock(save_at_end=False) as swarm_task:
await swarm_task.add_to_failed_tasks(swarm_item_key)
ctx.log(f"Swarm item {swarm_item_key} added to failed_tasks (total failed: {len(swarm_task.failed_tasks)})")

should_stop_after_failures = (
swarm_task.config.stop_after_n_failures is not None
)
Expand All @@ -91,7 +135,7 @@ async def swarm_item_failed(msg: EmptyModel, ctx: Context):

await handle_finish_tasks(swarm_task, ctx, msg)
except Exception as e:
ctx.log(f"MAJOR - Error in swarm item failed")
ctx.log(f"MAJOR - Error in swarm item failed: {type(e).__name__}: {e}")
raise
finally:
await TaskSignature.try_remove(task_key)
Expand All @@ -100,15 +144,37 @@ async def swarm_item_failed(msg: EmptyModel, ctx: Context):
async def handle_finish_tasks(
swarm_task: SwarmTaskSignature, ctx: Context, msg: BaseModel
):
# Log current state before decrementing
ctx.log(
f"handle_finish_tasks: swarm={swarm_task.key} "
f"running={swarm_task.current_running_tasks} "
f"finished={len(swarm_task.finished_tasks)} "
f"failed={len(swarm_task.failed_tasks)} "
f"queued={len(swarm_task.tasks_left_to_run)} "
f"total={len(swarm_task.tasks)} "
f"closed={swarm_task.is_swarm_closed}"
)

await swarm_task.decrease_running_tasks_count()
num_task_started = await swarm_task.fill_running_tasks()
num_task_started = await swarm_task.fill_running_tasks(logger=ctx.log)
if num_task_started:
ctx.log(f"Swarm item started new task {num_task_started}/{swarm_task.key}")
else:
ctx.log(f"Swarm item no new task to run in {swarm_task.key}")

# Check if the swarm should end
if await swarm_task.is_swarm_done():
is_done = await swarm_task.is_swarm_done()
ctx.log(f"is_swarm_done check: {is_done} (closed={swarm_task.is_swarm_closed})")

if is_done:
ctx.log(f"Swarm item done - closing swarm {swarm_task.key}")
await swarm_task.activate_success(msg)
ctx.log(f"Swarm item done - closed swarm {swarm_task.key}")
else:
# Log why swarm is not done yet
done_tasks = set(swarm_task.finished_tasks) | set(swarm_task.failed_tasks)
missing = set(swarm_task.tasks) - done_tasks
ctx.log(
f"Swarm not done yet: {len(missing)} tasks remaining. "
f"Missing task IDs: {list(missing)[:5]}{'...' if len(missing) > 5 else ''}"
)
1 change: 0 additions & 1 deletion tests/integration/hatchet/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from hatchet_sdk import Hatchet
from hatchet_sdk.clients.rest import V1TaskStatus, V1TaskSummary

from mageflow.chain.model import ChainTaskSignature
from mageflow.signature.consts import TASK_ID_PARAM_NAME, MAGEFLOW_TASK_INITIALS
from mageflow.signature.model import TaskSignature
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/chain/test__chain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio

import pytest

import mageflow
import pytest
from mageflow.signature.model import TaskSignature
from tests.integration.hatchet.assertions import (
assert_redis_is_clean,
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/chain/test_edge_cases.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio

import pytest

import mageflow
import pytest
from tests.integration.hatchet.assertions import (
get_runs,
assert_signature_done,
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/hatchet/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
from threading import Thread
from typing import Generator, Callable, AsyncGenerator

import mageflow
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'mageflow' is imported with both 'import' and 'import from'.

Suggested change
import mageflow

Copilot uses AI. Check for mistakes.
import psutil
import pytest
import pytest_asyncio
import requests
from hatchet_sdk import Hatchet
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from redis.asyncio.client import Redis

import mageflow
from mageflow import Mageflow
from mageflow.client import HatchetMageflow
from mageflow.startup import mageflow_config, init_mageflow
from mageflow.task.model import HatchetTaskModel
from redis.asyncio.client import Redis
from tests.integration.hatchet.worker import (
config_obj,
task1,
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Annotated

from pydantic import BaseModel, Field

from mageflow.models.message import ReturnValue
from pydantic import BaseModel, Field


class ContextMessage(BaseModel):
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/signature/test__signature.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio

import pytest

import mageflow
import pytest
from tests.integration.hatchet.assertions import (
assert_task_done,
assert_redis_is_clean,
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/signature/test_edge_case.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
from datetime import datetime

import pytest

import mageflow
import pytest
from tests.integration.hatchet.assertions import (
get_runs,
assert_signature_done,
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/signature/test_stop_resume.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio

import pytest

import mageflow
import pytest
from mageflow.signature.model import TaskSignature
from tests.integration.hatchet.assertions import (
assert_redis_is_clean,
Expand Down
5 changes: 1 addition & 4 deletions tests/integration/hatchet/swarm/test__swarm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio

import mageflow
import pytest
from hatchet_sdk.clients.rest import V1TaskStatus
from hatchet_sdk.runnables.types import EmptyModel

import mageflow
from mageflow.signature.model import TaskSignature
from mageflow.swarm.model import SwarmConfig, BatchItemTaskSignature
from tests.integration.hatchet.assertions import (
Expand All @@ -14,7 +12,6 @@
assert_signature_done,
map_wf_by_id,
assert_overlaps_leq_k_workflows,
is_wf_done,
find_sub_calls_by_signature,
)
from tests.integration.hatchet.conftest import HatchetInitData
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/swarm/test_edge_cases.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio

import pytest

import mageflow
import pytest
from mageflow.signature.model import TaskSignature
from mageflow.swarm.model import BatchItemTaskSignature, SwarmConfig
from tests.integration.hatchet.assertions import get_runs, assert_swarm_task_done
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/hatchet/swarm/test_stop_resume.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
from datetime import datetime

import pytest

import mageflow
import pytest
from mageflow.signature.model import TaskSignature
from mageflow.swarm.model import SwarmConfig, BatchItemTaskSignature
from tests.integration.hatchet.assertions import (
Expand Down
Loading
Loading