Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions sdk/batch/speechmatics/batch/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def submit_job(
config: Optional[JobConfig] = None,
transcription_config: Optional[TranscriptionConfig] = None,
parallel_engines: Optional[int] = None,
user_id: Optional[str] = None,
) -> JobDetails:
"""
Submit a new transcription job.
Expand All @@ -159,6 +160,9 @@ async def submit_job(
parallel_engines: Optional number of parallel engines to request for this job.
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.
user_id: Optional user identifier to associate with this job.
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.

Returns:
JobDetails object containing the job ID and initial status.
Expand Down Expand Up @@ -205,7 +209,9 @@ async def submit_job(
assert audio_file is not None # for type checker; validated above
multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict)

return await self._submit_and_create_job_details(multipart_data, filename, config, parallel_engines)
return await self._submit_and_create_job_details(
multipart_data, filename, config, parallel_engines, user_id
)
except Exception as e:
if isinstance(e, (AuthenticationError, BatchError)):
raise
Expand Down Expand Up @@ -441,6 +447,7 @@ async def transcribe(
polling_interval: float = 5.0,
timeout: Optional[float] = None,
parallel_engines: Optional[int] = None,
user_id: Optional[str] = None,
) -> Union[Transcript, str]:
"""
Complete transcription workflow: submit job and wait for completion.
Expand All @@ -457,6 +464,9 @@ async def transcribe(
parallel_engines: Optional number of parallel engines to request for this job.
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.
user_id: Optional user identifier to associate with this job.
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.

Returns:
Transcript object containing the transcript and metadata.
Expand Down Expand Up @@ -485,6 +495,7 @@ async def transcribe(
config=config,
transcription_config=transcription_config,
parallel_engines=parallel_engines,
user_id=user_id,
)

# Wait for completion and return result
Expand Down Expand Up @@ -538,12 +549,22 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi
return multipart_data, filename

async def _submit_and_create_job_details(
self, multipart_data: dict, filename: str, config: JobConfig, parallel_engines: Optional[int] = None
self,
multipart_data: dict,
filename: str,
config: JobConfig,
parallel_engines: Optional[int] = None,
user_id: Optional[str] = None,
) -> JobDetails:
"""Submit job and create JobDetails response."""
extra_headers: Optional[dict[str, Any]] = None
processing_data: dict[str, Any] = {}
if parallel_engines is not None:
extra_headers = {PROCESSING_DATA_HEADER: {"parallel_engines": parallel_engines}}
processing_data["parallel_engines"] = parallel_engines
if user_id is not None:
processing_data["user_id"] = user_id
if processing_data:
extra_headers = {PROCESSING_DATA_HEADER: processing_data}
response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers)
job_id = response.get("id")
if not job_id:
Expand Down
86 changes: 85 additions & 1 deletion tests/batch/test_submit_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Unit tests for AsyncClient.submit_job, focusing on the parallel_engines feature."""
"""Unit tests for AsyncClient.submit_job, focusing on the parallel engines and user_id features."""

import json
from io import BytesIO
Expand Down Expand Up @@ -127,6 +127,90 @@ async def test_header_sent_with_fetch_data_config(self):
assert payload == {"parallel_engines": 2}


class TestUserIdHeader:
"""X-SM-Processing-Data header is set correctly based on user_id."""

@pytest.mark.asyncio
async def test_header_sent_when_user_id_provided(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, user_id="user-abc")

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
assert PROCESSING_DATA_HEADER in extra_headers
payload = extra_headers[PROCESSING_DATA_HEADER]
assert payload == {"user_id": "user-abc"}

@pytest.mark.asyncio
async def test_header_not_sent_when_user_id_is_none(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio)

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is None

@pytest.mark.asyncio
async def test_user_id_and_parallel_engines_sent_together(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, parallel_engines=4, user_id="user-xyz")

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
payload = extra_headers[PROCESSING_DATA_HEADER]
assert payload == {"parallel_engines": 4, "user_id": "user-xyz"}

@pytest.mark.asyncio
async def test_user_id_does_not_appear_when_only_parallel_engines_set(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, parallel_engines=2)

payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER]
assert "user_id" not in payload

@pytest.mark.asyncio
async def test_parallel_engines_does_not_appear_when_only_user_id_set(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, user_id="u1")

payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER]
assert "parallel_engines" not in payload

@pytest.mark.asyncio
async def test_user_id_forwarded_from_transcribe(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
with patch.object(client, "wait_for_completion", new_callable=AsyncMock) as mock_wait:
mock_wait.return_value = MagicMock()
await client.transcribe(audio, user_id="transcribe-user")

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
assert extra_headers[PROCESSING_DATA_HEADER]["user_id"] == "transcribe-user"


class TestSubmitJobReturnValue:
"""submit_job still returns the correct JobDetails regardless of parallel_engines."""

Expand Down
Loading