From f6bb43c4102996b2da0468d5e8a3d5a673213294 Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Thu, 24 Jul 2025 20:35:40 +0530 Subject: [PATCH 1/6] chore: intial commit --- Dockerfile | 4 +- WORKSPACE | 8 + riva/clients/asr/BUILD | 25 ++ riva/clients/asr/riva_realtime_asr_client.cc | 206 ++++++++++ riva/clients/realtime/BUILD | 34 ++ riva/clients/realtime/audio_chunks.cpp | 160 ++++++++ riva/clients/realtime/audio_chunks.h | 84 ++++ riva/clients/realtime/base_client.cpp | 191 +++++++++ riva/clients/realtime/base_client.h | 86 ++++ riva/clients/realtime/recognition_client.cpp | 401 +++++++++++++++++++ riva/clients/realtime/recognition_client.h | 98 +++++ riva/utils/stats_builder/BUILD | 7 + riva/utils/stats_builder/stats_builder.cpp | 284 +++++++++++++ riva/utils/stats_builder/stats_builder.h | 147 +++++++ third_party/BUILD.websocketpp | 20 + 15 files changed, 1754 insertions(+), 1 deletion(-) create mode 100644 riva/clients/asr/riva_realtime_asr_client.cc create mode 100644 riva/clients/realtime/BUILD create mode 100644 riva/clients/realtime/audio_chunks.cpp create mode 100644 riva/clients/realtime/audio_chunks.h create mode 100644 riva/clients/realtime/base_client.cpp create mode 100644 riva/clients/realtime/base_client.h create mode 100644 riva/clients/realtime/recognition_client.cpp create mode 100644 riva/clients/realtime/recognition_client.h create mode 100644 riva/utils/stats_builder/BUILD create mode 100644 riva/utils/stats_builder/stats_builder.cpp create mode 100644 riva/utils/stats_builder/stats_builder.h create mode 100644 third_party/BUILD.websocketpp diff --git a/Dockerfile b/Dockerfile index a7c6e0a..c026eb9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && apt-get install -y \ libasound2t64 \ libogg0 \ openssl \ - ca-certificates + ca-certificates \ + libboost-all-dev FROM base AS builddep ARG BAZEL_VERSION @@ -60,6 +61,7 @@ FROM base as riva-clients WORKDIR /work COPY --from=builder /opt/riva/clients/asr/riva_asr_client /usr/local/bin/ +COPY --from=builder /opt/riva/clients/asr/riva_realtime_asr_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/asr/riva_streaming_asr_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/tts/riva_tts_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/tts/riva_tts_perf_client /usr/local/bin/ diff --git a/WORKSPACE b/WORKSPACE index ac622ea..9331c0f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -102,3 +102,11 @@ http_archive( strip_prefix = "platforms-1.0.0", sha256 = "852b71bfa15712cec124e4a57179b6bc95d59fdf5052945f5d550e072501a769", ) + +http_archive( + name = "websocketpp", + urls = ["https://github.com/zaphoyd/websocketpp/archive/refs/tags/0.8.2.tar.gz"], + sha256 = "6ce889d85ecdc2d8fa07408d6787e7352510750daa66b5ad44aacb47bea76755", + strip_prefix = "websocketpp-0.8.2", + build_file = "//third_party:BUILD.websocketpp" +) \ No newline at end of file diff --git a/riva/clients/asr/BUILD b/riva/clients/asr/BUILD index d88c46b..b2aa4ae 100644 --- a/riva/clients/asr/BUILD +++ b/riva/clients/asr/BUILD @@ -115,6 +115,31 @@ cc_binary( ], ) +cc_binary( + name = "riva_realtime_asr_client", + srcs = ["riva_realtime_asr_client.cc"], + includes = ["-Irealtime"], + deps = [ + "//riva/clients/realtime:realtime_audio_client_lib", + "@websocketpp//:websocketpp", + "@rapidjson//:rapidjson", + "//riva/utils/stats_builder:stats_builder_lib", + "//riva/utils/wav:reader", + ] + select({ + "@platforms//cpu:aarch64": [ + "@alsa_aarch64//:libasound" + ], + "//conditions:default": [ + "@alsa//:libasound" + ], + }), + linkopts = [ + "-lssl", + "-lcrypto", + "-lboost_system", + ] +) + cc_test( name = "streaming_recognize_client_test", srcs = ["streaming_recognize_client_test.cc"], diff --git a/riva/clients/asr/riva_realtime_asr_client.cc b/riva/clients/asr/riva_realtime_asr_client.cc new file mode 100644 index 0000000..28958f2 --- /dev/null +++ b/riva/clients/asr/riva_realtime_asr_client.cc @@ -0,0 +1,206 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "riva/clients/realtime/recognition_client.h" +#include "riva/utils/stats_builder/stats_builder.h" + +using namespace nvidia::riva::utils; +using namespace nvidia::riva::realtime; + + +// Global client pointer for signal handling +std::vector g_clients; +std::mutex g_clients_mutex; + +// Signal handler for graceful shutdown +void signal_handler(int signal) { + for (auto client : g_clients) { + std::cout << "\nReceived signal " << signal << ", shutting down gracefully..." << std::endl; + client->Close(); + } + exit(0); +} + +// Helper function to format throughput as 10.246e00 instead of 1.0246e+01 +std::string format_throughput(double value) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << value << "e00"; + return oss.str(); +} + +// Function to run the client example +void client_runner( const std::string& uri, + const std::shared_ptr& audio_chunks, + PerformanceStats& perfCounter, + const std::size_t connectionTimeoutInMs, + const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, + const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs, + const bool simulateRealtime = false) +{ + nvidia::riva::realtime::RecognitionClient client(perfCounter.GetObjectName(), audio_chunks, perfCounter); + + client.SetVerboseLogging(false); + client.SetTimingConfig(connectionTimeoutInMs, sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, transcriptionTimeoutInMs, chunkDelayTimeInMs); + + // Step 1: Connect to the WebSocket server + client.Connect(uri); + + std::thread client_thread([&client]() { + client.Run(); + }); + + // Step 2: Wait for the connection to be established + if (!client.WaitForConnection()) { + std::cerr << "Failed to establish WebSocket connection" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "WebSocket connection established" << std::endl; + + // Step 3: Initialize the session + if (!client.InitializeSession()) { + std::cerr << "Failed to initialize session" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "Waiting for session update confirmation..." << std::endl; + + // Step 4: Wait for the session to be updated + if (!client.WaitForSessionUpdate()) { + std::cerr << "Session update timeout" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + // Step 5: Send the audio chunks with realistic timing + perfCounter.StartProcessingTimer(); + perfCounter.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds()); + + // Send chunks with realistic timing + client.SendAudioChunks(simulateRealtime); + + std::cout << "Waiting for transcription completion..." << std::endl; + + // Step 6: Wait for the transcription to be completed + if (client.WaitForTranscriptionCompletion()) { + std::cout << "Transcription completed successfully!" << std::endl; + perfCounter.EndProcessingTimer(); + perfCounter.SetSuccess(true); + } else { + std::cout << "Transcription did not complete within timeout" << std::endl; + perfCounter.EndProcessingTimer(); + } + + // Step 7: Close the WebSocket connection + client.Close(); + client_thread.join(); + + { + std::lock_guard lock(g_clients_mutex); + g_clients.push_back(&client); + } + + // Step 8: Report the stats + perfCounter.ReportStats(); +} + + +int main(int argc, char* argv[]) { + std::size_t num_iterations = 1; + std::size_t num_parallel_clients = 50; + bool simulateRealtime = true; + + const std::size_t connectionTimeoutInMs = 1000 * 100; + const std::size_t sessionInitTimeoutInMs = 1000 * 100; + const std::size_t sessionUpdateTimeoutInMs = 1000 * 100; + const std::size_t transcriptionTimeoutInMs = 1000 * 100; + + // Realistic audio chunk timing - based on typical microphone sampling + // For 16kHz audio with 160ms chunks, this would be 160ms + const std::size_t chunkDelayTimeInMs = 160; // Realistic delay matching chunk duration + + const std::string uri = "ws://127.0.0.1:9090/v1/realtime?intent=transcription"; + const std::string audio_file_path = "/home/yhayaran/workspace/codebase/web-socket/new_ws_client/test_files/out5.wav"; + const std::size_t chunk_duration_ms = chunkDelayTimeInMs; + const auto audio_chunks = std::make_shared(audio_file_path, chunk_duration_ms); + if (!audio_chunks->Init()) { + std::cerr << "Failed to initialize audio chunks" << std::endl; + return 1; + } + + PerformanceStats overallPerf("Overall"); + + overallPerf.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds() * num_iterations * num_parallel_clients); + + // Create StatsBuilder for all clients + StatsBuilder statsBuilder("client", audio_chunks->GetDurationSeconds(), num_parallel_clients); + + // Run iterations asynchronously + std::vector> futures; + std::cout << "Starting " << num_parallel_clients << " async clients..." << std::endl; + + overallPerf.StartProcessingTimer(); + for (std::size_t N = 0; N < num_parallel_clients; ++N) { + // Launch each client asynchronously + futures.emplace_back(std::async(std::launch::async, [&, N]() { + std::cout << "Starting client " << (N + 1) << "/" << num_parallel_clients << std::endl; + + for (std::size_t M = 0; M < num_iterations; ++M) { + std::cout << " Running iteration " << (M + 1) << "/" << num_iterations << std::endl; + client_runner( uri, + audio_chunks, + statsBuilder.GetPerformanceStats(N), + connectionTimeoutInMs, + sessionInitTimeoutInMs, + sessionUpdateTimeoutInMs, + transcriptionTimeoutInMs, + chunkDelayTimeInMs, + simulateRealtime); + } + + std::cout << "Completed client " << (N + 1) << "/" << num_parallel_clients << std::endl; + })); + } + + // Wait for all iterations to complete + std::cout << "Waiting for all iterations to complete..." << std::endl; + for (auto& future : futures) { + future.wait(); + } + std::cout << "All iterations completed!" << std::endl; + overallPerf.EndProcessingTimer(); + + // Set up signal handlers for graceful shutdown + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + // Uncomment this section to show detailed stats including success rates + statsBuilder.ReportCumulativeStats(); + statsBuilder.ReportDetailedStats(); + statsBuilder.ReportTabularStats(); + + overallPerf.ReportStats(); + return 0; +} \ No newline at end of file diff --git a/riva/clients/realtime/BUILD b/riva/clients/realtime/BUILD new file mode 100644 index 0000000..7de3f31 --- /dev/null +++ b/riva/clients/realtime/BUILD @@ -0,0 +1,34 @@ +""" +Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +NVIDIA CORPORATION and its licensors retain all intellectual property +and proprietary rights in and to this software, related documentation +and any modifications thereto. Any use, reproduction, disclosure or +distribution of this software and related documentation without an express +license agreement from NVIDIA CORPORATION is strictly prohibited. +""" + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "realtime_audio_client_lib", + srcs = [ + "audio_chunks.cpp", + "base_client.cpp", + "recognition_client.cpp", + ], + hdrs = [ + "audio_chunks.h", + "base_client.h", + "recognition_client.h", + ], + deps = [ + "//riva/utils/wav:reader", + "//riva/utils/stats_builder:stats_builder_lib", + "@websocketpp//:websocketpp", + "@rapidjson//:rapidjson", + "@glog//:glog", + "@com_github_gflags_gflags//:gflags", + ], +) \ No newline at end of file diff --git a/riva/clients/realtime/audio_chunks.cpp b/riva/clients/realtime/audio_chunks.cpp new file mode 100644 index 0000000..98bdc96 --- /dev/null +++ b/riva/clients/realtime/audio_chunks.cpp @@ -0,0 +1,160 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "audio_chunks.h" +#include "riva/utils/wav/wav_reader.h" +#include "riva/utils/wav/wav_data.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +nvidia::riva::realtime::AudioChunks::AudioChunks(const std::string& filepath, const int& chunk_size_ms) + : filepath_(filepath), chunk_size_ms_(chunk_size_ms) { +} + +void nvidia::riva::realtime::AudioChunks::CalculateChunkSizeBytes() { + chunk_size_bytes_ = (GetSampleRateHz() * GetChunkSizeMs() / 1000) * sizeof(int16_t); + std::cout << "[AudioChunks] Calculated chunk size: " << chunk_size_bytes_ << " bytes" << std::endl; +} + +void nvidia::riva::realtime::AudioChunks::SplitIntoChunks() { + const std::vector& raw_data = wav_data_->data; + size_t total_size = raw_data.size(); + + std::cout << "[AudioChunks] Splitting WAV file into chunks of " << chunk_size_bytes_ << " bytes" << std::endl; + + chunk_base64s_.clear(); + for (size_t i = 0; i < total_size; i += chunk_size_bytes_) { + size_t current_chunk_size = std::min(chunk_size_bytes_, total_size - i); + std::vector chunk(raw_data.begin() + i, raw_data.begin() + i + current_chunk_size); + std::string chunk_base64 = EncodeBase64(chunk); + chunk_base64s_.push_back(chunk_base64); + } +} + +std::string nvidia::riva::realtime::AudioChunks::EncodeBase64(const std::vector& data) { + const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + std::string result; + int val = 0, valb = -6; + + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + result.push_back(base64_chars[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { + result.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (result.size() % 4) { + result.push_back('='); + } + + return result; +} + +bool nvidia::riva::realtime::AudioChunks::Init() { + if (initialized_) { + std::cout << "[AudioChunks] Chunks already initialized" << std::endl; + return true; + } + + std::cout << "[AudioChunks] Initializing chunks for file: " << filepath_ << std::endl; + fs::path path(filepath_); + std::string extension = path.extension().string(); + + // File exists + if (!fs::exists(filepath_)) { + std::cerr << "[AudioChunks] Error: File does not exist, " << filepath_ << std::endl; + return false; + } + + // File is a WAV file + if (extension != ".wav") { + std::cerr << "[AudioChunks] Error: File is not a WAV file, " << filepath_ << std::endl; + return false; + } + + // Load WAV file using the existing WAV utilities + std::vector> all_wav; + LoadWavData(all_wav, filepath_); + + if (all_wav.empty()) { + std::cerr << "[AudioChunks] Error: Failed to load WAV file, " << filepath_ << std::endl; + return false; + } + + wav_data_ = all_wav[0]; // Use the first WAV file + + CalculateChunkSizeBytes(); + SplitIntoChunks(); + + initialized_ = true; + + return initialized_; +} + +// Getter implementations +std::string nvidia::riva::realtime::AudioChunks::GetFilepath() const { + return filepath_; +} + +size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeMs() const { + return chunk_size_ms_; +} + +size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeBytes() const { + return chunk_size_bytes_; +} + +bool nvidia::riva::realtime::AudioChunks::IsInitialized() const { + return initialized_; +} + +// WAV file properties +int nvidia::riva::realtime::AudioChunks::GetSampleRateHz() const { + return wav_data_->sample_rate; +} + +int nvidia::riva::realtime::AudioChunks::GetNumChannels() const { + return wav_data_->channels; +} + +int nvidia::riva::realtime::AudioChunks::GetBitDepth() const { + // Calculate bit depth from data size and sample rate + if (wav_data_->channels > 0 && wav_data_->sample_rate > 0) { + return (wav_data_->data.size() * 8) / (wav_data_->channels * wav_data_->sample_rate); + } + return 16; // Default to 16-bit +} + +double nvidia::riva::realtime::AudioChunks::GetDurationSeconds() const { + if (wav_data_->sample_rate > 0 && wav_data_->channels > 0) { + return static_cast(wav_data_->data.size()) / (wav_data_->sample_rate * wav_data_->channels * 2); // Assuming 16-bit + } + return 0.0; +} + +int nvidia::riva::realtime::AudioChunks::GetNumSamples() const { + if (wav_data_->channels > 0) { + return wav_data_->data.size() / (wav_data_->channels * 2); // Assuming 16-bit + } + return 0; +} diff --git a/riva/clients/realtime/audio_chunks.h b/riva/clients/realtime/audio_chunks.h new file mode 100644 index 0000000..e8f7647 --- /dev/null +++ b/riva/clients/realtime/audio_chunks.h @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef AUDIO_CHUNKS_H +#define AUDIO_CHUNKS_H + +#include +#include +#include +#include +#include +#include "riva/utils/wav/wav_reader.h" +#include "riva/utils/wav/wav_data.h" + +namespace fs = std::filesystem; + +namespace nvidia::riva::realtime { + class AudioChunks { + private: + bool initialized_ = false; + std::string filepath_; + size_t chunk_size_ms_; + size_t chunk_size_bytes_; + std::shared_ptr wav_data_; + std::vector chunk_base64s_; + + void CalculateChunkSizeBytes(); + void SplitIntoChunks(); + std::string EncodeBase64(const std::vector& data); + + public: + AudioChunks(const std::string& filepath, const int& chunk_size_ms); + ~AudioChunks() = default; + + bool Init(); + + // Getters + std::string GetFilepath() const; + size_t GetChunkSizeMs() const; + size_t GetChunkSizeBytes() const; + bool IsInitialized() const; + + // WAV file properties + int GetSampleRateHz() const; + int GetNumChannels() const; + int GetBitDepth() const; + double GetDurationSeconds() const; + int GetNumSamples() const; + const std::vector& GetChunkBase64s() const; + + // Iterator support + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + using reverse_iterator = std::vector::reverse_iterator; + using const_reverse_iterator = std::vector::const_reverse_iterator; + + // Iterator methods + iterator begin() { return chunk_base64s_.begin(); } + const_iterator begin() const { return chunk_base64s_.begin(); } + iterator end() { return chunk_base64s_.end(); } + const_iterator end() const { return chunk_base64s_.end(); } + + // Reverse iterator methods + reverse_iterator rbegin() { return chunk_base64s_.rbegin(); } + const_reverse_iterator rbegin() const { return chunk_base64s_.rbegin(); } + reverse_iterator rend() { return chunk_base64s_.rend(); } + const_reverse_iterator rend() const { return chunk_base64s_.rend(); } + + // Const iterator methods + const_iterator cbegin() const { return chunk_base64s_.cbegin(); } + const_iterator cend() const { return chunk_base64s_.cend(); } + const_reverse_iterator crbegin() const { return chunk_base64s_.crbegin(); } + const_reverse_iterator crend() const { return chunk_base64s_.crend(); } + + // Size methods + size_t size() const { return chunk_base64s_.size(); } + bool empty() const { return chunk_base64s_.empty(); } + }; + +} // namespace nvidia::riva::realtime + +#endif // AUDIO_CHUNKS_H \ No newline at end of file diff --git a/riva/clients/realtime/base_client.cpp b/riva/clients/realtime/base_client.cpp new file mode 100644 index 0000000..ea9928a --- /dev/null +++ b/riva/clients/realtime/base_client.cpp @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "base_client.h" +#include +#include + + +nvidia::riva::realtime::WebSocketClientBase::WebSocketClientBase(const std::string& uri) + : connected_(false), + connectionClosedByServer_(false), + connectionTimeoutMs_(std::size_t(5000)), + uri_(uri) { + + // Set up logging - suppress verbose internal messages + wsClient_.set_access_channels(websocketpp::log::alevel::connect); + wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); + wsClient_.set_access_channels(websocketpp::log::alevel::fail); + wsClient_.set_access_channels(websocketpp::log::alevel::app); + + // Initialize ASIO + wsClient_.init_asio(); + + // Set up handlers + wsClient_.set_open_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnOpen, this, std::placeholders::_1)); + wsClient_.set_close_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnClose, this, std::placeholders::_1)); + wsClient_.set_fail_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnFail, this, std::placeholders::_1)); + wsClient_.set_message_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnMessage, this, std::placeholders::_1, std::placeholders::_2)); +} + +void nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(const std::size_t connectionTimeoutMs) { + connectionTimeoutMs_ = connectionTimeoutMs; +} + +std::size_t nvidia::riva::realtime::WebSocketClientBase::GetConnectionTimeout() { + return connectionTimeoutMs_; +} + +void nvidia::riva::realtime::WebSocketClientBase::SetVerboseLogging(bool verbose) { + if (verbose) { + // Enable all logging channels + wsClient_.set_access_channels(websocketpp::log::alevel::all); + wsClient_.clear_access_channels(websocketpp::log::alevel::frame_payload); + } else { + // Minimal logging - only important events + wsClient_.clear_access_channels(websocketpp::log::alevel::all); + wsClient_.set_access_channels(websocketpp::log::alevel::connect); + wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); + wsClient_.set_access_channels(websocketpp::log::alevel::fail); + wsClient_.set_access_channels(websocketpp::log::alevel::app); + } +} + +void nvidia::riva::realtime::WebSocketClientBase::Connect(const std::string& uri) { + uri_ = uri; + websocketpp::lib::error_code ec; + + websocketpp_client::connection_ptr con = wsClient_.get_connection(uri, ec); + if (ec) { + std::cerr << "Could not create connection: " << ec.message() << std::endl; + return; + } + + wsClient_.connect(con); +} + +void nvidia::riva::realtime::WebSocketClientBase::Run() { + wsClient_.run(); +} + +void nvidia::riva::realtime::WebSocketClientBase::Send(const std::string& message) { + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, message, websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; + } + } +} + +void nvidia::riva::realtime::WebSocketClientBase::Close() { + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.close(connectionHdl_, websocketpp::close::status::normal, "Client closing", ec); + } +} + +void nvidia::riva::realtime::WebSocketClientBase::SendJsonMessage(const std::string& type, const std::string& data) { + std::lock_guard lock(mutex_); + if (connected_) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value(type.c_str(), allocator), allocator); + if (!data.empty()) { + doc.AddMember("data", rapidjson::Value(data.c_str(), allocator), allocator); + } + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; + } else { + std::cout << "Sent: " << buffer.GetString() << std::endl; + } + } +} + +void nvidia::riva::realtime::WebSocketClientBase::OnOpen(websocketpp::connection_hdl hdl) { + std::lock_guard lock(mutex_); + connectionHdl_ = hdl; + connected_ = true; + + // Notify waiting threads that connection is established + { + std::lock_guard conn_lock(connectionMutex_); + connectionCv_.notify_one(); + } + + std::cout << "Connected to " << uri_ << std::endl; +} + +void nvidia::riva::realtime::WebSocketClientBase::OnClose(websocketpp::connection_hdl hdl) { + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Check if this was a server-initiated close + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "Connection closed" << std::endl; +} + +void nvidia::riva::realtime::WebSocketClientBase::OnFail(websocketpp::connection_hdl hdl) { + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Mark as server-initiated failure + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "************************ Connection failed" << std::endl; +} + +void nvidia::riva::realtime::WebSocketClientBase::OnMessage(websocketpp::connection_hdl hdl, message_ptr msg) { + (void)hdl; // Suppress unused parameter warning + HandleMessage(msg->get_payload()); +} + +bool nvidia::riva::realtime::WebSocketClientBase::WaitForConnection() { + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for(lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connected_; }); +} + +bool nvidia::riva::realtime::WebSocketClientBase::WaitForDisconnection() { + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for(lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return !connected_; }); +} + +bool nvidia::riva::realtime::WebSocketClientBase::WaitForServerClose() { + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for(lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connectionClosedByServer_; }); +} + diff --git a/riva/clients/realtime/base_client.h b/riva/clients/realtime/base_client.h new file mode 100644 index 0000000..400f319 --- /dev/null +++ b/riva/clients/realtime/base_client.h @@ -0,0 +1,86 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef REALTIME_CLIENT_H +#define REALTIME_CLIENT_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "audio_chunks.h" + +namespace nvidia::riva::realtime { + class WebSocketClientBase { + protected: + typedef websocketpp::client websocketpp_client; + typedef websocketpp::config::asio_client::message_type::ptr message_ptr; + + websocketpp_client wsClient_; + websocketpp::connection_hdl connectionHdl_; + + std::string uri_; + bool connected_; + std::mutex mutex_; + + // Connection state + bool connectionClosedByServer_; + std::condition_variable connectionCv_; + std::mutex connectionMutex_; + std::size_t connectionTimeoutMs_; + + // Protected access to websocket client for derived classes + websocketpp_client& GetWsClient() { return wsClient_; } + websocketpp::connection_hdl& GetConnection() { return connectionHdl_; } + std::mutex& GetConnectionMutex() { return connectionMutex_; } + + public: + WebSocketClientBase(const std::string& uri); + ~WebSocketClientBase() = default; + + // Connection timeout + void SetConnectionTimeout(const std::size_t connectionTimeoutMs); + std::size_t GetConnectionTimeout(); + + // Connection status + bool IsConnected() const { return connected_; } + bool IsConnectionClosedByServer() const { return connectionClosedByServer_; } + bool IsConnectionOpen() const { return connected_ && !connectionClosedByServer_; } + bool IsConnectionClosed() const { return !connected_ || connectionClosedByServer_; } + + // Control logging verbosity + void SetVerboseLogging(bool verbose); + + // Connection management + void Connect(const std::string& uri); + void Run(); + void Send(const std::string& message); + void Close(); + void SendJsonMessage(const std::string& type, const std::string& data = ""); + + // Connection waiting methods + bool WaitForConnection(); + bool WaitForDisconnection(); + bool WaitForServerClose(); + + // Event handlers + void OnOpen(websocketpp::connection_hdl hdl); + void OnClose(websocketpp::connection_hdl hdl); + void OnFail(websocketpp::connection_hdl hdl); + void OnMessage(websocketpp::connection_hdl hdl, message_ptr msg); + virtual void HandleMessage(const std::string& message) = 0; + }; +} // namespace nvidia::riva::realtime +#endif // REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/recognition_client.cpp b/riva/clients/realtime/recognition_client.cpp new file mode 100644 index 0000000..a0ae91d --- /dev/null +++ b/riva/clients/realtime/recognition_client.cpp @@ -0,0 +1,401 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "recognition_client.h" +#include "base_client.h" +#include +#include +#include +#include + + +nvidia::riva::realtime::RecognitionClient::RecognitionClient( + const std::string& objectName, + const std::shared_ptr audioChunksPtr, + nvidia::riva::utils::PerformanceStats& perfCounter) + : WebSocketClientBase("ws://127.0.0.1:9090/v1/realtime?intent=transcription"), + sessionInitialized_(false), + sessionUpdated_(false), + transcriptionCompleted_(false), + finalTranscriptionCount_(0), + connectionTimeoutInMs_(std::size_t(10000)), + sessionInitTimeoutInMs_(std::size_t(10000)), + sessionUpdateTimeoutInMs_(std::size_t(10000)), + transcriptionTimeoutInMs_(std::size_t(10000)), + chunkDelayTimeInMs_(std::size_t(1000)), + objectName_(objectName), + audioChunksPtr_(audioChunksPtr), + perfCounter_(perfCounter) { + + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); +} + + +void nvidia::riva::realtime::RecognitionClient::SetTimingConfig( const std::size_t connectionTimeoutInMs, + const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, + const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs) { + connectionTimeoutInMs_ = connectionTimeoutInMs; + sessionInitTimeoutInMs_ = sessionInitTimeoutInMs; + sessionUpdateTimeoutInMs_ = sessionUpdateTimeoutInMs; + transcriptionTimeoutInMs_ = transcriptionTimeoutInMs; + chunkDelayTimeInMs_ = chunkDelayTimeInMs; + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); +} + +void nvidia::riva::realtime::RecognitionClient::Log(const std::string& message) { + std::cout << "[" << objectName_ << "]" << message << std::endl; +} + +bool nvidia::riva::realtime::RecognitionClient::WaitForTranscriptionCompletion() { + std::unique_lock lock(transcriptionMutex_); + + // Reset completion flag + transcriptionCompleted_ = false; + + // Wait for completion event with timeout (increased from 3 seconds to 10 seconds) + bool completed = transcriptionCv_.wait_for(lock, + std::chrono::milliseconds(transcriptionTimeoutInMs_), + [this] { return transcriptionCompleted_; }); + + if (!completed) { + Log(" Timeout waiting for transcription completion after " + std::to_string(transcriptionTimeoutInMs_) + " milliseconds"); + } + else if (transcriptionCompleted_) { + // Close the connection + Close(); + } + + return completed; +} + +bool nvidia::riva::realtime::RecognitionClient::WaitForSessionUpdate() { + std::unique_lock lock(sessionMutex_); + + if (sessionUpdated_) { + return true; + } + + // Wait for session update event with timeout + sessionUpdated_ = sessionCv_.wait_for( + lock, + std::chrono::milliseconds(sessionUpdateTimeoutInMs_), + [this] { return sessionUpdated_; } + ); + + if (!sessionUpdated_) { + Log("Timeout waiting for session update after " + std::to_string(sessionUpdateTimeoutInMs_) + " milliseconds"); + } + + return sessionUpdated_; +} + +// Send audio buffer append message (inspired by Python realtime.py) +void nvidia::riva::realtime::RecognitionClient::SendAudioAppend(const std::string& audioBase64) +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) + { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + doc.AddMember("type", rapidjson::Value("input_audio_buffer.append", allocator), allocator); + doc.AddMember("audio", rapidjson::Value(audioBase64.c_str(), allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio append failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } + else { + Log("Skipping audio append - connection closed"); + } +} + +// Send audio buffer commit message (inspired by Python realtime.py) +void nvidia::riva::realtime::RecognitionClient::SendAudioCommit() { + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) + { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.commit", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio commit failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } + else { + Log("Skipping audio commit - connection closed"); + } +} + +// Send audio buffer done message (inspired by Python realtime.py) +void nvidia::riva::realtime::RecognitionClient::SendAudioDone() { + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) + { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.done", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio done failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } else { + Log("Audio streaming completed"); + } + } + else { + Log("Skipping audio done - connection closed"); + } +} + +// Session initialization (inspired by Python realtime.py) +bool nvidia::riva::realtime::RecognitionClient::InitializeSession() { + std::cout << "[" << objectName_ << "]" << " Initializing session..." << std::endl; + + // Wait for the initial connection and session creation (increased from 1000ms to 3000ms) + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + + // Check if we're still connected + if (IsConnectionClosed()) { + std::cerr << "Connection lost during session initialization" << std::endl; + return false; + } + + return UpdateSessionConfig(); +} + +bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { + int sampleRateHz = audioChunksPtr_->GetSampleRateHz(); + int numChannels = audioChunksPtr_->GetNumChannels(); + + std::cout << "Updating session configuration..." << std::endl; + std::cout << "Using WAV file parameters - Sample rate: " << sampleRateHz + << " Hz, Channels: " << numChannels << std::endl; + + // Create session configuration similar to Python client + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + // Create session config + rapidjson::Value session_config(rapidjson::kObjectType); + + // Input audio transcription config + rapidjson::Value transcription_config(rapidjson::kObjectType); + transcription_config.AddMember("language", "en-US", allocator); + transcription_config.AddMember("model", "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble", allocator); + transcription_config.AddMember("prompt", "", allocator); + session_config.AddMember("input_audio_transcription", transcription_config, allocator); + + // Input audio params - use actual WAV file parameters + rapidjson::Value audio_params(rapidjson::kObjectType); + audio_params.AddMember("sample_rate_hz", sampleRateHz, allocator); + audio_params.AddMember("num_channels", numChannels, allocator); + session_config.AddMember("input_audio_params", audio_params, allocator); + + // Recognition config + rapidjson::Value recognition_config(rapidjson::kObjectType); + recognition_config.AddMember("max_alternatives", 1, allocator); + recognition_config.AddMember("enable_automatic_punctuation", false, allocator); + recognition_config.AddMember("enable_word_time_offsets", false, allocator); + recognition_config.AddMember("enable_profanity_filter", false, allocator); + recognition_config.AddMember("enable_verbatim_transcripts", false, allocator); + session_config.AddMember("recognition_config", recognition_config, allocator); + + // Create update request + rapidjson::Value update_request(rapidjson::kObjectType); + update_request.AddMember("type", "transcription_session.update", allocator); + update_request.AddMember("session", session_config, allocator); + + // Send the update request + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + update_request.Accept(writer); + + if (IsConnectionOpen()) + { + std::lock_guard lock(connectionMutex_); + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cout << "Session update failed: " << ec.message() << std::endl; + return false; + } else { + std::cout << "Session update request sent" << std::endl; + } + } + + WaitForSessionUpdate(); + return true; +} + +// Send audio chunks +void nvidia::riva::realtime::RecognitionClient::SendAudioChunks(const bool simulateRealtime) { + if (audioChunksPtr_ == nullptr) { + std::cerr << "Audio chunks pointer is null. Please call InitializeSession first." << std::endl; + return; + } + + if (!IsSessionInitialized()) { + std::cerr << "Session is not initialized. Please call InitializeSession first." << std::endl; + return; + } + + if (audioChunksPtr_->size() == 0) { + std::cerr << "No audio chunks to send. Please add audio chunks to the audio chunks pointer." << std::endl; + return; + } + + std::cout << "Sending audio chunks with " << (simulateRealtime ? "real-time" : "burst") << " timing..." << std::endl; + + // Track timing for accurate real-time simulation + auto stream_start_time = std::chrono::steady_clock::now(); + size_t chunk_index = 0; + + for (const std::string& chunk_base64 : *audioChunksPtr_) { + SendAudioAppend(chunk_base64); + SendAudioCommit(); + + if (simulateRealtime) { + // Calculate the exact time when this chunk should be sent + auto chunk_duration_ms = audioChunksPtr_->GetChunkSizeMs(); + auto expected_send_time = stream_start_time + + std::chrono::milliseconds((chunk_index + 1) * chunk_duration_ms); + + auto current_time = std::chrono::steady_clock::now(); + auto time_to_wait = expected_send_time - current_time; + + // Log timing information + // Timing calculations for real-time simulation (commented out as unused) + // auto elapsed_ms = std::chrono::duration(current_time - stream_start_time).count(); + // auto expected_ms = (chunk_index + 1) * chunk_duration_ms; + // auto drift_ms = elapsed_ms - expected_ms; + + //auto wait_ms = std::chrono::duration(time_to_wait).count(); + //std::cout << "[" << objectName_ << "] Chunk " << (chunk_index + 1) << "/" << audioChunksPtr_->size() + // << " - Elapsed: " << std::fixed << std::setprecision(1) << elapsed_ms << "ms" + // << " Expected: " << expected_ms << "ms" + // << " Drift: " << drift_ms << "ms"; + // << " Waiting: " << wait_ms << "ms" << std::endl; + + if (time_to_wait > std::chrono::milliseconds(0)) { + std::this_thread::sleep_for(time_to_wait); + } + } + else { + // Burst mode - just log progress + if ((chunk_index + 1) % 10 == 0 || chunk_index == audioChunksPtr_->size() - 1) { + //zstd::cout << "[" << objectName_ << "] Sent " << (chunk_index + 1) << "/" << audioChunksPtr_->size() << " chunks" << std::endl; + } + } + + chunk_index++; + } + SendAudioDone(); +} + +void nvidia::riva::realtime::RecognitionClient::HandleMessage(const std::string& message) { + bool is_last_result = false; + rapidjson::Document doc; + + if (doc.Parse(message.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON message" << std::endl; + return; + } + + std::string eventType = doc.HasMember("type") ? doc["type"].GetString() : ""; + + if (eventType == "conversation.created") { + std::cout << "Conversation created" << std::endl; + } + else if (eventType == "transcription_session.updated") { + std::cout << "Session updated successfully" << std::endl; + sessionInitialized_ = true; + // Signal session update completion + { + std::lock_guard lock(sessionMutex_); + sessionUpdated_ = true; + } + sessionCv_.notify_one(); + } + else if (eventType == "conversation.item.input_audio_transcription.delta") { + if (doc.HasMember("delta")) { + std::string delta = doc["delta"].GetString(); + + //std::cout << "Delta: " << delta << std::endl; + std::cout.flush(); // Ensure immediate output for streaming + } + } + else if (eventType == "conversation.item.input_audio_transcription.completed") { + finalTranscriptionCount_++; + std::string transcript = doc.HasMember("transcript") ? doc["transcript"].GetString() : ""; + is_last_result = doc.HasMember("is_last_result") ? doc["is_last_result"].GetBool() : false; + + if (is_last_result) { + std::cout << "--------------------------------" << std::endl; + std::cout << "Final transcript: " << transcript << std::endl; + std::cout << "Final transcription count: " << finalTranscriptionCount_ << std::endl; + std::cout << "--------------------------------" << std::endl; + + // Transcription completed + std::lock_guard lock(transcriptionMutex_); + transcriptionCompleted_ = true; + transcriptionCv_.notify_one(); + } + else { + std::cout << "Interim transcript: " << transcript << std::endl; + } + } + else if (eventType.find("error") != std::string::npos) { + std::string errorMsg = "Unknown error"; + if (doc.HasMember("error") && doc["error"].HasMember("message")) { + errorMsg = doc["error"]["message"].GetString(); + } + std::cerr << "Error: " << errorMsg << std::endl; + } + else { + //std::cout << "Received message type: " << event_type << std::endl; + } +} + \ No newline at end of file diff --git a/riva/clients/realtime/recognition_client.h b/riva/clients/realtime/recognition_client.h new file mode 100644 index 0000000..5783647 --- /dev/null +++ b/riva/clients/realtime/recognition_client.h @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef RECOGNITION_CLIENT_H +#define RECOGNITION_CLIENT_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "audio_chunks.h" +#include "base_client.h" +#include "riva/utils/stats_builder/stats_builder.h" + +namespace nvidia::riva::realtime { + class RecognitionClient : public WebSocketClientBase { + private: + + // Session tracking + bool sessionInitialized_; + bool sessionUpdated_; + std::condition_variable sessionCv_; + std::mutex sessionMutex_; + nvidia::riva::utils::PerformanceStats& perfCounter_; + + + // Event tracking + bool transcriptionCompleted_; + std::condition_variable transcriptionCv_; + std::mutex transcriptionMutex_; + + std::size_t finalTranscriptionCount_; + + // Configurable timing parameters (in milliseconds) + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + std::string objectName_; + + // Audio processing + std::shared_ptr audioChunksPtr_; + + // Audio streaming methods + void SendAudioAppend(const std::string& audioBase64); + void SendAudioCommit(); + void SendAudioDone(); + + // Override base class methods + void HandleMessage(const std::string& message) override; + + public: + RecognitionClient( const std::string& objectName, + const std::shared_ptr audioChunksPtr, + nvidia::riva::utils::PerformanceStats& perfCounter); + ~RecognitionClient() = default; + + void Log(const std::string& message); + + // Timing configuration + void SetTimingConfig( const std::size_t connectionTimeoutInMs, + const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, + const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs); + + // Session management methods + bool InitializeSession(); + bool UpdateSessionConfig(); + + bool IsSessionInitialized() const { return sessionInitialized_; } + + // Wait methods + bool WaitForSessionUpdate(); + bool WaitForTranscriptionCompletion(); + + // WAV file processing methods + void SendAudioChunks(const bool simulateRealtime = false); + }; + +} // namespace nvidia::riva::realtime + +#endif // RECOGNITION_CLIENT_H \ No newline at end of file diff --git a/riva/utils/stats_builder/BUILD b/riva/utils/stats_builder/BUILD new file mode 100644 index 0000000..5b550ee --- /dev/null +++ b/riva/utils/stats_builder/BUILD @@ -0,0 +1,7 @@ +cc_library( + name = "stats_builder_lib", + srcs = ["stats_builder.cpp"], + hdrs = ["stats_builder.h"], + includes = ["."], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.cpp b/riva/utils/stats_builder/stats_builder.cpp new file mode 100644 index 0000000..22794d1 --- /dev/null +++ b/riva/utils/stats_builder/stats_builder.cpp @@ -0,0 +1,284 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#include "stats_builder.h" +#include +#include +#include +#include + +namespace nvidia::riva::utils { + +PerformanceStats::PerformanceStats(const std::string& objectName) + : success_(false), + objectName_(objectName), + processing_start_time_(std::chrono::steady_clock::now()), + processing_end_time_(std::chrono::steady_clock::now()), + audio_duration_seconds_(0.0) {} + +StatsBuilder::StatsBuilder(const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations) + : audio_duration_seconds_(audio_duration_seconds), num_iterations_(num_iterations), object_name_(objectName) { + // Pre-allocate the vector with the expected number of iterations + performanceStats_.reserve(num_iterations); + + // Create PerformanceStats objects for each iteration + for (std::size_t i = 0; i < num_iterations; ++i) { + std::string iteration_name = objectName + "-" + std::to_string(i); + performanceStats_.emplace_back(iteration_name); + // Set the audio duration for each performance stats object + performanceStats_.back().SetAudioDurationInSeconds(audio_duration_seconds); + } +} + +void PerformanceStats::StartProcessingTimer() { + processing_start_time_ = std::chrono::steady_clock::now(); + //std::cout << "Starting processing timer: " << std::chrono::duration_cast(processing_start_time_.time_since_epoch()).count() << std::endl; + } + +void PerformanceStats::EndProcessingTimer() { + processing_end_time_ = std::chrono::steady_clock::now(); + //std::cout << "Ending processing timer: " << std::chrono::duration_cast(processing_end_time_.time_since_epoch()).count() << std::endl; + } + +double PerformanceStats::GetRuntimeInMs() const { + auto durationInMs = std::chrono::duration_cast( + processing_end_time_ - processing_start_time_); + return durationInMs.count(); +} + +double PerformanceStats::GetRuntimeInSeconds() const { + return GetRuntimeInMs() / 1000.0; +} + +void PerformanceStats::SetAudioDurationInSeconds(double audio_duration_seconds) { + audio_duration_seconds_ = audio_duration_seconds; +} + +double PerformanceStats::GetThroughputRTFX() const { + double runtimeInMs = GetRuntimeInMs(); + if (runtimeInMs > 0.0 && audio_duration_seconds_ > 0.0) { + // RTFX = (Total Audio Processed in seconds) × 1000 ÷ (Total Runtime in milliseconds) + return (audio_duration_seconds_ * 1000.0) / runtimeInMs; + } + return 0.0; + } + +void PerformanceStats::SetObjectName(const std::string& objectName) { + objectName_ = objectName; +} + +std::string PerformanceStats::GetObjectName() const { + return objectName_; +} + +void PerformanceStats::ReportStats() { + std::cout << "Object Name: " << GetObjectName() << std::endl; + std::cout << "Success: " << IsSuccess() << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Total Runtime: " << GetRuntimeInMs() << " ms (" << GetRuntimeInSeconds() << " seconds)" << std::endl; + std::cout << "Throughput: " << GetThroughputRTFX() << " RTFX" << std::endl; +} + + + +void StatsBuilder::ReportCumulativeStats() { + std::cout << "Cumulative Stats" << std::endl; + std::cout << "=================" << std::endl; + for (auto performanceStats : performanceStats_) { + std::cout << "Object Name: " << performanceStats.GetObjectName() << std::endl; + std::cout << "Total Runtime: " << performanceStats.GetRuntimeInMs() << " ms (" << performanceStats.GetRuntimeInSeconds() << " seconds)" << std::endl; + std::cout << "Throughput: " << performanceStats.GetThroughputRTFX() << " RTFX" << std::endl; + } +} + +// Helper function to calculate percentile +double CalculatePercentile(const std::vector& values, double percentile) { + if (values.empty()) return 0.0; + + std::vector sorted_values = values; + std::sort(sorted_values.begin(), sorted_values.end()); + + double index = (percentile / 100.0) * (sorted_values.size() - 1); + int lower_index = static_cast(index); + int upper_index = lower_index + 1; + + if (upper_index >= sorted_values.size()) { + return sorted_values[lower_index]; + } + + double weight = index - lower_index; + return sorted_values[lower_index] * (1 - weight) + sorted_values[upper_index] * weight; +} + +// Statistical methods for runtime +double StatsBuilder::GetAverageRuntime() const { + if (performanceStats_.empty()) return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetRuntimeInMs(); + } + return sum / performanceStats_.size(); +} + +double StatsBuilder::GetP50Runtime() const { + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 50.0); +} + +double StatsBuilder::GetP90Runtime() const { + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 90.0); +} + +double StatsBuilder::GetP95Runtime() const { + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 95.0); +} + +double StatsBuilder::GetP99Runtime() const { + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 99.0); +} + +double StatsBuilder::GetMinRuntime() const { + if (performanceStats_.empty()) return 0.0; + + double min_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + min_runtime = std::min(min_runtime, stats.GetRuntimeInMs()); + } + return min_runtime; +} + +double StatsBuilder::GetMaxRuntime() const { + if (performanceStats_.empty()) return 0.0; + + double max_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + max_runtime = std::max(max_runtime, stats.GetRuntimeInMs()); + } + return max_runtime; +} + +// Statistical methods for throughput +double StatsBuilder::GetAverageThroughput() const { + if (performanceStats_.empty()) return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum / performanceStats_.size(); +} + +// Statistical methods for throughput +double StatsBuilder::GetCumulativeThroughput() const { + if (performanceStats_.empty()) return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum; +} + +double StatsBuilder::GetP90Throughput() const { + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 90.0); +} + +double StatsBuilder::GetP95Throughput() const { + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 95.0); +} + +double StatsBuilder::GetP99Throughput() const { + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 99.0); +} + +bool StatsBuilder::AreAllIterationsSuccessful() const { + if (performanceStats_.empty()) return false; + + for (const auto& stats : performanceStats_) { + if (!stats.IsSuccess()) { + return false; + } + } + return true; +} + +std::size_t StatsBuilder::GetSuccessfulIterationsCount() const { + std::size_t success_count = 0; + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) { + success_count++; + } + } + return success_count; +} + +std::size_t StatsBuilder::GetFailedIterationsCount() const { + return performanceStats_.size() - GetSuccessfulIterationsCount(); +} + +double StatsBuilder::GetSuccessRate() const { + if (performanceStats_.empty()) return 0.0; + return static_cast(GetSuccessfulIterationsCount()) / performanceStats_.size() * 100.0; +} + +void StatsBuilder::ReportDetailedStats() const { + std::cout << "\n=== DETAILED PERFORMANCE STATISTICS ===" << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Number of Iterations: " << num_iterations_ << std::endl; + std::cout << "Sample Count: " << performanceStats_.size() << std::endl; + + // Add success rate information + std::cout << "Success Rate: " << GetSuccessRate() << "% (" << GetSuccessfulIterationsCount() + << "/" << performanceStats_.size() << " iterations)" << std::endl; + std::cout << "All Iterations Successful: " << (AreAllIterationsSuccessful() ? "YES" : "NO") << std::endl; + + std::cout << "\n--- RUNTIME STATISTICS (ms) ---" << std::endl; + std::cout << "Average: " << GetAverageRuntime() << " ms" << std::endl; + std::cout << "P50: " << GetP50Runtime() << " ms" << std::endl; + std::cout << "P90: " << GetP90Runtime() << " ms" << std::endl; + std::cout << "P95: " << GetP95Runtime() << " ms" << std::endl; + std::cout << "P99: " << GetP99Runtime() << " ms" << std::endl; + std::cout << "Min: " << GetMinRuntime() << " ms" << std::endl; + std::cout << "Max: " << GetMaxRuntime() << " ms" << std::endl; + + std::cout << "\n--- THROUGHPUT STATISTICS (RTFX) ---" << std::endl; + std::cout << "Average: " << GetAverageThroughput() << " RTFX" << std::endl; + std::cout << "Cumulative: " << GetCumulativeThroughput() << " RTFX" << std::endl; + std::cout << "P90: " << GetP90Throughput() << " RTFX" << std::endl; + std::cout << "P95: " << GetP95Throughput() << " RTFX" << std::endl; + std::cout << "P99: " << GetP99Throughput() << " RTFX" << std::endl; + + std::cout << "=====================================" << std::endl; +} + +} // namespace nvidia::riva::utils \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.h b/riva/utils/stats_builder/stats_builder.h new file mode 100644 index 0000000..5cfa89e --- /dev/null +++ b/riva/utils/stats_builder/stats_builder.h @@ -0,0 +1,147 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + */ + +#ifndef STATS_BUILDER_H +#define STATS_BUILDER_H + +#include +#include +#include +#include +#include +#include +#include // Required for std::setw and std::fixed + +namespace nvidia::riva::utils { + +class PerformanceStats { + private: + bool success_; + std::string objectName_; + // Timing measurement + std::chrono::steady_clock::time_point processing_start_time_; + std::chrono::steady_clock::time_point processing_end_time_; + double audio_duration_seconds_; + + public: + PerformanceStats(const std::string& objectName); + ~PerformanceStats() = default; + + bool IsSuccess() const { return success_; } + void SetSuccess(bool success) { success_ = success; } + + void StartProcessingTimer(); + void EndProcessingTimer(); + std::chrono::steady_clock::time_point GetStartTime() const { return processing_start_time_; } + double GetRuntimeInMs() const; + double GetRuntimeInSeconds() const; + void SetAudioDurationInSeconds(double audio_duration_seconds); + double GetAudioDurationInSeconds() const { return audio_duration_seconds_; } + double GetThroughputRTFX() const; + + void SetObjectName(const std::string& objectName); + std::string GetObjectName() const; + + void ReportStats(); +}; + +class StatsBuilder { + private: + std::vector performanceStats_; + double audio_duration_seconds_; + std::size_t num_iterations_; + std::string object_name_; // Added to store the object name + + public: + StatsBuilder(const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations); + ~StatsBuilder() = default; + + void SetAudioDurationInSeconds(double audio_duration_seconds); + void SetNumIterations(std::size_t num_iterations); + void ReportCumulativeStats(); + PerformanceStats& GetPerformanceStats(std::size_t index) { return performanceStats_[index]; } + + // Statistical methods + double GetAverageRuntime() const; + double GetP50Runtime() const; + double GetP90Runtime() const; + double GetP95Runtime() const; + double GetP99Runtime() const; + double GetMinRuntime() const; + double GetMaxRuntime() const; + + // Throughput statistics + double GetAverageThroughput() const; + double GetCumulativeThroughput() const; + double GetP90Throughput() const; + double GetP95Throughput() const; + double GetP99Throughput() const; + + // Comprehensive reporting + void ReportDetailedStats() const; + + // Success checking methods + bool AreAllIterationsSuccessful() const; + std::size_t GetSuccessfulIterationsCount() const; + std::size_t GetFailedIterationsCount() const; + double GetSuccessRate() const; + + void ReportTabularStats() const { + std::cout << "\n=== Tabular Performance Statistics ===" << std::endl; + std::cout << std::left + << std::setw(15) << "Name" + << std::setw(10) << "Success" + << std::setw(12) << "Runtime (s)" + << std::setw(15) << "Audio (s)" + << std::setw(15) << "Throughput" + << std::endl; + std::cout << std::string(75, '-') << std::endl; + + for (size_t i = 0; i < performanceStats_.size(); ++i) { + const auto& stats = performanceStats_[i]; + std::string name = object_name_ + "-" + std::to_string(i); + std::string success = stats.IsSuccess() ? "true" : "false"; + double runtime = stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + double audio_duration = audio_duration_seconds_; // Total audio processed + double throughput = stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + + std::cout << std::left + << std::setw(15) << name + << std::setw(10) << success + << std::fixed << std::setprecision(3) + << std::setw(12) << runtime + << std::setw(15) << audio_duration + << std::setw(15) << throughput + << std::endl; + } + std::cout << std::string(60, '-') << std::endl; + + // Summary row + size_t success_count = 0; + double total_runtime = 0.0; + double total_audio_processed = audio_duration_seconds_ * performanceStats_.size(); // Total audio across all iterations + double total_throughput = 0.0; + + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) success_count++; + total_runtime += stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + total_throughput += stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + } + + std::cout << std::left + << std::setw(15) << "SUMMARY" + << std::setw(10) << (success_count == performanceStats_.size() ? "ALL" : std::to_string(success_count) + "/" + std::to_string(performanceStats_.size())) + << std::fixed << std::setprecision(3) + << std::setw(12) << total_runtime + << std::setw(15) << total_audio_processed + << std::setw(15) << total_throughput + << std::endl; + std::cout << std::endl; + } + }; + +} // namespace nvidia::riva::utils + +#endif // STATS_BUILDER_H \ No newline at end of file diff --git a/third_party/BUILD.websocketpp b/third_party/BUILD.websocketpp new file mode 100644 index 0000000..5269bf1 --- /dev/null +++ b/third_party/BUILD.websocketpp @@ -0,0 +1,20 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: MIT +""" + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "websocketpp", + hdrs = glob([ + "websocketpp/*.hpp", + "websocketpp/**/*.hpp" + ]), + strip_include_prefix = ".", + deps = [ + "@com_google_absl//absl/strings", + ], +) \ No newline at end of file From 33c39a1dfce94be80d5c3e1f67da504a9019582a Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Thu, 24 Jul 2025 23:00:01 +0530 Subject: [PATCH 2/6] chore: updating with the params --- riva/clients/asr/riva_realtime_asr_client.cc | 271 ++++++++++++++-- riva/clients/realtime/recognition_client.cpp | 317 ++++++++++++++++++- riva/clients/realtime/recognition_client.h | 61 ++++ 3 files changed, 613 insertions(+), 36 deletions(-) diff --git a/riva/clients/asr/riva_realtime_asr_client.cc b/riva/clients/asr/riva_realtime_asr_client.cc index 28958f2..e4fa74d 100644 --- a/riva/clients/asr/riva_realtime_asr_client.cc +++ b/riva/clients/asr/riva_realtime_asr_client.cc @@ -16,12 +16,73 @@ #include #include #include +#include +#include #include "riva/clients/realtime/recognition_client.h" #include "riva/utils/stats_builder/stats_builder.h" +#include +#include +#include +#include +#include + +// Add these includes for HTTP functionality +#include +#include +#include +#include +#include +#include using namespace nvidia::riva::utils; using namespace nvidia::riva::realtime; +// Define command-line flags (matching streaming client) +DEFINE_string(audio_file, "", "Folder that contains audio files to transcribe or individual audio file name"); +DEFINE_int32(max_alternatives, 1, "Maximum number of alternative transcripts to return (up to limit configured on server)"); +DEFINE_bool(profanity_filter, false, "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); +DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested"); +DEFINE_bool(simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); +DEFINE_string(audio_device, "", "Name of audio device to use"); +DEFINE_string(riva_uri, "ws://127.0.0.1:9090/v1/realtime?intent=transcription", "URI to access riva-server"); +DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files"); +DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight"); +DEFINE_int32(chunk_duration_ms, 100, "Chunk duration in milliseconds"); +DEFINE_bool(print_transcripts, true, "Print final transcripts"); +DEFINE_bool(interim_results, true, "Print intermediate transcripts"); +DEFINE_string(output_filename, "final_transcripts.json", "Filename of .json file containing output transcripts"); +DEFINE_string(model_name, "", "Name of the TRTIS model to use"); +DEFINE_string(language_code, "en-US", "Language code of the model to use"); +DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); +DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); +DEFINE_bool(verbatim_transcripts, true, "True returns text exactly as it was said with no normalization. False applies text inverse normalization"); +DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file"); +DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key"); +DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file"); +DEFINE_bool(use_ssl, false, "Whether to use SSL credentials or not. If ssl_root_cert is specified, this is assumed to be true"); +DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); +DEFINE_int32(start_history, -1, "Value (in milliseconds) to detect and initiate start of speech utterance"); +DEFINE_double(start_threshold, -1., "Threshold value to determine at what percentage start of speech is initiated"); +DEFINE_int32(stop_history, -1, "Value (in milliseconds) to detect endpoint and reset decoder"); +DEFINE_double(stop_threshold, -1., "Threshold value to determine when endpoint detected"); +DEFINE_int32(stop_history_eou, -1, "Value (in milliseconds) to detect endpoint and generate an intermediate final transcript"); +DEFINE_double(stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string(custom_configuration, "", "Custom configurations to be sent to the server as key value pairs "); +DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); +DEFINE_int32(diarization_max_speakers, 4, "Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)"); +DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); +DEFINE_uint64(max_grpc_message_size, 16777216, "Max GRPC message size"); + +// Additional realtime-specific flags +DEFINE_int32(connection_timeout_ms, 100000, "Connection timeout in milliseconds"); +DEFINE_int32(session_init_timeout_ms, 100000, "Session initialization timeout in milliseconds"); +DEFINE_int32(session_update_timeout_ms, 100000, "Session update timeout in milliseconds"); +DEFINE_int32(transcription_timeout_ms, 100000, "Transcription timeout in milliseconds"); +DEFINE_int32(chunk_delay_time_ms, 160, "Delay between audio chunks in milliseconds"); +DEFINE_bool(verbose_logging, false, "Enable verbose logging"); +DEFINE_bool(show_detailed_stats, true, "Show detailed statistics"); +DEFINE_bool(show_tabular_stats, true, "Show tabular statistics"); // Global client pointer for signal handling std::vector g_clients; @@ -43,6 +104,8 @@ std::string format_throughput(double value) { return oss.str(); } + + // Function to run the client example void client_runner( const std::string& uri, const std::shared_ptr& audio_chunks, @@ -56,7 +119,79 @@ void client_runner( const std::string& uri, { nvidia::riva::realtime::RecognitionClient client(perfCounter.GetObjectName(), audio_chunks, perfCounter); - client.SetVerboseLogging(false); + // Extract server URL from URI (remove ws:// and path) + std::string server_url = uri; + if (server_url.find("ws://") == 0) { + server_url = server_url.substr(5); // Remove "ws://" + } else if (server_url.find("wss://") == 0) { + server_url = server_url.substr(6); // Remove "wss://" + } + + // Remove path part (everything after first /) + size_t path_pos = server_url.find('/'); + if (path_pos != std::string::npos) { + server_url = server_url.substr(0, path_pos); + } + + client.SetServerUrl(server_url); + + // Set session configuration from command line flags (these will override defaults) + nvidia::riva::realtime::SessionConfig sessionConfig; + + // Only set values if they were provided by user (not default values) + if (!FLAGS_language_code.empty() && FLAGS_language_code != "en-US") { + sessionConfig.language_code_ = FLAGS_language_code; + } + if (!FLAGS_model_name.empty()) { + sessionConfig.model_name_ = FLAGS_model_name; + } + if (FLAGS_max_alternatives != 1) { + sessionConfig.max_alternatives_ = FLAGS_max_alternatives; + } + if (!FLAGS_automatic_punctuation) { // Default is true, so only override if false + sessionConfig.automatic_punctuation_ = FLAGS_automatic_punctuation; + } + if (!FLAGS_word_time_offsets) { // Default is true, so only override if false + sessionConfig.word_time_offsets_ = FLAGS_word_time_offsets; + } + if (FLAGS_profanity_filter) { // Default is false, so only override if true + sessionConfig.profanity_filter_ = FLAGS_profanity_filter; + } + if (!FLAGS_verbatim_transcripts) { // Default is true, so only override if false + sessionConfig.verbatim_transcripts_ = FLAGS_verbatim_transcripts; + } + if (!FLAGS_boosted_words_file.empty()) { + sessionConfig.boosted_words_file_ = FLAGS_boosted_words_file; + sessionConfig.boosted_words_score_ = FLAGS_boosted_words_score; + } + if (FLAGS_speaker_diarization) { // Default is false, so only override if true + sessionConfig.speaker_diarization_ = FLAGS_speaker_diarization; + sessionConfig.diarization_max_speakers_ = FLAGS_diarization_max_speakers; + } + if (FLAGS_start_history > 0) { + sessionConfig.start_history_ = FLAGS_start_history; + } + if (FLAGS_start_threshold > 0) { + sessionConfig.start_threshold_ = FLAGS_start_threshold; + } + if (FLAGS_stop_history > 0) { + sessionConfig.stop_history_ = FLAGS_stop_history; + } + if (FLAGS_stop_threshold > 0) { + sessionConfig.stop_threshold_ = FLAGS_stop_threshold; + } + if (FLAGS_stop_history_eou > 0) { + sessionConfig.stop_history_eou_ = FLAGS_stop_history_eou; + } + if (FLAGS_stop_threshold_eou > 0) { + sessionConfig.stop_threshold_eou_ = FLAGS_stop_threshold_eou; + } + if (!FLAGS_custom_configuration.empty()) { + sessionConfig.custom_configuration_ = FLAGS_custom_configuration; + } + + client.SetSessionConfig(sessionConfig); + client.SetVerboseLogging(FLAGS_verbose_logging); client.SetTimingConfig(connectionTimeoutInMs, sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, transcriptionTimeoutInMs, chunkDelayTimeInMs); // Step 1: Connect to the WebSocket server @@ -126,24 +261,112 @@ void client_runner( const std::string& uri, perfCounter.ReportStats(); } - int main(int argc, char* argv[]) { - std::size_t num_iterations = 1; - std::size_t num_parallel_clients = 50; - bool simulateRealtime = true; - - const std::size_t connectionTimeoutInMs = 1000 * 100; - const std::size_t sessionInitTimeoutInMs = 1000 * 100; - const std::size_t sessionUpdateTimeoutInMs = 1000 * 100; - const std::size_t transcriptionTimeoutInMs = 1000 * 100; - - // Realistic audio chunk timing - based on typical microphone sampling - // For 16kHz audio with 160ms chunks, this would be 160ms - const std::size_t chunkDelayTimeInMs = 160; // Realistic delay matching chunk duration - - const std::string uri = "ws://127.0.0.1:9090/v1/realtime?intent=transcription"; - const std::string audio_file_path = "/home/yhayaran/workspace/codebase/web-socket/new_ws_client/test_files/out5.wav"; - const std::size_t chunk_duration_ms = chunkDelayTimeInMs; + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; + + // Set up usage message + std::stringstream str_usage; + str_usage << "Usage: riva_realtime_asr_client " << std::endl; + str_usage << " --audio_file= " << std::endl; + str_usage << " --audio_device= " << std::endl; + str_usage << " --automatic_punctuation=" << std::endl; + str_usage << " --max_alternatives=" << std::endl; + str_usage << " --profanity_filter=" << std::endl; + str_usage << " --word_time_offsets=" << std::endl; + str_usage << " --riva_uri= " << std::endl; + str_usage << " --chunk_duration_ms= " << std::endl; + str_usage << " --interim_results= " << std::endl; + str_usage << " --simulate_realtime= " << std::endl; + str_usage << " --num_iterations= " << std::endl; + str_usage << " --num_parallel_requests= " << std::endl; + str_usage << " --print_transcripts= " << std::endl; + str_usage << " --output_filename=" << std::endl; + str_usage << " --verbatim_transcripts=" << std::endl; + str_usage << " --language_code=" << std::endl; + str_usage << " --boosted_words_file=" << std::endl; + str_usage << " --boosted_words_score=" << std::endl; + str_usage << " --ssl_root_cert=" << std::endl; + str_usage << " --ssl_client_key=" << std::endl; + str_usage << " --ssl_client_cert=" << std::endl; + str_usage << " --model_name=" << std::endl; + str_usage << " --metadata=" << std::endl; + str_usage << " --start_history=" << std::endl; + str_usage << " --start_threshold=" << std::endl; + str_usage << " --stop_history=" << std::endl; + str_usage << " --stop_history_eou=" << std::endl; + str_usage << " --stop_threshold=" << std::endl; + str_usage << " --stop_threshold_eou=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; + str_usage << " --speaker_diarization=" << std::endl; + str_usage << " --diarization_max_speakers=" << std::endl; + str_usage << " --timeout_ms=" << std::endl; + str_usage << " --max_grpc_message_size=" << std::endl; + str_usage << " --connection_timeout_ms=" << std::endl; + str_usage << " --session_init_timeout_ms=" << std::endl; + str_usage << " --session_update_timeout_ms=" << std::endl; + str_usage << " --transcription_timeout_ms=" << std::endl; + str_usage << " --chunk_delay_time_ms=" << std::endl; + str_usage << " --verbose_logging=" << std::endl; + str_usage << " --show_detailed_stats=" << std::endl; + str_usage << " --show_tabular_stats=" << std::endl; + + gflags::SetUsageMessage(str_usage.str()); + + if (argc < 2) { + std::cout << gflags::ProgramUsage(); + return 1; + } + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (argc > 1) { + std::cout << gflags::ProgramUsage(); + return 1; + } + + // Validate arguments + if (FLAGS_max_alternatives < 1) { + std::cerr << "max_alternatives must be greater than or equal to 1." << std::endl; + return 1; + } + + if (FLAGS_num_iterations < 1) { + std::cerr << "num_iterations must be greater than 0" << std::endl; + return 1; + } + + if (FLAGS_num_parallel_requests < 1) { + std::cerr << "num_parallel_requests must be greater than 0" << std::endl; + return 1; + } + + // Check if audio file or device is specified + if (FLAGS_audio_file.empty() && FLAGS_audio_device.empty()) { + std::cerr << "Either --audio_file or --audio_device must be specified" << std::endl; + return 1; + } + + // Validate audio file exists if specified + if (!FLAGS_audio_file.empty() && !std::filesystem::exists(FLAGS_audio_file)) { + std::cerr << "Audio file does not exist: " << FLAGS_audio_file << std::endl; + return 1; + } + + // Use command-line arguments + const std::string uri = FLAGS_riva_uri; + const std::string audio_file_path = FLAGS_audio_file; + const std::size_t num_iterations = FLAGS_num_iterations; + const std::size_t num_parallel_clients = FLAGS_num_parallel_requests; + const bool simulateRealtime = FLAGS_simulate_realtime; + + const std::size_t connectionTimeoutInMs = FLAGS_connection_timeout_ms; + const std::size_t sessionInitTimeoutInMs = FLAGS_session_init_timeout_ms; + const std::size_t sessionUpdateTimeoutInMs = FLAGS_session_update_timeout_ms; + const std::size_t transcriptionTimeoutInMs = FLAGS_transcription_timeout_ms; + const std::size_t chunkDelayTimeInMs = FLAGS_chunk_delay_time_ms; + const std::size_t chunk_duration_ms = FLAGS_chunk_duration_ms; + const auto audio_chunks = std::make_shared(audio_file_path, chunk_duration_ms); if (!audio_chunks->Init()) { std::cerr << "Failed to initialize audio chunks" << std::endl; @@ -196,11 +419,15 @@ int main(int argc, char* argv[]) { signal(SIGINT, signal_handler); signal(SIGTERM, signal_handler); - // Uncomment this section to show detailed stats including success rates - statsBuilder.ReportCumulativeStats(); - statsBuilder.ReportDetailedStats(); - statsBuilder.ReportTabularStats(); + // Conditional stats reporting based on flags + if (FLAGS_show_detailed_stats) { + statsBuilder.ReportDetailedStats(); + } + if (FLAGS_show_tabular_stats) { + statsBuilder.ReportTabularStats(); + } + statsBuilder.ReportCumulativeStats(); overallPerf.ReportStats(); return 0; } \ No newline at end of file diff --git a/riva/clients/realtime/recognition_client.cpp b/riva/clients/realtime/recognition_client.cpp index a0ae91d..4d6b9f1 100644 --- a/riva/clients/realtime/recognition_client.cpp +++ b/riva/clients/realtime/recognition_client.cpp @@ -9,7 +9,230 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// Helper method for HTTP requests using raw sockets +std::string nvidia::riva::realtime::RecognitionClient::MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body) { + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::cerr << "Failed to create socket" << std::endl; + return ""; + } + + struct hostent* server = gethostbyname(host.c_str()); + if (server == nullptr) { + std::cerr << "Failed to resolve host: " << host << std::endl; + close(sock); + return ""; + } + + struct sockaddr_in serv_addr; + memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + memcpy(&serv_addr.sin_addr.s_addr, server->h_addr, server->h_length); + serv_addr.sin_port = htons(port); + + if (connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) { + std::cerr << "Failed to connect to " << host << ":" << port << std::endl; + close(sock); + return ""; + } + + // Build HTTP request + std::ostringstream request; + request << method << " " << path << " HTTP/1.1\r\n"; + request << "Host: " << host << ":" << port << "\r\n"; + request << "Content-Type: application/json\r\n"; + request << "Content-Length: " << body.length() << "\r\n"; + request << "Connection: close\r\n"; + request << "\r\n"; + request << body; + + std::string request_str = request.str(); + + // Send request + if (send(sock, request_str.c_str(), request_str.length(), 0) < 0) { + std::cerr << "Failed to send HTTP request" << std::endl; + close(sock); + return ""; + } + + // Receive response + std::string response; + char buffer[4096]; + int bytes_received; + + while ((bytes_received = recv(sock, buffer, sizeof(buffer) - 1, 0)) > 0) { + buffer[bytes_received] = '\0'; + response += buffer; + } + + close(sock); + + // Extract JSON body from HTTP response + size_t body_start = response.find("\r\n\r\n"); + if (body_start != std::string::npos) { + return response.substr(body_start + 4); + } + + return response; +} + +bool nvidia::riva::realtime::RecognitionClient::InitializeHttpSession() { + if (server_url_.empty()) { + std::cerr << "Server URL not set" << std::endl; + return false; + } + + // Parse server URL to extract host and port + std::string host = server_url_; + int port = 80; // Default HTTP port + + // Check if port is specified + size_t colon_pos = host.find(':'); + if (colon_pos != std::string::npos) { + port = std::stoi(host.substr(colon_pos + 1)); + host = host.substr(0, colon_pos); + } + + std::string path = "/v1/realtime/transcription_sessions"; + std::string response_body = MakeHttpRequest(host, port, path, "POST", "{}"); + + if (response_body.empty()) { + std::cerr << "HTTP request failed" << std::endl; + return false; + } + + try { + // Parse JSON response using rapidjson + rapidjson::Document session_data; + if (session_data.Parse(response_body.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON response" << std::endl; + return false; + } + + // Extract session ID + if (session_data.HasMember("id")) { + session_id_ = session_data["id"].GetString(); + } else { + std::cerr << "No session ID found in response" << std::endl; + return false; + } + + // Store server defaults but don't overwrite user-provided values + SessionConfig serverDefaults; + + if (session_data.HasMember("input_audio_transcription")) { + const auto& transcription = session_data["input_audio_transcription"]; + if (transcription.HasMember("language")) { + serverDefaults.language_code_ = transcription["language"].GetString(); + } + if (transcription.HasMember("model")) { + serverDefaults.model_name_ = transcription["model"].GetString(); + } + } + + if (session_data.HasMember("recognition_config")) { + const auto& recognition = session_data["recognition_config"]; + if (recognition.HasMember("max_alternatives")) { + serverDefaults.max_alternatives_ = recognition["max_alternatives"].GetInt(); + } + if (recognition.HasMember("enable_automatic_punctuation")) { + serverDefaults.automatic_punctuation_ = recognition["enable_automatic_punctuation"].GetBool(); + } + if (recognition.HasMember("enable_word_time_offsets")) { + serverDefaults.word_time_offsets_ = recognition["enable_word_time_offsets"].GetBool(); + } + if (recognition.HasMember("enable_profanity_filter")) { + serverDefaults.profanity_filter_ = recognition["enable_profanity_filter"].GetBool(); + } + if (recognition.HasMember("enable_verbatim_transcripts")) { + serverDefaults.verbatim_transcripts_ = recognition["enable_verbatim_transcripts"].GetBool(); + } + } + + if (session_data.HasMember("speaker_diarization")) { + const auto& diarization = session_data["speaker_diarization"]; + if (diarization.HasMember("enable_speaker_diarization")) { + serverDefaults.speaker_diarization_ = diarization["enable_speaker_diarization"].GetBool(); + } + if (diarization.HasMember("max_speaker_count")) { + serverDefaults.diarization_max_speakers_ = diarization["max_speaker_count"].GetInt(); + } + } + + if (session_data.HasMember("endpointing_config")) { + const auto& endpointing = session_data["endpointing_config"]; + if (endpointing.HasMember("start_history")) { + serverDefaults.start_history_ = endpointing["start_history"].GetInt(); + } + if (endpointing.HasMember("start_threshold")) { + serverDefaults.start_threshold_ = endpointing["start_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history")) { + serverDefaults.stop_history_ = endpointing["stop_history"].GetInt(); + } + if (endpointing.HasMember("stop_threshold")) { + serverDefaults.stop_threshold_ = endpointing["stop_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history_eou")) { + serverDefaults.stop_history_eou_ = endpointing["stop_history_eou"].GetInt(); + } + if (endpointing.HasMember("stop_threshold_eou")) { + serverDefaults.stop_threshold_eou_ = endpointing["stop_threshold_eou"].GetDouble(); + } + } + + // Only use server defaults for values that haven't been set by user + if (sessionConfig_.language_code_.empty()) { + sessionConfig_.language_code_ = serverDefaults.language_code_; + } + if (sessionConfig_.model_name_.empty()) { + sessionConfig_.model_name_ = serverDefaults.model_name_; + } + if (sessionConfig_.max_alternatives_ == 0) { + sessionConfig_.max_alternatives_ = serverDefaults.max_alternatives_; + } + if (sessionConfig_.start_history_ == -1) { + sessionConfig_.start_history_ = serverDefaults.start_history_; + } + if (sessionConfig_.start_threshold_ == -1.0) { + sessionConfig_.start_threshold_ = serverDefaults.start_threshold_; + } + if (sessionConfig_.stop_history_ == -1) { + sessionConfig_.stop_history_ = serverDefaults.stop_history_; + } + if (sessionConfig_.stop_threshold_ == -1.0) { + sessionConfig_.stop_threshold_ = serverDefaults.stop_threshold_; + } + if (sessionConfig_.stop_history_eou_ == -1) { + sessionConfig_.stop_history_eou_ = serverDefaults.stop_history_eou_; + } + if (sessionConfig_.stop_threshold_eou_ == -1.0) { + sessionConfig_.stop_threshold_eou_ = serverDefaults.stop_threshold_eou_; + } + + // Convert rapidjson document to string for logging + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + session_data.Accept(writer); + + std::cout << "[" << objectName_ << "] Session initialized with defaults: " << buffer.GetString() << std::endl; + return true; + + } catch (const std::exception& e) { + std::cerr << "Failed to parse session response: " << e.what() << std::endl; + return false; + } +} nvidia::riva::realtime::RecognitionClient::RecognitionClient( const std::string& objectName, @@ -30,6 +253,17 @@ nvidia::riva::realtime::RecognitionClient::RecognitionClient( perfCounter_(perfCounter) { nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); + + // Initialize default session config + sessionConfig_.language_code_ = "en-US"; + sessionConfig_.model_name_ = "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble"; + sessionConfig_.max_alternatives_ = 1; + sessionConfig_.automatic_punctuation_ = true; + sessionConfig_.word_time_offsets_ = true; + sessionConfig_.profanity_filter_ = false; + sessionConfig_.verbatim_transcripts_ = true; + sessionConfig_.speaker_diarization_ = false; + sessionConfig_.diarization_max_speakers_ = 4; } @@ -189,19 +423,26 @@ void nvidia::riva::realtime::RecognitionClient::SendAudioDone() { } } -// Session initialization (inspired by Python realtime.py) +// Modify the InitializeSession method to call HTTP initialization first bool nvidia::riva::realtime::RecognitionClient::InitializeSession() { std::cout << "[" << objectName_ << "]" << " Initializing session..." << std::endl; - // Wait for the initial connection and session creation (increased from 1000ms to 3000ms) + // Step 1: Initialize HTTP session + if (!InitializeHttpSession()) { + std::cerr << "Failed to initialize HTTP session" << std::endl; + return false; + } + + // Step 2: Wait for the initial connection and session creation std::this_thread::sleep_for(std::chrono::milliseconds(3000)); - // Check if we're still connected + // Step 3: Check if we're still connected if (IsConnectionClosed()) { std::cerr << "Connection lost during session initialization" << std::endl; return false; } + // Step 4: Update session configuration return UpdateSessionConfig(); } @@ -213,7 +454,7 @@ bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { std::cout << "Using WAV file parameters - Sample rate: " << sampleRateHz << " Hz, Channels: " << numChannels << std::endl; - // Create session configuration similar to Python client + // Create session configuration using sessionConfig_ (which now has defaults + user overrides) rapidjson::Document doc; doc.SetObject(); rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); @@ -221,11 +462,19 @@ bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { // Create session config rapidjson::Value session_config(rapidjson::kObjectType); + // Add modalities + rapidjson::Value modalities(rapidjson::kArrayType); + modalities.PushBack(rapidjson::Value("text", allocator), allocator); + session_config.AddMember("modalities", modalities, allocator); + + // Add input audio format + session_config.AddMember("input_audio_format", rapidjson::Value("pcm16", allocator), allocator); + // Input audio transcription config rapidjson::Value transcription_config(rapidjson::kObjectType); - transcription_config.AddMember("language", "en-US", allocator); - transcription_config.AddMember("model", "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble", allocator); - transcription_config.AddMember("prompt", "", allocator); + transcription_config.AddMember("language", rapidjson::Value(sessionConfig_.language_code_.c_str(), allocator), allocator); + transcription_config.AddMember("model", rapidjson::Value(sessionConfig_.model_name_.c_str(), allocator), allocator); + transcription_config.AddMember("prompt", rapidjson::Value(rapidjson::kNullType), allocator); session_config.AddMember("input_audio_transcription", transcription_config, allocator); // Input audio params - use actual WAV file parameters @@ -234,18 +483,56 @@ bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { audio_params.AddMember("num_channels", numChannels, allocator); session_config.AddMember("input_audio_params", audio_params, allocator); - // Recognition config + // Recognition config - use session configuration rapidjson::Value recognition_config(rapidjson::kObjectType); - recognition_config.AddMember("max_alternatives", 1, allocator); - recognition_config.AddMember("enable_automatic_punctuation", false, allocator); - recognition_config.AddMember("enable_word_time_offsets", false, allocator); - recognition_config.AddMember("enable_profanity_filter", false, allocator); - recognition_config.AddMember("enable_verbatim_transcripts", false, allocator); + recognition_config.AddMember("max_alternatives", sessionConfig_.max_alternatives_, allocator); + recognition_config.AddMember("enable_automatic_punctuation", sessionConfig_.automatic_punctuation_, allocator); + recognition_config.AddMember("enable_word_time_offsets", sessionConfig_.word_time_offsets_, allocator); + recognition_config.AddMember("enable_profanity_filter", sessionConfig_.profanity_filter_, allocator); + recognition_config.AddMember("enable_verbatim_transcripts", sessionConfig_.verbatim_transcripts_, allocator); + recognition_config.AddMember("custom_configuration", rapidjson::Value(sessionConfig_.custom_configuration_.c_str(), allocator), allocator); session_config.AddMember("recognition_config", recognition_config, allocator); + // Speaker diarization config + rapidjson::Value diarization_config(rapidjson::kObjectType); + diarization_config.AddMember("enable_speaker_diarization", sessionConfig_.speaker_diarization_, allocator); + diarization_config.AddMember("max_speaker_count", sessionConfig_.diarization_max_speakers_, allocator); + session_config.AddMember("speaker_diarization", diarization_config, allocator); + + // Word boosting config + rapidjson::Value word_boosting_config(rapidjson::kObjectType); + bool enable_word_boosting = !sessionConfig_.boosted_words_file_.empty(); + word_boosting_config.AddMember("enable_word_boosting", enable_word_boosting, allocator); + + if (enable_word_boosting) { + rapidjson::Value word_list(rapidjson::kArrayType); + std::ifstream file(sessionConfig_.boosted_words_file_); + std::string word; + while (std::getline(file, word)) { + if (!word.empty()) { + word_list.PushBack(rapidjson::Value(word.c_str(), allocator), allocator); + } + } + word_boosting_config.AddMember("word_boosting_list", word_list, allocator); + } else { + rapidjson::Value empty_list(rapidjson::kArrayType); + word_boosting_config.AddMember("word_boosting_list", empty_list, allocator); + } + session_config.AddMember("word_boosting", word_boosting_config, allocator); + + // Endpointing config + rapidjson::Value endpointing_config(rapidjson::kObjectType); + endpointing_config.AddMember("start_history", sessionConfig_.start_history_, allocator); + endpointing_config.AddMember("start_threshold", sessionConfig_.start_threshold_, allocator); + endpointing_config.AddMember("stop_history", sessionConfig_.stop_history_, allocator); + endpointing_config.AddMember("stop_threshold", sessionConfig_.stop_threshold_, allocator); + endpointing_config.AddMember("stop_history_eou", sessionConfig_.stop_history_eou_, allocator); + endpointing_config.AddMember("stop_threshold_eou", sessionConfig_.stop_threshold_eou_, allocator); + session_config.AddMember("endpointing_config", endpointing_config, allocator); + // Create update request rapidjson::Value update_request(rapidjson::kObjectType); - update_request.AddMember("type", "transcription_session.update", allocator); + update_request.AddMember("type", rapidjson::Value("transcription_session.update", allocator), allocator); update_request.AddMember("session", session_config, allocator); // Send the update request @@ -266,6 +553,8 @@ bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { } } + + WaitForSessionUpdate(); return true; } diff --git a/riva/clients/realtime/recognition_client.h b/riva/clients/realtime/recognition_client.h index 5783647..c835c29 100644 --- a/riva/clients/realtime/recognition_client.h +++ b/riva/clients/realtime/recognition_client.h @@ -25,7 +25,48 @@ #include "base_client.h" #include "riva/utils/stats_builder/stats_builder.h" +// Add these includes for HTTP functionality +#include +#include +#include +#include +#include +#include + namespace nvidia::riva::realtime { + class SessionConfig { + public: + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + // Add session configuration parameters + std::string language_code_; + std::string model_name_; + int max_alternatives_; + bool automatic_punctuation_; + bool word_time_offsets_; + bool profanity_filter_; + bool verbatim_transcripts_; + std::string boosted_words_file_; + double boosted_words_score_; + bool speaker_diarization_; + int diarization_max_speakers_; + int start_history_; + double start_threshold_; + int stop_history_; + double stop_threshold_; + int stop_history_eou_; + double stop_threshold_eou_; + std::string custom_configuration_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; + }; + class RecognitionClient : public WebSocketClientBase { private: @@ -56,6 +97,19 @@ namespace nvidia::riva::realtime { // Audio processing std::shared_ptr audioChunksPtr_; + // Add session configuration + SessionConfig sessionConfig_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; + + // HTTP session initialization method + bool InitializeHttpSession(); + + // Helper method for HTTP requests + std::string MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body); + // Audio streaming methods void SendAudioAppend(const std::string& audioBase64); void SendAudioCommit(); @@ -79,6 +133,9 @@ namespace nvidia::riva::realtime { const std::size_t transcriptionTimeoutInMs, const std::size_t chunkDelayTimeInMs); + // Session configuration + void SetSessionConfig(const SessionConfig& config) { sessionConfig_ = config; } + // Session management methods bool InitializeSession(); bool UpdateSessionConfig(); @@ -91,6 +148,10 @@ namespace nvidia::riva::realtime { // WAV file processing methods void SendAudioChunks(const bool simulateRealtime = false); + + // Add method to set server URL + void SetServerUrl(const std::string& server_url) { server_url_ = server_url; } + std::string GetSessionId() const { return session_id_; } }; } // namespace nvidia::riva::realtime From 83351893f8f9c37c6c8babb81bf7eb864d636187 Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Mon, 28 Jul 2025 12:56:45 +0530 Subject: [PATCH 3/6] chore: moveing then realtime client to its dir --- Dockerfile | 2 +- README.md | 6 +++++ riva/clients/asr/BUILD | 25 ----------------- riva/clients/realtime/BUILD | 27 ++++++++++++++++++- .../riva_realtime_asr_client.cc | 0 5 files changed, 33 insertions(+), 27 deletions(-) rename riva/clients/{asr => realtime}/riva_realtime_asr_client.cc (100%) diff --git a/Dockerfile b/Dockerfile index c026eb9..73da6b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,7 +61,6 @@ FROM base as riva-clients WORKDIR /work COPY --from=builder /opt/riva/clients/asr/riva_asr_client /usr/local/bin/ -COPY --from=builder /opt/riva/clients/asr/riva_realtime_asr_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/asr/riva_streaming_asr_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/tts/riva_tts_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/tts/riva_tts_perf_client /usr/local/bin/ @@ -69,4 +68,5 @@ COPY --from=builder /opt/riva/clients/nlp/riva_nlp_punct /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_t2t_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2t_client /usr/local/bin/ COPY --from=builder /opt/riva/clients/nmt/riva_nmt_streaming_s2s_client /usr/local/bin/ +COPY --from=builder /opt/riva/clients/realtime/riva_realtime_asr_client /usr/local/bin/ COPY examples /work/examples diff --git a/README.md b/README.md index 6092fe8..6301d40 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ NVIDIA Riva is a GPU-accelerated SDK for building Speech AI applications that ar - **Automatic Speech Recognition (ASR)** - `riva_streaming_asr_client` - `riva_asr_client` + - `riva_realtime_asr_client` - **Speech Synthesis (TTS)** - `riva_tts_client` - `riva_tts_perf_client` @@ -73,6 +74,7 @@ You can find the built binaries in `bazel-bin/riva/clients` Riva comes with 2 ASR clients: 1. `riva_asr_client` for offline usage. Using this client, the server will wait until it receives the full audio file before transcribing it and sending it back to the client. 2. `riva_streaming_asr_client` for online usage. Using this client, the server will start transcribing after it receives a sufficient amount of audio data, "streaming" intermediate transcripts as it goes on back to the client. By default, it is set to transcribe after every `100ms`, this can be changed using the `--chunk_duration_ms` command line flag. +3. `riva_realtime_asr_client` for realtime (websocket) usage. This client establishes a persistent websocket connection to the server, allowing for bidirectional real-time communication. The server will start transcribing after it receives a sufficient amount of audio data and continuously stream intermediate transcripts back to the client as it processes the audio. By default, it is set to transcribe after every `100ms`, which can be changed using the `--chunk_duration_ms` command line flag. To use the clients, simply pass in a folder containing audio files or an individual audio file name with the `audio_file` flag: ``` @@ -82,6 +84,10 @@ or ``` $ riva_asr_client --audio_file audio_folder ``` +or +``` +$ riva_realtime_asr_client --audio_file individual_audio_file.wav +``` Note that only single-channel audio files in the `.wav` format are currently supported. diff --git a/riva/clients/asr/BUILD b/riva/clients/asr/BUILD index b2aa4ae..d88c46b 100644 --- a/riva/clients/asr/BUILD +++ b/riva/clients/asr/BUILD @@ -115,31 +115,6 @@ cc_binary( ], ) -cc_binary( - name = "riva_realtime_asr_client", - srcs = ["riva_realtime_asr_client.cc"], - includes = ["-Irealtime"], - deps = [ - "//riva/clients/realtime:realtime_audio_client_lib", - "@websocketpp//:websocketpp", - "@rapidjson//:rapidjson", - "//riva/utils/stats_builder:stats_builder_lib", - "//riva/utils/wav:reader", - ] + select({ - "@platforms//cpu:aarch64": [ - "@alsa_aarch64//:libasound" - ], - "//conditions:default": [ - "@alsa//:libasound" - ], - }), - linkopts = [ - "-lssl", - "-lcrypto", - "-lboost_system", - ] -) - cc_test( name = "streaming_recognize_client_test", srcs = ["streaming_recognize_client_test.cc"], diff --git a/riva/clients/realtime/BUILD b/riva/clients/realtime/BUILD index 7de3f31..fdd75d9 100644 --- a/riva/clients/realtime/BUILD +++ b/riva/clients/realtime/BUILD @@ -31,4 +31,29 @@ cc_library( "@glog//:glog", "@com_github_gflags_gflags//:gflags", ], -) \ No newline at end of file +) + +cc_binary( + name = "riva_realtime_asr_client", + srcs = ["riva_realtime_asr_client.cc"], + includes = ["-Irealtime"], + deps = [ + ":realtime_audio_client_lib", + "@websocketpp//:websocketpp", + "@rapidjson//:rapidjson", + "//riva/utils/stats_builder:stats_builder_lib", + "//riva/utils/wav:reader", + ] + select({ + "@platforms//cpu:aarch64": [ + "@alsa_aarch64//:libasound" + ], + "//conditions:default": [ + "@alsa//:libasound" + ], + }), + linkopts = [ + "-lssl", + "-lcrypto", + "-lboost_system", + ] +) \ No newline at end of file diff --git a/riva/clients/asr/riva_realtime_asr_client.cc b/riva/clients/realtime/riva_realtime_asr_client.cc similarity index 100% rename from riva/clients/asr/riva_realtime_asr_client.cc rename to riva/clients/realtime/riva_realtime_asr_client.cc From 48b6d3b955f16c929c2991dc78ba228be10c09a1 Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Mon, 28 Jul 2025 15:59:19 +0530 Subject: [PATCH 4/6] chore: updating the circle ci config --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 238bafe..5103715 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: steps: - run: name: "Install build dependencies" - command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev" + command: "sudo apt-get --allow-releaseinfo-change update && sudo apt-get install -y wget libasound2-dev libopus-dev libopusfile-dev libboost-all-dev" - run: name: "Install bazel" command: "wget https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-linux-amd64 && sudo mv bazelisk-linux-amd64 /usr/local/bin/bazelisk && sudo chmod +x /usr/local/bin/bazelisk" From b643f20a6409e0feeb338c635a95edcdbb207c1b Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Sat, 9 Aug 2025 13:55:04 +0530 Subject: [PATCH 5/6] chore: rename recognition client to realtime client --- riva/clients/realtime/BUILD | 4 +-- riva/clients/realtime/base_client.h | 6 ++-- ...gnition_client.cpp => realtime_client.cpp} | 30 +++++++++---------- ...recognition_client.h => realtime_client.h} | 12 ++++---- .../realtime/riva_realtime_asr_client.cc | 6 ++-- 5 files changed, 29 insertions(+), 29 deletions(-) rename riva/clients/realtime/{recognition_client.cpp => realtime_client.cpp} (95%) rename riva/clients/realtime/{recognition_client.h => realtime_client.h} (95%) diff --git a/riva/clients/realtime/BUILD b/riva/clients/realtime/BUILD index fdd75d9..3d7d4fe 100644 --- a/riva/clients/realtime/BUILD +++ b/riva/clients/realtime/BUILD @@ -16,12 +16,12 @@ cc_library( srcs = [ "audio_chunks.cpp", "base_client.cpp", - "recognition_client.cpp", + "realtime_client.cpp", ], hdrs = [ "audio_chunks.h", "base_client.h", - "recognition_client.h", + "realtime_client.h", ], deps = [ "//riva/utils/wav:reader", diff --git a/riva/clients/realtime/base_client.h b/riva/clients/realtime/base_client.h index 400f319..6ebb3db 100644 --- a/riva/clients/realtime/base_client.h +++ b/riva/clients/realtime/base_client.h @@ -3,8 +3,8 @@ * SPDX-License-Identifier: MIT */ -#ifndef REALTIME_CLIENT_H -#define REALTIME_CLIENT_H +#ifndef BASE_REALTIME_CLIENT_H +#define BASE_REALTIME_CLIENT_H #include #include @@ -83,4 +83,4 @@ namespace nvidia::riva::realtime { virtual void HandleMessage(const std::string& message) = 0; }; } // namespace nvidia::riva::realtime -#endif // REALTIME_CLIENT_H \ No newline at end of file +#endif // BASE_REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/recognition_client.cpp b/riva/clients/realtime/realtime_client.cpp similarity index 95% rename from riva/clients/realtime/recognition_client.cpp rename to riva/clients/realtime/realtime_client.cpp index 4d6b9f1..94d4953 100644 --- a/riva/clients/realtime/recognition_client.cpp +++ b/riva/clients/realtime/realtime_client.cpp @@ -3,7 +3,7 @@ * SPDX-License-Identifier: MIT */ -#include "recognition_client.h" +#include "realtime_client.h" #include "base_client.h" #include #include @@ -20,7 +20,7 @@ #include // Helper method for HTTP requests using raw sockets -std::string nvidia::riva::realtime::RecognitionClient::MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body) { +std::string nvidia::riva::realtime::RealtimeClient::MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body) { int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { std::cerr << "Failed to create socket" << std::endl; @@ -86,7 +86,7 @@ std::string nvidia::riva::realtime::RecognitionClient::MakeHttpRequest(const std return response; } -bool nvidia::riva::realtime::RecognitionClient::InitializeHttpSession() { +bool nvidia::riva::realtime::RealtimeClient::InitializeHttpSession() { if (server_url_.empty()) { std::cerr << "Server URL not set" << std::endl; return false; @@ -234,7 +234,7 @@ bool nvidia::riva::realtime::RecognitionClient::InitializeHttpSession() { } } -nvidia::riva::realtime::RecognitionClient::RecognitionClient( +nvidia::riva::realtime::RealtimeClient::RealtimeClient( const std::string& objectName, const std::shared_ptr audioChunksPtr, nvidia::riva::utils::PerformanceStats& perfCounter) @@ -267,7 +267,7 @@ nvidia::riva::realtime::RecognitionClient::RecognitionClient( } -void nvidia::riva::realtime::RecognitionClient::SetTimingConfig( const std::size_t connectionTimeoutInMs, +void nvidia::riva::realtime::RealtimeClient::SetTimingConfig( const std::size_t connectionTimeoutInMs, const std::size_t sessionInitTimeoutInMs, const std::size_t sessionUpdateTimeoutInMs, const std::size_t transcriptionTimeoutInMs, @@ -280,11 +280,11 @@ void nvidia::riva::realtime::RecognitionClient::SetTimingConfig( const std::siz nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); } -void nvidia::riva::realtime::RecognitionClient::Log(const std::string& message) { +void nvidia::riva::realtime::RealtimeClient::Log(const std::string& message) { std::cout << "[" << objectName_ << "]" << message << std::endl; } -bool nvidia::riva::realtime::RecognitionClient::WaitForTranscriptionCompletion() { +bool nvidia::riva::realtime::RealtimeClient::WaitForTranscriptionCompletion() { std::unique_lock lock(transcriptionMutex_); // Reset completion flag @@ -306,7 +306,7 @@ bool nvidia::riva::realtime::RecognitionClient::WaitForTranscriptionCompletion() return completed; } -bool nvidia::riva::realtime::RecognitionClient::WaitForSessionUpdate() { +bool nvidia::riva::realtime::RealtimeClient::WaitForSessionUpdate() { std::unique_lock lock(sessionMutex_); if (sessionUpdated_) { @@ -328,7 +328,7 @@ bool nvidia::riva::realtime::RecognitionClient::WaitForSessionUpdate() { } // Send audio buffer append message (inspired by Python realtime.py) -void nvidia::riva::realtime::RecognitionClient::SendAudioAppend(const std::string& audioBase64) +void nvidia::riva::realtime::RealtimeClient::SendAudioAppend(const std::string& audioBase64) { std::lock_guard lock(connectionMutex_); if (IsConnectionOpen()) @@ -360,7 +360,7 @@ void nvidia::riva::realtime::RecognitionClient::SendAudioAppend(const std::strin } // Send audio buffer commit message (inspired by Python realtime.py) -void nvidia::riva::realtime::RecognitionClient::SendAudioCommit() { +void nvidia::riva::realtime::RealtimeClient::SendAudioCommit() { std::lock_guard lock(connectionMutex_); if (IsConnectionOpen()) { @@ -391,7 +391,7 @@ void nvidia::riva::realtime::RecognitionClient::SendAudioCommit() { } // Send audio buffer done message (inspired by Python realtime.py) -void nvidia::riva::realtime::RecognitionClient::SendAudioDone() { +void nvidia::riva::realtime::RealtimeClient::SendAudioDone() { std::lock_guard lock(connectionMutex_); if (IsConnectionOpen()) { @@ -424,7 +424,7 @@ void nvidia::riva::realtime::RecognitionClient::SendAudioDone() { } // Modify the InitializeSession method to call HTTP initialization first -bool nvidia::riva::realtime::RecognitionClient::InitializeSession() { +bool nvidia::riva::realtime::RealtimeClient::InitializeSession() { std::cout << "[" << objectName_ << "]" << " Initializing session..." << std::endl; // Step 1: Initialize HTTP session @@ -446,7 +446,7 @@ bool nvidia::riva::realtime::RecognitionClient::InitializeSession() { return UpdateSessionConfig(); } -bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { +bool nvidia::riva::realtime::RealtimeClient::UpdateSessionConfig() { int sampleRateHz = audioChunksPtr_->GetSampleRateHz(); int numChannels = audioChunksPtr_->GetNumChannels(); @@ -560,7 +560,7 @@ bool nvidia::riva::realtime::RecognitionClient::UpdateSessionConfig() { } // Send audio chunks -void nvidia::riva::realtime::RecognitionClient::SendAudioChunks(const bool simulateRealtime) { +void nvidia::riva::realtime::RealtimeClient::SendAudioChunks(const bool simulateRealtime) { if (audioChunksPtr_ == nullptr) { std::cerr << "Audio chunks pointer is null. Please call InitializeSession first." << std::endl; return; @@ -624,7 +624,7 @@ void nvidia::riva::realtime::RecognitionClient::SendAudioChunks(const bool simul SendAudioDone(); } -void nvidia::riva::realtime::RecognitionClient::HandleMessage(const std::string& message) { +void nvidia::riva::realtime::RealtimeClient::HandleMessage(const std::string& message) { bool is_last_result = false; rapidjson::Document doc; diff --git a/riva/clients/realtime/recognition_client.h b/riva/clients/realtime/realtime_client.h similarity index 95% rename from riva/clients/realtime/recognition_client.h rename to riva/clients/realtime/realtime_client.h index c835c29..5178cf9 100644 --- a/riva/clients/realtime/recognition_client.h +++ b/riva/clients/realtime/realtime_client.h @@ -3,8 +3,8 @@ * SPDX-License-Identifier: MIT */ -#ifndef RECOGNITION_CLIENT_H -#define RECOGNITION_CLIENT_H +#ifndef REALTIME_CLIENT_H +#define REALTIME_CLIENT_H #include #include @@ -67,7 +67,7 @@ namespace nvidia::riva::realtime { std::string server_url_; }; - class RecognitionClient : public WebSocketClientBase { + class RealtimeClient : public WebSocketClientBase { private: // Session tracking @@ -119,10 +119,10 @@ namespace nvidia::riva::realtime { void HandleMessage(const std::string& message) override; public: - RecognitionClient( const std::string& objectName, + RealtimeClient( const std::string& objectName, const std::shared_ptr audioChunksPtr, nvidia::riva::utils::PerformanceStats& perfCounter); - ~RecognitionClient() = default; + ~RealtimeClient() = default; void Log(const std::string& message); @@ -156,4 +156,4 @@ namespace nvidia::riva::realtime { } // namespace nvidia::riva::realtime -#endif // RECOGNITION_CLIENT_H \ No newline at end of file +#endif // REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/riva_realtime_asr_client.cc b/riva/clients/realtime/riva_realtime_asr_client.cc index e4fa74d..853461e 100644 --- a/riva/clients/realtime/riva_realtime_asr_client.cc +++ b/riva/clients/realtime/riva_realtime_asr_client.cc @@ -18,7 +18,7 @@ #include #include #include -#include "riva/clients/realtime/recognition_client.h" +#include "riva/clients/realtime/realtime_client.h" #include "riva/utils/stats_builder/stats_builder.h" #include #include @@ -85,7 +85,7 @@ DEFINE_bool(show_detailed_stats, true, "Show detailed statistics"); DEFINE_bool(show_tabular_stats, true, "Show tabular statistics"); // Global client pointer for signal handling -std::vector g_clients; +std::vector g_clients; std::mutex g_clients_mutex; // Signal handler for graceful shutdown @@ -117,7 +117,7 @@ void client_runner( const std::string& uri, const std::size_t chunkDelayTimeInMs, const bool simulateRealtime = false) { - nvidia::riva::realtime::RecognitionClient client(perfCounter.GetObjectName(), audio_chunks, perfCounter); + nvidia::riva::realtime::RealtimeClient client(perfCounter.GetObjectName(), audio_chunks, perfCounter); // Extract server URL from URI (remove ws:// and path) std::string server_url = uri; From 0e68f6d9244455ab73ef8a11ed04168340b8c2e9 Mon Sep 17 00:00:00 2001 From: Yash Hayaran Date: Mon, 11 Aug 2025 16:00:30 +0530 Subject: [PATCH 6/6] adding microphone support --- riva/clients/realtime/audio_chunks.cpp | 585 ++++++-- riva/clients/realtime/audio_chunks.h | 216 ++- riva/clients/realtime/base_client.cpp | 353 ++--- riva/clients/realtime/base_client.h | 137 +- riva/clients/realtime/realtime_client.cpp | 1268 +++++++++-------- riva/clients/realtime/realtime_client.h | 262 ++-- .../realtime/riva_realtime_asr_client.cc | 870 ++++++----- riva/utils/stats_builder/stats_builder.cpp | 498 ++++--- riva/utils/stats_builder/stats_builder.h | 247 ++-- 9 files changed, 2595 insertions(+), 1841 deletions(-) diff --git a/riva/clients/realtime/audio_chunks.cpp b/riva/clients/realtime/audio_chunks.cpp index 98bdc96..3a36b7f 100644 --- a/riva/clients/realtime/audio_chunks.cpp +++ b/riva/clients/realtime/audio_chunks.cpp @@ -4,157 +4,508 @@ */ #include "audio_chunks.h" -#include "riva/utils/wav/wav_reader.h" -#include "riva/utils/wav/wav_data.h" + +#include + +#include +#include #include #include -#include -#include +#include #include +#include #include -#include -#include -#include +#include #include -nvidia::riva::realtime::AudioChunks::AudioChunks(const std::string& filepath, const int& chunk_size_ms) - : filepath_(filepath), chunk_size_ms_(chunk_size_ms) { -} +#include "riva/utils/wav/wav_data.h" +#include "riva/utils/wav/wav_reader.h" -void nvidia::riva::realtime::AudioChunks::CalculateChunkSizeBytes() { - chunk_size_bytes_ = (GetSampleRateHz() * GetChunkSizeMs() / 1000) * sizeof(int16_t); - std::cout << "[AudioChunks] Calculated chunk size: " << chunk_size_bytes_ << " bytes" << std::endl; +namespace nvidia::riva::realtime { + +// ============================================================================ +// Base AudioChunks class implementation +// ============================================================================ + +AudioChunks::AudioChunks(const int& chunk_size_ms) : chunk_size_ms_(chunk_size_ms) {} + +void +AudioChunks::CalculateChunkSizeBytes(int sample_rate) +{ + chunk_size_bytes_ = (sample_rate * chunk_size_ms_ / 1000) * sizeof(int16_t); + std::cout << "[AudioChunks] Calculated chunk size: " << chunk_size_bytes_ << " bytes" + << std::endl; } -void nvidia::riva::realtime::AudioChunks::SplitIntoChunks() { - const std::vector& raw_data = wav_data_->data; - size_t total_size = raw_data.size(); - - std::cout << "[AudioChunks] Splitting WAV file into chunks of " << chunk_size_bytes_ << " bytes" << std::endl; - - chunk_base64s_.clear(); - for (size_t i = 0; i < total_size; i += chunk_size_bytes_) { - size_t current_chunk_size = std::min(chunk_size_bytes_, total_size - i); - std::vector chunk(raw_data.begin() + i, raw_data.begin() + i + current_chunk_size); - std::string chunk_base64 = EncodeBase64(chunk); - chunk_base64s_.push_back(chunk_base64); +std::string +AudioChunks::EncodeBase64(const std::vector& data) +{ + const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + std::string result; + int val = 0, valb = -6; + + for (unsigned char c : data) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + result.push_back(base64_chars[(val >> valb) & 0x3F]); + valb -= 6; } + } + + if (valb > -6) { + result.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (result.size() % 4) { + result.push_back('='); + } + + return result; } -std::string nvidia::riva::realtime::AudioChunks::EncodeBase64(const std::vector& data) { - const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - - std::string result; - int val = 0, valb = -6; - - for (unsigned char c : data) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - result.push_back(base64_chars[(val >> valb) & 0x3F]); - valb -= 6; - } - } - - if (valb > -6) { - result.push_back(base64_chars[((val << 8) >> (valb + 8)) & 0x3F]); - } - - while (result.size() % 4) { - result.push_back('='); - } - - return result; +bool +AudioChunks::Init() +{ + if (initialized_) { + std::cout << "[AudioChunks] Chunks already initialized" << std::endl; + return true; + } + + std::cout << "[AudioChunks] Initializing audio chunks..." << std::endl; + + if (!InitializeAudio()) { + std::cerr << "[AudioChunks] Error: Failed to initialize audio" << std::endl; + return false; + } + + ProcessAudioData(); + + initialized_ = true; + std::cout << "[AudioChunks] Successfully initialized with " << chunk_base64s_.size() << " chunks" + << std::endl; + + return initialized_; } -bool nvidia::riva::realtime::AudioChunks::Init() { - if (initialized_) { - std::cout << "[AudioChunks] Chunks already initialized" << std::endl; - return true; - } +// Getter implementations +size_t +AudioChunks::GetChunkSizeMs() const +{ + return chunk_size_ms_; +} - std::cout << "[AudioChunks] Initializing chunks for file: " << filepath_ << std::endl; - fs::path path(filepath_); - std::string extension = path.extension().string(); - - // File exists - if (!fs::exists(filepath_)) { - std::cerr << "[AudioChunks] Error: File does not exist, " << filepath_ << std::endl; - return false; - } +size_t +AudioChunks::GetChunkSizeBytes() const +{ + return chunk_size_bytes_; +} + +bool +AudioChunks::IsInitialized() const +{ + return initialized_; +} + +const std::vector& +AudioChunks::GetChunkBase64s() const +{ + return chunk_base64s_; +} + +// ============================================================================ +// FileAudioChunks derived class implementation +// ============================================================================ + +FileAudioChunks::FileAudioChunks(const std::string& filepath, const int& chunk_size_ms) + : AudioChunks(chunk_size_ms), filepath_(filepath) +{ +} + +void +FileAudioChunks::SplitIntoChunks() +{ + const std::vector& raw_data = wav_data_->data; + size_t total_size = raw_data.size(); + + std::cout << "[FileAudioChunks] Splitting WAV file into chunks of " << chunk_size_bytes_ + << " bytes" << std::endl; - // File is a WAV file - if (extension != ".wav") { - std::cerr << "[AudioChunks] Error: File is not a WAV file, " << filepath_ << std::endl; - return false; + chunk_base64s_.clear(); + for (size_t i = 0; i < total_size; i += chunk_size_bytes_) { + size_t current_chunk_size = std::min(chunk_size_bytes_, total_size - i); + std::vector chunk(raw_data.begin() + i, raw_data.begin() + i + current_chunk_size); + std::string chunk_base64 = EncodeBase64(chunk); + chunk_base64s_.push_back(chunk_base64); + } +} + +bool +FileAudioChunks::InitializeAudio() +{ + std::cout << "[FileAudioChunks] Initializing file audio for: " << filepath_ << std::endl; + fs::path path(filepath_); + std::string extension = path.extension().string(); + + // File exists + if (!fs::exists(filepath_)) { + std::cerr << "[FileAudioChunks] Error: File does not exist, " << filepath_ << std::endl; + return false; + } + + // File is a WAV file + if (extension != ".wav") { + std::cerr << "[FileAudioChunks] Error: File is not a WAV file, " << filepath_ << std::endl; + return false; + } + + // Load WAV file using the existing WAV utilities + std::vector> all_wav; + LoadWavData(all_wav, filepath_); + + if (all_wav.empty()) { + std::cerr << "[FileAudioChunks] Error: Failed to load WAV file, " << filepath_ << std::endl; + return false; + } + + wav_data_ = all_wav[0]; // Use the first WAV file + + CalculateChunkSizeBytes(GetSampleRateHz()); + + return true; +} + +void +FileAudioChunks::ProcessAudioData() +{ + SplitIntoChunks(); +} + +// FileAudioChunks getter implementations +std::string +FileAudioChunks::GetFilepath() const +{ + return filepath_; +} + +int +FileAudioChunks::GetSampleRateHz() const +{ + return wav_data_->sample_rate; +} + +int +FileAudioChunks::GetNumChannels() const +{ + return wav_data_->channels; +} + +int +FileAudioChunks::GetBitDepth() const +{ + // Calculate bit depth from data size and sample rate + if (wav_data_->channels > 0 && wav_data_->sample_rate > 0) { + return (wav_data_->data.size() * 8) / (wav_data_->channels * wav_data_->sample_rate); + } + return 16; // Default to 16-bit +} + +double +FileAudioChunks::GetDurationSeconds() const +{ + if (wav_data_->sample_rate > 0 && wav_data_->channels > 0) { + return static_cast(wav_data_->data.size()) / + (wav_data_->sample_rate * wav_data_->channels * 2); // Assuming 16-bit + } + return 0.0; +} + +int +FileAudioChunks::GetNumSamples() const +{ + if (wav_data_->channels > 0) { + return wav_data_->data.size() / (wav_data_->channels * 2); // Assuming 16-bit + } + return 0; +} + +// ============================================================================ +// MicrophoneChunks derived class implementation +// ============================================================================ + +MicrophoneChunks::MicrophoneChunks( + const std::string& device_name, const int& chunk_size_ms, int sample_rate, int num_channels, + int bit_depth) + : AudioChunks(chunk_size_ms), device_name_(device_name), alsa_handle_(nullptr), + sample_rate_(sample_rate), num_channels_(num_channels), bit_depth_(bit_depth), + is_capturing_(false), request_exit_(false) +{ +} + +MicrophoneChunks::~MicrophoneChunks() +{ + StopCapture(); + CloseAudioDevice(); +} + +bool +MicrophoneChunks::OpenAudioDevice() +{ + int rc; + static snd_output_t* log; + + std::cout << "[MicrophoneChunks] Opening ALSA device: " << device_name_ << std::endl; + std::cout << "[MicrophoneChunks] Sample rate: " << sample_rate_ + << " Hz, Channels: " << num_channels_ << std::endl; + + if ((rc = snd_pcm_open(&alsa_handle_, device_name_.c_str(), SND_PCM_STREAM_CAPTURE, 0)) < 0) { + std::cerr << "[MicrophoneChunks] Unable to open PCM device for recording: " << snd_strerror(rc) + << std::endl; + return false; + } + + if ((rc = snd_output_stdio_attach(&log, stderr, 0)) < 0) { + std::cerr << "[MicrophoneChunks] Unable to attach log output: " << snd_strerror(rc) + << std::endl; + return false; + } + + // Set audio parameters + snd_pcm_format_t format = (bit_depth_ == 16) ? SND_PCM_FORMAT_S16_LE : SND_PCM_FORMAT_S32_LE; + unsigned int latency = 100000; // 100ms latency + + if ((rc = snd_pcm_set_params( + alsa_handle_, format, SND_PCM_ACCESS_RW_INTERLEAVED, num_channels_, sample_rate_, 1, + latency)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_set_params error: " << snd_strerror(rc) << std::endl; + return false; + } + + // Set software parameters for capture + snd_pcm_sw_params_t* sw_params = nullptr; + if ((rc = snd_pcm_sw_params_malloc(&sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_malloc error: " << snd_strerror(rc) + << std::endl; + return false; + } + + if ((rc = snd_pcm_sw_params_current(alsa_handle_, sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_current error: " << snd_strerror(rc) + << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + if ((rc = snd_pcm_sw_params_set_start_threshold(alsa_handle_, sw_params, 1)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params_set_start_threshold failed: " + << snd_strerror(rc) << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + if ((rc = snd_pcm_sw_params(alsa_handle_, sw_params)) < 0) { + std::cerr << "[MicrophoneChunks] snd_pcm_sw_params failed: " << snd_strerror(rc) << std::endl; + snd_pcm_sw_params_free(sw_params); + return false; + } + + snd_pcm_sw_params_free(sw_params); + + std::cout << "[MicrophoneChunks] Successfully opened ALSA device" << std::endl; + return true; +} + +void +MicrophoneChunks::CloseAudioDevice() +{ + if (alsa_handle_) { + snd_pcm_close(alsa_handle_); + alsa_handle_ = nullptr; + std::cout << "[MicrophoneChunks] Closed ALSA device" << std::endl; + } +} + +bool +MicrophoneChunks::InitializeAudio() +{ + std::cout << "[MicrophoneChunks] Initializing microphone audio for device: " << device_name_ + << std::endl; + + if (!OpenAudioDevice()) { + std::cerr << "[MicrophoneChunks] Error: Failed to open audio device" << std::endl; + return false; + } + + CalculateChunkSizeBytes(sample_rate_); + + return true; +} + +void +MicrophoneChunks::ProcessAudioData() +{ + // For microphone, we don't pre-process data - it comes in real-time + // This method is called during Init() but doesn't populate chunks initially + std::cout << "[MicrophoneChunks] Microphone initialized, ready for capture" << std::endl; +} + +bool +MicrophoneChunks::StartCapture() +{ + if (is_capturing_) { + std::cout << "[MicrophoneChunks] Already capturing audio" << std::endl; + return true; + } + + if (!initialized_) { + std::cerr << "[MicrophoneChunks] Error: Microphone not initialized" << std::endl; + return false; + } + + request_exit_ = false; + is_capturing_ = true; + + // Start capture thread + capture_thread_ = std::thread(&MicrophoneChunks::CaptureThreadMain, this); + + std::cout << "[MicrophoneChunks] Started audio capture" << std::endl; + return true; +} + +void +MicrophoneChunks::StopCapture() +{ + if (!is_capturing_) { + return; + } + + request_exit_ = true; + is_capturing_ = false; + + if (capture_thread_.joinable()) { + capture_thread_.join(); + } + + std::cout << "[MicrophoneChunks] Stopped audio capture" << std::endl; +} + +void +MicrophoneChunks::CaptureThreadMain() +{ + std::cout << "[MicrophoneChunks] Capture thread started" << std::endl; + + const size_t chunk_size = chunk_size_bytes_; + std::vector chunk(chunk_size); + + while (is_capturing_ && !request_exit_) { + // Read audio chunk from microphone + snd_pcm_sframes_t frames_read = + snd_pcm_readi(alsa_handle_, &chunk[0], chunk_size / sizeof(int16_t)); + + if (frames_read < 0) { + std::cerr << "[MicrophoneChunks] Read failed: " << snd_strerror(frames_read) << std::endl; + // Try to recover from error + if (snd_pcm_recover(alsa_handle_, frames_read, 0) < 0) { + std::cerr << "[MicrophoneChunks] Failed to recover from error" << std::endl; + break; + } + continue; } - - // Load WAV file using the existing WAV utilities - std::vector> all_wav; - LoadWavData(all_wav, filepath_); - - if (all_wav.empty()) { - std::cerr << "[AudioChunks] Error: Failed to load WAV file, " << filepath_ << std::endl; - return false; + + if (frames_read > 0) { + // Convert frames to bytes + size_t bytes_read = frames_read * sizeof(int16_t); + + // Create chunk with actual data read + std::vector actual_chunk(chunk.begin(), chunk.begin() + bytes_read); + std::string chunk_base64 = EncodeBase64(actual_chunk); + + // Add to chunks with thread safety + { + std::lock_guard lock(chunks_mutex_); + chunk_base64s_.push_back(chunk_base64); + + // Keep only last 100 chunks to prevent memory issues + if (chunk_base64s_.size() > 100) { + chunk_base64s_.erase(chunk_base64s_.begin()); + } + } + + // Notify waiting threads + chunks_cv_.notify_all(); + + std::cout << "[MicrophoneChunks] Captured chunk " << chunk_base64s_.size() << " (" + << bytes_read << " bytes)" << std::endl; } - - wav_data_ = all_wav[0]; // Use the first WAV file - - CalculateChunkSizeBytes(); - SplitIntoChunks(); - - initialized_ = true; - - return initialized_; + } + + std::cout << "[MicrophoneChunks] Capture thread ended" << std::endl; } -// Getter implementations -std::string nvidia::riva::realtime::AudioChunks::GetFilepath() const { - return filepath_; +// MicrophoneChunks getter implementations +std::string +MicrophoneChunks::GetDeviceName() const +{ + return device_name_; } -size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeMs() const { - return chunk_size_ms_; +bool +MicrophoneChunks::IsCapturing() const +{ + return is_capturing_; } -size_t nvidia::riva::realtime::AudioChunks::GetChunkSizeBytes() const { - return chunk_size_bytes_; +int +MicrophoneChunks::GetSampleRateHz() const +{ + return sample_rate_; } -bool nvidia::riva::realtime::AudioChunks::IsInitialized() const { - return initialized_; +int +MicrophoneChunks::GetNumChannels() const +{ + return num_channels_; } -// WAV file properties -int nvidia::riva::realtime::AudioChunks::GetSampleRateHz() const { - return wav_data_->sample_rate; +int +MicrophoneChunks::GetBitDepth() const +{ + return bit_depth_; } -int nvidia::riva::realtime::AudioChunks::GetNumChannels() const { - return wav_data_->channels; +double +MicrophoneChunks::GetDurationSeconds() const +{ + // For microphone, duration is ongoing - return 0 + return 0.0; } -int nvidia::riva::realtime::AudioChunks::GetBitDepth() const { - // Calculate bit depth from data size and sample rate - if (wav_data_->channels > 0 && wav_data_->sample_rate > 0) { - return (wav_data_->data.size() * 8) / (wav_data_->channels * wav_data_->sample_rate); - } - return 16; // Default to 16-bit +int +MicrophoneChunks::GetNumSamples() const +{ + // For microphone, samples are ongoing - return 0 + return 0; } -double nvidia::riva::realtime::AudioChunks::GetDurationSeconds() const { - if (wav_data_->sample_rate > 0 && wav_data_->channels > 0) { - return static_cast(wav_data_->data.size()) / (wav_data_->sample_rate * wav_data_->channels * 2); // Assuming 16-bit - } - return 0.0; +std::string +MicrophoneChunks::GetLatestChunk() const +{ + std::lock_guard lock(chunks_mutex_); + if (chunk_base64s_.empty()) { + return ""; + } + return chunk_base64s_.back(); } -int nvidia::riva::realtime::AudioChunks::GetNumSamples() const { - if (wav_data_->channels > 0) { - return wav_data_->data.size() / (wav_data_->channels * 2); // Assuming 16-bit - } - return 0; +void +MicrophoneChunks::WaitForNewChunk() +{ + std::unique_lock lock(chunks_mutex_); + chunks_cv_.wait(lock, [this] { return !chunk_base64s_.empty(); }); } + +} // namespace nvidia::riva::realtime diff --git a/riva/clients/realtime/audio_chunks.h b/riva/clients/realtime/audio_chunks.h index e8f7647..fa15560 100644 --- a/riva/clients/realtime/audio_chunks.h +++ b/riva/clients/realtime/audio_chunks.h @@ -6,79 +6,161 @@ #ifndef AUDIO_CHUNKS_H #define AUDIO_CHUNKS_H +#include + +#include +#include +#include #include #include +#include #include +#include #include -#include -#include "riva/utils/wav/wav_reader.h" + #include "riva/utils/wav/wav_data.h" +#include "riva/utils/wav/wav_reader.h" namespace fs = std::filesystem; namespace nvidia::riva::realtime { - class AudioChunks { - private: - bool initialized_ = false; - std::string filepath_; - size_t chunk_size_ms_; - size_t chunk_size_bytes_; - std::shared_ptr wav_data_; - std::vector chunk_base64s_; - - void CalculateChunkSizeBytes(); - void SplitIntoChunks(); - std::string EncodeBase64(const std::vector& data); - - public: - AudioChunks(const std::string& filepath, const int& chunk_size_ms); - ~AudioChunks() = default; - - bool Init(); - - // Getters - std::string GetFilepath() const; - size_t GetChunkSizeMs() const; - size_t GetChunkSizeBytes() const; - bool IsInitialized() const; - - // WAV file properties - int GetSampleRateHz() const; - int GetNumChannels() const; - int GetBitDepth() const; - double GetDurationSeconds() const; - int GetNumSamples() const; - const std::vector& GetChunkBase64s() const; - - // Iterator support - using iterator = std::vector::iterator; - using const_iterator = std::vector::const_iterator; - using reverse_iterator = std::vector::reverse_iterator; - using const_reverse_iterator = std::vector::const_reverse_iterator; - - // Iterator methods - iterator begin() { return chunk_base64s_.begin(); } - const_iterator begin() const { return chunk_base64s_.begin(); } - iterator end() { return chunk_base64s_.end(); } - const_iterator end() const { return chunk_base64s_.end(); } - - // Reverse iterator methods - reverse_iterator rbegin() { return chunk_base64s_.rbegin(); } - const_reverse_iterator rbegin() const { return chunk_base64s_.rbegin(); } - reverse_iterator rend() { return chunk_base64s_.rend(); } - const_reverse_iterator rend() const { return chunk_base64s_.rend(); } - - // Const iterator methods - const_iterator cbegin() const { return chunk_base64s_.cbegin(); } - const_iterator cend() const { return chunk_base64s_.cend(); } - const_reverse_iterator crbegin() const { return chunk_base64s_.crbegin(); } - const_reverse_iterator crend() const { return chunk_base64s_.crend(); } - - // Size methods - size_t size() const { return chunk_base64s_.size(); } - bool empty() const { return chunk_base64s_.empty(); } - }; - -} // namespace nvidia::riva::realtime - -#endif // AUDIO_CHUNKS_H \ No newline at end of file + +// Forward declarations - we'll include the actual headers in the .cpp file +void LoadWavData(std::vector>& all_wav, const std::string& filepath); + +// Base class for audio input +class AudioChunks { + protected: + bool initialized_ = false; + size_t chunk_size_ms_; + size_t chunk_size_bytes_; + std::vector chunk_base64s_; + + // Common methods for derived classes + void CalculateChunkSizeBytes(int sample_rate); + std::string EncodeBase64(const std::vector& data); + + // Virtual methods for derived classes to implement + virtual bool InitializeAudio() = 0; + virtual void ProcessAudioData() = 0; + + public: + AudioChunks(const int& chunk_size_ms); + virtual ~AudioChunks() = default; + + bool Init(); + + // Getters + size_t GetChunkSizeMs() const; + size_t GetChunkSizeBytes() const; + bool IsInitialized() const; + + // Audio properties (to be implemented by derived classes) + virtual int GetSampleRateHz() const = 0; + virtual int GetNumChannels() const = 0; + virtual int GetBitDepth() const = 0; + virtual double GetDurationSeconds() const = 0; + virtual int GetNumSamples() const = 0; + const std::vector& GetChunkBase64s() const; + + // Iterator support + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + using reverse_iterator = std::vector::reverse_iterator; + using const_reverse_iterator = std::vector::const_reverse_iterator; + + // Iterator methods + iterator begin() { return chunk_base64s_.begin(); } + const_iterator begin() const { return chunk_base64s_.begin(); } + iterator end() { return chunk_base64s_.end(); } + const_iterator end() const { return chunk_base64s_.end(); } + + // Reverse iterator methods + reverse_iterator rbegin() { return chunk_base64s_.rbegin(); } + const_reverse_iterator rbegin() const { return chunk_base64s_.rbegin(); } + reverse_iterator rend() { return chunk_base64s_.rend(); } + const_reverse_iterator rend() const { return chunk_base64s_.rend(); } + + // Const iterator methods + const_iterator cbegin() const { return chunk_base64s_.cbegin(); } + const_iterator cend() const { return chunk_base64s_.cend(); } + const_reverse_iterator crbegin() const { return chunk_base64s_.crbegin(); } + const_reverse_iterator crend() const { return chunk_base64s_.crend(); } + + // Size methods + size_t size() const { return chunk_base64s_.size(); } + bool empty() const { return chunk_base64s_.empty(); } +}; + +// Derived class for file-based audio input +class FileAudioChunks : public AudioChunks { + private: + std::string filepath_; + std::shared_ptr wav_data_; + + void SplitIntoChunks(); + bool InitializeAudio() override; + void ProcessAudioData() override; + + public: + FileAudioChunks(const std::string& filepath, const int& chunk_size_ms); + ~FileAudioChunks() = default; + + std::string GetFilepath() const; + int GetSampleRateHz() const override; + int GetNumChannels() const override; + int GetBitDepth() const override; + double GetDurationSeconds() const override; + int GetNumSamples() const override; +}; + +// Derived class for microphone input +class MicrophoneChunks : public AudioChunks { + private: + std::string device_name_; + snd_pcm_t* alsa_handle_; + std::thread capture_thread_; + std::atomic is_capturing_; + std::atomic request_exit_; + mutable std::mutex chunks_mutex_; // Make mutable for const member functions + std::condition_variable chunks_cv_; + + // Audio capture parameters + int sample_rate_; + int num_channels_; + int bit_depth_; + + // Capture thread function + void CaptureThreadMain(); + bool OpenAudioDevice(); + void CloseAudioDevice(); + bool InitializeAudio() override; + void ProcessAudioData() override; + + public: + MicrophoneChunks( + const std::string& device_name, const int& chunk_size_ms, int sample_rate = 16000, + int num_channels = 1, int bit_depth = 16); + ~MicrophoneChunks(); + + // Microphone-specific methods + bool StartCapture(); + void StopCapture(); + bool IsCapturing() const; + std::string GetDeviceName() const; + + // Audio properties + int GetSampleRateHz() const override; + int GetNumChannels() const override; + int GetBitDepth() const override; + double GetDurationSeconds() const override; + int GetNumSamples() const override; + + // Real-time chunk access + std::string GetLatestChunk() const; + void WaitForNewChunk(); +}; + +} // namespace nvidia::riva::realtime + +#endif // AUDIO_CHUNKS_H \ No newline at end of file diff --git a/riva/clients/realtime/base_client.cpp b/riva/clients/realtime/base_client.cpp index ea9928a..bdeab1b 100644 --- a/riva/clients/realtime/base_client.cpp +++ b/riva/clients/realtime/base_client.cpp @@ -4,188 +4,223 @@ */ #include "base_client.h" -#include + #include +#include nvidia::riva::realtime::WebSocketClientBase::WebSocketClientBase(const std::string& uri) - : connected_(false), - connectionClosedByServer_(false), - connectionTimeoutMs_(std::size_t(5000)), - uri_(uri) { - - // Set up logging - suppress verbose internal messages + : connected_(false), connectionClosedByServer_(false), connectionTimeoutMs_(std::size_t(5000)), + uri_(uri) +{ + // Set up logging - suppress verbose internal messages + wsClient_.set_access_channels(websocketpp::log::alevel::connect); + wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); + wsClient_.set_access_channels(websocketpp::log::alevel::fail); + wsClient_.set_access_channels(websocketpp::log::alevel::app); + + // Initialize ASIO + wsClient_.init_asio(); + + // Set up handlers + wsClient_.set_open_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnOpen, this, std::placeholders::_1)); + wsClient_.set_close_handler(std::bind( + &nvidia::riva::realtime::WebSocketClientBase::OnClose, this, std::placeholders::_1)); + wsClient_.set_fail_handler( + std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnFail, this, std::placeholders::_1)); + wsClient_.set_message_handler(std::bind( + &nvidia::riva::realtime::WebSocketClientBase::OnMessage, this, std::placeholders::_1, + std::placeholders::_2)); +} + +void +nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout( + const std::size_t connectionTimeoutMs) +{ + connectionTimeoutMs_ = connectionTimeoutMs; +} + +std::size_t +nvidia::riva::realtime::WebSocketClientBase::GetConnectionTimeout() +{ + return connectionTimeoutMs_; +} + +void +nvidia::riva::realtime::WebSocketClientBase::SetVerboseLogging(bool verbose) +{ + if (verbose) { + // Enable all logging channels + wsClient_.set_access_channels(websocketpp::log::alevel::all); + wsClient_.clear_access_channels(websocketpp::log::alevel::frame_payload); + } else { + // Minimal logging - only important events + wsClient_.clear_access_channels(websocketpp::log::alevel::all); wsClient_.set_access_channels(websocketpp::log::alevel::connect); wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); wsClient_.set_access_channels(websocketpp::log::alevel::fail); wsClient_.set_access_channels(websocketpp::log::alevel::app); - - // Initialize ASIO - wsClient_.init_asio(); - - // Set up handlers - wsClient_.set_open_handler( - std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnOpen, this, std::placeholders::_1)); - wsClient_.set_close_handler( - std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnClose, this, std::placeholders::_1)); - wsClient_.set_fail_handler( - std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnFail, this, std::placeholders::_1)); - wsClient_.set_message_handler( - std::bind(&nvidia::riva::realtime::WebSocketClientBase::OnMessage, this, std::placeholders::_1, std::placeholders::_2)); -} - -void nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(const std::size_t connectionTimeoutMs) { - connectionTimeoutMs_ = connectionTimeoutMs; -} - -std::size_t nvidia::riva::realtime::WebSocketClientBase::GetConnectionTimeout() { - return connectionTimeoutMs_; -} - -void nvidia::riva::realtime::WebSocketClientBase::SetVerboseLogging(bool verbose) { - if (verbose) { - // Enable all logging channels - wsClient_.set_access_channels(websocketpp::log::alevel::all); - wsClient_.clear_access_channels(websocketpp::log::alevel::frame_payload); - } else { - // Minimal logging - only important events - wsClient_.clear_access_channels(websocketpp::log::alevel::all); - wsClient_.set_access_channels(websocketpp::log::alevel::connect); - wsClient_.set_access_channels(websocketpp::log::alevel::disconnect); - wsClient_.set_access_channels(websocketpp::log::alevel::fail); - wsClient_.set_access_channels(websocketpp::log::alevel::app); - } + } } -void nvidia::riva::realtime::WebSocketClientBase::Connect(const std::string& uri) { - uri_ = uri; - websocketpp::lib::error_code ec; - - websocketpp_client::connection_ptr con = wsClient_.get_connection(uri, ec); - if (ec) { - std::cerr << "Could not create connection: " << ec.message() << std::endl; - return; - } - - wsClient_.connect(con); -} +void +nvidia::riva::realtime::WebSocketClientBase::Connect(const std::string& uri) +{ + uri_ = uri; + websocketpp::lib::error_code ec; -void nvidia::riva::realtime::WebSocketClientBase::Run() { - wsClient_.run(); -} + websocketpp_client::connection_ptr con = wsClient_.get_connection(uri, ec); + if (ec) { + std::cerr << "Could not create connection: " << ec.message() << std::endl; + return; + } -void nvidia::riva::realtime::WebSocketClientBase::Send(const std::string& message) { - std::lock_guard lock(mutex_); - if (connected_) { - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, message, websocketpp::frame::opcode::text, ec); - if (ec) { - std::cerr << "Send failed: " << ec.message() << std::endl; - } - } + wsClient_.connect(con); } -void nvidia::riva::realtime::WebSocketClientBase::Close() { - std::lock_guard lock(mutex_); - if (connected_) { - websocketpp::lib::error_code ec; - wsClient_.close(connectionHdl_, websocketpp::close::status::normal, "Client closing", ec); - } +void +nvidia::riva::realtime::WebSocketClientBase::Run() +{ + wsClient_.run(); } -void nvidia::riva::realtime::WebSocketClientBase::SendJsonMessage(const std::string& type, const std::string& data) { - std::lock_guard lock(mutex_); - if (connected_) { - rapidjson::Document doc; - doc.SetObject(); - rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); - - doc.AddMember("type", rapidjson::Value(type.c_str(), allocator), allocator); - if (!data.empty()) { - doc.AddMember("data", rapidjson::Value(data.c_str(), allocator), allocator); - } - - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - doc.Accept(writer); - - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); - if (ec) { - std::cerr << "Send failed: " << ec.message() << std::endl; - } else { - std::cout << "Sent: " << buffer.GetString() << std::endl; - } +void +nvidia::riva::realtime::WebSocketClientBase::Send(const std::string& message) +{ + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, message, websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; } + } } -void nvidia::riva::realtime::WebSocketClientBase::OnOpen(websocketpp::connection_hdl hdl) { - std::lock_guard lock(mutex_); - connectionHdl_ = hdl; - connected_ = true; - - // Notify waiting threads that connection is established - { - std::lock_guard conn_lock(connectionMutex_); - connectionCv_.notify_one(); - } - - std::cout << "Connected to " << uri_ << std::endl; -} - -void nvidia::riva::realtime::WebSocketClientBase::OnClose(websocketpp::connection_hdl hdl) { - (void)hdl; // Suppress unused parameter warning - std::lock_guard lock(mutex_); - connected_ = false; - - // Check if this was a server-initiated close - { - std::lock_guard conn_lock(connectionMutex_); - connectionClosedByServer_ = true; - } - connectionCv_.notify_one(); - - std::cout << "Connection closed" << std::endl; -} - -void nvidia::riva::realtime::WebSocketClientBase::OnFail(websocketpp::connection_hdl hdl) { - (void)hdl; // Suppress unused parameter warning - std::lock_guard lock(mutex_); - connected_ = false; - - // Mark as server-initiated failure - { - std::lock_guard conn_lock(connectionMutex_); - connectionClosedByServer_ = true; +void +nvidia::riva::realtime::WebSocketClientBase::Close() +{ + std::lock_guard lock(mutex_); + if (connected_) { + websocketpp::lib::error_code ec; + wsClient_.close(connectionHdl_, websocketpp::close::status::normal, "Client closing", ec); + } +} + +void +nvidia::riva::realtime::WebSocketClientBase::SendJsonMessage( + const std::string& type, const std::string& data) +{ + std::lock_guard lock(mutex_); + if (connected_) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value(type.c_str(), allocator), allocator); + if (!data.empty()) { + doc.AddMember("data", rapidjson::Value(data.c_str(), allocator), allocator); } - connectionCv_.notify_one(); - - std::cout << "************************ Connection failed" << std::endl; -} -void nvidia::riva::realtime::WebSocketClientBase::OnMessage(websocketpp::connection_hdl hdl, message_ptr msg) { - (void)hdl; // Suppress unused parameter warning - HandleMessage(msg->get_payload()); -} + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); -bool nvidia::riva::realtime::WebSocketClientBase::WaitForConnection() { - std::unique_lock lock(connectionMutex_); - return connectionCv_.wait_for(lock, - std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout - [this] { return connected_; }); + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cerr << "Send failed: " << ec.message() << std::endl; + } else { + std::cout << "Sent: " << buffer.GetString() << std::endl; + } + } } -bool nvidia::riva::realtime::WebSocketClientBase::WaitForDisconnection() { - std::unique_lock lock(connectionMutex_); - return connectionCv_.wait_for(lock, - std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout - [this] { return !connected_; }); -} +void +nvidia::riva::realtime::WebSocketClientBase::OnOpen(websocketpp::connection_hdl hdl) +{ + std::lock_guard lock(mutex_); + connectionHdl_ = hdl; + connected_ = true; -bool nvidia::riva::realtime::WebSocketClientBase::WaitForServerClose() { - std::unique_lock lock(connectionMutex_); - return connectionCv_.wait_for(lock, - std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout - [this] { return connectionClosedByServer_; }); + // Notify waiting threads that connection is established + { + std::lock_guard conn_lock(connectionMutex_); + connectionCv_.notify_one(); + } + + std::cout << "Connected to " << uri_ << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnClose(websocketpp::connection_hdl hdl) +{ + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Check if this was a server-initiated close + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "Connection closed" << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnFail(websocketpp::connection_hdl hdl) +{ + (void)hdl; // Suppress unused parameter warning + std::lock_guard lock(mutex_); + connected_ = false; + + // Mark as server-initiated failure + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + connectionCv_.notify_one(); + + std::cout << "************************ Connection failed" << std::endl; +} + +void +nvidia::riva::realtime::WebSocketClientBase::OnMessage( + websocketpp::connection_hdl hdl, message_ptr msg) +{ + (void)hdl; // Suppress unused parameter warning + HandleMessage(msg->get_payload()); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForConnection() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connected_; }); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForDisconnection() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return !connected_; }); +} + +bool +nvidia::riva::realtime::WebSocketClientBase::WaitForServerClose() +{ + std::unique_lock lock(connectionMutex_); + return connectionCv_.wait_for( + lock, + std::chrono::milliseconds(connectionTimeoutMs_), // Use the provided timeout + [this] { return connectionClosedByServer_; }); } - diff --git a/riva/clients/realtime/base_client.h b/riva/clients/realtime/base_client.h index 6ebb3db..ec0ec8c 100644 --- a/riva/clients/realtime/base_client.h +++ b/riva/clients/realtime/base_client.h @@ -6,81 +6,82 @@ #ifndef BASE_REALTIME_CLIENT_H #define BASE_REALTIME_CLIENT_H -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include #include #include -#include -#include -#include #include -#include +#include +#include + #include "audio_chunks.h" namespace nvidia::riva::realtime { - class WebSocketClientBase { - protected: - typedef websocketpp::client websocketpp_client; - typedef websocketpp::config::asio_client::message_type::ptr message_ptr; - - websocketpp_client wsClient_; - websocketpp::connection_hdl connectionHdl_; - - std::string uri_; - bool connected_; - std::mutex mutex_; - - // Connection state - bool connectionClosedByServer_; - std::condition_variable connectionCv_; - std::mutex connectionMutex_; - std::size_t connectionTimeoutMs_; - - // Protected access to websocket client for derived classes - websocketpp_client& GetWsClient() { return wsClient_; } - websocketpp::connection_hdl& GetConnection() { return connectionHdl_; } - std::mutex& GetConnectionMutex() { return connectionMutex_; } - - public: - WebSocketClientBase(const std::string& uri); - ~WebSocketClientBase() = default; - - // Connection timeout - void SetConnectionTimeout(const std::size_t connectionTimeoutMs); - std::size_t GetConnectionTimeout(); - - // Connection status - bool IsConnected() const { return connected_; } - bool IsConnectionClosedByServer() const { return connectionClosedByServer_; } - bool IsConnectionOpen() const { return connected_ && !connectionClosedByServer_; } - bool IsConnectionClosed() const { return !connected_ || connectionClosedByServer_; } - - // Control logging verbosity - void SetVerboseLogging(bool verbose); - - // Connection management - void Connect(const std::string& uri); - void Run(); - void Send(const std::string& message); - void Close(); - void SendJsonMessage(const std::string& type, const std::string& data = ""); - - // Connection waiting methods - bool WaitForConnection(); - bool WaitForDisconnection(); - bool WaitForServerClose(); - - // Event handlers - void OnOpen(websocketpp::connection_hdl hdl); - void OnClose(websocketpp::connection_hdl hdl); - void OnFail(websocketpp::connection_hdl hdl); - void OnMessage(websocketpp::connection_hdl hdl, message_ptr msg); - virtual void HandleMessage(const std::string& message) = 0; - }; -} // namespace nvidia::riva::realtime -#endif // BASE_REALTIME_CLIENT_H \ No newline at end of file +class WebSocketClientBase { + protected: + typedef websocketpp::client websocketpp_client; + typedef websocketpp::config::asio_client::message_type::ptr message_ptr; + + websocketpp_client wsClient_; + websocketpp::connection_hdl connectionHdl_; + + std::string uri_; + bool connected_; + std::mutex mutex_; + + // Connection state + bool connectionClosedByServer_; + std::condition_variable connectionCv_; + std::mutex connectionMutex_; + std::size_t connectionTimeoutMs_; + + // Protected access to websocket client for derived classes + websocketpp_client& GetWsClient() { return wsClient_; } + websocketpp::connection_hdl& GetConnection() { return connectionHdl_; } + std::mutex& GetConnectionMutex() { return connectionMutex_; } + + public: + WebSocketClientBase(const std::string& uri); + ~WebSocketClientBase() = default; + + // Connection timeout + void SetConnectionTimeout(const std::size_t connectionTimeoutMs); + std::size_t GetConnectionTimeout(); + + // Connection status + bool IsConnected() const { return connected_; } + bool IsConnectionClosedByServer() const { return connectionClosedByServer_; } + bool IsConnectionOpen() const { return connected_ && !connectionClosedByServer_; } + bool IsConnectionClosed() const { return !connected_ || connectionClosedByServer_; } + + // Control logging verbosity + void SetVerboseLogging(bool verbose); + + // Connection management + void Connect(const std::string& uri); + void Run(); + void Send(const std::string& message); + void Close(); + void SendJsonMessage(const std::string& type, const std::string& data = ""); + + // Connection waiting methods + bool WaitForConnection(); + bool WaitForDisconnection(); + bool WaitForServerClose(); + + // Event handlers + void OnOpen(websocketpp::connection_hdl hdl); + void OnClose(websocketpp::connection_hdl hdl); + void OnFail(websocketpp::connection_hdl hdl); + void OnMessage(websocketpp::connection_hdl hdl, message_ptr msg); + virtual void HandleMessage(const std::string& message) = 0; +}; +} // namespace nvidia::riva::realtime +#endif // BASE_REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/realtime_client.cpp b/riva/clients/realtime/realtime_client.cpp index 94d4953..4c583a0 100644 --- a/riva/clients/realtime/realtime_client.cpp +++ b/riva/clients/realtime/realtime_client.cpp @@ -4,687 +4,717 @@ */ #include "realtime_client.h" -#include "base_client.h" -#include -#include -#include -#include -#include -#include + #include -#include #include +#include +#include +#include + +#include #include -#include -#include #include +#include +#include +#include +#include +#include + +#include "base_client.h" // Helper method for HTTP requests using raw sockets -std::string nvidia::riva::realtime::RealtimeClient::MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body) { - int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { - std::cerr << "Failed to create socket" << std::endl; - return ""; +std::string +nvidia::riva::realtime::RealtimeClient::MakeHttpRequest( + const std::string& host, int port, const std::string& path, const std::string& method, + const std::string& body) +{ + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + std::cerr << "Failed to create socket" << std::endl; + return ""; + } + + struct hostent* server = gethostbyname(host.c_str()); + if (server == nullptr) { + std::cerr << "Failed to resolve host: " << host << std::endl; + close(sock); + return ""; + } + + struct sockaddr_in serv_addr; + memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + memcpy(&serv_addr.sin_addr.s_addr, server->h_addr, server->h_length); + serv_addr.sin_port = htons(port); + + if (connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) { + std::cerr << "Failed to connect to " << host << ":" << port << std::endl; + close(sock); + return ""; + } + + // Build HTTP request + std::ostringstream request; + request << method << " " << path << " HTTP/1.1\r\n"; + request << "Host: " << host << ":" << port << "\r\n"; + request << "Content-Type: application/json\r\n"; + request << "Content-Length: " << body.length() << "\r\n"; + request << "Connection: close\r\n"; + request << "\r\n"; + request << body; + + std::string request_str = request.str(); + + // Send request + if (send(sock, request_str.c_str(), request_str.length(), 0) < 0) { + std::cerr << "Failed to send HTTP request" << std::endl; + close(sock); + return ""; + } + + // Receive response + std::string response; + char buffer[4096]; + int bytes_received; + + while ((bytes_received = recv(sock, buffer, sizeof(buffer) - 1, 0)) > 0) { + buffer[bytes_received] = '\0'; + response += buffer; + } + + close(sock); + + // Extract JSON body from HTTP response + size_t body_start = response.find("\r\n\r\n"); + if (body_start != std::string::npos) { + return response.substr(body_start + 4); + } + + return response; +} + +bool +nvidia::riva::realtime::RealtimeClient::InitializeHttpSession() +{ + if (server_url_.empty()) { + std::cerr << "Server URL not set" << std::endl; + return false; + } + + // Parse server URL to extract host and port + std::string host = server_url_; + int port = 80; // Default HTTP port + + // Check if port is specified + size_t colon_pos = host.find(':'); + if (colon_pos != std::string::npos) { + port = std::stoi(host.substr(colon_pos + 1)); + host = host.substr(0, colon_pos); + } + + std::string path = "/v1/realtime/transcription_sessions"; + std::string response_body = MakeHttpRequest(host, port, path, "POST", "{}"); + + if (response_body.empty()) { + std::cerr << "HTTP request failed" << std::endl; + return false; + } + + try { + // Parse JSON response using rapidjson + rapidjson::Document session_data; + if (session_data.Parse(response_body.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON response" << std::endl; + return false; } - - struct hostent* server = gethostbyname(host.c_str()); - if (server == nullptr) { - std::cerr << "Failed to resolve host: " << host << std::endl; - close(sock); - return ""; + + // Extract session ID + if (session_data.HasMember("id")) { + session_id_ = session_data["id"].GetString(); + } else { + std::cerr << "No session ID found in response" << std::endl; + return false; } - - struct sockaddr_in serv_addr; - memset(&serv_addr, 0, sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - memcpy(&serv_addr.sin_addr.s_addr, server->h_addr, server->h_length); - serv_addr.sin_port = htons(port); - - if (connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) { - std::cerr << "Failed to connect to " << host << ":" << port << std::endl; - close(sock); - return ""; + + // Store server defaults but don't overwrite user-provided values + SessionConfig serverDefaults; + + if (session_data.HasMember("input_audio_transcription")) { + const auto& transcription = session_data["input_audio_transcription"]; + if (transcription.HasMember("language")) { + serverDefaults.language_code_ = transcription["language"].GetString(); + } + if (transcription.HasMember("model")) { + serverDefaults.model_name_ = transcription["model"].GetString(); + } } - - // Build HTTP request - std::ostringstream request; - request << method << " " << path << " HTTP/1.1\r\n"; - request << "Host: " << host << ":" << port << "\r\n"; - request << "Content-Type: application/json\r\n"; - request << "Content-Length: " << body.length() << "\r\n"; - request << "Connection: close\r\n"; - request << "\r\n"; - request << body; - - std::string request_str = request.str(); - - // Send request - if (send(sock, request_str.c_str(), request_str.length(), 0) < 0) { - std::cerr << "Failed to send HTTP request" << std::endl; - close(sock); - return ""; + + if (session_data.HasMember("recognition_config")) { + const auto& recognition = session_data["recognition_config"]; + if (recognition.HasMember("max_alternatives")) { + serverDefaults.max_alternatives_ = recognition["max_alternatives"].GetInt(); + } + if (recognition.HasMember("enable_automatic_punctuation")) { + serverDefaults.automatic_punctuation_ = + recognition["enable_automatic_punctuation"].GetBool(); + } + if (recognition.HasMember("enable_word_time_offsets")) { + serverDefaults.word_time_offsets_ = recognition["enable_word_time_offsets"].GetBool(); + } + if (recognition.HasMember("enable_profanity_filter")) { + serverDefaults.profanity_filter_ = recognition["enable_profanity_filter"].GetBool(); + } + if (recognition.HasMember("enable_verbatim_transcripts")) { + serverDefaults.verbatim_transcripts_ = recognition["enable_verbatim_transcripts"].GetBool(); + } } - - // Receive response - std::string response; - char buffer[4096]; - int bytes_received; - - while ((bytes_received = recv(sock, buffer, sizeof(buffer) - 1, 0)) > 0) { - buffer[bytes_received] = '\0'; - response += buffer; + + if (session_data.HasMember("speaker_diarization")) { + const auto& diarization = session_data["speaker_diarization"]; + if (diarization.HasMember("enable_speaker_diarization")) { + serverDefaults.speaker_diarization_ = diarization["enable_speaker_diarization"].GetBool(); + } + if (diarization.HasMember("max_speaker_count")) { + serverDefaults.diarization_max_speakers_ = diarization["max_speaker_count"].GetInt(); + } } - - close(sock); - - // Extract JSON body from HTTP response - size_t body_start = response.find("\r\n\r\n"); - if (body_start != std::string::npos) { - return response.substr(body_start + 4); + + if (session_data.HasMember("endpointing_config")) { + const auto& endpointing = session_data["endpointing_config"]; + if (endpointing.HasMember("start_history")) { + serverDefaults.start_history_ = endpointing["start_history"].GetInt(); + } + if (endpointing.HasMember("start_threshold")) { + serverDefaults.start_threshold_ = endpointing["start_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history")) { + serverDefaults.stop_history_ = endpointing["stop_history"].GetInt(); + } + if (endpointing.HasMember("stop_threshold")) { + serverDefaults.stop_threshold_ = endpointing["stop_threshold"].GetDouble(); + } + if (endpointing.HasMember("stop_history_eou")) { + serverDefaults.stop_history_eou_ = endpointing["stop_history_eou"].GetInt(); + } + if (endpointing.HasMember("stop_threshold_eou")) { + serverDefaults.stop_threshold_eou_ = endpointing["stop_threshold_eou"].GetDouble(); + } } - - return response; -} -bool nvidia::riva::realtime::RealtimeClient::InitializeHttpSession() { - if (server_url_.empty()) { - std::cerr << "Server URL not set" << std::endl; - return false; + // Only use server defaults for values that haven't been set by user + if (sessionConfig_.language_code_.empty()) { + sessionConfig_.language_code_ = serverDefaults.language_code_; + } + if (sessionConfig_.model_name_.empty()) { + sessionConfig_.model_name_ = serverDefaults.model_name_; + } + if (sessionConfig_.max_alternatives_ == 0) { + sessionConfig_.max_alternatives_ = serverDefaults.max_alternatives_; + } + if (sessionConfig_.start_history_ == -1) { + sessionConfig_.start_history_ = serverDefaults.start_history_; } - - // Parse server URL to extract host and port - std::string host = server_url_; - int port = 80; // Default HTTP port - - // Check if port is specified - size_t colon_pos = host.find(':'); - if (colon_pos != std::string::npos) { - port = std::stoi(host.substr(colon_pos + 1)); - host = host.substr(0, colon_pos); + if (sessionConfig_.start_threshold_ == -1.0) { + sessionConfig_.start_threshold_ = serverDefaults.start_threshold_; } - - std::string path = "/v1/realtime/transcription_sessions"; - std::string response_body = MakeHttpRequest(host, port, path, "POST", "{}"); - - if (response_body.empty()) { - std::cerr << "HTTP request failed" << std::endl; - return false; + if (sessionConfig_.stop_history_ == -1) { + sessionConfig_.stop_history_ = serverDefaults.stop_history_; } - - try { - // Parse JSON response using rapidjson - rapidjson::Document session_data; - if (session_data.Parse(response_body.c_str()).HasParseError()) { - std::cerr << "Failed to parse JSON response" << std::endl; - return false; - } - - // Extract session ID - if (session_data.HasMember("id")) { - session_id_ = session_data["id"].GetString(); - } else { - std::cerr << "No session ID found in response" << std::endl; - return false; - } - - // Store server defaults but don't overwrite user-provided values - SessionConfig serverDefaults; - - if (session_data.HasMember("input_audio_transcription")) { - const auto& transcription = session_data["input_audio_transcription"]; - if (transcription.HasMember("language")) { - serverDefaults.language_code_ = transcription["language"].GetString(); - } - if (transcription.HasMember("model")) { - serverDefaults.model_name_ = transcription["model"].GetString(); - } - } - - if (session_data.HasMember("recognition_config")) { - const auto& recognition = session_data["recognition_config"]; - if (recognition.HasMember("max_alternatives")) { - serverDefaults.max_alternatives_ = recognition["max_alternatives"].GetInt(); - } - if (recognition.HasMember("enable_automatic_punctuation")) { - serverDefaults.automatic_punctuation_ = recognition["enable_automatic_punctuation"].GetBool(); - } - if (recognition.HasMember("enable_word_time_offsets")) { - serverDefaults.word_time_offsets_ = recognition["enable_word_time_offsets"].GetBool(); - } - if (recognition.HasMember("enable_profanity_filter")) { - serverDefaults.profanity_filter_ = recognition["enable_profanity_filter"].GetBool(); - } - if (recognition.HasMember("enable_verbatim_transcripts")) { - serverDefaults.verbatim_transcripts_ = recognition["enable_verbatim_transcripts"].GetBool(); - } - } - - if (session_data.HasMember("speaker_diarization")) { - const auto& diarization = session_data["speaker_diarization"]; - if (diarization.HasMember("enable_speaker_diarization")) { - serverDefaults.speaker_diarization_ = diarization["enable_speaker_diarization"].GetBool(); - } - if (diarization.HasMember("max_speaker_count")) { - serverDefaults.diarization_max_speakers_ = diarization["max_speaker_count"].GetInt(); - } - } - - if (session_data.HasMember("endpointing_config")) { - const auto& endpointing = session_data["endpointing_config"]; - if (endpointing.HasMember("start_history")) { - serverDefaults.start_history_ = endpointing["start_history"].GetInt(); - } - if (endpointing.HasMember("start_threshold")) { - serverDefaults.start_threshold_ = endpointing["start_threshold"].GetDouble(); - } - if (endpointing.HasMember("stop_history")) { - serverDefaults.stop_history_ = endpointing["stop_history"].GetInt(); - } - if (endpointing.HasMember("stop_threshold")) { - serverDefaults.stop_threshold_ = endpointing["stop_threshold"].GetDouble(); - } - if (endpointing.HasMember("stop_history_eou")) { - serverDefaults.stop_history_eou_ = endpointing["stop_history_eou"].GetInt(); - } - if (endpointing.HasMember("stop_threshold_eou")) { - serverDefaults.stop_threshold_eou_ = endpointing["stop_threshold_eou"].GetDouble(); - } - } - - // Only use server defaults for values that haven't been set by user - if (sessionConfig_.language_code_.empty()) { - sessionConfig_.language_code_ = serverDefaults.language_code_; - } - if (sessionConfig_.model_name_.empty()) { - sessionConfig_.model_name_ = serverDefaults.model_name_; - } - if (sessionConfig_.max_alternatives_ == 0) { - sessionConfig_.max_alternatives_ = serverDefaults.max_alternatives_; - } - if (sessionConfig_.start_history_ == -1) { - sessionConfig_.start_history_ = serverDefaults.start_history_; - } - if (sessionConfig_.start_threshold_ == -1.0) { - sessionConfig_.start_threshold_ = serverDefaults.start_threshold_; - } - if (sessionConfig_.stop_history_ == -1) { - sessionConfig_.stop_history_ = serverDefaults.stop_history_; - } - if (sessionConfig_.stop_threshold_ == -1.0) { - sessionConfig_.stop_threshold_ = serverDefaults.stop_threshold_; - } - if (sessionConfig_.stop_history_eou_ == -1) { - sessionConfig_.stop_history_eou_ = serverDefaults.stop_history_eou_; - } - if (sessionConfig_.stop_threshold_eou_ == -1.0) { - sessionConfig_.stop_threshold_eou_ = serverDefaults.stop_threshold_eou_; - } - - // Convert rapidjson document to string for logging - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - session_data.Accept(writer); - - std::cout << "[" << objectName_ << "] Session initialized with defaults: " << buffer.GetString() << std::endl; - return true; - - } catch (const std::exception& e) { - std::cerr << "Failed to parse session response: " << e.what() << std::endl; - return false; + if (sessionConfig_.stop_threshold_ == -1.0) { + sessionConfig_.stop_threshold_ = serverDefaults.stop_threshold_; } + if (sessionConfig_.stop_history_eou_ == -1) { + sessionConfig_.stop_history_eou_ = serverDefaults.stop_history_eou_; + } + if (sessionConfig_.stop_threshold_eou_ == -1.0) { + sessionConfig_.stop_threshold_eou_ = serverDefaults.stop_threshold_eou_; + } + + // Convert rapidjson document to string for logging + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + session_data.Accept(writer); + + std::cout << "[" << objectName_ << "] Session initialized with defaults: " << buffer.GetString() + << std::endl; + return true; + } + catch (const std::exception& e) { + std::cerr << "Failed to parse session response: " << e.what() << std::endl; + return false; + } } -nvidia::riva::realtime::RealtimeClient::RealtimeClient( - const std::string& objectName, - const std::shared_ptr audioChunksPtr, +nvidia::riva::realtime::RealtimeClient::RealtimeClient( + const std::string& objectName, const std::shared_ptr audioChunksPtr, nvidia::riva::utils::PerformanceStats& perfCounter) : WebSocketClientBase("ws://127.0.0.1:9090/v1/realtime?intent=transcription"), - sessionInitialized_(false), - sessionUpdated_(false), - transcriptionCompleted_(false), - finalTranscriptionCount_(0), - connectionTimeoutInMs_(std::size_t(10000)), - sessionInitTimeoutInMs_(std::size_t(10000)), - sessionUpdateTimeoutInMs_(std::size_t(10000)), - transcriptionTimeoutInMs_(std::size_t(10000)), - chunkDelayTimeInMs_(std::size_t(1000)), - objectName_(objectName), - audioChunksPtr_(audioChunksPtr), - perfCounter_(perfCounter) { - - nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); - - // Initialize default session config - sessionConfig_.language_code_ = "en-US"; - sessionConfig_.model_name_ = "parakeet-1.1b-en-US-asr-streaming-silero-vad-asr-bls-ensemble"; - sessionConfig_.max_alternatives_ = 1; - sessionConfig_.automatic_punctuation_ = true; - sessionConfig_.word_time_offsets_ = true; - sessionConfig_.profanity_filter_ = false; - sessionConfig_.verbatim_transcripts_ = true; - sessionConfig_.speaker_diarization_ = false; - sessionConfig_.diarization_max_speakers_ = 4; + sessionInitialized_(false), sessionUpdated_(false), transcriptionCompleted_(false), + finalTranscriptionCount_(0), connectionTimeoutInMs_(std::size_t(10000)), + sessionInitTimeoutInMs_(std::size_t(10000)), sessionUpdateTimeoutInMs_(std::size_t(10000)), + transcriptionTimeoutInMs_(std::size_t(10000)), chunkDelayTimeInMs_(std::size_t(1000)), + objectName_(objectName), audioChunksPtr_(audioChunksPtr), perfCounter_(perfCounter) +{ + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); } -void nvidia::riva::realtime::RealtimeClient::SetTimingConfig( const std::size_t connectionTimeoutInMs, - const std::size_t sessionInitTimeoutInMs, - const std::size_t sessionUpdateTimeoutInMs, - const std::size_t transcriptionTimeoutInMs, - const std::size_t chunkDelayTimeInMs) { - connectionTimeoutInMs_ = connectionTimeoutInMs; - sessionInitTimeoutInMs_ = sessionInitTimeoutInMs; - sessionUpdateTimeoutInMs_ = sessionUpdateTimeoutInMs; - transcriptionTimeoutInMs_ = transcriptionTimeoutInMs; - chunkDelayTimeInMs_ = chunkDelayTimeInMs; - nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); +void +nvidia::riva::realtime::RealtimeClient::SetTimingConfig( + const std::size_t connectionTimeoutInMs, const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs) +{ + connectionTimeoutInMs_ = connectionTimeoutInMs; + sessionInitTimeoutInMs_ = sessionInitTimeoutInMs; + sessionUpdateTimeoutInMs_ = sessionUpdateTimeoutInMs; + transcriptionTimeoutInMs_ = transcriptionTimeoutInMs; + chunkDelayTimeInMs_ = chunkDelayTimeInMs; + nvidia::riva::realtime::WebSocketClientBase::SetConnectionTimeout(connectionTimeoutInMs_); } -void nvidia::riva::realtime::RealtimeClient::Log(const std::string& message) { - std::cout << "[" << objectName_ << "]" << message << std::endl; +void +nvidia::riva::realtime::RealtimeClient::Log(const std::string& message) +{ + std::cout << "[" << objectName_ << "]" << message << std::endl; } -bool nvidia::riva::realtime::RealtimeClient::WaitForTranscriptionCompletion() { - std::unique_lock lock(transcriptionMutex_); - - // Reset completion flag - transcriptionCompleted_ = false; - - // Wait for completion event with timeout (increased from 3 seconds to 10 seconds) - bool completed = transcriptionCv_.wait_for(lock, - std::chrono::milliseconds(transcriptionTimeoutInMs_), - [this] { return transcriptionCompleted_; }); - - if (!completed) { - Log(" Timeout waiting for transcription completion after " + std::to_string(transcriptionTimeoutInMs_) + " milliseconds"); - } - else if (transcriptionCompleted_) { - // Close the connection - Close(); - } - - return completed; +bool +nvidia::riva::realtime::RealtimeClient::WaitForTranscriptionCompletion() +{ + std::unique_lock lock(transcriptionMutex_); + + // Reset completion flag + transcriptionCompleted_ = false; + + // Wait for completion event with timeout (increased from 3 seconds to 10 seconds) + bool completed = transcriptionCv_.wait_for( + lock, std::chrono::milliseconds(transcriptionTimeoutInMs_), + [this] { return transcriptionCompleted_; }); + + if (!completed) { + Log(" Timeout waiting for transcription completion after " + + std::to_string(transcriptionTimeoutInMs_) + " milliseconds"); + } else if (transcriptionCompleted_) { + // Close the connection + Close(); + } + + return completed; } -bool nvidia::riva::realtime::RealtimeClient::WaitForSessionUpdate() { - std::unique_lock lock(sessionMutex_); - - if (sessionUpdated_) { - return true; - } +bool +nvidia::riva::realtime::RealtimeClient::WaitForSessionUpdate() +{ + std::unique_lock lock(sessionMutex_); - // Wait for session update event with timeout - sessionUpdated_ = sessionCv_.wait_for( - lock, - std::chrono::milliseconds(sessionUpdateTimeoutInMs_), - [this] { return sessionUpdated_; } - ); + if (sessionUpdated_) { + return true; + } - if (!sessionUpdated_) { - Log("Timeout waiting for session update after " + std::to_string(sessionUpdateTimeoutInMs_) + " milliseconds"); - } - - return sessionUpdated_; + // Wait for session update event with timeout + sessionUpdated_ = sessionCv_.wait_for( + lock, std::chrono::milliseconds(sessionUpdateTimeoutInMs_), + [this] { return sessionUpdated_; }); + + if (!sessionUpdated_) { + Log("Timeout waiting for session update after " + std::to_string(sessionUpdateTimeoutInMs_) + + " milliseconds"); + } + + return sessionUpdated_; } // Send audio buffer append message (inspired by Python realtime.py) -void nvidia::riva::realtime::RealtimeClient::SendAudioAppend(const std::string& audioBase64) +void +nvidia::riva::realtime::RealtimeClient::SendAudioAppend(const std::string& audioBase64) { - std::lock_guard lock(connectionMutex_); - if (IsConnectionOpen()) - { - rapidjson::Document doc; - doc.SetObject(); - rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); - doc.AddMember("type", rapidjson::Value("input_audio_buffer.append", allocator), allocator); - doc.AddMember("audio", rapidjson::Value(audioBase64.c_str(), allocator), allocator); - - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - doc.Accept(writer); - - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); - if (ec) { - Log("Audio append failed: " + ec.message()); - // Mark connection as failed - { - std::lock_guard conn_lock(connectionMutex_); - connectionClosedByServer_ = true; - } - } - } - else { - Log("Skipping audio append - connection closed"); - } + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + doc.AddMember("type", rapidjson::Value("input_audio_buffer.append", allocator), allocator); + doc.AddMember("audio", rapidjson::Value(audioBase64.c_str(), allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio append failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } else { + Log("Skipping audio append - connection closed"); + } } // Send audio buffer commit message (inspired by Python realtime.py) -void nvidia::riva::realtime::RealtimeClient::SendAudioCommit() { - std::lock_guard lock(connectionMutex_); - if (IsConnectionOpen()) - { - rapidjson::Document doc; - doc.SetObject(); - rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); - - doc.AddMember("type", rapidjson::Value("input_audio_buffer.commit", allocator), allocator); - - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - doc.Accept(writer); - - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); - if (ec) { - Log("Audio commit failed: " + ec.message()); - // Mark connection as failed - { - std::lock_guard conn_lock(connectionMutex_); - connectionClosedByServer_ = true; - } - } - } - else { - Log("Skipping audio commit - connection closed"); - } +void +nvidia::riva::realtime::RealtimeClient::SendAudioCommit() +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.commit", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio commit failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } + } else { + Log("Skipping audio commit - connection closed"); + } } // Send audio buffer done message (inspired by Python realtime.py) -void nvidia::riva::realtime::RealtimeClient::SendAudioDone() { - std::lock_guard lock(connectionMutex_); - if (IsConnectionOpen()) - { - rapidjson::Document doc; - doc.SetObject(); - rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); - - doc.AddMember("type", rapidjson::Value("input_audio_buffer.done", allocator), allocator); - - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - doc.Accept(writer); - - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); - if (ec) { - Log("Audio done failed: " + ec.message()); - // Mark connection as failed - { - std::lock_guard conn_lock(connectionMutex_); - connectionClosedByServer_ = true; - } - } else { - Log("Audio streaming completed"); - } - } - else { - Log("Skipping audio done - connection closed"); +void +nvidia::riva::realtime::RealtimeClient::SendAudioDone() +{ + std::lock_guard lock(connectionMutex_); + if (IsConnectionOpen()) { + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + doc.AddMember("type", rapidjson::Value("input_audio_buffer.done", allocator), allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + doc.Accept(writer); + + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + Log("Audio done failed: " + ec.message()); + // Mark connection as failed + { + std::lock_guard conn_lock(connectionMutex_); + connectionClosedByServer_ = true; + } + } else { + Log("Audio streaming completed"); } + } else { + Log("Skipping audio done - connection closed"); + } } // Modify the InitializeSession method to call HTTP initialization first -bool nvidia::riva::realtime::RealtimeClient::InitializeSession() { - std::cout << "[" << objectName_ << "]" << " Initializing session..." << std::endl; - - // Step 1: Initialize HTTP session - if (!InitializeHttpSession()) { - std::cerr << "Failed to initialize HTTP session" << std::endl; - return false; - } - - // Step 2: Wait for the initial connection and session creation - std::this_thread::sleep_for(std::chrono::milliseconds(3000)); - - // Step 3: Check if we're still connected - if (IsConnectionClosed()) { - std::cerr << "Connection lost during session initialization" << std::endl; - return false; - } - - // Step 4: Update session configuration - return UpdateSessionConfig(); +bool +nvidia::riva::realtime::RealtimeClient::InitializeSession() +{ + std::cout << "[" << objectName_ << "]" + << " Initializing session..." << std::endl; + + // Step 1: Initialize HTTP session + if (!InitializeHttpSession()) { + std::cerr << "Failed to initialize HTTP session" << std::endl; + return false; + } + + // Step 2: Wait for the initial connection and session creation + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + + // Step 3: Check if we're still connected + if (IsConnectionClosed()) { + std::cerr << "Connection lost during session initialization" << std::endl; + return false; + } + + // Step 4: Update session configuration + return UpdateSessionConfig(); } -bool nvidia::riva::realtime::RealtimeClient::UpdateSessionConfig() { - int sampleRateHz = audioChunksPtr_->GetSampleRateHz(); - int numChannels = audioChunksPtr_->GetNumChannels(); - - std::cout << "Updating session configuration..." << std::endl; - std::cout << "Using WAV file parameters - Sample rate: " << sampleRateHz - << " Hz, Channels: " << numChannels << std::endl; - - // Create session configuration using sessionConfig_ (which now has defaults + user overrides) - rapidjson::Document doc; - doc.SetObject(); - rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); - - // Create session config - rapidjson::Value session_config(rapidjson::kObjectType); - - // Add modalities - rapidjson::Value modalities(rapidjson::kArrayType); - modalities.PushBack(rapidjson::Value("text", allocator), allocator); - session_config.AddMember("modalities", modalities, allocator); - - // Add input audio format - session_config.AddMember("input_audio_format", rapidjson::Value("pcm16", allocator), allocator); - - // Input audio transcription config - rapidjson::Value transcription_config(rapidjson::kObjectType); - transcription_config.AddMember("language", rapidjson::Value(sessionConfig_.language_code_.c_str(), allocator), allocator); - transcription_config.AddMember("model", rapidjson::Value(sessionConfig_.model_name_.c_str(), allocator), allocator); - transcription_config.AddMember("prompt", rapidjson::Value(rapidjson::kNullType), allocator); - session_config.AddMember("input_audio_transcription", transcription_config, allocator); - - // Input audio params - use actual WAV file parameters - rapidjson::Value audio_params(rapidjson::kObjectType); - audio_params.AddMember("sample_rate_hz", sampleRateHz, allocator); - audio_params.AddMember("num_channels", numChannels, allocator); - session_config.AddMember("input_audio_params", audio_params, allocator); - - // Recognition config - use session configuration - rapidjson::Value recognition_config(rapidjson::kObjectType); - recognition_config.AddMember("max_alternatives", sessionConfig_.max_alternatives_, allocator); - recognition_config.AddMember("enable_automatic_punctuation", sessionConfig_.automatic_punctuation_, allocator); - recognition_config.AddMember("enable_word_time_offsets", sessionConfig_.word_time_offsets_, allocator); - recognition_config.AddMember("enable_profanity_filter", sessionConfig_.profanity_filter_, allocator); - recognition_config.AddMember("enable_verbatim_transcripts", sessionConfig_.verbatim_transcripts_, allocator); - recognition_config.AddMember("custom_configuration", rapidjson::Value(sessionConfig_.custom_configuration_.c_str(), allocator), allocator); - session_config.AddMember("recognition_config", recognition_config, allocator); - - // Speaker diarization config - rapidjson::Value diarization_config(rapidjson::kObjectType); - diarization_config.AddMember("enable_speaker_diarization", sessionConfig_.speaker_diarization_, allocator); - diarization_config.AddMember("max_speaker_count", sessionConfig_.diarization_max_speakers_, allocator); - session_config.AddMember("speaker_diarization", diarization_config, allocator); - - // Word boosting config - rapidjson::Value word_boosting_config(rapidjson::kObjectType); - bool enable_word_boosting = !sessionConfig_.boosted_words_file_.empty(); - word_boosting_config.AddMember("enable_word_boosting", enable_word_boosting, allocator); - - if (enable_word_boosting) { - rapidjson::Value word_list(rapidjson::kArrayType); - std::ifstream file(sessionConfig_.boosted_words_file_); - std::string word; - while (std::getline(file, word)) { - if (!word.empty()) { - word_list.PushBack(rapidjson::Value(word.c_str(), allocator), allocator); - } - } - word_boosting_config.AddMember("word_boosting_list", word_list, allocator); +bool +nvidia::riva::realtime::RealtimeClient::UpdateSessionConfig() +{ + int sampleRateHz = audioChunksPtr_->GetSampleRateHz(); + int numChannels = audioChunksPtr_->GetNumChannels(); + + std::cout << "Updating session configuration..." << std::endl; + std::cout << "Using WAV file parameters - Sample rate: " << sampleRateHz + << " Hz, Channels: " << numChannels << std::endl; + + // Create session configuration using sessionConfig_ (which now has defaults + user overrides) + rapidjson::Document doc; + doc.SetObject(); + rapidjson::Document::AllocatorType& allocator = doc.GetAllocator(); + + // Create session config + rapidjson::Value session_config(rapidjson::kObjectType); + + // Add modalities + rapidjson::Value modalities(rapidjson::kArrayType); + modalities.PushBack(rapidjson::Value("text", allocator), allocator); + session_config.AddMember("modalities", modalities, allocator); + + // Add input audio format + session_config.AddMember("input_audio_format", rapidjson::Value("pcm16", allocator), allocator); + + // Input audio transcription config + rapidjson::Value transcription_config(rapidjson::kObjectType); + transcription_config.AddMember( + "language", rapidjson::Value(sessionConfig_.language_code_.c_str(), allocator), allocator); + transcription_config.AddMember( + "model", rapidjson::Value(sessionConfig_.model_name_.c_str(), allocator), allocator); + transcription_config.AddMember("prompt", rapidjson::Value(rapidjson::kNullType), allocator); + session_config.AddMember("input_audio_transcription", transcription_config, allocator); + + // Input audio params - use actual WAV file parameters + rapidjson::Value audio_params(rapidjson::kObjectType); + audio_params.AddMember("sample_rate_hz", sampleRateHz, allocator); + audio_params.AddMember("num_channels", numChannels, allocator); + session_config.AddMember("input_audio_params", audio_params, allocator); + + // Recognition config - use session configuration + rapidjson::Value recognition_config(rapidjson::kObjectType); + recognition_config.AddMember("max_alternatives", sessionConfig_.max_alternatives_, allocator); + recognition_config.AddMember( + "enable_automatic_punctuation", sessionConfig_.automatic_punctuation_, allocator); + recognition_config.AddMember( + "enable_word_time_offsets", sessionConfig_.word_time_offsets_, allocator); + recognition_config.AddMember( + "enable_profanity_filter", sessionConfig_.profanity_filter_, allocator); + recognition_config.AddMember( + "enable_verbatim_transcripts", sessionConfig_.verbatim_transcripts_, allocator); + recognition_config.AddMember( + "custom_configuration", + rapidjson::Value(sessionConfig_.custom_configuration_.c_str(), allocator), allocator); + session_config.AddMember("recognition_config", recognition_config, allocator); + + // Speaker diarization config + rapidjson::Value diarization_config(rapidjson::kObjectType); + diarization_config.AddMember( + "enable_speaker_diarization", sessionConfig_.speaker_diarization_, allocator); + diarization_config.AddMember( + "max_speaker_count", sessionConfig_.diarization_max_speakers_, allocator); + session_config.AddMember("speaker_diarization", diarization_config, allocator); + + // Word boosting config + rapidjson::Value word_boosting_config(rapidjson::kObjectType); + bool enable_word_boosting = !sessionConfig_.boosted_words_file_.empty(); + word_boosting_config.AddMember("enable_word_boosting", enable_word_boosting, allocator); + + if (enable_word_boosting) { + rapidjson::Value word_list(rapidjson::kArrayType); + std::ifstream file(sessionConfig_.boosted_words_file_); + std::string word; + while (std::getline(file, word)) { + if (!word.empty()) { + word_list.PushBack(rapidjson::Value(word.c_str(), allocator), allocator); + } + } + word_boosting_config.AddMember("word_boosting_list", word_list, allocator); + } else { + rapidjson::Value empty_list(rapidjson::kArrayType); + word_boosting_config.AddMember("word_boosting_list", empty_list, allocator); + } + session_config.AddMember("word_boosting", word_boosting_config, allocator); + + // Endpointing config + rapidjson::Value endpointing_config(rapidjson::kObjectType); + endpointing_config.AddMember("start_history", sessionConfig_.start_history_, allocator); + endpointing_config.AddMember("start_threshold", sessionConfig_.start_threshold_, allocator); + endpointing_config.AddMember("stop_history", sessionConfig_.stop_history_, allocator); + endpointing_config.AddMember("stop_threshold", sessionConfig_.stop_threshold_, allocator); + endpointing_config.AddMember("stop_history_eou", sessionConfig_.stop_history_eou_, allocator); + endpointing_config.AddMember("stop_threshold_eou", sessionConfig_.stop_threshold_eou_, allocator); + session_config.AddMember("endpointing_config", endpointing_config, allocator); + + // Create update request + rapidjson::Value update_request(rapidjson::kObjectType); + update_request.AddMember( + "type", rapidjson::Value("transcription_session.update", allocator), allocator); + update_request.AddMember("session", session_config, allocator); + + // Send the update request + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + update_request.Accept(writer); + + if (IsConnectionOpen()) { + std::lock_guard lock(connectionMutex_); + websocketpp::lib::error_code ec; + wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); + if (ec) { + std::cout << "Session update failed: " << ec.message() << std::endl; + return false; } else { - rapidjson::Value empty_list(rapidjson::kArrayType); - word_boosting_config.AddMember("word_boosting_list", empty_list, allocator); - } - session_config.AddMember("word_boosting", word_boosting_config, allocator); - - // Endpointing config - rapidjson::Value endpointing_config(rapidjson::kObjectType); - endpointing_config.AddMember("start_history", sessionConfig_.start_history_, allocator); - endpointing_config.AddMember("start_threshold", sessionConfig_.start_threshold_, allocator); - endpointing_config.AddMember("stop_history", sessionConfig_.stop_history_, allocator); - endpointing_config.AddMember("stop_threshold", sessionConfig_.stop_threshold_, allocator); - endpointing_config.AddMember("stop_history_eou", sessionConfig_.stop_history_eou_, allocator); - endpointing_config.AddMember("stop_threshold_eou", sessionConfig_.stop_threshold_eou_, allocator); - session_config.AddMember("endpointing_config", endpointing_config, allocator); - - // Create update request - rapidjson::Value update_request(rapidjson::kObjectType); - update_request.AddMember("type", rapidjson::Value("transcription_session.update", allocator), allocator); - update_request.AddMember("session", session_config, allocator); - - // Send the update request - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - update_request.Accept(writer); - - if (IsConnectionOpen()) - { - std::lock_guard lock(connectionMutex_); - websocketpp::lib::error_code ec; - wsClient_.send(connectionHdl_, buffer.GetString(), websocketpp::frame::opcode::text, ec); - if (ec) { - std::cout << "Session update failed: " << ec.message() << std::endl; - return false; - } else { - std::cout << "Session update request sent" << std::endl; - } + std::cout << "Session update request sent" << std::endl; } + } - - WaitForSessionUpdate(); - return true; + WaitForSessionUpdate(); + return true; } // Send audio chunks -void nvidia::riva::realtime::RealtimeClient::SendAudioChunks(const bool simulateRealtime) { - if (audioChunksPtr_ == nullptr) { - std::cerr << "Audio chunks pointer is null. Please call InitializeSession first." << std::endl; - return; - } - - if (!IsSessionInitialized()) { - std::cerr << "Session is not initialized. Please call InitializeSession first." << std::endl; - return; - } +void +nvidia::riva::realtime::RealtimeClient::SendAudioChunks(const bool simulateRealtime) +{ + if (audioChunksPtr_ == nullptr) { + std::cerr << "Audio chunks pointer is null. Please call InitializeSession first." << std::endl; + return; + } - if (audioChunksPtr_->size() == 0) { - std::cerr << "No audio chunks to send. Please add audio chunks to the audio chunks pointer." << std::endl; - return; - } + if (!IsSessionInitialized()) { + std::cerr << "Session is not initialized. Please call InitializeSession first." << std::endl; + return; + } + + if (audioChunksPtr_->size() == 0) { + std::cerr << "No audio chunks to send. Please add audio chunks to the audio chunks pointer." + << std::endl; + return; + } + + std::cout << "Sending audio chunks with " << (simulateRealtime ? "real-time" : "burst") + << " timing..." << std::endl; + + // Track timing for accurate real-time simulation + auto stream_start_time = std::chrono::steady_clock::now(); + size_t chunk_index = 0; + + for (const std::string& chunk_base64 : *audioChunksPtr_) { + SendAudioAppend(chunk_base64); + SendAudioCommit(); - std::cout << "Sending audio chunks with " << (simulateRealtime ? "real-time" : "burst") << " timing..." << std::endl; - - // Track timing for accurate real-time simulation - auto stream_start_time = std::chrono::steady_clock::now(); - size_t chunk_index = 0; - - for (const std::string& chunk_base64 : *audioChunksPtr_) { - SendAudioAppend(chunk_base64); - SendAudioCommit(); - - if (simulateRealtime) { - // Calculate the exact time when this chunk should be sent - auto chunk_duration_ms = audioChunksPtr_->GetChunkSizeMs(); - auto expected_send_time = stream_start_time + - std::chrono::milliseconds((chunk_index + 1) * chunk_duration_ms); - - auto current_time = std::chrono::steady_clock::now(); - auto time_to_wait = expected_send_time - current_time; - - // Log timing information - // Timing calculations for real-time simulation (commented out as unused) - // auto elapsed_ms = std::chrono::duration(current_time - stream_start_time).count(); - // auto expected_ms = (chunk_index + 1) * chunk_duration_ms; - // auto drift_ms = elapsed_ms - expected_ms; - - //auto wait_ms = std::chrono::duration(time_to_wait).count(); - //std::cout << "[" << objectName_ << "] Chunk " << (chunk_index + 1) << "/" << audioChunksPtr_->size() - // << " - Elapsed: " << std::fixed << std::setprecision(1) << elapsed_ms << "ms" - // << " Expected: " << expected_ms << "ms" - // << " Drift: " << drift_ms << "ms"; - // << " Waiting: " << wait_ms << "ms" << std::endl; - - if (time_to_wait > std::chrono::milliseconds(0)) { - std::this_thread::sleep_for(time_to_wait); - } - } - else { - // Burst mode - just log progress - if ((chunk_index + 1) % 10 == 0 || chunk_index == audioChunksPtr_->size() - 1) { - //zstd::cout << "[" << objectName_ << "] Sent " << (chunk_index + 1) << "/" << audioChunksPtr_->size() << " chunks" << std::endl; - } - } - - chunk_index++; + if (simulateRealtime) { + // Calculate the exact time when this chunk should be sent + auto chunk_duration_ms = audioChunksPtr_->GetChunkSizeMs(); + auto expected_send_time = + stream_start_time + std::chrono::milliseconds((chunk_index + 1) * chunk_duration_ms); + + auto current_time = std::chrono::steady_clock::now(); + auto time_to_wait = expected_send_time - current_time; + + // Log timing information + // Timing calculations for real-time simulation (commented out as unused) + // auto elapsed_ms = std::chrono::duration(current_time - + // stream_start_time).count(); auto expected_ms = (chunk_index + 1) * chunk_duration_ms; auto + // drift_ms = elapsed_ms - expected_ms; + + // auto wait_ms = std::chrono::duration(time_to_wait).count(); + // std::cout << "[" << objectName_ << "] Chunk " << (chunk_index + 1) << "/" << + // audioChunksPtr_->size() + // << " - Elapsed: " << std::fixed << std::setprecision(1) << elapsed_ms << "ms" + // << " Expected: " << expected_ms << "ms" + // << " Drift: " << drift_ms << "ms"; + // << " Waiting: " << wait_ms << "ms" << std::endl; + + if (time_to_wait > std::chrono::milliseconds(0)) { + std::this_thread::sleep_for(time_to_wait); + } + } else { + // Burst mode - just log progress + if ((chunk_index + 1) % 10 == 0 || chunk_index == audioChunksPtr_->size() - 1) { + // zstd::cout << "[" << objectName_ << "] Sent " << (chunk_index + 1) << "/" << + // audioChunksPtr_->size() << " chunks" << std::endl; + } } - SendAudioDone(); + + chunk_index++; + } + SendAudioDone(); } -void nvidia::riva::realtime::RealtimeClient::HandleMessage(const std::string& message) { - bool is_last_result = false; - rapidjson::Document doc; - - if (doc.Parse(message.c_str()).HasParseError()) { - std::cerr << "Failed to parse JSON message" << std::endl; - return; - } - - std::string eventType = doc.HasMember("type") ? doc["type"].GetString() : ""; - - if (eventType == "conversation.created") { - std::cout << "Conversation created" << std::endl; - } - else if (eventType == "transcription_session.updated") { - std::cout << "Session updated successfully" << std::endl; - sessionInitialized_ = true; - // Signal session update completion - { - std::lock_guard lock(sessionMutex_); - sessionUpdated_ = true; - } - sessionCv_.notify_one(); - } - else if (eventType == "conversation.item.input_audio_transcription.delta") { - if (doc.HasMember("delta")) { - std::string delta = doc["delta"].GetString(); - - //std::cout << "Delta: " << delta << std::endl; - std::cout.flush(); // Ensure immediate output for streaming - } +void +nvidia::riva::realtime::RealtimeClient::HandleMessage(const std::string& message) +{ + bool is_last_result = false; + rapidjson::Document doc; + + if (doc.Parse(message.c_str()).HasParseError()) { + std::cerr << "Failed to parse JSON message" << std::endl; + return; + } + + std::string eventType = doc.HasMember("type") ? doc["type"].GetString() : ""; + + if (eventType == "conversation.created") { + std::cout << "Conversation created" << std::endl; + } else if (eventType == "transcription_session.updated") { + std::cout << "Session updated successfully" << std::endl; + sessionInitialized_ = true; + // Signal session update completion + { + std::lock_guard lock(sessionMutex_); + sessionUpdated_ = true; } - else if (eventType == "conversation.item.input_audio_transcription.completed") { - finalTranscriptionCount_++; - std::string transcript = doc.HasMember("transcript") ? doc["transcript"].GetString() : ""; - is_last_result = doc.HasMember("is_last_result") ? doc["is_last_result"].GetBool() : false; - - if (is_last_result) { - std::cout << "--------------------------------" << std::endl; - std::cout << "Final transcript: " << transcript << std::endl; - std::cout << "Final transcription count: " << finalTranscriptionCount_ << std::endl; - std::cout << "--------------------------------" << std::endl; - - // Transcription completed - std::lock_guard lock(transcriptionMutex_); - transcriptionCompleted_ = true; - transcriptionCv_.notify_one(); - } - else { - std::cout << "Interim transcript: " << transcript << std::endl; - } + sessionCv_.notify_one(); + } else if (eventType == "conversation.item.input_audio_transcription.delta") { + if (doc.HasMember("delta")) { + std::string delta = doc["delta"].GetString(); + + // std::cout << "Delta: " << delta << std::endl; + std::cout.flush(); // Ensure immediate output for streaming } - else if (eventType.find("error") != std::string::npos) { - std::string errorMsg = "Unknown error"; - if (doc.HasMember("error") && doc["error"].HasMember("message")) { - errorMsg = doc["error"]["message"].GetString(); - } - std::cerr << "Error: " << errorMsg << std::endl; + } else if (eventType == "conversation.item.input_audio_transcription.completed") { + finalTranscriptionCount_++; + std::string transcript = doc.HasMember("transcript") ? doc["transcript"].GetString() : ""; + is_last_result = doc.HasMember("is_last_result") ? doc["is_last_result"].GetBool() : false; + + if (is_last_result) { + std::cout << "--------------------------------" << std::endl; + std::cout << "Final transcript: " << transcript << std::endl; + std::cout << "Final transcription count: " << finalTranscriptionCount_ << std::endl; + std::cout << "--------------------------------" << std::endl; + + // Transcription completed + std::lock_guard lock(transcriptionMutex_); + transcriptionCompleted_ = true; + transcriptionCv_.notify_one(); + } else { + std::cout << "Interim transcript: " << transcript << std::endl; } - else { - //std::cout << "Received message type: " << event_type << std::endl; + } else if (eventType.find("error") != std::string::npos) { + std::string errorMsg = "Unknown error"; + if (doc.HasMember("error") && doc["error"].HasMember("message")) { + errorMsg = doc["error"]["message"].GetString(); } + std::cerr << "Error: " << errorMsg << std::endl; + } else { + // std::cout << "Received message type: " << event_type << std::endl; + } +} + +// Public wrapper methods for microphone audio streaming +void +nvidia::riva::realtime::RealtimeClient::SendAudioAppendPublic(const std::string& audioBase64) +{ + SendAudioAppend(audioBase64); } - \ No newline at end of file + +void +nvidia::riva::realtime::RealtimeClient::SendAudioCommitPublic() +{ + SendAudioCommit(); +} + +void +nvidia::riva::realtime::RealtimeClient::SendAudioDonePublic() +{ + SendAudioDone(); +} \ No newline at end of file diff --git a/riva/clients/realtime/realtime_client.h b/riva/clients/realtime/realtime_client.h index 5178cf9..eaf6669 100644 --- a/riva/clients/realtime/realtime_client.h +++ b/riva/clients/realtime/realtime_client.h @@ -9,8 +9,6 @@ #include #include #include -#include -#include #include #include @@ -20,140 +18,148 @@ #include #include #include +#include +#include #include "audio_chunks.h" #include "base_client.h" #include "riva/utils/stats_builder/stats_builder.h" // Add these includes for HTTP functionality -#include -#include #include -#include #include +#include +#include +#include + #include namespace nvidia::riva::realtime { - class SessionConfig { - public: - std::size_t connectionTimeoutInMs_; - std::size_t sessionInitTimeoutInMs_; - std::size_t sessionUpdateTimeoutInMs_; - std::size_t transcriptionTimeoutInMs_; - std::size_t chunkDelayTimeInMs_; - - // Add session configuration parameters - std::string language_code_; - std::string model_name_; - int max_alternatives_; - bool automatic_punctuation_; - bool word_time_offsets_; - bool profanity_filter_; - bool verbatim_transcripts_; - std::string boosted_words_file_; - double boosted_words_score_; - bool speaker_diarization_; - int diarization_max_speakers_; - int start_history_; - double start_threshold_; - int stop_history_; - double stop_threshold_; - int stop_history_eou_; - double stop_threshold_eou_; - std::string custom_configuration_; - - // Add HTTP session data - std::string session_id_; - std::string server_url_; - }; - - class RealtimeClient : public WebSocketClientBase { - private: - - // Session tracking - bool sessionInitialized_; - bool sessionUpdated_; - std::condition_variable sessionCv_; - std::mutex sessionMutex_; - nvidia::riva::utils::PerformanceStats& perfCounter_; - - - // Event tracking - bool transcriptionCompleted_; - std::condition_variable transcriptionCv_; - std::mutex transcriptionMutex_; - - std::size_t finalTranscriptionCount_; - - // Configurable timing parameters (in milliseconds) - std::size_t connectionTimeoutInMs_; - std::size_t sessionInitTimeoutInMs_; - std::size_t sessionUpdateTimeoutInMs_; - std::size_t transcriptionTimeoutInMs_; - std::size_t chunkDelayTimeInMs_; - - std::string objectName_; - - // Audio processing - std::shared_ptr audioChunksPtr_; - - // Add session configuration - SessionConfig sessionConfig_; - - // Add HTTP session data - std::string session_id_; - std::string server_url_; - - // HTTP session initialization method - bool InitializeHttpSession(); - - // Helper method for HTTP requests - std::string MakeHttpRequest(const std::string& host, int port, const std::string& path, const std::string& method, const std::string& body); - - // Audio streaming methods - void SendAudioAppend(const std::string& audioBase64); - void SendAudioCommit(); - void SendAudioDone(); - - // Override base class methods - void HandleMessage(const std::string& message) override; - - public: - RealtimeClient( const std::string& objectName, - const std::shared_ptr audioChunksPtr, - nvidia::riva::utils::PerformanceStats& perfCounter); - ~RealtimeClient() = default; - - void Log(const std::string& message); - - // Timing configuration - void SetTimingConfig( const std::size_t connectionTimeoutInMs, - const std::size_t sessionInitTimeoutInMs, - const std::size_t sessionUpdateTimeoutInMs, - const std::size_t transcriptionTimeoutInMs, - const std::size_t chunkDelayTimeInMs); - - // Session configuration - void SetSessionConfig(const SessionConfig& config) { sessionConfig_ = config; } - - // Session management methods - bool InitializeSession(); - bool UpdateSessionConfig(); - - bool IsSessionInitialized() const { return sessionInitialized_; } - - // Wait methods - bool WaitForSessionUpdate(); - bool WaitForTranscriptionCompletion(); - - // WAV file processing methods - void SendAudioChunks(const bool simulateRealtime = false); - - // Add method to set server URL - void SetServerUrl(const std::string& server_url) { server_url_ = server_url; } - std::string GetSessionId() const { return session_id_; } - }; - -} // namespace nvidia::riva::realtime - -#endif // REALTIME_CLIENT_H \ No newline at end of file +class SessionConfig { + public: + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + // Add session configuration parameters + std::string language_code_; + std::string model_name_; + int max_alternatives_; + bool automatic_punctuation_; + bool word_time_offsets_; + bool profanity_filter_; + bool verbatim_transcripts_; + std::string boosted_words_file_; + double boosted_words_score_; + bool speaker_diarization_; + int diarization_max_speakers_; + int start_history_; + double start_threshold_; + int stop_history_; + double stop_threshold_; + int stop_history_eou_; + double stop_threshold_eou_; + std::string custom_configuration_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; +}; + +class RealtimeClient : public WebSocketClientBase { + private: + // Session tracking + bool sessionInitialized_; + bool sessionUpdated_; + std::condition_variable sessionCv_; + std::mutex sessionMutex_; + nvidia::riva::utils::PerformanceStats& perfCounter_; + + + // Event tracking + bool transcriptionCompleted_; + std::condition_variable transcriptionCv_; + std::mutex transcriptionMutex_; + + std::size_t finalTranscriptionCount_; + + // Configurable timing parameters (in milliseconds) + std::size_t connectionTimeoutInMs_; + std::size_t sessionInitTimeoutInMs_; + std::size_t sessionUpdateTimeoutInMs_; + std::size_t transcriptionTimeoutInMs_; + std::size_t chunkDelayTimeInMs_; + + std::string objectName_; + + // Audio processing + std::shared_ptr audioChunksPtr_; + + // Add session configuration + SessionConfig sessionConfig_; + + // Add HTTP session data + std::string session_id_; + std::string server_url_; + + // HTTP session initialization method + bool InitializeHttpSession(); + + // Helper method for HTTP requests + std::string MakeHttpRequest( + const std::string& host, int port, const std::string& path, const std::string& method, + const std::string& body); + + // Audio streaming methods + void SendAudioAppend(const std::string& audioBase64); + void SendAudioCommit(); + void SendAudioDone(); + + // Override base class methods + void HandleMessage(const std::string& message) override; + + public: + RealtimeClient( + const std::string& objectName, const std::shared_ptr audioChunksPtr, + nvidia::riva::utils::PerformanceStats& perfCounter); + ~RealtimeClient() = default; + + void Log(const std::string& message); + + // Timing configuration + void SetTimingConfig( + const std::size_t connectionTimeoutInMs, const std::size_t sessionInitTimeoutInMs, + const std::size_t sessionUpdateTimeoutInMs, const std::size_t transcriptionTimeoutInMs, + const std::size_t chunkDelayTimeInMs); + + // Session configuration + void SetSessionConfig(const SessionConfig& config) { sessionConfig_ = config; } + + // Session management methods + bool InitializeSession(); + bool UpdateSessionConfig(); + + bool IsSessionInitialized() const { return sessionInitialized_; } + + // Wait methods + bool WaitForSessionUpdate(); + bool WaitForTranscriptionCompletion(); + + // WAV file processing methods + void SendAudioChunks(const bool simulateRealtime = false); + + // Public audio streaming methods for microphone input + void SendAudioAppendPublic(const std::string& audioBase64); + void SendAudioCommitPublic(); + void SendAudioDonePublic(); + + // Add method to set server URL + void SetServerUrl(const std::string& server_url) { server_url_ = server_url; } + std::string GetSessionId() const { return session_id_; } +}; + +} // namespace nvidia::riva::realtime + +#endif // REALTIME_CLIENT_H \ No newline at end of file diff --git a/riva/clients/realtime/riva_realtime_asr_client.cc b/riva/clients/realtime/riva_realtime_asr_client.cc index 853461e..2c75caa 100644 --- a/riva/clients/realtime/riva_realtime_asr_client.cc +++ b/riva/clients/realtime/riva_realtime_asr_client.cc @@ -3,74 +3,105 @@ * SPDX-License-Identifier: MIT */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include "riva/clients/realtime/realtime_client.h" -#include "riva/utils/stats_builder/stats_builder.h" #include #include #include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include +#include "audio_chunks.h" +#include "riva/clients/realtime/realtime_client.h" +#include "riva/utils/stats_builder/stats_builder.h" + // Add these includes for HTTP functionality -#include -#include #include -#include #include +#include +#include +#include + #include using namespace nvidia::riva::utils; using namespace nvidia::riva::realtime; // Define command-line flags (matching streaming client) -DEFINE_string(audio_file, "", "Folder that contains audio files to transcribe or individual audio file name"); -DEFINE_int32(max_alternatives, 1, "Maximum number of alternative transcripts to return (up to limit configured on server)"); -DEFINE_bool(profanity_filter, false, "Flag that controls if generated transcripts should be filtered for the profane words"); +DEFINE_string( + audio_file, "", "Folder that contains audio files to transcribe or individual audio file name"); +DEFINE_int32( + max_alternatives, 1, + "Maximum number of alternative transcripts to return (up to limit configured on server)"); +DEFINE_bool( + profanity_filter, false, + "Flag that controls if generated transcripts should be filtered for the profane words"); DEFINE_bool(automatic_punctuation, true, "Flag that controls if transcript should be punctuated"); DEFINE_bool(word_time_offsets, true, "Flag that controls if word time stamps are requested"); -DEFINE_bool(simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); +DEFINE_bool( + simulate_realtime, false, "Flag that controls if audio files should be sent in realtime"); DEFINE_string(audio_device, "", "Name of audio device to use"); -DEFINE_string(riva_uri, "ws://127.0.0.1:9090/v1/realtime?intent=transcription", "URI to access riva-server"); +DEFINE_string( + riva_uri, "ws://127.0.0.1:9090/v1/realtime?intent=transcription", "URI to access riva-server"); DEFINE_int32(num_iterations, 1, "Number of times to loop over audio files"); DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in flight"); DEFINE_int32(chunk_duration_ms, 100, "Chunk duration in milliseconds"); DEFINE_bool(print_transcripts, true, "Print final transcripts"); DEFINE_bool(interim_results, true, "Print intermediate transcripts"); -DEFINE_string(output_filename, "final_transcripts.json", "Filename of .json file containing output transcripts"); +DEFINE_string( + output_filename, "final_transcripts.json", + "Filename of .json file containing output transcripts"); DEFINE_string(model_name, "", "Name of the TRTIS model to use"); DEFINE_string(language_code, "en-US", "Language code of the model to use"); DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); -DEFINE_bool(verbatim_transcripts, true, "True returns text exactly as it was said with no normalization. False applies text inverse normalization"); +DEFINE_bool( + verbatim_transcripts, true, + "True returns text exactly as it was said with no normalization. False applies text inverse " + "normalization"); DEFINE_string(ssl_root_cert, "", "Path to SSL root certificates file"); DEFINE_string(ssl_client_key, "", "Path to SSL client certificates key"); DEFINE_string(ssl_client_cert, "", "Path to SSL client certificates file"); -DEFINE_bool(use_ssl, false, "Whether to use SSL credentials or not. If ssl_root_cert is specified, this is assumed to be true"); +DEFINE_bool( + use_ssl, false, + "Whether to use SSL credentials or not. If ssl_root_cert is specified, this is assumed to be " + "true"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); -DEFINE_int32(start_history, -1, "Value (in milliseconds) to detect and initiate start of speech utterance"); -DEFINE_double(start_threshold, -1., "Threshold value to determine at what percentage start of speech is initiated"); +DEFINE_int32( + start_history, -1, "Value (in milliseconds) to detect and initiate start of speech utterance"); +DEFINE_double( + start_threshold, -1., + "Threshold value to determine at what percentage start of speech is initiated"); DEFINE_int32(stop_history, -1, "Value (in milliseconds) to detect endpoint and reset decoder"); DEFINE_double(stop_threshold, -1., "Threshold value to determine when endpoint detected"); -DEFINE_int32(stop_history_eou, -1, "Value (in milliseconds) to detect endpoint and generate an intermediate final transcript"); -DEFINE_double(stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); -DEFINE_string(custom_configuration, "", "Custom configurations to be sent to the server as key value pairs "); +DEFINE_int32( + stop_history_eou, -1, + "Value (in milliseconds) to detect endpoint and generate an intermediate final transcript"); +DEFINE_double( + stop_threshold_eou, -1., + "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + custom_configuration, "", + "Custom configurations to be sent to the server as key value pairs "); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); -DEFINE_int32(diarization_max_speakers, 4, "Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)"); +DEFINE_int32( + diarization_max_speakers, 4, + "Max number of speakers to detect when performing speaker diarization. Default is 4 (Max)"); DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); DEFINE_uint64(max_grpc_message_size, 16777216, "Max GRPC message size"); @@ -84,350 +115,499 @@ DEFINE_bool(verbose_logging, false, "Enable verbose logging"); DEFINE_bool(show_detailed_stats, true, "Show detailed statistics"); DEFINE_bool(show_tabular_stats, true, "Show tabular statistics"); +// Microphone configuration (hardcoded like ASR clients) +const int MIC_SAMPLE_RATE = 16000; // 16kHz +const int MIC_CHANNELS = 1; // Mono +const int MIC_BIT_DEPTH = 16; // 16-bit + // Global client pointer for signal handling std::vector g_clients; std::mutex g_clients_mutex; +// Global exit flag for microphone coordination (like ASR clients) +std::atomic g_request_exit(false); + // Signal handler for graceful shutdown -void signal_handler(int signal) { - for (auto client : g_clients) { - std::cout << "\nReceived signal " << signal << ", shutting down gracefully..." << std::endl; - client->Close(); - } - exit(0); +void +signal_handler(int signal) +{ + std::cout << "\nReceived signal " << signal << ", shutting down gracefully..." << std::endl; + g_request_exit = true; + + for (auto client : g_clients) { + client->Close(); + } + exit(0); } // Helper function to format throughput as 10.246e00 instead of 1.0246e+01 -std::string format_throughput(double value) { - std::ostringstream oss; - oss << std::fixed << std::setprecision(3) << value << "e00"; - return oss.str(); +std::string +format_throughput(double value) +{ + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << value << "e00"; + return oss.str(); } +// Helper function to create appropriate audio chunks based on input type +std::shared_ptr +CreateAudioChunks( + const std::string& audio_file, const std::string& audio_device, int chunk_duration_ms) +{ + if (!audio_device.empty()) { + // Create microphone-based audio chunks + std::cout << "Creating microphone audio chunks for device: " << audio_device << std::endl; + std::cout << "Sample rate: " << MIC_SAMPLE_RATE << " Hz, Channels: " << MIC_CHANNELS + << ", Bit depth: " << MIC_BIT_DEPTH << std::endl; + auto mic_chunks = std::make_shared( + audio_device, chunk_duration_ms, MIC_SAMPLE_RATE, MIC_CHANNELS, MIC_BIT_DEPTH); -// Function to run the client example -void client_runner( const std::string& uri, - const std::shared_ptr& audio_chunks, - PerformanceStats& perfCounter, - const std::size_t connectionTimeoutInMs, - const std::size_t sessionInitTimeoutInMs, - const std::size_t sessionUpdateTimeoutInMs, - const std::size_t transcriptionTimeoutInMs, - const std::size_t chunkDelayTimeInMs, - const bool simulateRealtime = false) -{ - nvidia::riva::realtime::RealtimeClient client(perfCounter.GetObjectName(), audio_chunks, perfCounter); - - // Extract server URL from URI (remove ws:// and path) - std::string server_url = uri; - if (server_url.find("ws://") == 0) { - server_url = server_url.substr(5); // Remove "ws://" - } else if (server_url.find("wss://") == 0) { - server_url = server_url.substr(6); // Remove "wss://" - } - - // Remove path part (everything after first /) - size_t path_pos = server_url.find('/'); - if (path_pos != std::string::npos) { - server_url = server_url.substr(0, path_pos); - } - - client.SetServerUrl(server_url); - - // Set session configuration from command line flags (these will override defaults) - nvidia::riva::realtime::SessionConfig sessionConfig; - - // Only set values if they were provided by user (not default values) - if (!FLAGS_language_code.empty() && FLAGS_language_code != "en-US") { - sessionConfig.language_code_ = FLAGS_language_code; - } - if (!FLAGS_model_name.empty()) { - sessionConfig.model_name_ = FLAGS_model_name; - } - if (FLAGS_max_alternatives != 1) { - sessionConfig.max_alternatives_ = FLAGS_max_alternatives; - } - if (!FLAGS_automatic_punctuation) { // Default is true, so only override if false - sessionConfig.automatic_punctuation_ = FLAGS_automatic_punctuation; - } - if (!FLAGS_word_time_offsets) { // Default is true, so only override if false - sessionConfig.word_time_offsets_ = FLAGS_word_time_offsets; + if (!mic_chunks->Init()) { + std::cerr << "Failed to initialize microphone audio chunks" << std::endl; + return nullptr; } - if (FLAGS_profanity_filter) { // Default is false, so only override if true - sessionConfig.profanity_filter_ = FLAGS_profanity_filter; - } - if (!FLAGS_verbatim_transcripts) { // Default is true, so only override if false - sessionConfig.verbatim_transcripts_ = FLAGS_verbatim_transcripts; - } - if (!FLAGS_boosted_words_file.empty()) { - sessionConfig.boosted_words_file_ = FLAGS_boosted_words_file; - sessionConfig.boosted_words_score_ = FLAGS_boosted_words_score; - } - if (FLAGS_speaker_diarization) { // Default is false, so only override if true - sessionConfig.speaker_diarization_ = FLAGS_speaker_diarization; - sessionConfig.diarization_max_speakers_ = FLAGS_diarization_max_speakers; - } - if (FLAGS_start_history > 0) { - sessionConfig.start_history_ = FLAGS_start_history; - } - if (FLAGS_start_threshold > 0) { - sessionConfig.start_threshold_ = FLAGS_start_threshold; - } - if (FLAGS_stop_history > 0) { - sessionConfig.stop_history_ = FLAGS_stop_history; - } - if (FLAGS_stop_threshold > 0) { - sessionConfig.stop_threshold_ = FLAGS_stop_threshold; - } - if (FLAGS_stop_history_eou > 0) { - sessionConfig.stop_history_eou_ = FLAGS_stop_history_eou; - } - if (FLAGS_stop_threshold_eou > 0) { - sessionConfig.stop_threshold_eou_ = FLAGS_stop_threshold_eou; + + return mic_chunks; + } else if (!audio_file.empty()) { + // Create file-based audio chunks + std::cout << "Creating file audio chunks for: " << audio_file << std::endl; + + auto file_chunks = std::make_shared(audio_file, chunk_duration_ms); + + if (!file_chunks->Init()) { + std::cerr << "Failed to initialize file audio chunks" << std::endl; + return nullptr; } - if (!FLAGS_custom_configuration.empty()) { - sessionConfig.custom_configuration_ = FLAGS_custom_configuration; + + return file_chunks; + } + + std::cerr << "No audio source specified" << std::endl; + return nullptr; +} + +// Function to run the client example +void +client_runner( + const std::string& uri, + const std::shared_ptr& audio_chunks, + PerformanceStats& perfCounter, const std::size_t connectionTimeoutInMs, + const std::size_t sessionInitTimeoutInMs, const std::size_t sessionUpdateTimeoutInMs, + const std::size_t transcriptionTimeoutInMs, const std::size_t chunkDelayTimeInMs, + const bool simulateRealtime = false) +{ + nvidia::riva::realtime::RealtimeClient client( + perfCounter.GetObjectName(), audio_chunks, perfCounter); + + // Extract server URL from URI (remove ws:// and path) + std::string server_url = uri; + if (server_url.find("ws://") == 0) { + server_url = server_url.substr(5); // Remove "ws://" + } else if (server_url.find("wss://") == 0) { + server_url = server_url.substr(6); // Remove "wss://" + } + + // Remove path part (everything after first /) + size_t path_pos = server_url.find('/'); + if (path_pos != std::string::npos) { + server_url = server_url.substr(0, path_pos); + } + + client.SetServerUrl(server_url); + + // Set session configuration from command line flags (these will override defaults) + nvidia::riva::realtime::SessionConfig sessionConfig; + + // Only set values if they were provided by user (not default values) + if (!FLAGS_language_code.empty() && FLAGS_language_code != "en-US") { + sessionConfig.language_code_ = FLAGS_language_code; + } + if (!FLAGS_model_name.empty()) { + sessionConfig.model_name_ = FLAGS_model_name; + } + if (FLAGS_max_alternatives != 1) { + sessionConfig.max_alternatives_ = FLAGS_max_alternatives; + } + if (!FLAGS_automatic_punctuation) { // Default is true, so only override if false + sessionConfig.automatic_punctuation_ = FLAGS_automatic_punctuation; + } + if (!FLAGS_word_time_offsets) { // Default is true, so only override if false + sessionConfig.word_time_offsets_ = FLAGS_word_time_offsets; + } + if (FLAGS_profanity_filter) { // Default is false, so only override if true + sessionConfig.profanity_filter_ = FLAGS_profanity_filter; + } + if (!FLAGS_verbatim_transcripts) { // Default is true, so only override if false + sessionConfig.verbatim_transcripts_ = FLAGS_verbatim_transcripts; + } + if (!FLAGS_boosted_words_file.empty()) { + sessionConfig.boosted_words_file_ = FLAGS_boosted_words_file; + sessionConfig.boosted_words_score_ = FLAGS_boosted_words_score; + } + if (FLAGS_speaker_diarization) { // Default is false, so only override if true + sessionConfig.speaker_diarization_ = FLAGS_speaker_diarization; + sessionConfig.diarization_max_speakers_ = FLAGS_diarization_max_speakers; + } + if (FLAGS_start_history > 0) { + sessionConfig.start_history_ = FLAGS_start_history; + } + if (FLAGS_start_threshold > 0) { + sessionConfig.start_threshold_ = FLAGS_start_threshold; + } + if (FLAGS_stop_history > 0) { + sessionConfig.stop_history_ = FLAGS_stop_history; + } + if (FLAGS_stop_threshold > 0) { + sessionConfig.stop_threshold_ = FLAGS_stop_threshold; + } + if (FLAGS_stop_history_eou > 0) { + sessionConfig.stop_history_eou_ = FLAGS_stop_history_eou; + } + if (FLAGS_stop_threshold_eou > 0) { + sessionConfig.stop_threshold_eou_ = FLAGS_stop_threshold_eou; + } + if (!FLAGS_custom_configuration.empty()) { + sessionConfig.custom_configuration_ = FLAGS_custom_configuration; + } + + client.SetSessionConfig(sessionConfig); + client.SetVerboseLogging(FLAGS_verbose_logging); + client.SetTimingConfig( + connectionTimeoutInMs, sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, + transcriptionTimeoutInMs, chunkDelayTimeInMs); + + // Step 1: Connect to the WebSocket server + client.Connect(uri); + + std::thread client_thread([&client]() { client.Run(); }); + + // Step 2: Wait for the connection to be established + if (!client.WaitForConnection()) { + std::cerr << "Failed to establish WebSocket connection" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "WebSocket connection established" << std::endl; + + // Step 3: Initialize the session + if (!client.InitializeSession()) { + std::cerr << "Failed to initialize session" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + std::cout << "Waiting for session update confirmation..." << std::endl; + + // Step 4: Wait for the session to be updated + if (!client.WaitForSessionUpdate()) { + std::cerr << "Session update timeout" << std::endl; + client.Close(); + client_thread.join(); + return; + } + + // Step 5: Send the audio chunks with realistic timing + perfCounter.StartProcessingTimer(); + perfCounter.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds()); + + // For microphone input, we need to start capture before sending chunks + if (auto mic_chunks = std::dynamic_pointer_cast(audio_chunks)) { + std::cout << "Starting microphone capture..." << std::endl; + + // Ensure microphone is stopped on early exit + auto mic_cleanup = [mic_chunks]() { + if (mic_chunks->IsCapturing()) { + mic_chunks->StopCapture(); + std::cout << "Stopped microphone capture (cleanup)" << std::endl; + } + }; + + if (!mic_chunks->StartCapture()) { + std::cerr << "Failed to start microphone capture" << std::endl; + mic_cleanup(); + client.Close(); + client_thread.join(); + return; } - - client.SetSessionConfig(sessionConfig); - client.SetVerboseLogging(FLAGS_verbose_logging); - client.SetTimingConfig(connectionTimeoutInMs, sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, transcriptionTimeoutInMs, chunkDelayTimeInMs); - - // Step 1: Connect to the WebSocket server - client.Connect(uri); - - std::thread client_thread([&client]() { - client.Run(); + + // For microphone: start continuous audio streaming in background thread + std::cout << "Starting continuous audio streaming..." << std::endl; + + // Start continuous audio streaming in a separate thread (like ASR clients) + std::thread audio_thread([&client, mic_chunks, simulateRealtime]() { + // Continuous streaming loop for microphone input + while (!g_request_exit && mic_chunks->IsCapturing()) { + // Get the latest audio chunk from microphone + std::string latest_chunk = mic_chunks->GetLatestChunk(); + if (!latest_chunk.empty()) { + // Send the audio chunk to the server + client.SendAudioAppendPublic(latest_chunk); + client.SendAudioCommitPublic(); + } + + // Small delay to prevent busy waiting + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Send audio done when streaming ends + if (!g_request_exit) { + client.SendAudioDonePublic(); + } }); - - // Step 2: Wait for the connection to be established - if (!client.WaitForConnection()) { - std::cerr << "Failed to establish WebSocket connection" << std::endl; - client.Close(); - client_thread.join(); - return; - } - - std::cout << "WebSocket connection established" << std::endl; - - // Step 3: Initialize the session - if (!client.InitializeSession()) { - std::cerr << "Failed to initialize session" << std::endl; - client.Close(); - client_thread.join(); - return; - } - - std::cout << "Waiting for session update confirmation..." << std::endl; - - // Step 4: Wait for the session to be updated - if (!client.WaitForSessionUpdate()) { - std::cerr << "Session update timeout" << std::endl; - client.Close(); - client_thread.join(); - return; + + // Keep microphone running while waiting for transcription completion + // The microphone will continue capturing until transcription completes or exit is requested + std::cout << "Microphone is now active. Press Ctrl+C to stop." << std::endl; + + // For microphone input, wait for user interruption or transcription completion + // The audio thread will continue running until g_request_exit is set + while (!g_request_exit && mic_chunks->IsCapturing()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - - // Step 5: Send the audio chunks with realistic timing - perfCounter.StartProcessingTimer(); - perfCounter.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds()); - - // Send chunks with realistic timing + + // Wait for audio transmission to complete + audio_thread.join(); + + } else { + // For file-based audio, send chunks normally client.SendAudioChunks(simulateRealtime); - - std::cout << "Waiting for transcription completion..." << std::endl; - - // Step 6: Wait for the transcription to be completed - if (client.WaitForTranscriptionCompletion()) { - std::cout << "Transcription completed successfully!" << std::endl; - perfCounter.EndProcessingTimer(); - perfCounter.SetSuccess(true); - } else { - std::cout << "Transcription did not complete within timeout" << std::endl; - perfCounter.EndProcessingTimer(); - } - - // Step 7: Close the WebSocket connection - client.Close(); - client_thread.join(); - - { - std::lock_guard lock(g_clients_mutex); - g_clients.push_back(&client); - } + } + + std::cout << "Waiting for transcription completion..." << std::endl; + + // Step 6: Wait for the transcription to be completed + if (client.WaitForTranscriptionCompletion()) { + std::cout << "Transcription completed successfully!" << std::endl; + perfCounter.EndProcessingTimer(); + perfCounter.SetSuccess(true); + } else { + std::cout << "Transcription did not complete within timeout" << std::endl; + perfCounter.EndProcessingTimer(); + } + + // Step 6.5: Stop microphone capture if it was used (after transcription completes) + if (auto mic_chunks = std::dynamic_pointer_cast(audio_chunks)) { + // Set exit flag to stop the audio streaming thread + g_request_exit = true; - // Step 8: Report the stats - perfCounter.ReportStats(); + // Stop microphone capture + mic_chunks->StopCapture(); + std::cout << "Stopped microphone capture" << std::endl; + } + + // Step 7: Close the WebSocket connection + client.Close(); + client_thread.join(); + + { + std::lock_guard lock(g_clients_mutex); + g_clients.push_back(&client); + } + + // Step 8: Report the stats + perfCounter.ReportStats(); } -int main(int argc, char* argv[]) { - google::InitGoogleLogging(argv[0]); - FLAGS_logtostderr = 1; - - // Set up usage message - std::stringstream str_usage; - str_usage << "Usage: riva_realtime_asr_client " << std::endl; - str_usage << " --audio_file= " << std::endl; - str_usage << " --audio_device= " << std::endl; - str_usage << " --automatic_punctuation=" << std::endl; - str_usage << " --max_alternatives=" << std::endl; - str_usage << " --profanity_filter=" << std::endl; - str_usage << " --word_time_offsets=" << std::endl; - str_usage << " --riva_uri= " << std::endl; - str_usage << " --chunk_duration_ms= " << std::endl; - str_usage << " --interim_results= " << std::endl; - str_usage << " --simulate_realtime= " << std::endl; - str_usage << " --num_iterations= " << std::endl; - str_usage << " --num_parallel_requests= " << std::endl; - str_usage << " --print_transcripts= " << std::endl; - str_usage << " --output_filename=" << std::endl; - str_usage << " --verbatim_transcripts=" << std::endl; - str_usage << " --language_code=" << std::endl; - str_usage << " --boosted_words_file=" << std::endl; - str_usage << " --boosted_words_score=" << std::endl; - str_usage << " --ssl_root_cert=" << std::endl; - str_usage << " --ssl_client_key=" << std::endl; - str_usage << " --ssl_client_cert=" << std::endl; - str_usage << " --model_name=" << std::endl; - str_usage << " --metadata=" << std::endl; - str_usage << " --start_history=" << std::endl; - str_usage << " --start_threshold=" << std::endl; - str_usage << " --stop_history=" << std::endl; - str_usage << " --stop_history_eou=" << std::endl; - str_usage << " --stop_threshold=" << std::endl; - str_usage << " --stop_threshold_eou=" << std::endl; - str_usage << " --custom_configuration=" << std::endl; - str_usage << " --speaker_diarization=" << std::endl; - str_usage << " --diarization_max_speakers=" << std::endl; - str_usage << " --timeout_ms=" << std::endl; - str_usage << " --max_grpc_message_size=" << std::endl; - str_usage << " --connection_timeout_ms=" << std::endl; - str_usage << " --session_init_timeout_ms=" << std::endl; - str_usage << " --session_update_timeout_ms=" << std::endl; - str_usage << " --transcription_timeout_ms=" << std::endl; - str_usage << " --chunk_delay_time_ms=" << std::endl; - str_usage << " --verbose_logging=" << std::endl; - str_usage << " --show_detailed_stats=" << std::endl; - str_usage << " --show_tabular_stats=" << std::endl; - - gflags::SetUsageMessage(str_usage.str()); - - if (argc < 2) { - std::cout << gflags::ProgramUsage(); - return 1; - } +int +main(int argc, char* argv[]) +{ + google::InitGoogleLogging(argv[0]); + FLAGS_logtostderr = 1; - gflags::ParseCommandLineFlags(&argc, &argv, true); + // Set up usage message + std::stringstream str_usage; + str_usage << "Usage: riva_realtime_asr_client " << std::endl; + str_usage << " --audio_file= " << std::endl; + str_usage << " --audio_device= " << std::endl; + str_usage << " --automatic_punctuation=" << std::endl; + str_usage << " --max_alternatives=" << std::endl; + str_usage << " --profanity_filter=" << std::endl; + str_usage << " --word_time_offsets=" << std::endl; + str_usage << " --riva_uri= " << std::endl; + str_usage << " --chunk_duration_ms= " << std::endl; + str_usage << " --interim_results= " << std::endl; + str_usage << " --simulate_realtime= " << std::endl; + str_usage << " --num_iterations= " << std::endl; + str_usage << " --num_parallel_requests= " << std::endl; + str_usage << " --print_transcripts= " << std::endl; + str_usage << " --output_filename=" << std::endl; + str_usage << " --verbatim_transcripts=" << std::endl; + str_usage << " --language_code=" << std::endl; + str_usage << " --boosted_words_file=" << std::endl; + str_usage << " --boosted_words_score=" << std::endl; + str_usage << " --ssl_root_cert=" << std::endl; + str_usage << " --ssl_client_key=" << std::endl; + str_usage << " --ssl_client_cert=" << std::endl; + str_usage << " --model_name=" << std::endl; + str_usage << " --metadata=" << std::endl; + str_usage << " --start_history=" << std::endl; + str_usage << " --start_threshold=" << std::endl; + str_usage << " --stop_history=" << std::endl; + str_usage << " --stop_history_eou=" << std::endl; + str_usage << " --stop_threshold=" << std::endl; + str_usage << " --stop_threshold_eou=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; + str_usage << " --speaker_diarization=" << std::endl; + str_usage << " --diarization_max_speakers=" << std::endl; + str_usage << " --timeout_ms=" << std::endl; + str_usage << " --max_grpc_message_size=" << std::endl; + str_usage << " --connection_timeout_ms=" << std::endl; + str_usage << " --session_init_timeout_ms=" << std::endl; + str_usage << " --session_update_timeout_ms=" << std::endl; + str_usage << " --transcription_timeout_ms=" << std::endl; + str_usage << " --chunk_delay_time_ms=" << std::endl; + str_usage << " --verbose_logging=" << std::endl; + str_usage << " --show_detailed_stats=" << std::endl; + str_usage << " --show_tabular_stats=" << std::endl; + // Note: Microphone configuration is hardcoded (16kHz, mono, 16-bit) like ASR clients - if (argc > 1) { - std::cout << gflags::ProgramUsage(); - return 1; - } + gflags::SetUsageMessage(str_usage.str()); - // Validate arguments - if (FLAGS_max_alternatives < 1) { - std::cerr << "max_alternatives must be greater than or equal to 1." << std::endl; - return 1; - } + if (argc < 2) { + std::cout << gflags::ProgramUsage(); + return 1; + } - if (FLAGS_num_iterations < 1) { - std::cerr << "num_iterations must be greater than 0" << std::endl; - return 1; - } + gflags::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_num_parallel_requests < 1) { - std::cerr << "num_parallel_requests must be greater than 0" << std::endl; - return 1; - } + if (argc > 1) { + std::cout << gflags::ProgramUsage(); + return 1; + } - // Check if audio file or device is specified - if (FLAGS_audio_file.empty() && FLAGS_audio_device.empty()) { - std::cerr << "Either --audio_file or --audio_device must be specified" << std::endl; - return 1; - } + // Validate arguments + if (FLAGS_max_alternatives < 1) { + std::cerr << "max_alternatives must be greater than or equal to 1." << std::endl; + return 1; + } - // Validate audio file exists if specified - if (!FLAGS_audio_file.empty() && !std::filesystem::exists(FLAGS_audio_file)) { - std::cerr << "Audio file does not exist: " << FLAGS_audio_file << std::endl; - return 1; - } + if (FLAGS_num_iterations < 1) { + std::cerr << "num_iterations must be greater than 0" << std::endl; + return 1; + } + + if (FLAGS_num_parallel_requests < 1) { + std::cerr << "num_parallel_requests must be greater than 0" << std::endl; + return 1; + } + + // Check if audio file or device is specified + if (FLAGS_audio_file.empty() && FLAGS_audio_device.empty()) { + std::cerr << "Either --audio_file or --audio_device must be specified" << std::endl; + return 1; + } - // Use command-line arguments - const std::string uri = FLAGS_riva_uri; - const std::string audio_file_path = FLAGS_audio_file; - const std::size_t num_iterations = FLAGS_num_iterations; - const std::size_t num_parallel_clients = FLAGS_num_parallel_requests; - const bool simulateRealtime = FLAGS_simulate_realtime; - - const std::size_t connectionTimeoutInMs = FLAGS_connection_timeout_ms; - const std::size_t sessionInitTimeoutInMs = FLAGS_session_init_timeout_ms; - const std::size_t sessionUpdateTimeoutInMs = FLAGS_session_update_timeout_ms; - const std::size_t transcriptionTimeoutInMs = FLAGS_transcription_timeout_ms; - const std::size_t chunkDelayTimeInMs = FLAGS_chunk_delay_time_ms; - const std::size_t chunk_duration_ms = FLAGS_chunk_duration_ms; - - const auto audio_chunks = std::make_shared(audio_file_path, chunk_duration_ms); - if (!audio_chunks->Init()) { - std::cerr << "Failed to initialize audio chunks" << std::endl; - return 1; + // Validate audio file exists if specified + if (!FLAGS_audio_file.empty() && !std::filesystem::exists(FLAGS_audio_file)) { + std::cerr << "Audio file does not exist: " << FLAGS_audio_file << std::endl; + return 1; + } + + // Validate microphone parameters (using hardcoded values like ASR clients) + if (!FLAGS_audio_device.empty()) { + // No validation needed since we use hardcoded values + // MIC_SAMPLE_RATE = 16000, MIC_CHANNELS = 1, MIC_BIT_DEPTH = 16 + + // For microphone input, enforce single request and iteration + if (FLAGS_num_parallel_requests != 1) { + std::cout << "Warning: num_parallel_requests set to 1 for microphone input" << std::endl; + FLAGS_num_parallel_requests = 1; } - - PerformanceStats overallPerf("Overall"); - - overallPerf.SetAudioDurationInSeconds(audio_chunks->GetDurationSeconds() * num_iterations * num_parallel_clients); - - // Create StatsBuilder for all clients - StatsBuilder statsBuilder("client", audio_chunks->GetDurationSeconds(), num_parallel_clients); - - // Run iterations asynchronously - std::vector> futures; - std::cout << "Starting " << num_parallel_clients << " async clients..." << std::endl; - - overallPerf.StartProcessingTimer(); - for (std::size_t N = 0; N < num_parallel_clients; ++N) { - // Launch each client asynchronously - futures.emplace_back(std::async(std::launch::async, [&, N]() { - std::cout << "Starting client " << (N + 1) << "/" << num_parallel_clients << std::endl; - - for (std::size_t M = 0; M < num_iterations; ++M) { - std::cout << " Running iteration " << (M + 1) << "/" << num_iterations << std::endl; - client_runner( uri, - audio_chunks, - statsBuilder.GetPerformanceStats(N), - connectionTimeoutInMs, - sessionInitTimeoutInMs, - sessionUpdateTimeoutInMs, - transcriptionTimeoutInMs, - chunkDelayTimeInMs, - simulateRealtime); - } - - std::cout << "Completed client " << (N + 1) << "/" << num_parallel_clients << std::endl; - })); + if (FLAGS_num_iterations != 1) { + std::cout << "Warning: num_iterations set to 1 for microphone input" << std::endl; + FLAGS_num_iterations = 1; } - - // Wait for all iterations to complete - std::cout << "Waiting for all iterations to complete..." << std::endl; - for (auto& future : futures) { - future.wait(); + if (FLAGS_simulate_realtime) { + std::cout << "Warning: simulate_realtime set to false for microphone input" << std::endl; + FLAGS_simulate_realtime = false; } - std::cout << "All iterations completed!" << std::endl; - overallPerf.EndProcessingTimer(); + } - // Set up signal handlers for graceful shutdown - signal(SIGINT, signal_handler); - signal(SIGTERM, signal_handler); + // Use command-line arguments + const std::string uri = FLAGS_riva_uri; + const std::string audio_file_path = FLAGS_audio_file; + const std::string audio_device = FLAGS_audio_device; + const std::size_t num_iterations = FLAGS_num_iterations; + const std::size_t num_parallel_clients = FLAGS_num_parallel_requests; + const bool simulateRealtime = FLAGS_simulate_realtime; - // Conditional stats reporting based on flags - if (FLAGS_show_detailed_stats) { - statsBuilder.ReportDetailedStats(); - } - if (FLAGS_show_tabular_stats) { - statsBuilder.ReportTabularStats(); - } - - statsBuilder.ReportCumulativeStats(); - overallPerf.ReportStats(); - return 0; -} \ No newline at end of file + const std::size_t connectionTimeoutInMs = FLAGS_connection_timeout_ms; + const std::size_t sessionInitTimeoutInMs = FLAGS_session_init_timeout_ms; + const std::size_t sessionUpdateTimeoutInMs = FLAGS_session_update_timeout_ms; + const std::size_t transcriptionTimeoutInMs = FLAGS_transcription_timeout_ms; + const std::size_t chunkDelayTimeInMs = FLAGS_chunk_delay_time_ms; + const std::size_t chunk_duration_ms = FLAGS_chunk_duration_ms; + + // Create appropriate audio chunks based on input type + const auto audio_chunks = CreateAudioChunks(audio_file_path, audio_device, chunk_duration_ms); + if (!audio_chunks) { + std::cerr << "Failed to create audio chunks" << std::endl; + return 1; + } + + PerformanceStats overallPerf("Overall"); + + // For microphone input, duration is ongoing, so we'll use a reasonable estimate + double audio_duration = audio_chunks->GetDurationSeconds(); + if (audio_duration <= 0.0 && !audio_device.empty()) { + // Microphone input - use a reasonable duration estimate for stats + audio_duration = 60.0; // Assume 1 minute for microphone sessions + std::cout << "Using estimated duration of " << audio_duration << " seconds for microphone input" + << std::endl; + } + + overallPerf.SetAudioDurationInSeconds(audio_duration * num_iterations * num_parallel_clients); + + // Set up signal handlers for graceful shutdown (before starting async operations) + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + // Create StatsBuilder for all clients + StatsBuilder statsBuilder("client", audio_duration, num_parallel_clients); + + // Run iterations asynchronously + std::vector> futures; + std::cout << "Starting " << num_parallel_clients << " async clients..." << std::endl; + + overallPerf.StartProcessingTimer(); + for (std::size_t N = 0; N < num_parallel_clients; ++N) { + // Launch each client asynchronously + futures.emplace_back(std::async(std::launch::async, [&, N]() { + std::cout << "Starting client " << (N + 1) << "/" << num_parallel_clients << std::endl; + + for (std::size_t M = 0; M < num_iterations; ++M) { + std::cout << " Running iteration " << (M + 1) << "/" << num_iterations << std::endl; + client_runner( + uri, audio_chunks, statsBuilder.GetPerformanceStats(N), connectionTimeoutInMs, + sessionInitTimeoutInMs, sessionUpdateTimeoutInMs, transcriptionTimeoutInMs, + chunkDelayTimeInMs, simulateRealtime); + } + + std::cout << "Completed client " << (N + 1) << "/" << num_parallel_clients << std::endl; + })); + } + + // Wait for all iterations to complete + std::cout << "Waiting for all iterations to complete..." << std::endl; + for (auto& future : futures) { + future.wait(); + } + std::cout << "All iterations completed!" << std::endl; + overallPerf.EndProcessingTimer(); + + // Conditional stats reporting based on flags + if (FLAGS_show_detailed_stats) { + statsBuilder.ReportDetailedStats(); + } + if (FLAGS_show_tabular_stats) { + statsBuilder.ReportTabularStats(); + } + + statsBuilder.ReportCumulativeStats(); + overallPerf.ReportStats(); + return 0; +} \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.cpp b/riva/utils/stats_builder/stats_builder.cpp index 22794d1..6484d75 100644 --- a/riva/utils/stats_builder/stats_builder.cpp +++ b/riva/utils/stats_builder/stats_builder.cpp @@ -2,283 +2,357 @@ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: MIT */ - + #include "stats_builder.h" -#include + #include +#include #include #include namespace nvidia::riva::utils { -PerformanceStats::PerformanceStats(const std::string& objectName) - : success_(false), - objectName_(objectName), +PerformanceStats::PerformanceStats(const std::string& objectName) + : success_(false), objectName_(objectName), processing_start_time_(std::chrono::steady_clock::now()), - processing_end_time_(std::chrono::steady_clock::now()), - audio_duration_seconds_(0.0) {} - -StatsBuilder::StatsBuilder(const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations) - : audio_duration_seconds_(audio_duration_seconds), num_iterations_(num_iterations), object_name_(objectName) { - // Pre-allocate the vector with the expected number of iterations - performanceStats_.reserve(num_iterations); - - // Create PerformanceStats objects for each iteration - for (std::size_t i = 0; i < num_iterations; ++i) { - std::string iteration_name = objectName + "-" + std::to_string(i); - performanceStats_.emplace_back(iteration_name); - // Set the audio duration for each performance stats object - performanceStats_.back().SetAudioDurationInSeconds(audio_duration_seconds); - } + processing_end_time_(std::chrono::steady_clock::now()), audio_duration_seconds_(0.0) +{ } -void PerformanceStats::StartProcessingTimer() { - processing_start_time_ = std::chrono::steady_clock::now(); - //std::cout << "Starting processing timer: " << std::chrono::duration_cast(processing_start_time_.time_since_epoch()).count() << std::endl; - } - -void PerformanceStats::EndProcessingTimer() { - processing_end_time_ = std::chrono::steady_clock::now(); - //std::cout << "Ending processing timer: " << std::chrono::duration_cast(processing_end_time_.time_since_epoch()).count() << std::endl; - } - -double PerformanceStats::GetRuntimeInMs() const { - auto durationInMs = std::chrono::duration_cast( - processing_end_time_ - processing_start_time_); - return durationInMs.count(); +StatsBuilder::StatsBuilder( + const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations) + : audio_duration_seconds_(audio_duration_seconds), num_iterations_(num_iterations), + object_name_(objectName) +{ + // Pre-allocate the vector with the expected number of iterations + performanceStats_.reserve(num_iterations); + + // Create PerformanceStats objects for each iteration + for (std::size_t i = 0; i < num_iterations; ++i) { + std::string iteration_name = objectName + "-" + std::to_string(i); + performanceStats_.emplace_back(iteration_name); + // Set the audio duration for each performance stats object + performanceStats_.back().SetAudioDurationInSeconds(audio_duration_seconds); + } } -double PerformanceStats::GetRuntimeInSeconds() const { - return GetRuntimeInMs() / 1000.0; +void +PerformanceStats::StartProcessingTimer() +{ + processing_start_time_ = std::chrono::steady_clock::now(); + // std::cout << "Starting processing timer: " << + // std::chrono::duration_cast(processing_start_time_.time_since_epoch()).count() + // << std::endl; } - -void PerformanceStats::SetAudioDurationInSeconds(double audio_duration_seconds) { - audio_duration_seconds_ = audio_duration_seconds; + +void +PerformanceStats::EndProcessingTimer() +{ + processing_end_time_ = std::chrono::steady_clock::now(); + // std::cout << "Ending processing timer: " << + // std::chrono::duration_cast(processing_end_time_.time_since_epoch()).count() + // << std::endl; } - -double PerformanceStats::GetThroughputRTFX() const { - double runtimeInMs = GetRuntimeInMs(); - if (runtimeInMs > 0.0 && audio_duration_seconds_ > 0.0) { - // RTFX = (Total Audio Processed in seconds) × 1000 ÷ (Total Runtime in milliseconds) - return (audio_duration_seconds_ * 1000.0) / runtimeInMs; - } - return 0.0; - } -void PerformanceStats::SetObjectName(const std::string& objectName) { - objectName_ = objectName; +double +PerformanceStats::GetRuntimeInMs() const +{ + auto durationInMs = std::chrono::duration_cast( + processing_end_time_ - processing_start_time_); + return durationInMs.count(); } -std::string PerformanceStats::GetObjectName() const { - return objectName_; +double +PerformanceStats::GetRuntimeInSeconds() const +{ + return GetRuntimeInMs() / 1000.0; } -void PerformanceStats::ReportStats() { - std::cout << "Object Name: " << GetObjectName() << std::endl; - std::cout << "Success: " << IsSuccess() << std::endl; - std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; - std::cout << "Total Runtime: " << GetRuntimeInMs() << " ms (" << GetRuntimeInSeconds() << " seconds)" << std::endl; - std::cout << "Throughput: " << GetThroughputRTFX() << " RTFX" << std::endl; +void +PerformanceStats::SetAudioDurationInSeconds(double audio_duration_seconds) +{ + audio_duration_seconds_ = audio_duration_seconds; } +double +PerformanceStats::GetThroughputRTFX() const +{ + double runtimeInMs = GetRuntimeInMs(); + if (runtimeInMs > 0.0 && audio_duration_seconds_ > 0.0) { + // RTFX = (Total Audio Processed in seconds) × 1000 ÷ (Total Runtime in milliseconds) + return (audio_duration_seconds_ * 1000.0) / runtimeInMs; + } + return 0.0; +} +void +PerformanceStats::SetObjectName(const std::string& objectName) +{ + objectName_ = objectName; +} -void StatsBuilder::ReportCumulativeStats() { - std::cout << "Cumulative Stats" << std::endl; - std::cout << "=================" << std::endl; - for (auto performanceStats : performanceStats_) { - std::cout << "Object Name: " << performanceStats.GetObjectName() << std::endl; - std::cout << "Total Runtime: " << performanceStats.GetRuntimeInMs() << " ms (" << performanceStats.GetRuntimeInSeconds() << " seconds)" << std::endl; - std::cout << "Throughput: " << performanceStats.GetThroughputRTFX() << " RTFX" << std::endl; - } +std::string +PerformanceStats::GetObjectName() const +{ + return objectName_; +} + +void +PerformanceStats::ReportStats() +{ + std::cout << "Object Name: " << GetObjectName() << std::endl; + std::cout << "Success: " << IsSuccess() << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Total Runtime: " << GetRuntimeInMs() << " ms (" << GetRuntimeInSeconds() + << " seconds)" << std::endl; + std::cout << "Throughput: " << GetThroughputRTFX() << " RTFX" << std::endl; +} + + +void +StatsBuilder::ReportCumulativeStats() +{ + std::cout << "Cumulative Stats" << std::endl; + std::cout << "=================" << std::endl; + for (auto performanceStats : performanceStats_) { + std::cout << "Object Name: " << performanceStats.GetObjectName() << std::endl; + std::cout << "Total Runtime: " << performanceStats.GetRuntimeInMs() << " ms (" + << performanceStats.GetRuntimeInSeconds() << " seconds)" << std::endl; + std::cout << "Throughput: " << performanceStats.GetThroughputRTFX() << " RTFX" << std::endl; + } } // Helper function to calculate percentile -double CalculatePercentile(const std::vector& values, double percentile) { - if (values.empty()) return 0.0; - - std::vector sorted_values = values; - std::sort(sorted_values.begin(), sorted_values.end()); - - double index = (percentile / 100.0) * (sorted_values.size() - 1); - int lower_index = static_cast(index); - int upper_index = lower_index + 1; - - if (upper_index >= sorted_values.size()) { - return sorted_values[lower_index]; - } - - double weight = index - lower_index; - return sorted_values[lower_index] * (1 - weight) + sorted_values[upper_index] * weight; +double +CalculatePercentile(const std::vector& values, double percentile) +{ + if (values.empty()) + return 0.0; + + std::vector sorted_values = values; + std::sort(sorted_values.begin(), sorted_values.end()); + + double index = (percentile / 100.0) * (sorted_values.size() - 1); + int lower_index = static_cast(index); + int upper_index = lower_index + 1; + + if (upper_index >= sorted_values.size()) { + return sorted_values[lower_index]; + } + + double weight = index - lower_index; + return sorted_values[lower_index] * (1 - weight) + sorted_values[upper_index] * weight; } // Statistical methods for runtime -double StatsBuilder::GetAverageRuntime() const { - if (performanceStats_.empty()) return 0.0; - - double sum = 0.0; - for (const auto& stats : performanceStats_) { - sum += stats.GetRuntimeInMs(); - } - return sum / performanceStats_.size(); +double +StatsBuilder::GetAverageRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetRuntimeInMs(); + } + return sum / performanceStats_.size(); } -double StatsBuilder::GetP50Runtime() const { - std::vector runtimes; - for (const auto& stats : performanceStats_) { - runtimes.push_back(stats.GetRuntimeInMs()); - } - return CalculatePercentile(runtimes, 50.0); +double +StatsBuilder::GetP50Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 50.0); } -double StatsBuilder::GetP90Runtime() const { - std::vector runtimes; - for (const auto& stats : performanceStats_) { - runtimes.push_back(stats.GetRuntimeInMs()); - } - return CalculatePercentile(runtimes, 90.0); +double +StatsBuilder::GetP90Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 90.0); } -double StatsBuilder::GetP95Runtime() const { - std::vector runtimes; - for (const auto& stats : performanceStats_) { - runtimes.push_back(stats.GetRuntimeInMs()); - } - return CalculatePercentile(runtimes, 95.0); +double +StatsBuilder::GetP95Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 95.0); } -double StatsBuilder::GetP99Runtime() const { - std::vector runtimes; - for (const auto& stats : performanceStats_) { - runtimes.push_back(stats.GetRuntimeInMs()); - } - return CalculatePercentile(runtimes, 99.0); +double +StatsBuilder::GetP99Runtime() const +{ + std::vector runtimes; + for (const auto& stats : performanceStats_) { + runtimes.push_back(stats.GetRuntimeInMs()); + } + return CalculatePercentile(runtimes, 99.0); } -double StatsBuilder::GetMinRuntime() const { - if (performanceStats_.empty()) return 0.0; - - double min_runtime = performanceStats_[0].GetRuntimeInMs(); - for (const auto& stats : performanceStats_) { - min_runtime = std::min(min_runtime, stats.GetRuntimeInMs()); - } - return min_runtime; +double +StatsBuilder::GetMinRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double min_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + min_runtime = std::min(min_runtime, stats.GetRuntimeInMs()); + } + return min_runtime; } -double StatsBuilder::GetMaxRuntime() const { - if (performanceStats_.empty()) return 0.0; - - double max_runtime = performanceStats_[0].GetRuntimeInMs(); - for (const auto& stats : performanceStats_) { - max_runtime = std::max(max_runtime, stats.GetRuntimeInMs()); - } - return max_runtime; +double +StatsBuilder::GetMaxRuntime() const +{ + if (performanceStats_.empty()) + return 0.0; + + double max_runtime = performanceStats_[0].GetRuntimeInMs(); + for (const auto& stats : performanceStats_) { + max_runtime = std::max(max_runtime, stats.GetRuntimeInMs()); + } + return max_runtime; } // Statistical methods for throughput -double StatsBuilder::GetAverageThroughput() const { - if (performanceStats_.empty()) return 0.0; - - double sum = 0.0; - for (const auto& stats : performanceStats_) { - sum += stats.GetThroughputRTFX(); - } - return sum / performanceStats_.size(); +double +StatsBuilder::GetAverageThroughput() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum / performanceStats_.size(); } // Statistical methods for throughput -double StatsBuilder::GetCumulativeThroughput() const { - if (performanceStats_.empty()) return 0.0; - - double sum = 0.0; - for (const auto& stats : performanceStats_) { - sum += stats.GetThroughputRTFX(); - } - return sum; +double +StatsBuilder::GetCumulativeThroughput() const +{ + if (performanceStats_.empty()) + return 0.0; + + double sum = 0.0; + for (const auto& stats : performanceStats_) { + sum += stats.GetThroughputRTFX(); + } + return sum; } -double StatsBuilder::GetP90Throughput() const { - std::vector throughputs; - for (const auto& stats : performanceStats_) { - throughputs.push_back(stats.GetThroughputRTFX()); - } - return CalculatePercentile(throughputs, 90.0); +double +StatsBuilder::GetP90Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 90.0); } -double StatsBuilder::GetP95Throughput() const { - std::vector throughputs; - for (const auto& stats : performanceStats_) { - throughputs.push_back(stats.GetThroughputRTFX()); - } - return CalculatePercentile(throughputs, 95.0); +double +StatsBuilder::GetP95Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 95.0); } -double StatsBuilder::GetP99Throughput() const { - std::vector throughputs; - for (const auto& stats : performanceStats_) { - throughputs.push_back(stats.GetThroughputRTFX()); - } - return CalculatePercentile(throughputs, 99.0); +double +StatsBuilder::GetP99Throughput() const +{ + std::vector throughputs; + for (const auto& stats : performanceStats_) { + throughputs.push_back(stats.GetThroughputRTFX()); + } + return CalculatePercentile(throughputs, 99.0); } -bool StatsBuilder::AreAllIterationsSuccessful() const { - if (performanceStats_.empty()) return false; - - for (const auto& stats : performanceStats_) { - if (!stats.IsSuccess()) { - return false; - } +bool +StatsBuilder::AreAllIterationsSuccessful() const +{ + if (performanceStats_.empty()) + return false; + + for (const auto& stats : performanceStats_) { + if (!stats.IsSuccess()) { + return false; } - return true; + } + return true; } -std::size_t StatsBuilder::GetSuccessfulIterationsCount() const { - std::size_t success_count = 0; - for (const auto& stats : performanceStats_) { - if (stats.IsSuccess()) { - success_count++; - } +std::size_t +StatsBuilder::GetSuccessfulIterationsCount() const +{ + std::size_t success_count = 0; + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) { + success_count++; } - return success_count; + } + return success_count; } -std::size_t StatsBuilder::GetFailedIterationsCount() const { - return performanceStats_.size() - GetSuccessfulIterationsCount(); +std::size_t +StatsBuilder::GetFailedIterationsCount() const +{ + return performanceStats_.size() - GetSuccessfulIterationsCount(); } -double StatsBuilder::GetSuccessRate() const { - if (performanceStats_.empty()) return 0.0; - return static_cast(GetSuccessfulIterationsCount()) / performanceStats_.size() * 100.0; +double +StatsBuilder::GetSuccessRate() const +{ + if (performanceStats_.empty()) + return 0.0; + return static_cast(GetSuccessfulIterationsCount()) / performanceStats_.size() * 100.0; } -void StatsBuilder::ReportDetailedStats() const { - std::cout << "\n=== DETAILED PERFORMANCE STATISTICS ===" << std::endl; - std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; - std::cout << "Number of Iterations: " << num_iterations_ << std::endl; - std::cout << "Sample Count: " << performanceStats_.size() << std::endl; - - // Add success rate information - std::cout << "Success Rate: " << GetSuccessRate() << "% (" << GetSuccessfulIterationsCount() - << "/" << performanceStats_.size() << " iterations)" << std::endl; - std::cout << "All Iterations Successful: " << (AreAllIterationsSuccessful() ? "YES" : "NO") << std::endl; - - std::cout << "\n--- RUNTIME STATISTICS (ms) ---" << std::endl; - std::cout << "Average: " << GetAverageRuntime() << " ms" << std::endl; - std::cout << "P50: " << GetP50Runtime() << " ms" << std::endl; - std::cout << "P90: " << GetP90Runtime() << " ms" << std::endl; - std::cout << "P95: " << GetP95Runtime() << " ms" << std::endl; - std::cout << "P99: " << GetP99Runtime() << " ms" << std::endl; - std::cout << "Min: " << GetMinRuntime() << " ms" << std::endl; - std::cout << "Max: " << GetMaxRuntime() << " ms" << std::endl; - - std::cout << "\n--- THROUGHPUT STATISTICS (RTFX) ---" << std::endl; - std::cout << "Average: " << GetAverageThroughput() << " RTFX" << std::endl; - std::cout << "Cumulative: " << GetCumulativeThroughput() << " RTFX" << std::endl; - std::cout << "P90: " << GetP90Throughput() << " RTFX" << std::endl; - std::cout << "P95: " << GetP95Throughput() << " RTFX" << std::endl; - std::cout << "P99: " << GetP99Throughput() << " RTFX" << std::endl; - - std::cout << "=====================================" << std::endl; +void +StatsBuilder::ReportDetailedStats() const +{ + std::cout << "\n=== DETAILED PERFORMANCE STATISTICS ===" << std::endl; + std::cout << "Audio Duration: " << audio_duration_seconds_ << " seconds" << std::endl; + std::cout << "Number of Iterations: " << num_iterations_ << std::endl; + std::cout << "Sample Count: " << performanceStats_.size() << std::endl; + + // Add success rate information + std::cout << "Success Rate: " << GetSuccessRate() << "% (" << GetSuccessfulIterationsCount() + << "/" << performanceStats_.size() << " iterations)" << std::endl; + std::cout << "All Iterations Successful: " << (AreAllIterationsSuccessful() ? "YES" : "NO") + << std::endl; + + std::cout << "\n--- RUNTIME STATISTICS (ms) ---" << std::endl; + std::cout << "Average: " << GetAverageRuntime() << " ms" << std::endl; + std::cout << "P50: " << GetP50Runtime() << " ms" << std::endl; + std::cout << "P90: " << GetP90Runtime() << " ms" << std::endl; + std::cout << "P95: " << GetP95Runtime() << " ms" << std::endl; + std::cout << "P99: " << GetP99Runtime() << " ms" << std::endl; + std::cout << "Min: " << GetMinRuntime() << " ms" << std::endl; + std::cout << "Max: " << GetMaxRuntime() << " ms" << std::endl; + + std::cout << "\n--- THROUGHPUT STATISTICS (RTFX) ---" << std::endl; + std::cout << "Average: " << GetAverageThroughput() << " RTFX" << std::endl; + std::cout << "Cumulative: " << GetCumulativeThroughput() << " RTFX" << std::endl; + std::cout << "P90: " << GetP90Throughput() << " RTFX" << std::endl; + std::cout << "P95: " << GetP95Throughput() << " RTFX" << std::endl; + std::cout << "P99: " << GetP99Throughput() << " RTFX" << std::endl; + + std::cout << "=====================================" << std::endl; } } // namespace nvidia::riva::utils \ No newline at end of file diff --git a/riva/utils/stats_builder/stats_builder.h b/riva/utils/stats_builder/stats_builder.h index 5cfa89e..2a1c273 100644 --- a/riva/utils/stats_builder/stats_builder.h +++ b/riva/utils/stats_builder/stats_builder.h @@ -6,141 +6,136 @@ #ifndef STATS_BUILDER_H #define STATS_BUILDER_H -#include -#include -#include -#include #include +#include +#include // Required for std::setw and std::fixed +#include #include -#include // Required for std::setw and std::fixed +#include +#include namespace nvidia::riva::utils { class PerformanceStats { - private: - bool success_; - std::string objectName_; - // Timing measurement - std::chrono::steady_clock::time_point processing_start_time_; - std::chrono::steady_clock::time_point processing_end_time_; - double audio_duration_seconds_; - - public: - PerformanceStats(const std::string& objectName); - ~PerformanceStats() = default; - - bool IsSuccess() const { return success_; } - void SetSuccess(bool success) { success_ = success; } - - void StartProcessingTimer(); - void EndProcessingTimer(); - std::chrono::steady_clock::time_point GetStartTime() const { return processing_start_time_; } - double GetRuntimeInMs() const; - double GetRuntimeInSeconds() const; - void SetAudioDurationInSeconds(double audio_duration_seconds); - double GetAudioDurationInSeconds() const { return audio_duration_seconds_; } - double GetThroughputRTFX() const; - - void SetObjectName(const std::string& objectName); - std::string GetObjectName() const; - - void ReportStats(); + private: + bool success_; + std::string objectName_; + // Timing measurement + std::chrono::steady_clock::time_point processing_start_time_; + std::chrono::steady_clock::time_point processing_end_time_; + double audio_duration_seconds_; + + public: + PerformanceStats(const std::string& objectName); + ~PerformanceStats() = default; + + bool IsSuccess() const { return success_; } + void SetSuccess(bool success) { success_ = success; } + + void StartProcessingTimer(); + void EndProcessingTimer(); + std::chrono::steady_clock::time_point GetStartTime() const { return processing_start_time_; } + double GetRuntimeInMs() const; + double GetRuntimeInSeconds() const; + void SetAudioDurationInSeconds(double audio_duration_seconds); + double GetAudioDurationInSeconds() const { return audio_duration_seconds_; } + double GetThroughputRTFX() const; + + void SetObjectName(const std::string& objectName); + std::string GetObjectName() const; + + void ReportStats(); }; class StatsBuilder { - private: - std::vector performanceStats_; - double audio_duration_seconds_; - std::size_t num_iterations_; - std::string object_name_; // Added to store the object name - - public: - StatsBuilder(const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations); - ~StatsBuilder() = default; - - void SetAudioDurationInSeconds(double audio_duration_seconds); - void SetNumIterations(std::size_t num_iterations); - void ReportCumulativeStats(); - PerformanceStats& GetPerformanceStats(std::size_t index) { return performanceStats_[index]; } - - // Statistical methods - double GetAverageRuntime() const; - double GetP50Runtime() const; - double GetP90Runtime() const; - double GetP95Runtime() const; - double GetP99Runtime() const; - double GetMinRuntime() const; - double GetMaxRuntime() const; - - // Throughput statistics - double GetAverageThroughput() const; - double GetCumulativeThroughput() const; - double GetP90Throughput() const; - double GetP95Throughput() const; - double GetP99Throughput() const; - - // Comprehensive reporting - void ReportDetailedStats() const; - - // Success checking methods - bool AreAllIterationsSuccessful() const; - std::size_t GetSuccessfulIterationsCount() const; - std::size_t GetFailedIterationsCount() const; - double GetSuccessRate() const; - - void ReportTabularStats() const { - std::cout << "\n=== Tabular Performance Statistics ===" << std::endl; - std::cout << std::left - << std::setw(15) << "Name" - << std::setw(10) << "Success" - << std::setw(12) << "Runtime (s)" - << std::setw(15) << "Audio (s)" - << std::setw(15) << "Throughput" - << std::endl; - std::cout << std::string(75, '-') << std::endl; - - for (size_t i = 0; i < performanceStats_.size(); ++i) { - const auto& stats = performanceStats_[i]; - std::string name = object_name_ + "-" + std::to_string(i); - std::string success = stats.IsSuccess() ? "true" : "false"; - double runtime = stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds - double audio_duration = audio_duration_seconds_; // Total audio processed - double throughput = stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX - - std::cout << std::left - << std::setw(15) << name - << std::setw(10) << success - << std::fixed << std::setprecision(3) - << std::setw(12) << runtime - << std::setw(15) << audio_duration - << std::setw(15) << throughput - << std::endl; - } - std::cout << std::string(60, '-') << std::endl; - - // Summary row - size_t success_count = 0; - double total_runtime = 0.0; - double total_audio_processed = audio_duration_seconds_ * performanceStats_.size(); // Total audio across all iterations - double total_throughput = 0.0; - - for (const auto& stats : performanceStats_) { - if (stats.IsSuccess()) success_count++; - total_runtime += stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds - total_throughput += stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX - } - - std::cout << std::left - << std::setw(15) << "SUMMARY" - << std::setw(10) << (success_count == performanceStats_.size() ? "ALL" : std::to_string(success_count) + "/" + std::to_string(performanceStats_.size())) - << std::fixed << std::setprecision(3) - << std::setw(12) << total_runtime - << std::setw(15) << total_audio_processed - << std::setw(15) << total_throughput - << std::endl; - std::cout << std::endl; - } - }; + private: + std::vector performanceStats_; + double audio_duration_seconds_; + std::size_t num_iterations_; + std::string object_name_; // Added to store the object name + + public: + StatsBuilder( + const std::string& objectName, double audio_duration_seconds, std::size_t num_iterations); + ~StatsBuilder() = default; + + void SetAudioDurationInSeconds(double audio_duration_seconds); + void SetNumIterations(std::size_t num_iterations); + void ReportCumulativeStats(); + PerformanceStats& GetPerformanceStats(std::size_t index) { return performanceStats_[index]; } + + // Statistical methods + double GetAverageRuntime() const; + double GetP50Runtime() const; + double GetP90Runtime() const; + double GetP95Runtime() const; + double GetP99Runtime() const; + double GetMinRuntime() const; + double GetMaxRuntime() const; + + // Throughput statistics + double GetAverageThroughput() const; + double GetCumulativeThroughput() const; + double GetP90Throughput() const; + double GetP95Throughput() const; + double GetP99Throughput() const; + + // Comprehensive reporting + void ReportDetailedStats() const; + + // Success checking methods + bool AreAllIterationsSuccessful() const; + std::size_t GetSuccessfulIterationsCount() const; + std::size_t GetFailedIterationsCount() const; + double GetSuccessRate() const; + + void ReportTabularStats() const + { + std::cout << "\n=== Tabular Performance Statistics ===" << std::endl; + std::cout << std::left << std::setw(15) << "Name" << std::setw(10) << "Success" << std::setw(12) + << "Runtime (s)" << std::setw(15) << "Audio (s)" << std::setw(15) << "Throughput" + << std::endl; + std::cout << std::string(75, '-') << std::endl; + + for (size_t i = 0; i < performanceStats_.size(); ++i) { + const auto& stats = performanceStats_[i]; + std::string name = object_name_ + "-" + std::to_string(i); + std::string success = stats.IsSuccess() ? "true" : "false"; + double runtime = stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + double audio_duration = audio_duration_seconds_; // Total audio processed + double throughput = stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + + std::cout << std::left << std::setw(15) << name << std::setw(10) << success << std::fixed + << std::setprecision(3) << std::setw(12) << runtime << std::setw(15) + << audio_duration << std::setw(15) << throughput << std::endl; + } + std::cout << std::string(60, '-') << std::endl; + + // Summary row + size_t success_count = 0; + double total_runtime = 0.0; + double total_audio_processed = + audio_duration_seconds_ * performanceStats_.size(); // Total audio across all iterations + double total_throughput = 0.0; + + for (const auto& stats : performanceStats_) { + if (stats.IsSuccess()) + success_count++; + total_runtime += stats.GetRuntimeInSeconds(); // Changed to GetRuntimeInSeconds + total_throughput += stats.GetThroughputRTFX(); // Changed to GetThroughputRTFX + } + + std::cout << std::left << std::setw(15) << "SUMMARY" << std::setw(10) + << (success_count == performanceStats_.size() + ? "ALL" + : std::to_string(success_count) + "/" + + std::to_string(performanceStats_.size())) + << std::fixed << std::setprecision(3) << std::setw(12) << total_runtime + << std::setw(15) << total_audio_processed << std::setw(15) << total_throughput + << std::endl; + std::cout << std::endl; + } +}; } // namespace nvidia::riva::utils