Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions .github/actions/pylint/action.yml

This file was deleted.

31 changes: 23 additions & 8 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
- 'datashare-python/**.py'
- 'worker-template/**.py'
- 'asr-worker/**.py'
- '.github/workflows/linting.yml'
- 'translation-worker/**.py'

# TODO: leverage some caching here
jobs:
Expand Down Expand Up @@ -37,26 +37,41 @@ jobs:
- name: Lint test
run: ruff check --config qa/ruff.toml worker-template

doc:
asr-worker:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
args: "--version" # skips test by displaying the version
- name: Check formatting
run: ruff format --config qa/ruff.toml --check docs
run: ruff format --config qa/ruff.toml --check asr-worker
- name: Lint test
run: ruff check --config qa/ruff.toml docs
run: ruff check --config qa/ruff.toml asr-worker

asr-worker:
translation-worker:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/pylint
name: "Lint ASR worker"
- uses: astral-sh/ruff-action@v3
with:
path: asr-worker
args: "--version" # skips test by displaying the version
- name: Check formatting
run: ruff format --config qa/ruff.toml --check translation-worker
- name: Lint test
run: ruff check --config qa/ruff.toml translation-worker

doc:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
args: "--version" # skips test by displaying the version
- name: Check formatting
run: ruff format --config qa/ruff.toml --check docs
- name: Lint test
run: ruff check --config qa/ruff.toml docs

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
54 changes: 54 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,60 @@ jobs:
uv sync --frozen --all-extras
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A

test-asr-worker:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Python project
uses: actions/setup-python@v6
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Cache Docker images
uses: ScribeMD/docker-cache@0.5.0
with:
key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yml') }}
- name: Start test services
run: docker compose up -d datashare temporal-post-init elasticsearch
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.ASTRAL_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
enable-cache: true
working-directory: worker-template
- name: Run tests
run: |
cd asr-worker
uv sync --frozen --all-extras
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A

test-translation-worker:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Setup Python project
uses: actions/setup-python@v6
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Cache Docker images
uses: ScribeMD/docker-cache@0.5.0
with:
key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yml') }}
- name: Start test services
run: docker compose up -d datashare temporal-post-init elasticsearch
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: ${{ env.ASTRAL_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
enable-cache: true
working-directory: worker-template
- name: Run tests
run: |
cd translation-worker
uv sync --frozen --all-extras
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ WORKDIR /app

# add task cli
ADD datashare-python/ ./datashare-python/
ADD worker-template/ ./worker-template/

# install python deps
RUN --mount=type=cache,target=~/.cache/uv \
uv pip install --system datashare-python/
RUN --mount=type=cache,target=~/.cache/uv uv pip install --system worker-template/
RUN --mount=type=cache,target=~/.cache/uv uv pip install --system datashare-python/

# copy build-independant files
ADD scripts scripts
Expand Down
1 change: 1 addition & 0 deletions asr-worker/asr_worker/activities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torchaudio

from caul.configs.parakeet import ParakeetConfig
from caul.model_handlers.helpers import ParakeetModelHandlerResult
from caul.tasks.preprocessing.helpers import PreprocessedInput
Expand Down
6 changes: 1 addition & 5 deletions asr-worker/asr_worker/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,4 @@

PARAKEET = "parakeet"

DEFAULT_TEMPORAL_ADDRESS = "temporal:7233"

RESPONSE_SUCCESS = "success"

RESPONSE_ERROR = "error"
ASR_WORKFLOW_NAME = "asr-workflow"
54 changes: 0 additions & 54 deletions asr-worker/asr_worker/models.py

This file was deleted.

40 changes: 40 additions & 0 deletions asr-worker/asr_worker/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pydantic import BaseModel, Field
from datashare_python.objects import WorkerResponse, BasePayload

from .constants import PARAKEET


class BatchSize(BaseModel):
"""Batch size helper"""

batch_size: int = 32


class PreprocessingConfig(BatchSize):
"""Preprocessing config"""


class InferenceConfig(BatchSize):
"""Inference config"""

model_name: str = PARAKEET


class ASRPipelineConfig(BaseModel):
"""ASR pipeline config"""

preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)


class ASRRequest(BasePayload):
"""Inputs to ASR workflow"""

file_paths: list[str]
pipeline: ASRPipelineConfig


class ASRResponse(WorkerResponse):
"""ASR workflow response"""

transcriptions: list[dict] = Field(default_factory=list)
12 changes: 5 additions & 7 deletions asr-worker/asr_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@
import logging
import os
import socket
from concurrent.futures import ThreadPoolExecutor

from concurrent.futures import ThreadPoolExecutor
from temporalio.client import Client
from temporalio.worker import Worker

from temporalio import workflow

from asr_worker.constants import (
from datashare_python.constants import DEFAULT_TEMPORAL_ADDRESS
from .constants import (
ASR_TASK_QUEUE,
ASR_WORKER_NAME,
DEFAULT_TEMPORAL_ADDRESS,
)
from asr_worker.workflow import ASRWorkflow
from .workflow import ASRWorkflow

with workflow.unsafe.imports_passed_through():
from asr_worker.activities import ASRActivities
from .activities import ASRActivities

LOGGER = logging.getLogger(__name__)

Expand Down
26 changes: 12 additions & 14 deletions asr-worker/asr_worker/workflow.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
from asyncio import gather
from dataclasses import asdict
from datetime import timedelta

from more_itertools import flatten
from temporalio import workflow

from asr_worker.constants import _TEN_MINUTES, RESPONSE_ERROR, RESPONSE_SUCCESS
from asr_worker.models import ASRInputs, ASRResponse
from datashare_python.objects import WorkerResponseStatus
from .objects import ASRResponse, ASRRequest
from .constants import _TEN_MINUTES

with workflow.unsafe.imports_passed_through():
from asr_worker.activities import ASRActivities
from .activities import ASRActivities


# TODO: Figure out which modules are violating sandbox restrictions
# and grant a limited passthrough
@workflow.defn(name="asr.transcription", sandboxed=False)
@workflow.defn(sandboxed=False)
class ASRWorkflow:
"""ASR workflow definition"""

def __init__(self):
pass

@workflow.run
async def run(self, inputs: ASRInputs) -> ASRResponse:
async def run(self, inputs: ASRRequest) -> ASRResponse:
"""Run ASR workflow

:param inputs: ASRInputs
:param inputs: ASRRequest
:return: ASRResponse
"""
try:
Expand Down Expand Up @@ -81,8 +81,8 @@ async def run(self, inputs: ASRInputs) -> ASRResponse:
serialized_transcriptions = []

# drop unnecessary fields, serialize
for transcription in flatten(transcriptions):
transcription = asdict(transcription)
for trans in flatten(transcriptions):
transcription = asdict(trans)

del transcription["input_ordering"]

Expand All @@ -93,11 +93,9 @@ async def run(self, inputs: ASRInputs) -> ASRResponse:
# TODO: Output formatting; do we want to keep PreprocessedInput metadata
# and remap results to it?
return ASRResponse(
status=RESPONSE_SUCCESS, transcriptions=serialized_transcriptions
status=WorkerResponseStatus.SUCCESS,
transcriptions=serialized_transcriptions,
)
except ValueError as e:
workflow.logger.exception(e)
return ASRResponse(status=RESPONSE_ERROR, error=str(e))


WORKFLOWS = [ASRWorkflow]
return ASRResponse(status=WorkerResponseStatus.ERROR, error=str(e))
Loading