diff --git a/README.md b/README.md index c6b8057..16353e7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- + # 🎶 Audio Separator 🎶 [![PyPI version](https://badge.fury.io/py/audio-separator.svg)](https://badge.fury.io/py/audio-separator) @@ -318,6 +318,61 @@ The chunking feature supports all model types: Chunks are concatenated without crossfading, which may result in minor artifacts at chunk boundaries in rare cases. For most use cases, these are not noticeable. The simple concatenation approach keeps processing time minimal while solving out-of-memory issues. +### Ensembling Multiple Models + +You can combine the results of multiple models to improve separation quality. This will run each model and then combine their outputs using a specified algorithm. + +#### CLI Usage + +Use the `--model_filename` (or `-m`) flag with multiple arguments. You can also specify the ensemble algorithm using `--ensemble_algorithm`. + +```sh +# Ensemble two models using the default 'avg_wave' algorithm +audio-separator audio.wav -m model1.ckpt model2.onnx + +# Ensemble multiple models using a specific algorithm +audio-separator audio.wav -m model1.ckpt model2.onnx model3.ckpt --ensemble_algorithm max_fft +``` + +#### Python API Usage + +```python +from audio_separator.separator import Separator + +# Initialize the Separator class with custom parameters +separator = Separator( + output_dir='output', + ensemble_algorithm='avg_wave' +) + +# List of models to ensemble +# Note: These models will be downloaded automatically if not present +models = [ + 'UVR-MDX-NET-Inst_HQ_3.onnx', + 'UVR_MDXNET_KARA_2.onnx' +] + +# Specify multiple models for ensembling +separator.load_model(model_filename=models) + +# Perform separation +# The algorithm defaults to 'avg_wave' as specified during Separator initialization +output_files = separator.separate('audio.wav') +``` + +#### Supported Ensemble Algorithms +- `avg_wave`: Weighted average of waveforms (default) +- `median_wave`: Median of waveforms +- `min_wave`: Minimum of waveforms +- `max_wave`: Maximum of waveforms +- `avg_fft`: Weighted average of spectrograms +- `median_fft`: Median of spectrograms +- `min_fft`: Minimum of spectrograms +- `max_fft`: Maximum of spectrograms +- `uvr_max_spec`: UVR-based maximum spectrogram ensemble +- `uvr_min_spec`: UVR-based minimum spectrogram ensemble +- `ensemble_wav`: UVR-based least noisy chunk ensemble + ### Full command-line interface options ```sh @@ -525,6 +580,8 @@ You can also rename specific stems: - **`vr_params`:** (Optional) VR Architecture Specific Attributes & Defaults. `Default: {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}` - **`demucs_params`:** (Optional) Demucs Architecture Specific Attributes & Defaults. `Default: {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}` - **`mdxc_params`:** (Optional) MDXC Architecture Specific Attributes & Defaults. `Default: {"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0}` +- **`ensemble_algorithm`:** (Optional) Algorithm to use for ensembling multiple models. `Default: 'avg_wave'` +- **`ensemble_weights`:** (Optional) Weights for each model in the ensemble. `Default: None` (equal weights) ## Remote API Usage 🌐 @@ -653,4 +710,4 @@ For questions or feedback, please raise an issue or reach out to @beveradb ([And -
+ \ No newline at end of file diff --git a/audio_separator/separator/ensembler.py b/audio_separator/separator/ensembler.py new file mode 100644 index 0000000..10e357d --- /dev/null +++ b/audio_separator/separator/ensembler.py @@ -0,0 +1,156 @@ +import numpy as np +import librosa +from audio_separator.separator.uvr_lib_v5 import spec_utils + + +class Ensembler: + def __init__(self, logger, algorithm="avg_wave", weights=None): + self.logger = logger + self.algorithm = algorithm + self.weights = weights + + def ensemble(self, waveforms): + """ + Ensemble multiple waveforms using the selected algorithm. + :param waveforms: List of waveforms, each of shape (channels, length) + :return: Ensembled waveform of shape (channels, length) + """ + if not waveforms: + return None + if len(waveforms) == 1: + return waveforms[0] + + # Ensure all waveforms have the same number of channels + num_channels = waveforms[0].shape[0] + if any(w.shape[0] != num_channels for w in waveforms): + raise ValueError("All waveforms must have the same number of channels for ensembling.") + + # Ensure all waveforms have the same length by padding with zeros + max_length = max(w.shape[1] for w in waveforms) + waveforms = [np.pad(w, ((0, 0), (0, max_length - w.shape[1]))) if w.shape[1] < max_length else w for w in waveforms] + + if self.weights is None: + weights = np.ones(len(waveforms)) + else: + weights = np.array(self.weights) + if len(weights) != len(waveforms): + self.logger.warning(f"Number of weights ({len(weights)}) does not match number of waveforms ({len(waveforms)}). Using equal weights.") + weights = np.ones(len(waveforms)) + + # Validate weights are finite and sum is non-zero + weights_sum = np.sum(weights) + if not np.all(np.isfinite(weights)) or not np.isfinite(weights_sum) or weights_sum == 0: + self.logger.warning(f"Weights {self.weights} contain non-finite values or sum to zero. Falling back to equal weights.") + weights = np.ones(len(waveforms)) + + self.logger.debug(f"Ensembling {len(waveforms)} waveforms using algorithm {self.algorithm}") + + if self.algorithm == "avg_wave": + ensembled = np.zeros_like(waveforms[0]) + for w, weight in zip(waveforms, weights, strict=True): + ensembled += w * weight + return ensembled / np.sum(weights) + elif self.algorithm == "median_wave": + if self.weights is not None and not np.all(weights == weights[0]): + self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}") + return np.median(waveforms, axis=0) + elif self.algorithm == "min_wave": + if self.weights is not None and not np.all(weights == weights[0]): + self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}") + return self._lambda_min(np.array(waveforms), axis=0, key=np.abs) + elif self.algorithm == "max_wave": + if self.weights is not None and not np.all(weights == weights[0]): + self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}") + return self._lambda_max(np.array(waveforms), axis=0, key=np.abs) + elif self.algorithm in ["avg_fft", "median_fft", "min_fft", "max_fft"]: + return self._ensemble_fft(waveforms, weights) + elif self.algorithm == "uvr_max_spec": + return self._ensemble_uvr(waveforms, spec_utils.MAX_SPEC) + elif self.algorithm == "uvr_min_spec": + return self._ensemble_uvr(waveforms, spec_utils.MIN_SPEC) + elif self.algorithm == "ensemble_wav": + return spec_utils.ensemble_wav(waveforms) + else: + raise ValueError(f"Unknown ensemble algorithm: {self.algorithm}") + + def _lambda_max(self, arr, axis=None, key=None, keepdims=False): + idxs = np.argmax(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + def _lambda_min(self, arr, axis=None, key=None, keepdims=False): + idxs = np.argmin(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + def _stft(self, wave, nfft=2048, hl=1024): + if wave.ndim == 1: + wave = np.stack([wave, wave]) + elif wave.shape[0] == 1: + wave = np.vstack([wave, wave]) + + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + return spec + + def _istft(self, spec, hl=1024, length=None, original_channels=None): + if spec.shape[0] == 1: + spec = np.vstack([spec, spec]) + + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl, length=length) + wave_right = librosa.istft(spec_right, hop_length=hl, length=length) + wave = np.asfortranarray([wave_left, wave_right]) + + if original_channels == 1: + wave = wave[:1, :] + + return wave + + def _ensemble_fft(self, waveforms, weights): + num_channels = waveforms[0].shape[0] + final_length = waveforms[0].shape[-1] + specs = [self._stft(w) for w in waveforms] + specs = np.array(specs) + + if self.algorithm == "avg_fft": + ense_spec = np.zeros_like(specs[0]) + for s, weight in zip(specs, weights, strict=True): + ense_spec += s * weight + ense_spec /= np.sum(weights) + elif self.algorithm in ["median_fft", "min_fft", "max_fft"]: + if self.weights is not None and not np.all(weights == weights[0]): + self.logger.warning(f"Weights are ignored for algorithm {self.algorithm}") + + if self.algorithm == "median_fft": + # For complex numbers, we take median of real and imag parts separately to be safe + real_median = np.median(np.real(specs), axis=0) + imag_median = np.median(np.imag(specs), axis=0) + ense_spec = real_median + 1j * imag_median + elif self.algorithm == "min_fft": + ense_spec = self._lambda_min(specs, axis=0, key=np.abs) + elif self.algorithm == "max_fft": + ense_spec = self._lambda_max(specs, axis=0, key=np.abs) + + return self._istft(ense_spec, length=final_length, original_channels=num_channels) + + def _ensemble_uvr(self, waveforms, uvr_algorithm): + specs = [spec_utils.wave_to_spectrogram_no_mp(w) for w in waveforms] + ense_spec = spec_utils.ensembling(uvr_algorithm, specs) + return spec_utils.spectrogram_to_wave_no_mp(ense_spec) \ No newline at end of file diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index 9a8e42f..476bbdf 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -11,6 +11,8 @@ import importlib import io import re +import librosa +import numpy as np from typing import Optional import hashlib @@ -21,6 +23,7 @@ import torch.amp.autocast_mode as autocast_mode import onnxruntime as ort from tqdm import tqdm +from audio_separator.separator.ensembler import Ensembler class Separator: @@ -100,6 +103,8 @@ def __init__( vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0}, + ensemble_algorithm="avg_wave", + ensemble_weights=None, info_only=False, ): """Initialize the separator.""" @@ -189,6 +194,9 @@ def __init__( if chunk_duration <= 0: raise ValueError("chunk_duration must be greater than 0") + self.ensemble_algorithm = ensemble_algorithm + self.ensemble_weights = ensemble_weights + # These are parameters which users may want to configure so we expose them to the top-level Separator class, # even though they are specific to a single model architecture self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params} @@ -724,6 +732,17 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt") This method instantiates the architecture-specific separation class, loading the separation model into memory, downloading it first if necessary. """ + if isinstance(model_filename, list): + if len(model_filename) > 1: + self.model_filename = model_filename + self.model_filenames = model_filename + self.logger.info(f"Multiple models specified for ensembling: {self.model_filenames}") + return + model_filename = model_filename[0] + + self.model_filename = model_filename + self.model_filenames = [model_filename] + self.logger.info(f"Loading model {model_filename}...") load_model_start_time = time.perf_counter() @@ -786,7 +805,7 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt") separator_class = getattr(module, class_name) self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}") - + try: self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type]) except Exception as e: @@ -804,7 +823,7 @@ def load_model(self, model_filename="model_bs_roformer_ep_317_sdr_12.9755.ckpt") roformer_stats = self.model_instance.get_roformer_loading_stats() if roformer_stats: self.logger.info(f"Roformer loading stats: {roformer_stats}") - + # Log the completion of the model load process self.logger.debug("Loading model completed.") self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}') @@ -825,9 +844,12 @@ def separate(self, audio_file_path, custom_output_names=None): - output_files (list of str): A list containing the paths to the separated audio stem files. """ # Check if the model and device are properly initialized - if not (self.torch_device and self.model_instance): + if not (self.torch_device and (self.model_instance or (isinstance(self.model_filename, list) and len(self.model_filename) > 0))): raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.") + if isinstance(self.model_filename, list) and len(self.model_filename) > 1: + return self._separate_ensemble(audio_file_path, custom_output_names) + # If audio_file_path is a string, convert it to a list for uniform processing if isinstance(audio_file_path, str): audio_file_path = [audio_file_path] @@ -1112,3 +1134,162 @@ def sort_key(item): return dict(sorted(filtered_list.items(), key=sort_key, reverse=True)) return simplified_list + + def _separate_ensemble(self, audio_file_path, custom_output_names=None): + """ + Internal method to handle ensembling of multiple models. + """ + import tempfile + import shutil + + if isinstance(audio_file_path, str): + audio_file_path = [audio_file_path] + + output_files = [] + + original_model_filename = self.model_filename + original_model_filenames = self.model_filenames + + for path in audio_file_path: + self.logger.info(f"Ensemble processing for file: {path}") + + # Create temporary directory for intermediate stems + temp_dir = tempfile.mkdtemp(prefix="audio-separator-ensemble-") + self.logger.debug(f"Created temporary directory for ensemble: {temp_dir}") + + try: + # Store paths of intermediate stems grouped by stem name + # { "Vocals": ["temp_dir/model1_Vocals.wav", "temp_dir/model2_Vocals.wav"], ... } + stems_by_type = {} + + for model_filename in original_model_filenames: + self.logger.info(f"Processing with model: {model_filename}") + + # Load the model + self.load_model(model_filename) + + # Set temporary output directory for this model + original_output_dir = self.output_dir + self.output_dir = temp_dir + if self.model_instance: + self.model_instance.output_dir = temp_dir + + try: + # Perform separation + model_stems = self._separate_file(path, custom_output_names) + + for stem_path in model_stems: + # Extract stem name from filename: "audio_(Vocals)_model.wav" -> "Vocals" + filename = os.path.basename(stem_path) + match = re.search(r'_\(([^)]+)\)', filename) + if match: + stem_name = match.group(1) + else: + stem_name = "Unknown" + + # Normalize stem names to fix mismatched model labels + lower_name = stem_name.lower() + stem_name_map = { + "vocals": "Vocals", + "instrumental": "Instrumental", + "inst": "Instrumental", + "karaoke": "Instrumental", + "other": "Other", + "no_vocals": "Instrumental", + "drums": "Drums", + "bass": "Bass", + "guitar": "Guitar", + "piano": "Piano", + "synthesizer": "Synthesizer", + "strings": "Strings", + "woodwinds": "Woodwinds", + "brass": "Brass", + "wind inst": "Wind Inst", + "lead vocals": "Lead Vocals", + "backing vocals": "Backing Vocals", + "primary stem": "Primary Stem", + "secondary stem": "Secondary Stem", + } + + if "vocal" in lower_name and "lead" not in lower_name and "backing" not in lower_name: + stem_name = "Vocals" + elif lower_name in stem_name_map: + stem_name = stem_name_map[lower_name] + else: + # Standardize capitalization for other stems (e.g., Drums, Bass) + stem_name = stem_name.title() + + if stem_name not in stems_by_type: + stems_by_type[stem_name] = [] + + # Ensure absolute path + abs_path = stem_path if os.path.isabs(stem_path) else os.path.join(temp_dir, stem_path) + stems_by_type[stem_name].append(abs_path) + finally: + self.output_dir = original_output_dir + + # Perform ensembling for each stem type + ensembler = Ensembler(self.logger, self.ensemble_algorithm, self.ensemble_weights) + base_name = os.path.splitext(os.path.basename(path))[0] + + for stem_name, stem_paths in stems_by_type.items(): + self.logger.info(f"Ensembling {len(stem_paths)} stems for type: {stem_name}") + + waveforms = [] + for sp in stem_paths: + wav, _ = librosa.load(sp, mono=False, sr=self.sample_rate) + if wav.ndim == 1: + wav = np.asfortranarray([wav, wav]) + waveforms.append(wav) + + ensembled_wav = ensembler.ensemble(waveforms) + + # Determine output filename + if custom_output_names and stem_name in custom_output_names: + output_filename = custom_output_names[stem_name] + else: + output_filename = f"{base_name}_({stem_name})_Ensemble" + + output_path = f"{output_filename}.{self.output_format.lower()}" + + # Use a dummy model instance to write the audio if necessary, + # or just use the last model instance we had. + # Actually, we can use the write_audio method from the last model_instance + if self.model_instance: + # Ensure the model instance has the correct audio_file_path and output_dir + self.model_instance.audio_file_path = path + self.model_instance.output_dir = self.output_dir + self.model_instance.write_audio(output_path, ensembled_wav.T) + final_output_path = os.path.join(self.output_dir, output_path) + output_files.append(final_output_path) + else: + # Fallback writer if no model instance is available + self.logger.warning(f"No model instance available to write ensembled audio. Using fallback writer for {output_path}") + final_output_path = os.path.join(self.output_dir, output_path) + + import soundfile as sf + + try: + self.logger.debug(f"Attempting to write ensembled audio to {final_output_path}...") + sf.write(final_output_path, ensembled_wav.T, self.sample_rate) + except Exception as e: + self.logger.error(f"Error writing {self.output_format} format: {e}. Falling back to WAV.") + final_output_path = final_output_path.rsplit(".", 1)[0] + ".wav" + sf.write(final_output_path, ensembled_wav.T, self.sample_rate) + + output_files.append(final_output_path) + + finally: + # Restore original model filenames state + self.model_filename = original_model_filename + self.model_filenames = original_model_filenames + + # Clear model instance reference + self.model_instance = None + + # Clean up temporary directory + if os.path.exists(temp_dir): + self.logger.debug(f"Cleaning up temporary directory: {temp_dir}") + shutil.rmtree(temp_dir, ignore_errors=True) + + return output_files \ No newline at end of file diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py index e6a8492..a31445a 100755 --- a/audio_separator/utils/cli.py +++ b/audio_separator/utils/cli.py @@ -37,7 +37,7 @@ def main(): info_params.add_argument("--list_limit", type=int, help="Limit the number of models shown") info_params.add_argument("--list_format", choices=["pretty", "json"], default="pretty", help="Format for listing models: 'pretty' for formatted output, 'json' for raw JSON dump") - model_filename_help = "Model to use for separation (default: %(default)s). Example: -m 2_HP-UVR.pth" + model_filename_help = "Model(s) to use for separation (default: %(default)s). Multiple models can be specified for ensembling. Example: -m model1.ckpt model2.onnx" output_format_help = "Output format for separated files, any common format (default: %(default)s). Example: --output_format=MP3" output_bitrate_help = "Output bitrate for separated files, any ffmpeg-compatible bitrate (default: %(default)s). Example: --output_bitrate=320k" output_dir_help = "Directory to write output files (default: ). Example: --output_dir=/app/separated" @@ -45,7 +45,7 @@ def main(): download_model_only_help = "Download a single model file only, without performing separation." io_params = parser.add_argument_group("Separation I/O Params") - io_params.add_argument("-m", "--model_filename", default="model_bs_roformer_ep_317_sdr_12.9755.ckpt", help=model_filename_help) + io_params.add_argument("-m", "--model_filename", default=["model_bs_roformer_ep_317_sdr_12.9755.ckpt"], nargs="+", help=model_filename_help) io_params.add_argument("--output_format", default="FLAC", help=output_format_help) io_params.add_argument("--output_bitrate", default=None, help=output_bitrate_help) io_params.add_argument("--output_dir", default=None, help=output_dir_help) @@ -60,6 +60,8 @@ def main(): use_soundfile_help = "Use soundfile to write audio output (default: %(default)s). Example: --use_soundfile" use_autocast_help = "Use PyTorch autocast for faster inference (default: %(default)s). Do not use for CPU inference. Example: --use_autocast" chunk_duration_help = "Split audio into chunks of this duration in seconds (default: %(default)s = no chunking). Useful for processing very long audio files on systems with limited memory. Recommended: 600 (10 minutes) for files >1 hour. Chunks are concatenated without overlap/crossfade. Example: --chunk_duration=600" + ensemble_algorithm_help = "Algorithm to use for ensembling multiple models (default: %(default)s). Choices: avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft, uvr_max_spec, uvr_min_spec, ensemble_wav. Example: --ensemble_algorithm=uvr_max_spec" + ensemble_weights_help = "Weights for ensembling multiple models (default: %(default)s). Number of weights must match number of models. Example: --ensemble_weights 1.0 0.5" custom_output_names_help = 'Custom names for all output files in JSON format (default: %(default)s). Example: --custom_output_names=\'{"Vocals": "vocals_output", "Drums": "drums_output"}\'' common_params = parser.add_argument_group("Common Separation Parameters") @@ -71,6 +73,13 @@ def main(): common_params.add_argument("--use_soundfile", action="store_true", help=use_soundfile_help) common_params.add_argument("--use_autocast", action="store_true", help=use_autocast_help) common_params.add_argument("--chunk_duration", type=float, default=None, help=chunk_duration_help) + common_params.add_argument( + "--ensemble_algorithm", + default="avg_wave", + choices=["avg_wave", "median_wave", "min_wave", "max_wave", "avg_fft", "median_fft", "min_fft", "max_fft", "uvr_max_spec", "uvr_min_spec", "ensemble_wav"], + help=ensemble_algorithm_help, + ) + common_params.add_argument("--ensemble_weights", nargs="+", type=float, default=None, help=ensemble_weights_help) common_params.add_argument("--custom_output_names", type=json.loads, default=None, help=custom_output_names_help) mdx_segment_size_help = "Larger consumes more resources, but may give better results (default: %(default)s). Example: --mdx_segment_size=256" @@ -175,10 +184,14 @@ def main(): sys.exit(0) if args.download_model_only: - logger.info(f"Separator version {package_version} downloading model {args.model_filename} to directory {args.model_file_dir}") + models_to_download = args.model_filename if isinstance(args.model_filename, list) else [args.model_filename] separator = Separator(log_formatter=log_formatter, log_level=log_level, model_file_dir=args.model_file_dir) - separator.download_model_and_data(args.model_filename) - logger.info(f"Model {args.model_filename} downloaded successfully.") + for model in models_to_download: + logger.info(f"Separator version {package_version} downloading model {model} to directory {args.model_file_dir}") + separator.download_model_and_data(model) + + models_string = ", ".join(models_to_download) if isinstance(models_to_download, list) else models_to_download + logger.info(f"Model {models_string} downloaded successfully.") sys.exit(0) audio_files = list(getattr(args, "audio_files", [])) @@ -203,6 +216,8 @@ def main(): use_soundfile=args.use_soundfile, use_autocast=args.use_autocast, chunk_duration=args.chunk_duration, + ensemble_algorithm=args.ensemble_algorithm, + ensemble_weights=args.ensemble_weights, mdx_params={ "hop_length": args.mdx_hop_length, "segment_size": args.mdx_segment_size, @@ -237,4 +252,4 @@ def main(): separator.load_model(model_filename=args.model_filename) output_files = separator.separate(audio_files, custom_output_names=args.custom_output_names) - logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}") + logger.info(f"Separation complete! Output file(s): {' '.join(output_files)}") \ No newline at end of file diff --git a/tests/integration/test_24bit_preservation.py b/tests/integration/test_24bit_preservation.py index 8a3d560..99f77a1 100644 --- a/tests/integration/test_24bit_preservation.py +++ b/tests/integration/test_24bit_preservation.py @@ -72,7 +72,7 @@ def run_separation_test_24bit(model, audio_path, expected_files): # Run the CLI command result = subprocess.run( - ["audio-separator", "-m", model, audio_path], + ["audio-separator", audio_path, "-m", model], capture_output=True, text=True, check=False diff --git a/tests/integration/test_cli_integration.py b/tests/integration/test_cli_integration.py index c10473f..da38126 100644 --- a/tests/integration/test_cli_integration.py +++ b/tests/integration/test_cli_integration.py @@ -45,7 +45,7 @@ def run_separation_test(model, audio_path, expected_files): os.remove(file) # Run the CLI command - result = subprocess.run(["audio-separator", "-m", model, audio_path], capture_output=True, text=True, check=False) # Explicitly set check to False as we handle errors manually + result = subprocess.run(["audio-separator", audio_path, "-m", model], capture_output=True, text=True, check=False) # Explicitly set check to False as we handle errors manually # Check that the command completed successfully assert result.returncode == 0, f"Command failed with output: {result.stderr}" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 4ed37d7..fc3e334 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -3,10 +3,27 @@ import logging from audio_separator.utils.cli import main import subprocess +import importlib.metadata from unittest import mock from unittest.mock import patch, MagicMock, mock_open +# Mock metadata.distribution for tests to avoid PackageNotFoundError in environment without installed package +@pytest.fixture(autouse=True) +def mock_distribution(): + original_distribution = importlib.metadata.distribution + + def side_effect(package_name): + if package_name == "audio-separator": + mock_dist = MagicMock() + mock_dist.version = "0.41.1" + return mock_dist + return original_distribution(package_name) + + with patch("importlib.metadata.distribution", side_effect=side_effect): + yield + + # Common fixture for expected arguments @pytest.fixture def common_expected_args(): @@ -25,6 +42,8 @@ def common_expected_args(): "use_soundfile": False, "use_autocast": False, "chunk_duration": None, + "ensemble_algorithm": "avg_wave", + "ensemble_weights": None, "mdx_params": {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, "vr_params": {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, "demucs_params": {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, @@ -115,7 +134,7 @@ def test_cli_model_filename_argument(common_expected_args): # Assertions mock_separator.assert_called_once_with(**common_expected_args) - mock_separator_instance.load_model.assert_called_once_with(model_filename="Custom_Model.onnx") + mock_separator_instance.load_model.assert_called_once_with(model_filename=["Custom_Model.onnx"]) # Test using output directory argument @@ -274,4 +293,4 @@ def test_cli_demucs_output_names_argument(common_expected_args): # Assertions mock_separator.assert_called_once_with(**common_expected_args) - mock_separator_instance.separate.assert_called_once_with(["test_audio.mp3"], custom_output_names=demucs_output_names) + mock_separator_instance.separate.assert_called_once_with(["test_audio.mp3"], custom_output_names=demucs_output_names) \ No newline at end of file diff --git a/tests/unit/test_ensembler.py b/tests/unit/test_ensembler.py new file mode 100644 index 0000000..2edf7a6 --- /dev/null +++ b/tests/unit/test_ensembler.py @@ -0,0 +1,111 @@ +import pytest +import numpy as np +import logging +from audio_separator.separator.ensembler import Ensembler + +@pytest.fixture +def logger(): + return logging.getLogger("test") + +def test_ensembler_avg_wave(logger): + # Test simple averaging + wav1 = np.ones((2, 100)) + wav2 = np.zeros((2, 100)) + ensembler = Ensembler(logger, algorithm="avg_wave") + result = ensembler.ensemble([wav1, wav2]) + assert np.allclose(result, 0.5) + +def test_ensembler_weighted_avg(logger): + # Test weighted averaging + wav1 = np.ones((2, 100)) + wav2 = np.zeros((2, 100)) + ensembler = Ensembler(logger, algorithm="avg_wave", weights=[3.0, 1.0]) + result = ensembler.ensemble([wav1, wav2]) + assert np.allclose(result, 0.75) + +def test_ensembler_different_lengths(logger): + # Test padding for different lengths + wav1 = np.ones((2, 100)) + wav2 = np.zeros((2, 80)) + ensembler = Ensembler(logger, algorithm="avg_wave") + result = ensembler.ensemble([wav1, wav2]) + assert result.shape == (2, 100) + assert np.allclose(result[:, :80], 0.5) + assert np.allclose(result[:, 80:], 0.5) # 0.5 * 1 + 0.5 * 0 + +def test_ensembler_median_wave(logger): + wav1 = np.ones((2, 100)) + wav2 = np.zeros((2, 100)) + wav3 = np.ones((2, 100)) * 0.7 + ensembler = Ensembler(logger, algorithm="median_wave") + result = ensembler.ensemble([wav1, wav2, wav3]) + assert np.allclose(result, 0.7) + +def test_ensembler_max_wave(logger): + wav1 = np.array([[1.0, -2.0], [3.0, -4.0]]) + wav2 = np.array([[0.5, -1.0], [4.0, -3.0]]) + ensembler = Ensembler(logger, algorithm="max_wave") + result = ensembler.ensemble([wav1, wav2]) + # key=np.abs, so max of (1.0, 0.5) is 1.0, (-2.0, -1.0) is -2.0, (3.0, 4.0) is 4.0, (-4.0, -3.0) is -4.0 + expected = np.array([[1.0, -2.0], [4.0, -4.0]]) + assert np.allclose(result, expected) + +def test_ensembler_min_wave(logger): + wav1 = np.array([[1.0, -2.0], [3.0, -4.0]]) + wav2 = np.array([[0.5, -1.0], [4.0, -3.0]]) + ensembler = Ensembler(logger, algorithm="min_wave") + result = ensembler.ensemble([wav1, wav2]) + # key=np.abs, so min of (1.0, 0.5) is 0.5, (-2.0, -1.0) is -1.0, (3.0, 4.0) is 3.0, (-4.0, -3.0) is -3.0 + expected = np.array([[0.5, -1.0], [3.0, -3.0]]) + assert np.allclose(result, expected) + +def test_ensembler_avg_fft(logger): + # FFT algorithms involve STFT/ISTFT which are harder to test with simple constants + # but we can check if it returns a valid waveform of correct shape + wav1 = np.random.rand(2, 1024) + wav2 = np.random.rand(2, 1024) + ensembler = Ensembler(logger, algorithm="avg_fft") + result = ensembler.ensemble([wav1, wav2]) + assert result.shape == (2, 1024) + +def test_ensembler_ensemble_wav_uvr(logger): + # Linear Ensemble (least noisy chunk) + wav1 = np.ones((2, 1000)) + wav2 = np.zeros((2, 1000)) + ensembler = Ensembler(logger, algorithm="ensemble_wav") + # It splits into 240 chunks by default. Each chunk in wav2 is less noisy (all 0s) + # so the result should be all 0s. + result = ensembler.ensemble([wav1, wav2]) + assert np.allclose(result, 0.0) + +def test_ensembler_empty_list(logger): + ensembler = Ensembler(logger) + assert ensembler.ensemble([]) is None + +def test_ensembler_single_waveform(logger): + wav = np.random.rand(2, 100) + ensembler = Ensembler(logger) + result = ensembler.ensemble([wav]) + assert np.array_equal(result, wav) + +def test_ensembler_mismatched_channels(logger): + wav1 = np.random.rand(2, 100) + wav2 = np.random.rand(1, 100) + ensembler = Ensembler(logger) + # Broadcasing will happen in np.zeros_like(waveforms[0]) + w * weight + # but let's see what happens. Actually it should probably be handled or at least tested. + # Current implementation pads length but not channels. + with pytest.raises(ValueError): + ensembler.ensemble([wav1, wav2]) + +def test_ensembler_mono_stft(logger): + wav_mono = np.random.rand(1024) + ensembler = Ensembler(logger) + spec = ensembler._stft(wav_mono) + assert spec.shape[0] == 2 # Should be converted to stereo + +def test_ensembler_single_channel_stft(logger): + wav_mono = np.random.rand(1, 1024) + ensembler = Ensembler(logger) + spec = ensembler._stft(wav_mono) + assert spec.shape[0] == 2 # Should be converted to stereo \ No newline at end of file