Skip to content

Commit f0932bc

Browse files
committed
initial commit
1 parent 8eb0718 commit f0932bc

24 files changed

Lines changed: 1031 additions & 145 deletions

File tree

asr-worker/asr_worker/constants.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,4 @@
88

99
PARAKEET = "parakeet"
1010

11-
DEFAULT_TEMPORAL_ADDRESS = "temporal:7233"
12-
13-
RESPONSE_SUCCESS = "success"
14-
15-
RESPONSE_ERROR = "error"
11+
ASR_WORKFLOW_NAME = "asr-workflow"
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass, field, asdict
2-
from typing import Optional
32

4-
from asr_worker.constants import PARAKEET
3+
from .constants import PARAKEET
4+
5+
from datashare_python.objects import WorkerResponse
56

67

78
@dataclass
@@ -45,10 +46,7 @@ class ASRInputs:
4546
pipeline: ASRPipelineConfig
4647

4748

48-
@dataclass
49-
class ASRResponse:
49+
class ASRResponse(WorkerResponse):
5050
"""ASR workflow response"""
5151

52-
status: str
5352
transcriptions: list[dict] = field(default_factory=list)
54-
error: Optional[str] = None

asr-worker/asr_worker/worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010
from temporalio import workflow
1111

12-
from asr_worker.constants import (
12+
from datashare_python.constants import DEFAULT_TEMPORAL_ADDRESS
13+
from .constants import (
1314
ASR_TASK_QUEUE,
1415
ASR_WORKER_NAME,
15-
DEFAULT_TEMPORAL_ADDRESS,
1616
)
17-
from asr_worker.workflow import ASRWorkflow
17+
from .workflow import ASRWorkflow
1818

1919
with workflow.unsafe.imports_passed_through():
20-
from asr_worker.activities import ASRActivities
20+
from .activities import ASRActivities
2121

2222
LOGGER = logging.getLogger(__name__)
2323

asr-worker/asr_worker/workflow.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from more_itertools import flatten
66
from temporalio import workflow
77

8-
from asr_worker.constants import _TEN_MINUTES, RESPONSE_ERROR, RESPONSE_SUCCESS
9-
from asr_worker.models import ASRInputs, ASRResponse
8+
from asr_worker.models import ASRResponse, ASRInputs
9+
from asr_worker.constants import _TEN_MINUTES, RESPONSE_SUCCESS, RESPONSE_ERROR
1010

1111
with workflow.unsafe.imports_passed_through():
1212
from asr_worker.activities import ASRActivities
1313

1414

1515
# TODO: Figure out which modules are violating sandbox restrictions
1616
# and grant a limited passthrough
17-
@workflow.defn(name="asr.transcription", sandboxed=False)
17+
@workflow.defn(sandboxed=False)
1818
class ASRWorkflow:
1919
"""ASR workflow definition"""
2020

@@ -28,6 +28,7 @@ async def run(self, inputs: ASRInputs) -> ASRResponse:
2828
:param inputs: ASRInputs
2929
:return: ASRResponse
3030
"""
31+
3132
try:
3233
# Preprocessing
3334
preprocessed_batches = await gather(
@@ -98,6 +99,3 @@ async def run(self, inputs: ASRInputs) -> ASRResponse:
9899
except ValueError as e:
99100
workflow.logger.exception(e)
100101
return ASRResponse(status=RESPONSE_ERROR, error=str(e))
101-
102-
103-
WORKFLOWS = [ASRWorkflow]

asr-worker/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ authors = [
2121
readme = "README.md"
2222
requires-python = ">=3.11.0, <3.14"
2323
dependencies = [
24+
"datashare-python",
2425
"temporalio>=1.22.0",
2526
"caul==0.1.5",
2627
"pyarrow<21.0.0",
@@ -37,6 +38,8 @@ workflows = "asr_worker.workflows:REGISTRY"
3738
[project.entry-points."datashare.activities"]
3839
activities = "asr_worker.activities:REGISTRY"
3940

41+
42+
4043
[[tool.uv.index]]
4144
name = "pytorch-cpu"
4245
url = "https://download.pytorch.org/whl/cpu"
@@ -52,6 +55,7 @@ torchaudio = [
5255
torchcodec = [
5356
{ index = "pytorch-cpu" }
5457
]
58+
datashare-python = { path = "../datashare-python", editable = true }
5559

5660
[tool.uv]
5761
package = true
File renamed without changes.

datashare-python/datashare_python/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@
1111
DEFAULT_DS_ADDRESS = "http://localhost:8080"
1212

1313
DEFAULT_NAMESPACE = "datashare-default"
14+
15+
WORKER_RESPONSE_SUCCESS = "success"
16+
17+
WORKER_RESPONSE_ERROR = "error"
18+
19+
CPU = "cpu"
20+
21+
CUDA = "cuda"
22+
23+
MPS = "mps"
24+
25+
MKL = "mkl"

datashare-python/datashare_python/objects.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from dataclasses import dataclass
44
from datetime import UTC, datetime
55
from enum import StrEnum, unique
6-
from typing import Any, Literal, Self
6+
from typing import Any, Literal, Self, Optional, TypeVar, Awaitable
77

88
from temporalio import workflow
99

10+
from datashare_python.constants import WORKER_RESPONSE_SUCCESS, WORKER_RESPONSE_ERROR
11+
1012
with workflow.unsafe.imports_passed_through():
1113
from icij_common.es import DOC_CONTENT, DOC_LANGUAGE, DOC_ROOT_ID, ID_, SOURCE
1214

@@ -23,14 +25,28 @@
2325
logger = logging.getLogger(__name__)
2426

2527

28+
T = TypeVar("T")
29+
30+
31+
Predicate = Callable[[T], bool] | Callable[[T], Awaitable[bool]]
32+
33+
2634
class BaseModel(_BaseModel):
2735
model_config = merge_configs(icij_config(), no_enum_values_config())
2836

2937

38+
class BasePayload(_BaseModel):
39+
model_config = icij_config()
40+
41+
3042
class DatashareModel(BaseModel):
3143
model_config = merge_configs(BaseModel.model_config, lowercamel_case_config())
3244

3345

46+
class LowerCamelCaseModel(_BaseModel):
47+
model_config = merge_configs(icij_config(), lowercamel_case_config())
48+
49+
3450
@unique
3551
class TaskState(StrEnum):
3652
CREATED = "CREATED"
@@ -142,3 +158,16 @@ def from_es(cls, es_doc: dict) -> Self:
142158
root_document=sources[DOC_ROOT_ID],
143159
tags=sources.get("tags", []),
144160
)
161+
162+
163+
# Temporal objects
164+
class WorkerResponse(BasePayload):
165+
"""Generic worker response"""
166+
167+
status: str
168+
error: Optional[str] = None
169+
170+
171+
class WorkerResponseStatus(StrEnum):
172+
SUCCESS = WORKER_RESPONSE_SUCCESS
173+
ERROR = WORKER_RESPONSE_ERROR

datashare-python/datashare_python/utils.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@
66
from dataclasses import dataclass
77
from functools import partial, wraps
88
from inspect import signature
9-
from typing import ParamSpec, TypeVar
9+
from itertools import islice
10+
from typing import (
11+
ParamSpec,
12+
TypeVar,
13+
AsyncIterator,
14+
AsyncIterable,
15+
Awaitable,
16+
Generator,
17+
Iterable,
18+
)
1019

1120
import nest_asyncio
21+
import torch
1222
from icij_common.logging_utils import (
1323
DATE_FMT,
1424
STREAM_HANDLER_FMT,
@@ -23,6 +33,8 @@
2333
from temporalio.common import SearchAttributeKey
2434
from temporalio.exceptions import ApplicationError
2535

36+
from .constants import CPU, MPS, MKL, CUDA
37+
from .objects import Predicate
2638
from .types_ import ProgressRateHandler, RawProgressHandler
2739

2840
DependencyLabel = str | None
@@ -201,6 +213,7 @@ async def wrapper(*args, **kwargs) -> T:
201213
# recreate kwargs from pargs
202214
new_args, new_kwargs = _unpack_positional_args(args, keyword_only, params)
203215
return await activity_fn(*new_args, **new_kwargs, **kwargs)
216+
204217
else:
205218

206219
@wraps(activity_fn)
@@ -211,9 +224,11 @@ def wrapper(*args, **kwargs) -> T:
211224

212225
# Update the decorated function signature to appear as p-args only
213226
new_params = [
214-
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
215-
if p.kind == inspect.Parameter.KEYWORD_ONLY
216-
else p
227+
(
228+
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
229+
if p.kind == inspect.Parameter.KEYWORD_ONLY
230+
else p
231+
)
217232
for p in params
218233
]
219234
wrapper.__signature__ = sig.replace(parameters=new_params)
@@ -252,6 +267,7 @@ async def wrapper(*args, **kwargs) -> T:
252267
raise
253268
except Exception as e:
254269
raise fatal_error_from_exception(e) from e
270+
255271
else:
256272

257273
@wraps(activity_fn)
@@ -374,3 +390,83 @@ def _handlers(
374390
handler.addFilter(worker_id_filter)
375391
handler.setLevel(log_level)
376392
return handlers
393+
394+
395+
# Temporal utils
396+
async def async_batches(
397+
iterable: AsyncIterable[T], batch_size: int
398+
) -> AsyncIterator[tuple[T]]:
399+
it = aiter(iterable)
400+
if batch_size < 1:
401+
raise ValueError("n must be at least one")
402+
while True:
403+
batch = []
404+
while len(batch) < batch_size:
405+
try:
406+
batch.append(await anext(it))
407+
except StopAsyncIteration:
408+
if batch:
409+
yield tuple(batch)
410+
return
411+
yield tuple(batch)
412+
413+
414+
def batches(
415+
iterable: Iterable[T], batch_size: int
416+
) -> Generator[tuple[T, ...], None, None]:
417+
if batch_size < 1:
418+
raise ValueError("n must be at least one")
419+
it = iter(iterable)
420+
while batch := tuple(islice(it, batch_size)):
421+
yield batch
422+
423+
424+
async def maybe_await(maybe_awaitable: Awaitable[T] | T) -> T:
425+
if inspect.isawaitable(maybe_awaitable):
426+
return await maybe_awaitable
427+
return maybe_awaitable
428+
429+
430+
async def once(item: T) -> AsyncIterator[T]:
431+
yield item
432+
433+
434+
def before_and_after(
435+
iterable: AsyncIterable[T], predicate: Predicate[T]
436+
) -> tuple[AsyncIterable[T], AsyncIterable[T]]:
437+
transition = asyncio.get_event_loop().create_future()
438+
439+
async def true_iterator() -> AsyncIterator[T]:
440+
async for elem in iterable:
441+
if await maybe_await(predicate(elem)):
442+
yield elem
443+
else:
444+
transition.set_result(elem)
445+
return
446+
transition.set_exception(StopAsyncIteration)
447+
448+
async def remainder_iterator() -> AsyncIterator[T]:
449+
try:
450+
yield await transition
451+
except StopAsyncIteration:
452+
return
453+
async for elm in iterable:
454+
yield elm
455+
456+
return true_iterator(), remainder_iterator()
457+
458+
459+
# Torch utils
460+
def find_device(device_name: str = CPU) -> torch.Device:
461+
"""Find a device by name if available
462+
463+
:param device_name: Device name
464+
:return: torch.Device
465+
"""
466+
if (
467+
hasattr(torch.backends, device_name)
468+
and getattr(torch.backends, device_name).is_available()
469+
):
470+
return torch.device(device_name)
471+
472+
return torch.device(CPU)

datashare-python/tests/test_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
from datashare_python.utils import positional_args_only
1+
from typing import AsyncGenerator, AsyncIterable, AsyncIterator
2+
3+
from aiostream.stream import chain
4+
5+
from datashare_python.utils import positional_args_only, before_and_after, once
26
from temporalio import activity
37

48

@@ -15,3 +19,29 @@ def test_keyword_safe_activity() -> None:
1519
raise AssertionError(
1620
"couldn't create activity from keyword only function "
1721
) from e
22+
23+
24+
async def _num_gen() -> AsyncGenerator[int, None]:
25+
for i in range(10):
26+
yield i // 3
27+
28+
29+
async def test_before_and_after() -> None:
30+
# Given
31+
async def group_by_iterator(
32+
items: AsyncIterable[int],
33+
) -> AsyncIterator[AsyncIterator[int]]:
34+
while True:
35+
try:
36+
next_item = await anext(aiter(items))
37+
except StopAsyncIteration:
38+
return
39+
gr, items = before_and_after(items, lambda x, next_i=next_item: x == next_i)
40+
yield chain(once(next_item), gr)
41+
42+
# When
43+
grouped = []
44+
async for group in group_by_iterator(_num_gen()):
45+
group = [item async for item in group] # noqa: PLW2901
46+
grouped.append(group)
47+
assert grouped == [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3]]

0 commit comments

Comments
 (0)