Skip to content

Commit 6ab4bf6

Browse files
committed
chore: Update TTS client
1 parent a543e92 commit 6ab4bf6

2 files changed

Lines changed: 243 additions & 41 deletions

File tree

riva/clients/tts/riva_tts_client.cc

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ namespace nr = nvidia::riva;
2929
namespace nr_tts = nvidia::riva::tts;
3030

3131
DEFINE_string(text, "", "Text to be synthesized");
32+
DEFINE_string(
33+
text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set.");
3234
DEFINE_string(audio_file, "output.wav", "Output file");
3335
DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
3436
DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port");
@@ -37,6 +39,7 @@ DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key");
3739
DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file");
3840
DEFINE_int32(rate, 44100, "Sample rate for the TTS output");
3941
DEFINE_bool(online, false, "Whether synthesis should be online or batch");
42+
DEFINE_bool(streaming, false, "Whether synthesis should be streaming or batch");
4043
DEFINE_string(
4144
language, "en-US",
4245
"Language code as per [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag.");
@@ -101,13 +104,15 @@ main(int argc, char** argv)
101104
std::stringstream str_usage;
102105
str_usage << "Usage: riva_tts_client " << std::endl;
103106
str_usage << " --text=<text> " << std::endl;
107+
str_usage << " --text_file=<filename> " << std::endl;
104108
str_usage << " --audio_file=<filename> " << std::endl;
105109
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
106110
str_usage << " --riva_uri=<server_name:port> " << std::endl;
107111
str_usage << " --rate=<sample_rate> " << std::endl;
108112
str_usage << " --language=<language-code> " << std::endl;
109113
str_usage << " --voice_name=<voice-name> " << std::endl;
110114
str_usage << " --online=<true|false> " << std::endl;
115+
str_usage << " --streaming=<true|false> " << std::endl;
111116
str_usage << " --ssl_root_cert=<filename>" << std::endl;
112117
str_usage << " --ssl_client_key=<filename>" << std::endl;
113118
str_usage << " --ssl_client_cert=<filename>" << std::endl;
@@ -134,10 +139,26 @@ main(int argc, char** argv)
134139
}
135140

136141
auto text = FLAGS_text;
137-
if (text.length() == 0) {
138-
LOG(ERROR) << "Input text cannot be empty." << std::endl;
142+
auto text_file = FLAGS_text_file;
143+
std::vector<std::string> text_lines;
144+
if (text.length() == 0 && text_file.length() == 0) {
145+
LOG(ERROR) << "Input text or text file cannot be empty." << std::endl;
146+
return -1;
147+
}
148+
if (text.length() > 0 && text_file.length() > 0) {
149+
LOG(ERROR) << "Only one of text or text file can be provided." << std::endl;
139150
return -1;
140151
}
152+
if (text_file.length() > 0) {
153+
std::ifstream infile(text_file);
154+
if (infile.is_open()) {
155+
std::string line;
156+
while (std::getline(infile, line)) {
157+
text_lines.push_back(line);
158+
text += line + " ";
159+
}
160+
}
161+
}
141162

142163
bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default;
143164
const char* riva_uri = getenv("RIVA_URI");
@@ -152,7 +173,8 @@ main(int argc, char** argv)
152173
auto creds = riva::clients::CreateChannelCredentials(
153174
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
154175
FLAGS_metadata);
155-
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
176+
grpc_channel = riva::clients::CreateChannelBlocking(
177+
FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
156178
}
157179
catch (const std::exception& e) {
158180
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
@@ -251,7 +273,7 @@ main(int argc, char** argv)
251273
decoder.DeserializeOpus(std::vector<unsigned char>(ptr, ptr + audio.size())));
252274
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
253275
}
254-
} else { // online inference
276+
} else if (FLAGS_online && not FLAGS_streaming) { // batch inference
255277
if (not FLAGS_zero_shot_transcript.empty()) {
256278
LOG(ERROR) << "Zero shot transcript is not supported for streaming inference.";
257279
return -1;
@@ -261,8 +283,11 @@ main(int argc, char** argv)
261283
size_t audio_len = 0;
262284
nr_tts::SynthesizeSpeechResponse chunk;
263285
auto start = std::chrono::steady_clock::now();
264-
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
265-
tts->SynthesizeOnline(&context, request));
286+
std::unique_ptr<
287+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
288+
reader(tts->SynthesizeOnline(&context));
289+
reader->Write(request);
290+
reader->WritesDone();
266291
while (reader->Read(&chunk)) {
267292
// Copy chunk to local buffer
268293
if (audio_len == 0) {
@@ -295,6 +320,61 @@ main(int argc, char** argv)
295320
return -1;
296321
}
297322

323+
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
324+
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
325+
} else if (FLAGS_audio_encoding == "opus") {
326+
riva::utils::opus::Decoder decoder(rate, 1);
327+
auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer));
328+
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
329+
}
330+
} else if (FLAGS_online && FLAGS_streaming) { // streaming inference
331+
332+
std::vector<int16_t> pcm_buffer;
333+
std::vector<unsigned char> opus_buffer;
334+
size_t audio_len = 0;
335+
nr_tts::SynthesizeSpeechResponse chunk;
336+
auto start = std::chrono::steady_clock::now();
337+
std::unique_ptr<
338+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
339+
reader(tts->SynthesizeOnline(&context));
340+
for (const auto& line : text_lines) {
341+
request.set_text(line);
342+
reader->Write(request);
343+
}
344+
reader->WritesDone();
345+
while (reader->Read(&chunk)) {
346+
// Copy chunk to local buffer
347+
if (audio_len == 0) {
348+
auto t_first_audio = std::chrono::steady_clock::now();
349+
std::chrono::duration<double> elapsed_first_audio = t_first_audio - start;
350+
LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl;
351+
}
352+
LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl;
353+
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
354+
int16_t* audio_data = (int16_t*)chunk.audio().data();
355+
size_t len = chunk.audio().length() / sizeof(int16_t);
356+
std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer));
357+
audio_len += len;
358+
} else if (FLAGS_audio_encoding == "opus") {
359+
const unsigned char* opus_data = (unsigned char*)chunk.audio().data();
360+
size_t len = chunk.audio().length();
361+
std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer));
362+
audio_len += len;
363+
}
364+
}
365+
grpc::Status rpc_status = reader->Finish();
366+
auto end = std::chrono::steady_clock::now();
367+
std::chrono::duration<double> elapsed_total = end - start;
368+
LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl;
369+
370+
if (!rpc_status.ok()) {
371+
// Report the RPC failure.
372+
LOG(ERROR) << rpc_status.error_message() << std::endl;
373+
LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl;
374+
LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl;
375+
return -1;
376+
}
377+
298378
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
299379
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
300380
} else if (FLAGS_audio_encoding == "opus") {

riva/clients/tts/riva_tts_perf_client.cc

Lines changed: 157 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
3838
DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port");
3939
DEFINE_int32(rate, 44100, "Sample rate for the TTS output");
4040
DEFINE_bool(online, false, "Whether synthesis should be online or batch");
41+
DEFINE_bool(streaming, false, "Whether synthesis should be streaming input");
4142
DEFINE_bool(
4243
write_output_audio, false,
4344
"Whether to dump output audio or not. When true, throughput and latency are not reported.");
@@ -47,6 +48,7 @@ DEFINE_string(
4748
DEFINE_string(voice_name, "", "Desired voice name");
4849
DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files");
4950
DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight");
51+
DEFINE_int32(num_sentences, 1, "Number of sentences to send");
5052
DEFINE_int32(throttle_milliseconds, 0, "Number of milliseconds to sleep for between TTS requests");
5153
DEFINE_int32(offset_milliseconds, 0, "Number of milliseconds to offset each parallel TTS requests");
5254
DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file");
@@ -260,8 +262,11 @@ synthesizeOnline(
260262
nr_tts::SynthesizeSpeechResponse chunk;
261263

262264
auto start = std::chrono::steady_clock::now();
263-
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
264-
tts->SynthesizeOnline(&context, request));
265+
std::unique_ptr<
266+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
267+
reader(tts->SynthesizeOnline(&context));
268+
reader->Write(request);
269+
reader->WritesDone();
265270
DLOG(INFO) << "Sending request for input \"" << text << "\".";
266271

267272
std::vector<int16_t> buffer;
@@ -344,9 +349,11 @@ main(int argc, char** argv)
344349
str_usage << " --language=<language-code> " << std::endl;
345350
str_usage << " --voice_name=<voice-name> " << std::endl;
346351
str_usage << " --online=<true|false> " << std::endl;
352+
str_usage << " --streaming=<true|false> " << std::endl;
347353
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
348354
str_usage << " --num_parallel_requests=<num-parallel-reqs> " << std::endl;
349355
str_usage << " --num_iterations=<num-iterations> " << std::endl;
356+
str_usage << " --num_sentences=<num-sentences> " << std::endl;
350357
str_usage << " --throttle_milliseconds=<throttle-milliseconds> " << std::endl;
351358
str_usage << " --offset_milliseconds=<offset-milliseconds> " << std::endl;
352359
str_usage << " --ssl_root_cert=<filename>" << std::endl;
@@ -404,7 +411,7 @@ main(int argc, char** argv)
404411

405412
// open text file, load sentences as a vector
406413
int count = 0;
407-
for (int i = 0; i < FLAGS_num_iterations; i++) {
414+
for (int i = 0; i < FLAGS_num_iterations * FLAGS_num_sentences; i++) {
408415
std::ifstream file(text_file);
409416
while (std::getline(file, sentence)) {
410417
if (sentence.find("|") != std::string::npos) {
@@ -458,38 +465,154 @@ main(int argc, char** argv)
458465
usleep(i * FLAGS_offset_milliseconds * 1000);
459466
auto start_time = std::chrono::steady_clock::now();
460467

461-
for (size_t s = 0; s < sentences[i].size(); s++) {
462-
auto current_time = std::chrono::steady_clock::now();
463-
double diff_time =
464-
std::chrono::duration<double, std::milli>(current_time - start_time).count();
465-
double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time;
466-
467-
// To nanoseconds
468-
wait_time *= 1.e3;
469-
wait_time = std::max(wait_time, 0.);
470-
471-
// Round to nearest integer
472-
wait_time = wait_time + 0.5 - (wait_time < 0);
473-
int64_t usecs = (int64_t)wait_time;
474-
// Sleep
475-
if (usecs > 0) {
476-
usleep(usecs);
477-
}
478-
468+
if (FLAGS_streaming) {
469+
// Streaming mode: send all sentences in one stream
479470
auto tts = CreateTTS(grpc_channel);
480-
double time_to_first_chunk = 0.;
481-
auto time_to_next_chunk = new std::vector<double>();
482-
size_t num_samples = 0;
483-
synthesizeOnline(
484-
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
485-
&time_to_first_chunk, time_to_next_chunk, &num_samples,
486-
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
487-
FLAGS_zero_shot_quality);
488-
latencies_first_chunk[i]->push_back(time_to_first_chunk);
489-
latencies_next_chunks[i]->insert(
490-
latencies_next_chunks[i]->end(), time_to_next_chunk->begin(),
491-
time_to_next_chunk->end());
492-
lengths[i]->push_back(num_samples);
471+
472+
nr_tts::SynthesizeSpeechRequest request;
473+
request.set_language_code(FLAGS_language);
474+
request.set_sample_rate_hz(rate);
475+
request.set_voice_name(FLAGS_voice_name);
476+
477+
auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED;
478+
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
479+
ae = nr::LINEAR_PCM;
480+
} else if (FLAGS_audio_encoding == "opus") {
481+
ae = nr::OGGOPUS;
482+
} else {
483+
std::cerr << "Unsupported encoding: \'" << FLAGS_audio_encoding << "\'" << std::endl;
484+
return;
485+
}
486+
request.set_encoding(ae);
487+
488+
if (not FLAGS_zero_shot_audio_prompt.empty()) {
489+
auto zero_shot_data = request.mutable_zero_shot_data();
490+
std::vector<std::shared_ptr<WaveData>> audio_prompt;
491+
LoadWavData(audio_prompt, FLAGS_zero_shot_audio_prompt);
492+
if (audio_prompt.size() != 1) {
493+
LOG(ERROR) << "Unsupported number of audio prompts. Need exactly 1 audio prompt."
494+
<< std::endl;
495+
return;
496+
}
497+
498+
if (audio_prompt[0]->encoding != nr::LINEAR_PCM && audio_prompt[0]->encoding != nr::OGGOPUS) {
499+
LOG(ERROR) << "Unsupported encoding for zero shot prompt: \'" << audio_prompt[0]->encoding
500+
<< "\'";
501+
std::cerr << "Unsupported encoding for zero shot prompt: \'" << audio_prompt[0]->encoding
502+
<< "\'" << std::endl;
503+
return;
504+
}
505+
zero_shot_data->set_audio_prompt(&audio_prompt[0]->data[0], audio_prompt[0]->data.size());
506+
int32_t zero_shot_sample_rate = audio_prompt[0]->sample_rate;
507+
zero_shot_data->set_encoding(audio_prompt[0]->encoding);
508+
if (audio_prompt[0]->encoding == nr::OGGOPUS) {
509+
zero_shot_sample_rate =
510+
riva::utils::opus::Decoder::AdjustRateIfUnsupported(zero_shot_sample_rate);
511+
}
512+
zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate);
513+
zero_shot_data->set_quality(FLAGS_zero_shot_quality);
514+
}
515+
516+
grpc::ClientContext context;
517+
nr_tts::SynthesizeSpeechResponse chunk;
518+
auto stream_start = std::chrono::steady_clock::now();
519+
520+
std::unique_ptr<
521+
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
522+
reader(tts->SynthesizeOnline(&context));
523+
524+
// Write all sentences to the stream
525+
for (size_t s = 0; s < sentences[i].size(); s++) {
526+
request.set_text(sentences[i][s].second);
527+
reader->Write(request);
528+
}
529+
reader->WritesDone();
530+
531+
std::vector<int16_t> buffer;
532+
size_t audio_len = 0;
533+
riva::utils::opus::Decoder opus_decoder(rate, 1);
534+
535+
// Read all responses
536+
while (reader->Read(&chunk)) {
537+
size_t len = 0U;
538+
if (ae == nr::OGGOPUS) {
539+
const unsigned char* opus_data = (unsigned char*)chunk.audio().data();
540+
len = chunk.audio().length();
541+
auto pcm = opus_decoder.DecodePcm(
542+
opus_decoder.DeserializeOpus(std::vector<unsigned char>(opus_data, opus_data + len)));
543+
len = pcm.size();
544+
std::copy(pcm.cbegin(), pcm.cend(), std::back_inserter(buffer));
545+
} else {
546+
int16_t* audio_data;
547+
audio_data = (int16_t*)chunk.audio().data();
548+
len = chunk.audio().length() / sizeof(int16_t);
549+
std::copy(audio_data, audio_data + len, std::back_inserter(buffer));
550+
}
551+
552+
if (audio_len == 0) {
553+
auto t_next_audio = std::chrono::steady_clock::now();
554+
std::chrono::duration<double> elapsed_first_audio = t_next_audio - stream_start;
555+
latencies_first_chunk[i]->push_back(elapsed_first_audio.count());
556+
stream_start = t_next_audio;
557+
} else {
558+
auto t_next_audio = std::chrono::steady_clock::now();
559+
std::chrono::duration<double> elapsed_next_audio = t_next_audio - stream_start;
560+
time_to_next_chunks->push_back(elapsed_next_audio.count());
561+
stream_start = t_next_audio;
562+
}
563+
audio_len += len;
564+
}
565+
566+
grpc::Status rpc_status = reader->Finish();
567+
568+
if (!rpc_status.ok()) {
569+
std::cerr << rpc_status.error_message() << std::endl;
570+
std::cerr << "Streaming input failed for worker " << i << std::endl;
571+
} else {
572+
lengths[i]->push_back(audio_len);
573+
latencies_next_chunks[i]->insert(
574+
latencies_next_chunks[i]->end(), time_to_next_chunks->begin(),
575+
time_to_next_chunks->end());
576+
if (FLAGS_write_output_audio) {
577+
::riva::utils::wav::Write(
578+
"worker_" + std::to_string(i) + ".wav", rate, buffer.data(), buffer.size());
579+
}
580+
}
581+
} else {
582+
// Non-streaming mode: send one sentence per stream
583+
for (size_t s = 0; s < sentences[i].size(); s++) {
584+
auto current_time = std::chrono::steady_clock::now();
585+
double diff_time =
586+
std::chrono::duration<double, std::milli>(current_time - start_time).count();
587+
double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time;
588+
589+
// To nanoseconds
590+
wait_time *= 1.e3;
591+
wait_time = std::max(wait_time, 0.);
592+
593+
// Round to nearest integer
594+
wait_time = wait_time + 0.5 - (wait_time < 0);
595+
int64_t usecs = (int64_t)wait_time;
596+
// Sleep
597+
if (usecs > 0) {
598+
usleep(usecs);
599+
}
600+
601+
auto tts = CreateTTS(grpc_channel);
602+
double time_to_first_chunk = 0.;
603+
auto time_to_next_chunk = new std::vector<double>();
604+
size_t num_samples = 0;
605+
synthesizeOnline(
606+
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
607+
&time_to_first_chunk, time_to_next_chunk, &num_samples,
608+
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
609+
FLAGS_zero_shot_quality);
610+
latencies_first_chunk[i]->push_back(time_to_first_chunk);
611+
latencies_next_chunks[i]->insert(
612+
latencies_next_chunks[i]->end(), time_to_next_chunk->begin(),
613+
time_to_next_chunk->end());
614+
lengths[i]->push_back(num_samples);
615+
}
493616
}
494617
}));
495618
}
@@ -588,4 +711,3 @@ main(int argc, char** argv)
588711
}
589712
return STATUS;
590713
}
591-

0 commit comments

Comments
 (0)