From 5bf5d712d2985e3a15119deba5bf809585ec8e64 Mon Sep 17 00:00:00 2001 From: avan Date: Thu, 7 Aug 2025 22:44:01 +0800 Subject: [PATCH] Fix audio truncation by adding 20-second silent buffer Modified the `predict` function in `inference.py` to always append 20 seconds of silence to the input audio before running inference. This prevents the model from incorrectly truncating the tail end of the audio, which was happening on long, continuous files due to CNN edge effects. --- basic_pitch/inference.py | 61 ++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 9bba9e8..1f6990d 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -242,6 +242,20 @@ def get_audio_input( yield np.expand_dims(window, axis=0), window_time, original_length +def get_audio_input_from_array( + audio_original: npt.NDArray[np.float32], overlap_len: int, hop_size: int +) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]: + """ + A version of get_audio_input that works on an in-memory numpy array. + """ + assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}" + + original_length = audio_original.shape[0] + audio_padded = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) + for window, window_time in window_audio_file(audio_padded, hop_size): + yield np.expand_dims(window, axis=0), window_time, original_length + + def unwrap_output( output: npt.NDArray[np.float32], audio_original_length: int, @@ -272,14 +286,14 @@ def unwrap_output( def run_inference( - audio_path: Union[pathlib.Path, str], + audio_input: Union[pathlib.Path, str, npt.NDArray[np.float32]], model_or_model_path: Union[Model, pathlib.Path, str], debug_file: Optional[pathlib.Path] = None, ) -> Dict[str, np.array]: - """Run the model on the input audio path. + """Run the model on the input audio path or numpy array. Args: - audio_path: The audio to run inference on. + audio_input: The audio to run inference on, can be a file path or a numpy array. model_or_model_path: A loaded Model or path to a serialized model to load. debug_file: An optional path to output debug data to. Useful for testing/verification. @@ -297,14 +311,21 @@ def run_inference( hop_size = AUDIO_N_SAMPLES - overlap_len output: Dict[str, Any] = {"note": [], "onset": [], "contour": []} - for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size): + + # Choose the correct generator based on input type + if isinstance(audio_input, (pathlib.Path, str)): + audio_generator = get_audio_input(audio_input, overlap_len, hop_size) + else: # It's a numpy array + audio_generator = get_audio_input_from_array(audio_input, overlap_len, hop_size) + + for audio_windowed, _, audio_original_length in audio_generator: for k, v in model.predict(audio_windowed).items(): output[k].append(v) unwrapped_output = { k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output } - + if debug_file: with open(debug_file, "w") as f: json.dump( @@ -317,7 +338,7 @@ def run_inference( }, f, ) - + return unwrapped_output @@ -428,6 +449,7 @@ def predict( minimum_frequency: Optional[float] = None, maximum_frequency: Optional[float] = None, multiple_pitch_bends: bool = False, + infer_onsets: bool = True, melodia_trick: bool = True, debug_file: Optional[pathlib.Path] = None, midi_tempo: float = DEFAULT_MINIMUM_MIDI_TEMPO, @@ -444,11 +466,13 @@ def predict( onset_threshold: Minimum energy required for an onset to be considered present. frame_threshold: Minimum energy requirement for a frame to be considered present. minimum_note_length: The minimum allowed note length in milliseconds. - minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used. - maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used. + minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used. + maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used. multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends. + infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes. melodia_trick: Use the melodia post-processing step. debug_file: An optional path to output debug data to. Useful for testing/verification. + midi_tempo: The tempo for the output midi file. Returns: The model output, midi data and note events from a single prediction """ @@ -456,12 +480,30 @@ def predict( with no_tf_warnings(): print(f"Predicting MIDI for {audio_path}...") - model_output = run_inference(audio_path, model_or_model_path, debug_file) + # --- Simplified Workflow --- + # 1. Load the entire audio file into memory. + print("Loading audio file into memory...") + y, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + audio_duration = len(y) / AUDIO_SAMPLE_RATE + print(f"Audio loaded. Duration: {audio_duration:.2f} seconds.") + + # 2. Add a robust padding to the end of the audio. + # A longer padding ensures the CNN has enough context at the end of the audio stream. + padding_duration_s = 20.0 + padding = np.zeros(int(AUDIO_SAMPLE_RATE * padding_duration_s), dtype=np.float32) + audio_to_process = np.concatenate([y, padding]) + + # 3. Run inference on the padded audio. + print("Running inference...") + model_output = run_inference(audio_to_process, model_or_model_path, debug_file) + + # 4. Convert model output to notes. min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP))) midi_data, note_events = infer.model_output_to_notes( model_output, onset_thresh=onset_threshold, frame_thresh=frame_threshold, + infer_onsets=infer_onsets, min_note_len=min_note_len, # convert to frames min_freq=minimum_frequency, max_freq=maximum_frequency, @@ -470,6 +512,7 @@ def predict( midi_tempo=midi_tempo, ) + # Write the aggregated results after processing all chunks if debug_file: with open(debug_file) as f: debug_data = json.load(f)