Skip to content

Commit df196cd

Browse files
authored
Fix multi-stem MDXC bug (#242)
1 parent 466253c commit df196cd

2 files changed

Lines changed: 25 additions & 10 deletions

File tree

audio_separator/separator/architectures/mdxc_separator.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def __init__(self, common_config, arch_config):
6565
self.audio_file_path = None
6666
self.audio_file_base = None
6767

68-
self.is_primary_stem_main_target = False
69-
if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
70-
self.is_primary_stem_main_target = True
68+
# Only mark primary stem as main target for single-target models.
69+
# Multi-stem models should not trigger residual subtraction logic.
70+
self.is_primary_stem_main_target = bool(self.model_data_cfgdict.training.target_instrument)
7171

7272
self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
7373

@@ -403,8 +403,8 @@ def demix(self, mix: np.ndarray) -> dict:
403403
self.logger.debug("Deleting accumulated outputs to free up memory")
404404
del accumulated_outputs
405405

406-
if num_stems > 1 or self.is_primary_stem_main_target:
407-
self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
406+
if num_stems > 1:
407+
self.logger.debug("Number of stems is greater than 1, detaching individual sources and correcting pitch if necessary...")
408408

409409
sources = {}
410410

@@ -420,7 +420,8 @@ def demix(self, mix: np.ndarray) -> dict:
420420
else:
421421
sources[key] = value
422422

423-
if self.is_primary_stem_main_target:
423+
# Residual subtraction is only applicable for single-target models (not multi-stem)
424+
if self.is_primary_stem_main_target and num_stems == 1:
424425
self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
425426
if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
426427
sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
@@ -445,9 +446,23 @@ def demix(self, mix: np.ndarray) -> dict:
445446
self.logger.debug("Deleting inferenced outputs to free up memory")
446447
del inferenced_outputs
447448

449+
# For single-target models (e.g., karaoke), also return the residual as secondary
448450
if self.pitch_shift != 0:
449451
self.logger.debug("Applying pitch correction for single instrument")
450-
return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
452+
primary = self.pitch_fix(inferenced_output, sample_rate, orig_mix)
451453
else:
452-
self.logger.debug("Returning inferenced output for single instrument")
453-
return inferenced_output
454+
primary = inferenced_output
455+
456+
if self.is_primary_stem_main_target:
457+
self.logger.debug("Single-target model detected; computing residual secondary stem from original mix")
458+
# Ensure shapes match before residual subtraction
459+
if primary.shape[1] != orig_mix.shape[1]:
460+
primary = spec_utils.match_array_shapes(primary, orig_mix)
461+
secondary = orig_mix - primary
462+
return {
463+
self.primary_stem_name: primary,
464+
self.secondary_stem_name: secondary,
465+
}
466+
467+
self.logger.debug("Returning inferenced output for single instrument")
468+
return primary

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "audio-separator"
7-
version = "0.39.0"
7+
version = "0.39.1"
88
description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07"
99
authors = ["Andrew Beveridge <andrew@beveridge.uk>"]
1010
license = "MIT"

0 commit comments

Comments
 (0)