From 548667aea9ef117601de81ea06545c34f2ace895 Mon Sep 17 00:00:00 2001 From: mohnishp Date: Fri, 19 Dec 2025 11:23:38 +0530 Subject: [PATCH] chore: Update TTS client --- WORKSPACE | 2 +- riva/clients/tts/riva_tts_client.cc | 93 ++++++++++++++++++++++-- riva/clients/tts/riva_tts_perf_client.cc | 55 ++++++++++---- 3 files changed, 128 insertions(+), 22 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index ac622ea..d89c586 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -77,7 +77,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/nvidia-riva/common.git", - commit = "1301af41cbf429dda8204b22d817c0e17cf8b369" + commit = "da9435bb6fbfdaaf5bc3d48452795c93c23d09a9" ) http_archive( diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 05d9f94..07355aa 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -29,6 +29,8 @@ namespace nr = nvidia::riva; namespace nr_tts = nvidia::riva::tts; DEFINE_string(text, "", "Text to be synthesized"); +DEFINE_string( + text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set."); DEFINE_string(audio_file, "output.wav", "Output file"); DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)"); DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port"); @@ -101,6 +103,7 @@ main(int argc, char** argv) std::stringstream str_usage; str_usage << "Usage: riva_tts_client " << std::endl; str_usage << " --text= " << std::endl; + str_usage << " --text_file= " << std::endl; str_usage << " --audio_file= " << std::endl; str_usage << " --audio_encoding= " << std::endl; str_usage << " --riva_uri= " << std::endl; @@ -134,10 +137,26 @@ main(int argc, char** argv) } auto text = FLAGS_text; - if (text.length() == 0) { - LOG(ERROR) << "Input text cannot be empty." << std::endl; + auto text_file = FLAGS_text_file; + std::vector text_lines; + if (text.length() == 0 && text_file.length() == 0) { + LOG(ERROR) << "Input text or text file cannot be empty." << std::endl; + return -1; + } + if (text.length() > 0 && text_file.length() > 0) { + LOG(ERROR) << "Only one of text or text file can be provided." << std::endl; return -1; } + if (text_file.length() > 0) { + std::ifstream infile(text_file); + if (infile.is_open()) { + std::string line; + while (std::getline(infile, line)) { + text_lines.push_back(line); + text += line + " "; + } + } + } bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default; const char* riva_uri = getenv("RIVA_URI"); @@ -152,7 +171,8 @@ main(int argc, char** argv) auto creds = riva::clients::CreateChannelCredentials( FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert, FLAGS_metadata); - grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); + grpc_channel = riva::clients::CreateChannelBlocking( + FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size); } catch (const std::exception& e) { std::cerr << "Error creating GRPC channel: " << e.what() << std::endl; @@ -251,7 +271,7 @@ main(int argc, char** argv) decoder.DeserializeOpus(std::vector(ptr, ptr + audio.size()))); ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size()); } - } else { // online inference + } else if (FLAGS_online && text_lines.size() == 0) { // if text_lines is empty, the query is already stored in the request object if (not FLAGS_zero_shot_transcript.empty()) { LOG(ERROR) << "Zero shot transcript is not supported for streaming inference."; return -1; @@ -261,8 +281,11 @@ main(int argc, char** argv) size_t audio_len = 0; nr_tts::SynthesizeSpeechResponse chunk; auto start = std::chrono::steady_clock::now(); - std::unique_ptr> reader( - tts->SynthesizeOnline(&context, request)); + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + reader->Write(request); + reader->WritesDone(); while (reader->Read(&chunk)) { // Copy chunk to local buffer if (audio_len == 0) { @@ -295,6 +318,64 @@ main(int argc, char** argv) return -1; } + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { + ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size()); + } else if (FLAGS_audio_encoding == "opus") { + riva::utils::opus::Decoder decoder(rate, 1); + auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer)); + ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size()); + } + } else if (FLAGS_online && text_lines.size() > 0) { // streaming inference + std::vector pcm_buffer; + std::vector opus_buffer; + size_t audio_len = 0; + nr_tts::SynthesizeSpeechResponse chunk; + auto start = std::chrono::steady_clock::now(); + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + for (const auto& line : text_lines) { + if (line.find("|") != std::string::npos) { + request.set_text(line.substr(line.find("|") + 1, line.length())); + } else { + request.set_text(line); + } + reader->Write(request); + } + reader->WritesDone(); + while (reader->Read(&chunk)) { + // Copy chunk to local buffer + if (audio_len == 0) { + auto t_first_audio = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_first_audio = t_first_audio - start; + LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl; + } + LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl; + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { + int16_t* audio_data = (int16_t*)chunk.audio().data(); + size_t len = chunk.audio().length() / sizeof(int16_t); + std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer)); + audio_len += len; + } else if (FLAGS_audio_encoding == "opus") { + const unsigned char* opus_data = (unsigned char*)chunk.audio().data(); + size_t len = chunk.audio().length(); + std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer)); + audio_len += len; + } + } + grpc::Status rpc_status = reader->Finish(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_total = end - start; + LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl; + + if (!rpc_status.ok()) { + // Report the RPC failure. + LOG(ERROR) << rpc_status.error_message() << std::endl; + LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl; + LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl; + return -1; + } + if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { ::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size()); } else if (FLAGS_audio_encoding == "opus") { diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 0f15e6e..2fc7dc8 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -47,6 +47,7 @@ DEFINE_string( DEFINE_string(voice_name, "", "Desired voice name"); DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files"); DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight"); +DEFINE_int32(num_sentences, 1, "Number of sentences to send"); DEFINE_int32(throttle_milliseconds, 0, "Number of milliseconds to sleep for between TTS requests"); DEFINE_int32(offset_milliseconds, 0, "Number of milliseconds to offset each parallel TTS requests"); DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file"); @@ -203,13 +204,12 @@ synthesizeBatch( void synthesizeOnline( - std::unique_ptr tts, std::string text, std::string language, + std::unique_ptr tts, std::vector text, std::string language, uint32_t rate, std::string voice_name, double* time_to_first_chunk, std::vector* time_to_next_chunk, size_t* num_samples, std::string filepath, std::string zero_shot_prompt_filename, int32_t zero_shot_quality) { nr_tts::SynthesizeSpeechRequest request; - request.set_text(text); request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); @@ -260,9 +260,17 @@ synthesizeOnline( nr_tts::SynthesizeSpeechResponse chunk; auto start = std::chrono::steady_clock::now(); - std::unique_ptr> reader( - tts->SynthesizeOnline(&context, request)); - DLOG(INFO) << "Sending request for input \"" << text << "\"."; + std::unique_ptr< + grpc::ClientReaderWriter> + reader(tts->SynthesizeOnline(&context)); + std::string text_complete = ""; + for (const auto& text_line : text) { + request.set_text(text_line); + text_complete += text_line + " "; + reader->Write(request); + } + reader->WritesDone(); + DLOG(INFO) << "Sending request for input \"" << text[0] << "\"."; std::vector buffer; size_t audio_len = 0; @@ -292,7 +300,7 @@ synthesizeOnline( // std::cerr << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl; *time_to_first_chunk = elapsed_first_audio.count(); start = t_next_audio; - DLOG(INFO) << "Received first chunk for input \"" << text << "\"."; + DLOG(INFO) << "Received first chunk for input \"" << text[0] << "\"."; } else { auto t_next_audio = std::chrono::steady_clock::now(); std::chrono::duration elapsed_next_audio = t_next_audio - start; @@ -302,12 +310,12 @@ synthesizeOnline( audio_len += len; } grpc::Status rpc_status = reader->Finish(); - DLOG(INFO) << "Received all chunks for input \"" << text << "\"."; + DLOG(INFO) << "Received all chunks for input \"" << text_complete << "\"."; if (!rpc_status.ok()) { // Report the RPC failure. std::cerr << rpc_status.error_message() << std::endl; - std::cerr << "Input was: \'" << text << "\'" << std::endl; + std::cerr << "Input was: \'" << text_complete << "\'" << std::endl; } else { *num_samples = audio_len; if (FLAGS_write_output_audio) { @@ -347,6 +355,7 @@ main(int argc, char** argv) str_usage << " --audio_encoding= " << std::endl; str_usage << " --num_parallel_requests= " << std::endl; str_usage << " --num_iterations= " << std::endl; + str_usage << " --num_sentences= " << std::endl; str_usage << " --throttle_milliseconds= " << std::endl; str_usage << " --offset_milliseconds= " << std::endl; str_usage << " --ssl_root_cert=" << std::endl; @@ -404,7 +413,7 @@ main(int argc, char** argv) // open text file, load sentences as a vector int count = 0; - for (int i = 0; i < FLAGS_num_iterations; i++) { + for (int i = 0; i < FLAGS_num_iterations * FLAGS_num_sentences; i++) { std::ifstream file(text_file); while (std::getline(file, sentence)) { if (sentence.find("|") != std::string::npos) { @@ -446,6 +455,7 @@ main(int argc, char** argv) std::vector*> lengths; auto start = std::chrono::steady_clock::now(); + std::vector worker_sentence_idx(FLAGS_num_parallel_requests, 0); for (int i = 0; i < FLAGS_num_parallel_requests; i++) { auto time_to_first_chunks = new std::vector(); @@ -458,11 +468,12 @@ main(int argc, char** argv) usleep(i * FLAGS_offset_milliseconds * 1000); auto start_time = std::chrono::steady_clock::now(); - for (size_t s = 0; s < sentences[i].size(); s++) { + int batch_count = 0; + while (worker_sentence_idx[i] < sentences[i].size()) { auto current_time = std::chrono::steady_clock::now(); double diff_time = std::chrono::duration(current_time - start_time).count(); - double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time; + double wait_time = (batch_count + 1) * FLAGS_throttle_milliseconds - diff_time; // To nanoseconds wait_time *= 1.e3; @@ -479,17 +490,30 @@ main(int argc, char** argv) auto tts = CreateTTS(grpc_channel); double time_to_first_chunk = 0.; auto time_to_next_chunk = new std::vector(); + + std::vector texts; + std::string text_complete = ""; + int count = sentences[i][worker_sentence_idx[i]].first; + for (int j = 0; j < FLAGS_num_sentences; j++) { + if (worker_sentence_idx[i] >= sentences[i].size()) { + break; + } + texts.push_back(sentences[i][worker_sentence_idx[i]].second); + text_complete += sentences[i][worker_sentence_idx[i]].second + " "; + worker_sentence_idx[i]++; + } size_t num_samples = 0; synthesizeOnline( - std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, + std::move(tts), texts, FLAGS_language, rate, FLAGS_voice_name, &time_to_first_chunk, time_to_next_chunk, &num_samples, - std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, + std::to_string(count) + ".wav", FLAGS_zero_shot_audio_prompt, FLAGS_zero_shot_quality); latencies_first_chunk[i]->push_back(time_to_first_chunk); latencies_next_chunks[i]->insert( latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), time_to_next_chunk->end()); lengths[i]->push_back(num_samples); + batch_count++; } })); } @@ -555,13 +579,15 @@ main(int argc, char** argv) auto results_num_samples_thread = new std::vector(); results_num_samples.push_back(results_num_samples_thread); workers.push_back(std::thread([&, i]() { + int count = 0; for (size_t s = 0; s < sentences[i].size(); s++) { auto tts = CreateTTS(grpc_channel); int32_t num_samples = synthesizeBatch( std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, - std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, + std::to_string(count) + ".wav", FLAGS_zero_shot_audio_prompt, FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript); results_num_samples[i]->push_back(num_samples); + count++; } })); } @@ -588,4 +614,3 @@ main(int argc, char** argv) } return STATUS; } -