Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/caption_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def _display_abandoned_chunks(abandoned_chunks, fix, tracker):
if fix:
console.print("\n[yellow]Resetting abandoned chunks to pending...[/yellow]")
for chunk_id, _, _ in abandoned_chunks:
tracker.mark_failed(chunk_id)
tracker.mark_pending(chunk_id)
console.print(f"[green]✓ Reset {len(abandoned_chunks)} chunks[/green]")


Expand Down Expand Up @@ -1165,7 +1165,7 @@ def _cross_check_storage(storage, tracker, fix):
if fix:
console.print("[yellow]Resetting these chunks to pending...[/yellow]")
for chunk_id in missing_in_storage:
tracker.mark_failed(chunk_id)
tracker.mark_pending(chunk_id)
console.print(f"[green]✓ Reset {len(missing_in_storage)} chunks[/green]")

if missing_in_tracker:
Expand All @@ -1180,10 +1180,16 @@ def _cross_check_storage(storage, tracker, fix):
@main.command()
@click.option("--data-dir", default="./caption_data", help="Storage directory")
@click.option("--checkpoint-dir", default="./checkpoints", help="Checkpoint directory")
@click.option("--fix", is_flag=True, help="Fix issues by resetting abandoned chunks")
@click.option(
"--fix", is_flag=True, help="Fix issues by resetting abandoned chunks (STOP ORCHESTRATOR FIRST)"
)
@click.option("--verbose", is_flag=True, help="Show detailed information")
def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
"""Scan for sparse or abandoned chunks and optionally fix them."""
"""Scan for sparse or abandoned chunks and optionally fix them.

WARNING: If using --fix, stop the orchestrator first to avoid conflicts.
The --fix option preserves partial progress but requires orchestrator restart.
"""
from .storage import StorageManager
from .utils.chunk_tracker import ChunkTracker

Expand Down
10 changes: 7 additions & 3 deletions src/caption_flow/processors/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:

# Store storage reference for chunk state synchronization
self.storage = storage
storage_cfg = cfg.get("storage", {})

# Dataset configuration
dataset_cfg = cfg.get("dataset", {})
Expand All @@ -172,7 +173,7 @@ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)

# Initialize chunk tracking
self.checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
self.checkpoint_dir = Path(storage_cfg.get("checkpoint_dir", "./checkpoints"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")

Expand Down Expand Up @@ -792,8 +793,11 @@ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
self.assigned_units[worker_id].discard(unit_id)
self.pending_units.append(unit_id)

if self.chunk_tracker:
self.chunk_tracker.mark_failed(unit_id)
# NOTE: We don't call chunk_tracker.mark_failed() here because that would
# reset the entire chunk to unprocessed, losing any partial progress that
# was already recorded via handle_result(). The chunk tracker should retain
# the processed ranges and only make the remaining unprocessed items available
# for retry when the unit is reassigned.

def release_assignments(self, worker_id: str) -> None:
"""Release all assignments for a disconnected worker."""
Expand Down
103 changes: 86 additions & 17 deletions src/caption_flow/processors/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
logger.info(f"Root path: {self.dataset_path}, recursive: {self.recursive}")

# Initialize chunk tracking
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
storage_cfg = cfg.get("storage", {})
self.checkpoint_dir = Path(storage_cfg.get("checkpoint_dir", "./checkpoints"))
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")

# Discover images
self._discover_images()
Expand Down Expand Up @@ -388,6 +389,30 @@ def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
unit = self.work_units.get(unit_id)

if unit:
# Update the unit's unprocessed_ranges based on current chunk tracker state
# This is important for units that were marked as failed after partial success
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
chunk_state = self.chunk_tracker.chunks[unit_id]
relative_unprocessed_ranges = chunk_state.get_unprocessed_ranges()

# Convert relative ranges to absolute ranges
unprocessed_ranges = []
for start, end in relative_unprocessed_ranges:
abs_start = chunk_state.start_index + start
abs_end = chunk_state.start_index + end
unprocessed_ranges.append((abs_start, abs_end))

# If no unprocessed ranges, skip this unit (it's already complete)
if not unprocessed_ranges:
logger.debug(
"Skipping unit %s - no unprocessed ranges (already complete)",
unit_id,
)
continue

# Update the work unit with current unprocessed ranges
unit.data["unprocessed_ranges"] = unprocessed_ranges

self.assigned_units[worker_id].add(unit_id)
assigned.append(unit)
logger.debug("Assigning unit %s to worker %s", unit_id, worker_id)
Expand Down Expand Up @@ -416,8 +441,11 @@ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
self.assigned_units[worker_id].discard(unit_id)
self.pending_units.append(unit_id)

if self.chunk_tracker:
self.chunk_tracker.mark_failed(unit_id)
# NOTE: We don't call chunk_tracker.mark_failed() here because that would
# reset the entire chunk to unprocessed, losing any partial progress that
# was already recorded via handle_result(). The chunk tracker should retain
# the processed ranges and only make the remaining unprocessed items available
# for retry when the unit is reassigned.

def release_assignments(self, worker_id: str) -> None:
"""Release all assignments for a disconnected worker."""
Expand Down Expand Up @@ -500,25 +528,66 @@ def handle_result(self, result: WorkResult) -> Dict[str, Any]:
"""Handle result processing."""
base_result = super().handle_result(result)

# Track processed items
# Track processed items - but only items that actually produced successful outputs
if self.chunk_tracker:
if "item_indices" not in result.metadata:
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
indices = result.metadata["item_indices"]

if indices:
indices.sort()
reported_indices = result.metadata["item_indices"]

# Filter indices to only include items that actually have outputs/succeeded
# This is a workaround for workers that incorrectly report failed items as processed
successful_indices = []
if reported_indices and result.outputs:
# Check if we have per-item success information
if "successful_items" in result.metadata:
successful_indices = result.metadata["successful_items"]
else:
# Fallback: assume all reported indices were successful
# unless we can detect failures from the outputs
successful_indices = reported_indices

# If outputs indicate some items failed, try to filter them out
if isinstance(result.outputs, dict):
# Check for structured errors
if "errors" in result.outputs:
errors = result.outputs["errors"]
if isinstance(errors, list):
failed_indices = []
for error in errors:
if isinstance(error, dict) and "item_index" in error:
failed_indices.append(error["item_index"])
# Remove failed indices
successful_indices = [
idx for idx in reported_indices if idx not in failed_indices
]

# Also check if the number of successful outputs doesn't match reported indices
# This handles cases where errors are logged but not structured in outputs
elif "captions" in result.outputs:
captions = result.outputs["captions"]
if isinstance(captions, list):
# If we have fewer captions than reported indices, some items failed
num_captions = len(
[c for c in captions if c is not None and c != ""]
)
if num_captions < len(reported_indices):
# Conservative approach: only mark the first N items as successful
# where N = number of actual successful outputs
successful_indices = reported_indices[:num_captions]

if successful_indices:
successful_indices.sort()
ranges = []
start = indices[0]
end = indices[0]
start = successful_indices[0]
end = successful_indices[0]

for i in range(1, len(indices)):
if indices[i] == end + 1:
end = indices[i]
for i in range(1, len(successful_indices)):
if successful_indices[i] == end + 1:
end = successful_indices[i]
else:
ranges.append((start, end))
start = indices[i]
end = indices[i]
start = successful_indices[i]
end = successful_indices[i]

ranges.append((start, end))

Expand Down
10 changes: 7 additions & 3 deletions src/caption_flow/processors/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:

cfg = config.config
dataset_cfg = cfg.get("dataset", {})
storage_cfg = cfg.get("storage", {})
self.dataset_path = dataset_cfg.get("dataset_path")
metadata_path = dataset_cfg.get("metadata_path", None)

Expand Down Expand Up @@ -83,7 +84,7 @@ def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
logger.info(f"Dataset discovered: {self.dataset.num_shards} shards")

# Initialize chunk tracker
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
checkpoint_dir = Path(storage_cfg.get("checkpoint_dir", "./checkpoints"))
checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")

Expand Down Expand Up @@ -378,8 +379,11 @@ def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
self.assigned_units[worker_id].discard(unit_id)
self.pending_units.append(unit_id)

if self.chunk_tracker:
self.chunk_tracker.mark_failed(unit_id)
# NOTE: We don't call chunk_tracker.mark_failed() here because that would
# reset the entire chunk to unprocessed, losing any partial progress that
# was already recorded via handle_result(). The chunk tracker should retain
# the processed ranges and only make the remaining unprocessed items available
# for retry when the unit is reassigned.

def release_assignments(self, worker_id: str) -> None:
"""Release all assignments for a disconnected worker."""
Expand Down
16 changes: 12 additions & 4 deletions tests/test_duplicate_job_assignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def orchestrator_config(temp_checkpoint_dir):
"dataset_config": None, # Will auto-detect
"dataset_split": None, # Will auto-detect
},
"checkpoint_dir": str(temp_checkpoint_dir),
"storage": {
"checkpoint_dir": str(temp_checkpoint_dir),
},
"chunk_size": 1000, # Small chunks for testing
"min_chunk_buffer": 5,
"chunk_buffer_multiplier": 2,
Expand Down Expand Up @@ -910,7 +912,9 @@ async def test_concurrent_workers_same_token_real_storage(self, temp_checkpoint_
"dataset_config": None,
"dataset_split": None,
},
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints"),
"storage": {
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints"),
},
"chunk_size": 100, # Small chunks
"min_chunk_buffer": 10,
"chunk_buffer_multiplier": 2,
Expand Down Expand Up @@ -1321,7 +1325,9 @@ async def test_relative_absolute_index_misalignments(temp_checkpoint_dir):
"processor_type": "huggingface_datasets",
"dataset_path": "terminusresearch/pexels-metadata-1.71M",
},
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints2"),
"storage": {
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints2"),
},
"chunk_size": 100,
}

Expand Down Expand Up @@ -1494,7 +1500,9 @@ async def test_huggingface_chunk_start_index_bug(temp_checkpoint_dir):
"processor_type": "huggingface_datasets",
"dataset_path": "terminusresearch/pexels-metadata-1.71M",
},
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints"),
"storage": {
"checkpoint_dir": str(temp_checkpoint_dir / "checkpoints"),
},
"chunk_size": 100,
}

Expand Down
Loading