@@ -38,6 +38,7 @@ DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
3838DEFINE_string (riva_uri, " localhost:50051" , " Riva API server URI and port" );
3939DEFINE_int32 (rate, 44100 , " Sample rate for the TTS output" );
4040DEFINE_bool (online, false , " Whether synthesis should be online or batch" );
41+ DEFINE_bool (streaming, false , " Whether synthesis should be streaming input" );
4142DEFINE_bool (
4243 write_output_audio, false ,
4344 " Whether to dump output audio or not. When true, throughput and latency are not reported." );
@@ -47,6 +48,7 @@ DEFINE_string(
4748DEFINE_string (voice_name, " " , " Desired voice name" );
4849DEFINE_int32 (num_iterations, 1 , " Number of times to loop over audio files" );
4950DEFINE_int32 (num_parallel_requests, 1 , " Number of parallel requests to keep in flight" );
51+ DEFINE_int32 (num_sentences, 1 , " Number of sentences to send" );
5052DEFINE_int32 (throttle_milliseconds, 0 , " Number of milliseconds to sleep for between TTS requests" );
5153DEFINE_int32 (offset_milliseconds, 0 , " Number of milliseconds to offset each parallel TTS requests" );
5254DEFINE_string (ssl_root_cert, " " , " Path to SSL root certificates file" );
@@ -260,8 +262,11 @@ synthesizeOnline(
260262 nr_tts::SynthesizeSpeechResponse chunk;
261263
262264 auto start = std::chrono::steady_clock::now ();
263- std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader (
264- tts->SynthesizeOnline (&context, request));
265+ std::unique_ptr<
266+ grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
267+ reader (tts->SynthesizeOnline (&context));
268+ reader->Write (request);
269+ reader->WritesDone ();
265270 DLOG (INFO) << " Sending request for input \" " << text << " \" ." ;
266271
267272 std::vector<int16_t > buffer;
@@ -344,9 +349,11 @@ main(int argc, char** argv)
344349 str_usage << " --language=<language-code> " << std::endl;
345350 str_usage << " --voice_name=<voice-name> " << std::endl;
346351 str_usage << " --online=<true|false> " << std::endl;
352+ str_usage << " --streaming=<true|false> " << std::endl;
347353 str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
348354 str_usage << " --num_parallel_requests=<num-parallel-reqs> " << std::endl;
349355 str_usage << " --num_iterations=<num-iterations> " << std::endl;
356+ str_usage << " --num_sentences=<num-sentences> " << std::endl;
350357 str_usage << " --throttle_milliseconds=<throttle-milliseconds> " << std::endl;
351358 str_usage << " --offset_milliseconds=<offset-milliseconds> " << std::endl;
352359 str_usage << " --ssl_root_cert=<filename>" << std::endl;
@@ -404,7 +411,7 @@ main(int argc, char** argv)
404411
405412 // open text file, load sentences as a vector
406413 int count = 0 ;
407- for (int i = 0 ; i < FLAGS_num_iterations; i++) {
414+ for (int i = 0 ; i < FLAGS_num_iterations * FLAGS_num_sentences ; i++) {
408415 std::ifstream file (text_file);
409416 while (std::getline (file, sentence)) {
410417 if (sentence.find (" |" ) != std::string::npos) {
@@ -458,38 +465,154 @@ main(int argc, char** argv)
458465 usleep (i * FLAGS_offset_milliseconds * 1000 );
459466 auto start_time = std::chrono::steady_clock::now ();
460467
461- for (size_t s = 0 ; s < sentences[i].size (); s++) {
462- auto current_time = std::chrono::steady_clock::now ();
463- double diff_time =
464- std::chrono::duration<double , std::milli>(current_time - start_time).count ();
465- double wait_time = (s + 1 ) * FLAGS_throttle_milliseconds - diff_time;
466-
467- // To nanoseconds
468- wait_time *= 1 .e3 ;
469- wait_time = std::max (wait_time, 0 .);
470-
471- // Round to nearest integer
472- wait_time = wait_time + 0.5 - (wait_time < 0 );
473- int64_t usecs = (int64_t )wait_time;
474- // Sleep
475- if (usecs > 0 ) {
476- usleep (usecs);
477- }
478-
468+ if (FLAGS_streaming) {
469+ // Streaming mode: send all sentences in one stream
479470 auto tts = CreateTTS (grpc_channel);
480- double time_to_first_chunk = 0 .;
481- auto time_to_next_chunk = new std::vector<double >();
482- size_t num_samples = 0 ;
483- synthesizeOnline (
484- std::move (tts), sentences[i][s].second , FLAGS_language, rate, FLAGS_voice_name,
485- &time_to_first_chunk, time_to_next_chunk, &num_samples,
486- std::to_string (sentences[i][s].first ) + " .wav" , FLAGS_zero_shot_audio_prompt,
487- FLAGS_zero_shot_quality);
488- latencies_first_chunk[i]->push_back (time_to_first_chunk);
489- latencies_next_chunks[i]->insert (
490- latencies_next_chunks[i]->end (), time_to_next_chunk->begin (),
491- time_to_next_chunk->end ());
492- lengths[i]->push_back (num_samples);
471+
472+ nr_tts::SynthesizeSpeechRequest request;
473+ request.set_language_code (FLAGS_language);
474+ request.set_sample_rate_hz (rate);
475+ request.set_voice_name (FLAGS_voice_name);
476+
477+ auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED;
478+ if (FLAGS_audio_encoding.empty () || FLAGS_audio_encoding == " pcm" ) {
479+ ae = nr::LINEAR_PCM;
480+ } else if (FLAGS_audio_encoding == " opus" ) {
481+ ae = nr::OGGOPUS;
482+ } else {
483+ std::cerr << " Unsupported encoding: \' " << FLAGS_audio_encoding << " \' " << std::endl;
484+ return ;
485+ }
486+ request.set_encoding (ae);
487+
488+ if (not FLAGS_zero_shot_audio_prompt.empty ()) {
489+ auto zero_shot_data = request.mutable_zero_shot_data ();
490+ std::vector<std::shared_ptr<WaveData>> audio_prompt;
491+ LoadWavData (audio_prompt, FLAGS_zero_shot_audio_prompt);
492+ if (audio_prompt.size () != 1 ) {
493+ LOG (ERROR) << " Unsupported number of audio prompts. Need exactly 1 audio prompt."
494+ << std::endl;
495+ return ;
496+ }
497+
498+ if (audio_prompt[0 ]->encoding != nr::LINEAR_PCM && audio_prompt[0 ]->encoding != nr::OGGOPUS) {
499+ LOG (ERROR) << " Unsupported encoding for zero shot prompt: \' " << audio_prompt[0 ]->encoding
500+ << " \' " ;
501+ std::cerr << " Unsupported encoding for zero shot prompt: \' " << audio_prompt[0 ]->encoding
502+ << " \' " << std::endl;
503+ return ;
504+ }
505+ zero_shot_data->set_audio_prompt (&audio_prompt[0 ]->data [0 ], audio_prompt[0 ]->data .size ());
506+ int32_t zero_shot_sample_rate = audio_prompt[0 ]->sample_rate ;
507+ zero_shot_data->set_encoding (audio_prompt[0 ]->encoding );
508+ if (audio_prompt[0 ]->encoding == nr::OGGOPUS) {
509+ zero_shot_sample_rate =
510+ riva::utils::opus::Decoder::AdjustRateIfUnsupported (zero_shot_sample_rate);
511+ }
512+ zero_shot_data->set_sample_rate_hz (zero_shot_sample_rate);
513+ zero_shot_data->set_quality (FLAGS_zero_shot_quality);
514+ }
515+
516+ grpc::ClientContext context;
517+ nr_tts::SynthesizeSpeechResponse chunk;
518+ auto stream_start = std::chrono::steady_clock::now ();
519+
520+ std::unique_ptr<
521+ grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
522+ reader (tts->SynthesizeOnline (&context));
523+
524+ // Write all sentences to the stream
525+ for (size_t s = 0 ; s < sentences[i].size (); s++) {
526+ request.set_text (sentences[i][s].second );
527+ reader->Write (request);
528+ }
529+ reader->WritesDone ();
530+
531+ std::vector<int16_t > buffer;
532+ size_t audio_len = 0 ;
533+ riva::utils::opus::Decoder opus_decoder (rate, 1 );
534+
535+ // Read all responses
536+ while (reader->Read (&chunk)) {
537+ size_t len = 0U ;
538+ if (ae == nr::OGGOPUS) {
539+ const unsigned char * opus_data = (unsigned char *)chunk.audio ().data ();
540+ len = chunk.audio ().length ();
541+ auto pcm = opus_decoder.DecodePcm (
542+ opus_decoder.DeserializeOpus (std::vector<unsigned char >(opus_data, opus_data + len)));
543+ len = pcm.size ();
544+ std::copy (pcm.cbegin (), pcm.cend (), std::back_inserter (buffer));
545+ } else {
546+ int16_t * audio_data;
547+ audio_data = (int16_t *)chunk.audio ().data ();
548+ len = chunk.audio ().length () / sizeof (int16_t );
549+ std::copy (audio_data, audio_data + len, std::back_inserter (buffer));
550+ }
551+
552+ if (audio_len == 0 ) {
553+ auto t_next_audio = std::chrono::steady_clock::now ();
554+ std::chrono::duration<double > elapsed_first_audio = t_next_audio - stream_start;
555+ latencies_first_chunk[i]->push_back (elapsed_first_audio.count ());
556+ stream_start = t_next_audio;
557+ } else {
558+ auto t_next_audio = std::chrono::steady_clock::now ();
559+ std::chrono::duration<double > elapsed_next_audio = t_next_audio - stream_start;
560+ time_to_next_chunks->push_back (elapsed_next_audio.count ());
561+ stream_start = t_next_audio;
562+ }
563+ audio_len += len;
564+ }
565+
566+ grpc::Status rpc_status = reader->Finish ();
567+
568+ if (!rpc_status.ok ()) {
569+ std::cerr << rpc_status.error_message () << std::endl;
570+ std::cerr << " Streaming input failed for worker " << i << std::endl;
571+ } else {
572+ lengths[i]->push_back (audio_len);
573+ latencies_next_chunks[i]->insert (
574+ latencies_next_chunks[i]->end (), time_to_next_chunks->begin (),
575+ time_to_next_chunks->end ());
576+ if (FLAGS_write_output_audio) {
577+ ::riva::utils::wav::Write (
578+ " worker_" + std::to_string(i) + ".wav", rate, buffer.data(), buffer.size());
579+ }
580+ }
581+ } else {
582+ // Non-streaming mode: send one sentence per stream
583+ for (size_t s = 0 ; s < sentences[i].size (); s++) {
584+ auto current_time = std::chrono::steady_clock::now ();
585+ double diff_time =
586+ std::chrono::duration<double , std::milli>(current_time - start_time).count ();
587+ double wait_time = (s + 1 ) * FLAGS_throttle_milliseconds - diff_time;
588+
589+ // To nanoseconds
590+ wait_time *= 1 .e3 ;
591+ wait_time = std::max (wait_time, 0 .);
592+
593+ // Round to nearest integer
594+ wait_time = wait_time + 0.5 - (wait_time < 0 );
595+ int64_t usecs = (int64_t )wait_time;
596+ // Sleep
597+ if (usecs > 0 ) {
598+ usleep (usecs);
599+ }
600+
601+ auto tts = CreateTTS (grpc_channel);
602+ double time_to_first_chunk = 0 .;
603+ auto time_to_next_chunk = new std::vector<double >();
604+ size_t num_samples = 0 ;
605+ synthesizeOnline (
606+ std::move (tts), sentences[i][s].second , FLAGS_language, rate, FLAGS_voice_name,
607+ &time_to_first_chunk, time_to_next_chunk, &num_samples,
608+ std::to_string (sentences[i][s].first ) + " .wav" , FLAGS_zero_shot_audio_prompt,
609+ FLAGS_zero_shot_quality);
610+ latencies_first_chunk[i]->push_back (time_to_first_chunk);
611+ latencies_next_chunks[i]->insert (
612+ latencies_next_chunks[i]->end (), time_to_next_chunk->begin (),
613+ time_to_next_chunk->end ());
614+ lengths[i]->push_back (num_samples);
615+ }
493616 }
494617 }));
495618 }
@@ -588,4 +711,3 @@ main(int argc, char** argv)
588711 }
589712 return STATUS;
590713}
591-
0 commit comments