Skip to content

Commit 40f3308

Browse files
committed
Fix audio truncation by adding 20-second silent buffer
Modified the `predict` function in `inference.py` to always append 20 seconds of silence to the input audio before running inference. This prevents the model from incorrectly truncating the tail end of the audio, which was happening on long, continuous files due to CNN edge effects.
1 parent f423902 commit 40f3308

1 file changed

Lines changed: 52 additions & 9 deletions

File tree

basic_pitch/inference.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ def get_audio_input(
234234
yield np.expand_dims(window, axis=0), window_time, original_length
235235

236236

237+
def get_audio_input_from_array(
238+
audio_original: npt.NDArray[np.float32], overlap_len: int, hop_size: int
239+
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
240+
"""
241+
A version of get_audio_input that works on an in-memory numpy array.
242+
"""
243+
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"
244+
245+
original_length = audio_original.shape[0]
246+
audio_padded = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
247+
for window, window_time in window_audio_file(audio_padded, hop_size):
248+
yield np.expand_dims(window, axis=0), window_time, original_length
249+
250+
237251
def unwrap_output(
238252
output: npt.NDArray[np.float32],
239253
audio_original_length: int,
@@ -264,14 +278,14 @@ def unwrap_output(
264278

265279

266280
def run_inference(
267-
audio_path: Union[pathlib.Path, str],
281+
audio_input: Union[pathlib.Path, str, npt.NDArray[np.float32]],
268282
model_or_model_path: Union[Model, pathlib.Path, str],
269283
debug_file: Optional[pathlib.Path] = None,
270284
) -> Dict[str, np.array]:
271-
"""Run the model on the input audio path.
285+
"""Run the model on the input audio path or numpy array.
272286
273287
Args:
274-
audio_path: The audio to run inference on.
288+
audio_input: The audio to run inference on, can be a file path or a numpy array.
275289
model_or_model_path: A loaded Model or path to a serialized model to load.
276290
debug_file: An optional path to output debug data to. Useful for testing/verification.
277291
@@ -289,14 +303,21 @@ def run_inference(
289303
hop_size = AUDIO_N_SAMPLES - overlap_len
290304

291305
output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
292-
for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):
306+
307+
# Choose the correct generator based on input type
308+
if isinstance(audio_input, (pathlib.Path, str)):
309+
audio_generator = get_audio_input(audio_input, overlap_len, hop_size)
310+
else: # It's a numpy array
311+
audio_generator = get_audio_input_from_array(audio_input, overlap_len, hop_size)
312+
313+
for audio_windowed, _, audio_original_length in audio_generator:
293314
for k, v in model.predict(audio_windowed).items():
294315
output[k].append(v)
295316

296317
unwrapped_output = {
297318
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
298319
}
299-
320+
300321
if debug_file:
301322
with open(debug_file, "w") as f:
302323
json.dump(
@@ -309,7 +330,7 @@ def run_inference(
309330
},
310331
f,
311332
)
312-
333+
313334
return unwrapped_output
314335

315336

@@ -420,6 +441,7 @@ def predict(
420441
minimum_frequency: Optional[float] = None,
421442
maximum_frequency: Optional[float] = None,
422443
multiple_pitch_bends: bool = False,
444+
infer_onsets: bool = True,
423445
melodia_trick: bool = True,
424446
debug_file: Optional[pathlib.Path] = None,
425447
midi_tempo: float = 120,
@@ -436,24 +458,44 @@ def predict(
436458
onset_threshold: Minimum energy required for an onset to be considered present.
437459
frame_threshold: Minimum energy requirement for a frame to be considered present.
438460
minimum_note_length: The minimum allowed note length in milliseconds.
439-
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
440-
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
461+
minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
462+
maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
441463
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
464+
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
442465
melodia_trick: Use the melodia post-processing step.
443466
debug_file: An optional path to output debug data to. Useful for testing/verification.
467+
midi_tempo: The tempo for the output midi file.
444468
Returns:
445469
The model output, midi data and note events from a single prediction
446470
"""
447471

448472
with no_tf_warnings():
449473
print(f"Predicting MIDI for {audio_path}...")
450474

451-
model_output = run_inference(audio_path, model_or_model_path, debug_file)
475+
# --- Simplified Workflow ---
476+
# 1. Load the entire audio file into memory.
477+
print("Loading audio file into memory...")
478+
y, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
479+
audio_duration = len(y) / AUDIO_SAMPLE_RATE
480+
print(f"Audio loaded. Duration: {audio_duration:.2f} seconds.")
481+
482+
# 2. Add a robust padding to the end of the audio.
483+
# A longer padding ensures the CNN has enough context at the end of the audio stream.
484+
padding_duration_s = 20.0
485+
padding = np.zeros(int(AUDIO_SAMPLE_RATE * padding_duration_s), dtype=np.float32)
486+
audio_to_process = np.concatenate([y, padding])
487+
488+
# 3. Run inference on the padded audio.
489+
print("Running inference...")
490+
model_output = run_inference(audio_to_process, model_or_model_path, debug_file)
491+
492+
# 4. Convert model output to notes.
452493
min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
453494
midi_data, note_events = infer.model_output_to_notes(
454495
model_output,
455496
onset_thresh=onset_threshold,
456497
frame_thresh=frame_threshold,
498+
infer_onsets=infer_onsets,
457499
min_note_len=min_note_len, # convert to frames
458500
min_freq=minimum_frequency,
459501
max_freq=maximum_frequency,
@@ -462,6 +504,7 @@ def predict(
462504
midi_tempo=midi_tempo,
463505
)
464506

507+
# Write the aggregated results after processing all chunks
465508
if debug_file:
466509
with open(debug_file) as f:
467510
debug_data = json.load(f)

0 commit comments

Comments
 (0)