diff --git a/.github/workflows/node-server-ci.yml b/.github/workflows/node-server-ci.yml index d0e40fb..f67ce92 100644 --- a/.github/workflows/node-server-ci.yml +++ b/.github/workflows/node-server-ci.yml @@ -42,7 +42,7 @@ jobs: - name: Run tests working-directory: "node-server" - run: npm run ci-test + run: npm run test:ci - name: Run linter working-directory: "node-server" diff --git a/.github/workflows/whisper-service-ci.yml b/.github/workflows/whisper-service-ci.yml index fd864d4..648ee17 100644 --- a/.github/workflows/whisper-service-ci.yml +++ b/.github/workflows/whisper-service-ci.yml @@ -40,9 +40,9 @@ jobs: working-directory: "whisper-service" run: pytest --cov=. - - name: Run tests + - name: Run linter working-directory: "whisper-service" - run: pylint --disable=import-error $(git ls-files '*.py') + run: pylint $(git ls-files '*.py') build-cpu-container-whisper-service: needs: test-lint-whisper-service diff --git a/node-server/.editorconfig b/node-server/.editorconfig deleted file mode 100644 index 79fe802..0000000 --- a/node-server/.editorconfig +++ /dev/null @@ -1,8 +0,0 @@ -root = true - -[*] -indent_style = space -indent_size = 2 -end_of_line = lf -charset = utf-8 -insert_final_newline = true diff --git a/node-server/package-lock.json b/node-server/package-lock.json index c06c476..7cd1f23 100644 --- a/node-server/package-lock.json +++ b/node-server/package-lock.json @@ -1,13 +1,12 @@ { - "name": "scribear-backend", + "name": "scribear-node-server", "version": "0.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "scribear-backend", + "name": "scribear-node-server", "version": "0.0.0", - "license": "MIT", "dependencies": { "@fastify/cors": "11.0.1", "@fastify/helmet": "13.0.1", @@ -43,7 +42,7 @@ "vitest": "3.0.5" }, "engines": { - "node": "^20.0.0" + "node": ">=20.0.0" } }, "node_modules/@ampproject/remapping": { diff --git a/node-server/package.json b/node-server/package.json index fdb87f1..5a207b8 100644 --- a/node-server/package.json +++ b/node-server/package.json @@ -1,25 +1,23 @@ { - "name": "scribear-backend", + "name": "scribear-node-server", "version": "0.0.0", - "main": "build/index.js", + "main": "build/src/index.js", "type": "module", "scripts": { "lint": "eslint", + "lint:fix": "eslint --fix", "clean": "gts clean", - "fix": "eslint --fix", "dev": "tsc-watch --compiler ts-patch/compiler/tsc.js --onSuccess \"node build/src/index.js\" | pino-pretty", - "pretest": "npm run lint", - "test": "vitest --ui --coverage", - "ci-test": "vitest run", + "test:dev": "vitest --ui --coverage", + "test:ci": "vitest run", "build": "tspc", "prestart": "npm run build", "start": "node ./build/src/index.js" }, - "author": "bwu1324", - "license": "MIT", + "author": "scribear", "description": "", "engines": { - "node": "^20.0.0" + "node": ">=20.0.0" }, "devDependencies": { "@eslint/compat": "1.2.6", diff --git a/node-server/src/index.ts b/node-server/src/index.ts index efcf12b..7107747 100644 --- a/node-server/src/index.ts +++ b/node-server/src/index.ts @@ -1,5 +1,5 @@ import loadConfig from './shared/config/load_config.js'; -import createServer from './server/start_server.js'; +import createServer from './server/create_server.js'; import createLogger from './shared/logger/logger.js'; async function init() { diff --git a/node-server/src/server/start_server.ts b/node-server/src/server/create_server.ts similarity index 100% rename from node-server/src/server/start_server.ts rename to node-server/src/server/create_server.ts diff --git a/whisper-service/.dockerignore b/whisper-service/.dockerignore index 1800114..ae493f0 100644 --- a/whisper-service/.dockerignore +++ b/whisper-service/.dockerignore @@ -1,3 +1,5 @@ +device_config.json + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/whisper-service/.gitignore b/whisper-service/.gitignore index 1800114..ae493f0 100644 --- a/whisper-service/.gitignore +++ b/whisper-service/.gitignore @@ -1,3 +1,5 @@ +device_config.json + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/whisper-service/create_server.py b/whisper-service/create_server.py index 17318a2..d29c1eb 100644 --- a/whisper-service/create_server.py +++ b/whisper-service/create_server.py @@ -10,21 +10,24 @@ from typing import Annotated, Callable from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query from model_bases.transcription_model_base import TranscriptionModelBase -from model_factory import model_factory -from load_config import AppConfig, load_config +from load_config import AppConfig +from init_device_config import DeviceConfig def create_server( config: AppConfig, - model_factory_func: Callable[[str, WebSocket], TranscriptionModelBase] + device_config: DeviceConfig, + model_factory_func: Callable[[DeviceConfig, + str, WebSocket], TranscriptionModelBase] ) -> FastAPI: ''' Instanciates FastAPI webserver. Parameters: - config (Config) : Application configuration object - model_factory_func (function): Function that takes in a modelKey and a WebSocket and - returns the corresponding model implementation + config (AppConfig) : Application configuration object + device_config (DeviceConfig): Application device configuration object + model_factory_func (function) : Function that takes in a modelKey and a WebSocket and + returns the corresponding model implementation Returns: FastAPI webserver @@ -56,7 +59,11 @@ async def whisper( return # Intanciate and setup requested model - transcription_model = model_factory_func(model_key, websocket) + transcription_model = model_factory_func( + device_config, + model_key, + websocket + ) transcription_model.load_model() # Send any audio chunks to transcription model @@ -69,8 +76,3 @@ async def whisper( return return fastapi_app - - -if __name__ == 'create_server': - app_config = load_config() - app = create_server(app_config, model_factory) diff --git a/whisper-service/create_server_test.py b/whisper-service/create_server_test.py index ed23691..f406b6c 100644 --- a/whisper-service/create_server_test.py +++ b/whisper-service/create_server_test.py @@ -10,7 +10,6 @@ from fastapi.testclient import TestClient from load_config import AppConfig from model_bases.transcription_model_base import TranscriptionModelBase - from create_server import create_server @@ -38,7 +37,11 @@ class Fake(TranscriptionModelBase): ''' def __init__(self): - super().__init__(None) + super().__init__(None, {}) + + @staticmethod + def validate_config(config): + return config def load_model(self): return None @@ -51,19 +54,27 @@ async def queue_audio_chunk(self, audio_chunk): return mock.Mock(wraps=Fake()) + @pytest.fixture(scope='function') def test_client(fake_config, fake_transcription_model): ''' Create a FastAPI test client for each test ''' - def fake_factory(model_key: str, ws: WebSocket): - if isinstance(ws, WebSocket) and model_key == 'test-model': + fake_device_config = { + 'test-model': {} + } + + def fake_factory(device_config, model_key: str, ws: WebSocket): + if (isinstance(ws, WebSocket) and + model_key == 'test-model' and + fake_device_config != device_config + ): return fake_transcription_model raise NotImplementedError( 'Invalid model key or invalid websocket argument.' ) - app = create_server(fake_config, fake_factory) + app = create_server(fake_config, {}, fake_factory) return TestClient(app) diff --git a/whisper-service/device_config.template.json b/whisper-service/device_config.template.json new file mode 100644 index 0000000..08c9c5d --- /dev/null +++ b/whisper-service/device_config.template.json @@ -0,0 +1,22 @@ +{ + "mock_transcription_duration": { + "display_name": "Sanity Test", + "description": "Returns how many seconds of audio was received by whisper service.", + "implementation_id": "mock_transcription_duration", + "implementation_configuration": {}, + "available_features": {} + }, + "faster-whisper:cpu-tiny-en": { + "display_name": "Tiny Faster Whisper", + "description": "Faster Whisper implementation of Open AI Whisper tiny.en model.", + "implementation_id": "faster_whisper", + "implementation_configuration": { + "model": "tiny.en", + "device": "cpu", + "local_agree_dim": 2, + "min_new_samples": 48000, + "max_segment_samples": 480000 + }, + "available_features": {} + } +} diff --git a/whisper-service/index.py b/whisper-service/index.py index a0b0015..b6e4341 100644 --- a/whisper-service/index.py +++ b/whisper-service/index.py @@ -1,18 +1,28 @@ ''' Entry point for whisper-service application. ''' +import sys import uvicorn from load_config import load_config from create_server import create_server from model_factory import model_factory +from init_device_config import init_device_config + +config = load_config() +device_config = init_device_config('device_config.json') +APP = create_server(config, device_config, model_factory) if __name__ == '__main__': - config = load_config() - app = create_server(config, model_factory) + dev_mode = len(sys.argv) > 1 and sys.argv[1] == '--dev' + + if dev_mode: + APP = 'index:app' uvicorn.run( - app, + APP, log_level=config.LOG_LEVEL, port=config.PORT, - host=config.HOST + host=config.HOST, + use_colors=dev_mode, + reload=dev_mode ) diff --git a/whisper-service/init_device_config.py b/whisper-service/init_device_config.py new file mode 100644 index 0000000..ff68f9b --- /dev/null +++ b/whisper-service/init_device_config.py @@ -0,0 +1,117 @@ +''' +Function to load then initialize whisper service according to device config + +Functions: + init_device_config + +Types: + AvailableFeaturesConfig + ModelConfig + DeviceConfig +''' +import json +import logging +from typing import Any, TypedDict +from model_implementations.import_model_implementation import \ + ModelImplementationId, import_model_implementation +from utils.config_dict_contains import \ + config_dict_contains_dict, config_dict_contains_one_of, config_dict_contains_str + + +class AvailableFeaturesConfig(TypedDict): + ''' + Type hint for available features configuration dict + ''' + + +class ModelConfig(TypedDict): + ''' + Type hint for model configuration dict + ''' + display_name: str + description: str + implementation_id: ModelImplementationId + implementation_configuration: dict + available_features: AvailableFeaturesConfig + + +# Type hint for loaded device configuration dict +type DeviceConfig = dict[str, ModelConfig] + + +def init_model(device_config: dict[str, Any], key: str) -> ModelConfig: + ''' + Validates and initalizes given model_key in device_config. + Checks if all required property for ModelConfig are present. Throws error if not. + Implementation configuration is checked automatically when implementation is initialized. + Models are initialized by calling load_model() then unload_mode(). + + Parameters: + device_config (dict): Loaded device_config dict + key (str) : model_key to initialize + + Return: + Validated ModelConfig dict + ''' + logger = logging.getLogger('uvicorn.error') + + # Grab config specific to model + config_dict_contains_dict(device_config, key) + model_config = device_config[key] + + # Check required properties + config_dict_contains_str(model_config, 'display_name', min_length=1) + config_dict_contains_str(model_config, 'description', min_length=1) + config_dict_contains_one_of( + model_config, 'implementation_id', list(ModelImplementationId)) + config_dict_contains_dict(model_config, 'implementation_configuration') + config_dict_contains_dict(model_config, 'available_features') + + # Initialize the configured model + implementation_id: ModelImplementationId = model_config['implementation_id'] + implementation_config = model_config['implementation_configuration'] + logger.info( + 'Initializing implementation: %s for model_key: %s', implementation_id, key + ) + + implementation = import_model_implementation(implementation_id) + model = implementation({}, implementation_config) + + model.load_model() + model.unload_model() + logger.info( + 'Successfully initialized implementation: %s for model_key: %s', implementation_id, key + ) + + return { + 'display_name': model_config['display_name'], + 'description': model_config['description'], + 'implementation_id': implementation_id, + 'implementation_configuration': implementation_config, + 'available_features': model_config['available_features'] + } + + +def init_device_config(device_config_path: str) -> DeviceConfig: + ''' + Loads device config file from provided path then initializes configured models. + + + Parameters: + device_config_path (str): Path to device config file + ''' + logger = logging.getLogger('uvicorn.error') + + logger.info('Loading device config from: %s', device_config_path) + with open(device_config_path, 'r', encoding='utf-8') as file: + loaded_config = json.load(file) + + if not isinstance(loaded_config, dict): + raise ValueError('Device config must an object') + + device_config: DeviceConfig = {} + for key in loaded_config.keys(): + model_config = init_model(loaded_config, key) + device_config[key] = model_config + + return device_config diff --git a/whisper-service/model_bases/buffer_audio_model_base.py b/whisper-service/model_bases/buffer_audio_model_base.py index 6ece80f..1a4d2a0 100644 --- a/whisper-service/model_bases/buffer_audio_model_base.py +++ b/whisper-service/model_bases/buffer_audio_model_base.py @@ -4,10 +4,12 @@ Classes: BufferAudioModelBase ''' +from abc import abstractmethod import numpy as np import numpy.typing as npt -from utils.np_circular_buffer import NPCircularBuffer +from utils.config_dict_contains import config_dict_contains_int from utils.decode_wav import decode_wav +from utils.np_circular_buffer import NPCircularBuffer from model_bases.transcription_model_base import TranscriptionModelBase @@ -26,22 +28,18 @@ class BufferAudioModelBase(TranscriptionModelBase): 'num_last_processed_samples', 'num_purged_samples', 'buffer'] SAMPLE_RATE = 16_000 - def __init__( - self, - ws, - max_segment_samples=SAMPLE_RATE * 30, - min_new_samples=SAMPLE_RATE - ): + def __init__(self, ws, config): ''' + Called when a websocket requests a transcription model. + Parameters: - ws (WebSocket): FastAPI websocket that requested the model - max_segment_samples (int) : Maximum number of samples to be passed to process_segment() - min_new_samples (int) : Minimum number of fresh samples in buffer before - process_segment() should be called + ws (WebSocket) : FastAPI websocket that requested the model + config (TranscriptionModelConfig): Custom JSON object containing configuration for model + Defined by implementation ''' - super().__init__(ws) - self.max_segment_samples = max_segment_samples - self.min_new_samples = min_new_samples + super().__init__(ws, config) + self.max_segment_samples = config['max_segment_samples'] + self.min_new_samples = config['min_new_samples'] self.num_last_processed_samples = 0 self.num_purged_samples = 0 @@ -50,6 +48,29 @@ def __init__( dtype=np.float16 ) + @staticmethod + def validate_config(config): + ''' + Should check if loaded JSON config is valid. Called model is instantiated. + Throw an error if provided config is not valid + Remember to call valididate_config for any model_bases to ensure configuration + for model_bases is checked as well. + e.g. if you use LocalAgreeModelBase: config = LocalAgreeModelBase.validate(config) + + Parameters: + config (dict): Parsed JSON config from server device_config.json. Guaranteed to be a dict. + + Returns: + config (TranscriptionModelConfig): Validated config object + ''' + config_dict_contains_int(config, 'min_new_samples') + config_dict_contains_int( + config, + 'max_segment_samples', + minimum=config['min_new_samples'] + ) + return config + def load_model(self) -> None: ''' Should load model into memory to be ready for transcription. @@ -64,6 +85,7 @@ def unload_model(self) -> None: ''' raise NotImplementedError('Must implement per model') + @abstractmethod async def process_segment( self, audio_segment: npt.NDArray, diff --git a/whisper-service/model_bases/local_agree_model_base.py b/whisper-service/model_bases/local_agree_model_base.py index 7ccdfdf..888e97a 100644 --- a/whisper-service/model_bases/local_agree_model_base.py +++ b/whisper-service/model_bases/local_agree_model_base.py @@ -5,9 +5,11 @@ TranscriptionSegment LocalAgreeModelBase ''' +from abc import abstractmethod import math import numpy.typing as npt from model_bases.buffer_audio_model_base import BufferAudioModelBase +from utils.config_dict_contains import config_dict_contains_int class TranscriptionSegment: @@ -95,6 +97,24 @@ def __init__(self, ws, *args, local_agree_dim=2, **kwargs): self.local_agree_dim = local_agree_dim self.prev_transcriptions: list[list[TranscriptionSegment]] = [] + @staticmethod + def validate_config(config): + ''' + Should check if loaded JSON config is valid. Called model is instantiated. + Throw an error if provided config is not valid + Remember to call valididate_config for any model_bases to ensure configuration + for model_bases is checked as well. + e.g. if you use LocalAgreeModelBase: config = LocalAgreeModelBase.validate(config) + + Parameters: + config (dict): Parsed JSON config from server device_config.json. Guaranteed to be a dict. + + Returns: + config (TranscriptionModelConfig): Validated config object + ''' + config_dict_contains_int(config, 'local_agree_dim', minimum=1) + return config + def load_model(self) -> None: ''' Should load model into memory to be ready for transcription. @@ -109,6 +129,7 @@ def unload_model(self) -> None: ''' raise NotImplementedError('Must implement per model') + @abstractmethod async def transcribe_audio( self, audio_segment: npt.NDArray, diff --git a/whisper-service/model_bases/transcription_model_base.py b/whisper-service/model_bases/transcription_model_base.py index abb408b..ceb9b10 100644 --- a/whisper-service/model_bases/transcription_model_base.py +++ b/whisper-service/model_bases/transcription_model_base.py @@ -4,9 +4,14 @@ Classes: BackendTranscriptionBlockType TranscriptionModelBase + +Types: + TranscriptionModelConfig ''' import io import logging +from typing import Union, List, Dict +from abc import ABC, abstractmethod from enum import IntEnum from fastapi import WebSocket @@ -20,25 +25,53 @@ class BackendTranscriptionBlockType(IntEnum): IN_PROGRESS = 1 -class TranscriptionModelBase: +type JsonType = Union[None, int, str, bool, + List[JsonType], Dict[str, JsonType]] +type ImplementationModelConfig = JsonType + + +class TranscriptionModelBase(ABC): ''' Base transcription model class. Presents a unified interface for using different transcription models on the backend. - The load_model(), unload_model(), and queue_audio_chunk() methods must be implemented. + The validate_config(), load_model(), unload_model(), and + queue_audio_chunk() methods must be implemented. ''' - __slots__ = ['logger', 'ws'] + __slots__ = ['logger', 'ws', 'config'] - def __init__(self, ws: WebSocket): + def __init__(self, ws: WebSocket, config: ImplementationModelConfig): ''' Called when a websocket requests a transcription model. Parameters: - ws (WebSocket): FastAPI websocket that requested the model + ws (WebSocket) : FastAPI websocket that requested the model + config (TranscriptionModelConfig): Custom JSON object containing configuration for model + Defined by implementation ''' self.ws = ws + self.config = self.validate_config(config) self.logger = logging.getLogger('uvicorn.error') + @staticmethod + @abstractmethod + def validate_config(config: dict) -> ImplementationModelConfig: + ''' + Should check if loaded JSON config is valid. Called model is instantiated. + Throw an error if provided config is not valid + Remember to call valididate_config for any model_bases to ensure configuration + for model_bases is checked as well. + e.g. if you use LocalAgreeModelBase: config = LocalAgreeModelBase.validate(config) + + Parameters: + config (dict): Parsed JSON config from server device_config.json. Guaranteed to be a dict. + + Returns: + config (TranscriptionModelConfig): Validated config object + ''' + raise NotImplementedError('Must implement per model') + + @abstractmethod def load_model(self) -> None: ''' Should load model into memory to be ready for transcription. @@ -46,6 +79,7 @@ def load_model(self) -> None: ''' raise NotImplementedError('Must implement per model') + @abstractmethod def unload_model(self) -> None: ''' Should unload model from memory and cleanup. @@ -53,6 +87,7 @@ def unload_model(self) -> None: ''' raise NotImplementedError('Must implement per model') + @abstractmethod async def queue_audio_chunk(self, audio_chunk: io.BytesIO) -> None: ''' Called when an audio chunk is received. diff --git a/whisper-service/model_bases/transcription_model_base_test.py b/whisper-service/model_bases/transcription_model_base_test.py index 23db815..c9381e3 100644 --- a/whisper-service/model_bases/transcription_model_base_test.py +++ b/whisper-service/model_bases/transcription_model_base_test.py @@ -1,10 +1,26 @@ ''' Unit tests for TranscriptionModelBase class ''' +# pylint: disable=redefined-outer-name import pytest from model_bases.transcription_model_base import \ TranscriptionModelBase, BackendTranscriptionBlockType +fake_config = { + 'some_param': 'string', + 'another_param': 0, + 'nested_param': { + 'array_param': ['erased'] + } +} +returned_fake_config = { + 'some_param': 'str', + 'another_param': 1, + 'nested_param': { + 'array_param': [] + } +} + class FakeWebSocket: ''' @@ -27,13 +43,38 @@ def get_sent_messages(self): return self.sent_messages +@pytest.fixture(scope='function') +def fake_implementation(): + ''' + Create a fake transcription model for each test + ''' + class Fake(TranscriptionModelBase): + ''' + Fake transcription model to track how object's methods are called + ''' + @staticmethod + def validate_config(config): + return returned_fake_config + + def load_model(self): + return None + + def unload_model(self): + return None + + async def queue_audio_chunk(self, audio_chunk): + return None + + return Fake + + @pytest.mark.asyncio -async def test_on_final_transcript(): +async def test_on_final_transcript(fake_implementation): ''' Test that on_final_transcript_block() sends correct websocket message ''' fake_ws = FakeWebSocket() - model_base = TranscriptionModelBase(fake_ws) + model_base = fake_implementation(fake_ws, fake_config) await model_base.on_final_transcript_block("Hello world", start=0, end=1) @@ -46,12 +87,12 @@ async def test_on_final_transcript(): @pytest.mark.asyncio -async def test_on_in_progress_transcript(): +async def test_on_in_progress_transcript(fake_implementation): ''' - Tests that on_in_progress_transcript_block() sends correct websocket message + Test that on_in_progress_transcript_block() sends correct websocket message ''' fake_ws = FakeWebSocket() - model_base = TranscriptionModelBase(fake_ws) + model_base = fake_implementation(fake_ws, fake_config) await model_base.on_in_progress_transcript_block("Processing...", start=0, end=1) @@ -61,3 +102,21 @@ async def test_on_in_progress_transcript(): assert message['text'] == "Processing..." assert message['start'] == 0 assert message['end'] == 1 + + +def test_validate_config_called(fake_implementation): + ''' + Test that validate_config() is called when model is instantiated and + return value is set as config property + ''' + fake_config = { + 'some_param': 'str', + 'another_param': 1, + 'nested_param': { + 'array_param': [] + } + } + fake_ws = FakeWebSocket() + model = fake_implementation(fake_ws, fake_config) + + assert model.config == returned_fake_config, 'config property not set' diff --git a/whisper-service/model_factory.py b/whisper-service/model_factory.py index 1ac92d5..74b7d1d 100644 --- a/whisper-service/model_factory.py +++ b/whisper-service/model_factory.py @@ -1,48 +1,36 @@ ''' -Function for instantiating specified model +Function for instantiating a specified model Functions: model_factory - -Classes: - ModelKey ''' -# pylint: disable=import-outside-toplevel from fastapi import WebSocket from model_bases.transcription_model_base import TranscriptionModelBase +from model_implementations.import_model_implementation import import_model_implementation +from init_device_config import DeviceConfig + -def model_factory(model_key: str, websocket: WebSocket) -> TranscriptionModelBase: +def model_factory( + device_config: DeviceConfig, + model_key: str, + websocket: WebSocket +) -> TranscriptionModelBase: ''' - Instantiates model with corresponding ModelKey. + Instantiates model with corresponding model_key. Parameters: - model_key (str) : Unique identifier for model to instantiate - websocket (WebSocket): Websocket requesting model + device_config (DeviceConfig): Device config object + model_key (str) : Unique identifier for model to instantiate + websocket (WebSocket) : Websocket requesting model Returns: A TranscriptionModelBase instance ''' - match model_key: - case 'mock-transcription-duration': - from models.mock_transcription_duration import MockTranscribeDuration - return MockTranscribeDuration(websocket) - case 'faster-whisper:gpu-large-v3': - from models.faster_whisper_model import FasterWhisperModel - return FasterWhisperModel( - websocket, - 'large-v3', - device='cuda', - local_agree_dim=2, - min_new_samples=FasterWhisperModel.SAMPLE_RATE * 3 - ) - case "faster-whisper:cpu-tiny-en": - from models.faster_whisper_model import FasterWhisperModel - return FasterWhisperModel( - websocket, - 'tiny.en', - device='cpu', - local_agree_dim=2, - min_new_samples=FasterWhisperModel.SAMPLE_RATE * 3 - ) - case _: - raise KeyError('No model matching model_key') + if model_key not in device_config: + raise KeyError('No model matching model_key') + + implementation = import_model_implementation( + device_config[model_key]['implementation_id'] + ) + + return implementation(websocket, device_config[model_key]['implementation_configuration']) diff --git a/whisper-service/models/faster_whisper_model.py b/whisper-service/model_implementations/faster_whisper_model.py similarity index 58% rename from whisper-service/models/faster_whisper_model.py rename to whisper-service/model_implementations/faster_whisper_model.py index b708f79..f85a1b0 100644 --- a/whisper-service/models/faster_whisper_model.py +++ b/whisper-service/model_implementations/faster_whisper_model.py @@ -4,7 +4,6 @@ Classes: FasterWhisperModel ''' -from fastapi import WebSocket from faster_whisper import WhisperModel from model_bases.local_agree_model_base import LocalAgreeModelBase, TranscriptionSegment @@ -13,26 +12,47 @@ class FasterWhisperModel(LocalAgreeModelBase): ''' Implementation of TranscriptionModelBase using faster whisper and local agreement. ''' - __slots__ = ['model_size', 'device', 'model'] + __slots__ = ['model'] - def __init__(self, ws: WebSocket, model_size: str, *args, device='auto', **kwargs): + def __init__(self, ws, config): ''' + Called when a websocket requests a transcription model. + Parameters: - ws (WebSocket): FastAPI websocket that requested the model. - model_size (str): faster_whisper model to run - *args, **kwargs : Args passed to LocalAgreeModelBase + ws (WebSocket) : FastAPI websocket that requested the model + config (TranscriptionModelConfig): Custom JSON object containing configuration for model + Defined by implementation ''' - super().__init__(ws, *args, **kwargs) - self.model_size = model_size - self.device = device + super().__init__(ws, config) self.model = None + @staticmethod + def validate_config(config): + ''' + Should check if loaded JSON config is valid. Called model is instantiated. + Throw an error if provided config is not valid + Remember to call valididate_config for any model_bases to ensure configuration + for model_bases is checked as well. + e.g. if you use LocalAgreeModelBase: config = LocalAgreeModelBase.validate(config) + + Parameters: + config (dict): Parsed JSON config from server device_config.json. Guaranteed to be a dict. + + Returns: + config (TranscriptionModelConfig): Validated config object + ''' + config = LocalAgreeModelBase.validate_config(config) + return config + def load_model(self): ''' Loads model into memory to be ready for transcription. Called when websocket connects. ''' - self.model = WhisperModel(self.model_size, device=self.device) + self.model = WhisperModel( + self.config['model'], + device=self.config['device'] + ) def unload_model(self): ''' diff --git a/whisper-service/model_implementations/import_model_implementation.py b/whisper-service/model_implementations/import_model_implementation.py new file mode 100644 index 0000000..7a7f812 --- /dev/null +++ b/whisper-service/model_implementations/import_model_implementation.py @@ -0,0 +1,42 @@ +''' +Function for importing specified model implementation + +Functions: + import_model_implementation + +Enums: + ModelImplementationId +''' +# pylint: disable=import-outside-toplevel +from enum import StrEnum + + +class ModelImplementationId(StrEnum): + ''' + Unique keys for all available implementations of TranscriptionModelBase + ''' + MOCK_TRANSCRIPTION_DURATION = "mock_transcription_duration" + FASTER_WHISPER = "faster_whisper" + + +def import_model_implementation(model_implementation_id: ModelImplementationId): + ''' + Imports model with corresponding model_implementation_id. + + Parameters: + model_implementation_id (str): Unique identifier for model to instantiate + + Returns: + A TranscriptionModelBase class + ''' + match(model_implementation_id): + case ModelImplementationId.MOCK_TRANSCRIPTION_DURATION: + from model_implementations.mock_transcription_duration import MockTranscribeDuration + return MockTranscribeDuration + case ModelImplementationId.FASTER_WHISPER: + from model_implementations.faster_whisper_model import FasterWhisperModel + return FasterWhisperModel + case _: + raise KeyError( + f'No model implementation matching {model_implementation_id}' + ) diff --git a/whisper-service/models/mock_transcription_duration.py b/whisper-service/model_implementations/mock_transcription_duration.py similarity index 67% rename from whisper-service/models/mock_transcription_duration.py rename to whisper-service/model_implementations/mock_transcription_duration.py index f0bd8d9..502b8d3 100644 --- a/whisper-service/models/mock_transcription_duration.py +++ b/whisper-service/model_implementations/mock_transcription_duration.py @@ -7,6 +7,7 @@ import wave from model_bases.transcription_model_base import TranscriptionModelBase + class MockTranscribeDuration(TranscriptionModelBase): ''' Dummy TranscriptionModelBase implementation that returns the @@ -14,6 +15,23 @@ class MockTranscribeDuration(TranscriptionModelBase): ''' time = 0 + @staticmethod + def validate_config(config): + ''' + Should check if loaded JSON config is valid. Called model is instantiated. + Throw an error if provided config is not valid + Remember to call valididate_config for any model_bases to ensure configuration + for model_bases is checked as well. + e.g. if you use LocalAgreeModelBase: config = LocalAgreeModelBase.validate(config) + + Parameters: + config (dict): Parsed JSON config from server device_config.json. Guaranteed to be a dict. + + Returns: + config (TranscriptionModelConfig): Validated config object + ''' + return config + def load_model(self): ''' Loads model into memory to be ready for transcription. diff --git a/whisper-service/utils/config_dict_contains.py b/whisper-service/utils/config_dict_contains.py new file mode 100644 index 0000000..6cf9bc1 --- /dev/null +++ b/whisper-service/utils/config_dict_contains.py @@ -0,0 +1,108 @@ +''' +Utility functions to help when implementing TranscriptionModelBase.validate_config() + +Functions: + config_dict_contains_int + config_dict_contains_str + config_dict_contains_dict + config_dict_contains_list + config_dict_contains_one_of +''' +import sys +from typing import Any + + +def config_dict_contains_int(config: dict, key: str, minimum=-sys.maxsize - 1, maximum=sys.maxsize): + ''' + Checks if config contains a property, key, + that is an integer between minimum and maximum inclusive + + Parameters: + config (dict): Config dictionary + key (str) : Key to check in config dictionary + minimum (int) : (Optional) minimum value key is allowed to be + maximum (int) : (Optional) maximum value key is allowed to be + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if not isinstance(config[key], int): + raise ValueError(f'"{key}" property of config must be an integer') + if config[key] < minimum: + raise ValueError( + f'{key} property of config must be greater than or equal to {minimum}' + ) + if config[key] > maximum: + raise ValueError( + f'{key} property of config must be less than or equal to {maximum}' + ) + + +def config_dict_contains_str(config: dict, key: str, min_length=0, max_length=sys.maxsize): + ''' + Checks if config contains a property, key, + that is a string with length between min_length and max_length inclusive + + Parameters: + config (dict) : Config dictionary + key (str) : Key to check in config dictionary + min_length (int): (Optional) minimum value key is allowed to be + min_length (int): (Optional) maximum value key is allowed to be + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if not isinstance(config[key], str): + raise ValueError(f'"{key}" property of config must be an string') + if len(config[key]) < min_length: + raise ValueError( + f'{key} property of config must be string with length \ +greater than or equal to {min_length}' + ) + if len(config[key]) > max_length: + raise ValueError( + f'{key} property of config must be string with length \ +less than or equal to {max_length}' + ) + + +def config_dict_contains_dict(config: dict, key: str): + ''' + Checks if config contains a property, key, that is a dict + + Parameters: + config (dict): Config dictionary + key (str) : Key to check in config dictionary + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if not isinstance(config[key], dict): + raise ValueError(f'"{key}" property of config must be an object') + + +def config_dict_contains_list(config: dict, key: Any): + ''' + Checks if config contains a property, key, that is a list + + Parameters: + config (dict): Config dictionary + key (str) : Key to check in config dictionary + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if not isinstance(config[key], list): + raise ValueError(f'"{key}" property of config must be an array') + + +def config_dict_contains_one_of(config: dict, key: Any, options: list[Any]): + ''' + Checks if config contains a property, key, that one of the options provided + + Parameters: + config (dict): Config dictionary + key (Any) : Key to check in config dictionary + options (list): List of possible options property can have + ''' + if key not in config: + raise ValueError(f'Config missing "{key}" property') + if config[key] not in options: + raise ValueError( + f'"{key}" property of config must be one of: {options}') diff --git a/whisper-service/weights/tiny-quantized.bin b/whisper-service/weights/tiny-quantized.bin deleted file mode 100644 index 7fbba82..0000000 Binary files a/whisper-service/weights/tiny-quantized.bin and /dev/null differ