Skip to content

Commit f6d93c4

Browse files
author
Kyle Bridburg
authored
AISDK-229: Add skip_postprocessing option (#97)
1 parent 26c5bdc commit f6d93c4

File tree

6 files changed

+60
-31
lines changed

6 files changed

+60
-31
lines changed

examples/async_example.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
# delete_after_seconds=None,
4747
# language=None,
4848
# custom_vocabulary_id=None,
49-
# notification_config=None)
49+
# notification_config=None,
50+
# skip_postprocessing=False)
5051

5152

5253
# Submitting a job with a link to the file you want transcribed
@@ -63,7 +64,8 @@
6364
language=None,
6465
custom_vocabulary_id=None,
6566
source_config=None,
66-
notification_config=None)
67+
notification_config=None,
68+
skip_postprocessing=False)
6769

6870
print("Submitted Job")
6971

src/rev_ai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""Top-level package for rev_ai"""
33

4-
__version__ = '2.17.0'
4+
__version__ = '2.17.1'
55

66
from .models import Job, JobStatus, Account, Transcript, Monologue, Element, MediaConfig, \
77
CaptionType, CustomVocabulary, TopicExtractionJob, TopicExtractionResult, Topic, Informant, \

src/rev_ai/apiclient.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def submit_job_url(
6363
segments_to_transcribe=None,
6464
speaker_names=None,
6565
source_config=None,
66-
notification_config=None):
66+
notification_config=None,
67+
skip_postprocessing=False):
6768
"""Submit media given a URL for transcription.
6869
The audio data is downloaded from the URL
6970
:param media_url: web location of the media file
@@ -109,6 +110,7 @@ def submit_job_url(
109110
:param notification_config: CustomerUrlData object containing the callback url to
110111
invoke on job completion as a webhook and optional authentication headers to use when
111112
calling the callback url
113+
:param skip_postprocessing: skip all text postprocessing (punctuation, capitalization, ITN)
112114
:returns: raw response data
113115
:raises: HTTPError
114116
"""
@@ -120,7 +122,8 @@ def submit_job_url(
120122
language, custom_vocabulary_id, transcriber,
121123
verbatim, rush, test_mode,
122124
segments_to_transcribe, speaker_names,
123-
source_config, notification_config)
125+
source_config, notification_config,
126+
skip_postprocessing)
124127

125128
response = self._make_http_request(
126129
"POST",
@@ -150,7 +153,8 @@ def submit_job_local_file(
150153
test_mode=None,
151154
segments_to_transcribe=None,
152155
speaker_names=None,
153-
notification_config=None):
156+
notification_config=None,
157+
skip_postprocessing=False):
154158
"""Submit a local file for transcription.
155159
Note that the content type is inferred if not provided.
156160
@@ -193,6 +197,7 @@ def submit_job_local_file(
193197
:param notification_config: CustomerUrlData object containing the callback url to
194198
invoke on job completion as a webhook and optional authentication headers to use when
195199
calling the callback url
200+
:param skip_postprocessing: skip all text postprocessing (punctuation, capitalization, ITN)
196201
:returns: raw response data
197202
:raises: HTTPError, ValueError
198203
"""
@@ -207,7 +212,7 @@ def submit_job_local_file(
207212
language, custom_vocabulary_id, transcriber,
208213
verbatim, rush, test_mode,
209214
segments_to_transcribe, speaker_names, None,
210-
notification_config)
215+
notification_config, skip_postprocessing)
211216

212217
with open(filename, 'rb') as f:
213218
files = {
@@ -457,7 +462,8 @@ def _create_job_options_payload(
457462
segments_to_transcribe=None,
458463
speaker_names=None,
459464
source_config=None,
460-
notification_config=None):
465+
notification_config=None,
466+
skip_postprocessing=False):
461467
payload = {}
462468
if media_url:
463469
payload['media_url'] = media_url
@@ -500,6 +506,8 @@ def _create_job_options_payload(
500506
payload['source_config'] = source_config.to_dict()
501507
if notification_config:
502508
payload['notification_config'] = notification_config.to_dict()
509+
if skip_postprocessing:
510+
payload['skip_postprocessing'] = skip_postprocessing
503511
return payload
504512

505513
def _create_captions_query(self, speaker_channel):

src/rev_ai/streamingclient.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def on_connected(job_id):
2525
print('Connected, Job ID : {}'.format(job_id))
2626

2727

28-
class RevAiStreamingClient():
28+
class RevAiStreamingClient:
2929
def __init__(self,
3030
access_token,
3131
config,
@@ -40,7 +40,7 @@ def __init__(self,
4040
:param config: a MediaConfig object containing audio information.
4141
See MediaConfig.py for more information
4242
:param version (optional): version of the streaming api to be used
43-
:param on_error (optional): function to be called when recieving an
43+
:param on_error (optional): function to be called when receiving an
4444
error from the server
4545
:param on_close (optional): function to be called when the websocket
4646
closes
@@ -72,7 +72,8 @@ def start(self,
7272
detailed_partials=None,
7373
start_ts=None,
7474
transcriber=None,
75-
language=None):
75+
language=None,
76+
skip_postprocessing=None):
7677
"""Function to connect the websocket to the URL and start the response
7778
thread
7879
:param generator: generator object that yields binary audio data
@@ -85,6 +86,7 @@ def start(self,
8586
:param start_ts: number of seconds to offset all hypotheses timings
8687
:param transcriber: type of transcriber to use to transcribe the media file
8788
:param language: language to use for the streaming job
89+
:param skip_postprocessing: skip all text postprocessing on final hypotheses
8890
"""
8991
url = self.base_url + '?' + urlencode({
9092
'access_token': self.access_token,
@@ -119,6 +121,9 @@ def start(self,
119121
if language:
120122
url += '&' + urlencode({'language': language})
121123

124+
if skip_postprocessing:
125+
url += '&' + urlencode({'skip_postprocessing': 'true'})
126+
122127
try:
123128
self.client.connect(url)
124129
except Exception as e:
@@ -153,7 +158,7 @@ def _start_send_data_thread(self, generator):
153158

154159
def _send_data(self, generator):
155160
"""Function used in a thread to send requests to the server.
156-
:param generator: enerator object that yields binary audio data
161+
:param generator: enumerator object that yields binary audio data
157162
"""
158163
if not generator:
159164
raise ValueError('generator must be provided')
@@ -164,7 +169,7 @@ def _send_data(self, generator):
164169
self.client.send("EOS")
165170

166171
def _get_response_generator(self):
167-
"""A generator of reponses from the server. Yields the data decoded.
172+
"""A generator of responses from the server. Yields the data decoded.
168173
"""
169174
while True:
170175
with self.client.readlock:

tests/test_job.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def test_submit_job_url_with_success(self, mock_session, make_mock_response):
119119
'remove_disfluencies': True,
120120
'delete_after_seconds': 0,
121121
'language': LANGUAGE,
122-
'transcriber': TRANSCRIBER
122+
'transcriber': TRANSCRIBER,
123+
'skip_postprocessing': True
123124
}
124125
response = make_mock_response(url=JOB_ID_URL, json_data=data)
125126
mock_session.request.return_value = response
@@ -129,7 +130,7 @@ def test_submit_job_url_with_success(self, mock_session, make_mock_response):
129130
NOTIFICATION_URL, True,
130131
True, 1, CUSTOM_VOCAB, True,
131132
True, 0, LANGUAGE, CUSTOM_VOCAB_ID,
132-
TRANSCRIBER)
133+
TRANSCRIBER, skip_postprocessing=True)
133134

134135
assert res == Job(JOB_ID,
135136
CREATED_ON,
@@ -159,7 +160,8 @@ def test_submit_job_url_with_success(self, mock_session, make_mock_response):
159160
'delete_after_seconds': 0,
160161
'language': LANGUAGE,
161162
'custom_vocabulary_id': CUSTOM_VOCAB_ID,
162-
'transcriber': TRANSCRIBER
163+
'transcriber': TRANSCRIBER,
164+
'skip_postprocessing': True
163165
},
164166
headers=client.default_headers)
165167

@@ -176,7 +178,8 @@ def test_submit_job_url_with_auth_options(self, mock_session, make_mock_response
176178
'remove_disfluencies': True,
177179
'delete_after_seconds': 0,
178180
'language': LANGUAGE,
179-
'transcriber': TRANSCRIBER
181+
'transcriber': TRANSCRIBER,
182+
'skip_postprocessing': True
180183
}
181184
response = make_mock_response(url=JOB_ID_URL, json_data=data)
182185
mock_session.request.return_value = response
@@ -188,7 +191,8 @@ def test_submit_job_url_with_auth_options(self, mock_session, make_mock_response
188191
delete_after_seconds=0, language=LANGUAGE,
189192
custom_vocabulary_id=CUSTOM_VOCAB_ID, transcriber=TRANSCRIBER,
190193
source_config=SOURCE_CONFIG,
191-
notification_config=NOTIFICATION_CONFIG)
194+
notification_config=NOTIFICATION_CONFIG,
195+
skip_postprocessing=True)
192196

193197
assert res == Job(JOB_ID,
194198
CREATED_ON,
@@ -218,7 +222,8 @@ def test_submit_job_url_with_auth_options(self, mock_session, make_mock_response
218222
'delete_after_seconds': 0,
219223
'language': LANGUAGE,
220224
'custom_vocabulary_id': CUSTOM_VOCAB_ID,
221-
'transcriber': TRANSCRIBER
225+
'transcriber': TRANSCRIBER,
226+
'skip_postprocessing': True
222227
},
223228
headers=client.default_headers)
224229

@@ -277,7 +282,8 @@ def test_submit_job_local_file_with_success(self, mocker, mock_session, make_moc
277282
'remove_disfluencies': True,
278283
'delete_after_seconds': 0,
279284
'language': LANGUAGE,
280-
'transcriber': TRANSCRIBER
285+
'transcriber': TRANSCRIBER,
286+
'skip_postprocessing': True
281287
}
282288
response = make_mock_response(url=JOB_ID_URL, json_data=data)
283289
mock_session.request.return_value = response
@@ -288,7 +294,7 @@ def test_submit_job_local_file_with_success(self, mocker, mock_session, make_moc
288294
NOTIFICATION_URL, True,
289295
True, 1, CUSTOM_VOCAB, True,
290296
True, 0, LANGUAGE, CUSTOM_VOCAB_ID,
291-
TRANSCRIBER)
297+
TRANSCRIBER, skip_postprocessing=True)
292298

293299
assert res == Job(JOB_ID,
294300
CREATED_ON,
@@ -322,7 +328,8 @@ def test_submit_job_local_file_with_success(self, mocker, mock_session, make_moc
322328
'delete_after_seconds': 0,
323329
'language': LANGUAGE,
324330
'custom_vocabulary_id': CUSTOM_VOCAB_ID,
325-
'transcriber': TRANSCRIBER
331+
'transcriber': TRANSCRIBER,
332+
'skip_postprocessing': True
326333
}, sort_keys=True)
327334
)
328335
},
@@ -343,7 +350,8 @@ def test_submit_job_local_file_auth_options_with_success(self, mocker, mock_sess
343350
'remove_disfluencies': True,
344351
'delete_after_seconds': 0,
345352
'language': LANGUAGE,
346-
'transcriber': TRANSCRIBER
353+
'transcriber': TRANSCRIBER,
354+
'skip_postprocessing': True
347355
}
348356
response = make_mock_response(url=JOB_ID_URL, json_data=data)
349357
mock_session.request.return_value = response
@@ -358,7 +366,8 @@ def test_submit_job_local_file_auth_options_with_success(self, mocker, mock_sess
358366
delete_after_seconds=0, language=LANGUAGE,
359367
custom_vocabulary_id=CUSTOM_VOCAB_ID,
360368
transcriber=TRANSCRIBER,
361-
notification_config=NOTIFICATION_CONFIG)
369+
notification_config=NOTIFICATION_CONFIG,
370+
skip_postprocessing=True)
362371

363372
assert res == Job(JOB_ID,
364373
CREATED_ON,
@@ -392,7 +401,8 @@ def test_submit_job_local_file_auth_options_with_success(self, mocker, mock_sess
392401
'delete_after_seconds': 0,
393402
'language': LANGUAGE,
394403
'custom_vocabulary_id': CUSTOM_VOCAB_ID,
395-
'transcriber': TRANSCRIBER
404+
'transcriber': TRANSCRIBER,
405+
'skip_postprocessing': True
396406
}, sort_keys=True)
397407
)
398408
},

tests/test_streamingclient.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_constructor_no_token_no_config(self):
5555

5656
def test_start_noparams_success(self, mock_streaming_client, mock_generator, capsys):
5757
expected_query_dict = build_expected_query_dict(mock_streaming_client, None, None, None, None, None, None, None,
58-
None, None)
58+
None, None, None)
5959

6060
example_data = '{"type":"partial","transcript":"Test"}'
6161
example_connected = '{"type":"connected","id":"testid"}'
@@ -93,9 +93,10 @@ def test_start_noparams_success(self, mock_streaming_client, mock_generator, cap
9393
@pytest.mark.parametrize("start_ts", [10])
9494
@pytest.mark.parametrize("transcriber", ["machine"])
9595
@pytest.mark.parametrize("language", ["en"])
96+
@pytest.mark.parametrize("skip_postprocessing", [True])
9697
def test_start_allparams_success(self, mock_streaming_client, mock_generator, capsys,
97-
metadata, custom_vocabulary_id, filter_profanity, remove_disfluencies, delete_after_seconds, detailed_partials,
98-
start_ts, transcriber, language):
98+
metadata, custom_vocabulary_id, filter_profanity, remove_disfluencies, delete_after_seconds,
99+
detailed_partials, start_ts, transcriber, language, skip_postprocessing):
99100

100101
expected_query_dict = build_expected_query_dict(
101102
mock_streaming_client,
@@ -107,7 +108,8 @@ def test_start_allparams_success(self, mock_streaming_client, mock_generator, ca
107108
detailed_partials,
108109
start_ts,
109110
transcriber,
110-
language
111+
language,
112+
skip_postprocessing
111113
)
112114
example_data = '{"type":"partial","transcript":"Test"}'
113115
example_connected = '{"type":"connected","id":"testid"}'
@@ -124,7 +126,7 @@ def test_start_allparams_success(self, mock_streaming_client, mock_generator, ca
124126

125127
response_gen = mock_streaming_client.start(mock_generator(),
126128
metadata, custom_vocabulary_id, filter_profanity, remove_disfluencies, delete_after_seconds,
127-
detailed_partials, start_ts, transcriber, language)
129+
detailed_partials, start_ts, transcriber, language, skip_postprocessing)
128130

129131
called_url = mock_streaming_client.client.connect.call_args_list[0][0][0]
130132
validate_query_parameters(called_url, expected_query_dict)
@@ -152,7 +154,7 @@ def test_end(self, mock_streaming_client):
152154

153155
def build_expected_query_dict(mock_streaming_client,
154156
metadata, custom_vocabulary_id, filter_profanity, remove_disfluencies, delete_after_seconds, detailed_partials,
155-
start_ts, transcriber, language):
157+
start_ts, transcriber, language, skip_postprocessing):
156158
expected_query_dict = {
157159
'access_token': mock_streaming_client.access_token,
158160
'content_type': mock_streaming_client.config.get_content_type_string(),
@@ -177,6 +179,8 @@ def build_expected_query_dict(mock_streaming_client,
177179
expected_query_dict["transcriber"] = transcriber
178180
if language:
179181
expected_query_dict["language"] = language
182+
if skip_postprocessing:
183+
expected_query_dict["skip_postprocessing"] = "true"
180184

181185
return expected_query_dict
182186

0 commit comments

Comments
 (0)