diff --git a/.github/workflows/benchmark.yaml b/.github/workflows/benchmark.yaml index 3abb985e..5555ddb2 100644 --- a/.github/workflows/benchmark.yaml +++ b/.github/workflows/benchmark.yaml @@ -35,7 +35,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Install Apache Bench diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 13403a61..798de9d2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,7 +19,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Run golangci-lint @@ -38,7 +38,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Download dependencies @@ -66,7 +66,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Install govulncheck @@ -87,7 +87,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Run E2E tests @@ -105,7 +105,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Build binary @@ -321,7 +321,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Build diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 77a6db4e..83472f74 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -1,10 +1,13 @@ name: Docker on: + pull_request: + branches: [main] push: branches: [main] tags: - "v*" + workflow_dispatch: permissions: contents: read @@ -29,6 +32,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} @@ -63,7 +67,7 @@ jobs: with: context: . platforms: linux/amd64,linux/arm64 - push: true + push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | diff --git a/.gitignore b/.gitignore index d0faa240..21d2fc47 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,5 @@ completions/ # Dashboard frontend dist (copied from mockd-desktop at release time) pkg/admin/dashboard/dist/assets/ pkg/admin/dashboard/dist/index.html +.mcp.json +.claude/worktrees/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ef9fa26..62a767cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **WebSocket connection management API** — `GET /websocket/connections`, `GET /websocket/connections/{id}`, `DELETE /websocket/connections/{id}`, `POST /websocket/connections/{id}/send`, `GET /websocket/stats` added to the Admin API for real-time visibility, control, and server-initiated messaging of active WebSocket connections +- **SSE connection management API** — `GET /sse/connections`, `GET /sse/connections/{id}`, `DELETE /sse/connections/{id}`, `GET /sse/stats` for managing active SSE connections with auto-disconnect on mock update +- **SSE response type improvements** — SSE connection list responses now include full stats alongside connections +- **MQTT connection management API** — `GET /mqtt-connections`, `GET /mqtt-connections/{id}`, `DELETE /mqtt-connections/{id}`, `GET /mqtt-connections/stats` for managing active MQTT client connections with connection time tracking +- **gRPC stream tracking and management** — `GET /grpc/connections`, `GET /grpc/connections/{id}`, `DELETE /grpc/connections/{id}`, `GET /grpc/stats` for managing active gRPC streaming RPCs with proper Canceled vs Unavailable status codes +- **CLI connection management commands** — `mockd connections list`, `mockd connections get`, `mockd connections close` for all protocols (WebSocket, SSE, MQTT, gRPC) +- **Connection management UI** — unified connections view in the web dashboard showing all active connections across protocols with real-time stats, search, and filtering +- **WebSocket auto-reconnect on mock update** — updating or deleting a WebSocket mock now automatically closes all active connections with close code 1012 (Service Restart) so clients reconnect and pick up the new configuration immediately - **Workspace-scoped stateful resources** — stateful resources, custom operations, and request logs are now isolated per workspace - **`--workspace` persistent CLI flag** — scope any CLI command to a specific workspace without switching context - **`?workspaceId=` API parameter** — all admin API endpoints now accept workspace filtering diff --git a/Dockerfile b/Dockerfile index 8a5522a1..c6f2364d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # ============================================================================= # Stage 1: Builder # ============================================================================= -FROM golang:1.26.1-alpine AS builder +FROM golang:1.26.2-alpine AS builder # Build arguments for version info ARG VERSION=dev diff --git a/docs/src/content/docs/reference/admin-api.md b/docs/src/content/docs/reference/admin-api.md index 784eddde..699d6f97 100644 --- a/docs/src/content/docs/reference/admin-api.md +++ b/docs/src/content/docs/reference/admin-api.md @@ -251,6 +251,12 @@ The response includes ports for: ### Mock Management +:::note[WebSocket: active clients reconnect on mock changes] +When a WebSocket mock is updated or deleted, all clients currently connected to that endpoint receive a **close frame with code 1012 (Service Restart)**. Most WebSocket client libraries treat code 1012 as a signal to reconnect automatically. On reconnect, the client establishes a fresh connection that uses the new mock configuration. + +This applies to `PUT /mocks/{id}`, `DELETE /mocks/{id}`, `POST /mocks/{id}/toggle` (when disabling), bulk `DELETE /mocks` (delete all), and `POST /config` with `replace: true`. +::: + #### GET /mocks List all configured mocks. @@ -929,6 +935,34 @@ Export mocks as Insomnia v4 collection (JSON format, legacy). List active SSE connections. +**Response:** + +```json +{ + "connections": [ + { + "id": "sse-abc123", + "mockId": "mock-1", + "path": "/events", + "clientIp": "127.0.0.1", + "userAgent": "Mozilla/5.0", + "connectedAt": "2024-01-15T10:30:00Z", + "eventsSent": 42, + "bytesSent": 1024, + "status": "active" + } + ], + "stats": { + "totalConnections": 10, + "activeConnections": 1, + "totalEventsSent": 500, + "totalBytesSent": 51200, + "connectionErrors": 0, + "connectionsByMock": {"mock-1": 1} + } +} +``` + #### GET /sse/connections/{id} Get SSE connection details. @@ -937,6 +971,10 @@ Get SSE connection details. Close an SSE connection. +:::note[SSE auto-disconnect on mock update] +When an SSE mock is updated or deleted, all active SSE connections to that endpoint are automatically closed. Clients should reconnect to pick up the new configuration. +::: + #### GET /sse/stats Get SSE statistics. @@ -945,21 +983,46 @@ Get SSE statistics. ### WebSocket Management -#### GET /admin/ws/connections +#### GET /websocket/connections List active WebSocket connections. -#### GET /admin/ws/connections/{id} +**Response:** + +```json +{ + "connections": [ + { + "id": "ws-abc123", + "mockId": "mock-1", + "path": "/ws/chat", + "connectedAt": "2024-01-15T10:30:00Z", + "messagesSent": 15, + "messagesRecv": 10, + "status": "connected" + } + ], + "stats": { + "totalConnections": 50, + "activeConnections": 1, + "totalMessagesSent": 500, + "totalMessagesRecv": 300, + "connectionsByMock": {"mock-1": 1} + } +} +``` + +#### GET /websocket/connections/{id} Get connection details. -#### DELETE /admin/ws/connections/{id} +#### DELETE /websocket/connections/{id} Close a WebSocket connection. -#### POST /admin/ws/connections/{id}/send +#### POST /websocket/connections/{id}/send -Send a message to a specific connection. +Send a text or binary message to a specific active WebSocket connection. **Request:** @@ -970,15 +1033,34 @@ Send a message to a specific connection. } ``` -#### POST /admin/ws/broadcast +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Message type: `"text"` (default) or `"binary"` | +| `data` | string | Message payload. For `"text"`, a plain UTF-8 string. For `"binary"`, a **base64-encoded** string — the server decodes it before writing raw bytes to the WebSocket. | -Broadcast message to all connections. +**Response:** -#### GET /admin/ws/endpoints +```json +{ + "message": "Message sent", + "connection": "ws-abc123", + "type": "text" +} +``` + +Returns `404` if the connection is not found. + +#### GET /mocks/{id}/websocket/connections + +List active WebSocket connections for a specific mock. -List configured WebSocket endpoints. +**Response:** Same format as `GET /websocket/connections`, filtered to the given mock ID. -#### GET /admin/ws/stats +#### DELETE /mocks/{id}/websocket/connections + +Close all WebSocket connections for a specific mock. + +#### GET /websocket/stats Get WebSocket statistics. @@ -1024,6 +1106,100 @@ Stop a replay session. --- +### MQTT Connection Management + +#### GET /mqtt-connections + +List active MQTT client connections. + +**Response:** + +```json +{ + "connections": [ + { + "id": "client-abc123", + "brokerId": "mqtt-broker-1", + "connectedAt": "2024-01-15T10:30:00Z", + "subscriptions": ["sensors/#", "devices/+"], + "protocolVersion": 5, + "username": "device-1", + "remoteAddr": "192.168.1.10:54321", + "status": "connected" + } + ], + "stats": { + "connectedClients": 1, + "totalSubscriptions": 2, + "topicCount": 5, + "port": 1883, + "tlsEnabled": false, + "authEnabled": false, + "subscriptionsByClient": {"client-abc123": 2} + } +} +``` + +#### GET /mqtt-connections/{id} + +Get details of a specific MQTT client connection. + +#### DELETE /mqtt-connections/{id} + +Disconnect an MQTT client. + +#### GET /mqtt-connections/stats + +Get MQTT connection statistics. + +--- + +### gRPC Stream Management + +#### GET /grpc/connections + +List active gRPC streaming RPCs. + +**Response:** + +```json +{ + "streams": [ + { + "id": "grpc-stream-1", + "method": "/myapp.ChatService/StreamMessages", + "streamType": "bidi", + "clientAddr": "127.0.0.1:54321", + "connectedAt": "2024-01-15T10:30:00Z", + "messagesSent": 15, + "messagesRecv": 10 + } + ], + "stats": { + "activeStreams": 1, + "totalStreams": 50, + "totalRPCs": 200, + "totalMessagesSent": 1000, + "totalMessagesRecv": 800, + "streamsByMethod": {"/myapp.ChatService/StreamMessages": 1} + } +} +``` + +#### GET /grpc/connections/{id} + +Get details of a specific gRPC stream. + +#### DELETE /grpc/connections/{id} + +Cancel a gRPC stream. The client receives a `codes.Unavailable` status, signaling it should reconnect. + +#### GET /grpc/stats + +Get gRPC stream statistics. + +--- + ### gRPC Management #### GET /grpc diff --git a/go.mod b/go.mod index b20001d5..1a8563c8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/getmockd/mockd -go 1.26.1 +go 1.26.2 require ( github.com/beevik/etree v1.6.0 diff --git a/pkg/admin/engineclient/client.go b/pkg/admin/engineclient/client.go index 7d3815ea..7695245b 100644 --- a/pkg/admin/engineclient/client.go +++ b/pkg/admin/engineclient/client.go @@ -1077,6 +1077,270 @@ func (c *Client) GetSSEStats(ctx context.Context) (*SSEStats, error) { return &stats, nil } +// ListWebSocketConnections returns all active WebSocket connections. +func (c *Client) ListWebSocketConnections(ctx context.Context) ([]*WebSocketConnection, error) { + resp, err := c.get(ctx, "/websocket/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Connections []*WebSocketConnection `json:"connections"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket connections: %w", err) + } + return result.Connections, nil +} + +// GetWebSocketConnection returns a specific WebSocket connection by ID. +func (c *Client) GetWebSocketConnection(ctx context.Context, id string) (*WebSocketConnection, error) { + resp, err := c.get(ctx, "/websocket/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var conn WebSocketConnection + if err := json.NewDecoder(resp.Body).Decode(&conn); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket connection: %w", err) + } + return &conn, nil +} + +// CloseWebSocketConnection closes a WebSocket connection by ID. +func (c *Client) CloseWebSocketConnection(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/websocket/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// SendToWebSocketConnection sends a text or binary message to a specific connection. +// msgType must be "text" (default) or "binary". +// For binary messages, data must be a base64-encoded string; the engine decodes +// it before sending the raw bytes over the WebSocket connection. +func (c *Client) SendToWebSocketConnection(ctx context.Context, id string, msgType string, data string) error { + body := map[string]string{ + "type": msgType, + "data": data, + } + resp, err := c.post(ctx, "/websocket/connections/"+url.PathEscape(id)+"/send", body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return c.parseError(resp) + } + return nil +} + +// GetWebSocketStats returns WebSocket statistics. +func (c *Client) GetWebSocketStats(ctx context.Context) (*WebSocketStats, error) { + resp, err := c.get(ctx, "/websocket/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats WebSocketStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket stats: %w", err) + } + return &stats, nil +} + +// ListMQTTConnections returns all active MQTT client connections. +func (c *Client) ListMQTTConnections(ctx context.Context) ([]*MQTTConnection, error) { + resp, err := c.get(ctx, "/mqtt/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Connections []*MQTTConnection `json:"connections"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode MQTT connections: %w", err) + } + return result.Connections, nil +} + +// GetMQTTConnection returns a specific MQTT client connection. +func (c *Client) GetMQTTConnection(ctx context.Context, id string) (*MQTTConnection, error) { + resp, err := c.get(ctx, "/mqtt/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var conn MQTTConnection + if err := json.NewDecoder(resp.Body).Decode(&conn); err != nil { + return nil, fmt.Errorf("failed to decode MQTT connection: %w", err) + } + return &conn, nil +} + +// CloseMQTTConnection disconnects a specific MQTT client. +func (c *Client) CloseMQTTConnection(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/mqtt/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetMQTTStats returns MQTT broker statistics. +func (c *Client) GetMQTTStats(ctx context.Context) (*MQTTStats, error) { + resp, err := c.get(ctx, "/mqtt/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats MQTTStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode MQTT stats: %w", err) + } + return &stats, nil +} + +// ListGRPCStreams returns all active gRPC streaming connections. +func (c *Client) ListGRPCStreams(ctx context.Context) ([]*GRPCStream, error) { + resp, err := c.get(ctx, "/grpc/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Streams []*GRPCStream `json:"streams"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode gRPC streams: %w", err) + } + return result.Streams, nil +} + +// GetGRPCStream returns a specific gRPC stream by ID. +func (c *Client) GetGRPCStream(ctx context.Context, id string) (*GRPCStream, error) { + resp, err := c.get(ctx, "/grpc/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stream GRPCStream + if err := json.NewDecoder(resp.Body).Decode(&stream); err != nil { + return nil, fmt.Errorf("failed to decode gRPC stream: %w", err) + } + return &stream, nil +} + +// CancelGRPCStream cancels (terminates) a specific gRPC stream. +func (c *Client) CancelGRPCStream(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/grpc/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetGRPCStats returns gRPC statistics. +func (c *Client) GetGRPCStats(ctx context.Context) (*GRPCStats, error) { + resp, err := c.get(ctx, "/grpc/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats GRPCStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode gRPC stats: %w", err) + } + return &stats, nil +} + // HTTP helpers func (c *Client) get(ctx context.Context, path string) (*http.Response, error) { diff --git a/pkg/admin/engineclient/client_test.go b/pkg/admin/engineclient/client_test.go index d01491fc..dbcefe1c 100644 --- a/pkg/admin/engineclient/client_test.go +++ b/pkg/admin/engineclient/client_test.go @@ -750,6 +750,145 @@ func TestHTTPMethods(t *testing.T) { } } +// --- WebSocket Tests --- + +func TestListWebSocketConnections_Success(t *testing.T) { + resp := struct { + Connections []*WebSocketConnection `json:"connections"` + Count int `json:"count"` + }{ + Connections: []*WebSocketConnection{{ID: "ws-1"}, {ID: "ws-2"}}, + Count: 2, + } + _, c := mockServer(t, jsonHandler(t, 200, resp)) + + conns, err := c.ListWebSocketConnections(context.Background()) + if err != nil { + t.Fatalf("ListWebSocketConnections() error = %v", err) + } + if len(conns) != 2 { + t.Errorf("ListWebSocketConnections() = %d, want 2", len(conns)) + } + if conns[0].ID != "ws-1" { + t.Errorf("ListWebSocketConnections()[0].ID = %q, want %q", conns[0].ID, "ws-1") + } +} + +func TestListWebSocketConnections_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 503, nil)) + _, err := c.ListWebSocketConnections(context.Background()) + if err == nil { + t.Error("ListWebSocketConnections() error = nil, want error for 503") + } +} + +func TestGetWebSocketConnection_Success(t *testing.T) { + conn := WebSocketConnection{ID: "ws-1", Status: "connected"} + _, c := mockServer(t, jsonHandler(t, 200, conn)) + + result, err := c.GetWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Fatalf("GetWebSocketConnection() error = %v", err) + } + if result.ID != "ws-1" { + t.Errorf("GetWebSocketConnection().ID = %q, want %q", result.ID, "ws-1") + } +} + +func TestGetWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + _, err := c.GetWebSocketConnection(context.Background(), "missing") + if !errors.Is(err, ErrNotFound) { + t.Errorf("GetWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestCloseWebSocketConnection_Success(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 200, map[string]string{"message": "closed"})) + err := c.CloseWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Errorf("CloseWebSocketConnection() error = %v, want nil", err) + } +} + +func TestCloseWebSocketConnection_NoContent(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 204, nil)) + err := c.CloseWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Errorf("CloseWebSocketConnection() 204 error = %v, want nil", err) + } +} + +func TestCloseWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + err := c.CloseWebSocketConnection(context.Background(), "missing") + if !errors.Is(err, ErrNotFound) { + t.Errorf("CloseWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestSendToWebSocketConnection_Success(t *testing.T) { + var capturedBody map[string]string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewDecoder(r.Body).Decode(&capturedBody) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "sent"}) + })) + defer ts.Close() + c := New(ts.URL) + + err := c.SendToWebSocketConnection(context.Background(), "ws-1", "text", "hello") + if err != nil { + t.Fatalf("SendToWebSocketConnection() error = %v", err) + } + if capturedBody["type"] != "text" { + t.Errorf("request body type = %q, want %q", capturedBody["type"], "text") + } + if capturedBody["data"] != "hello" { + t.Errorf("request body data = %q, want %q", capturedBody["data"], "hello") + } +} + +func TestSendToWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + err := c.SendToWebSocketConnection(context.Background(), "missing", "text", "hello") + if !errors.Is(err, ErrNotFound) { + t.Errorf("SendToWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestSendToWebSocketConnection_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 500, ErrorResponse{Error: "engine_error", Message: "failed"})) + err := c.SendToWebSocketConnection(context.Background(), "ws-1", "text", "hello") + if err == nil { + t.Error("SendToWebSocketConnection() error = nil, want error for 500") + } +} + +func TestGetWebSocketStats_Success(t *testing.T) { + stats := WebSocketStats{ActiveConnections: 3, TotalConnections: 10} + _, c := mockServer(t, jsonHandler(t, 200, stats)) + + result, err := c.GetWebSocketStats(context.Background()) + if err != nil { + t.Fatalf("GetWebSocketStats() error = %v", err) + } + if result.ActiveConnections != 3 { + t.Errorf("GetWebSocketStats().ActiveConnections = %d, want 3", result.ActiveConnections) + } + if result.TotalConnections != 10 { + t.Errorf("GetWebSocketStats().TotalConnections = %d, want 10", result.TotalConnections) + } +} + +func TestGetWebSocketStats_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 503, nil)) + _, err := c.GetWebSocketStats(context.Background()) + if err == nil { + t.Error("GetWebSocketStats() error = nil, want error for 503") + } +} + // containsStr checks if s contains substr. func containsStr(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) diff --git a/pkg/admin/engineclient/types.go b/pkg/admin/engineclient/types.go index 19367836..031df268 100644 --- a/pkg/admin/engineclient/types.go +++ b/pkg/admin/engineclient/types.go @@ -47,6 +47,12 @@ type ( ProtocolHandler = types.ProtocolHandler SSEConnection = types.SSEConnection SSEStats = types.SSEStats + WebSocketConnection = types.WebSocketConnection + WebSocketStats = types.WebSocketStats + MQTTConnection = types.MQTTConnection + MQTTStats = types.MQTTStats + GRPCStream = types.GRPCStream + GRPCStats = types.GRPCStats CustomOperationInfo = types.CustomOperationInfo CustomOperationDetail = types.CustomOperationDetail CustomOperationStep = types.CustomOperationStep diff --git a/pkg/admin/grpc_handlers.go b/pkg/admin/grpc_handlers.go new file mode 100644 index 00000000..1090e043 --- /dev/null +++ b/pkg/admin/grpc_handlers.go @@ -0,0 +1,184 @@ +package admin + +import ( + "context" + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// GRPCStreamListResponse represents a list of gRPC streams with stats. +type GRPCStreamListResponse struct { + Streams []*engineclient.GRPCStream `json:"streams"` + Stats engineclient.GRPCStats `json:"stats"` +} + +// handleListGRPCStreams handles GET /grpc/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol list handlers +func (a *API) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: []*engineclient.GRPCStream{}, + Stats: engineclient.GRPCStats{StreamsByMethod: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetGRPCStats(ctx) + if err != nil { + a.logger().Error("failed to get gRPC stats", "error", err) + status, code, msg := mapGRPCEngineError(err, a.logger(), "get gRPC stats") + writeError(w, status, code, msg) + return + } + + streams, err := engine.ListGRPCStreams(ctx) + if err != nil { + a.logger().Error("failed to list gRPC streams", "error", err) + status, code, msg := mapGRPCEngineError(err, a.logger(), "list gRPC streams") + writeError(w, status, code, msg) + return + } + + if streams == nil { + streams = []*engineclient.GRPCStream{} + } + + byMethod := stats.StreamsByMethod + if byMethod == nil { + byMethod = make(map[string]int) + } + + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: streams, + Stats: engineclient.GRPCStats{ + ActiveStreams: stats.ActiveStreams, + TotalStreams: stats.TotalStreams, + TotalRPCs: stats.TotalRPCs, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + StreamsByMethod: byMethod, + }, + }) +} + +// handleGetGRPCStream handles GET /grpc/connections/{id}. +func (a *API) handleGetGRPCStream(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Stream ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + + stream, err := engine.GetGRPCStream(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + a.logger().Error("failed to get gRPC stream", "error", err, "streamID", id) + status, code, msg := mapGRPCEngineError(err, a.logger(), "get gRPC stream") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, stream) +} + +// handleCancelGRPCStream handles DELETE /grpc/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers +func (a *API) handleCancelGRPCStream(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Stream ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + + err := engine.CancelGRPCStream(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + a.logger().Error("failed to cancel gRPC stream", "error", err, "streamID", id) + status, code, msg := mapGRPCEngineError(err, a.logger(), "cancel gRPC stream") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Stream cancelled", + "stream": id, + }) +} + +// handleGetGRPCStats handles GET /grpc/stats. +func (a *API) handleGetGRPCStats(w http.ResponseWriter, r *http.Request) { + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.GRPCStats{StreamsByMethod: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newGRPCStatsProvider(engine)) +} + +func mapGRPCEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} + +// grpcStatsProvider implements statsProvider for gRPC statistics. +type grpcStatsProvider struct { + engine *engineclient.Client +} + +func newGRPCStatsProvider(engine *engineclient.Client) *grpcStatsProvider { + return &grpcStatsProvider{engine: engine} +} + +func (p *grpcStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetGRPCStats(ctx) + if err != nil { + return nil, err + } + + byMethod := stats.StreamsByMethod + if byMethod == nil { + byMethod = make(map[string]int) + } + + return engineclient.GRPCStats{ + ActiveStreams: stats.ActiveStreams, + TotalStreams: stats.TotalStreams, + TotalRPCs: stats.TotalRPCs, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + StreamsByMethod: byMethod, + }, nil +} + +func (p *grpcStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapGRPCEngineError(err, log, operation) +} + +func (p *grpcStatsProvider) ProtocolName() string { return "gRPC" } diff --git a/pkg/admin/grpc_handlers_test.go b/pkg/admin/grpc_handlers_test.go new file mode 100644 index 00000000..e5941bc0 --- /dev/null +++ b/pkg/admin/grpc_handlers_test.go @@ -0,0 +1,218 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// ============================================================================ +// handleListGRPCStreams +// ============================================================================ + +func TestHandleListGRPCStreams_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + api.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp GRPCStreamListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Streams) + assert.NotNil(t, resp.Stats.StreamsByMethod) +} + +func TestHandleListGRPCStreams_WithEngine_ReturnsStreamsAndStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/grpc/stats": + _, _ = w.Write([]byte(`{"activeStreams":1,"totalStreams":5,"totalRPCs":100,"totalMessagesSent":50,"totalMessagesRecv":30,"streamsByMethod":{"/pkg.Svc/Stream":1}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/grpc/connections": + _, _ = w.Write([]byte(`{"streams":[{"id":"grpc-stream-1","method":"/pkg.Svc/Stream","streamType":"server_stream","clientAddr":"127.0.0.1:5000","connectedAt":"2026-04-05T00:00:00Z","messagesSent":10,"messagesRecv":1}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + api.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp GRPCStreamListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Streams, 1) + assert.Equal(t, "grpc-stream-1", resp.Streams[0].ID) + assert.Equal(t, "/pkg.Svc/Stream", resp.Streams[0].Method) + assert.Equal(t, 1, resp.Stats.ActiveStreams) + assert.Equal(t, int64(100), resp.Stats.TotalRPCs) +} + +// ============================================================================ +// handleGetGRPCStream +// ============================================================================ + +func TestHandleGetGRPCStream_MissingID_ReturnsBadRequest(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetGRPCStream_NoEngine_ReturnsNotFound(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetGRPCStream_Found_ReturnsStream(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"stream-1","method":"/pkg.Svc/Chat","streamType":"bidi","connectedAt":"2026-04-05T00:00:00Z","messagesSent":5,"messagesRecv":3}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stream engineclient.GRPCStream + err := json.Unmarshal(rec.Body.Bytes(), &stream) + require.NoError(t, err) + assert.Equal(t, "stream-1", stream.ID) + assert.Equal(t, "bidi", stream.StreamType) +} + +// ============================================================================ +// handleCancelGRPCStream +// ============================================================================ + +func TestHandleCancelGRPCStream_MissingID_ReturnsBadRequest(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/", nil) + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCancelGRPCStream_NoEngine_ReturnsNotFound(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCancelGRPCStream_Success(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"gRPC stream cancelled","id":"stream-1"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// ============================================================================ +// handleGetGRPCStats +// ============================================================================ + +func TestHandleGetGRPCStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.GRPCStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.StreamsByMethod) +} + +func TestHandleGetGRPCStats_WithEngine_ReturnsStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"activeStreams":3,"totalStreams":10,"totalRPCs":500,"totalMessagesSent":200,"totalMessagesRecv":150,"streamsByMethod":{"/pkg.Svc/A":2,"/pkg.Svc/B":1}}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.GRPCStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, 3, stats.ActiveStreams) + assert.Equal(t, int64(500), stats.TotalRPCs) + assert.Len(t, stats.StreamsByMethod, 2) +} diff --git a/pkg/admin/mqtt_handlers.go b/pkg/admin/mqtt_handlers.go new file mode 100644 index 00000000..8bbd2b36 --- /dev/null +++ b/pkg/admin/mqtt_handlers.go @@ -0,0 +1,146 @@ +package admin + +import ( + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// MQTTConnectionListResponse represents a list of MQTT connections with stats. +type MQTTConnectionListResponse struct { + Connections []*engineclient.MQTTConnection `json:"connections"` + Stats engineclient.MQTTStats `json:"stats"` +} + +// handleListMQTTConnections handles GET /mqtt/connections. +func (a *API) handleListMQTTConnections(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: []*engineclient.MQTTConnection{}, + Stats: engineclient.MQTTStats{SubscriptionsByClient: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetMQTTStats(ctx) + if err != nil { + a.logger().Error("failed to get MQTT stats", "error", err) + status, code, msg := mapMQTTEngineError(err, a.logger(), "get MQTT stats") + writeError(w, status, code, msg) + return + } + + connections, err := engine.ListMQTTConnections(ctx) + if err != nil { + a.logger().Error("failed to list MQTT connections", "error", err) + status, code, msg := mapMQTTEngineError(err, a.logger(), "list MQTT connections") + writeError(w, status, code, msg) + return + } + + if connections == nil { + connections = []*engineclient.MQTTConnection{} + } + + subsByClient := stats.SubscriptionsByClient + if subsByClient == nil { + subsByClient = make(map[string]int) + } + + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: connections, + Stats: engineclient.MQTTStats{ + ConnectedClients: stats.ConnectedClients, + TotalSubscriptions: stats.TotalSubscriptions, + TopicCount: stats.TopicCount, + Port: stats.Port, + TLSEnabled: stats.TLSEnabled, + AuthEnabled: stats.AuthEnabled, + SubscriptionsByClient: subsByClient, + }, + }) +} + +// handleGetMQTTConnection handles GET /mqtt/connections/{id}. +func (a *API) handleGetMQTTConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Client ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + conn, err := engine.GetMQTTConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to get MQTT connection", "error", err, "clientID", id) + status, code, msg := mapMQTTEngineError(err, a.logger(), "get MQTT connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, conn) +} + +// handleCloseMQTTConnection handles DELETE /mqtt/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers +func (a *API) handleCloseMQTTConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Client ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.CloseMQTTConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to close MQTT connection", "error", err, "clientID", id) + status, code, msg := mapMQTTEngineError(err, a.logger(), "close MQTT connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connection closed", + "connection": id, + }) +} + +// handleGetMQTTStats handles GET /mqtt/stats. +func (a *API) handleGetMQTTStats(w http.ResponseWriter, r *http.Request) { + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.MQTTStats{SubscriptionsByClient: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newMQTTStatsProvider(engine)) +} + +func mapMQTTEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} diff --git a/pkg/admin/mqtt_handlers_test.go b/pkg/admin/mqtt_handlers_test.go new file mode 100644 index 00000000..792d8403 --- /dev/null +++ b/pkg/admin/mqtt_handlers_test.go @@ -0,0 +1,251 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// ============================================================================ +// handleListMQTTConnections +// ============================================================================ + +func TestHandleListMQTTConnections_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp MQTTConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.NotNil(t, resp.Stats.SubscriptionsByClient) +} + +func TestHandleListMQTTConnections_WithEngine_ReturnsConnectionsAndStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/mqtt/stats": + _, _ = w.Write([]byte(`{"connectedClients":2,"totalSubscriptions":3,"topicCount":5,"port":1883,"tlsEnabled":false,"authEnabled":true,"subscriptionsByClient":{"client-1":2,"client-2":1}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/mqtt/connections": + _, _ = w.Write([]byte(`{"connections":[{"id":"client-1","brokerId":"broker-1","connectedAt":"2026-04-05T00:00:00Z","subscriptions":["sensors/#"],"protocolVersion":4,"status":"connected"}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp MQTTConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 1) + assert.Equal(t, "client-1", resp.Connections[0].ID) + assert.Equal(t, "broker-1", resp.Connections[0].BrokerID) + assert.Equal(t, []string{"sensors/#"}, resp.Connections[0].Subscriptions) + assert.Equal(t, 2, resp.Stats.ConnectedClients) + assert.Equal(t, 3, resp.Stats.TotalSubscriptions) + assert.True(t, resp.Stats.AuthEnabled) +} + +func TestHandleListMQTTConnections_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetMQTTConnection +// ============================================================================ + +func TestHandleGetMQTTConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetMQTTConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetMQTTConnection_Found_ReturnsConnection(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"client-1","brokerId":"broker-1","connectedAt":"2026-04-05T00:00:00Z","subscriptions":["topic/a"],"protocolVersion":5,"status":"connected"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn engineclient.MQTTConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "client-1", conn.ID) + assert.Equal(t, byte(5), conn.ProtocolVersion) +} + +func TestHandleGetMQTTConnection_NotFound_Returns404(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not_found","message":"MQTT connection not found"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +// ============================================================================ +// handleCloseMQTTConnection +// ============================================================================ + +func TestHandleCloseMQTTConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/", nil) + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCloseMQTTConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCloseMQTTConnection_Success(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"MQTT connection closed","id":"client-1"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// ============================================================================ +// handleGetMQTTStats +// ============================================================================ + +func TestHandleGetMQTTStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.MQTTStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.SubscriptionsByClient) +} + +func TestHandleGetMQTTStats_WithEngine_ReturnsStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"connectedClients":3,"totalSubscriptions":5,"topicCount":2,"port":1883,"subscriptionsByClient":{"c1":2,"c2":3}}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.MQTTStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, 3, stats.ConnectedClients) + assert.Equal(t, 5, stats.TotalSubscriptions) + assert.Equal(t, 2, stats.SubscriptionsByClient["c1"]) +} diff --git a/pkg/admin/routes.go b/pkg/admin/routes.go index 0cff1521..eb84396a 100644 --- a/pkg/admin/routes.go +++ b/pkg/admin/routes.go @@ -123,9 +123,34 @@ func (a *API) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("DELETE /sse/connections/{id}", a.handleCloseSSEConnection) mux.HandleFunc("GET /sse/stats", a.handleGetSSEStats) + // WebSocket connection management + mux.HandleFunc("GET /websocket/connections", a.handleListWebSocketConnections) + mux.HandleFunc("GET /websocket/connections/{id}", a.handleGetWebSocketConnection) + mux.HandleFunc("DELETE /websocket/connections/{id}", a.handleCloseWebSocketConnection) + mux.HandleFunc("POST /websocket/connections/{id}/send", a.handleSendToWebSocketConnection) + mux.HandleFunc("GET /websocket/stats", a.handleGetWebSocketStats) + + // MQTT connection management + // Routes use /mqtt-connections/ prefix to avoid ambiguity with existing + // /mqtt/{id}/status recording routes. + mux.HandleFunc("GET /mqtt-connections", a.handleListMQTTConnections) + mux.HandleFunc("GET /mqtt-connections/{id}", a.handleGetMQTTConnection) + mux.HandleFunc("DELETE /mqtt-connections/{id}", a.handleCloseMQTTConnection) + mux.HandleFunc("GET /mqtt-connections/stats", a.handleGetMQTTStats) + + // gRPC stream management + mux.HandleFunc("GET /grpc/connections", a.handleListGRPCStreams) + mux.HandleFunc("GET /grpc/connections/{id}", a.handleGetGRPCStream) + mux.HandleFunc("DELETE /grpc/connections/{id}", a.handleCancelGRPCStream) + mux.HandleFunc("GET /grpc/stats", a.handleGetGRPCStats) + // Mock-specific SSE endpoints mux.HandleFunc("GET /mocks/{id}/sse/connections", a.requireEngine(a.handleListMockSSEConnections)) mux.HandleFunc("DELETE /mocks/{id}/sse/connections", a.requireEngine(a.handleCloseMockSSEConnections)) + + // Mock-specific WebSocket endpoints + mux.HandleFunc("GET /mocks/{id}/websocket/connections", a.requireEngine(a.handleListMockWebSocketConnections)) + mux.HandleFunc("DELETE /mocks/{id}/websocket/connections", a.requireEngine(a.handleCloseMockWebSocketConnections)) mux.HandleFunc("GET /mocks/{id}/sse/buffer", a.handleGetMockSSEBuffer) mux.HandleFunc("DELETE /mocks/{id}/sse/buffer", a.handleClearMockSSEBuffer) diff --git a/pkg/admin/sse_handlers.go b/pkg/admin/sse_handlers.go index 0e741a51..e6b2705e 100644 --- a/pkg/admin/sse_handlers.go +++ b/pkg/admin/sse_handlers.go @@ -6,25 +6,26 @@ import ( "net/http" "github.com/getmockd/mockd/pkg/admin/engineclient" - "github.com/getmockd/mockd/pkg/sse" "github.com/getmockd/mockd/pkg/store" ) // SSEConnectionListResponse represents a list of SSE connections. type SSEConnectionListResponse struct { - Connections []sse.SSEStreamInfo `json:"connections"` - Stats sse.ConnectionStats `json:"stats"` + Connections []*engineclient.SSEConnection `json:"connections"` + Stats engineclient.SSEStats `json:"stats"` } // handleListSSEConnections handles GET /sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol list handlers func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { ctx := r.Context() engine := a.localEngine.Load() if engine == nil { writeJSON(w, http.StatusOK, SSEConnectionListResponse{ - Connections: []sse.SSEStreamInfo{}, - Stats: sse.ConnectionStats{ConnectionsByMock: make(map[string]int)}, + Connections: []*engineclient.SSEConnection{}, + Stats: engineclient.SSEStats{ConnectionsByMock: make(map[string]int)}, }) return } @@ -45,14 +46,8 @@ func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { return } - // Convert engine client connections to SSEStreamInfo - info := make([]sse.SSEStreamInfo, 0, len(connections)) - for _, conn := range connections { - info = append(info, sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - ClientIP: conn.ClientIP, - }) + if connections == nil { + connections = []*engineclient.SSEConnection{} } connsByMock := stats.ConnectionsByMock @@ -61,12 +56,13 @@ func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, SSEConnectionListResponse{ - Connections: info, - Stats: sse.ConnectionStats{ - ActiveConnections: stats.ActiveConnections, + Connections: connections, + Stats: engineclient.SSEStats{ TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, TotalEventsSent: stats.TotalEventsSent, TotalBytesSent: stats.TotalBytesSent, + ConnectionErrors: stats.ConnectionErrors, ConnectionsByMock: connsByMock, }, }) @@ -99,12 +95,7 @@ func (a *API) handleGetSSEConnection(w http.ResponseWriter, r *http.Request) { return } - info := sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - } - - writeJSON(w, http.StatusOK, info) + writeJSON(w, http.StatusOK, conn) } // handleCloseSSEConnection handles DELETE /sse/connections/{id}. @@ -149,38 +140,17 @@ func (a *API) handleCloseSSEConnection(w http.ResponseWriter, r *http.Request) { // handleGetSSEStats handles GET /sse/stats. func (a *API) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - engine := a.localEngine.Load() if engine == nil { - writeJSON(w, http.StatusOK, sse.ConnectionStats{ - ConnectionsByMock: make(map[string]int), - }) + writeJSON(w, http.StatusOK, engineclient.SSEStats{ConnectionsByMock: make(map[string]int)}) return } - - stats, err := engine.GetSSEStats(ctx) - if err != nil { - a.logger().Error("failed to get SSE stats", "error", err) - status, code, msg := mapSSEEngineError(err, a.logger(), "get SSE stats") - writeError(w, status, code, msg) - return - } - - connsByMock := stats.ConnectionsByMock - if connsByMock == nil { - connsByMock = make(map[string]int) - } - writeJSON(w, http.StatusOK, sse.ConnectionStats{ - ActiveConnections: stats.ActiveConnections, - TotalConnections: stats.TotalConnections, - TotalEventsSent: stats.TotalEventsSent, - TotalBytesSent: stats.TotalBytesSent, - ConnectionsByMock: connsByMock, - }) + a.handleGetStats(w, r, newSSEStatsProvider(engine)) } // handleListMockSSEConnections handles GET /mocks/{id}/sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleListMockSSEConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") @@ -208,24 +178,27 @@ func (a *API) handleListMockSSEConnections(w http.ResponseWriter, r *http.Reques return } - info := make([]sse.SSEStreamInfo, 0) + var filtered []*engineclient.SSEConnection for _, conn := range connections { if conn.MockID == mockID { - info = append(info, sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - }) + filtered = append(filtered, conn) } } + if filtered == nil { + filtered = []*engineclient.SSEConnection{} + } + writeJSON(w, http.StatusOK, map[string]interface{}{ - "connections": info, - "count": len(info), + "connections": filtered, + "count": len(filtered), "mockId": mockID, }) } // handleCloseMockSSEConnections handles DELETE /mocks/{id}/sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleCloseMockSSEConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") diff --git a/pkg/admin/stat_helper.go b/pkg/admin/stat_helper.go new file mode 100644 index 00000000..ed732ecf --- /dev/null +++ b/pkg/admin/stat_helper.go @@ -0,0 +1,156 @@ +package admin + +import ( + "context" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" + "github.com/getmockd/mockd/pkg/sse" +) + +// statsProvider defines the interface for retrieving and formatting statistics. +type statsProvider interface { + // GetStats retrieves statistics from the engine. + GetStats(ctx context.Context) (interface{}, error) + // MapError converts an error to HTTP status, code, and message. + MapError(err error, log *slog.Logger, operation string) (int, string, string) + // ProtocolName returns the protocol name for error logging. + ProtocolName() string +} + +// handleGetStats is a generic handler for retrieving statistics. +// The caller must guard against a nil engine and handle the empty-stats +// response before constructing the provider and invoking this function. +func (a *API) handleGetStats(w http.ResponseWriter, r *http.Request, provider statsProvider) { + ctx := r.Context() + + stats, err := provider.GetStats(ctx) + if err != nil { + a.logger().Error("failed to get "+provider.ProtocolName()+" stats", "error", err) + status, code, msg := provider.MapError(err, a.logger(), "get stats") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, stats) +} + +// sseStatsProvider implements statsProvider for SSE statistics. +type sseStatsProvider struct { + engine *engineclient.Client +} + +// newSSEStatsProvider creates a new SSE statistics provider. +func newSSEStatsProvider(engine *engineclient.Client) *sseStatsProvider { + return &sseStatsProvider{engine: engine} +} + +// GetStats retrieves SSE statistics from the engine. +func (p *sseStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetSSEStats(ctx) + if err != nil { + return nil, err + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + return sse.ConnectionStats{ + ActiveConnections: stats.ActiveConnections, + TotalConnections: stats.TotalConnections, + TotalEventsSent: stats.TotalEventsSent, + TotalBytesSent: stats.TotalBytesSent, + ConnectionsByMock: connsByMock, + }, nil +} + +// MapError converts an SSE engine error to HTTP status, code, and message. +func (p *sseStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapSSEEngineError(err, log, operation) +} + +// ProtocolName returns the protocol name for error logging. +func (p *sseStatsProvider) ProtocolName() string { return "SSE" } + +// wsStatsProvider implements statsProvider for WebSocket statistics. +type wsStatsProvider struct { + engine *engineclient.Client +} + +// newWSStatsProvider creates a new WebSocket statistics provider. +func newWSStatsProvider(engine *engineclient.Client) *wsStatsProvider { + return &wsStatsProvider{engine: engine} +} + +// GetStats retrieves WebSocket statistics from the engine. +func (p *wsStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetWebSocketStats(ctx) + if err != nil { + return nil, err + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + return engineclient.WebSocketStats{ + TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + ConnectionsByMock: connsByMock, + }, nil +} + +// MapError converts a WebSocket engine error to HTTP status, code, and message. +func (p *wsStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapWebSocketEngineError(err, log, operation) +} + +// ProtocolName returns the protocol name for error logging. +func (p *wsStatsProvider) ProtocolName() string { return "WebSocket" } + +// mqttStatsProvider implements statsProvider for MQTT statistics. +type mqttStatsProvider struct { + engine *engineclient.Client +} + +// newMQTTStatsProvider creates a new MQTT statistics provider. +func newMQTTStatsProvider(engine *engineclient.Client) *mqttStatsProvider { + return &mqttStatsProvider{engine: engine} +} + +// GetStats retrieves MQTT statistics from the engine. +func (p *mqttStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetMQTTStats(ctx) + if err != nil { + return nil, err + } + + subsByClient := stats.SubscriptionsByClient + if subsByClient == nil { + subsByClient = make(map[string]int) + } + + return engineclient.MQTTStats{ + ConnectedClients: stats.ConnectedClients, + TotalSubscriptions: stats.TotalSubscriptions, + TopicCount: stats.TopicCount, + Port: stats.Port, + TLSEnabled: stats.TLSEnabled, + AuthEnabled: stats.AuthEnabled, + SubscriptionsByClient: subsByClient, + }, nil +} + +// MapError converts an MQTT engine error to HTTP status, code, and message. +func (p *mqttStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapMQTTEngineError(err, log, operation) +} + +// ProtocolName returns the protocol name for error logging. +func (p *mqttStatsProvider) ProtocolName() string { return "MQTT" } diff --git a/pkg/admin/types.go b/pkg/admin/types.go index e0e192f6..e0877671 100644 --- a/pkg/admin/types.go +++ b/pkg/admin/types.go @@ -42,3 +42,14 @@ type MockInvocationListResponse struct { Count int `json:"count"` Total int `json:"total"` } + +// WebSocketSendRequest represents a request to send a message to a WebSocket connection. +type WebSocketSendRequest struct { + // Type is the message type: "text" (default) or "binary". + Type string `json:"type"` + // Data is the message payload. + // For type "text", Data is a plain UTF-8 string. + // For type "binary", Data must be a base64-encoded string; the server decodes it + // before sending the raw bytes over the WebSocket connection. + Data string `json:"data"` +} diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go new file mode 100644 index 00000000..339c5167 --- /dev/null +++ b/pkg/admin/websocket_handlers.go @@ -0,0 +1,290 @@ +package admin + +import ( + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" + "github.com/getmockd/mockd/pkg/store" +) + +// WebSocketConnectionListResponse represents a list of WebSocket connections with stats. +type WebSocketConnectionListResponse struct { + Connections []*engineclient.WebSocketConnection `json:"connections"` + Stats engineclient.WebSocketStats `json:"stats"` +} + +// handleListWebSocketConnections handles GET /websocket/connections. +func (a *API) handleListWebSocketConnections(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ + Connections: []*engineclient.WebSocketConnection{}, + Stats: engineclient.WebSocketStats{ConnectionsByMock: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetWebSocketStats(ctx) + if err != nil { + a.logger().Error("failed to get WebSocket stats", "error", err) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket stats") + writeError(w, status, code, msg) + return + } + + // Stats and connections are fetched in two separate calls; they are not + // atomically consistent. Connections may change between the two requests, + // so the counts in stats may not exactly match the length of the returned + // connection list. This is intentional — correctness is not required here. + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + a.logger().Error("failed to list WebSocket connections", "error", err) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "list WebSocket connections") + writeError(w, status, code, msg) + return + } + + if connections == nil { + connections = []*engineclient.WebSocketConnection{} + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ + Connections: connections, + Stats: engineclient.WebSocketStats{ + TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + ConnectionsByMock: connsByMock, + }, + }) +} + +// handleGetWebSocketConnection handles GET /websocket/connections/{id}. +func (a *API) handleGetWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + conn, err := engine.GetWebSocketConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to get WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, conn) +} + +// handleCloseWebSocketConnection handles DELETE /websocket/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers +func (a *API) handleCloseWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.CloseWebSocketConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to close WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "close WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connection closed", + "connection": id, + }) +} + +// handleGetWebSocketStats handles GET /websocket/stats. +func (a *API) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) { + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ConnectionsByMock: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newWSStatsProvider(engine)) +} + +func mapWebSocketEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} + +// handleSendToWebSocketConnection handles POST /websocket/connections/{id}/send. +// It forwards a text or binary message to a specific active connection through the engine. +func (a *API) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + var req WebSocketSendRequest + if err := decodeOptionalJSONBody(r, &req); err != nil { + writeJSONDecodeError(w, err, a.logger()) + return + } + if req.Type == "" { + req.Type = "text" + } + if req.Type != "text" && req.Type != "binary" { + writeError(w, http.StatusBadRequest, "invalid_type", `Type must be "text" or "binary"`) + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.SendToWebSocketConnection(ctx, id, req.Type, req.Data) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to send to WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "send to WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Message sent", + "connection": id, + "type": req.Type, + }) +} + +// handleListMockWebSocketConnections handles GET /mocks/{id}/websocket/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers +func (a *API) handleListMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { + ctx := r.Context() + mockID := r.PathValue("id") + if mockID == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Mock ID is required") + return + } + + // Verify mock exists in the admin store (single source of truth). + if mockStore := a.getMockStore(); mockStore != nil { + if _, err := mockStore.Get(ctx, mockID); err != nil { + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Mock not found") + return + } + writeError(w, http.StatusInternalServerError, "store_error", ErrMsgInternalError) + return + } + } + + // Get all WebSocket connections and filter by mock + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, a.logger(), "list WebSocket connections for mock")) + return + } + + var filtered []*engineclient.WebSocketConnection + for _, conn := range connections { + if conn.MockID == mockID { + filtered = append(filtered, conn) + } + } + if filtered == nil { + filtered = []*engineclient.WebSocketConnection{} + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "connections": filtered, + "count": len(filtered), + "mockId": mockID, + }) +} + +// handleCloseMockWebSocketConnections handles DELETE /mocks/{id}/websocket/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers +func (a *API) handleCloseMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { + ctx := r.Context() + mockID := r.PathValue("id") + if mockID == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Mock ID is required") + return + } + + // Verify mock exists in the admin store (single source of truth). + if mockStore := a.getMockStore(); mockStore != nil { + if _, err := mockStore.Get(ctx, mockID); err != nil { + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Mock not found") + return + } + writeError(w, http.StatusInternalServerError, "store_error", ErrMsgInternalError) + return + } + } + + // Get all WebSocket connections and close those for this mock + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, a.logger(), "list WebSocket connections for close")) + return + } + + closed := 0 + for _, conn := range connections { + if conn.MockID == mockID { + if err := engine.CloseWebSocketConnection(ctx, conn.ID); err == nil { + closed++ + } + } + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connections closed", + "closed": closed, + "mockId": mockID, + }) +} diff --git a/pkg/admin/websocket_handlers_test.go b/pkg/admin/websocket_handlers_test.go new file mode 100644 index 00000000..d611a5c8 --- /dev/null +++ b/pkg/admin/websocket_handlers_test.go @@ -0,0 +1,329 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// ============================================================================ +// handleListWebSocketConnections +// ============================================================================ + +func TestHandleListWebSocketConnections_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.NotNil(t, resp.Stats.ConnectionsByMock) +} + +func TestHandleListWebSocketConnections_WithEngine_ReturnsConnectionsAndStats(t *testing.T) { + // Spin up a mock engine that returns connections and stats. + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/websocket/stats": + _, _ = w.Write([]byte(`{"totalConnections":5,"activeConnections":2,"totalMessagesSent":100,"totalMessagesRecv":50,"connectionsByMock":{"mock-1":2}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/websocket/connections": + _, _ = w.Write([]byte(`{"connections":[{"id":"conn-1","path":"/ws","connectedAt":"2026-04-05T00:00:00Z","messagesSent":10,"messagesRecv":5,"status":"connected"}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 1) + assert.Equal(t, "conn-1", resp.Connections[0].ID) + assert.Equal(t, "/ws", resp.Connections[0].Path) + assert.Equal(t, int64(10), resp.Connections[0].MessagesSent) + assert.Equal(t, 2, resp.Stats.ActiveConnections) + assert.Equal(t, int64(5), resp.Stats.TotalConnections) + assert.Equal(t, 2, resp.Stats.ConnectionsByMock["mock-1"]) +} + +func TestHandleListWebSocketConnections_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetWebSocketConnection +// ============================================================================ + +func TestHandleGetWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/", nil) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleCloseWebSocketConnection +// ============================================================================ + +func TestHandleCloseWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/", nil) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCloseWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCloseWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetWebSocketStats +// ============================================================================ + +func TestHandleGetWebSocketStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) +} + +func TestHandleGetWebSocketStats_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleSendToWebSocketConnection +// ============================================================================ + +func TestHandleSendToWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections//send", strings.NewReader(`{"type":"text","data":"hello"}`)) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSendToWebSocketConnection_InvalidJSON_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", strings.NewReader(`{invalid`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSendToWebSocketConnection_InvalidType_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"invalid","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "invalid_type", resp["error"]) +} + +func TestHandleSendToWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleSendToWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +func TestHandleSendToWebSocketConnection_EmptyBody_DefaultsToText(t *testing.T) { + // No engine — expects 404, but verifies empty body doesn't fail decode + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", strings.NewReader("")) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + // No engine → 404, not a 400 parse error + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleSendToWebSocketConnection_Success_Returns200(t *testing.T) { + // Spin up a minimal mock engine that accepts the send call and returns 200. + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/send") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"Message sent","connection":"conn-1","type":"text"}`)) + return + } + http.NotFound(w, r) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + body := strings.NewReader(`{"type":"text","data":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", body) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "Message sent", resp["message"]) + assert.Equal(t, "conn-1", resp["connection"]) + assert.Equal(t, "text", resp["type"]) +} diff --git a/pkg/api/types/responses.go b/pkg/api/types/responses.go index a11e04d9..4c3b4c13 100644 --- a/pkg/api/types/responses.go +++ b/pkg/api/types/responses.go @@ -458,7 +458,7 @@ type SSEConnection struct { // SSEConnectionListResponse lists SSE connections. type SSEConnectionListResponse struct { Connections []*SSEConnection `json:"connections"` - Count int `json:"count"` + Stats SSEStats `json:"stats"` } // SSEStats represents SSE statistics. @@ -471,26 +471,54 @@ type SSEStats struct { ConnectionsByMock map[string]int `json:"connectionsByMock"` } +// --- MQTT --- + +// MQTTConnection represents an active MQTT client connection. +type MQTTConnection struct { + ID string `json:"id"` + BrokerID string `json:"brokerId"` + ConnectedAt time.Time `json:"connectedAt"` + Subscriptions []string `json:"subscriptions"` + ProtocolVersion byte `json:"protocolVersion"` + Username string `json:"username,omitempty"` + RemoteAddr string `json:"remoteAddr,omitempty"` + Status string `json:"status"` +} + +// MQTTConnectionListResponse lists MQTT connections. +type MQTTConnectionListResponse struct { + Connections []*MQTTConnection `json:"connections"` + Stats MQTTStats `json:"stats"` +} + +// MQTTStats represents MQTT broker statistics. +type MQTTStats struct { + ConnectedClients int `json:"connectedClients"` + TotalSubscriptions int `json:"totalSubscriptions"` + TopicCount int `json:"topicCount"` + Port int `json:"port"` + TLSEnabled bool `json:"tlsEnabled"` + AuthEnabled bool `json:"authEnabled"` + SubscriptionsByClient map[string]int `json:"subscriptionsByClient"` +} + // --- WebSocket --- // WebSocketConnection represents an active WebSocket connection. type WebSocketConnection struct { - ID string `json:"id"` - MockID string `json:"mockId"` - Path string `json:"path"` - ClientIP string `json:"clientIp"` - ConnectedAt time.Time `json:"connectedAt"` - MessagesSent int64 `json:"messagesSent"` - MessagesRecv int64 `json:"messagesRecv"` - BytesSent int64 `json:"bytesSent"` - BytesReceived int64 `json:"bytesReceived"` - Status string `json:"status"` + ID string `json:"id"` + MockID string `json:"mockId,omitempty"` + Path string `json:"path"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` + Status string `json:"status"` } // WebSocketConnectionListResponse lists WebSocket connections. type WebSocketConnectionListResponse struct { Connections []*WebSocketConnection `json:"connections"` - Count int `json:"count"` + Stats WebSocketStats `json:"stats"` } // WebSocketStats represents WebSocket statistics. @@ -501,3 +529,32 @@ type WebSocketStats struct { TotalMessagesRecv int64 `json:"totalMessagesRecv"` ConnectionsByMock map[string]int `json:"connectionsByMock"` } + +// --- gRPC --- + +// GRPCStream represents an active gRPC streaming RPC. +type GRPCStream struct { + ID string `json:"id"` + Method string `json:"method"` + StreamType string `json:"streamType"` + ClientAddr string `json:"clientAddr,omitempty"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` +} + +// GRPCStreamListResponse lists gRPC streams. +type GRPCStreamListResponse struct { + Streams []*GRPCStream `json:"streams"` + Stats GRPCStats `json:"stats"` +} + +// GRPCStats represents gRPC statistics. +type GRPCStats struct { + ActiveStreams int `json:"activeStreams"` + TotalStreams int64 `json:"totalStreams"` + TotalRPCs int64 `json:"totalRPCs"` + TotalMessagesSent int64 `json:"totalMessagesSent"` + TotalMessagesRecv int64 `json:"totalMessagesRecv"` + StreamsByMethod map[string]int `json:"streamsByMethod"` +} diff --git a/pkg/cli/client.go b/pkg/cli/client.go index 7cfb25ec..86b62505 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -143,6 +143,45 @@ type AdminClient interface { AddEngineWorkspace(engineID, workspaceID, workspaceName string) error // BulkCreateMocks creates multiple mocks in a single request. BulkCreateMocks(mocks []*mock.Mock, workspaceID string) (*BulkCreateResult, error) + + // Connection management + // ListWebSocketConnections returns active WebSocket connections. + ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) + // GetWebSocketConnection returns a specific WebSocket connection. + GetWebSocketConnection(id string) (*apitypes.WebSocketConnection, error) + // CloseWebSocketConnection closes a WebSocket connection. + CloseWebSocketConnection(id string) error + // SendWebSocketMessage sends a message to a WebSocket connection. + SendWebSocketMessage(id string, message string, binary bool) error + // GetWebSocketStats returns WebSocket statistics. + GetWebSocketStats() (*apitypes.WebSocketStats, error) + + // ListSSEConnections returns active SSE connections. + ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) + // GetSSEConnection returns a specific SSE connection. + GetSSEConnection(id string) (*apitypes.SSEConnection, error) + // CloseSSEConnection closes an SSE connection. + CloseSSEConnection(id string) error + // GetSSEStats returns SSE statistics. + GetSSEStats() (*apitypes.SSEStats, error) + + // ListMQTTConnections returns active MQTT client connections. + ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) + // GetMQTTConnection returns a specific MQTT client connection. + GetMQTTConnection(id string) (*apitypes.MQTTConnection, error) + // CloseMQTTConnection disconnects an MQTT client. + CloseMQTTConnection(id string) error + // GetMQTTStats returns MQTT broker statistics. + GetMQTTStats() (*apitypes.MQTTStats, error) + + // ListGRPCStreams returns active gRPC streaming connections. + ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) + // GetGRPCStream returns a specific gRPC stream. + GetGRPCStream(id string) (*apitypes.GRPCStream, error) + // CloseGRPCStream cancels a gRPC stream. + CloseGRPCStream(id string) error + // GetGRPCStats returns gRPC statistics. + GetGRPCStats() (*apitypes.GRPCStats, error) } // LogFilter specifies filtering criteria for request logs. @@ -1658,6 +1697,315 @@ func (c *adminClient) BulkCreateMocks(mocks []*mock.Mock, workspaceID string) (* return &result, nil } +// ─── Connection management ────────────────────────────────────────────────── + +// ListWebSocketConnections returns active WebSocket connections. +func (c *adminClient) ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) { + resp, err := c.get("/websocket/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetWebSocketConnection returns a specific WebSocket connection. +func (c *adminClient) GetWebSocketConnection(id string) (*apitypes.WebSocketConnection, error) { + resp, err := c.get("/websocket/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseWebSocketConnection closes a WebSocket connection. +func (c *adminClient) CloseWebSocketConnection(id string) error { + resp, err := c.delete("/websocket/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// SendWebSocketMessage sends a message to a WebSocket connection. +func (c *adminClient) SendWebSocketMessage(id string, message string, binary bool) error { + msgType := "text" + if binary { + msgType = "binary" + } + body, err := json.Marshal(map[string]interface{}{ + "data": message, + "type": msgType, + }) + if err != nil { + return fmt.Errorf("failed to encode request: %w", err) + } + resp, err := c.post("/websocket/connections/"+url.PathEscape(id)+"/send", body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return c.parseError(resp) + } + return nil +} + +// GetWebSocketStats returns WebSocket statistics. +func (c *adminClient) GetWebSocketStats() (*apitypes.WebSocketStats, error) { + resp, err := c.get("/websocket/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListSSEConnections returns active SSE connections. +func (c *adminClient) ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) { + resp, err := c.get("/sse/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetSSEConnection returns a specific SSE connection. +func (c *adminClient) GetSSEConnection(id string) (*apitypes.SSEConnection, error) { + resp, err := c.get("/sse/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "SSE connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseSSEConnection closes an SSE connection. +func (c *adminClient) CloseSSEConnection(id string) error { + resp, err := c.delete("/sse/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "SSE connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetSSEStats returns SSE statistics. +func (c *adminClient) GetSSEStats() (*apitypes.SSEStats, error) { + resp, err := c.get("/sse/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListMQTTConnections returns active MQTT client connections. +func (c *adminClient) ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) { + resp, err := c.get("/mqtt-connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetMQTTConnection returns a specific MQTT client connection. +func (c *adminClient) GetMQTTConnection(id string) (*apitypes.MQTTConnection, error) { + resp, err := c.get("/mqtt-connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "MQTT connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseMQTTConnection disconnects an MQTT client. +func (c *adminClient) CloseMQTTConnection(id string) error { + resp, err := c.delete("/mqtt-connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "MQTT connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetMQTTStats returns MQTT broker statistics. +func (c *adminClient) GetMQTTStats() (*apitypes.MQTTStats, error) { + resp, err := c.get("/mqtt-connections/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListGRPCStreams returns active gRPC streaming connections. +func (c *adminClient) ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) { + resp, err := c.get("/grpc/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStreamListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetGRPCStream returns a specific gRPC stream. +func (c *adminClient) GetGRPCStream(id string) (*apitypes.GRPCStream, error) { + resp, err := c.get("/grpc/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "gRPC stream not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStream + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseGRPCStream cancels a gRPC stream. +func (c *adminClient) CloseGRPCStream(id string) error { + resp, err := c.delete("/grpc/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "gRPC stream not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetGRPCStats returns gRPC statistics. +func (c *adminClient) GetGRPCStats() (*apitypes.GRPCStats, error) { + resp, err := c.get("/grpc/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + // get performs an HTTP GET request. func (c *adminClient) get(path string) (*http.Response, error) { return c.doRequest(http.MethodGet, path, nil) diff --git a/pkg/cli/grpc_connections.go b/pkg/cli/grpc_connections.go new file mode 100644 index 00000000..dcc46d8a --- /dev/null +++ b/pkg/cli/grpc_connections.go @@ -0,0 +1,134 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── gRPC connection management ───────────────────────────────────────────── + +var grpcConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active gRPC streaming connections", + Long: `List, inspect, or cancel active gRPC streaming RPC connections.`, + RunE: runGRPCConnectionsList, +} + +var grpcConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active gRPC streams", + Example: ` mockd grpc connections list + mockd grpc connections list --json`, + Args: cobra.NoArgs, + RunE: runGRPCConnectionsList, +} + +func runGRPCConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListGRPCStreams() + if err != nil { + return fmt.Errorf("failed to list gRPC streams: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Streams) == 0 { + fmt.Println("No active gRPC streams") + return + } + tw := output.Table() + _, _ = fmt.Fprintf(tw, "ID\tMETHOD\tSTREAM TYPE\tCLIENT\tCONNECTED\tMSG SENT\tMSG RECV\n") + for _, s := range result.Streams { + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%d\n", + s.ID, s.Method, s.StreamType, s.ClientAddr, + formatDuration(time.Since(s.ConnectedAt)), + s.MessagesSent, s.MessagesRecv) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d stream(s)\n", len(result.Streams)) + }) + return nil +} + +var grpcConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of a gRPC stream", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stream, err := client.GetGRPCStream(args[0]) + if err != nil { + return fmt.Errorf("failed to get gRPC stream: %s", FormatConnectionError(err)) + } + printResult(stream, func() { + fmt.Printf("gRPC Stream: %s\n", stream.ID) + fmt.Printf(" Method: %s\n", stream.Method) + fmt.Printf(" Stream Type: %s\n", stream.StreamType) + fmt.Printf(" Client Addr: %s\n", stream.ClientAddr) + fmt.Printf(" Connected: %s (%s ago)\n", stream.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(stream.ConnectedAt))) + fmt.Printf(" Messages Sent: %d\n", stream.MessagesSent) + fmt.Printf(" Messages Recv: %d\n", stream.MessagesRecv) + }) + return nil + }, +} + +var grpcConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Cancel a gRPC stream", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseGRPCStream(args[0]); err != nil { + return fmt.Errorf("failed to cancel gRPC stream: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "cancelled": true}, func() { + fmt.Printf("Cancelled gRPC stream: %s\n", args[0]) + }) + return nil + }, +} + +//nolint:dupl // intentionally parallel structure with other protocol stats commands +var grpcStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show gRPC statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetGRPCStats() + if err != nil { + return fmt.Errorf("failed to get gRPC stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("gRPC Statistics") + fmt.Printf(" Active Streams: %d\n", stats.ActiveStreams) + fmt.Printf(" Total Streams: %d\n", stats.TotalStreams) + fmt.Printf(" Total RPCs: %d\n", stats.TotalRPCs) + fmt.Printf(" Total Messages Sent: %d\n", stats.TotalMessagesSent) + fmt.Printf(" Total Messages Recv: %d\n", stats.TotalMessagesRecv) + if len(stats.StreamsByMethod) > 0 { + fmt.Println(" Streams by Method:") + for method, count := range stats.StreamsByMethod { + fmt.Printf(" %s: %d\n", method, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + grpcConnectionsCmd.AddCommand(grpcConnectionsListCmd) + grpcConnectionsCmd.AddCommand(grpcConnectionsGetCmd) + grpcConnectionsCmd.AddCommand(grpcConnectionsCloseCmd) + grpcCmd.AddCommand(grpcConnectionsCmd) + + // stats + grpcCmd.AddCommand(grpcStatsCmd) +} diff --git a/pkg/cli/mqtt_connections.go b/pkg/cli/mqtt_connections.go new file mode 100644 index 00000000..f8fb929c --- /dev/null +++ b/pkg/cli/mqtt_connections.go @@ -0,0 +1,142 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── MQTT connection management ───────────────────────────────────────────── + +var mqttConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active MQTT client connections", + Long: `List, inspect, or disconnect active MQTT client connections.`, + RunE: runMQTTConnectionsList, +} + +var mqttConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active MQTT client connections", + Example: ` mockd mqtt connections list + mockd mqtt connections list --json`, + Args: cobra.NoArgs, + RunE: runMQTTConnectionsList, +} + +func runMQTTConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListMQTTConnections() + if err != nil { + return fmt.Errorf("failed to list MQTT connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active MQTT connections") + return + } + tw := output.Table() + _, _ = fmt.Fprintf(tw, "ID\tREMOTE ADDR\tUSERNAME\tSUBSCRIPTIONS\tCONNECTED\tSTATUS\n") + for _, c := range result.Connections { + subs := fmt.Sprintf("%d topic(s)", len(c.Subscriptions)) + username := c.Username + if username == "" { + username = "-" + } + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", + c.ID, c.RemoteAddr, username, subs, + formatDuration(time.Since(c.ConnectedAt)), + c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) + }) + return nil +} + +var mqttConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of an MQTT client connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetMQTTConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get MQTT connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("MQTT Connection: %s\n", conn.ID) + fmt.Printf(" Broker ID: %s\n", conn.BrokerID) + fmt.Printf(" Remote Addr: %s\n", conn.RemoteAddr) + if conn.Username != "" { + fmt.Printf(" Username: %s\n", conn.Username) + } + fmt.Printf(" Protocol Version: %d\n", conn.ProtocolVersion) + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Subscriptions: %s\n", formatStringSlice(conn.Subscriptions)) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var mqttConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Disconnect an MQTT client", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseMQTTConnection(args[0]); err != nil { + return fmt.Errorf("failed to close MQTT connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Disconnected MQTT client: %s\n", args[0]) + }) + return nil + }, +} + +var mqttStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show MQTT broker statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetMQTTStats() + if err != nil { + return fmt.Errorf("failed to get MQTT stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("MQTT Broker Statistics") + fmt.Printf(" Connected Clients: %d\n", stats.ConnectedClients) + fmt.Printf(" Total Subscriptions: %d\n", stats.TotalSubscriptions) + fmt.Printf(" Topic Count: %d\n", stats.TopicCount) + fmt.Printf(" Port: %d\n", stats.Port) + fmt.Printf(" TLS Enabled: %v\n", stats.TLSEnabled) + fmt.Printf(" Auth Enabled: %v\n", stats.AuthEnabled) + if len(stats.SubscriptionsByClient) > 0 { + fmt.Println(" Subscriptions by Client:") + for clientID, count := range stats.SubscriptionsByClient { + fmt.Printf(" %s: %d\n", clientID, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + mqttConnectionsCmd.AddCommand(mqttConnectionsListCmd) + mqttConnectionsCmd.AddCommand(mqttConnectionsGetCmd) + mqttConnectionsCmd.AddCommand(mqttConnectionsCloseCmd) + mqttCmd.AddCommand(mqttConnectionsCmd) + + // stats + mqttCmd.AddCommand(mqttStatsCmd) +} diff --git a/pkg/cli/print.go b/pkg/cli/print.go index e8e55108..c58b95d6 100644 --- a/pkg/cli/print.go +++ b/pkg/cli/print.go @@ -1,6 +1,10 @@ package cli -import "github.com/getmockd/mockd/pkg/cli/internal/output" +import ( + "strings" + + "github.com/getmockd/mockd/pkg/cli/internal/output" +) // printResult outputs a single operation result. // @@ -15,6 +19,14 @@ func printResult(data any, textFn func()) { textFn() } +// formatStringSlice joins strings for display. +func formatStringSlice(items []string) string { + if len(items) == 0 { + return "(none)" + } + return strings.Join(items, ", ") +} + // printList outputs a collection of items. // // Same contract as printResult. textFn typically uses output.Table() for diff --git a/pkg/cli/sse.go b/pkg/cli/sse.go new file mode 100644 index 00000000..9671567b --- /dev/null +++ b/pkg/cli/sse.go @@ -0,0 +1,183 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── SSE parent command ───────────────────────────────────────────────────── + +var sseCmd = &cobra.Command{ + Use: "sse", + Short: "Manage SSE (Server-Sent Events) connections", +} + +var sseAddCmd = &cobra.Command{ + Use: "add", + Short: "Add a new SSE mock endpoint", + RunE: func(cmd *cobra.Command, args []string) error { + addMockType = "sse" + return runAdd(cmd, args) + }, +} + +// ─── SSE connection management ────────────────────────────────────────────── + +var sseConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active SSE connections", + Long: `List, inspect, or close active SSE connections.`, + RunE: runSSEConnectionsList, +} + +var sseConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active SSE connections", + Example: ` mockd sse connections list + mockd sse connections list --json`, + Args: cobra.NoArgs, + RunE: runSSEConnectionsList, +} + +func runSSEConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListSSEConnections() + if err != nil { + return fmt.Errorf("failed to list SSE connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active SSE connections") + return + } + tw := output.Table() + _, _ = fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCLIENT IP\tCONNECTED\tEVENTS\tSTATUS\n") + for _, c := range result.Connections { + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%s\n", + c.ID, c.Path, c.MockID, c.ClientIP, + formatDuration(time.Since(c.ConnectedAt)), + c.EventsSent, c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) + }) + return nil +} + +var sseConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of an SSE connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetSSEConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get SSE connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("SSE Connection: %s\n", conn.ID) + fmt.Printf(" Path: %s\n", conn.Path) + fmt.Printf(" Mock ID: %s\n", conn.MockID) + fmt.Printf(" Client IP: %s\n", conn.ClientIP) + if conn.UserAgent != "" { + fmt.Printf(" User Agent: %s\n", conn.UserAgent) + } + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Events Sent: %d\n", conn.EventsSent) + fmt.Printf(" Bytes Sent: %d\n", conn.BytesSent) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var sseConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Close an SSE connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseSSEConnection(args[0]); err != nil { + return fmt.Errorf("failed to close SSE connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Closed SSE connection: %s\n", args[0]) + }) + return nil + }, +} + +//nolint:dupl // intentionally parallel structure with other protocol stats commands +var sseStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show SSE statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetSSEStats() + if err != nil { + return fmt.Errorf("failed to get SSE stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("SSE Statistics") + fmt.Printf(" Active Connections: %d\n", stats.ActiveConnections) + fmt.Printf(" Total Connections: %d\n", stats.TotalConnections) + fmt.Printf(" Total Events Sent: %d\n", stats.TotalEventsSent) + fmt.Printf(" Total Bytes Sent: %d\n", stats.TotalBytesSent) + fmt.Printf(" Connection Errors: %d\n", stats.ConnectionErrors) + if len(stats.ConnectionsByMock) > 0 { + fmt.Println(" Connections by Mock:") + for mockID, count := range stats.ConnectionsByMock { + fmt.Printf(" %s: %d\n", mockID, count) + } + } + }) + return nil + }, +} + +// ─── init ─────────────────────────────────────────────────────────────────── + +func init() { + rootCmd.AddCommand(sseCmd) + + // sse add + sseAddCmd.Flags().StringVar(&addPath, "path", "", "SSE endpoint path (e.g., /events)") + sseAddCmd.Flags().StringVar(&addName, "name", "", "Mock display name") + sseCmd.AddCommand(sseAddCmd) + + // list/get/delete generic aliases + sseCmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List SSE mocks", + RunE: func(cmd *cobra.Command, args []string) error { + listMockType = "sse" + return runList(cmd, args) + }, + }) + sseCmd.AddCommand(&cobra.Command{ + Use: "get", + Short: "Get details of an SSE mock", + RunE: runGet, + }) + sseCmd.AddCommand(&cobra.Command{ + Use: "delete", + Short: "Delete an SSE mock", + RunE: runDelete, + }) + + // connections subgroup + sseConnectionsCmd.AddCommand(sseConnectionsListCmd) + sseConnectionsCmd.AddCommand(sseConnectionsGetCmd) + sseConnectionsCmd.AddCommand(sseConnectionsCloseCmd) + sseCmd.AddCommand(sseConnectionsCmd) + + // stats + sseCmd.AddCommand(sseStatsCmd) +} diff --git a/pkg/cli/ws_connections.go b/pkg/cli/ws_connections.go new file mode 100644 index 00000000..88ef1617 --- /dev/null +++ b/pkg/cli/ws_connections.go @@ -0,0 +1,172 @@ +package cli + +import ( + "encoding/base64" + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── websocket connection management ──────────────────────────────────────── + +var wsConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active WebSocket connections", + Long: `List, inspect, close, or send messages to active WebSocket connections.`, + RunE: runWSConnectionsList, +} + +var wsConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active WebSocket connections", + Example: ` mockd websocket connections list + mockd websocket connections list --json`, + Args: cobra.NoArgs, + RunE: runWSConnectionsList, +} + +func runWSConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListWebSocketConnections() + if err != nil { + return fmt.Errorf("failed to list WebSocket connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active WebSocket connections") + return + } + tw := output.Table() + _, _ = fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCONNECTED\tMSG SENT\tMSG RECV\tSTATUS\n") + for _, c := range result.Connections { + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%d\t%s\n", + c.ID, c.Path, c.MockID, + formatDuration(time.Since(c.ConnectedAt)), + c.MessagesSent, c.MessagesRecv, c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) + }) + return nil +} + +var wsConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of a WebSocket connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetWebSocketConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get WebSocket connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("WebSocket Connection: %s\n", conn.ID) + fmt.Printf(" Path: %s\n", conn.Path) + fmt.Printf(" Mock ID: %s\n", conn.MockID) + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Messages Sent: %d\n", conn.MessagesSent) + fmt.Printf(" Messages Recv: %d\n", conn.MessagesRecv) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var wsConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Close a WebSocket connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseWebSocketConnection(args[0]); err != nil { + return fmt.Errorf("failed to close WebSocket connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Closed WebSocket connection: %s\n", args[0]) + }) + return nil + }, +} + +var wsConnectionsSendBinary bool + +var wsConnectionsSendCmd = &cobra.Command{ + Use: "send ", + Short: "Send a message to a WebSocket connection", + Long: `Send a text or binary message to an active WebSocket connection. +Use --binary to send base64-encoded binary data.`, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + id := args[0] + message := args[1] + + if wsConnectionsSendBinary { + // Validate base64 + if _, err := base64.StdEncoding.DecodeString(message); err != nil { + return fmt.Errorf("invalid base64 data: %w", err) + } + } + + client := NewAdminClientWithAuth(adminURL) + if err := client.SendWebSocketMessage(id, message, wsConnectionsSendBinary); err != nil { + return fmt.Errorf("failed to send message: %s", FormatConnectionError(err)) + } + + msgType := "text" + if wsConnectionsSendBinary { + msgType = "binary" + } + printResult(map[string]interface{}{"id": id, "sent": true, "type": msgType}, func() { + fmt.Printf("Sent %s message to connection %s\n", msgType, id) + }) + return nil + }, +} + +var wsStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show WebSocket statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetWebSocketStats() + if err != nil { + return fmt.Errorf("failed to get WebSocket stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("WebSocket Statistics") + fmt.Printf(" Active Connections: %d\n", stats.ActiveConnections) + fmt.Printf(" Total Connections: %d\n", stats.TotalConnections) + fmt.Printf(" Total Messages Sent: %d\n", stats.TotalMessagesSent) + fmt.Printf(" Total Messages Recv: %d\n", stats.TotalMessagesRecv) + if len(stats.ConnectionsByMock) > 0 { + fmt.Println(" Connections by Mock:") + for mockID, count := range stats.ConnectionsByMock { + fmt.Printf(" %s: %d\n", mockID, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + wsConnectionsCmd.AddCommand(wsConnectionsListCmd) + wsConnectionsCmd.AddCommand(wsConnectionsGetCmd) + wsConnectionsCmd.AddCommand(wsConnectionsCloseCmd) + + wsConnectionsSendCmd.Flags().BoolVar(&wsConnectionsSendBinary, "binary", false, "Send base64-encoded binary message") + wsConnectionsCmd.AddCommand(wsConnectionsSendCmd) + + websocketCmd.AddCommand(wsConnectionsCmd) + + // stats as top-level websocket subcommand + websocketCmd.AddCommand(wsStatsCmd) +} diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index 6b63c320..69ebb7fc 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -14,6 +15,7 @@ import ( "github.com/getmockd/mockd/pkg/httputil" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" ) func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { @@ -738,9 +740,14 @@ func (s *Server) handleGetHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { connections := s.engine.ListSSEConnections() + stats := s.engine.GetSSEStats() + respStats := SSEStats{ConnectionsByMock: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, SSEConnectionListResponse{ Connections: connections, - Count: len(connections), + Stats: respStats, }) } @@ -779,9 +786,14 @@ func (s *Server) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { func (s *Server) handleListWebSocketConnections(w http.ResponseWriter, r *http.Request) { connections := s.engine.ListWebSocketConnections() + stats := s.engine.GetWebSocketStats() + respStats := WebSocketStats{ConnectionsByMock: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ Connections: connections, - Count: len(connections), + Stats: respStats, }) } @@ -816,6 +828,64 @@ func (s *Server) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) writeJSON(w, http.StatusOK, stats) } +// handleSendToWebSocketConnection handles POST /websocket/connections/{id}/send. +// It sends a text or binary message to a specific active WebSocket connection. +func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + limitedBody(w, r) + + var req struct { + // Type is "text" (default) or "binary". + // For binary messages, Data must be base64-encoded; it is decoded here + // before the raw bytes are sent over the WebSocket connection. + Type string `json:"type"` + Data string `json:"data"` + } + if err := decodeJSONBody(r, &req, false); err != nil { + writeDecodeError(w, err) + return + } + if req.Type == "" { + req.Type = "text" + } + if req.Type != "text" && req.Type != "binary" { + writeError(w, http.StatusBadRequest, "invalid_type", `Type must be "text" or "binary"`) + return + } + + var payload []byte + if req.Type == "binary" { + var err error + payload, err = base64.StdEncoding.DecodeString(req.Data) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_base64", "For binary messages, data must be a base64-encoded string") + return + } + } else { + payload = []byte(req.Data) + } + + if err := s.engine.SendToWebSocketConnection(id, req.Type, payload); err != nil { + if errors.Is(err, websocket.ErrConnectionNotFound) || errors.Is(err, websocket.ErrConnectionClosed) { + writeError(w, http.StatusNotFound, "not_found", "WebSocket connection not found or closed") + } else { + writeError(w, http.StatusInternalServerError, "send_error", "Failed to send message to WebSocket connection") + } + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Message sent", + "connection": id, + "type": req.Type, + }) +} + // Config handlers func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { @@ -1008,6 +1078,101 @@ func writeJSON(w http.ResponseWriter, status int, v any) { httputil.WriteJSON(w, status, v) } +// MQTT handlers + +func (s *Server) handleListMQTTConnections(w http.ResponseWriter, r *http.Request) { + connections := s.engine.ListMQTTConnections() + stats := s.engine.GetMQTTStats() + respStats := MQTTStats{SubscriptionsByClient: make(map[string]int)} + if stats != nil { + respStats = *stats + } + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: connections, + Stats: respStats, + }) +} + +func (s *Server) handleGetMQTTConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + conn := s.engine.GetMQTTConnection(id) + if conn == nil { + writeError(w, http.StatusNotFound, "not_found", "MQTT connection not found") + return + } + writeJSON(w, http.StatusOK, conn) +} + +func (s *Server) handleCloseMQTTConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.CloseMQTTConnection(id); err != nil { + writeError(w, http.StatusNotFound, "not_found", "MQTT connection not found") + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "message": "MQTT connection closed", + "id": id, + }) +} + +func (s *Server) handleGetMQTTStats(w http.ResponseWriter, r *http.Request) { + stats := s.engine.GetMQTTStats() + if stats == nil { + writeJSON(w, http.StatusOK, MQTTStats{SubscriptionsByClient: make(map[string]int)}) + return + } + writeJSON(w, http.StatusOK, stats) +} + +// gRPC stream handlers + +func (s *Server) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { + streams := s.engine.ListGRPCStreams() + if streams == nil { + streams = []*GRPCStream{} + } + stats := s.engine.GetGRPCStats() + respStats := GRPCStats{StreamsByMethod: make(map[string]int)} + if stats != nil { + respStats = *stats + } + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: streams, + Stats: respStats, + }) +} + +func (s *Server) handleGetGRPCStream(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + stream := s.engine.GetGRPCStream(id) + if stream == nil { + writeError(w, http.StatusNotFound, "not_found", "gRPC stream not found") + return + } + writeJSON(w, http.StatusOK, stream) +} + +func (s *Server) handleCancelGRPCStream(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.CancelGRPCStream(id); err != nil { + writeError(w, http.StatusNotFound, "not_found", "gRPC stream not found") + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "message": "gRPC stream cancelled", + "id": id, + }) +} + +func (s *Server) handleGetGRPCStats(w http.ResponseWriter, r *http.Request) { + stats := s.engine.GetGRPCStats() + if stats == nil { + writeJSON(w, http.StatusOK, GRPCStats{StreamsByMethod: make(map[string]int)}) + return + } + writeJSON(w, http.StatusOK, stats) +} + func writeError(w http.ResponseWriter, status int, code, message string) { httputil.WriteJSON(w, status, ErrorResponse{ Error: code, diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index 2f3957fc..ab9a3911 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -14,6 +14,7 @@ import ( "github.com/getmockd/mockd/pkg/mock" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -29,9 +30,13 @@ type mockEngine struct { stateOverview *StateOverview handlers []*ProtocolHandler sseConnections []*SSEConnection - wsConnections []*WebSocketConnection - sseStats *SSEStats - wsStats *WebSocketStats + wsConnections []*WebSocketConnection + mqttConnections []*MQTTConnection + grpcStreams []*GRPCStream + sseStats *SSEStats + wsStats *WebSocketStats + mqttStats *MQTTStats + grpcStats *GRPCStats configResp *ConfigResponse protocols map[string]ProtocolStatusInfo @@ -48,6 +53,9 @@ type mockEngine struct { resetStateErr error listStatefulItemsErr error getStatefulItemErr error + // wsSendErr allows injecting a specific error for a connection ID in + // SendToWebSocketConnection. Keyed by connection ID. + wsSendErr map[string]error } func newMockEngine() *mockEngine { @@ -55,6 +63,7 @@ func newMockEngine() *mockEngine { mocks: make(map[string]*config.MockConfiguration), requestLogs: make(map[string]*requestlog.Entry), customOps: make(map[string]*CustomOperationDetail), + wsSendErr: make(map[string]error), running: true, uptime: 100, protocols: map[string]ProtocolStatusInfo{ @@ -358,6 +367,73 @@ func (m *mockEngine) GetWebSocketStats() *WebSocketStats { return m.wsStats } +func (m *mockEngine) SendToWebSocketConnection(id string, msgType string, data []byte) error { + // Allow per-connection error injection for testing specific error paths. + if err, ok := m.wsSendErr[id]; ok { + return err + } + for _, c := range m.wsConnections { + if c.ID == id { + return nil + } + } + return websocket.ErrConnectionNotFound +} + +func (m *mockEngine) ListMQTTConnections() []*MQTTConnection { + return m.mqttConnections +} + +func (m *mockEngine) GetMQTTConnection(id string) *MQTTConnection { + for _, c := range m.mqttConnections { + if c.ID == id { + return c + } + } + return nil +} + +func (m *mockEngine) CloseMQTTConnection(id string) error { + for i, c := range m.mqttConnections { + if c.ID == id { + m.mqttConnections = append(m.mqttConnections[:i], m.mqttConnections[i+1:]...) + return nil + } + } + return errors.New("connection not found") +} + +func (m *mockEngine) GetMQTTStats() *MQTTStats { + return m.mqttStats +} + +func (m *mockEngine) ListGRPCStreams() []*GRPCStream { + return m.grpcStreams +} + +func (m *mockEngine) GetGRPCStream(id string) *GRPCStream { + for _, s := range m.grpcStreams { + if s.ID == id { + return s + } + } + return nil +} + +func (m *mockEngine) CancelGRPCStream(id string) error { + for i, s := range m.grpcStreams { + if s.ID == id { + m.grpcStreams = append(m.grpcStreams[:i], m.grpcStreams[i+1:]...) + return nil + } + } + return errors.New("stream not found") +} + +func (m *mockEngine) GetGRPCStats() *GRPCStats { + return m.grpcStats +} + func (m *mockEngine) GetConfig() *ConfigResponse { return m.configResp } @@ -2223,3 +2299,814 @@ func TestCustomOperationFullCRUD(t *testing.T) { server.handleGetCustomOperation(getRec2, getReq2) assert.Equal(t, http.StatusNotFound, getRec2.Code) } + +// TestHandleListWebSocketConnections tests the GET /websocket/connections handler. +func TestHandleListWebSocketConnections(t *testing.T) { + t.Run("returns empty list when no connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + server.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.Equal(t, 0, resp.Stats.ActiveConnections) + }) + + t.Run("returns all connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + {ID: "ws-2", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + server.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 2) + }) +} + +// TestHandleGetWebSocketConnection tests the GET /websocket/connections/{id} handler. +func TestHandleGetWebSocketConnection(t *testing.T) { + t.Run("returns connection when found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/ws-1", nil) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn WebSocketConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "ws-1", conn.ID) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleCloseWebSocketConnection tests the DELETE /websocket/connections/{id} handler. +func TestHandleCloseWebSocketConnection(t *testing.T) { + t.Run("closes existing connection and returns 200", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/ws-1", nil) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "WebSocket connection closed", resp["message"]) + assert.Equal(t, "ws-1", resp["id"]) + + // Connection must be removed from the engine. + assert.Empty(t, engine.wsConnections) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleGetWebSocketStats tests the GET /websocket/stats handler. +func TestHandleGetWebSocketStats(t *testing.T) { + t.Run("returns empty stats with non-nil map when wsStats is nil", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // wsStats is nil by default in newMockEngine. + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) + }) + + t.Run("returns populated stats", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsStats = &WebSocketStats{ + TotalConnections: 5, + ActiveConnections: 3, + TotalMessagesSent: 100, + TotalMessagesRecv: 50, + ConnectionsByMock: map[string]int{"mock-1": 3}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, int64(5), stats.TotalConnections) + assert.Equal(t, 3, stats.ActiveConnections) + assert.Equal(t, int64(100), stats.TotalMessagesSent) + assert.Equal(t, int64(50), stats.TotalMessagesRecv) + assert.Equal(t, 3, stats.ConnectionsByMock["mock-1"]) + }) +} + +// TestHandleSendToWebSocketConnection covers every branch of the send handler, +// including the new 404/500 distinction introduced in the review fix. +func TestHandleSendToWebSocketConnection(t *testing.T) { + t.Run("returns 400 when connection id is missing", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections//send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "missing_id", resp["error"]) + }) + + t.Run("returns 400 for invalid JSON body", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{invalid`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("returns 400 for invalid base64 binary payload", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"binary","data":"not-valid-base64!!!"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "invalid_base64", resp["error"]) + }) + + t.Run("returns 200 for text message to existing connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "Message sent", resp["message"]) + assert.Equal(t, "ws-1", resp["connection"]) + assert.Equal(t, "text", resp["type"]) + }) + + t.Run("returns 200 for valid base64 binary message", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + // base64("hello") = "aGVsbG8=" + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"binary","data":"aGVsbG8="}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "binary", resp["type"]) + }) + + t.Run("defaults type to text when omitted", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "text", resp["type"]) + }) + + t.Run("returns 404 when connection is not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // No connections registered → ErrConnectionNotFound + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/missing/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "missing") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "not_found", resp["error"]) + }) + + t.Run("returns 404 when connection is already closed", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // Inject ErrConnectionClosed for ws-1 (connection exists but is closed) + engine.wsSendErr["ws-1"] = websocket.ErrConnectionClosed + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "not_found", resp["error"]) + }) + + t.Run("returns 500 for unexpected write error", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // Inject a generic I/O error (e.g., broken pipe) + engine.wsSendErr["ws-1"] = errors.New("write: broken pipe") + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "send_error", resp["error"]) + }) +} + +// ============================================================================ +// MQTT handlers +// ============================================================================ + +func TestHandleListMQTTConnections(t *testing.T) { + t.Run("empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + server.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp MQTTConnectionListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Empty(t, resp.Connections) + }) + + t.Run("with connections", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", Subscriptions: []string{"sensors/#"}, Status: "connected"}, + {ID: "client-2", BrokerID: "broker-1", Subscriptions: []string{"devices/+"}, Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + server.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp MQTTConnectionListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Len(t, resp.Connections, 2) + assert.Equal(t, "client-1", resp.Connections[0].ID) + }) +} + +func TestHandleGetMQTTConnection(t *testing.T) { + t.Run("found", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", ProtocolVersion: 5, Subscriptions: []string{"topic/a"}, Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + server.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var conn MQTTConnection + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &conn)) + assert.Equal(t, "client-1", conn.ID) + assert.Equal(t, byte(5), conn.ProtocolVersion) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleCloseMQTTConnection(t *testing.T) { + t.Run("success", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + server.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "MQTT connection closed", resp["message"]) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleGetMQTTStats(t *testing.T) { + t.Run("nil stats returns empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats MQTTStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.NotNil(t, stats.SubscriptionsByClient) + }) + + t.Run("with stats", func(t *testing.T) { + engine := newMockEngine() + engine.mqttStats = &MQTTStats{ + ConnectedClients: 3, + TotalSubscriptions: 7, + TopicCount: 4, + Port: 1883, + SubscriptionsByClient: map[string]int{"c1": 3, "c2": 4}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats MQTTStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, 3, stats.ConnectedClients) + assert.Equal(t, 7, stats.TotalSubscriptions) + assert.Equal(t, 3, stats.SubscriptionsByClient["c1"]) + }) +} + +// ============================================================================ +// gRPC stream handlers +// ============================================================================ + +func TestHandleListGRPCStreams(t *testing.T) { + t.Run("empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + server.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp GRPCStreamListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Empty(t, resp.Streams) + }) + + t.Run("with streams", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi", MessagesSent: 5, MessagesRecv: 3}, + {ID: "stream-2", Method: "/pkg.Svc/Watch", StreamType: "server_stream", MessagesSent: 10}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + server.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp GRPCStreamListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Len(t, resp.Streams, 2) + assert.Equal(t, "stream-1", resp.Streams[0].ID) + }) +} + +func TestHandleGetGRPCStream(t *testing.T) { + t.Run("found", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + server.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stream GRPCStream + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stream)) + assert.Equal(t, "stream-1", stream.ID) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleCancelGRPCStream(t *testing.T) { + t.Run("success", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + server.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleGetGRPCStats(t *testing.T) { + t.Run("nil stats", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats GRPCStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.NotNil(t, stats.StreamsByMethod) + }) + + t.Run("with stats", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStats = &GRPCStats{ + ActiveStreams: 2, + TotalStreams: 10, + TotalRPCs: 100, + TotalMessagesSent: 50, + TotalMessagesRecv: 30, + StreamsByMethod: map[string]int{"/pkg.Svc/Chat": 2}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats GRPCStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, 2, stats.ActiveStreams) + assert.Equal(t, int64(100), stats.TotalRPCs) + assert.Equal(t, 2, stats.StreamsByMethod["/pkg.Svc/Chat"]) + }) +} + +// ============================================================================ +// SSE Handler Tests +// ============================================================================ + +// TestHandleListSSEConnections tests the GET /sse/connections handler. +func TestHandleListSSEConnections(t *testing.T) { + t.Run("returns empty list when no connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/connections", nil) + rec := httptest.NewRecorder() + + server.handleListSSEConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp SSEConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.Equal(t, 0, resp.Stats.ActiveConnections) + }) + + t.Run("returns all connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", Status: "active"}, + {ID: "sse-2", MockID: "mock-1", Path: "/events", Status: "active"}, + } + + req := httptest.NewRequest(http.MethodGet, "/sse/connections", nil) + rec := httptest.NewRecorder() + + server.handleListSSEConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp SSEConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 2) + }) +} + +// TestHandleGetSSEConnection tests the GET /sse/connections/{id} handler. +func TestHandleGetSSEConnection(t *testing.T) { + t.Run("returns connection when found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", ClientIP: "127.0.0.1", Status: "active", EventsSent: 42, BytesSent: 1024}, + } + + req := httptest.NewRequest(http.MethodGet, "/sse/connections/sse-1", nil) + req.SetPathValue("id", "sse-1") + rec := httptest.NewRecorder() + + server.handleGetSSEConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn SSEConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "sse-1", conn.ID) + assert.Equal(t, "mock-1", conn.MockID) + assert.Equal(t, "/events", conn.Path) + assert.Equal(t, "127.0.0.1", conn.ClientIP) + assert.Equal(t, int64(42), conn.EventsSent) + assert.Equal(t, int64(1024), conn.BytesSent) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleGetSSEConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleCloseSSEConnection tests the DELETE /sse/connections/{id} handler. +func TestHandleCloseSSEConnection(t *testing.T) { + t.Run("closes existing connection and returns 200", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", Status: "active"}, + } + + req := httptest.NewRequest(http.MethodDelete, "/sse/connections/sse-1", nil) + req.SetPathValue("id", "sse-1") + rec := httptest.NewRecorder() + + server.handleCloseSSEConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "SSE connection closed", resp["message"]) + assert.Equal(t, "sse-1", resp["id"]) + + // Connection must be removed from the engine. + assert.Empty(t, engine.sseConnections) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/sse/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleCloseSSEConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleGetSSEStats tests the GET /sse/stats handler. +func TestHandleGetSSEStats(t *testing.T) { + t.Run("returns empty stats with non-nil map when sseStats is nil", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetSSEStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats SSEStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) + }) + + t.Run("returns populated stats", func(t *testing.T) { + engine := newMockEngine() + engine.sseStats = &SSEStats{ + TotalConnections: 10, + ActiveConnections: 2, + TotalEventsSent: 500, + TotalBytesSent: 50000, + ConnectionErrors: 1, + ConnectionsByMock: map[string]int{"mock-1": 2}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetSSEStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats SSEStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, int64(10), stats.TotalConnections) + assert.Equal(t, 2, stats.ActiveConnections) + assert.Equal(t, int64(500), stats.TotalEventsSent) + assert.Equal(t, int64(50000), stats.TotalBytesSent) + assert.Equal(t, int64(1), stats.ConnectionErrors) + assert.Equal(t, 2, stats.ConnectionsByMock["mock-1"]) + }) +} diff --git a/pkg/engine/api/server.go b/pkg/engine/api/server.go index 760d492b..baba2c94 100644 --- a/pkg/engine/api/server.go +++ b/pkg/engine/api/server.go @@ -89,8 +89,21 @@ type EngineController interface { ListWebSocketConnections() []*WebSocketConnection GetWebSocketConnection(id string) *WebSocketConnection CloseWebSocketConnection(id string) error + SendToWebSocketConnection(id string, msgType string, data []byte) error GetWebSocketStats() *WebSocketStats + // MQTT connections + ListMQTTConnections() []*MQTTConnection + GetMQTTConnection(id string) *MQTTConnection + CloseMQTTConnection(id string) error + GetMQTTStats() *MQTTStats + + // gRPC streams + ListGRPCStreams() []*GRPCStream + GetGRPCStream(id string) *GRPCStream + CancelGRPCStream(id string) error + GetGRPCStats() *GRPCStats + // Config GetConfig() *ConfigResponse } @@ -224,6 +237,19 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /websocket/connections/{id}", s.handleGetWebSocketConnection) mux.HandleFunc("DELETE /websocket/connections/{id}", s.handleCloseWebSocketConnection) mux.HandleFunc("GET /websocket/stats", s.handleGetWebSocketStats) + mux.HandleFunc("POST /websocket/connections/{id}/send", s.handleSendToWebSocketConnection) + + // MQTT connections + mux.HandleFunc("GET /mqtt/connections", s.handleListMQTTConnections) + mux.HandleFunc("GET /mqtt/connections/{id}", s.handleGetMQTTConnection) + mux.HandleFunc("DELETE /mqtt/connections/{id}", s.handleCloseMQTTConnection) + mux.HandleFunc("GET /mqtt/stats", s.handleGetMQTTStats) + + // gRPC streams + mux.HandleFunc("GET /grpc/connections", s.handleListGRPCStreams) + mux.HandleFunc("GET /grpc/connections/{id}", s.handleGetGRPCStream) + mux.HandleFunc("DELETE /grpc/connections/{id}", s.handleCancelGRPCStream) + mux.HandleFunc("GET /grpc/stats", s.handleGetGRPCStats) // Config mux.HandleFunc("GET /config", s.handleGetConfig) diff --git a/pkg/engine/api/types.go b/pkg/engine/api/types.go index e069ecfb..f55fbc85 100644 --- a/pkg/engine/api/types.go +++ b/pkg/engine/api/types.go @@ -41,6 +41,12 @@ type ( WebSocketConnection = types.WebSocketConnection WebSocketConnectionListResponse = types.WebSocketConnectionListResponse WebSocketStats = types.WebSocketStats + MQTTConnection = types.MQTTConnection + MQTTConnectionListResponse = types.MQTTConnectionListResponse + MQTTStats = types.MQTTStats + GRPCStream = types.GRPCStream + GRPCStreamListResponse = types.GRPCStreamListResponse + GRPCStats = types.GRPCStats ConfigResponse = types.ConfigResponse CustomOperationInfo = types.CustomOperationInfo CustomOperationDetail = types.CustomOperationDetail diff --git a/pkg/engine/control_api.go b/pkg/engine/control_api.go index dda67441..4ab948ca 100644 --- a/pkg/engine/control_api.go +++ b/pkg/engine/control_api.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" - types "github.com/getmockd/mockd/pkg/api/types" + "github.com/getmockd/mockd/pkg/api/types" "github.com/getmockd/mockd/pkg/chaos" "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/engine/api" "github.com/getmockd/mockd/pkg/protocol" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" ) // Errors returned by the control API adapter. @@ -737,6 +738,27 @@ func (a *ControlAPIAdapter) CloseWebSocketConnection(id string) error { return wsManager.CloseConnection(id, "closed by API") } +// SendToWebSocketConnection implements api.EngineController. +// msgType must be "text" or "binary"; data is the raw message payload. +func (a *ControlAPIAdapter) SendToWebSocketConnection(id string, msgType string, data []byte) error { + handler := a.server.Handler() + if handler == nil { + return ErrWebSocketHandlerNotInitialized + } + + wsManager := handler.WebSocketManager() + if wsManager == nil { + return ErrWebSocketHandlerNotInitialized + } + + mt := websocket.MessageText + if msgType == "binary" { + mt = websocket.MessageBinary + } + + return wsManager.SendToConnection(id, mt, data) +} + // GetWebSocketStats implements api.EngineController. func (a *ControlAPIAdapter) GetWebSocketStats() *api.WebSocketStats { handler := a.server.Handler() @@ -769,6 +791,220 @@ func (a *ControlAPIAdapter) GetWebSocketStats() *api.WebSocketStats { } } +// ErrMQTTBrokerNotInitialized is returned when no MQTT broker is available. +var ErrMQTTBrokerNotInitialized = errors.New("MQTT broker not initialized") + +// ListMQTTConnections implements api.EngineController. +func (a *ControlAPIAdapter) ListMQTTConnections() []*api.MQTTConnection { + brokers := a.server.GetMQTTBrokers() + var result []*api.MQTTConnection + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + infos := broker.ListClientInfos() + for _, info := range infos { + if info.Closed { + continue + } + conn := &api.MQTTConnection{ + ID: info.ID, + BrokerID: info.BrokerID, + ConnectedAt: info.ConnectedAt, + Subscriptions: info.Subscriptions, + ProtocolVersion: info.ProtocolVersion, + Username: info.Username, + RemoteAddr: info.RemoteAddr, + Status: "connected", + } + if conn.Subscriptions == nil { + conn.Subscriptions = []string{} + } + result = append(result, conn) + } + } + + return result +} + +// GetMQTTConnection implements api.EngineController. +func (a *ControlAPIAdapter) GetMQTTConnection(id string) *api.MQTTConnection { + brokers := a.server.GetMQTTBrokers() + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + info := broker.GetClientInfo(id) + if info == nil { + continue + } + subs := info.Subscriptions + if subs == nil { + subs = []string{} + } + return &api.MQTTConnection{ + ID: info.ID, + BrokerID: info.BrokerID, + ConnectedAt: info.ConnectedAt, + Subscriptions: subs, + ProtocolVersion: info.ProtocolVersion, + Username: info.Username, + RemoteAddr: info.RemoteAddr, + Status: "connected", + } + } + + return nil +} + +// CloseMQTTConnection implements api.EngineController. +func (a *ControlAPIAdapter) CloseMQTTConnection(id string) error { + brokers := a.server.GetMQTTBrokers() + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + if err := broker.DisconnectClient(id); err == nil { + return nil + } + } + + return fmt.Errorf("MQTT client %q not found", id) +} + +// GetMQTTStats implements api.EngineController. +func (a *ControlAPIAdapter) GetMQTTStats() *api.MQTTStats { + brokers := a.server.GetMQTTBrokers() + if len(brokers) == 0 { + return nil + } + + stats := &api.MQTTStats{ + SubscriptionsByClient: make(map[string]int), + } + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + connected, totalSubs, subsByClient := broker.GetConnectionStats() + stats.ConnectedClients += connected + stats.TotalSubscriptions += totalSubs + for k, v := range subsByClient { + stats.SubscriptionsByClient[k] = v + } + cfg := broker.Config() + if cfg != nil { + stats.TopicCount += len(cfg.Topics) + stats.Port = cfg.Port + stats.TLSEnabled = stats.TLSEnabled || (cfg.TLS != nil && cfg.TLS.Enabled) + stats.AuthEnabled = stats.AuthEnabled || (cfg.Auth != nil && cfg.Auth.Enabled) + } + } + + return stats +} + +// ErrGRPCServerNotInitialized is returned when no gRPC server is available. +var ErrGRPCServerNotInitialized = errors.New("gRPC server not initialized") + +// ListGRPCStreams implements api.EngineController. +func (a *ControlAPIAdapter) ListGRPCStreams() []*api.GRPCStream { + servers := a.server.GRPCServers() + var result []*api.GRPCStream + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + for _, info := range srv.StreamTracker().List() { + result = append(result, &api.GRPCStream{ + ID: info.ID, + Method: info.Method, + StreamType: string(info.StreamType), + ClientAddr: info.ClientAddr, + ConnectedAt: info.ConnectedAt, + MessagesSent: info.MessagesSent, + MessagesRecv: info.MessagesRecv, + }) + } + } + + return result +} + +// GetGRPCStream implements api.EngineController. +func (a *ControlAPIAdapter) GetGRPCStream(id string) *api.GRPCStream { + servers := a.server.GRPCServers() + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + if info := srv.StreamTracker().Get(id); info != nil { + return &api.GRPCStream{ + ID: info.ID, + Method: info.Method, + StreamType: string(info.StreamType), + ClientAddr: info.ClientAddr, + ConnectedAt: info.ConnectedAt, + MessagesSent: info.MessagesSent, + MessagesRecv: info.MessagesRecv, + } + } + } + + return nil +} + +// CancelGRPCStream implements api.EngineController. +func (a *ControlAPIAdapter) CancelGRPCStream(id string) error { + servers := a.server.GRPCServers() + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + if err := srv.StreamTracker().Cancel(id); err == nil { + return nil + } + } + + return fmt.Errorf("gRPC stream %s not found", id) +} + +// GetGRPCStats implements api.EngineController. +func (a *ControlAPIAdapter) GetGRPCStats() *api.GRPCStats { + servers := a.server.GRPCServers() + if len(servers) == 0 { + return nil + } + + stats := &api.GRPCStats{ + StreamsByMethod: make(map[string]int), + } + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + s := srv.StreamTracker().Stats() + stats.ActiveStreams += s.ActiveStreams + stats.TotalStreams += s.TotalStreams + stats.TotalRPCs += s.TotalRPCs + stats.TotalMessagesSent += s.TotalMsgSent + stats.TotalMessagesRecv += s.TotalMsgRecv + for method, count := range s.StreamsByMethod { + stats.StreamsByMethod[method] += count + } + } + + return stats +} + // GetConfig implements api.EngineController. func (a *ControlAPIAdapter) GetConfig() *api.ConfigResponse { cfg := a.server.Config() diff --git a/pkg/engine/handler_protocol.go b/pkg/engine/handler_protocol.go index f2fde048..1663667e 100644 --- a/pkg/engine/handler_protocol.go +++ b/pkg/engine/handler_protocol.go @@ -150,6 +150,28 @@ func (h *Handler) UnregisterWebSocketEndpoint(path string) { h.wsManager.UnregisterEndpoint(path) } +// DisconnectWebSocketEndpoint closes all active connections on a WebSocket endpoint. +// Uses RFC 6455 close code 1012 (Service Restart) so clients reconnect automatically. +// Must be called before UnregisterWebSocketEndpoint while byEndpoint still tracks the path. +func (h *Handler) DisconnectWebSocketEndpoint(path string) { + h.wsManager.DisconnectByEndpoint(path, websocket.CloseServiceRestart, "mock updated") +} + +// DisconnectSSEByMock closes all active SSE connections for the given mock ID. +// Cancelling each connection's context causes the event loop to exit, the response +// writer to close, and the client to receive EOF. EventSource clients will +// auto-reconnect and pick up the new configuration. +func (h *Handler) DisconnectSSEByMock(mockID string) { + if h.sseHandler == nil { + return + } + mgr := h.sseHandler.GetManager() + if mgr == nil { + return + } + mgr.CloseByMock(mockID) +} + // ListSOAPHandlerPaths returns all registered SOAP handler paths. func (h *Handler) ListSOAPHandlerPaths() []string { h.soapMu.RLock() diff --git a/pkg/engine/mock_manager.go b/pkg/engine/mock_manager.go index 34e27cbc..0765a66f 100644 --- a/pkg/engine/mock_manager.go +++ b/pkg/engine/mock_manager.go @@ -181,15 +181,44 @@ func (mm *MockManager) unregisterHandlerLocked(cfg *config.MockConfiguration) { return } - switch cfg.Type { //nolint:exhaustive // HTTP mocks don't need cleanup on removal + switch cfg.Type { //nolint:exhaustive // plain HTTP mocks don't need cleanup on removal + case mock.TypeHTTP: + // HTTP mocks with SSE configuration have active streaming connections + // that must be disconnected before the mock is removed or updated. + if mm.handler != nil && cfg.HTTP != nil && cfg.HTTP.SSE != nil { + mm.handler.DisconnectSSEByMock(cfg.ID) + } case mock.TypeGRPC: if mm.protocolManager != nil { + // Cancel active streams first so clients receive codes.Unavailable + // and can reconnect with retry policies (gRPC equivalent of WS 1012). + if srv := mm.protocolManager.GetGRPCServer(cfg.ID); srv != nil { + n := srv.StreamTracker().CancelAll() + if n > 0 { + mm.log.Info("cancelled active gRPC streams", "id", cfg.ID, "count", n) + } + } if err := mm.protocolManager.StopGRPCServer(cfg.ID); err != nil { mm.log.Warn("failed to stop gRPC server", "id", cfg.ID, "error", err) } } case mock.TypeMQTT: if mm.protocolManager != nil { + // Disconnect all connected clients cleanly before stopping the broker. + // This gives clients a chance to detect the disconnection and reconnect + // when the broker restarts with the updated configuration. + broker := mm.protocolManager.GetMQTTBroker(cfg.ID) + if broker != nil && broker.IsRunning() { + clients := broker.GetClients() + for _, clientID := range clients { + if err := broker.DisconnectClient(clientID); err != nil { + mm.log.Debug("failed to disconnect MQTT client during mock update", + "clientID", clientID, "mockID", cfg.ID, "error", err) + } + } + mm.log.Info("disconnected MQTT clients before mock update", + "mockID", cfg.ID, "clientCount", len(clients)) + } if err := mm.protocolManager.StopMQTTBroker(cfg.ID); err != nil { mm.log.Warn("failed to stop MQTT broker", "id", cfg.ID, "error", err) } @@ -197,6 +226,11 @@ func (mm *MockManager) unregisterHandlerLocked(cfg *config.MockConfiguration) { case mock.TypeWebSocket: if mm.handler != nil && cfg.WebSocket != nil { + // Disconnect active clients first so they receive close code 1012 + // (Service Restart) and reconnect with the new configuration. + // Must happen before UnregisterWebSocketEndpoint while byEndpoint + // still tracks the path. + mm.handler.DisconnectWebSocketEndpoint(cfg.WebSocket.Path) mm.handler.UnregisterWebSocketEndpoint(cfg.WebSocket.Path) } case mock.TypeGraphQL: diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 66be6118..9443332d 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" + grpcpeer "google.golang.org/grpc/peer" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" @@ -86,6 +87,7 @@ type Server struct { startedAt time.Time log *slog.Logger templateEngine *template.Engine + streamTracker *StreamTracker // Request logging support requestLoggerMu sync.RWMutex @@ -106,6 +108,7 @@ func NewServer(config *GRPCConfig, schema *ProtoSchema) (*Server, error) { schema: schema, log: logging.Nop(), templateEngine: template.New(), + streamTracker: NewStreamTracker(), }, nil } @@ -196,6 +199,10 @@ func (s *Server) Stop(ctx context.Context, timeout time.Duration) error { s.listener = nil s.mu.Unlock() + // Cancel all tracked streams before stopping the gRPC server so that + // stream handlers return codes.Unavailable promptly. + s.streamTracker.CancelAll() + if grpcSrv != nil { // Create a channel to signal graceful stop completion done := make(chan struct{}) @@ -630,6 +637,7 @@ func (s *Server) makeStreamHandler(serviceName, methodName string) func(srv inte // handleUnary handles unary RPC calls. func (s *Server) handleUnary(_ interface{}, ctx context.Context, dec func(interface{}) error, _ grpc.UnaryServerInterceptor, serviceName, methodName string) (interface{}, error) { + s.streamTracker.RecordUnaryRPC() startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) @@ -770,13 +778,10 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker (replaces raw metrics.ActiveConnections) + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamServerStream, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) // Read single request from client inputDesc := method.GetInputDescriptor() @@ -792,6 +797,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, nil, nil, grpcErr) return grpcErr } + s.streamTracker.RecordRecv(streamID) reqMap := dynamicMessageToMap(reqMsg) @@ -803,11 +809,10 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD return err } - // Apply initial delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, nil, grpcErr) return grpcErr } @@ -836,6 +841,13 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD var collectedResponses []interface{} for i, respData := range responses { + // Check for tracker cancellation between messages + if ctx.Err() != nil { + grpcErr := s.contextCancelError(streamID) + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) + return grpcErr + } + resp, err := s.buildResponse(method, respData, templateCtx) if err != nil { grpcErr := status.Errorf(codes.Internal, "failed to build response: %v", err) @@ -849,6 +861,8 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD return grpcErr } + s.streamTracker.RecordSent(streamID) + // Collect for logging collectedResponses = append(collectedResponses, dynamicMessageToMap(resp)) @@ -856,7 +870,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD if i < len(responses)-1 { s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during stream delay") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) return grpcErr } @@ -874,13 +888,10 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamClientStream, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) inputDesc := method.GetInputDescriptor() if inputDesc == nil { @@ -893,6 +904,13 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD var allRequests []interface{} var lastReqMap map[string]interface{} for { + // Check for tracker cancellation + if ctx.Err() != nil { + grpcErr := s.contextCancelError(streamID) + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) + return grpcErr + } + reqMsg := dynamicpb.NewMessage(inputDesc) if err := stream.RecvMsg(reqMsg); err != nil { if errors.Is(err, io.EOF) { @@ -902,6 +920,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } + s.streamTracker.RecordRecv(streamID) lastReqMap = dynamicMessageToMap(reqMsg) allRequests = append(allRequests, lastReqMap) } @@ -914,11 +933,10 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD return err } - // Apply delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } @@ -949,6 +967,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD return err } + s.streamTracker.RecordSent(streamID) respMap := dynamicMessageToMap(resp) // Log the successful call @@ -962,13 +981,10 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamBidi, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) inputDesc := method.GetInputDescriptor() if inputDesc == nil { @@ -985,11 +1001,10 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return err } - // Apply initial delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, nil, nil, grpcErr) return grpcErr } @@ -1018,6 +1033,13 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * // Echo pattern: for each received message, send a response for { + // Check for tracker cancellation + if ctx.Err() != nil { + grpcErr := s.contextCancelError(streamID) + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) + return grpcErr + } + reqMsg := dynamicpb.NewMessage(inputDesc) if err := stream.RecvMsg(reqMsg); err != nil { if errors.Is(err, io.EOF) { @@ -1031,6 +1053,8 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return grpcErr } + s.streamTracker.RecordRecv(streamID) + // Collect request for logging reqMap := dynamicMessageToMap(reqMsg) allRequests = append(allRequests, reqMap) @@ -1064,13 +1088,15 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return grpcErr } + s.streamTracker.RecordSent(streamID) + // Collect response for logging allResponses = append(allResponses, dynamicMessageToMap(resp)) respIndex++ s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during stream delay") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) return grpcErr } @@ -1250,6 +1276,18 @@ func (s *Server) createTemplateContext(md metadata.MD, reqMap map[string]interfa return template.NewContextFromMap(reqMap, headers) } +// contextCancelError returns the appropriate gRPC status error for a +// cancelled context. If the stream was cancelled by admin action (mock update, +// explicit admin cancel), it returns codes.Unavailable so clients with retry +// policies reconnect. If cancelled by the client (e.g., timeout), it returns +// codes.Canceled. +func (s *Server) contextCancelError(streamID string) error { + if s.streamTracker.WasAdminCancelled(streamID) { + return status.Error(codes.Unavailable, "mock configuration updated") + } + return status.Error(codes.Canceled, "stream cancelled by client") +} + // applyDelay applies configured delay, respecting the context deadline. // Accepts Go duration strings ("100ms", "2s") or bare numbers treated as milliseconds ("100"). func (s *Server) applyDelay(delay string) { @@ -1652,6 +1690,11 @@ func (s *Server) Config() *GRPCConfig { return s.config } +// StreamTracker returns the server's stream tracker. +func (s *Server) StreamTracker() *StreamTracker { + return s.streamTracker +} + // Schema returns the proto schema. func (s *Server) Schema() *ProtoSchema { return s.schema @@ -1931,3 +1974,11 @@ func (s *Server) recordGRPCMetrics(fullPath string, grpcErr error, duration time } } } + +// peerAddr extracts the client address from a gRPC context. +func peerAddr(ctx context.Context) string { + if p, ok := grpcpeer.FromContext(ctx); ok && p.Addr != nil { + return p.Addr.String() + } + return "" +} diff --git a/pkg/grpc/stream_tracker.go b/pkg/grpc/stream_tracker.go new file mode 100644 index 00000000..96947a30 --- /dev/null +++ b/pkg/grpc/stream_tracker.go @@ -0,0 +1,265 @@ +package grpc + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/getmockd/mockd/pkg/metrics" +) + +// StreamInfo holds metadata about an active streaming RPC. +type StreamInfo struct { + ID string `json:"id"` + Method string `json:"method"` + StreamType streamType `json:"streamType"` + ClientAddr string `json:"clientAddr,omitempty"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` +} + +// StreamStats holds aggregate statistics for gRPC streams. +type StreamStats struct { + ActiveStreams int `json:"activeStreams"` + TotalStreams int64 `json:"totalStreams"` + TotalRPCs int64 `json:"totalRPCs"` + TotalMsgSent int64 `json:"totalMessagesSent"` + TotalMsgRecv int64 `json:"totalMessagesRecv"` + StreamsByMethod map[string]int `json:"streamsByMethod"` +} + +// StreamTracker tracks active gRPC streaming RPCs for a single Server. +type StreamTracker struct { + streams map[string]*trackedStream // ID -> tracked stream + adminCancelled map[string]bool // IDs cancelled by admin/mock-update + mu sync.RWMutex + + // Lifetime counters (include completed streams). + totalStreams atomic.Int64 + totalRPCs atomic.Int64 + totalMsgSent atomic.Int64 + totalMsgRecv atomic.Int64 +} + +// trackedStream extends StreamInfo with internal bookkeeping. +type trackedStream struct { + info StreamInfo + msgSent atomic.Int64 + msgRecv atomic.Int64 + cancel context.CancelFunc +} + +// NewStreamTracker creates a new StreamTracker. +func NewStreamTracker() *StreamTracker { + return &StreamTracker{ + streams: make(map[string]*trackedStream), + adminCancelled: make(map[string]bool), + } +} + +// nextStreamID generates a short unique stream ID. +var streamIDCounter atomic.Int64 + +func nextStreamID() string { + return fmt.Sprintf("grpc-stream-%d", streamIDCounter.Add(1)) +} + +// Register adds a new streaming RPC and returns its ID and a cancel-aware +// context. The caller should defer Unregister. +func (t *StreamTracker) Register(ctx context.Context, method string, st streamType, clientAddr string) (string, context.Context, context.CancelFunc) { + id := nextStreamID() + ctx, cancel := context.WithCancel(ctx) //nolint:gosec // cancel is stored in trackedStream.cancel + + ts := &trackedStream{ + info: StreamInfo{ + ID: id, + Method: method, + StreamType: st, + ClientAddr: clientAddr, + ConnectedAt: time.Now(), + }, + cancel: cancel, + } + + t.mu.Lock() + t.streams[id] = ts + t.mu.Unlock() + + t.totalStreams.Add(1) + + if metrics.ActiveConnections != nil { + if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { + vec.Inc() + } + } + + return id, ctx, cancel +} + +// Unregister removes a stream and accumulates its message counters. +func (t *StreamTracker) Unregister(id string) { + t.mu.Lock() + ts, ok := t.streams[id] + if ok { + delete(t.streams, id) + } + delete(t.adminCancelled, id) + t.mu.Unlock() + + if !ok { + return + } + + t.totalMsgSent.Add(ts.msgSent.Load()) + t.totalMsgRecv.Add(ts.msgRecv.Load()) + + if metrics.ActiveConnections != nil { + if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { + vec.Dec() + } + } +} + +// RecordSent increments the sent counter for a stream. +func (t *StreamTracker) RecordSent(id string) { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts != nil { + ts.msgSent.Add(1) + } +} + +// RecordRecv increments the received counter for a stream. +func (t *StreamTracker) RecordRecv(id string) { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts != nil { + ts.msgRecv.Add(1) + } +} + +// RecordUnaryRPC increments the total RPC counter for unary calls +// (which are not tracked as active streams). +func (t *StreamTracker) RecordUnaryRPC() { + t.totalRPCs.Add(1) +} + +// Get returns info about a specific stream, or nil. +func (t *StreamTracker) Get(id string) *StreamInfo { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts == nil { + return nil + } + return t.toInfo(ts) +} + +// List returns info about all active streams. +func (t *StreamTracker) List() []*StreamInfo { + t.mu.RLock() + defer t.mu.RUnlock() + + result := make([]*StreamInfo, 0, len(t.streams)) + for _, ts := range t.streams { + result = append(result, t.toInfo(ts)) + } + return result +} + +// Count returns the number of active streams. +func (t *StreamTracker) Count() int { + t.mu.RLock() + defer t.mu.RUnlock() + return len(t.streams) +} + +// Cancel cancels a specific stream's context, causing the RPC handler to +// return codes.Unavailable to the client. +func (t *StreamTracker) Cancel(id string) error { + t.mu.Lock() + ts := t.streams[id] + if ts != nil { + t.adminCancelled[id] = true + } + t.mu.Unlock() + if ts == nil { + return fmt.Errorf("stream %s not found", id) + } + ts.cancel() + return nil +} + +// WasAdminCancelled returns true if the stream was cancelled by an admin +// action (CancelAll for mock update, or explicit Cancel via admin API) +// rather than by the client. +func (t *StreamTracker) WasAdminCancelled(id string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + return t.adminCancelled[id] +} + +// CancelAll cancels all active streams with the given gRPC status. +// Returns the number of streams cancelled. +func (t *StreamTracker) CancelAll() int { + t.mu.RLock() + ids := make([]string, 0, len(t.streams)) + for id := range t.streams { + ids = append(ids, id) + } + t.mu.RUnlock() + + count := 0 + for _, id := range ids { + if t.Cancel(id) == nil { + count++ + } + } + return count +} + +// Stats returns aggregate statistics. +func (t *StreamTracker) Stats() *StreamStats { + t.mu.RLock() + defer t.mu.RUnlock() + + var liveSent, liveRecv int64 + byMethod := make(map[string]int) + for _, ts := range t.streams { + liveSent += ts.msgSent.Load() + liveRecv += ts.msgRecv.Load() + byMethod[ts.info.Method]++ + } + + return &StreamStats{ + ActiveStreams: len(t.streams), + TotalStreams: t.totalStreams.Load(), + TotalRPCs: t.totalRPCs.Load() + t.totalStreams.Load(), + TotalMsgSent: t.totalMsgSent.Load() + liveSent, + TotalMsgRecv: t.totalMsgRecv.Load() + liveRecv, + StreamsByMethod: byMethod, + } +} + +// Close cancels all streams (used during server shutdown). +func (t *StreamTracker) Close() { + t.CancelAll() +} + +func (t *StreamTracker) toInfo(ts *trackedStream) *StreamInfo { + return &StreamInfo{ + ID: ts.info.ID, + Method: ts.info.Method, + StreamType: ts.info.StreamType, + ClientAddr: ts.info.ClientAddr, + ConnectedAt: ts.info.ConnectedAt, + MessagesSent: ts.msgSent.Load(), + MessagesRecv: ts.msgRecv.Load(), + } +} + diff --git a/pkg/grpc/stream_tracker_test.go b/pkg/grpc/stream_tracker_test.go new file mode 100644 index 00000000..6228d42f --- /dev/null +++ b/pkg/grpc/stream_tracker_test.go @@ -0,0 +1,193 @@ +package grpc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamTracker_RegisterUnregister(t *testing.T) { + tracker := NewStreamTracker() + + id, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamServerStream, "127.0.0.1:1234") + defer cancel() + + assert.NotEmpty(t, id) + assert.Equal(t, 1, tracker.Count()) + + info := tracker.Get(id) + require.NotNil(t, info) + assert.Equal(t, "/pkg.Svc/Method", info.Method) + assert.Equal(t, streamServerStream, info.StreamType) + assert.Equal(t, "127.0.0.1:1234", info.ClientAddr) + assert.False(t, info.ConnectedAt.IsZero()) + + // Context should not be cancelled yet + assert.NoError(t, ctx.Err()) + + tracker.Unregister(id) + assert.Equal(t, 0, tracker.Count()) + assert.Nil(t, tracker.Get(id)) +} + +func TestStreamTracker_MessageCounting(t *testing.T) { + tracker := NewStreamTracker() + + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Stream", streamBidi, "") + defer cancel() + + tracker.RecordSent(id) + tracker.RecordSent(id) + tracker.RecordRecv(id) + + info := tracker.Get(id) + require.NotNil(t, info) + assert.Equal(t, int64(2), info.MessagesSent) + assert.Equal(t, int64(1), info.MessagesRecv) + + // Verify stats include live counters + stats := tracker.Stats() + assert.Equal(t, int64(2), stats.TotalMsgSent) + assert.Equal(t, int64(1), stats.TotalMsgRecv) + assert.Equal(t, 1, stats.ActiveStreams) + + // After unregister, counters roll into lifetime totals + tracker.Unregister(id) + stats = tracker.Stats() + assert.Equal(t, 0, stats.ActiveStreams) + assert.Equal(t, int64(2), stats.TotalMsgSent) + assert.Equal(t, int64(1), stats.TotalMsgRecv) +} + +func TestStreamTracker_Cancel(t *testing.T) { + tracker := NewStreamTracker() + + id, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamServerStream, "") + defer cancel() + + // Cancel the stream + err := tracker.Cancel(id) + assert.NoError(t, err) + + // Context should be cancelled + assert.Error(t, ctx.Err()) + + // Cancel non-existent stream + err = tracker.Cancel("non-existent") + assert.Error(t, err) +} + +func TestStreamTracker_CancelAll(t *testing.T) { + tracker := NewStreamTracker() + + _, ctx1, cancel1 := tracker.Register(context.Background(), "/pkg.Svc/A", streamServerStream, "") + defer cancel1() + _, ctx2, cancel2 := tracker.Register(context.Background(), "/pkg.Svc/B", streamBidi, "") + defer cancel2() + + assert.Equal(t, 2, tracker.Count()) + + n := tracker.CancelAll() + assert.Equal(t, 2, n) + + assert.Error(t, ctx1.Err()) + assert.Error(t, ctx2.Err()) +} + +func TestStreamTracker_List(t *testing.T) { + tracker := NewStreamTracker() + + id1, _, cancel1 := tracker.Register(context.Background(), "/pkg.Svc/A", streamServerStream, "1.2.3.4:100") + defer cancel1() + id2, _, cancel2 := tracker.Register(context.Background(), "/pkg.Svc/B", streamClientStream, "5.6.7.8:200") + defer cancel2() + + list := tracker.List() + assert.Len(t, list, 2) + + ids := map[string]bool{} + for _, info := range list { + ids[info.ID] = true + } + assert.True(t, ids[id1]) + assert.True(t, ids[id2]) +} + +func TestStreamTracker_Stats(t *testing.T) { + tracker := NewStreamTracker() + + // Record some unary RPCs + tracker.RecordUnaryRPC() + tracker.RecordUnaryRPC() + + // Register a stream + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Stream", streamServerStream, "") + defer cancel() + tracker.RecordSent(id) + + stats := tracker.Stats() + assert.Equal(t, 1, stats.ActiveStreams) + assert.Equal(t, int64(1), stats.TotalStreams) + // TotalRPCs = unary (2) + streams (1) = 3 + assert.Equal(t, int64(3), stats.TotalRPCs) + assert.Equal(t, int64(1), stats.TotalMsgSent) + assert.Equal(t, map[string]int{"/pkg.Svc/Stream": 1}, stats.StreamsByMethod) +} + +func TestStreamTracker_RecordOnNonExistent(t *testing.T) { + tracker := NewStreamTracker() + + // Should not panic + tracker.RecordSent("non-existent") + tracker.RecordRecv("non-existent") +} + +func TestStreamTracker_UnregisterIdempotent(t *testing.T) { + tracker := NewStreamTracker() + + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamBidi, "") + defer cancel() + + tracker.Unregister(id) + tracker.Unregister(id) // Should not panic + assert.Equal(t, 0, tracker.Count()) +} + +func TestStreamTracker_Close(t *testing.T) { + tracker := NewStreamTracker() + + _, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamBidi, "") + defer cancel() + + tracker.Close() + assert.Error(t, ctx.Err()) +} + +func TestStreamTracker_ConcurrentAccess(t *testing.T) { + tracker := NewStreamTracker() + ctx := context.Background() + + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 100; i++ { + id, _, cancel := tracker.Register(ctx, "/pkg.Svc/Method", streamBidi, "") + tracker.RecordSent(id) + tracker.RecordRecv(id) + time.Sleep(time.Microsecond) + tracker.Unregister(id) + cancel() + } + }() + + for i := 0; i < 100; i++ { + _ = tracker.List() + _ = tracker.Stats() + _ = tracker.Count() + } + + <-done +} diff --git a/pkg/mcp/tool_handlers_test.go b/pkg/mcp/tool_handlers_test.go index 2c5e4f19..5f468a4a 100644 --- a/pkg/mcp/tool_handlers_test.go +++ b/pkg/mcp/tool_handlers_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + apitypes "github.com/getmockd/mockd/pkg/api/types" "github.com/getmockd/mockd/pkg/cli" "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/mock" @@ -429,6 +430,40 @@ func (m *mockAdminClient) BulkCreateMocks(_ []*mock.Mock, _ string) (*cli.BulkCr return nil, nil } +// --- Connection Management (stubs) --- + +func (m *mockAdminClient) ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetWebSocketConnection(_ string) (*apitypes.WebSocketConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseWebSocketConnection(_ string) error { return nil } +func (m *mockAdminClient) SendWebSocketMessage(_ string, _ string, _ bool) error { return nil } +func (m *mockAdminClient) GetWebSocketStats() (*apitypes.WebSocketStats, error) { return nil, nil } +func (m *mockAdminClient) ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetSSEConnection(_ string) (*apitypes.SSEConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseSSEConnection(_ string) error { return nil } +func (m *mockAdminClient) GetSSEStats() (*apitypes.SSEStats, error) { return nil, nil } +func (m *mockAdminClient) ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetMQTTConnection(_ string) (*apitypes.MQTTConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseMQTTConnection(_ string) error { return nil } +func (m *mockAdminClient) GetMQTTStats() (*apitypes.MQTTStats, error) { return nil, nil } +func (m *mockAdminClient) ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetGRPCStream(_ string) (*apitypes.GRPCStream, error) { return nil, nil } +func (m *mockAdminClient) CloseGRPCStream(_ string) error { return nil } +func (m *mockAdminClient) GetGRPCStats() (*apitypes.GRPCStats, error) { return nil, nil } + // ============================================================================= // Test Helpers // ============================================================================= diff --git a/pkg/mqtt/broker.go b/pkg/mqtt/broker.go index 17b14e31..f6f540d8 100644 --- a/pkg/mqtt/broker.go +++ b/pkg/mqtt/broker.go @@ -59,6 +59,7 @@ type Broker struct { log *slog.Logger internalSubscribers map[string][]SubscriptionHandler clientSubscriptions map[string][]string + clientConnectedAt map[string]time.Time simulator *Simulator recordingEnabled bool recordingStore MQTTRecordingStore @@ -144,6 +145,7 @@ func NewBroker(config *MQTTConfig) (*Broker, error) { log: logging.Nop(), internalSubscribers: make(map[string][]SubscriptionHandler), clientSubscriptions: make(map[string][]string), + clientConnectedAt: make(map[string]time.Time), sessionManager: NewSessionManager(), } broker.responseHandler = NewResponseHandler(broker) @@ -654,6 +656,156 @@ func (b *Broker) notifyTestPanelPublish(topic string, payload []byte, qos int, r b.sessionManager.NotifyMessage(b.config.ID, msg) } +// trackClientConnect records the connection time for a client. +// Called from the MessageHook's OnSessionEstablished callback. +func (b *Broker) trackClientConnect(clientID string) { + if b.stopping.Load() != 0 { + return + } + b.mu.Lock() + b.clientConnectedAt[clientID] = time.Now() + b.mu.Unlock() +} + +// trackClientDisconnect removes the connection time for a client. +// Called from the MessageHook's OnDisconnect callback. +func (b *Broker) trackClientDisconnect(clientID string) { + if b.stopping.Load() != 0 { + return + } + b.mu.Lock() + delete(b.clientConnectedAt, clientID) + b.mu.Unlock() +} + +// MQTTClientInfo represents information about a connected MQTT client. +type MQTTClientInfo struct { + ID string + BrokerID string + ConnectedAt time.Time + Subscriptions []string + ProtocolVersion byte + Username string + RemoteAddr string + Closed bool +} + +// ListClientInfos returns information about all connected MQTT clients. +func (b *Broker) ListClientInfos() []*MQTTClientInfo { + b.mu.RLock() + brokerID := "" + if b.config != nil { + brokerID = b.config.ID + } + subs := make(map[string][]string, len(b.clientSubscriptions)) + for k, v := range b.clientSubscriptions { + subs[k] = append([]string{}, v...) + } + connTimes := make(map[string]time.Time, len(b.clientConnectedAt)) + for k, v := range b.clientConnectedAt { + connTimes[k] = v + } + b.mu.RUnlock() + + if b.server == nil { + return nil + } + + clients := b.server.Clients.GetAll() + var result []*MQTTClientInfo + + for id, cl := range clients { + if cl.Net.Inline { + continue // skip the built-in inline client + } + info := &MQTTClientInfo{ + ID: id, + BrokerID: brokerID, + ProtocolVersion: cl.Properties.ProtocolVersion, + Username: string(cl.Properties.Username), + RemoteAddr: cl.Net.Remote, + Closed: cl.Closed(), + } + if ct, ok := connTimes[id]; ok { + info.ConnectedAt = ct + } + if subList, ok := subs[id]; ok { + info.Subscriptions = subList + } + result = append(result, info) + } + + return result +} + +// GetClientInfo returns information about a specific MQTT client. +func (b *Broker) GetClientInfo(clientID string) *MQTTClientInfo { + if b.server == nil { + return nil + } + + cl, ok := b.server.Clients.Get(clientID) + if !ok { + return nil + } + if cl.Net.Inline { + return nil + } + + b.mu.RLock() + brokerID := "" + if b.config != nil { + brokerID = b.config.ID + } + var subsCopy []string + if s, ok := b.clientSubscriptions[clientID]; ok { + subsCopy = append([]string{}, s...) + } + connectedAt := b.clientConnectedAt[clientID] + b.mu.RUnlock() + + return &MQTTClientInfo{ + ID: clientID, + BrokerID: brokerID, + ConnectedAt: connectedAt, + ProtocolVersion: cl.Properties.ProtocolVersion, + Username: string(cl.Properties.Username), + RemoteAddr: cl.Net.Remote, + Closed: cl.Closed(), + Subscriptions: subsCopy, + } +} + +// DisconnectClient disconnects a specific MQTT client by client ID. +// Returns an error if the client is not found. +func (b *Broker) DisconnectClient(clientID string) error { + if b.server == nil { + return errors.New("broker is not running") + } + + cl, ok := b.server.Clients.Get(clientID) + if !ok || cl.Net.Inline { + return fmt.Errorf("client %q not found", clientID) + } + + cl.Stop(errors.New("disconnected by admin API")) + return nil +} + +// GetConnectionStats returns aggregate connection statistics. +func (b *Broker) GetConnectionStats() (connectedClients int, totalSubscriptions int, subsByClient map[string]int) { + b.mu.RLock() + defer b.mu.RUnlock() + + subsByClient = make(map[string]int, len(b.clientSubscriptions)) + for clientID, subs := range b.clientSubscriptions { + subsByClient[clientID] = len(subs) + totalSubscriptions += len(subs) + } + connectedClients = len(b.getClientsLocked()) + return +} + // detectPayloadFormat detects the format of a payload func detectPayloadFormat(payload []byte) string { // Try to parse as JSON diff --git a/pkg/mqtt/connection_management_test.go b/pkg/mqtt/connection_management_test.go new file mode 100644 index 00000000..74a70dda --- /dev/null +++ b/pkg/mqtt/connection_management_test.go @@ -0,0 +1,104 @@ +package mqtt + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBroker_ListClientInfos_Empty(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + infos := broker.ListClientInfos() + // The inline client may or may not appear, but we filter it out + for _, info := range infos { + assert.False(t, info.Closed, "active clients should not be closed") + } +} + +func TestBroker_GetClientInfo_NotFound(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + info := broker.GetClientInfo("nonexistent-client") + assert.Nil(t, info) +} + +func TestBroker_DisconnectClient_NotFound(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + err = broker.DisconnectClient("nonexistent-client") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestBroker_DisconnectClient_BrokerNotRunning(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + // Don't start the broker + err = broker.DisconnectClient("some-client") + assert.Error(t, err) +} + +func TestBroker_GetConnectionStats_Empty(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + connected, totalSubs, subsByClient := broker.GetConnectionStats() + assert.GreaterOrEqual(t, connected, 0) + assert.Equal(t, 0, totalSubs) + assert.NotNil(t, subsByClient) +} + +func TestBroker_ListClientInfos_NilServer(t *testing.T) { + t.Parallel() + broker := &Broker{ + config: &MQTTConfig{ID: "test"}, + clientSubscriptions: make(map[string][]string), + } + + infos := broker.ListClientInfos() + assert.Nil(t, infos) +} + +func TestBroker_GetClientInfo_NilServer(t *testing.T) { + t.Parallel() + broker := &Broker{ + config: &MQTTConfig{ID: "test"}, + clientSubscriptions: make(map[string][]string), + } + + info := broker.GetClientInfo("any-client") + assert.Nil(t, info) +} diff --git a/pkg/mqtt/hooks.go b/pkg/mqtt/hooks.go index e423a9ff..67346744 100644 --- a/pkg/mqtt/hooks.go +++ b/pkg/mqtt/hooks.go @@ -219,9 +219,29 @@ func (h *MessageHook) Provides(b byte) bool { mqtt.OnPublish, mqtt.OnSubscribed, mqtt.OnUnsubscribed, + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, }, []byte{b}) } +// OnSessionEstablished is called after a client has connected and its session +// is fully set up. We use it to track the connection time. +func (h *MessageHook) OnSessionEstablished(cl *mqtt.Client, _ packets.Packet) { + if cl.Net.Inline { + return + } + h.broker.trackClientConnect(cl.ID) +} + +// OnDisconnect is called when a client disconnects. We use it to clean up the +// tracked connection time. +func (h *MessageHook) OnDisconnect(cl *mqtt.Client, _ error, _ bool) { + if cl.Net.Inline { + return + } + h.broker.trackClientDisconnect(cl.ID) +} + // OnPublish handles incoming publish messages func (h *MessageHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { // Defense-in-depth: enforce ACL on publish for non-inline clients. diff --git a/pkg/websocket/manager.go b/pkg/websocket/manager.go index 2dffe284..0a39ffc7 100644 --- a/pkg/websocket/manager.go +++ b/pkg/websocket/manager.go @@ -102,12 +102,15 @@ func (m *ConnectionManager) RegisterEndpoint(e *Endpoint) { } // UnregisterEndpoint removes an endpoint from the manager. +// Callers that want active clients to reconnect with new config should call +// DisconnectByEndpoint before this method, while byEndpoint still tracks +// the path. The engine's unregisterHandlerLocked does this automatically +// for every mock update, delete, and clear operation. func (m *ConnectionManager) UnregisterEndpoint(path string) { m.mu.Lock() defer m.mu.Unlock() delete(m.endpoints, path) - // Note: connections remain until they close } // GetEndpoint returns an endpoint by path. @@ -255,6 +258,38 @@ func (m *ConnectionManager) CountByEndpoint(path string) int { return 0 } +// DisconnectByEndpoint closes all active connections on an endpoint with the given code and reason. +// Returns the number of connections that were closed. +// +// Must be called before UnregisterEndpoint so that byEndpoint still contains the connections. +// Uses the read-lock/copy pattern (same as BroadcastToEndpoint) to avoid holding the lock +// while calling conn.Close, which acquires its own lock. +func (m *ConnectionManager) DisconnectByEndpoint(path string, code CloseCode, reason string) int { + m.mu.RLock() + var ids []string + if eps, ok := m.byEndpoint[path]; ok { + ids = make([]string, 0, len(eps)) + for id := range eps { + ids = append(ids, id) + } + } + m.mu.RUnlock() + + closed := 0 + for _, id := range ids { + m.mu.RLock() + conn := m.connections[id] + m.mu.RUnlock() + + if conn != nil && !conn.IsClosed() { + if err := conn.Close(code, reason); err == nil { + closed++ + } + } + } + return closed +} + // BroadcastToEndpoint sends a message to all connections on an endpoint. func (m *ConnectionManager) BroadcastToEndpoint(path string, msgType MessageType, data []byte) int { m.mu.RLock() diff --git a/pkg/websocket/manager_test.go b/pkg/websocket/manager_test.go index 3298942e..c592e793 100644 --- a/pkg/websocket/manager_test.go +++ b/pkg/websocket/manager_test.go @@ -230,6 +230,55 @@ func TestConnectionManager_ConcurrentJoinFromBothSides(t *testing.T) { } } +func TestConnectionManager_DisconnectByEndpoint_UnknownPath(t *testing.T) { + manager := NewConnectionManager() + + // Endpoint that was never registered — should return 0 without panicking + count := manager.DisconnectByEndpoint("/ws/unknown", CloseServiceRestart, "mock updated") + if count != 0 { + t.Errorf("expected 0 for unknown path, got %d", count) + } +} + +func TestConnectionManager_DisconnectByEndpoint_SkipsAlreadyClosed(t *testing.T) { + manager := NewConnectionManager() + + // Pre-closed connection — Close must not be called again (nil conn would panic). + conn := &Connection{ + id: "conn-already-closed", + endpointPath: "/ws/test", + groups: make(map[string]struct{}), + } + conn.closed.Store(true) + manager.Add(conn) + + // All connections already closed → count must be 0. + count := manager.DisconnectByEndpoint("/ws/test", CloseServiceRestart, "mock updated") + if count != 0 { + t.Errorf("expected 0 (connection pre-closed), got %d", count) + } +} + +func TestConnectionManager_DisconnectByEndpoint_IsolatesEndpoint(t *testing.T) { + manager := NewConnectionManager() + + // Two endpoints: only /ws/target should be affected. + target := &Connection{id: "t1", endpointPath: "/ws/target", groups: make(map[string]struct{})} + target.closed.Store(true) // pre-close to avoid nil conn panic + other := &Connection{id: "o1", endpointPath: "/ws/other", groups: make(map[string]struct{})} + other.closed.Store(true) + + manager.Add(target) + manager.Add(other) + + manager.DisconnectByEndpoint("/ws/target", CloseServiceRestart, "mock updated") + + // The connection on /ws/other must still be tracked by the manager. + if manager.Get("o1") == nil { + t.Error("connection on /ws/other should not be removed by DisconnectByEndpoint on /ws/target") + } +} + func TestConnectionManager_BroadcastToGroup(t *testing.T) { manager := NewConnectionManager() diff --git a/tests/integration/sse_test.go b/tests/integration/sse_test.go new file mode 100644 index 00000000..8ab9cd53 --- /dev/null +++ b/tests/integration/sse_test.go @@ -0,0 +1,302 @@ +package integration + +import ( + "bufio" + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/config" + "github.com/getmockd/mockd/pkg/engine" + "github.com/getmockd/mockd/pkg/mock" + "github.com/getmockd/mockd/pkg/sse" +) + +// ============================================================================ +// SSE Test Helpers +// ============================================================================ + +// setupSSEMockServer creates a test server and adapter for SSE mock management. +func setupSSEMockServer(t *testing.T) (*httptest.Server, *engine.ControlAPIAdapter) { + t.Helper() + cfg := config.DefaultServerConfiguration() + srv := engine.NewServer(cfg) + adapter := engine.NewControlAPIAdapter(srv) + + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(func() { ts.Close() }) + return ts, adapter +} + +// getSSEManager extracts the SSE ConnectionManager from the test server. +func getSSEManager(t *testing.T, ts *httptest.Server) *sse.SSEConnectionManager { + t.Helper() + handler := ts.Config.Handler + require.NotNil(t, handler, "handler is nil") + + engineHandler, ok := handler.(*engine.Handler) + require.True(t, ok, "handler is not *engine.Handler") + + sseHandler := engineHandler.SSEHandler() + require.NotNil(t, sseHandler, "SSE handler is nil") + + mgr := sseHandler.GetManager() + require.NotNil(t, mgr, "SSE connection manager is nil") + return mgr +} + +// newSSEMock creates a minimal SSE mock configuration that streams events +// with a keepalive to hold the connection open. +func newSSEMock(path string) *mock.Mock { + return &mock.Mock{ + Type: mock.TypeHTTP, + HTTP: &mock.HTTPSpec{ + Matcher: &mock.HTTPMatcher{ + Method: "GET", + Path: path, + }, + SSE: &mock.SSEConfig{ + Generator: &mock.SSEEventGenerator{ + Type: "sequence", + Count: 0, // infinite + Sequence: &mock.SSESequenceGenerator{ + Start: 1, + Increment: 1, + }, + }, + Lifecycle: mock.SSELifecycleConfig{ + KeepaliveInterval: 300, // 5min keepalive keeps connection alive + }, + }, + }, + } +} + +// connectSSE opens an SSE connection and waits until the first event is received, +// confirming the server has registered the connection. Returns a cancel func +// that closes the connection. +func connectSSE(t *testing.T, url string) context.CancelFunc { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + t.Cleanup(func() { resp.Body.Close() }) + + // Read until we get at least one SSE event (data: line), confirming connection is tracked. + ready := make(chan struct{}) + go func() { + scanner := bufio.NewScanner(resp.Body) + signalled := false + for scanner.Scan() { + line := scanner.Text() + if !signalled && len(line) > 0 { + close(ready) + signalled = true + } + } + }() + + select { + case <-ready: + // Connection is active and streaming + case <-time.After(5 * time.Second): + cancel() + t.Fatal("timeout waiting for SSE event") + } + + return cancel +} + +// waitForSSECount polls until the SSE manager reports the expected connection count +// or the timeout expires. +func waitForSSECount(t *testing.T, mgr *sse.SSEConnectionManager, expected int, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if mgr.Count() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("expected %d SSE connections, got %d", expected, mgr.Count()) +} + +// ============================================================================ +// SSE Auto-Disconnect Tests +// ============================================================================ + +// TestSSE_MockUpdate_ClosesActiveConnections verifies that updating an SSE mock +// disconnects all active SSE clients so they reconnect with the new configuration. +func TestSSE_MockUpdate_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + // Create an SSE mock. + mockCfg := newSSEMock("/sse/update-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect an SSE client. + cancelSSE := connectSSE(t, ts.URL+"/sse/update-test") + defer cancelSSE() + + // Wait for the connection to be registered. + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Update the mock — this should disconnect all active SSE clients. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + updated := *mocks[0] + updated.HTTP.SSE.Generator.Sequence.Start = 100 + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // Connections should be closed. + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MockDelete_ClosesActiveConnections verifies that deleting an SSE mock +// disconnects all active SSE clients. +func TestSSE_MockDelete_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/delete-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/delete-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Delete the mock — SSE clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MockToggleDisable_ClosesActiveConnections verifies that disabling an SSE +// mock (toggle to enabled=false) disconnects all active SSE clients. +func TestSSE_MockToggleDisable_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/toggle-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/toggle-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Disable the mock — SSE clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + disabled := *mocks[0] + enabled := false + disabled.Enabled = &enabled + require.NoError(t, adapter.UpdateMock(disabled.ID, &disabled)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_ClearMocks_ClosesActiveConnections verifies that clearing all mocks +// disconnects all active SSE clients. +func TestSSE_ClearMocks_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/clear-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/clear-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Clear all mocks — SSE clients must be disconnected. + adapter.ClearMocks() + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MultipleConnections_AllDisconnected verifies that multiple SSE clients +// connected to the same mock are all disconnected when the mock is deleted. +func TestSSE_MultipleConnections_AllDisconnected(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/multi-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect 3 SSE clients. + for i := 0; i < 3; i++ { + cancel := connectSSE(t, ts.URL+"/sse/multi-test") + defer cancel() + } + + waitForSSECount(t, mgr, 3, 2*time.Second) + + // Delete the mock — all 3 clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_UpdateDoesNotAffectOtherMocks verifies that updating one SSE mock +// only disconnects clients on that mock, not clients on other SSE mocks. +func TestSSE_UpdateDoesNotAffectOtherMocks(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + // Create two SSE mocks. + mock1 := newSSEMock("/sse/mock1") + require.NoError(t, adapter.AddMock(mock1)) + mock2 := newSSEMock("/sse/mock2") + require.NoError(t, adapter.AddMock(mock2)) + + // Connect a client to each. + cancel1 := connectSSE(t, ts.URL+"/sse/mock1") + defer cancel1() + cancel2 := connectSSE(t, ts.URL+"/sse/mock2") + defer cancel2() + + waitForSSECount(t, mgr, 2, 2*time.Second) + + // Update mock1 — only mock1 connections should drop. + mocks := adapter.ListMocks() + require.Len(t, mocks, 2) + + var target *mock.Mock + for _, m := range mocks { + if m.HTTP != nil && m.HTTP.Matcher != nil && m.HTTP.Matcher.Path == "/sse/mock1" { + target = m + break + } + } + require.NotNil(t, target, "could not find /sse/mock1 mock") + + updated := *target + updated.HTTP.SSE.Generator.Sequence.Start = 100 + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // mock1 connections should be closed, mock2 should remain. + // Wait for mock1 connections to drop. + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 1, mgr.Count(), "expected 1 remaining SSE connection (mock2)") +} diff --git a/tests/integration/websocket_test.go b/tests/integration/websocket_test.go index 0e6c910b..92c6a079 100644 --- a/tests/integration/websocket_test.go +++ b/tests/integration/websocket_test.go @@ -2,6 +2,7 @@ package integration import ( "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -9,12 +10,13 @@ import ( "time" ws "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/engine" "github.com/getmockd/mockd/pkg/mock" "github.com/getmockd/mockd/pkg/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // ============================================================================ @@ -799,6 +801,217 @@ func TestWS_FullServer_WithMiddleware(t *testing.T) { assert.Equal(t, testMsg, string(data)) } +// ============================================================================ +// User Story: Mock Update / Delete Closes Active Connections +// ============================================================================ + +// assertCloseCode1012 asserts that readErr is a WebSocket close error with code 1012 +// (Service Restart). Expects readErr to be non-nil (connection already closed by server). +func assertCloseCode1012(t *testing.T, readErr error) { + t.Helper() + require.Error(t, readErr, "expected connection to be closed by server") + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// setupWSMockServer creates a test server that has a WebSocket mock registered +// via the engine's control adapter (not the legacy WebSocketEndpoints collection +// field). Returns the httptest.Server and a ControlAPIAdapter for mock management. +func setupWSMockServer(t *testing.T) (*httptest.Server, *engine.ControlAPIAdapter) { + t.Helper() + cfg := config.DefaultServerConfiguration() + srv := engine.NewServer(cfg) + adapter := engine.NewControlAPIAdapter(srv) + + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(func() { ts.Close() }) + return ts, adapter +} + +// TestWS_MockUpdate_ClosesActiveConnections verifies that updating a WebSocket mock +// sends close code 1012 (Service Restart) to all connected clients so they reconnect +// and receive the new configuration. +func TestWS_MockUpdate_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + // Create initial WebSocket mock. + initialResponse := &mock.WSMessageResponse{Type: "text", Value: "v1"} + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/update-test", + DefaultResponse: initialResponse, + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect a WebSocket client. + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/update-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + wsm := getWSManager(t, ts) + assert.Equal(t, 1, wsm.CountByEndpoint("/ws/update-test"), "expected 1 active connection before update") + + // Update the mock — this should trigger disconnect of all active clients. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + updated := *mocks[0] + updated.WebSocket = &mock.WebSocketSpec{ + Path: "/ws/update-test", + DefaultResponse: &mock.WSMessageResponse{Type: "text", Value: "v2"}, + } + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // The client must receive a close frame with code 1012 Service Restart. + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + require.Error(t, readErr, "expected connection to be closed by server") + + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// TestWS_MockDelete_ClosesActiveConnections verifies that deleting a WebSocket mock +// sends close code 1012 (Service Restart) to all connected clients. +func TestWS_MockDelete_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/delete-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/delete-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Delete the mock — active clients should receive close code 1012. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + require.Error(t, readErr, "expected connection to be closed by server") + + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// TestWS_MockToggleDisable_ClosesActiveConnections verifies that disabling a WebSocket +// mock (toggle to enabled=false) sends close code 1012 to all connected clients. +// This covers the path: POST /mocks/{id}/toggle → UpdateMock → unregisterHandlerLocked. +func TestWS_MockToggleDisable_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/toggle-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/toggle-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Disable the mock — active clients should receive close code 1012. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + disabled := *mocks[0] + enabled := false + disabled.Enabled = &enabled + require.NoError(t, adapter.UpdateMock(disabled.ID, &disabled)) + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + assertCloseCode1012(t, readErr) +} + +// TestWS_ClearMocks_ClosesActiveConnections verifies that clearing all mocks +// sends close code 1012 to all connected WebSocket clients. +// This covers the path used by bulk DELETE /mocks and POST /config?replace=true. +func TestWS_ClearMocks_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/clear-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/clear-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Clear all mocks — the WebSocket client must receive close code 1012. + adapter.ClearMocks() + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + assertCloseCode1012(t, readErr) +} + // itoa converts int to string without importing strconv. func itoa(i int) string { if i == 0 {