Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ version = "0.4.1"
description = "Self-contained distributed community captioning system"
readme = "README.md"
requires-python = ">=3.11,<3.13"
license = {text = "MIT"}
authors = [
{name = "bghira", email = "bghira@users.github.com"},
]
license = { text = "MIT" }
authors = [{ name = "bghira", email = "bghira@users.github.com" }]
keywords = ["captioning", "distributed", "vllm", "dataset", "community"]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -70,8 +68,21 @@ target-version = ['py310']

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "B", "C90", "D"]
ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107"]
ignore = [
"D100",
"D101",
"D102",
"D103",
"D104",
"D105",
"D106",
"D107",
"D203",
"D213",
]

[tool.mypy]
python_version = "3.11"
Expand All @@ -81,4 +92,3 @@ disallow_untyped_defs = true

[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^1.1.0"

110 changes: 110 additions & 0 deletions src/caption_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,90 @@ def apply_cli_overrides(config: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return ConfigManager.merge_configs(config, overrides)


def validate_orchestrator_auth_config(
config_data: Dict[str, Any], base_config: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate and normalize orchestrator auth configuration.

Handles fallback from top-level auth to orchestrator.auth and validates
that auth configuration exists.

Args:
config_data: The orchestrator section config
base_config: The full loaded config (may contain top-level auth)

Returns:
Updated config_data with validated auth section

Raises:
ValueError: If no auth configuration is found
"""
# Check if auth is already in the orchestrator section
if "auth" in config_data:
auth_config = config_data["auth"]
console.print("[dim]Using auth config from orchestrator section[/dim]")
# Fallback to top-level auth with warning
elif "auth" in base_config:
auth_config = base_config["auth"]
console.print(
"[yellow]Warning: Found auth config at top level, should be under 'orchestrator' section[/yellow]"
)
console.print("[dim]Moving auth config to orchestrator.auth section[/dim]")
config_data["auth"] = auth_config
else:
# No auth config found anywhere
console.print("[red]Error: No auth configuration found[/red]")
console.print("[dim]Auth config must be present either as:[/dim]")
console.print("[dim] orchestrator.auth (recommended) or top-level auth (deprecated)[/dim]")
console.print("\n[cyan]Example auth configuration:[/cyan]")
console.print("orchestrator:")
console.print(" auth:")
console.print(" worker_tokens:")
console.print(" - token: 'your-worker-token'")
console.print(" name: 'Worker 1'")
console.print(" admin_tokens:")
console.print(" - token: 'your-admin-token'")
console.print(" name: 'Admin User'")
raise ValueError("Auth configuration is required for orchestrator")

# Validate auth structure
if not isinstance(auth_config, dict):
raise ValueError("Auth configuration must be a dictionary")

# Validate that at least one token type exists
token_types = ["worker_tokens", "admin_tokens", "monitor_tokens"]
has_tokens = False
for token_type in token_types:
if (
token_type in auth_config
and type(auth_config[token_type]).__name__ == "list"
and len(auth_config[token_type]) > 0
):
has_tokens = True
break

if not has_tokens:
console.print("[red]Error: Auth configuration must contain at least one token type[/red]")
console.print(
"[dim]Required token types: worker_tokens, admin_tokens, or monitor_tokens[/dim]"
)
raise ValueError("At least one token type must be configured")

# Warn about missing critical token types
if "worker_tokens" not in auth_config or not auth_config["worker_tokens"]:
console.print(
"[yellow]Warning: No worker tokens configured - workers won't be able to connect[/yellow]"
)

if "admin_tokens" not in auth_config or not auth_config["admin_tokens"]:
console.print(
"[yellow]Warning: No admin tokens configured - admin operations will be unavailable[/yellow]"
)

console.print(f"[green]✓ Auth validation passed[/green]")
return config_data


@click.group()
@click.option("--verbose", is_flag=True, help="Enable verbose logging")
@click.pass_context
Expand Down Expand Up @@ -228,6 +312,13 @@ def orchestrator(ctx, config: Optional[str], **kwargs):
else:
config_data = base_config

# Validate and normalize auth configuration
try:
config_data = validate_orchestrator_auth_config(config_data, base_config)
except ValueError as e:
console.print(f"[red]Configuration error: {e}[/red]")
sys.exit(1)

# Apply CLI overrides
if kwargs.get("port"):
config_data["port"] = kwargs["port"]
Expand Down Expand Up @@ -916,6 +1007,25 @@ def reload_config(
console.print("[red]Failed to load configuration[/red]")
sys.exit(1)

# Validate and normalize auth configuration for reload
console.print(f"[cyan]Validating configuration...[/cyan]")
if "orchestrator" in new_cfg:
orchestrator_config = new_cfg["orchestrator"]
try:
# Apply same auth validation as orchestrator startup
orchestrator_config = validate_orchestrator_auth_config(orchestrator_config, new_cfg)
new_cfg["orchestrator"] = orchestrator_config
except ValueError as e:
console.print(f"[red]Configuration validation error: {e}[/red]")
sys.exit(1)
else:
# If no orchestrator section, treat whole config as orchestrator config
try:
new_cfg = validate_orchestrator_auth_config(new_cfg, new_cfg)
except ValueError as e:
console.print(f"[red]Configuration validation error: {e}[/red]")
sys.exit(1)

ssl_context = _setup_ssl_context(server, no_verify_ssl)

async def send_reload():
Expand Down
36 changes: 19 additions & 17 deletions src/caption_flow/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,20 +621,20 @@ async def _send_monitor_initial_data(self, websocket: ServerConnection):
# Send current stats (already in memory)
stats_start = time.time()
await websocket.send(safe_json_dumps({"type": "stats", "data": self.stats}))
logger.debug(f"Monitor stats sent in {(time.time() - stats_start)*1000:.1f}ms")
logger.debug(f"Monitor stats sent in {(time.time() - stats_start) * 1000:.1f}ms")

# Get processor stats instead of chunk stats
processor_stats_start = time.time()
processor_stats = self.processor.get_stats()
logger.debug(
f"Processor stats retrieved in {(time.time() - processor_stats_start)*1000:.1f}ms"
f"Processor stats retrieved in {(time.time() - processor_stats_start) * 1000:.1f}ms"
)

stats_send_start = time.time()
await websocket.send(
safe_json_dumps({"type": "processor_stats", "data": processor_stats})
)
logger.debug(f"Processor stats sent in {(time.time() - stats_send_start)*1000:.1f}ms")
logger.debug(f"Processor stats sent in {(time.time() - stats_send_start) * 1000:.1f}ms")

if websocket not in self.monitors:
return
Expand All @@ -647,18 +647,18 @@ async def _send_monitor_initial_data(self, websocket: ServerConnection):
safe_json_dumps({"type": "leaderboard", "data": self._cached_leaderboard})
)
logger.debug(
f"Cached leaderboard sent in {(time.time() - cache_send_start)*1000:.1f}ms"
f"Cached leaderboard sent in {(time.time() - cache_send_start) * 1000:.1f}ms"
)
else:
# Schedule leaderboard update separately
leaderboard_task_start = time.time()
asyncio.create_task(self._send_leaderboard_to_monitor(websocket))
logger.debug(
f"Leaderboard task created in {(time.time() - leaderboard_task_start)*1000:.1f}ms"
f"Leaderboard task created in {(time.time() - leaderboard_task_start) * 1000:.1f}ms"
)

logger.debug(
f"Monitor initial data send completed in {(time.time() - total_start)*1000:.1f}ms"
f"Monitor initial data send completed in {(time.time() - total_start) * 1000:.1f}ms"
)

except websockets.exceptions.ConnectionClosed:
Expand All @@ -677,7 +677,7 @@ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
contributors_start = time.time()
contributors = await self.storage.get_top_contributors(10)
logger.debug(
f"Contributors retrieved in {(time.time() - contributors_start)*1000:.1f}ms"
f"Contributors retrieved in {(time.time() - contributors_start) * 1000:.1f}ms"
)

# Get worker counts in thread pool
Expand All @@ -690,7 +690,7 @@ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
),
)
logger.debug(
f"Worker counts retrieved in {(time.time() - worker_counts_start)*1000:.1f}ms"
f"Worker counts retrieved in {(time.time() - worker_counts_start) * 1000:.1f}ms"
)

# Build enhanced contributors list
Expand All @@ -707,7 +707,9 @@ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
),
}
enhanced_contributors.append(contrib_dict)
logger.debug(f"Enhanced contributors built in {(time.time() - build_start)*1000:.1f}ms")
logger.debug(
f"Enhanced contributors built in {(time.time() - build_start) * 1000:.1f}ms"
)

# Cache for future monitors
self._cached_leaderboard = enhanced_contributors
Expand All @@ -719,11 +721,11 @@ async def _send_monitor_leaderboard(self, websocket: ServerConnection):
safe_json_dumps({"type": "leaderboard", "data": enhanced_contributors})
)
logger.debug(
f"Leaderboard sent to monitor in {(time.time() - send_start)*1000:.1f}ms"
f"Leaderboard sent to monitor in {(time.time() - send_start) * 1000:.1f}ms"
)

logger.debug(
f"Leaderboard send to monitor completed in {(time.time() - total_start)*1000:.1f}ms"
f"Leaderboard send to monitor completed in {(time.time() - total_start) * 1000:.1f}ms"
)

except websockets.exceptions.ConnectionClosed:
Expand Down Expand Up @@ -779,7 +781,7 @@ async def _broadcast_leaderboard(self):
contributors_start = time.time()
contributors = await self.storage.get_top_contributors(10)
logger.debug(
f"Contributors retrieved for broadcast in {(time.time() - contributors_start)*1000:.1f}ms"
f"Contributors retrieved for broadcast in {(time.time() - contributors_start) * 1000:.1f}ms"
)

# Get worker counts
Expand All @@ -792,7 +794,7 @@ async def _broadcast_leaderboard(self):
),
)
logger.debug(
f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start)*1000:.1f}ms"
f"Worker counts retrieved for broadcast in {(time.time() - worker_counts_start) * 1000:.1f}ms"
)

# Build enhanced contributors list
Expand All @@ -810,7 +812,7 @@ async def _broadcast_leaderboard(self):
}
enhanced_contributors.append(contrib_dict)
logger.debug(
f"Enhanced contributors built for broadcast in {(time.time() - build_start)*1000:.1f}ms"
f"Enhanced contributors built for broadcast in {(time.time() - build_start) * 1000:.1f}ms"
)

# Cache it
Expand All @@ -822,7 +824,7 @@ async def _broadcast_leaderboard(self):
{"type": "leaderboard", "data": enhanced_contributors}
)
logger.debug(
f"Leaderboard message created in {(time.time() - message_create_start)*1000:.1f}ms"
f"Leaderboard message created in {(time.time() - message_create_start) * 1000:.1f}ms"
)

# Send to all monitors in parallel
Expand All @@ -849,10 +851,10 @@ async def send_leaderboard(monitor):
self.monitors -= disconnected

logger.debug(
f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start)*1000:.1f}ms"
f"Leaderboard sent to {len(monitors_copy)} monitors in {(time.time() - send_start) * 1000:.1f}ms"
)
logger.debug(
f"Leaderboard broadcast completed in {(time.time() - total_start)*1000:.1f}ms"
f"Leaderboard broadcast completed in {(time.time() - total_start) * 1000:.1f}ms"
)

except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/caption_flow/processors/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def log_memory(location: str):
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
logger.info(
f"Memory at {location}: RSS={mem_info.rss/1024/1024:.1f}MB, VMS={mem_info.vms/1024/1024:.1f}MB"
f"Memory at {location}: RSS={mem_info.rss / 1024 / 1024:.1f}MB, VMS={mem_info.vms / 1024 / 1024:.1f}MB"
)
# Force garbage collection
gc.collect()
Expand Down
23 changes: 23 additions & 0 deletions src/caption_flow/processors/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,29 @@ def handle_result(self, result: WorkResult) -> Dict[str, Any]:

return base_result

def cleanup(self):
"""Clean up resources."""
logger.info("Cleaning up orchestrator")

# Stop background threads
self.stop_creation.set()
if self.unit_creation_thread:
self.unit_creation_thread.join(timeout=5)

# Clean up HTTP server task if it exists
if hasattr(self, "http_server_task") and self.http_server_task:
try:
# Try to cancel the task - this should work across event loops
if not self.http_server_task.done():
self.http_server_task.cancel()
logger.info("Cancelled HTTP server task")
except Exception as e:
logger.debug(f"Error cleaning up HTTP server task: {e}")

# Flush final checkpoint on cleanup
if self.chunk_tracker:
self.chunk_tracker.flush()

def get_image_paths(self) -> List[Tuple[Path, int]]:
"""Get the list of discovered image paths and sizes."""
return self.all_images
Expand Down
Loading