@@ -234,6 +234,20 @@ def get_audio_input(
234234 yield np .expand_dims (window , axis = 0 ), window_time , original_length
235235
236236
237+ def get_audio_input_from_array (
238+ audio_original : npt .NDArray [np .float32 ], overlap_len : int , hop_size : int
239+ ) -> Iterable [Tuple [npt .NDArray [np .float32 ], Dict [str , float ], int ]]:
240+ """
241+ A version of get_audio_input that works on an in-memory numpy array.
242+ """
243+ assert overlap_len % 2 == 0 , f"overlap_length must be even, got { overlap_len } "
244+
245+ original_length = audio_original .shape [0 ]
246+ audio_padded = np .concatenate ([np .zeros ((int (overlap_len / 2 ),), dtype = np .float32 ), audio_original ])
247+ for window , window_time in window_audio_file (audio_padded , hop_size ):
248+ yield np .expand_dims (window , axis = 0 ), window_time , original_length
249+
250+
237251def unwrap_output (
238252 output : npt .NDArray [np .float32 ],
239253 audio_original_length : int ,
@@ -264,14 +278,14 @@ def unwrap_output(
264278
265279
266280def run_inference (
267- audio_path : Union [pathlib .Path , str ],
281+ audio_input : Union [pathlib .Path , str , npt . NDArray [ np . float32 ] ],
268282 model_or_model_path : Union [Model , pathlib .Path , str ],
269283 debug_file : Optional [pathlib .Path ] = None ,
270284) -> Dict [str , np .array ]:
271- """Run the model on the input audio path.
285+ """Run the model on the input audio path or numpy array .
272286
273287 Args:
274- audio_path : The audio to run inference on.
288+ audio_input : The audio to run inference on, can be a file path or a numpy array .
275289 model_or_model_path: A loaded Model or path to a serialized model to load.
276290 debug_file: An optional path to output debug data to. Useful for testing/verification.
277291
@@ -289,14 +303,21 @@ def run_inference(
289303 hop_size = AUDIO_N_SAMPLES - overlap_len
290304
291305 output : Dict [str , Any ] = {"note" : [], "onset" : [], "contour" : []}
292- for audio_windowed , _ , audio_original_length in get_audio_input (audio_path , overlap_len , hop_size ):
306+
307+ # Choose the correct generator based on input type
308+ if isinstance (audio_input , (pathlib .Path , str )):
309+ audio_generator = get_audio_input (audio_input , overlap_len , hop_size )
310+ else : # It's a numpy array
311+ audio_generator = get_audio_input_from_array (audio_input , overlap_len , hop_size )
312+
313+ for audio_windowed , _ , audio_original_length in audio_generator :
293314 for k , v in model .predict (audio_windowed ).items ():
294315 output [k ].append (v )
295316
296317 unwrapped_output = {
297318 k : unwrap_output (np .concatenate (output [k ]), audio_original_length , n_overlapping_frames ) for k in output
298319 }
299-
320+
300321 if debug_file :
301322 with open (debug_file , "w" ) as f :
302323 json .dump (
@@ -309,7 +330,7 @@ def run_inference(
309330 },
310331 f ,
311332 )
312-
333+
313334 return unwrapped_output
314335
315336
@@ -420,6 +441,7 @@ def predict(
420441 minimum_frequency : Optional [float ] = None ,
421442 maximum_frequency : Optional [float ] = None ,
422443 multiple_pitch_bends : bool = False ,
444+ infer_onsets : bool = True ,
423445 melodia_trick : bool = True ,
424446 debug_file : Optional [pathlib .Path ] = None ,
425447 midi_tempo : float = 120 ,
@@ -436,24 +458,44 @@ def predict(
436458 onset_threshold: Minimum energy required for an onset to be considered present.
437459 frame_threshold: Minimum energy requirement for a frame to be considered present.
438460 minimum_note_length: The minimum allowed note length in milliseconds.
439- minimum_freq : Minimum allowed output frequency, in Hz. If None, all frequencies are used.
440- maximum_freq : Maximum allowed output frequency, in Hz. If None, all frequencies are used.
461+ minimum_frequency : Minimum allowed output frequency, in Hz. If None, all frequencies are used.
462+ maximum_frequency : Maximum allowed output frequency, in Hz. If None, all frequencies are used.
441463 multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
464+ infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
442465 melodia_trick: Use the melodia post-processing step.
443466 debug_file: An optional path to output debug data to. Useful for testing/verification.
467+ midi_tempo: The tempo for the output midi file.
444468 Returns:
445469 The model output, midi data and note events from a single prediction
446470 """
447471
448472 with no_tf_warnings ():
449473 print (f"Predicting MIDI for { audio_path } ..." )
450474
451- model_output = run_inference (audio_path , model_or_model_path , debug_file )
475+ # --- Simplified Workflow ---
476+ # 1. Load the entire audio file into memory.
477+ print ("Loading audio file into memory..." )
478+ y , _ = librosa .load (str (audio_path ), sr = AUDIO_SAMPLE_RATE , mono = True )
479+ audio_duration = len (y ) / AUDIO_SAMPLE_RATE
480+ print (f"Audio loaded. Duration: { audio_duration :.2f} seconds." )
481+
482+ # 2. Add a robust padding to the end of the audio.
483+ # A longer padding ensures the CNN has enough context at the end of the audio stream.
484+ padding_duration_s = 20.0
485+ padding = np .zeros (int (AUDIO_SAMPLE_RATE * padding_duration_s ), dtype = np .float32 )
486+ audio_to_process = np .concatenate ([y , padding ])
487+
488+ # 3. Run inference on the padded audio.
489+ print ("Running inference..." )
490+ model_output = run_inference (audio_to_process , model_or_model_path , debug_file )
491+
492+ # 4. Convert model output to notes.
452493 min_note_len = int (np .round (minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP )))
453494 midi_data , note_events = infer .model_output_to_notes (
454495 model_output ,
455496 onset_thresh = onset_threshold ,
456497 frame_thresh = frame_threshold ,
498+ infer_onsets = infer_onsets ,
457499 min_note_len = min_note_len , # convert to frames
458500 min_freq = minimum_frequency ,
459501 max_freq = maximum_frequency ,
@@ -462,6 +504,7 @@ def predict(
462504 midi_tempo = midi_tempo ,
463505 )
464506
507+ # Write the aggregated results after processing all chunks
465508 if debug_file :
466509 with open (debug_file ) as f :
467510 debug_data = json .load (f )
0 commit comments