From 4f5ac0c1b6751da1773f0df00b815a8523b10d4a Mon Sep 17 00:00:00 2001 From: Bennett Wu <57691028+bennettrwu@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:18:23 -0500 Subject: [PATCH] fix(validate_config): fix issue where LocalAgreeModelBase doesn't call BufferAudioModelBase.validate_config --- .../model_bases/buffer_audio_model_base.py | 5 ++-- .../model_bases/local_agree_model_base.py | 1 + whisper-service/utils/config_dict_contains.py | 30 +++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/whisper-service/model_bases/buffer_audio_model_base.py b/whisper-service/model_bases/buffer_audio_model_base.py index b098508..9f505e9 100644 --- a/whisper-service/model_bases/buffer_audio_model_base.py +++ b/whisper-service/model_bases/buffer_audio_model_base.py @@ -7,7 +7,7 @@ from abc import abstractmethod import numpy as np import numpy.typing as npt -from utils.config_dict_contains import config_dict_contains_int +from utils.config_dict_contains import config_dict_contains_int, config_dict_contains_float from utils.decode_wav import decode_wav from utils.np_circular_buffer import NPCircularBuffer from model_bases.transcription_model_base import TranscriptionModelBase @@ -26,7 +26,7 @@ class BufferAudioModelBase(TranscriptionModelBase): and process_segment() methods must be implemented. ''' __slots__ = ['max_segment_samples', 'min_new_samples', - 'num_last_processed_samples', 'num_purged_samples', 'buffer','silence_threshold'] + 'num_last_processed_samples', 'num_purged_samples', 'buffer', 'silence_threshold'] SAMPLE_RATE = 16_000 def __init__(self, ws, config): @@ -71,6 +71,7 @@ def validate_config(config: dict) -> ImplementationModelConfig: 'max_segment_samples', minimum=config['min_new_samples'] ) + config_dict_contains_float(config, 'silence_threshold') return config def load_model(self) -> None: diff --git a/whisper-service/model_bases/local_agree_model_base.py b/whisper-service/model_bases/local_agree_model_base.py index 9611dbb..b68341d 100644 --- a/whisper-service/model_bases/local_agree_model_base.py +++ b/whisper-service/model_bases/local_agree_model_base.py @@ -113,6 +113,7 @@ def validate_config(config: dict) -> ImplementationModelConfig: Returns: config (TranscriptionModelConfig): Validated config object ''' + config = BufferAudioModelBase.validate_config(config) config_dict_contains_int(config, 'local_agree_dim', minimum=1) return config diff --git a/whisper-service/utils/config_dict_contains.py b/whisper-service/utils/config_dict_contains.py index 6cf9bc1..dc8c1c4 100644 --- a/whisper-service/utils/config_dict_contains.py +++ b/whisper-service/utils/config_dict_contains.py @@ -37,6 +37,36 @@ def config_dict_contains_int(config: dict, key: str, minimum=-sys.maxsize - 1, m ) +def config_dict_contains_float( + config: dict, + key: str, + minimum=-sys.maxsize - 1, + maximum=sys.maxsize +): + ''' + Checks if config contains a property, key, + that is a float between minimum and maximum inclusive + + Parameters: + config (dict) : Config dictionary + key (str) : Key to check in config dictionary + minimum (float): (Optional) minimum value key is allowed to be + maximum (int) : (Optional) maximum value key is allowed to be + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if not isinstance(config[key], float): + raise ValueError(f'"{key}" property of config must be a float') + if config[key] < minimum: + raise ValueError( + f'{key} property of config must be greater than or equal to {minimum}' + ) + if config[key] > maximum: + raise ValueError( + f'{key} property of config must be less than or equal to {maximum}' + ) + + def config_dict_contains_str(config: dict, key: str, min_length=0, max_length=sys.maxsize): ''' Checks if config contains a property, key,