diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 9bba9e8..1f6990d 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -242,6 +242,20 @@ def get_audio_input( yield np.expand_dims(window, axis=0), window_time, original_length +def get_audio_input_from_array( + audio_original: npt.NDArray[np.float32], overlap_len: int, hop_size: int +) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]: + """ + A version of get_audio_input that works on an in-memory numpy array. + """ + assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}" + + original_length = audio_original.shape[0] + audio_padded = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) + for window, window_time in window_audio_file(audio_padded, hop_size): + yield np.expand_dims(window, axis=0), window_time, original_length + + def unwrap_output( output: npt.NDArray[np.float32], audio_original_length: int, @@ -272,14 +286,14 @@ def unwrap_output( def run_inference( - audio_path: Union[pathlib.Path, str], + audio_input: Union[pathlib.Path, str, npt.NDArray[np.float32]], model_or_model_path: Union[Model, pathlib.Path, str], debug_file: Optional[pathlib.Path] = None, ) -> Dict[str, np.array]: - """Run the model on the input audio path. + """Run the model on the input audio path or numpy array. Args: - audio_path: The audio to run inference on. + audio_input: The audio to run inference on, can be a file path or a numpy array. model_or_model_path: A loaded Model or path to a serialized model to load. debug_file: An optional path to output debug data to. Useful for testing/verification. @@ -297,14 +311,21 @@ def run_inference( hop_size = AUDIO_N_SAMPLES - overlap_len output: Dict[str, Any] = {"note": [], "onset": [], "contour": []} - for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size): + + # Choose the correct generator based on input type + if isinstance(audio_input, (pathlib.Path, str)): + audio_generator = get_audio_input(audio_input, overlap_len, hop_size) + else: # It's a numpy array + audio_generator = get_audio_input_from_array(audio_input, overlap_len, hop_size) + + for audio_windowed, _, audio_original_length in audio_generator: for k, v in model.predict(audio_windowed).items(): output[k].append(v) unwrapped_output = { k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output } - + if debug_file: with open(debug_file, "w") as f: json.dump( @@ -317,7 +338,7 @@ def run_inference( }, f, ) - + return unwrapped_output @@ -428,6 +449,7 @@ def predict( minimum_frequency: Optional[float] = None, maximum_frequency: Optional[float] = None, multiple_pitch_bends: bool = False, + infer_onsets: bool = True, melodia_trick: bool = True, debug_file: Optional[pathlib.Path] = None, midi_tempo: float = DEFAULT_MINIMUM_MIDI_TEMPO, @@ -444,11 +466,13 @@ def predict( onset_threshold: Minimum energy required for an onset to be considered present. frame_threshold: Minimum energy requirement for a frame to be considered present. minimum_note_length: The minimum allowed note length in milliseconds. - minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used. - maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used. + minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used. + maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used. multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends. + infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes. melodia_trick: Use the melodia post-processing step. debug_file: An optional path to output debug data to. Useful for testing/verification. + midi_tempo: The tempo for the output midi file. Returns: The model output, midi data and note events from a single prediction """ @@ -456,12 +480,30 @@ def predict( with no_tf_warnings(): print(f"Predicting MIDI for {audio_path}...") - model_output = run_inference(audio_path, model_or_model_path, debug_file) + # --- Simplified Workflow --- + # 1. Load the entire audio file into memory. + print("Loading audio file into memory...") + y, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + audio_duration = len(y) / AUDIO_SAMPLE_RATE + print(f"Audio loaded. Duration: {audio_duration:.2f} seconds.") + + # 2. Add a robust padding to the end of the audio. + # A longer padding ensures the CNN has enough context at the end of the audio stream. + padding_duration_s = 20.0 + padding = np.zeros(int(AUDIO_SAMPLE_RATE * padding_duration_s), dtype=np.float32) + audio_to_process = np.concatenate([y, padding]) + + # 3. Run inference on the padded audio. + print("Running inference...") + model_output = run_inference(audio_to_process, model_or_model_path, debug_file) + + # 4. Convert model output to notes. min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP))) midi_data, note_events = infer.model_output_to_notes( model_output, onset_thresh=onset_threshold, frame_thresh=frame_threshold, + infer_onsets=infer_onsets, min_note_len=min_note_len, # convert to frames min_freq=minimum_frequency, max_freq=maximum_frequency, @@ -470,6 +512,7 @@ def predict( midi_tempo=midi_tempo, ) + # Write the aggregated results after processing all chunks if debug_file: with open(debug_file) as f: debug_data = json.load(f)