Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from ml_collections import ConfigDict
from scipy import signal
Expand All @@ -13,6 +14,64 @@
# Roformer direct constructors removed; loading handled via RoformerLoader in CommonSeparator.


class RoformerDataset(Dataset):
"""
Dataset for handling Roformer audio chunks.
It splits the audio mix into configurable chunks with a specified step size.
"""

def __init__(self, mix, chunk_size, step):
"""
Initializes the RoformerDataset.

Args:
mix (np.ndarray): The audio mix to be processed.
chunk_size (int): The size of each chunk.
step (int): The step size between chunks.
"""
self.mix = mix
self.chunk_size = chunk_size
self.step = step

indices = list(range(0, mix.shape[1], step))
last_start = mix.shape[1] - chunk_size

if last_start > 0:
# Remap any index that would result in a short chunk to the last_start
indices = [i if i <= last_start else last_start for i in indices]
elif last_start <= 0:
# If mix is shorter than or equal to chunk_size, only one chunk starting at 0 is needed
indices = [0]

# Use a dictionary to preserve insertion order while deduplicating
self.indices = list(dict.fromkeys(indices))

def __len__(self):
"""
Returns the number of chunks in the dataset.

Returns:
int: The number of chunks.
"""
return len(self.indices)

def __getitem__(self, idx):
"""
Gets a chunk from the dataset by index.

Args:
idx (int): The index of the chunk.

Returns:
tuple: A tuple containing the chunk (np.ndarray), the start index (int), and the length (int).
"""
start_idx = self.indices[idx]
part = self.mix[:, start_idx : start_idx + self.chunk_size]
length = part.shape[-1]

return part, start_idx, length


class MDXCSeparator(CommonSeparator):
"""
MDXCSeparator is responsible for separating audio sources using MDXC models.
Expand Down Expand Up @@ -41,6 +100,7 @@ def __init__(self, common_config, arch_config):

self.overlap = arch_config.get("overlap", 8)
self.batch_size = arch_config.get("batch_size", 1)
self.num_workers = arch_config.get("num_workers", 0)

# Amount of pitch shift to apply during processing (this does NOT affect the pitch of the output audio):
# • Whole numbers indicate semitones.
Expand All @@ -51,7 +111,7 @@ def __init__(self, common_config, arch_config):

self.process_all_stems = arch_config.get("process_all_stems", True)

self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}, num_workers={self.num_workers}")
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")

Expand Down Expand Up @@ -317,28 +377,23 @@ def demix(self, mix: np.ndarray) -> dict:
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)

for i in tqdm(range(0, mix.shape[1], step)):
part = mix[:, i : i + chunk_size]
length = part.shape[-1]
if i + chunk_size > mix.shape[1]:
part = mix[:, -chunk_size:]
length = chunk_size
part = part.to(device)
x = self.model_run(part.unsqueeze(0))[0]
x = x.cpu()
# Perform overlap_add on CPU
if i + chunk_size > mix.shape[1]:
# Fixed to correctly add to the end of the tensor
start_idx = result.shape[-1] - chunk_size
dataset = RoformerDataset(mix, chunk_size, step)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=(device.type == "cuda"))

for parts, start_idxs, lengths in tqdm(dataloader):
parts = parts.to(device)
xs = self.model_run(parts).detach().cpu()

for b in range(xs.shape[0]):
x = xs[b]
start_idx = start_idxs[b].item()
length = lengths[b].item()

# Perform overlap_add on CPU
result = self.overlap_add(result, x, window, start_idx, length)
safe_len = min(length, x.shape[-1], window.shape[0])
if safe_len > 0:
counter[..., start_idx : start_idx + safe_len] += window[:safe_len]
else:
result = self.overlap_add(result, x, window, i, length)
safe_len = min(length, x.shape[-1], window.shape[0])
if safe_len > 0:
counter[..., i : i + safe_len] += window[:safe_len]

inferenced_outputs = result / counter.clamp(min=1e-10)

Expand Down
10 changes: 5 additions & 5 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0, "num_workers": 0},
info_only=False,
):
"""Initialize the separator."""
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(

self.invert_using_spec = invert_using_spec
if self.invert_using_spec:
self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
self.logger.debug("Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")

try:
self.sample_rate = int(sample_rate)
Expand Down Expand Up @@ -496,14 +496,14 @@ def list_supported_model_files(self):
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)

model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
self.logger.debug(f"UVR model download list loaded")
self.logger.debug("UVR model download list loaded")

# Load the model scores with error handling
model_scores = {}
try:
with resources.open_text("audio_separator", "models-scores.json") as f:
model_scores = json.load(f)
self.logger.debug(f"Model scores loaded")
self.logger.debug("Model scores loaded")
except json.JSONDecodeError as e:
self.logger.warning(f"Failed to load model scores: {str(e)}")
self.logger.warning("Continuing without model scores")
Expand All @@ -529,7 +529,7 @@ def list_supported_model_files(self):
# Load the JSON file using importlib.resources
with resources.open_text("audio_separator", "models.json") as f:
audio_separator_models_list = json.load(f)
self.logger.debug(f"Audio-Separator model list loaded")
self.logger.debug("Audio-Separator model list loaded")

# Return object with list of model names
model_files_grouped_by_type = {
Expand Down
4 changes: 3 additions & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import json
import sys
import os
from importlib import metadata


Expand Down Expand Up @@ -118,13 +117,15 @@ def main():
mdxc_override_model_segment_size_help = "Override model default segment size instead of using the model default value. Example: --mdxc_override_model_segment_size"
mdxc_overlap_help = "Amount of overlap between prediction windows, 2-50. Higher is better but slower (default: %(default)s). Example: --mdxc_overlap=8"
mdxc_batch_size_help = "Larger consumes more RAM but may process slightly faster (default: %(default)s). Example: --mdxc_batch_size=4"
mdxc_num_workers_help = "Number of workers for DataLoader. Higher = faster preprocessing but more CPU/RAM (default: %(default)s). Example: --mdxc_num_workers=4"
mdxc_pitch_shift_help = "Shift audio pitch by a number of semitones while processing. May improve output for deep/high vocals. (default: %(default)s). Example: --mdxc_pitch_shift=2"

mdxc_params = parser.add_argument_group("MDXC Architecture Parameters")
mdxc_params.add_argument("--mdxc_segment_size", type=int, default=256, help=mdxc_segment_size_help)
mdxc_params.add_argument("--mdxc_override_model_segment_size", action="store_true", help=mdxc_override_model_segment_size_help)
mdxc_params.add_argument("--mdxc_overlap", type=int, default=8, help=mdxc_overlap_help)
mdxc_params.add_argument("--mdxc_batch_size", type=int, default=1, help=mdxc_batch_size_help)
mdxc_params.add_argument("--mdxc_num_workers", type=int, default=0, help=mdxc_num_workers_help)
mdxc_params.add_argument("--mdxc_pitch_shift", type=int, default=0, help=mdxc_pitch_shift_help)

args = parser.parse_args()
Expand Down Expand Up @@ -228,6 +229,7 @@ def main():
mdxc_params={
"segment_size": args.mdxc_segment_size,
"batch_size": args.mdxc_batch_size,
"num_workers": args.mdxc_num_workers,
"overlap": args.mdxc_overlap,
"override_model_segment_size": args.mdxc_override_model_segment_size,
"pitch_shift": args.mdxc_pitch_shift,
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest
import logging
from audio_separator.utils.cli import main
import subprocess
from unittest import mock
from unittest.mock import patch, MagicMock, mock_open

Expand All @@ -28,7 +27,7 @@ def common_expected_args():
"mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
"vr_params": {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
"demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
"mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0},
"mdxc_params": {"segment_size": 256, "batch_size": 1, "overlap": 8, "override_model_segment_size": False, "pitch_shift": 0, "num_workers": 0},
}


Expand Down Expand Up @@ -275,3 +274,21 @@ def test_cli_demucs_output_names_argument(common_expected_args):
# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with(["test_audio.mp3"], custom_output_names=demucs_output_names)


# Test using mdxc_num_workers argument
def test_cli_mdxc_num_workers_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--mdxc_num_workers=2"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Update expected args for this specific test
expected_args = common_expected_args.copy()
expected_args["mdxc_params"] = expected_args["mdxc_params"].copy()
expected_args["mdxc_params"]["num_workers"] = 2

# Assertions
mock_separator.assert_called_once_with(**expected_args)
59 changes: 55 additions & 4 deletions tests/unit/test_mdxc_roformer_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pytest
import numpy as np
import torch
from unittest.mock import Mock, MagicMock, patch
from unittest.mock import Mock
import logging
from audio_separator.separator.architectures.mdxc_separator import RoformerDataset


class TestMDXCRoformerChunking:
Expand Down Expand Up @@ -103,7 +104,6 @@ def test_counter_updates_safe_len(self):
"""T055: Counter increments match overlap_add safe span."""
# Mock counter and overlap_add logic
counter = torch.zeros(2, 20000)
chunk_size = 8192
safe_len = 6000 # Shorter than chunk_size
start_idx = 1000

Expand Down Expand Up @@ -288,7 +288,7 @@ def mock_setup_chunking_with_logging(model, audio):
audio = Mock()
audio.hop_length = 512

result = mock_setup_chunking_with_logging(model_with_stft, audio)
mock_setup_chunking_with_logging(model_with_stft, audio)

# Verify logging occurred
assert "stft_hop_length=1024" in caplog.text
Expand Down Expand Up @@ -342,10 +342,61 @@ def mock_calculate_iterations(audio_len, chunk_sz, step_sz):
)

# Verify minimum iterations
assert actual_iterations >= 1, f"Should always have at least 1 iteration"
assert actual_iterations >= 1, "Should always have at least 1 iteration"

# Verify maximum reasonable iterations
max_reasonable = (audio_length // step_size) + 2
assert actual_iterations <= max_reasonable, (
f"Too many iterations {actual_iterations} for audio_len={audio_length}"
)


class TestRoformerDataset:
"""Test cases for the RoformerDataset class."""

def test_roformer_dataset_no_duplicates(self):
"""Verify that indices are correctly deduplicated when tail lands on step boundary."""
mix = np.zeros((2, 100))
chunk_size = 20
step = 10
dataset = RoformerDataset(mix, chunk_size, step)

# Expected indices: 0, 10, 20, 30, 40, 50, 60, 70, 80
# (90 was remapped to 80 and then deduplicated)
expected_indices = [0, 10, 20, 30, 40, 50, 60, 70, 80]
assert dataset.indices == expected_indices
assert len(dataset.indices) == len(set(dataset.indices))

def test_roformer_dataset_tail_remapped(self):
"""Verify that audio tail is correctly remapped and included."""
mix = np.zeros((2, 105))
chunk_size = 20
step = 10
dataset = RoformerDataset(mix, chunk_size, step)

# Expected indices: 0, 10, 20, 30, 40, 50, 60, 70, 80, 85
expected_indices = [0, 10, 20, 30, 40, 50, 60, 70, 80, 85]
assert dataset.indices == expected_indices
assert len(dataset.indices) == len(set(dataset.indices))

def test_roformer_dataset_short_audio(self):
"""Verify that audio shorter than chunk_size is handled correctly."""
mix = np.zeros((2, 10))
chunk_size = 20
step = 10
dataset = RoformerDataset(mix, chunk_size, step)

# Should result in just [0]
assert dataset.indices == [0]
part, start_idx, length = dataset[0]
assert part.shape == (2, 10)
assert start_idx == 0
assert length == 10

def test_roformer_dataset_exact_overlap(self):
"""Verify that exact overlaps result in correct index scheduling."""
mix = np.zeros((2, 40))
chunk_size = 20
step = 20
dataset = RoformerDataset(mix, chunk_size, step)
assert dataset.indices == [0, 20]
Loading