Skip to content

feat: Implement batch processing for the MDXC separator#262

Open
pedroalmeida415 wants to merge 3 commits intonomadkaraoke:mainfrom
pedroalmeida415:main
Open

feat: Implement batch processing for the MDXC separator#262
pedroalmeida415 wants to merge 3 commits intonomadkaraoke:mainfrom
pedroalmeida415:main

Conversation

@pedroalmeida415
Copy link

@pedroalmeida415 pedroalmeida415 commented Mar 12, 2026

Motivation

This PR properly implements batch processing for the MDXC Arch models, addressing a massive processing bottleneck. Previously, setting batch_size for MDXC separator did not actually influence the execution loop for Roformer models, causing the GPU to be severely underutilized on high-overlap segment sliding window operations.

By restructuring the iteration to leverage torch.utils.data.DataLoader over a custom RoformerDataset, this implementation lazily calculates coordinate slice bounds instead of redundantly slicing massive sections of audio into RAM.

Why this approach?

  1. Massive Throughput Boost: Processing one slice at a time on modern GPUs leaves thousands of CUDA cores idling. Using batches allows for parallel computation over the STFT representations simultaneously.
  2. Fixed Memory Footprint: The previous methodology for overlap-adding could spike RAM or VRAM based on the audio length or the overlap amount. A lazy-loading Dataset handles indices behind the scenes, ensuring deterministic memory bounds regardless of track size.
  3. PCIe Offloading: Passing grouped tensors asynchronously limits the slow host-to-device memory transfer overhead that plagued the frame-by-frame loop.
  4. Weighted Overlap-Add Reassembly: Similarly to industry standards like MONAI's SlidingWindowInferer or Asteroid's LambdaOverlapAdd, overlapping regions are gracefully interpolated through a hamming window weighting buffer to completely nullify any clipping or "seams" at segment boundaries.

Benchmarks

File Details: ~22 minutes of PCM_16 audio
Model Details: mel_band_roformer_kim_ft_unwa.ckpt

Using standard inference parameters

Metric Previous Approach New DataLoader Approach
Execution Time ~5-6 minutes ~2 minutes 1 second
Memory Usage N/A 6.8 GB VRAM batch_size=4

Speedup is ~2.75x

Inference use_autocast=True

Metric Previous Approach New DataLoader Approach
Execution Time ~2 minutes 34 seconds ~1 minute 9 seconds
Memory Usage N/A 5.5 GB VRAM batch_size=4

Speedup is ~2.2x

System Environment

  • Kernel: Linux 6.19.6-2-cachyos
  • CPU: Intel(R) Core(TM) i7-14700HX (28) @ 5.50 GHz
  • GPU 1: NVIDIA GeForce RTX 4070 Max-Q / Mobile [Discrete]
  • Memory: 32 GiB

Summary by CodeRabbit

  • New Features

    • Added --mdxc_num_workers CLI parameter to configure parallel worker threads for MDXC audio separation (default: 0).
  • Refactor

    • Switched to batched processing for MDXC to improve throughput and memory efficiency; logging around MDXC parameters was clarified.

@coderabbitai
Copy link

coderabbitai bot commented Mar 12, 2026

Walkthrough

Adds a DataLoader-based batching path for Roformer demixing: new RoformerDataset provides chunked, stepped audio samples; MDXC separator accepts num_workers and uses DataLoader to run batched model inference with overlap-add reconstruction; CLI and tests expose the new parameter.

Changes

Cohort / File(s) Summary
Roformer batching & dataset
audio_separator/separator/architectures/mdxc_separator.py
Added RoformerDataset(Dataset) with __init__, __len__, __getitem__. Replaced per-step Roformer loop with DataLoader batching (uses batch_size, num_workers, pin_memory), moves inputs to device per-batch, and performs overlap-add accumulation using start indices and safe lengths.
Separator defaults & logging
audio_separator/separator/separator.py
Added num_workers: 0 to default mdxc_params; MDXCSeparator stores self.num_workers and logs it. Standardized several debug log messages to plain string literals.
CLI parameter
audio_separator/utils/cli.py
Added --mdxc_num_workers CLI argument (default 0) and propagated it into mdxc_params passed to Separator.
Tests updated
tests/unit/test_cli.py
Updated common_expected_args to include "num_workers": 0 in mdxc_params fixture expectations.

Sequence Diagram

sequenceDiagram
    participant Sep as Separator
    participant RDS as RoformerDataset
    participant DL as DataLoader
    participant Model as Roformer Model
    participant OA as Overlap-Add

    Sep->>RDS: Create(mix, chunk_size, step)
    Sep->>DL: Create(RoformerDataset, batch_size, num_workers, pin_memory)
    DL->>RDS: __getitem__/batch fetch
    RDS-->>DL: (chunks, start_indices, lengths)
    DL-->>Sep: Batch of chunks
    Sep->>Model: Forward(batch moved to device)
    Model-->>Sep: Predictions (batch)
    Sep->>OA: Overlap-add accumulate using start_indices, lengths (accumulate on CPU)
    loop until all batches processed
    end
    OA-->>Sep: Final demixed output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐇 I nibble bytes in comfy rows,

Chunks and steps where rhythm grows,
DataLoaders hum, workers leap,
Overlap-add tucks sounds to sleep,
A rabbit's hop makes batching neat.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main change: implementing batch processing for MDXC separator using DataLoader instead of frame-by-frame execution.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
audio_separator/separator/architectures/mdxc_separator.py (1)

347-364: Good implementation of batched inference with DataLoader.

The DataLoader integration with pin_memory for CUDA and the batch-wise overlap-add logic is well structured.

Minor suggestion: Consider using xs.shape[0] instead of len(xs) at line 354 for clearer tensor semantics.

♻️ Optional: Use tensor shape accessor
-                for b in range(len(xs)):
+                for b in range(xs.shape[0]):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 347 -
364, Replace the Python built-in len(xs) with the tensor shape accessor
xs.shape[0] in the batch loop to use tensor semantics; in mdxc_separator.py
within the batched inference loop that calls self.model_run(parts) and then
iterates over outputs, change "for b in range(len(xs)):" to "for b in
range(xs.shape[0]):" so the loop uses the tensor's first-dimension size reliably
when xs is a tensor returned by self.model_run.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 17-40: The RoformerDataset.__getitem__ last-chunk handling can
produce a negative start_idx and an incorrect length when mix.shape[1] <
chunk_size; update the block to clamp start_idx with start_idx = max(0,
self.mix.shape[1] - self.chunk_size) and set length from the actual slice
(length = part.shape[-1]) after computing part, so that for very short audio the
returned part, start_idx, and length are consistent and non-negative (refer to
RoformerDataset, __getitem__, self.mix, chunk_size, and length).

---

Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 347-364: Replace the Python built-in len(xs) with the tensor shape
accessor xs.shape[0] in the batch loop to use tensor semantics; in
mdxc_separator.py within the batched inference loop that calls
self.model_run(parts) and then iterates over outputs, change "for b in
range(len(xs)):" to "for b in range(xs.shape[0]):" so the loop uses the tensor's
first-dimension size reliably when xs is a tensor returned by self.model_run.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: fc0d3223-134e-403d-8fbc-c4caa174747a

📥 Commits

Reviewing files that changed from the base of the PR and between 12f8fc6 and ec71756.

📒 Files selected for processing (3)
  • audio_separator/separator/architectures/mdxc_separator.py
  • audio_separator/separator/separator.py
  • audio_separator/utils/cli.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
tests/unit/test_cli.py (1)

31-31: Add one non-default --mdxc_num_workers test.

Line 31 only covers the default 0. If the CLI accepted --mdxc_num_workers but still always forwarded 0, this suite would still pass. A small targeted test with a non-default value would close that gap.

def test_cli_mdxc_num_workers_argument(common_expected_args):
    test_args = ["cli.py", "test_audio.mp3", "--mdxc_num_workers=2"]
    with patch("sys.argv", test_args):
        with patch("audio_separator.separator.Separator") as mock_separator:
            mock_separator.return_value.separate.return_value = ["output_file.mp3"]
            main()

            expected_args = common_expected_args.copy()
            expected_args["mdxc_params"] = {
                **common_expected_args["mdxc_params"],
                "num_workers": 2,
            }
            mock_separator.assert_called_once_with(**expected_args)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/test_cli.py` at line 31, Add a new unit test that verifies the CLI
forwards a non-default --mdxc_num_workers value into the mdxc_params passed to
Separator: call main() with argv containing "--mdxc_num_workers=2", patch
audio_separator.separator.Separator, assert Separator was called once and that
its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from
common_expected_args); reference main(), Separator, and the mdxc_params key when
implementing the test.
audio_separator/separator/architectures/mdxc_separator.py (2)

35-35: Avoid scheduling the same tail window twice.

Line 35 schedules starts all the way to the raw tail, and Lines 60-65 remap any short tail back to mix.shape[1] - chunk_size. When that last full-window start is already on a step boundary, the same end chunk gets inferred twice.

♻️ Proposed fix
-        self.indices = list(range(0, mix.shape[1], step))
+        if mix.shape[1] <= chunk_size:
+            self.indices = [0]
+        else:
+            last_start = mix.shape[1] - chunk_size
+            self.indices = list(range(0, last_start + 1, step))
+            if self.indices[-1] != last_start:
+                self.indices.append(last_start)
-        # We need to handle the last chunk where part is smaller than chunk_size
-        if length < self.chunk_size and self.mix.shape[1] >= self.chunk_size:
-            # Take the last chunk_size from the end
-            part = self.mix[:, -self.chunk_size :]
-            length = self.chunk_size
-            start_idx = self.mix.shape[1] - self.chunk_size
-        # If mix is shorter than chunk_size, keep original part and length
-
         return part, start_idx, length

Also applies to: 60-65

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` at line 35, The
current scheduling can append a tail start that duplicates an existing start
when mix.shape[1]-chunk_size falls on a step boundary; update the logic that
builds self.indices so that the final list never contains duplicate starts:
after creating indices = list(range(0, mix.shape[1], step)) and before applying
the tail remap (the code that references chunk_size and mix.shape[1]), compute
last_start = mix.shape[1] - chunk_size and ensure you only append or remap to
last_start if it is not already present (or alternatively deduplicate
self.indices preserving order), so the tail window is not scheduled twice.

377-385: Keep the host/device transfers batched.

Line 377 pins the batch, but Line 380 still uses a blocking .to(device). Line 384 then copies each item back separately with .cpu(), which reintroduces per-sample D2H transfers immediately after batching.

⚡ Proposed fix
-                for parts, start_idxs, lengths in tqdm(dataloader):
-                    parts = parts.to(device)
-                    xs = self.model_run(parts)
+                for parts, start_idxs, lengths in tqdm(dataloader):
+                    parts = parts.to(device, non_blocking=(device.type == "cuda"))
+                    xs = self.model_run(parts).cpu()
 
                     for b in range(xs.shape[0]):
-                        x = xs[b].cpu()
+                        x = xs[b]
                         start_idx = start_idxs[b].item()
                         length = lengths[b].item()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 377 -
385, The per-sample device-to-host transfers happen because you call .to(device)
on the whole batch (parts) but then call .cpu() inside the loop for each sample
(xs[b].cpu()); fix by doing the host copy once per batch: after xs =
self.model_run(parts) perform a single batch-level transfer (e.g., xs =
xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g., start_idxs
= start_idxs.cpu()) before the inner loop, then iterate over xs[b] and
start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers
batched and avoids repeated D2H copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Line 35: The current scheduling can append a tail start that duplicates an
existing start when mix.shape[1]-chunk_size falls on a step boundary; update the
logic that builds self.indices so that the final list never contains duplicate
starts: after creating indices = list(range(0, mix.shape[1], step)) and before
applying the tail remap (the code that references chunk_size and mix.shape[1]),
compute last_start = mix.shape[1] - chunk_size and ensure you only append or
remap to last_start if it is not already present (or alternatively deduplicate
self.indices preserving order), so the tail window is not scheduled twice.
- Around line 377-385: The per-sample device-to-host transfers happen because
you call .to(device) on the whole batch (parts) but then call .cpu() inside the
loop for each sample (xs[b].cpu()); fix by doing the host copy once per batch:
after xs = self.model_run(parts) perform a single batch-level transfer (e.g., xs
= xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g.,
start_idxs = start_idxs.cpu()) before the inner loop, then iterate over xs[b]
and start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers
batched and avoids repeated D2H copies.

In `@tests/unit/test_cli.py`:
- Line 31: Add a new unit test that verifies the CLI forwards a non-default
--mdxc_num_workers value into the mdxc_params passed to Separator: call main()
with argv containing "--mdxc_num_workers=2", patch
audio_separator.separator.Separator, assert Separator was called once and that
its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from
common_expected_args); reference main(), Separator, and the mdxc_params key when
implementing the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 1a40c725-b4a4-4590-89c1-8271d945ff54

📥 Commits

Reviewing files that changed from the base of the PR and between ec71756 and 0bb3d1b.

📒 Files selected for processing (2)
  • audio_separator/separator/architectures/mdxc_separator.py
  • tests/unit/test_cli.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant