@@ -65,9 +65,9 @@ def __init__(self, common_config, arch_config):
6565 self .audio_file_path = None
6666 self .audio_file_base = None
6767
68- self . is_primary_stem_main_target = False
69- if self . model_data_cfgdict . training . target_instrument == "Vocals" or len ( self . model_data_cfgdict . training . instruments ) > 1 :
70- self .is_primary_stem_main_target = True
68+ # Only mark primary stem as main target for single-target models.
69+ # Multi-stem models should not trigger residual subtraction logic.
70+ self .is_primary_stem_main_target = bool ( self . model_data_cfgdict . training . target_instrument )
7171
7272 self .logger .debug (f"is_primary_stem_main_target: { self .is_primary_stem_main_target } " )
7373
@@ -403,8 +403,8 @@ def demix(self, mix: np.ndarray) -> dict:
403403 self .logger .debug ("Deleting accumulated outputs to free up memory" )
404404 del accumulated_outputs
405405
406- if num_stems > 1 or self . is_primary_stem_main_target :
407- self .logger .debug ("Number of stems is greater than 1 or vocals are main target , detaching individual sources and correcting pitch if necessary..." )
406+ if num_stems > 1 :
407+ self .logger .debug ("Number of stems is greater than 1, detaching individual sources and correcting pitch if necessary..." )
408408
409409 sources = {}
410410
@@ -420,7 +420,8 @@ def demix(self, mix: np.ndarray) -> dict:
420420 else :
421421 sources [key ] = value
422422
423- if self .is_primary_stem_main_target :
423+ # Residual subtraction is only applicable for single-target models (not multi-stem)
424+ if self .is_primary_stem_main_target and num_stems == 1 :
424425 self .logger .debug (f"Primary stem: { self .primary_stem_name } is main target, detaching and matching array shapes if necessary..." )
425426 if sources [self .primary_stem_name ].shape [1 ] != orig_mix .shape [1 ]:
426427 sources [self .primary_stem_name ] = spec_utils .match_array_shapes (sources [self .primary_stem_name ], orig_mix )
@@ -445,9 +446,23 @@ def demix(self, mix: np.ndarray) -> dict:
445446 self .logger .debug ("Deleting inferenced outputs to free up memory" )
446447 del inferenced_outputs
447448
449+ # For single-target models (e.g., karaoke), also return the residual as secondary
448450 if self .pitch_shift != 0 :
449451 self .logger .debug ("Applying pitch correction for single instrument" )
450- return self .pitch_fix (inferenced_output , sample_rate , orig_mix )
452+ primary = self .pitch_fix (inferenced_output , sample_rate , orig_mix )
451453 else :
452- self .logger .debug ("Returning inferenced output for single instrument" )
453- return inferenced_output
454+ primary = inferenced_output
455+
456+ if self .is_primary_stem_main_target :
457+ self .logger .debug ("Single-target model detected; computing residual secondary stem from original mix" )
458+ # Ensure shapes match before residual subtraction
459+ if primary .shape [1 ] != orig_mix .shape [1 ]:
460+ primary = spec_utils .match_array_shapes (primary , orig_mix )
461+ secondary = orig_mix - primary
462+ return {
463+ self .primary_stem_name : primary ,
464+ self .secondary_stem_name : secondary ,
465+ }
466+
467+ self .logger .debug ("Returning inferenced output for single instrument" )
468+ return primary
0 commit comments