Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ grpc_extra_deps()
git_repository(
name = "nvriva_common",
remote = "https://github.com/nvidia-riva/common.git",
commit = "1301af41cbf429dda8204b22d817c0e17cf8b369"
commit = "da9435bb6fbfdaaf5bc3d48452795c93c23d09a9"
)

http_archive(
Expand Down
93 changes: 87 additions & 6 deletions riva/clients/tts/riva_tts_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace nr = nvidia::riva;
namespace nr_tts = nvidia::riva::tts;

DEFINE_string(text, "", "Text to be synthesized");
DEFINE_string(
text_file, "", "Text file with list of sentences to be synthesized. Ignored if 'text' is set.");
DEFINE_string(audio_file, "output.wav", "Output file");
DEFINE_string(audio_encoding, "pcm", "Audio encoding (pcm or opus)");
DEFINE_string(riva_uri, "localhost:50051", "Riva API server URI and port");
Expand Down Expand Up @@ -101,6 +103,7 @@ main(int argc, char** argv)
std::stringstream str_usage;
str_usage << "Usage: riva_tts_client " << std::endl;
str_usage << " --text=<text> " << std::endl;
str_usage << " --text_file=<filename> " << std::endl;
str_usage << " --audio_file=<filename> " << std::endl;
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
str_usage << " --riva_uri=<server_name:port> " << std::endl;
Expand Down Expand Up @@ -134,10 +137,26 @@ main(int argc, char** argv)
}

auto text = FLAGS_text;
if (text.length() == 0) {
LOG(ERROR) << "Input text cannot be empty." << std::endl;
auto text_file = FLAGS_text_file;
std::vector<std::string> text_lines;
if (text.length() == 0 && text_file.length() == 0) {
LOG(ERROR) << "Input text or text file cannot be empty." << std::endl;
return -1;
}
if (text.length() > 0 && text_file.length() > 0) {
LOG(ERROR) << "Only one of text or text file can be provided." << std::endl;
return -1;
}
if (text_file.length() > 0) {
std::ifstream infile(text_file);
if (infile.is_open()) {
std::string line;
while (std::getline(infile, line)) {
text_lines.push_back(line);
text += line + " ";
}
}
}

bool flag_set = gflags::GetCommandLineFlagInfoOrDie("riva_uri").is_default;
const char* riva_uri = getenv("RIVA_URI");
Expand All @@ -152,7 +171,8 @@ main(int argc, char** argv)
auto creds = riva::clients::CreateChannelCredentials(
FLAGS_use_ssl, FLAGS_ssl_root_cert, FLAGS_ssl_client_key, FLAGS_ssl_client_cert,
FLAGS_metadata);
grpc_channel = riva::clients::CreateChannelBlocking(FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
grpc_channel = riva::clients::CreateChannelBlocking(
FLAGS_riva_uri, creds, FLAGS_timeout_ms, FLAGS_max_grpc_message_size);
}
catch (const std::exception& e) {
std::cerr << "Error creating GRPC channel: " << e.what() << std::endl;
Expand Down Expand Up @@ -251,7 +271,7 @@ main(int argc, char** argv)
decoder.DeserializeOpus(std::vector<unsigned char>(ptr, ptr + audio.size())));
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
}
} else { // online inference
} else if (FLAGS_online && text_lines.size() == 0) { // if text_lines is empty, the query is already stored in the request object
if (not FLAGS_zero_shot_transcript.empty()) {
LOG(ERROR) << "Zero shot transcript is not supported for streaming inference.";
return -1;
Expand All @@ -261,8 +281,11 @@ main(int argc, char** argv)
size_t audio_len = 0;
nr_tts::SynthesizeSpeechResponse chunk;
auto start = std::chrono::steady_clock::now();
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
tts->SynthesizeOnline(&context, request));
std::unique_ptr<
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
reader(tts->SynthesizeOnline(&context));
reader->Write(request);
reader->WritesDone();
while (reader->Read(&chunk)) {
// Copy chunk to local buffer
if (audio_len == 0) {
Expand Down Expand Up @@ -295,6 +318,64 @@ main(int argc, char** argv)
return -1;
}

if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
} else if (FLAGS_audio_encoding == "opus") {
riva::utils::opus::Decoder decoder(rate, 1);
auto pcm = decoder.DecodePcm(decoder.DeserializeOpus(opus_buffer));
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm.data(), pcm.size());
}
} else if (FLAGS_online && text_lines.size() > 0) { // streaming inference
std::vector<int16_t> pcm_buffer;
std::vector<unsigned char> opus_buffer;
size_t audio_len = 0;
nr_tts::SynthesizeSpeechResponse chunk;
auto start = std::chrono::steady_clock::now();
std::unique_ptr<
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
reader(tts->SynthesizeOnline(&context));
for (const auto& line : text_lines) {
if (line.find("|") != std::string::npos) {
request.set_text(line.substr(line.find("|") + 1, line.length()));
} else {
request.set_text(line);
}
reader->Write(request);
}
reader->WritesDone();
while (reader->Read(&chunk)) {
// Copy chunk to local buffer
if (audio_len == 0) {
auto t_first_audio = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed_first_audio = t_first_audio - start;
LOG(INFO) << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl;
}
LOG(INFO) << "Got chunk: " << chunk.audio().size() << " bytes" << std::endl;
if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
int16_t* audio_data = (int16_t*)chunk.audio().data();
size_t len = chunk.audio().length() / sizeof(int16_t);
std::copy(audio_data, audio_data + len, std::back_inserter(pcm_buffer));
audio_len += len;
} else if (FLAGS_audio_encoding == "opus") {
const unsigned char* opus_data = (unsigned char*)chunk.audio().data();
size_t len = chunk.audio().length();
std::copy(opus_data, opus_data + len, std::back_inserter(opus_buffer));
audio_len += len;
}
}
grpc::Status rpc_status = reader->Finish();
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed_total = end - start;
LOG(INFO) << "Total streaming time: " << elapsed_total.count() << " s" << std::endl;

if (!rpc_status.ok()) {
// Report the RPC failure.
LOG(ERROR) << rpc_status.error_message() << std::endl;
LOG(ERROR) << "Input was: " << text_lines.size() << " lines." << std::endl;
LOG(ERROR) << "Input was: \'" << text << "\'" << std::endl;
return -1;
}

if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") {
::riva::utils::wav::Write(FLAGS_audio_file, rate, pcm_buffer.data(), pcm_buffer.size());
} else if (FLAGS_audio_encoding == "opus") {
Expand Down
55 changes: 40 additions & 15 deletions riva/clients/tts/riva_tts_perf_client.cc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohnishparmar can you check build workflow failure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the Proto path needed update

Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ DEFINE_string(
DEFINE_string(voice_name, "", "Desired voice name");
DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files");
DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight");
DEFINE_int32(num_sentences, 1, "Number of sentences to send");
DEFINE_int32(throttle_milliseconds, 0, "Number of milliseconds to sleep for between TTS requests");
DEFINE_int32(offset_milliseconds, 0, "Number of milliseconds to offset each parallel TTS requests");
DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file");
Expand Down Expand Up @@ -203,13 +204,12 @@ synthesizeBatch(

void
synthesizeOnline(
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::string text, std::string language,
std::unique_ptr<nr_tts::RivaSpeechSynthesis::Stub> tts, std::vector<std::string> text, std::string language,
uint32_t rate, std::string voice_name, double* time_to_first_chunk,
std::vector<double>* time_to_next_chunk, size_t* num_samples, std::string filepath,
std::string zero_shot_prompt_filename, int32_t zero_shot_quality)
{
nr_tts::SynthesizeSpeechRequest request;
request.set_text(text);
request.set_language_code(language);
request.set_sample_rate_hz(rate);
request.set_voice_name(voice_name);
Expand Down Expand Up @@ -260,9 +260,17 @@ synthesizeOnline(
nr_tts::SynthesizeSpeechResponse chunk;

auto start = std::chrono::steady_clock::now();
std::unique_ptr<grpc::ClientReader<nr_tts::SynthesizeSpeechResponse>> reader(
tts->SynthesizeOnline(&context, request));
DLOG(INFO) << "Sending request for input \"" << text << "\".";
std::unique_ptr<
grpc::ClientReaderWriter<nr_tts::SynthesizeSpeechRequest, nr_tts::SynthesizeSpeechResponse>>
reader(tts->SynthesizeOnline(&context));
std::string text_complete = "";
for (const auto& text_line : text) {
request.set_text(text_line);
text_complete += text_line + " ";
reader->Write(request);
}
reader->WritesDone();
DLOG(INFO) << "Sending request for input \"" << text[0] << "\".";

std::vector<int16_t> buffer;
size_t audio_len = 0;
Expand Down Expand Up @@ -292,7 +300,7 @@ synthesizeOnline(
// std::cerr << "Time to first chunk: " << elapsed_first_audio.count() << " s" << std::endl;
*time_to_first_chunk = elapsed_first_audio.count();
start = t_next_audio;
DLOG(INFO) << "Received first chunk for input \"" << text << "\".";
DLOG(INFO) << "Received first chunk for input \"" << text[0] << "\".";
} else {
auto t_next_audio = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed_next_audio = t_next_audio - start;
Expand All @@ -302,12 +310,12 @@ synthesizeOnline(
audio_len += len;
}
grpc::Status rpc_status = reader->Finish();
DLOG(INFO) << "Received all chunks for input \"" << text << "\".";
DLOG(INFO) << "Received all chunks for input \"" << text_complete << "\".";

if (!rpc_status.ok()) {
// Report the RPC failure.
std::cerr << rpc_status.error_message() << std::endl;
std::cerr << "Input was: \'" << text << "\'" << std::endl;
std::cerr << "Input was: \'" << text_complete << "\'" << std::endl;
} else {
*num_samples = audio_len;
if (FLAGS_write_output_audio) {
Expand Down Expand Up @@ -347,6 +355,7 @@ main(int argc, char** argv)
str_usage << " --audio_encoding=<pcm|opus> " << std::endl;
str_usage << " --num_parallel_requests=<num-parallel-reqs> " << std::endl;
str_usage << " --num_iterations=<num-iterations> " << std::endl;
str_usage << " --num_sentences=<num-sentences> " << std::endl;
str_usage << " --throttle_milliseconds=<throttle-milliseconds> " << std::endl;
str_usage << " --offset_milliseconds=<offset-milliseconds> " << std::endl;
str_usage << " --ssl_root_cert=<filename>" << std::endl;
Expand Down Expand Up @@ -404,7 +413,7 @@ main(int argc, char** argv)

// open text file, load sentences as a vector
int count = 0;
for (int i = 0; i < FLAGS_num_iterations; i++) {
for (int i = 0; i < FLAGS_num_iterations * FLAGS_num_sentences; i++) {
std::ifstream file(text_file);
while (std::getline(file, sentence)) {
if (sentence.find("|") != std::string::npos) {
Expand Down Expand Up @@ -446,6 +455,7 @@ main(int argc, char** argv)
std::vector<std::vector<size_t>*> lengths;

auto start = std::chrono::steady_clock::now();
std::vector<int> worker_sentence_idx(FLAGS_num_parallel_requests, 0);

for (int i = 0; i < FLAGS_num_parallel_requests; i++) {
auto time_to_first_chunks = new std::vector<double>();
Expand All @@ -458,11 +468,12 @@ main(int argc, char** argv)
usleep(i * FLAGS_offset_milliseconds * 1000);
auto start_time = std::chrono::steady_clock::now();

for (size_t s = 0; s < sentences[i].size(); s++) {
int batch_count = 0;
while (worker_sentence_idx[i] < sentences[i].size()) {
auto current_time = std::chrono::steady_clock::now();
double diff_time =
std::chrono::duration<double, std::milli>(current_time - start_time).count();
double wait_time = (s + 1) * FLAGS_throttle_milliseconds - diff_time;
double wait_time = (batch_count + 1) * FLAGS_throttle_milliseconds - diff_time;

// To nanoseconds
wait_time *= 1.e3;
Expand All @@ -479,17 +490,30 @@ main(int argc, char** argv)
auto tts = CreateTTS(grpc_channel);
double time_to_first_chunk = 0.;
auto time_to_next_chunk = new std::vector<double>();

std::vector<std::string> texts;
std::string text_complete = "";
int count = sentences[i][worker_sentence_idx[i]].first;
for (int j = 0; j < FLAGS_num_sentences; j++) {
if (worker_sentence_idx[i] >= sentences[i].size()) {
break;
}
texts.push_back(sentences[i][worker_sentence_idx[i]].second);
text_complete += sentences[i][worker_sentence_idx[i]].second + " ";
worker_sentence_idx[i]++;
}
size_t num_samples = 0;
synthesizeOnline(
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
std::move(tts), texts, FLAGS_language, rate, FLAGS_voice_name,
&time_to_first_chunk, time_to_next_chunk, &num_samples,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
std::to_string(count) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality);
latencies_first_chunk[i]->push_back(time_to_first_chunk);
latencies_next_chunks[i]->insert(
latencies_next_chunks[i]->end(), time_to_next_chunk->begin(),
time_to_next_chunk->end());
lengths[i]->push_back(num_samples);
batch_count++;
}
}));
}
Expand Down Expand Up @@ -555,13 +579,15 @@ main(int argc, char** argv)
auto results_num_samples_thread = new std::vector<int32_t>();
results_num_samples.push_back(results_num_samples_thread);
workers.push_back(std::thread([&, i]() {
int count = 0;
for (size_t s = 0; s < sentences[i].size(); s++) {
auto tts = CreateTTS(grpc_channel);
int32_t num_samples = synthesizeBatch(
std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name,
std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt,
std::to_string(count) + ".wav", FLAGS_zero_shot_audio_prompt,
FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript);
results_num_samples[i]->push_back(num_samples);
count++;
}
}));
}
Expand All @@ -588,4 +614,3 @@ main(int argc, char** argv)
}
return STATUS;
}