diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md new file mode 100644 index 000000000000..a1c84d3585fc --- /dev/null +++ b/IMPLEMENTATION_PLAN.md @@ -0,0 +1,79 @@ +# WebSocket Mode Implementation Plan for OpenAI Responses API + +## Overview +Implement WebSocket support for LocalAI's OpenAI API-compatible Responses endpoint, enabling persistent WebSocket connections for long-running, tool-call-heavy agentic workflows. + +## Technical Requirements + +### 1. WebSocket Endpoint +- **Endpoint**: `ws://:/v1/responses` +- **Upgrade**: HTTP upgrade from POST /v1/responses when `Upgrade: websocket` header is present + +### 2. Message Types (Client → Server) + +#### response.create (Initial Turn) +```json +{ + "type": "response.create", + "model": "gpt-4o", + "store": false, + "input": [...], + "tools": [] +} +``` + +#### response.create with Continuation (Subsequent Turns) +```json +{ + "type": "response.create", + "model": "gpt-4o", + "store": false, + "previous_response_id": "resp_123", + "input": [...], + "tools": [] +} +``` + +### 3. Response Events (Server → Client) + +1. **response.created** - Response object created +2. **response.progress** - Incremental output +3. **response.function_call_arguments.delta** - Streaming function arguments +4. **response.function_call_arguments.done** - Function call complete +5. **response.done** - Final response + +### 4. Connection Management +- Track active connections with 60-minute timeout +- Connection-local cache for responses (when store=false) +- One in-flight response at a time per connection + +### 5. Error Handling +- `previous_response_not_found` (400) +- `websocket_connection_limit_reached` (400) + +## Implementation Steps + +### Step 1: Add WebSocket Schema Types +- Add WebSocket message types to `core/schema/openresponses.go` +- Add connection-related types + +### Step 2: Add WebSocket Route +- Modify `core/http/routes/openresponses.go` to handle WebSocket upgrade +- Add GET /v1/responses WebSocket endpoint + +### Step 3: Create WebSocket Handler +- Create `core/http/endpoints/openresponses/websocket.go` +- Implement connection handling +- Implement message parsing +- Implement event streaming + +### Step 4: Add Connection Store +- Implement connection management in store +- Add 60-minute timeout +- Add connection-local cache + +## Files to Modify/Create +1. `core/schema/openresponses.go` - Add WebSocket types +2. `core/http/routes/openresponses.go` - Add WebSocket route +3. `core/http/endpoints/openresponses/websocket.go` - New WebSocket handler (create) +4. `core/http/endpoints/openresponses/store.go` - Add connection management diff --git a/backend/python/nemo/backend.py b/backend/python/nemo/backend.py index fd2218f695e4..e6ebaeff36a4 100644 --- a/backend/python/nemo/backend.py +++ b/backend/python/nemo/backend.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR. +GRPC server of LocalAI for NVIDIA NEMO Toolkit ASR. """ from concurrent import futures import time @@ -12,6 +12,14 @@ import backend_pb2_grpc import torch import nemo.collections.asr as nemo_asr +import numpy as np + +try: + import torchaudio + TORCHAUDIO_AVAILABLE = True +except ImportError: + TORCHAUDIO_AVAILABLE = False + print("[WARNING] torchaudio not available, will use fallback audio loading", file=sys.stderr) import grpc @@ -36,6 +44,50 @@ def is_int(s): MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) +def load_audio_np(audio_path, target_sample_rate=16000): + """Load audio file as numpy array using available methods.""" + if TORCHAUDIO_AVAILABLE: + try: + waveform, sample_rate = torchaudio.load(audio_path) + # Convert to mono if stereo + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + # Resample if needed + if sample_rate != target_sample_rate: + resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate) + waveform = resampler(waveform) + # Convert to numpy + audio_np = waveform.squeeze().numpy() + return audio_np, target_sample_rate + except Exception as e: + print(f"[WARNING] torchaudio loading failed: {e}, trying fallback", file=sys.stderr) + + # Fallback: try using scipy or soundfile + try: + import soundfile as sf + audio_np, sample_rate = sf.read(audio_path) + if audio_np.ndim > 1: + audio_np = audio_np.mean(axis=1) + if sample_rate != target_sample_rate: + from scipy.signal import resample + num_samples = int(len(audio_np) * target_sample_rate / sample_rate) + audio_np = resample(audio_np, num_samples) + return audio_np, target_sample_rate + except ImportError: + pass + + try: + from scipy.io import wavfile + sample_rate, audio_np = wavfile.read(audio_path) + if audio_np.ndim > 1: + audio_np = audio_np.mean(axis=1) + return audio_np, sample_rate + except ImportError: + pass + + raise RuntimeError("No audio loading library available (torchaudio, soundfile, scipy)") + + class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) @@ -89,14 +141,37 @@ def AudioTranscription(self, request, context): print(f"Error: Audio file not found: {audio_path}", file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") - # NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts - results = self.model.transcribe([audio_path]) - + # Load audio as numpy array to avoid lhotse dataloader issues + audio_np, sample_rate = load_audio_np(audio_path, target_sample_rate=16000) + + # Convert to torch tensor + audio_tensor = torch.from_numpy(audio_np).float() + audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension + + # Use the model's transcribe method with the tensor directly + # Some NEMO models accept audio tensors directly + try: + # Try passing the waveform tensor directly + results = self.model.transcribe(audio_tensor, return_char_alignments=False) + except TypeError: + # Fallback: try with dict format + results = self.model.transcribe( + [{"audio_file": audio_path}], + return_char_alignments=False + ) + if not results or len(results) == 0: + print("[WARNING] No transcription results returned", file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") # Get the transcript text from the first result - text = results[0] + if isinstance(results, list) and len(results) > 0: + text = results[0] + elif isinstance(results, dict) and "text" in results: + text = results["text"] + else: + text = str(results) if results else "" + if text: # Create a single segment with the full transcription result_segments.append(backend_pb2.TranscriptSegment( diff --git a/backend/python/nemo/requirements.txt b/backend/python/nemo/requirements.txt index f18110b3ffcd..3a51811a0c6e 100644 --- a/backend/python/nemo/requirements.txt +++ b/backend/python/nemo/requirements.txt @@ -4,3 +4,7 @@ certifi packaging==24.1 setuptools pyarrow==20.0.0 +torchaudio +soundfile +scipy +numpy diff --git a/core/http/routes/openresponses.go b/core/http/routes/openresponses.go index 19cadbbae677..d1a5fa4f18b2 100644 --- a/core/http/routes/openresponses.go +++ b/core/http/routes/openresponses.go @@ -43,6 +43,18 @@ func RegisterOpenResponsesRoutes(app *echo.Echo, cancelResponseHandler := openresponses.CancelResponseEndpoint() app.POST("/v1/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application)) app.POST("/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application)) + + // WebSocket endpoint for OpenAI Responses API WebSocket Mode + websocketHandler := openresponses.WebSocketEndpoint( + application.ModelConfigLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ) + + // WebSocket at /v1/responses (GET method for upgrade) + app.GET("/v1/responses", websocketHandler, middleware.TraceMiddleware(application)) + app.GET("/responses", websocketHandler, middleware.TraceMiddleware(application)) } // setOpenResponsesRequestContext sets up the context and cancel function for Open Responses requests diff --git a/core/schema/openresponses.go b/core/schema/openresponses.go index b5a81f413362..ddcb11b0ccf6 100644 --- a/core/schema/openresponses.go +++ b/core/schema/openresponses.go @@ -1,6 +1,7 @@ package schema import ( + "time" "context" ) @@ -309,3 +310,72 @@ func ORContentPartWithLogprobs(text string, logprobs *Logprobs) ORContentPart { Logprobs: orLogprobs, // REQUIRED - must always be present as array (empty if none) } } + +// WebSocket message types for Open Responses API WebSocket Mode +// https://developers.openai.com/api/docs/guides/websocket-mode + +// ORWebSocketMessage represents a WebSocket message (client -> server or server -> client) +type ORWebSocketMessage struct { + Type string `json:"type"` // response.create, response.created, response.progress, etc. +} + +// ORWebSocketClientMessage represents a client message to the WebSocket endpoint +type ORWebSocketClientMessage struct { + Type string `json:"type"` // "response.create" + Model string `json:"model,omitempty"` + Input interface{} `json:"input,omitempty"` + Tools []ORFunctionTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + Instructions string `json:"instructions,omitempty"` + Reasoning *ORReasoningParam `json:"reasoning,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Store *bool `json:"store,omitempty"` + TextFormat interface{} `json:"text_format,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + AllowedTools []string `json:"allowed_tools,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxToolCalls *int `json:"max_tool_calls,omitempty"` + Generate *bool `json:"generate,omitempty"` // If false, just warm up and return response_id +} + +// ORWebSocketServerEvent represents a server event to the WebSocket +type ORWebSocketServerEvent struct { + Type string `json:"type"` // response.created, response.progress, etc. + ResponseID string `json:"response_id,omitempty"` + Response *ORResponseResource `json:"response,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + Output []ORItemField `json:"output,omitempty"` + ItemID string `json:"item_id,omitempty"` + Item *ORItemField `json:"item,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` + Delta *string `json:"delta,omitempty"` + Text *string `json:"text,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments *string `json:"arguments,omitempty"` + Error *ORError `json:"error,omitempty"` +} + +// ORWebSocketError represents a WebSocket error event +type ORWebSocketError struct { + Type string `json:"type"` // error + Code string `json:"code,omitempty"` // previous_response_not_found, websocket_connection_limit_reached, etc. + Message string `json:"message"` + Param string `json:"param,omitempty"` +} + +// ConnectionLocalCacheEntry represents a cached response in connection-local storage +type ConnectionLocalCacheEntry struct { + ResponseID string + Response *ORResponseResource + Input *ORWebSocketClientMessage + CachedAt time.Time + ExpiresAt *time.Time +}