Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions IMPLEMENTATION_PLAN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# WebSocket Mode Implementation Plan for OpenAI Responses API

## Overview
Implement WebSocket support for LocalAI's OpenAI API-compatible Responses endpoint, enabling persistent WebSocket connections for long-running, tool-call-heavy agentic workflows.

## Technical Requirements

### 1. WebSocket Endpoint
- **Endpoint**: `ws://<host>:<port>/v1/responses`
- **Upgrade**: HTTP upgrade from POST /v1/responses when `Upgrade: websocket` header is present

### 2. Message Types (Client → Server)

#### response.create (Initial Turn)
```json
{
"type": "response.create",
"model": "gpt-4o",
"store": false,
"input": [...],
"tools": []
}
```

#### response.create with Continuation (Subsequent Turns)
```json
{
"type": "response.create",
"model": "gpt-4o",
"store": false,
"previous_response_id": "resp_123",
"input": [...],
"tools": []
}
```

### 3. Response Events (Server → Client)

1. **response.created** - Response object created
2. **response.progress** - Incremental output
3. **response.function_call_arguments.delta** - Streaming function arguments
4. **response.function_call_arguments.done** - Function call complete
5. **response.done** - Final response

### 4. Connection Management
- Track active connections with 60-minute timeout
- Connection-local cache for responses (when store=false)
- One in-flight response at a time per connection

### 5. Error Handling
- `previous_response_not_found` (400)
- `websocket_connection_limit_reached` (400)

## Implementation Steps

### Step 1: Add WebSocket Schema Types
- Add WebSocket message types to `core/schema/openresponses.go`
- Add connection-related types

### Step 2: Add WebSocket Route
- Modify `core/http/routes/openresponses.go` to handle WebSocket upgrade
- Add GET /v1/responses WebSocket endpoint

### Step 3: Create WebSocket Handler
- Create `core/http/endpoints/openresponses/websocket.go`
- Implement connection handling
- Implement message parsing
- Implement event streaming

### Step 4: Add Connection Store
- Implement connection management in store
- Add 60-minute timeout
- Add connection-local cache

## Files to Modify/Create
1. `core/schema/openresponses.go` - Add WebSocket types
2. `core/http/routes/openresponses.go` - Add WebSocket route
3. `core/http/endpoints/openresponses/websocket.go` - New WebSocket handler (create)
4. `core/http/endpoints/openresponses/store.go` - Add connection management
85 changes: 80 additions & 5 deletions backend/python/nemo/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
GRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
"""
from concurrent import futures
import time
Expand All @@ -12,6 +12,14 @@
import backend_pb2_grpc
import torch
import nemo.collections.asr as nemo_asr
import numpy as np

try:
import torchaudio
TORCHAUDIO_AVAILABLE = True
except ImportError:
TORCHAUDIO_AVAILABLE = False
print("[WARNING] torchaudio not available, will use fallback audio loading", file=sys.stderr)

import grpc

Expand All @@ -36,6 +44,50 @@ def is_int(s):
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))


def load_audio_np(audio_path, target_sample_rate=16000):
"""Load audio file as numpy array using available methods."""
if TORCHAUDIO_AVAILABLE:
try:
waveform, sample_rate = torchaudio.load(audio_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample if needed
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
waveform = resampler(waveform)
# Convert to numpy
audio_np = waveform.squeeze().numpy()
return audio_np, target_sample_rate
except Exception as e:
print(f"[WARNING] torchaudio loading failed: {e}, trying fallback", file=sys.stderr)

# Fallback: try using scipy or soundfile
try:
import soundfile as sf
audio_np, sample_rate = sf.read(audio_path)
if audio_np.ndim > 1:
audio_np = audio_np.mean(axis=1)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(audio_np) * target_sample_rate / sample_rate)
audio_np = resample(audio_np, num_samples)
return audio_np, target_sample_rate
except ImportError:
pass

try:
from scipy.io import wavfile
sample_rate, audio_np = wavfile.read(audio_path)
if audio_np.ndim > 1:
audio_np = audio_np.mean(axis=1)
return audio_np, sample_rate
except ImportError:
pass

raise RuntimeError("No audio loading library available (torchaudio, soundfile, scipy)")


class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
Expand Down Expand Up @@ -89,14 +141,37 @@ def AudioTranscription(self, request, context):
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")

# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
results = self.model.transcribe([audio_path])

# Load audio as numpy array to avoid lhotse dataloader issues
audio_np, sample_rate = load_audio_np(audio_path, target_sample_rate=16000)

# Convert to torch tensor
audio_tensor = torch.from_numpy(audio_np).float()
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension

# Use the model's transcribe method with the tensor directly
# Some NEMO models accept audio tensors directly
try:
# Try passing the waveform tensor directly
results = self.model.transcribe(audio_tensor, return_char_alignments=False)
except TypeError:
# Fallback: try with dict format
results = self.model.transcribe(
[{"audio_file": audio_path}],
return_char_alignments=False
)

if not results or len(results) == 0:
print("[WARNING] No transcription results returned", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")

# Get the transcript text from the first result
text = results[0]
if isinstance(results, list) and len(results) > 0:
text = results[0]
elif isinstance(results, dict) and "text" in results:
text = results["text"]
else:
text = str(results) if results else ""

if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
Expand Down
4 changes: 4 additions & 0 deletions backend/python/nemo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ certifi
packaging==24.1
setuptools
pyarrow==20.0.0
torchaudio
soundfile
scipy
numpy
12 changes: 12 additions & 0 deletions core/http/routes/openresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ func RegisterOpenResponsesRoutes(app *echo.Echo,
cancelResponseHandler := openresponses.CancelResponseEndpoint()
app.POST("/v1/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application))
app.POST("/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application))

// WebSocket endpoint for OpenAI Responses API WebSocket Mode
websocketHandler := openresponses.WebSocketEndpoint(
application.ModelConfigLoader(),
application.ModelLoader(),
application.TemplatesEvaluator(),
application.ApplicationConfig(),
)

// WebSocket at /v1/responses (GET method for upgrade)
app.GET("/v1/responses", websocketHandler, middleware.TraceMiddleware(application))
app.GET("/responses", websocketHandler, middleware.TraceMiddleware(application))
}

// setOpenResponsesRequestContext sets up the context and cancel function for Open Responses requests
Expand Down
70 changes: 70 additions & 0 deletions core/schema/openresponses.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package schema

import (
"time"
"context"
)

Expand All @@ -17,7 +18,7 @@
// OpenResponsesRequest represents a request to the Open Responses API
// https://www.openresponses.org/specification
type OpenResponsesRequest struct {
Model string `json:"model"`

Check failure on line 21 in core/schema/openresponses.go

View workflow job for this annotation

GitHub Actions / tests-apple (1.25.x)

other declaration of ORWebSocketMessage

Check failure on line 21 in core/schema/openresponses.go

View workflow job for this annotation

GitHub Actions / backend-jobs-darwin (nemo, -metal-darwin-arm64-nemo, mps) / darwin-backend-build (1.24.x)

other declaration of ORWebSocketMessage
Input interface{} `json:"input"` // string or []ORItemParam
Tools []ORFunctionTool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"` // "auto"|"required"|"none"|{type:"function",name:"..."}
Expand Down Expand Up @@ -309,3 +310,72 @@
Logprobs: orLogprobs, // REQUIRED - must always be present as array (empty if none)
}
}

// WebSocket message types for Open Responses API WebSocket Mode
// https://developers.openai.com/api/docs/guides/websocket-mode

// ORWebSocketMessage represents a WebSocket message (client -> server or server -> client)
type ORWebSocketMessage struct {
Type string `json:"type"` // response.create, response.created, response.progress, etc.
}

// ORWebSocketClientMessage represents a client message to the WebSocket endpoint
type ORWebSocketClientMessage struct {
Type string `json:"type"` // "response.create"
Model string `json:"model,omitempty"`
Input interface{} `json:"input,omitempty"`

Check failure on line 326 in core/schema/openresponses.go

View workflow job for this annotation

GitHub Actions / tests-apple (1.25.x)

ORWebSocketMessage redeclared in this block

Check failure on line 326 in core/schema/openresponses.go

View workflow job for this annotation

GitHub Actions / backend-jobs-darwin (nemo, -metal-darwin-arm64-nemo, mps) / darwin-backend-build (1.24.x)

ORWebSocketMessage redeclared in this block
Tools []ORFunctionTool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"`
Instructions string `json:"instructions,omitempty"`
Reasoning *ORReasoningParam `json:"reasoning,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Store *bool `json:"store,omitempty"`
TextFormat interface{} `json:"text_format,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
AllowedTools []string `json:"allowed_tools,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
MaxToolCalls *int `json:"max_tool_calls,omitempty"`
Generate *bool `json:"generate,omitempty"` // If false, just warm up and return response_id
}

// ORWebSocketServerEvent represents a server event to the WebSocket
type ORWebSocketServerEvent struct {
Type string `json:"type"` // response.created, response.progress, etc.
ResponseID string `json:"response_id,omitempty"`
Response *ORResponseResource `json:"response,omitempty"`
OutputIndex *int `json:"output_index,omitempty"`
Output []ORItemField `json:"output,omitempty"`
ItemID string `json:"item_id,omitempty"`
Item *ORItemField `json:"item,omitempty"`
ContentIndex *int `json:"content_index,omitempty"`
Delta *string `json:"delta,omitempty"`
Text *string `json:"text,omitempty"`
CallID string `json:"call_id,omitempty"`
Arguments *string `json:"arguments,omitempty"`
Error *ORError `json:"error,omitempty"`
}

// ORWebSocketError represents a WebSocket error event
type ORWebSocketError struct {
Type string `json:"type"` // error
Code string `json:"code,omitempty"` // previous_response_not_found, websocket_connection_limit_reached, etc.
Message string `json:"message"`
Param string `json:"param,omitempty"`
}

// ConnectionLocalCacheEntry represents a cached response in connection-local storage
type ConnectionLocalCacheEntry struct {
ResponseID string
Response *ORResponseResource
Input *ORWebSocketClientMessage
CachedAt time.Time
ExpiresAt *time.Time
}
Loading