Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<div align="center">

# 🎶 Audio Separator 🎶

[![PyPI version](https://badge.fury.io/py/audio-separator.svg)](https://badge.fury.io/py/audio-separator)
Expand Down Expand Up @@ -318,6 +318,61 @@ The chunking feature supports all model types:

Chunks are concatenated without crossfading, which may result in minor artifacts at chunk boundaries in rare cases. For most use cases, these are not noticeable. The simple concatenation approach keeps processing time minimal while solving out-of-memory issues.

### Ensembling Multiple Models

You can combine the results of multiple models to improve separation quality. This will run each model and then combine their outputs using a specified algorithm.

#### CLI Usage

Use the `--model_filename` (or `-m`) flag with multiple arguments. You can also specify the ensemble algorithm using `--ensemble_algorithm`.

```sh
# Ensemble two models using the default 'avg_wave' algorithm
audio-separator audio.wav -m model1.ckpt model2.onnx

# Ensemble multiple models using a specific algorithm
audio-separator audio.wav -m model1.ckpt model2.onnx model3.ckpt --ensemble_algorithm max_fft
```

#### Python API Usage

```python
from audio_separator.separator import Separator

# Initialize the Separator class with custom parameters
separator = Separator(
output_dir='output',
ensemble_algorithm='avg_wave'
)

# List of models to ensemble
# Note: These models will be downloaded automatically if not present
models = [
'UVR-MDX-NET-Inst_HQ_3.onnx',
'UVR_MDXNET_KARA_2.onnx'
]

# Specify multiple models for ensembling
separator.load_model(model_filename=models)

# Perform separation
# The algorithm defaults to 'avg_wave' as specified during Separator initialization
output_files = separator.separate('audio.wav')
```

#### Supported Ensemble Algorithms
- `avg_wave`: Weighted average of waveforms (default)
- `median_wave`: Median of waveforms
- `min_wave`: Minimum of waveforms
- `max_wave`: Maximum of waveforms
- `avg_fft`: Weighted average of spectrograms
- `median_fft`: Median of spectrograms
- `min_fft`: Minimum of spectrograms
- `max_fft`: Maximum of spectrograms
- `uvr_max_spec`: UVR-based maximum spectrogram ensemble
- `uvr_min_spec`: UVR-based minimum spectrogram ensemble
- `ensemble_wav`: UVR-based least noisy chunk ensemble

### Full command-line interface options

```sh
Expand Down Expand Up @@ -525,6 +580,8 @@ You can also rename specific stems:
- **`vr_params`:** (Optional) VR Architecture Specific Attributes & Defaults. `Default: {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}`
- **`demucs_params`:** (Optional) Demucs Architecture Specific Attributes & Defaults. `Default: {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}`
- **`mdxc_params`:** (Optional) MDXC Architecture Specific Attributes & Defaults. `Default: {"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0}`
- **`ensemble_algorithm`:** (Optional) Algorithm to use for ensembling multiple models. `Default: 'avg_wave'`
- **`ensemble_weights`:** (Optional) Weights for each model in the ensemble. `Default: None` (equal weights)

## Remote API Usage 🌐

Expand Down Expand Up @@ -653,4 +710,4 @@ For questions or feedback, please raise an issue or reach out to @beveradb ([And
<img src="https://contrib.rocks/image?repo=nomadkaraoke/python-audio-separator" />
</a>

</div>
</div>
156 changes: 156 additions & 0 deletions audio_separator/separator/ensembler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import numpy as np
import librosa
from audio_separator.separator.uvr_lib_v5 import spec_utils


class Ensembler:
def __init__(self, logger, algorithm="avg_wave", weights=None):
self.logger = logger
self.algorithm = algorithm
self.weights = weights

def ensemble(self, waveforms):
"""
Ensemble multiple waveforms using the selected algorithm.
:param waveforms: List of waveforms, each of shape (channels, length)
:return: Ensembled waveform of shape (channels, length)
"""
if not waveforms:
return None
if len(waveforms) == 1:
return waveforms[0]

# Ensure all waveforms have the same number of channels
num_channels = waveforms[0].shape[0]
if any(w.shape[0] != num_channels for w in waveforms):
raise ValueError("All waveforms must have the same number of channels for ensembling.")

# Ensure all waveforms have the same length by padding with zeros
max_length = max(w.shape[1] for w in waveforms)
waveforms = [np.pad(w, ((0, 0), (0, max_length - w.shape[1]))) if w.shape[1] < max_length else w for w in waveforms]

if self.weights is None:
weights = np.ones(len(waveforms))
else:
weights = np.array(self.weights)
if len(weights) != len(waveforms):
self.logger.warning(f"Number of weights ({len(weights)}) does not match number of waveforms ({len(waveforms)}). Using equal weights.")
weights = np.ones(len(waveforms))

# Validate weights are finite and sum is non-zero
weights_sum = np.sum(weights)
if not np.all(np.isfinite(weights)) or not np.isfinite(weights_sum) or weights_sum == 0:
self.logger.warning(f"Weights {self.weights} contain non-finite values or sum to zero. Falling back to equal weights.")
weights = np.ones(len(waveforms))

self.logger.debug(f"Ensembling {len(waveforms)} waveforms using algorithm {self.algorithm}")

if self.algorithm == "avg_wave":
ensembled = np.zeros_like(waveforms[0])
for w, weight in zip(waveforms, weights, strict=True):
ensembled += w * weight
return ensembled / np.sum(weights)
elif self.algorithm == "median_wave":
if self.weights is not None and not np.all(weights == weights[0]):
self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
return np.median(waveforms, axis=0)
elif self.algorithm == "min_wave":
if self.weights is not None and not np.all(weights == weights[0]):
self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
return self._lambda_min(np.array(waveforms), axis=0, key=np.abs)
elif self.algorithm == "max_wave":
if self.weights is not None and not np.all(weights == weights[0]):
self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")
return self._lambda_max(np.array(waveforms), axis=0, key=np.abs)
elif self.algorithm in ["avg_fft", "median_fft", "min_fft", "max_fft"]:
return self._ensemble_fft(waveforms, weights)
elif self.algorithm == "uvr_max_spec":
return self._ensemble_uvr(waveforms, spec_utils.MAX_SPEC)
elif self.algorithm == "uvr_min_spec":
return self._ensemble_uvr(waveforms, spec_utils.MIN_SPEC)
elif self.algorithm == "ensemble_wav":
return spec_utils.ensemble_wav(waveforms)
else:
raise ValueError(f"Unknown ensemble algorithm: {self.algorithm}")

def _lambda_max(self, arr, axis=None, key=None, keepdims=False):
idxs = np.argmax(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]

def _lambda_min(self, arr, axis=None, key=None, keepdims=False):
idxs = np.argmin(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]

def _stft(self, wave, nfft=2048, hl=1024):
if wave.ndim == 1:
wave = np.stack([wave, wave])
elif wave.shape[0] == 1:
wave = np.vstack([wave, wave])

wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
spec = np.asfortranarray([spec_left, spec_right])
return spec

def _istft(self, spec, hl=1024, length=None, original_channels=None):
if spec.shape[0] == 1:
spec = np.vstack([spec, spec])

spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
wave = np.asfortranarray([wave_left, wave_right])

if original_channels == 1:
wave = wave[:1, :]

return wave

def _ensemble_fft(self, waveforms, weights):
num_channels = waveforms[0].shape[0]
final_length = waveforms[0].shape[-1]
specs = [self._stft(w) for w in waveforms]
specs = np.array(specs)

if self.algorithm == "avg_fft":
ense_spec = np.zeros_like(specs[0])
for s, weight in zip(specs, weights, strict=True):
ense_spec += s * weight
ense_spec /= np.sum(weights)
elif self.algorithm in ["median_fft", "min_fft", "max_fft"]:
if self.weights is not None and not np.all(weights == weights[0]):
self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}")

if self.algorithm == "median_fft":
# For complex numbers, we take median of real and imag parts separately to be safe
real_median = np.median(np.real(specs), axis=0)
imag_median = np.median(np.imag(specs), axis=0)
ense_spec = real_median + 1j * imag_median
elif self.algorithm == "min_fft":
ense_spec = self._lambda_min(specs, axis=0, key=np.abs)
elif self.algorithm == "max_fft":
ense_spec = self._lambda_max(specs, axis=0, key=np.abs)

return self._istft(ense_spec, length=final_length, original_channels=num_channels)

def _ensemble_uvr(self, waveforms, uvr_algorithm):
specs = [spec_utils.wave_to_spectrogram_no_mp(w) for w in waveforms]
ense_spec = spec_utils.ensembling(uvr_algorithm, specs)
return spec_utils.spectrogram_to_wave_no_mp(ense_spec)
Comment on lines +153 to +156
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check spec_utils for mono handling in ensemble functions
rg -n "def wave_to_spectrogram_no_mp|def spectrogram_to_wave_no_mp|def ensembling" audio_separator/separator/uvr_lib_v5/spec_utils.py -A 10

Repository: nomadkaraoke/python-audio-separator

Length of output: 983


🏁 Script executed:

# Find and examine _ensemble_fft and _ensemble_uvr methods
rg -n "def _ensemble_fft|def _ensemble_uvr" audio_separator/separator/ensembler.py -A 15

Repository: nomadkaraoke/python-audio-separator

Length of output: 1163


🏁 Script executed:

# Get the full ensembling function from spec_utils
rg -n "def ensembling" audio_separator/separator/uvr_lib_v5/spec_utils.py -A 30

Repository: nomadkaraoke/python-audio-separator

Length of output: 1300


🏁 Script executed:

# Check how _ensemble_uvr is called to understand expected input/output
rg -n "_ensemble_uvr" audio_separator/separator/ensembler.py -B 3 -A 3

Repository: nomadkaraoke/python-audio-separator

Length of output: 975


🏁 Script executed:

# Find _istft method to see how it uses original_channels
rg -n "def _istft" audio_separator/separator/ensembler.py -A 10

Repository: nomadkaraoke/python-audio-separator

Length of output: 612


🏁 Script executed:

# Also check the _stft method to understand the flow
rg -n "def _stft" audio_separator/separator/ensembler.py -A 10

Repository: nomadkaraoke/python-audio-separator

Length of output: 610


🏁 Script executed:

# Get the rest of the _istft method to see the original_channels handling
rg -n "def _istft" audio_separator/separator/ensembler.py -A 15

Repository: nomadkaraoke/python-audio-separator

Length of output: 734


_ensemble_uvr does not preserve mono channel shape.

Unlike _ensemble_fft, which captures the original channel count and passes it to _istft for restoration, _ensemble_uvr has no mechanism to track or restore the input channel configuration. The spec_utils.wave_to_spectrogram_no_mp function converts mono (1D) to stereo (2D) by duplicating channels, and spec_utils.spectrogram_to_wave_no_mp has no original_channels parameter to reverse this conversion. As a result, mono waveforms will be output as stereo.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/ensembler.py` around lines 153 - 156, The
_ensemble_uvr implementation doesn't preserve original channel count: capture
the original channel configuration from the input (e.g., inspect the first entry
in waveforms to determine if it was mono vs multichannel) before calling
spec_utils.wave_to_spectrogram_no_mp; after getting the waveform back from
spec_utils.spectrogram_to_wave_no_mp, if the original was mono but the returned
wave is stereo/2D, collapse it back to mono (for example by averaging channels
or taking the first channel) so mono inputs produce mono outputs — mirror the
behavior of _ensemble_fft/_istft by using original channel info to restore
shape.

Loading
Loading