diff --git a/README.md b/README.md
index c6b8057..16353e7 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
# 🎶 Audio Separator 🎶
[](https://badge.fury.io/py/audio-separator)
@@ -318,6 +318,61 @@ The chunking feature supports all model types:
Chunks are concatenated without crossfading, which may result in minor artifacts at chunk boundaries in rare cases. For most use cases, these are not noticeable. The simple concatenation approach keeps processing time minimal while solving out-of-memory issues.
+### Ensembling Multiple Models
+
+You can combine the results of multiple models to improve separation quality. This will run each model and then combine their outputs using a specified algorithm.
+
+#### CLI Usage
+
+Use the `--model_filename` (or `-m`) flag with multiple arguments. You can also specify the ensemble algorithm using `--ensemble_algorithm`.
+
+```sh
+# Ensemble two models using the default 'avg_wave' algorithm
+audio-separator audio.wav -m model1.ckpt model2.onnx
+
+# Ensemble multiple models using a specific algorithm
+audio-separator audio.wav -m model1.ckpt model2.onnx model3.ckpt --ensemble_algorithm max_fft
+```
+
+#### Python API Usage
+
+```python
+from audio_separator.separator import Separator
+
+# Initialize the Separator class with custom parameters
+separator = Separator(
+ output_dir='output',
+ ensemble_algorithm='avg_wave'
+)
+
+# List of models to ensemble
+# Note: These models will be downloaded automatically if not present
+models = [
+ 'UVR-MDX-NET-Inst_HQ_3.onnx',
+ 'UVR_MDXNET_KARA_2.onnx'
+]
+
+# Specify multiple models for ensembling
+separator.load_model(model_filename=models)
+
+# Perform separation
+# The algorithm defaults to 'avg_wave' as specified during Separator initialization
+output_files = separator.separate('audio.wav')
+```
+
+#### Supported Ensemble Algorithms
+- `avg_wave`: Weighted average of waveforms (default)
+- `median_wave`: Median of waveforms
+- `min_wave`: Minimum of waveforms
+- `max_wave`: Maximum of waveforms
+- `avg_fft`: Weighted average of spectrograms
+- `median_fft`: Median of spectrograms
+- `min_fft`: Minimum of spectrograms
+- `max_fft`: Maximum of spectrograms
+- `uvr_max_spec`: UVR-based maximum spectrogram ensemble
+- `uvr_min_spec`: UVR-based minimum spectrogram ensemble
+- `ensemble_wav`: UVR-based least noisy chunk ensemble
+
### Full command-line interface options
```sh
@@ -525,6 +580,8 @@ You can also rename specific stems:
- **`vr_params`:** (Optional) VR Architecture Specific Attributes & Defaults. `Default: {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}`
- **`demucs_params`:** (Optional) Demucs Architecture Specific Attributes & Defaults. `Default: {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}`
- **`mdxc_params`:** (Optional) MDXC Architecture Specific Attributes & Defaults. `Default: {"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0}`
+- **`ensemble_algorithm`:** (Optional) Algorithm to use for ensembling multiple models. `Default: 'avg_wave'`
+- **`ensemble_weights`:** (Optional) Weights for each model in the ensemble. `Default: None` (equal weights)
## Remote API Usage 🌐
@@ -653,4 +710,4 @@ For questions or feedback, please raise an issue or reach out to @beveradb ([And

-
+
\ No newline at end of file
diff --git a/audio_separator/separator/ensembler.py b/audio_separator/separator/ensembler.py
new file mode 100644
index 0000000..10e357d
--- /dev/null
+++ b/audio_separator/separator/ensembler.py
@@ -0,0 +1,156 @@
+import numpy as np
+import librosa
+from audio_separator.separator.uvr_lib_v5 import spec_utils
+
+
+class Ensembler:
+ def __init__(self, logger, algorithm="avg_wave", weights=None):
+ self.logger = logger
+ self.algorithm = algorithm
+ self.weights = weights
+
+ def ensemble(self, waveforms):
+ """
+ Ensemble multiple waveforms using the selected algorithm.
+ :param waveforms: List of waveforms, each of shape (channels, length)
+ :return: Ensembled waveform of shape (channels, length)
+ """
+ if not waveforms:
+ return None
+ if len(waveforms) == 1:
+ return waveforms[0]
+
+ # Ensure all waveforms have the same number of channels
+ num_channels = waveforms[0].shape[0]
+ if any(w.shape[0] != num_channels for w in waveforms):
+ raise ValueError("All waveforms must have the same number of channels for ensembling.")
+
+ # Ensure all waveforms have the same length by padding with zeros
+ max_length = max(w.shape[1] for w in waveforms)
+ waveforms = [np.pad(w, ((0, 0), (0, max_length - w.shape[1]))) if w.shape[1] < max_length else w for w in waveforms]
+
+ if self.weights is None:
+ weights = np.ones(len(waveforms))
+ else:
+ weights = np.array(self.weights)
+ if len(weights) != len(waveforms):
+ self.logger.warning(f"Number of weights ({len(weights)}) does not match number of waveforms ({len(waveforms)}). Using equal weights.")
+ weights = np.ones(len(waveforms))
+
+ # Validate weights are finite and sum is non-zero
+ weights_sum = np.sum(weights)
+ if not np.all(np.isfinite(weights)) or not np.isfinite(weights_sum) or weights_sum == 0:
+ self.logger.warning(f"Weights {self.weights} contain non-finite values or sum to zero. Falling back to equal weights.")
+ weights = np.ones(len(waveforms))
+
+ self.logger.debug(f"Ensembling {len(waveforms)} waveforms using algorithm {self.algorithm}")
+
+ if self.algorithm == "avg_wave":
+ ensembled = np.zeros_like(waveforms[0])
+ for w, weight in zip(waveforms, weights, strict=True):
+ ensembled += w * weight
+ return ensembled / np.sum(weights)
+ elif self.algorithm == "median_wave":
+ if self.weights is not None and not np.all(weights == weights[0]):
+ self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
+ return np.median(waveforms, axis=0)
+ elif self.algorithm == "min_wave":
+ if self.weights is not None and not np.all(weights == weights[0]):
+ self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
+ return self._lambda_min(np.array(waveforms), axis=0, key=np.abs)
+ elif self.algorithm == "max_wave":
+ if self.weights is not None and not np.all(weights == weights[0]):
+ self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
+ return self._lambda_max(np.array(waveforms), axis=0, key=np.abs)
+ elif self.algorithm in ["avg_fft", "median_fft", "min_fft", "max_fft"]:
+ return self._ensemble_fft(waveforms, weights)
+ elif self.algorithm == "uvr_max_spec":
+ return self._ensemble_uvr(waveforms, spec_utils.MAX_SPEC)
+ elif self.algorithm == "uvr_min_spec":
+ return self._ensemble_uvr(waveforms, spec_utils.MIN_SPEC)
+ elif self.algorithm == "ensemble_wav":
+ return spec_utils.ensemble_wav(waveforms)
+ else:
+ raise ValueError(f"Unknown ensemble algorithm: {self.algorithm}")
+
+ def _lambda_max(self, arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmax(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+ def _lambda_min(self, arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmin(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+ def _stft(self, wave, nfft=2048, hl=1024):
+ if wave.ndim == 1:
+ wave = np.stack([wave, wave])
+ elif wave.shape[0] == 1:
+ wave = np.vstack([wave, wave])
+
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
+ spec = np.asfortranarray([spec_left, spec_right])
+ return spec
+
+ def _istft(self, spec, hl=1024, length=None, original_channels=None):
+ if spec.shape[0] == 1:
+ spec = np.vstack([spec, spec])
+
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
+ wave = np.asfortranarray([wave_left, wave_right])
+
+ if original_channels == 1:
+ wave = wave[:1, :]
+
+ return wave
+
+ def _ensemble_fft(self, waveforms, weights):
+ num_channels = waveforms[0].shape[0]
+ final_length = waveforms[0].shape[-1]
+ specs = [self._stft(w) for w in waveforms]
+ specs = np.array(specs)
+
+ if self.algorithm == "avg_fft":
+ ense_spec = np.zeros_like(specs[0])
+ for s, weight in zip(specs, weights, strict=True):
+ ense_spec += s * weight
+ ense_spec /= np.sum(weights)
+ elif self.algorithm in ["median_fft", "min_fft", "max_fft"]:
+ if self.weights is not None and not np.all(weights == weights[0]):
+ self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
+
+ if self.algorithm == "median_fft":
+ # For complex numbers, we take median of real and imag parts separately to be safe
+ real_median = np.median(np.real(specs), axis=0)
+ imag_median = np.median(np.imag(specs), axis=0)
+ ense_spec = real_median + 1j * imag_median
+ elif self.algorithm == "min_fft":
+ ense_spec = self._lambda_min(specs, axis=0, key=np.abs)
+ elif self.algorithm == "max_fft":
+ ense_spec = self._lambda_max(specs, axis=0, key=np.abs)
+
+ return self._istft(ense_spec, length=final_length, original_channels=num_channels)
+
+ def _ensemble_uvr(self, waveforms, uvr_algorithm):
+ specs = [spec_utils.wave_to_spectrogram_no_mp(w) for w in waveforms]
+ ense_spec = spec_utils.ensembling(uvr_algorithm, specs)
+ return spec_utils.spectrogram_to_wave_no_mp(ense_spec)
\ No newline at end of file
diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py
index 9a8e42f..476bbdf 100644
--- a/audio_separator/separator/separator.py
+++ b/audio_separator/separator/separator.py
@@ -11,6 +11,8 @@
import importlib
import io
import re
+import librosa
+import numpy as np
from typing import Optional
import hashlib
@@ -21,6 +23,7 @@
import torch.amp.autocast_mode as autocast_mode
import onnxruntime as ort
from tqdm import tqdm
+from audio_separator.separator.ensembler import Ensembler
class Separator:
@@ -100,6 +103,8 @@ def __init__(
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
+ ensemble_algorithm="avg_wave",
+ ensemble_weights=None,
info_only=False,
):
"""Initialize the separator."""
@@ -189,6 +194,9 @@ def __init__(
if chunk_duration <= 0:
raise ValueError("chunk_duration must be greater than 0")
+ self.ensemble_algorithm = ensemble_algorithm
+ self.ensemble_weights = ensemble_weights
+
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
# even though they are specific to a single model architecture
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
@@ -724,6 +732,17 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt")
This method instantiates the architecture-specific separation class,
loading the separation model into memory, downloading it first if necessary.
"""
+ if isinstance(model_filename, list):
+ if len(model_filename) > 1:
+ self.model_filename = model_filename
+ self.model_filenames = model_filename
+ self.logger.info(f"Multiple models specified for ensembling: {self.model_filenames}")
+ return
+ model_filename = model_filename[0]
+
+ self.model_filename = model_filename
+ self.model_filenames = [model_filename]
+
self.logger.info(f"Loading model {model_filename}...")
load_model_start_time = time.perf_counter()
@@ -786,7 +805,7 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt")
separator_class = getattr(module, class_name)
self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
-
+
try:
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
except Exception as e:
@@ -804,7 +823,7 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt")
roformer_stats = self.model_instance.get_roformer_loading_stats()
if roformer_stats:
self.logger.info(f"Roformer loading stats: {roformer_stats}")
-
+
# Log the completion of the model load process
self.logger.debug("Loading model completed.")
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
@@ -825,9 +844,12 @@ def separate(self, audio_file_path, custom_output_names=None):
- output_files (list of str): A list containing the paths to the separated audio stem files.
"""
# Check if the model and device are properly initialized
- if not (self.torch_device and self.model_instance):
+ if not (self.torch_device and (self.model_instance or (isinstance(self.model_filename, list) and len(self.model_filename) > 0))):
raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
+ if isinstance(self.model_filename, list) and len(self.model_filename) > 1:
+ return self._separate_ensemble(audio_file_path, custom_output_names)
+
# If audio_file_path is a string, convert it to a list for uniform processing
if isinstance(audio_file_path, str):
audio_file_path = [audio_file_path]
@@ -1112,3 +1134,162 @@ def sort_key(item):
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
return simplified_list
+
+ def _separate_ensemble(self, audio_file_path, custom_output_names=None):
+ """
+ Internal method to handle ensembling of multiple models.
+ """
+ import tempfile
+ import shutil
+
+ if isinstance(audio_file_path, str):
+ audio_file_path = [audio_file_path]
+
+ output_files = []
+
+ original_model_filename = self.model_filename
+ original_model_filenames = self.model_filenames
+
+ for path in audio_file_path:
+ self.logger.info(f"Ensemble processing for file: {path}")
+
+ # Create temporary directory for intermediate stems
+ temp_dir = tempfile.mkdtemp(prefix="audio-separator-ensemble-")
+ self.logger.debug(f"Created temporary directory for ensemble: {temp_dir}")
+
+ try:
+ # Store paths of intermediate stems grouped by stem name
+ # { "Vocals": ["temp_dir/model1_Vocals.wav", "temp_dir/model2_Vocals.wav"], ... }
+ stems_by_type = {}
+
+ for model_filename in original_model_filenames:
+ self.logger.info(f"Processing with model: {model_filename}")
+
+ # Load the model
+ self.load_model(model_filename)
+
+ # Set temporary output directory for this model
+ original_output_dir = self.output_dir
+ self.output_dir = temp_dir
+ if self.model_instance:
+ self.model_instance.output_dir = temp_dir
+
+ try:
+ # Perform separation
+ model_stems = self._separate_file(path, custom_output_names)
+
+ for stem_path in model_stems:
+ # Extract stem name from filename: "audio_(Vocals)_model.wav" -> "Vocals"
+ filename = os.path.basename(stem_path)
+ match = re.search(r'_\(([^)]+)\)', filename)
+ if match:
+ stem_name = match.group(1)
+ else:
+ stem_name = "Unknown"
+
+ # Normalize stem names to fix mismatched model labels
+ lower_name = stem_name.lower()
+ stem_name_map = {
+ "vocals": "Vocals",
+ "instrumental": "Instrumental",
+ "inst": "Instrumental",
+ "karaoke": "Instrumental",
+ "other": "Other",
+ "no_vocals": "Instrumental",
+ "drums": "Drums",
+ "bass": "Bass",
+ "guitar": "Guitar",
+ "piano": "Piano",
+ "synthesizer": "Synthesizer",
+ "strings": "Strings",
+ "woodwinds": "Woodwinds",
+ "brass": "Brass",
+ "wind inst": "Wind Inst",
+ "lead vocals": "Lead Vocals",
+ "backing vocals": "Backing Vocals",
+ "primary stem": "Primary Stem",
+ "secondary stem": "Secondary Stem",
+ }
+
+ if "vocal" in lower_name and "lead" not in lower_name and "backing" not in lower_name:
+ stem_name = "Vocals"
+ elif lower_name in stem_name_map:
+ stem_name = stem_name_map[lower_name]
+ else:
+ # Standardize capitalization for other stems (e.g., Drums, Bass)
+ stem_name = stem_name.title()
+
+ if stem_name not in stems_by_type:
+ stems_by_type[stem_name] = []
+
+ # Ensure absolute path
+ abs_path = stem_path if os.path.isabs(stem_path) else os.path.join(temp_dir, stem_path)
+ stems_by_type[stem_name].append(abs_path)
+ finally:
+ self.output_dir = original_output_dir
+
+ # Perform ensembling for each stem type
+ ensembler = Ensembler(self.logger, self.ensemble_algorithm, self.ensemble_weights)
+ base_name = os.path.splitext(os.path.basename(path))[0]
+
+ for stem_name, stem_paths in stems_by_type.items():
+ self.logger.info(f"Ensembling {len(stem_paths)} stems for type: {stem_name}")
+
+ waveforms = []
+ for sp in stem_paths:
+ wav, _ = librosa.load(sp, mono=False, sr=self.sample_rate)
+ if wav.ndim == 1:
+ wav = np.asfortranarray([wav, wav])
+ waveforms.append(wav)
+
+ ensembled_wav = ensembler.ensemble(waveforms)
+
+ # Determine output filename
+ if custom_output_names and stem_name in custom_output_names:
+ output_filename = custom_output_names[stem_name]
+ else:
+ output_filename = f"{base_name}_({stem_name})_Ensemble"
+
+ output_path = f"{output_filename}.{self.output_format.lower()}"
+
+ # Use a dummy model instance to write the audio if necessary,
+ # or just use the last model instance we had.
+ # Actually, we can use the write_audio method from the last model_instance
+ if self.model_instance:
+ # Ensure the model instance has the correct audio_file_path and output_dir
+ self.model_instance.audio_file_path = path
+ self.model_instance.output_dir = self.output_dir
+ self.model_instance.write_audio(output_path, ensembled_wav.T)
+ final_output_path = os.path.join(self.output_dir, output_path)
+ output_files.append(final_output_path)
+ else:
+ # Fallback writer if no model instance is available
+ self.logger.warning(f"No model instance available to write ensembled audio. Using fallback writer for {output_path}")
+ final_output_path = os.path.join(self.output_dir, output_path)
+
+ import soundfile as sf
+
+ try:
+ self.logger.debug(f"Attempting to write ensembled audio to {final_output_path}...")
+ sf.write(final_output_path, ensembled_wav.T, self.sample_rate)
+ except Exception as e:
+ self.logger.error(f"Error writing {self.output_format} format: {e}. Falling back to WAV.")
+ final_output_path = final_output_path.rsplit(".", 1)[0] + ".wav"
+ sf.write(final_output_path, ensembled_wav.T, self.sample_rate)
+
+ output_files.append(final_output_path)
+
+ finally:
+ # Restore original model filenames state
+ self.model_filename = original_model_filename
+ self.model_filenames = original_model_filenames
+
+ # Clear model instance reference
+ self.model_instance = None
+
+ # Clean up temporary directory
+ if os.path.exists(temp_dir):
+ self.logger.debug(f"Cleaning up temporary directory: {temp_dir}")
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+ return output_files
\ No newline at end of file
diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py
index e6a8492..a31445a 100755
--- a/audio_separator/utils/cli.py
+++ b/audio_separator/utils/cli.py
@@ -37,7 +37,7 @@ def main():
info_params.add_argument("--list_limit", type=int, help="Limit the number of models shown")
info_params.add_argument("--list_format", choices=["pretty", "json"], default="pretty", help="Format for listing models: 'pretty' for formatted output, 'json' for raw JSON dump")
- model_filename_help = "Model to use for separation (default: %(default)s). Example: -m 2_HP-UVR.pth"
+ model_filename_help = "Model(s) to use for separation (default: %(default)s). Multiple models can be specified for ensembling. Example: -m model1.ckpt model2.onnx"
output_format_help = "Output format for separated files, any common format (default: %(default)s). Example: --output_format=MP3"
output_bitrate_help = "Output bitrate for separated files, any ffmpeg-compatible bitrate (default: %(default)s). Example: --output_bitrate=320k"
output_dir_help = "Directory to write output files (default: