feat: Implement batch processing for the MDXC separator#262
feat: Implement batch processing for the MDXC separator#262pedroalmeida415 wants to merge 3 commits intonomadkaraoke:mainfrom
Conversation
WalkthroughAdds a DataLoader-based batching path for Roformer demixing: new Changes
Sequence DiagramsequenceDiagram
participant Sep as Separator
participant RDS as RoformerDataset
participant DL as DataLoader
participant Model as Roformer Model
participant OA as Overlap-Add
Sep->>RDS: Create(mix, chunk_size, step)
Sep->>DL: Create(RoformerDataset, batch_size, num_workers, pin_memory)
DL->>RDS: __getitem__/batch fetch
RDS-->>DL: (chunks, start_indices, lengths)
DL-->>Sep: Batch of chunks
Sep->>Model: Forward(batch moved to device)
Model-->>Sep: Predictions (batch)
Sep->>OA: Overlap-add accumulate using start_indices, lengths (accumulate on CPU)
loop until all batches processed
end
OA-->>Sep: Final demixed output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
audio_separator/separator/architectures/mdxc_separator.py (1)
347-364: Good implementation of batched inference with DataLoader.The DataLoader integration with
pin_memoryfor CUDA and the batch-wise overlap-add logic is well structured.Minor suggestion: Consider using
xs.shape[0]instead oflen(xs)at line 354 for clearer tensor semantics.♻️ Optional: Use tensor shape accessor
- for b in range(len(xs)): + for b in range(xs.shape[0]):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 347 - 364, Replace the Python built-in len(xs) with the tensor shape accessor xs.shape[0] in the batch loop to use tensor semantics; in mdxc_separator.py within the batched inference loop that calls self.model_run(parts) and then iterates over outputs, change "for b in range(len(xs)):" to "for b in range(xs.shape[0]):" so the loop uses the tensor's first-dimension size reliably when xs is a tensor returned by self.model_run.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 17-40: The RoformerDataset.__getitem__ last-chunk handling can
produce a negative start_idx and an incorrect length when mix.shape[1] <
chunk_size; update the block to clamp start_idx with start_idx = max(0,
self.mix.shape[1] - self.chunk_size) and set length from the actual slice
(length = part.shape[-1]) after computing part, so that for very short audio the
returned part, start_idx, and length are consistent and non-negative (refer to
RoformerDataset, __getitem__, self.mix, chunk_size, and length).
---
Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 347-364: Replace the Python built-in len(xs) with the tensor shape
accessor xs.shape[0] in the batch loop to use tensor semantics; in
mdxc_separator.py within the batched inference loop that calls
self.model_run(parts) and then iterates over outputs, change "for b in
range(len(xs)):" to "for b in range(xs.shape[0]):" so the loop uses the tensor's
first-dimension size reliably when xs is a tensor returned by self.model_run.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: fc0d3223-134e-403d-8fbc-c4caa174747a
📒 Files selected for processing (3)
audio_separator/separator/architectures/mdxc_separator.pyaudio_separator/separator/separator.pyaudio_separator/utils/cli.py
There was a problem hiding this comment.
🧹 Nitpick comments (3)
tests/unit/test_cli.py (1)
31-31: Add one non-default--mdxc_num_workerstest.Line 31 only covers the default
0. If the CLI accepted--mdxc_num_workersbut still always forwarded0, this suite would still pass. A small targeted test with a non-default value would close that gap.def test_cli_mdxc_num_workers_argument(common_expected_args): test_args = ["cli.py", "test_audio.mp3", "--mdxc_num_workers=2"] with patch("sys.argv", test_args): with patch("audio_separator.separator.Separator") as mock_separator: mock_separator.return_value.separate.return_value = ["output_file.mp3"] main() expected_args = common_expected_args.copy() expected_args["mdxc_params"] = { **common_expected_args["mdxc_params"], "num_workers": 2, } mock_separator.assert_called_once_with(**expected_args)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/test_cli.py` at line 31, Add a new unit test that verifies the CLI forwards a non-default --mdxc_num_workers value into the mdxc_params passed to Separator: call main() with argv containing "--mdxc_num_workers=2", patch audio_separator.separator.Separator, assert Separator was called once and that its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from common_expected_args); reference main(), Separator, and the mdxc_params key when implementing the test.audio_separator/separator/architectures/mdxc_separator.py (2)
35-35: Avoid scheduling the same tail window twice.Line 35 schedules starts all the way to the raw tail, and Lines 60-65 remap any short tail back to
mix.shape[1] - chunk_size. When that last full-window start is already on a step boundary, the same end chunk gets inferred twice.♻️ Proposed fix
- self.indices = list(range(0, mix.shape[1], step)) + if mix.shape[1] <= chunk_size: + self.indices = [0] + else: + last_start = mix.shape[1] - chunk_size + self.indices = list(range(0, last_start + 1, step)) + if self.indices[-1] != last_start: + self.indices.append(last_start)- # We need to handle the last chunk where part is smaller than chunk_size - if length < self.chunk_size and self.mix.shape[1] >= self.chunk_size: - # Take the last chunk_size from the end - part = self.mix[:, -self.chunk_size :] - length = self.chunk_size - start_idx = self.mix.shape[1] - self.chunk_size - # If mix is shorter than chunk_size, keep original part and length - return part, start_idx, lengthAlso applies to: 60-65
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@audio_separator/separator/architectures/mdxc_separator.py` at line 35, The current scheduling can append a tail start that duplicates an existing start when mix.shape[1]-chunk_size falls on a step boundary; update the logic that builds self.indices so that the final list never contains duplicate starts: after creating indices = list(range(0, mix.shape[1], step)) and before applying the tail remap (the code that references chunk_size and mix.shape[1]), compute last_start = mix.shape[1] - chunk_size and ensure you only append or remap to last_start if it is not already present (or alternatively deduplicate self.indices preserving order), so the tail window is not scheduled twice.
377-385: Keep the host/device transfers batched.Line 377 pins the batch, but Line 380 still uses a blocking
.to(device). Line 384 then copies each item back separately with.cpu(), which reintroduces per-sample D2H transfers immediately after batching.⚡ Proposed fix
- for parts, start_idxs, lengths in tqdm(dataloader): - parts = parts.to(device) - xs = self.model_run(parts) + for parts, start_idxs, lengths in tqdm(dataloader): + parts = parts.to(device, non_blocking=(device.type == "cuda")) + xs = self.model_run(parts).cpu() for b in range(xs.shape[0]): - x = xs[b].cpu() + x = xs[b] start_idx = start_idxs[b].item() length = lengths[b].item()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 377 - 385, The per-sample device-to-host transfers happen because you call .to(device) on the whole batch (parts) but then call .cpu() inside the loop for each sample (xs[b].cpu()); fix by doing the host copy once per batch: after xs = self.model_run(parts) perform a single batch-level transfer (e.g., xs = xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g., start_idxs = start_idxs.cpu()) before the inner loop, then iterate over xs[b] and start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers batched and avoids repeated D2H copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Line 35: The current scheduling can append a tail start that duplicates an
existing start when mix.shape[1]-chunk_size falls on a step boundary; update the
logic that builds self.indices so that the final list never contains duplicate
starts: after creating indices = list(range(0, mix.shape[1], step)) and before
applying the tail remap (the code that references chunk_size and mix.shape[1]),
compute last_start = mix.shape[1] - chunk_size and ensure you only append or
remap to last_start if it is not already present (or alternatively deduplicate
self.indices preserving order), so the tail window is not scheduled twice.
- Around line 377-385: The per-sample device-to-host transfers happen because
you call .to(device) on the whole batch (parts) but then call .cpu() inside the
loop for each sample (xs[b].cpu()); fix by doing the host copy once per batch:
after xs = self.model_run(parts) perform a single batch-level transfer (e.g., xs
= xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g.,
start_idxs = start_idxs.cpu()) before the inner loop, then iterate over xs[b]
and start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers
batched and avoids repeated D2H copies.
In `@tests/unit/test_cli.py`:
- Line 31: Add a new unit test that verifies the CLI forwards a non-default
--mdxc_num_workers value into the mdxc_params passed to Separator: call main()
with argv containing "--mdxc_num_workers=2", patch
audio_separator.separator.Separator, assert Separator was called once and that
its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from
common_expected_args); reference main(), Separator, and the mdxc_params key when
implementing the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 1a40c725-b4a4-4590-89c1-8271d945ff54
📒 Files selected for processing (2)
audio_separator/separator/architectures/mdxc_separator.pytests/unit/test_cli.py
Motivation
This PR properly implements batch processing for the MDXC Arch models, addressing a massive processing bottleneck. Previously, setting
batch_sizefor MDXC separator did not actually influence the execution loop for Roformer models, causing the GPU to be severely underutilized on high-overlap segment sliding window operations.By restructuring the iteration to leverage
torch.utils.data.DataLoaderover a customRoformerDataset, this implementation lazily calculates coordinate slice bounds instead of redundantly slicing massive sections of audio into RAM.Why this approach?
Datasethandles indices behind the scenes, ensuring deterministic memory bounds regardless of track size.SlidingWindowInfereror Asteroid'sLambdaOverlapAdd, overlapping regions are gracefully interpolated through ahammingwindow weighting buffer to completely nullify any clipping or "seams" at segment boundaries.Benchmarks
File Details: ~22 minutes of PCM_16 audio
Model Details:
mel_band_roformer_kim_ft_unwa.ckptUsing standard inference parameters
batch_size=4Speedup is ~2.75x
Inference
use_autocast=Truebatch_size=4Speedup is ~2.2x
System Environment
Summary by CodeRabbit
New Features
Refactor