From ca247e6b2deb1a781ff0aed4f509c649bf5a362a Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Fri, 3 Apr 2026 20:03:34 +0300 Subject: [PATCH 01/18] feat: add API endpoint to send messages to WebSocket connections --- docs/src/content/docs/reference/admin-api.md | 32 ++- pkg/admin/engineclient/client.go | 103 ++++++++ pkg/admin/engineclient/client_test.go | 139 +++++++++++ pkg/admin/routes.go | 7 + pkg/admin/types.go | 8 + pkg/admin/websocket_handlers.go | 209 ++++++++++++++++ pkg/admin/websocket_handlers_test.go | 239 +++++++++++++++++++ pkg/engine/api/handlers.go | 37 +++ pkg/engine/api/handlers_test.go | 9 + pkg/engine/api/server.go | 2 + pkg/engine/control_api.go | 24 +- 11 files changed, 804 insertions(+), 5 deletions(-) create mode 100644 pkg/admin/websocket_handlers.go create mode 100644 pkg/admin/websocket_handlers_test.go diff --git a/docs/src/content/docs/reference/admin-api.md b/docs/src/content/docs/reference/admin-api.md index 784eddde..ea947484 100644 --- a/docs/src/content/docs/reference/admin-api.md +++ b/docs/src/content/docs/reference/admin-api.md @@ -972,13 +972,37 @@ Send a message to a specific connection. #### POST /admin/ws/broadcast -Broadcast message to all connections. +#### POST /websocket/connections/{id}/send -#### GET /admin/ws/endpoints +Send a text or binary message to a specific active WebSocket connection. -List configured WebSocket endpoints. +**Request:** + +```json +{ + "type": "text", + "data": "Hello from server" +} +``` + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Message type: `"text"` (default) or `"binary"` | +| `data` | string | Message payload | + +**Response:** + +```json +{ + "message": "Message sent", + "connection": "ws-abc123", + "type": "text" +} +``` + +Returns `404` if the connection is not found. -#### GET /admin/ws/stats +#### GET /websocket/stats Get WebSocket statistics. diff --git a/pkg/admin/engineclient/client.go b/pkg/admin/engineclient/client.go index 7d3815ea..dc64b8b0 100644 --- a/pkg/admin/engineclient/client.go +++ b/pkg/admin/engineclient/client.go @@ -1077,6 +1077,109 @@ func (c *Client) GetSSEStats(ctx context.Context) (*SSEStats, error) { return &stats, nil } +// ListWebSocketConnections returns all active WebSocket connections. +func (c *Client) ListWebSocketConnections(ctx context.Context) ([]*WebSocketConnection, error) { + resp, err := c.get(ctx, "/websocket/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Connections []*WebSocketConnection `json:"connections"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket connections: %w", err) + } + return result.Connections, nil +} + +// GetWebSocketConnection returns a specific WebSocket connection by ID. +func (c *Client) GetWebSocketConnection(ctx context.Context, id string) (*WebSocketConnection, error) { + resp, err := c.get(ctx, "/websocket/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var conn WebSocketConnection + if err := json.NewDecoder(resp.Body).Decode(&conn); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket connection: %w", err) + } + return &conn, nil +} + +// CloseWebSocketConnection closes a WebSocket connection by ID. +func (c *Client) CloseWebSocketConnection(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/websocket/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// SendToWebSocketConnection sends a text or binary message to a specific connection. +// msgType must be "text" (default) or "binary". +// For binary messages, data should contain the raw bytes as a string. +func (c *Client) SendToWebSocketConnection(ctx context.Context, id string, msgType string, data string) error { + body := map[string]string{ + "type": msgType, + "data": data, + } + resp, err := c.post(ctx, "/websocket/connections/"+url.PathEscape(id)+"/send", body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return c.parseError(resp) + } + return nil +} + +// GetWebSocketStats returns WebSocket statistics. +func (c *Client) GetWebSocketStats(ctx context.Context) (*WebSocketStats, error) { + resp, err := c.get(ctx, "/websocket/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats WebSocketStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode WebSocket stats: %w", err) + } + return &stats, nil +} + // HTTP helpers func (c *Client) get(ctx context.Context, path string) (*http.Response, error) { diff --git a/pkg/admin/engineclient/client_test.go b/pkg/admin/engineclient/client_test.go index d01491fc..dbcefe1c 100644 --- a/pkg/admin/engineclient/client_test.go +++ b/pkg/admin/engineclient/client_test.go @@ -750,6 +750,145 @@ func TestHTTPMethods(t *testing.T) { } } +// --- WebSocket Tests --- + +func TestListWebSocketConnections_Success(t *testing.T) { + resp := struct { + Connections []*WebSocketConnection `json:"connections"` + Count int `json:"count"` + }{ + Connections: []*WebSocketConnection{{ID: "ws-1"}, {ID: "ws-2"}}, + Count: 2, + } + _, c := mockServer(t, jsonHandler(t, 200, resp)) + + conns, err := c.ListWebSocketConnections(context.Background()) + if err != nil { + t.Fatalf("ListWebSocketConnections() error = %v", err) + } + if len(conns) != 2 { + t.Errorf("ListWebSocketConnections() = %d, want 2", len(conns)) + } + if conns[0].ID != "ws-1" { + t.Errorf("ListWebSocketConnections()[0].ID = %q, want %q", conns[0].ID, "ws-1") + } +} + +func TestListWebSocketConnections_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 503, nil)) + _, err := c.ListWebSocketConnections(context.Background()) + if err == nil { + t.Error("ListWebSocketConnections() error = nil, want error for 503") + } +} + +func TestGetWebSocketConnection_Success(t *testing.T) { + conn := WebSocketConnection{ID: "ws-1", Status: "connected"} + _, c := mockServer(t, jsonHandler(t, 200, conn)) + + result, err := c.GetWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Fatalf("GetWebSocketConnection() error = %v", err) + } + if result.ID != "ws-1" { + t.Errorf("GetWebSocketConnection().ID = %q, want %q", result.ID, "ws-1") + } +} + +func TestGetWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + _, err := c.GetWebSocketConnection(context.Background(), "missing") + if !errors.Is(err, ErrNotFound) { + t.Errorf("GetWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestCloseWebSocketConnection_Success(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 200, map[string]string{"message": "closed"})) + err := c.CloseWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Errorf("CloseWebSocketConnection() error = %v, want nil", err) + } +} + +func TestCloseWebSocketConnection_NoContent(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 204, nil)) + err := c.CloseWebSocketConnection(context.Background(), "ws-1") + if err != nil { + t.Errorf("CloseWebSocketConnection() 204 error = %v, want nil", err) + } +} + +func TestCloseWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + err := c.CloseWebSocketConnection(context.Background(), "missing") + if !errors.Is(err, ErrNotFound) { + t.Errorf("CloseWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestSendToWebSocketConnection_Success(t *testing.T) { + var capturedBody map[string]string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewDecoder(r.Body).Decode(&capturedBody) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"message": "sent"}) + })) + defer ts.Close() + c := New(ts.URL) + + err := c.SendToWebSocketConnection(context.Background(), "ws-1", "text", "hello") + if err != nil { + t.Fatalf("SendToWebSocketConnection() error = %v", err) + } + if capturedBody["type"] != "text" { + t.Errorf("request body type = %q, want %q", capturedBody["type"], "text") + } + if capturedBody["data"] != "hello" { + t.Errorf("request body data = %q, want %q", capturedBody["data"], "hello") + } +} + +func TestSendToWebSocketConnection_NotFound(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 404, nil)) + err := c.SendToWebSocketConnection(context.Background(), "missing", "text", "hello") + if !errors.Is(err, ErrNotFound) { + t.Errorf("SendToWebSocketConnection() error = %v, want ErrNotFound", err) + } +} + +func TestSendToWebSocketConnection_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 500, ErrorResponse{Error: "engine_error", Message: "failed"})) + err := c.SendToWebSocketConnection(context.Background(), "ws-1", "text", "hello") + if err == nil { + t.Error("SendToWebSocketConnection() error = nil, want error for 500") + } +} + +func TestGetWebSocketStats_Success(t *testing.T) { + stats := WebSocketStats{ActiveConnections: 3, TotalConnections: 10} + _, c := mockServer(t, jsonHandler(t, 200, stats)) + + result, err := c.GetWebSocketStats(context.Background()) + if err != nil { + t.Fatalf("GetWebSocketStats() error = %v", err) + } + if result.ActiveConnections != 3 { + t.Errorf("GetWebSocketStats().ActiveConnections = %d, want 3", result.ActiveConnections) + } + if result.TotalConnections != 10 { + t.Errorf("GetWebSocketStats().TotalConnections = %d, want 10", result.TotalConnections) + } +} + +func TestGetWebSocketStats_Error(t *testing.T) { + _, c := mockServer(t, jsonHandler(t, 503, nil)) + _, err := c.GetWebSocketStats(context.Background()) + if err == nil { + t.Error("GetWebSocketStats() error = nil, want error for 503") + } +} + // containsStr checks if s contains substr. func containsStr(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) diff --git a/pkg/admin/routes.go b/pkg/admin/routes.go index 0cff1521..579e3e3e 100644 --- a/pkg/admin/routes.go +++ b/pkg/admin/routes.go @@ -123,6 +123,13 @@ func (a *API) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("DELETE /sse/connections/{id}", a.handleCloseSSEConnection) mux.HandleFunc("GET /sse/stats", a.handleGetSSEStats) + // WebSocket connection management + mux.HandleFunc("GET /websocket/connections", a.handleListWebSocketConnections) + mux.HandleFunc("GET /websocket/connections/{id}", a.handleGetWebSocketConnection) + mux.HandleFunc("DELETE /websocket/connections/{id}", a.handleCloseWebSocketConnection) + mux.HandleFunc("POST /websocket/connections/{id}/send", a.handleSendToWebSocketConnection) + mux.HandleFunc("GET /websocket/stats", a.handleGetWebSocketStats) + // Mock-specific SSE endpoints mux.HandleFunc("GET /mocks/{id}/sse/connections", a.requireEngine(a.handleListMockSSEConnections)) mux.HandleFunc("DELETE /mocks/{id}/sse/connections", a.requireEngine(a.handleCloseMockSSEConnections)) diff --git a/pkg/admin/types.go b/pkg/admin/types.go index e0e192f6..7f1ada16 100644 --- a/pkg/admin/types.go +++ b/pkg/admin/types.go @@ -42,3 +42,11 @@ type MockInvocationListResponse struct { Count int `json:"count"` Total int `json:"total"` } + +// WebSocketSendRequest represents a request to send a message to a WebSocket connection. +type WebSocketSendRequest struct { + // Type is the message type: "text" (default) or "binary". + Type string `json:"type"` + // Data is the message payload. + Data string `json:"data"` +} diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go new file mode 100644 index 00000000..4ac59c15 --- /dev/null +++ b/pkg/admin/websocket_handlers.go @@ -0,0 +1,209 @@ +package admin + +import ( + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// WebSocketConnectionListResponse represents a list of WebSocket connections with stats. +type WebSocketConnectionListResponse struct { + Connections []*engineclient.WebSocketConnection `json:"connections"` + Stats engineclient.WebSocketStats `json:"stats"` +} + +// handleListWebSocketConnections handles GET /websocket/connections. +func (a *API) handleListWebSocketConnections(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ + Connections: []*engineclient.WebSocketConnection{}, + Stats: engineclient.WebSocketStats{ConnectionsByMock: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetWebSocketStats(ctx) + if err != nil { + a.logger().Error("failed to get WebSocket stats", "error", err) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket stats") + writeError(w, status, code, msg) + return + } + + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + a.logger().Error("failed to list WebSocket connections", "error", err) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "list WebSocket connections") + writeError(w, status, code, msg) + return + } + + if connections == nil { + connections = []*engineclient.WebSocketConnection{} + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ + Connections: connections, + Stats: engineclient.WebSocketStats{ + TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + ConnectionsByMock: connsByMock, + }, + }) +} + +// handleGetWebSocketConnection handles GET /websocket/connections/{id}. +func (a *API) handleGetWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + conn, err := engine.GetWebSocketConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to get WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, conn) +} + +// handleCloseWebSocketConnection handles DELETE /websocket/connections/{id}. +func (a *API) handleCloseWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.CloseWebSocketConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to close WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "close WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connection closed", + "connection": id, + }) +} + +// handleGetWebSocketStats handles GET /websocket/stats. +func (a *API) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ + ConnectionsByMock: make(map[string]int), + }) + return + } + + stats, err := engine.GetWebSocketStats(ctx) + if err != nil { + a.logger().Error("failed to get WebSocket stats", "error", err) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket stats") + writeError(w, status, code, msg) + return + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ + TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + ConnectionsByMock: connsByMock, + }) +} + +func mapWebSocketEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} + +// handleSendToWebSocketConnection handles POST /websocket/connections/{id}/send. +// It forwards a text or binary message to a specific active connection through the engine. +func (a *API) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + var req WebSocketSendRequest + if err := decodeOptionalJSONBody(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_body", "Invalid JSON in request body") + return + } + if req.Type == "" { + req.Type = "text" + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.SendToWebSocketConnection(ctx, id, req.Type, req.Data) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to send to WebSocket connection", "error", err, "connectionID", id) + status, code, msg := mapWebSocketEngineError(err, a.logger(), "send to WebSocket connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Message sent", + "connection": id, + "type": req.Type, + }) +} diff --git a/pkg/admin/websocket_handlers_test.go b/pkg/admin/websocket_handlers_test.go new file mode 100644 index 00000000..5af8b298 --- /dev/null +++ b/pkg/admin/websocket_handlers_test.go @@ -0,0 +1,239 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/getmockd/mockd/pkg/admin/engineclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// handleListWebSocketConnections +// ============================================================================ + +func TestHandleListWebSocketConnections_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.NotNil(t, resp.Stats.ConnectionsByMock) +} + +func TestHandleListWebSocketConnections_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetWebSocketConnection +// ============================================================================ + +func TestHandleGetWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/", nil) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleCloseWebSocketConnection +// ============================================================================ + +func TestHandleCloseWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/", nil) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCloseWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCloseWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/conn-1", nil) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetWebSocketStats +// ============================================================================ + +func TestHandleGetWebSocketStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) +} + +func TestHandleGetWebSocketStats_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleSendToWebSocketConnection +// ============================================================================ + +func TestHandleSendToWebSocketConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections//send", strings.NewReader(`{"type":"text","data":"hello"}`)) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSendToWebSocketConnection_InvalidJSON_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", strings.NewReader(`{invalid`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleSendToWebSocketConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleSendToWebSocketConnection_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +func TestHandleSendToWebSocketConnection_EmptyBody_DefaultsToText(t *testing.T) { + // No engine — expects 404, but verifies empty body doesn't fail decode + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", strings.NewReader("")) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + // No engine → 404, not a 400 parse error + assert.Equal(t, http.StatusNotFound, rec.Code) +} diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index 6b63c320..9bd9c4f8 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -816,6 +816,43 @@ func (s *Server) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) writeJSON(w, http.StatusOK, stats) } +// handleSendToWebSocketConnection handles POST /websocket/connections/{id}/send. +// It sends a text or binary message to a specific active WebSocket connection. +func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Connection ID is required") + return + } + + limitedBody(w, r) + + var req struct { + // Type is "text" (default) or "binary" (data must be plain bytes). + Type string `json:"type"` + // Data is the message payload. For binary messages, pass raw bytes as a string. + Data string `json:"data"` + } + if err := decodeJSONBody(r, &req, false); err != nil { + writeDecodeError(w, err) + return + } + if req.Type == "" { + req.Type = "text" + } + + if err := s.engine.SendToWebSocketConnection(id, req.Type, []byte(req.Data)); err != nil { + writeError(w, http.StatusNotFound, "not_found", "WebSocket connection not found") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Message sent", + "connection": id, + "type": req.Type, + }) +} + // Config handlers func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index 2f3957fc..36f6cac7 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -358,6 +358,15 @@ func (m *mockEngine) GetWebSocketStats() *WebSocketStats { return m.wsStats } +func (m *mockEngine) SendToWebSocketConnection(id string, msgType string, data []byte) error { + for _, c := range m.wsConnections { + if c.ID == id { + return nil + } + } + return errors.New("connection not found") +} + func (m *mockEngine) GetConfig() *ConfigResponse { return m.configResp } diff --git a/pkg/engine/api/server.go b/pkg/engine/api/server.go index 760d492b..0650dab4 100644 --- a/pkg/engine/api/server.go +++ b/pkg/engine/api/server.go @@ -89,6 +89,7 @@ type EngineController interface { ListWebSocketConnections() []*WebSocketConnection GetWebSocketConnection(id string) *WebSocketConnection CloseWebSocketConnection(id string) error + SendToWebSocketConnection(id string, msgType string, data []byte) error GetWebSocketStats() *WebSocketStats // Config @@ -224,6 +225,7 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /websocket/connections/{id}", s.handleGetWebSocketConnection) mux.HandleFunc("DELETE /websocket/connections/{id}", s.handleCloseWebSocketConnection) mux.HandleFunc("GET /websocket/stats", s.handleGetWebSocketStats) + mux.HandleFunc("POST /websocket/connections/{id}/send", s.handleSendToWebSocketConnection) // Config mux.HandleFunc("GET /config", s.handleGetConfig) diff --git a/pkg/engine/control_api.go b/pkg/engine/control_api.go index dda67441..b85b2851 100644 --- a/pkg/engine/control_api.go +++ b/pkg/engine/control_api.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" - types "github.com/getmockd/mockd/pkg/api/types" + "github.com/getmockd/mockd/pkg/api/types" "github.com/getmockd/mockd/pkg/chaos" "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/engine/api" "github.com/getmockd/mockd/pkg/protocol" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" ) // Errors returned by the control API adapter. @@ -737,6 +738,27 @@ func (a *ControlAPIAdapter) CloseWebSocketConnection(id string) error { return wsManager.CloseConnection(id, "closed by API") } +// SendToWebSocketConnection implements api.EngineController. +// msgType must be "text" or "binary"; data is the raw message payload. +func (a *ControlAPIAdapter) SendToWebSocketConnection(id string, msgType string, data []byte) error { + handler := a.server.Handler() + if handler == nil { + return ErrWebSocketHandlerNotInitialized + } + + wsManager := handler.WebSocketManager() + if wsManager == nil { + return ErrWebSocketHandlerNotInitialized + } + + mt := websocket.MessageText + if msgType == "binary" { + mt = websocket.MessageBinary + } + + return wsManager.SendToConnection(id, mt, data) +} + // GetWebSocketStats implements api.EngineController. func (a *ControlAPIAdapter) GetWebSocketStats() *api.WebSocketStats { handler := a.server.Handler() From fff6da17e9724534da4d2ba648200c7ccb15fc2b Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Sat, 4 Apr 2026 06:52:02 +0300 Subject: [PATCH 02/18] feat: refactor WebSocket and SSE stats handlers to reduce code duplication --- pkg/admin/sse_handlers.go | 30 +------- pkg/admin/stat_helper.go | 128 ++++++++++++++++++++++++++++++++ pkg/admin/websocket_handlers.go | 30 +------- 3 files changed, 132 insertions(+), 56 deletions(-) create mode 100644 pkg/admin/stat_helper.go diff --git a/pkg/admin/sse_handlers.go b/pkg/admin/sse_handlers.go index 0e741a51..090e9ab5 100644 --- a/pkg/admin/sse_handlers.go +++ b/pkg/admin/sse_handlers.go @@ -149,35 +149,9 @@ func (a *API) handleCloseSSEConnection(w http.ResponseWriter, r *http.Request) { // handleGetSSEStats handles GET /sse/stats. func (a *API) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - engine := a.localEngine.Load() - if engine == nil { - writeJSON(w, http.StatusOK, sse.ConnectionStats{ - ConnectionsByMock: make(map[string]int), - }) - return - } - - stats, err := engine.GetSSEStats(ctx) - if err != nil { - a.logger().Error("failed to get SSE stats", "error", err) - status, code, msg := mapSSEEngineError(err, a.logger(), "get SSE stats") - writeError(w, status, code, msg) - return - } - - connsByMock := stats.ConnectionsByMock - if connsByMock == nil { - connsByMock = make(map[string]int) - } - writeJSON(w, http.StatusOK, sse.ConnectionStats{ - ActiveConnections: stats.ActiveConnections, - TotalConnections: stats.TotalConnections, - TotalEventsSent: stats.TotalEventsSent, - TotalBytesSent: stats.TotalBytesSent, - ConnectionsByMock: connsByMock, - }) + provider := newSSEStatsProvider(engine) + a.handleGetStats(w, r, provider) } // handleListMockSSEConnections handles GET /mocks/{id}/sse/connections. diff --git a/pkg/admin/stat_helper.go b/pkg/admin/stat_helper.go new file mode 100644 index 00000000..342e8046 --- /dev/null +++ b/pkg/admin/stat_helper.go @@ -0,0 +1,128 @@ +package admin + +import ( + "context" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" + "github.com/getmockd/mockd/pkg/sse" +) + +// statsProvider defines the interface for retrieving and formatting statistics. +type statsProvider interface { + // GetStats retrieves statistics from the engine. + GetStats(ctx context.Context) (interface{}, error) + // GetEmptyStats returns empty statistics for when engine is unavailable. + GetEmptyStats() interface{} + // MapError converts an error to HTTP status, code, and message. + MapError(err error, log *slog.Logger, operation string) (int, string, string) +} + +// handleGetStats is a generic handler for retrieving statistics. +// It eliminates code duplication between SSE and WebSocket stats handlers. +func (a *API) handleGetStats(w http.ResponseWriter, r *http.Request, provider statsProvider) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, provider.GetEmptyStats()) + return + } + + stats, err := provider.GetStats(ctx) + if err != nil { + a.logger().Error("failed to get stats", "error", err) + status, code, msg := provider.MapError(err, a.logger(), "get stats") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, stats) +} + +// sseStatsProvider implements statsProvider for SSE statistics. +type sseStatsProvider struct { + engine *engineclient.Client +} + +// newSSEStatsProvider creates a new SSE statistics provider. +func newSSEStatsProvider(engine *engineclient.Client) *sseStatsProvider { + return &sseStatsProvider{engine: engine} +} + +// GetStats retrieves SSE statistics from the engine. +func (p *sseStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetSSEStats(ctx) + if err != nil { + return nil, err + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + return sse.ConnectionStats{ + ActiveConnections: stats.ActiveConnections, + TotalConnections: stats.TotalConnections, + TotalEventsSent: stats.TotalEventsSent, + TotalBytesSent: stats.TotalBytesSent, + ConnectionsByMock: connsByMock, + }, nil +} + +// GetEmptyStats returns empty SSE statistics. +func (p *sseStatsProvider) GetEmptyStats() interface{} { + return sse.ConnectionStats{ + ConnectionsByMock: make(map[string]int), + } +} + +// MapError converts an SSE engine error to HTTP status, code, and message. +func (p *sseStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapSSEEngineError(err, log, operation) +} + +// wsStatsProvider implements statsProvider for WebSocket statistics. +type wsStatsProvider struct { + engine *engineclient.Client +} + +// newWSStatsProvider creates a new WebSocket statistics provider. +func newWSStatsProvider(engine *engineclient.Client) *wsStatsProvider { + return &wsStatsProvider{engine: engine} +} + +// GetStats retrieves WebSocket statistics from the engine. +func (p *wsStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetWebSocketStats(ctx) + if err != nil { + return nil, err + } + + connsByMock := stats.ConnectionsByMock + if connsByMock == nil { + connsByMock = make(map[string]int) + } + + return engineclient.WebSocketStats{ + TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + ConnectionsByMock: connsByMock, + }, nil +} + +// GetEmptyStats returns empty WebSocket statistics. +func (p *wsStatsProvider) GetEmptyStats() interface{} { + return engineclient.WebSocketStats{ + ConnectionsByMock: make(map[string]int), + } +} + +// MapError converts a WebSocket engine error to HTTP status, code, and message. +func (p *wsStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapWebSocketEngineError(err, log, operation) +} diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go index 4ac59c15..82faf69b 100644 --- a/pkg/admin/websocket_handlers.go +++ b/pkg/admin/websocket_handlers.go @@ -129,35 +129,9 @@ func (a *API) handleCloseWebSocketConnection(w http.ResponseWriter, r *http.Requ // handleGetWebSocketStats handles GET /websocket/stats. func (a *API) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - engine := a.localEngine.Load() - if engine == nil { - writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ - ConnectionsByMock: make(map[string]int), - }) - return - } - - stats, err := engine.GetWebSocketStats(ctx) - if err != nil { - a.logger().Error("failed to get WebSocket stats", "error", err) - status, code, msg := mapWebSocketEngineError(err, a.logger(), "get WebSocket stats") - writeError(w, status, code, msg) - return - } - - connsByMock := stats.ConnectionsByMock - if connsByMock == nil { - connsByMock = make(map[string]int) - } - writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ - TotalConnections: stats.TotalConnections, - ActiveConnections: stats.ActiveConnections, - TotalMessagesSent: stats.TotalMessagesSent, - TotalMessagesRecv: stats.TotalMessagesRecv, - ConnectionsByMock: connsByMock, - }) + provider := newWSStatsProvider(engine) + a.handleGetStats(w, r, provider) } func mapWebSocketEngineError(err error, log *slog.Logger, operation string) (int, string, string) { From bcb906c685da82c2f9259ec1cb28d722868ca1c5 Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Sat, 4 Apr 2026 07:27:15 +0300 Subject: [PATCH 03/18] feat: enhance WebSocket connection management with server-initiated messaging and base64 encoding for binary data --- CHANGELOG.md | 1 + pkg/admin/engineclient/client.go | 3 ++- pkg/admin/types.go | 3 +++ pkg/admin/websocket_handlers.go | 4 ++++ pkg/admin/websocket_handlers_test.go | 36 +++++++++++++++++++++++++++- pkg/engine/api/handlers.go | 20 +++++++++++++--- 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ef9fa26..3ee264c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **WebSocket connection management API** — `GET /websocket/connections`, `GET /websocket/connections/{id}`, `DELETE /websocket/connections/{id}`, `POST /websocket/connections/{id}/send`, `GET /websocket/stats` added to the Admin API for real-time visibility, control, and server-initiated messaging of active WebSocket connections - **Workspace-scoped stateful resources** — stateful resources, custom operations, and request logs are now isolated per workspace - **`--workspace` persistent CLI flag** — scope any CLI command to a specific workspace without switching context - **`?workspaceId=` API parameter** — all admin API endpoints now accept workspace filtering diff --git a/pkg/admin/engineclient/client.go b/pkg/admin/engineclient/client.go index dc64b8b0..0fd448ac 100644 --- a/pkg/admin/engineclient/client.go +++ b/pkg/admin/engineclient/client.go @@ -1140,7 +1140,8 @@ func (c *Client) CloseWebSocketConnection(ctx context.Context, id string) error // SendToWebSocketConnection sends a text or binary message to a specific connection. // msgType must be "text" (default) or "binary". -// For binary messages, data should contain the raw bytes as a string. +// For binary messages, data must be a base64-encoded string; the engine decodes +// it before sending the raw bytes over the WebSocket connection. func (c *Client) SendToWebSocketConnection(ctx context.Context, id string, msgType string, data string) error { body := map[string]string{ "type": msgType, diff --git a/pkg/admin/types.go b/pkg/admin/types.go index 7f1ada16..e0877671 100644 --- a/pkg/admin/types.go +++ b/pkg/admin/types.go @@ -48,5 +48,8 @@ type WebSocketSendRequest struct { // Type is the message type: "text" (default) or "binary". Type string `json:"type"` // Data is the message payload. + // For type "text", Data is a plain UTF-8 string. + // For type "binary", Data must be a base64-encoded string; the server decodes it + // before sending the raw bytes over the WebSocket connection. Data string `json:"data"` } diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go index 82faf69b..a778bd12 100644 --- a/pkg/admin/websocket_handlers.go +++ b/pkg/admin/websocket_handlers.go @@ -35,6 +35,10 @@ func (a *API) handleListWebSocketConnections(w http.ResponseWriter, r *http.Requ return } + // Stats and connections are fetched in two separate calls; they are not + // atomically consistent. Connections may change between the two requests, + // so the counts in stats may not exactly match the length of the returned + // connection list. This is intentional — correctness is not required here. connections, err := engine.ListWebSocketConnections(ctx) if err != nil { a.logger().Error("failed to list WebSocket connections", "error", err) diff --git a/pkg/admin/websocket_handlers_test.go b/pkg/admin/websocket_handlers_test.go index 5af8b298..51089c3d 100644 --- a/pkg/admin/websocket_handlers_test.go +++ b/pkg/admin/websocket_handlers_test.go @@ -7,9 +7,10 @@ import ( "strings" "testing" - "github.com/getmockd/mockd/pkg/admin/engineclient" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" ) // ============================================================================ @@ -237,3 +238,36 @@ func TestHandleSendToWebSocketConnection_EmptyBody_DefaultsToText(t *testing.T) // No engine → 404, not a 400 parse error assert.Equal(t, http.StatusNotFound, rec.Code) } + +func TestHandleSendToWebSocketConnection_Success_Returns200(t *testing.T) { + // Spin up a minimal mock engine that accepts the send call and returns 200. + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/send") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"Message sent","connection":"conn-1","type":"text"}`)) + return + } + http.NotFound(w, r) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + body := strings.NewReader(`{"type":"text","data":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", body) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "Message sent", resp["message"]) + assert.Equal(t, "conn-1", resp["connection"]) + assert.Equal(t, "text", resp["type"]) +} diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index 9bd9c4f8..ba16d190 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -828,9 +829,10 @@ func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http. limitedBody(w, r) var req struct { - // Type is "text" (default) or "binary" (data must be plain bytes). + // Type is "text" (default) or "binary". + // For binary messages, Data must be base64-encoded; it is decoded here + // before the raw bytes are sent over the WebSocket connection. Type string `json:"type"` - // Data is the message payload. For binary messages, pass raw bytes as a string. Data string `json:"data"` } if err := decodeJSONBody(r, &req, false); err != nil { @@ -841,7 +843,19 @@ func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http. req.Type = "text" } - if err := s.engine.SendToWebSocketConnection(id, req.Type, []byte(req.Data)); err != nil { + var payload []byte + if req.Type == "binary" { + var err error + payload, err = base64.StdEncoding.DecodeString(req.Data) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_base64", "For binary messages, data must be a base64-encoded string") + return + } + } else { + payload = []byte(req.Data) + } + + if err := s.engine.SendToWebSocketConnection(id, req.Type, payload); err != nil { writeError(w, http.StatusNotFound, "not_found", "WebSocket connection not found") return } From 9b8a6bef43298c7b93f40966394d973f14d59e2b Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Sat, 4 Apr 2026 07:55:55 +0300 Subject: [PATCH 04/18] feat: improve error handling for WebSocket message sending and enhance API documentation --- .github/workflows/docker.yaml | 6 +- docs/src/content/docs/reference/admin-api.md | 2 +- pkg/admin/engineclient/types.go | 2 + pkg/admin/sse_handlers.go | 7 +- pkg/admin/stat_helper.go | 9 +- pkg/admin/websocket_handlers.go | 7 +- pkg/engine/api/handlers.go | 7 +- pkg/engine/api/handlers_test.go | 347 ++++++++++++++++++- 8 files changed, 372 insertions(+), 15 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 77a6db4e..83472f74 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -1,10 +1,13 @@ name: Docker on: + pull_request: + branches: [main] push: branches: [main] tags: - "v*" + workflow_dispatch: permissions: contents: read @@ -29,6 +32,7 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} @@ -63,7 +67,7 @@ jobs: with: context: . platforms: linux/amd64,linux/arm64 - push: true + push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | diff --git a/docs/src/content/docs/reference/admin-api.md b/docs/src/content/docs/reference/admin-api.md index ea947484..994dc36a 100644 --- a/docs/src/content/docs/reference/admin-api.md +++ b/docs/src/content/docs/reference/admin-api.md @@ -988,7 +988,7 @@ Send a text or binary message to a specific active WebSocket connection. | Field | Type | Description | |-------|------|-------------| | `type` | string | Message type: `"text"` (default) or `"binary"` | -| `data` | string | Message payload | +| `data` | string | Message payload. For `"text"`, a plain UTF-8 string. For `"binary"`, a **base64-encoded** string — the server decodes it before writing raw bytes to the WebSocket. | **Response:** diff --git a/pkg/admin/engineclient/types.go b/pkg/admin/engineclient/types.go index 19367836..be7698ee 100644 --- a/pkg/admin/engineclient/types.go +++ b/pkg/admin/engineclient/types.go @@ -47,6 +47,8 @@ type ( ProtocolHandler = types.ProtocolHandler SSEConnection = types.SSEConnection SSEStats = types.SSEStats + WebSocketConnection = types.WebSocketConnection + WebSocketStats = types.WebSocketStats CustomOperationInfo = types.CustomOperationInfo CustomOperationDetail = types.CustomOperationDetail CustomOperationStep = types.CustomOperationStep diff --git a/pkg/admin/sse_handlers.go b/pkg/admin/sse_handlers.go index 090e9ab5..609f38c3 100644 --- a/pkg/admin/sse_handlers.go +++ b/pkg/admin/sse_handlers.go @@ -150,8 +150,11 @@ func (a *API) handleCloseSSEConnection(w http.ResponseWriter, r *http.Request) { // handleGetSSEStats handles GET /sse/stats. func (a *API) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { engine := a.localEngine.Load() - provider := newSSEStatsProvider(engine) - a.handleGetStats(w, r, provider) + if engine == nil { + writeJSON(w, http.StatusOK, sse.ConnectionStats{ConnectionsByMock: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newSSEStatsProvider(engine)) } // handleListMockSSEConnections handles GET /mocks/{id}/sse/connections. diff --git a/pkg/admin/stat_helper.go b/pkg/admin/stat_helper.go index 342e8046..17ddd6c0 100644 --- a/pkg/admin/stat_helper.go +++ b/pkg/admin/stat_helper.go @@ -20,16 +20,11 @@ type statsProvider interface { } // handleGetStats is a generic handler for retrieving statistics. -// It eliminates code duplication between SSE and WebSocket stats handlers. +// The caller must guard against a nil engine and handle the empty-stats +// response before constructing the provider and invoking this function. func (a *API) handleGetStats(w http.ResponseWriter, r *http.Request, provider statsProvider) { ctx := r.Context() - engine := a.localEngine.Load() - if engine == nil { - writeJSON(w, http.StatusOK, provider.GetEmptyStats()) - return - } - stats, err := provider.GetStats(ctx) if err != nil { a.logger().Error("failed to get stats", "error", err) diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go index a778bd12..3d44eb2e 100644 --- a/pkg/admin/websocket_handlers.go +++ b/pkg/admin/websocket_handlers.go @@ -134,8 +134,11 @@ func (a *API) handleCloseWebSocketConnection(w http.ResponseWriter, r *http.Requ // handleGetWebSocketStats handles GET /websocket/stats. func (a *API) handleGetWebSocketStats(w http.ResponseWriter, r *http.Request) { engine := a.localEngine.Load() - provider := newWSStatsProvider(engine) - a.handleGetStats(w, r, provider) + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.WebSocketStats{ConnectionsByMock: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newWSStatsProvider(engine)) } func mapWebSocketEngineError(err error, log *slog.Logger, operation string) (int, string, string) { diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index ba16d190..f220d54e 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -15,6 +15,7 @@ import ( "github.com/getmockd/mockd/pkg/httputil" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" ) func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { @@ -856,7 +857,11 @@ func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http. } if err := s.engine.SendToWebSocketConnection(id, req.Type, payload); err != nil { - writeError(w, http.StatusNotFound, "not_found", "WebSocket connection not found") + if errors.Is(err, websocket.ErrConnectionNotFound) || errors.Is(err, websocket.ErrConnectionClosed) { + writeError(w, http.StatusNotFound, "not_found", "WebSocket connection not found or closed") + } else { + writeError(w, http.StatusInternalServerError, "send_error", "Failed to send message to WebSocket connection") + } return } diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index 36f6cac7..d8723515 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -14,6 +14,7 @@ import ( "github.com/getmockd/mockd/pkg/mock" "github.com/getmockd/mockd/pkg/requestlog" "github.com/getmockd/mockd/pkg/stateful" + "github.com/getmockd/mockd/pkg/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -48,6 +49,9 @@ type mockEngine struct { resetStateErr error listStatefulItemsErr error getStatefulItemErr error + // wsSendErr allows injecting a specific error for a connection ID in + // SendToWebSocketConnection. Keyed by connection ID. + wsSendErr map[string]error } func newMockEngine() *mockEngine { @@ -55,6 +59,7 @@ func newMockEngine() *mockEngine { mocks: make(map[string]*config.MockConfiguration), requestLogs: make(map[string]*requestlog.Entry), customOps: make(map[string]*CustomOperationDetail), + wsSendErr: make(map[string]error), running: true, uptime: 100, protocols: map[string]ProtocolStatusInfo{ @@ -359,12 +364,16 @@ func (m *mockEngine) GetWebSocketStats() *WebSocketStats { } func (m *mockEngine) SendToWebSocketConnection(id string, msgType string, data []byte) error { + // Allow per-connection error injection for testing specific error paths. + if err, ok := m.wsSendErr[id]; ok { + return err + } for _, c := range m.wsConnections { if c.ID == id { return nil } } - return errors.New("connection not found") + return websocket.ErrConnectionNotFound } func (m *mockEngine) GetConfig() *ConfigResponse { @@ -2232,3 +2241,339 @@ func TestCustomOperationFullCRUD(t *testing.T) { server.handleGetCustomOperation(getRec2, getReq2) assert.Equal(t, http.StatusNotFound, getRec2.Code) } + +// TestHandleListWebSocketConnections tests the GET /websocket/connections handler. +func TestHandleListWebSocketConnections(t *testing.T) { + t.Run("returns empty list when no connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + server.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.Equal(t, 0, resp.Count) + }) + + t.Run("returns all connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + {ID: "ws-2", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + server.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 2) + assert.Equal(t, 2, resp.Count) + }) +} + +// TestHandleGetWebSocketConnection tests the GET /websocket/connections/{id} handler. +func TestHandleGetWebSocketConnection(t *testing.T) { + t.Run("returns connection when found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/ws-1", nil) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn WebSocketConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "ws-1", conn.ID) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleGetWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleCloseWebSocketConnection tests the DELETE /websocket/connections/{id} handler. +func TestHandleCloseWebSocketConnection(t *testing.T) { + t.Run("closes existing connection and returns 200", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{ + {ID: "ws-1", Path: "/ws", Status: "connected"}, + } + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/ws-1", nil) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "WebSocket connection closed", resp["message"]) + assert.Equal(t, "ws-1", resp["id"]) + + // Connection must be removed from the engine. + assert.Empty(t, engine.wsConnections) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/websocket/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleCloseWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleGetWebSocketStats tests the GET /websocket/stats handler. +func TestHandleGetWebSocketStats(t *testing.T) { + t.Run("returns empty stats with non-nil map when wsStats is nil", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // wsStats is nil by default in newMockEngine. + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) + }) + + t.Run("returns populated stats", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsStats = &WebSocketStats{ + TotalConnections: 5, + ActiveConnections: 3, + TotalMessagesSent: 100, + TotalMessagesRecv: 50, + ConnectionsByMock: map[string]int{"mock-1": 3}, + } + + req := httptest.NewRequest(http.MethodGet, "/websocket/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetWebSocketStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats WebSocketStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, int64(5), stats.TotalConnections) + assert.Equal(t, 3, stats.ActiveConnections) + assert.Equal(t, int64(100), stats.TotalMessagesSent) + assert.Equal(t, int64(50), stats.TotalMessagesRecv) + assert.Equal(t, 3, stats.ConnectionsByMock["mock-1"]) + }) +} + +// TestHandleSendToWebSocketConnection covers every branch of the send handler, +// including the new 404/500 distinction introduced in the review fix. +func TestHandleSendToWebSocketConnection(t *testing.T) { + t.Run("returns 400 when connection id is missing", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections//send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + // PathValue "id" intentionally not set → empty string + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "missing_id", resp["error"]) + }) + + t.Run("returns 400 for invalid JSON body", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{invalid`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + }) + + t.Run("returns 400 for invalid base64 binary payload", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"binary","data":"not-valid-base64!!!"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "invalid_base64", resp["error"]) + }) + + t.Run("returns 200 for text message to existing connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "Message sent", resp["message"]) + assert.Equal(t, "ws-1", resp["connection"]) + assert.Equal(t, "text", resp["type"]) + }) + + t.Run("returns 200 for valid base64 binary message", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + // base64("hello") = "aGVsbG8=" + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"binary","data":"aGVsbG8="}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "binary", resp["type"]) + }) + + t.Run("defaults type to text when omitted", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.wsConnections = []*WebSocketConnection{{ID: "ws-1"}} + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "text", resp["type"]) + }) + + t.Run("returns 404 when connection is not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // No connections registered → ErrConnectionNotFound + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/missing/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "missing") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "not_found", resp["error"]) + }) + + t.Run("returns 404 when connection is already closed", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // Inject ErrConnectionClosed for ws-1 (connection exists but is closed) + engine.wsSendErr["ws-1"] = websocket.ErrConnectionClosed + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "not_found", resp["error"]) + }) + + t.Run("returns 500 for unexpected write error", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + // Inject a generic I/O error (e.g., broken pipe) + engine.wsSendErr["ws-1"] = errors.New("write: broken pipe") + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/ws-1/send", + strings.NewReader(`{"type":"text","data":"hello"}`)) + req.SetPathValue("id", "ws-1") + rec := httptest.NewRecorder() + + server.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "send_error", resp["error"]) + }) +} From ab093f4a07e6c6b06e9950971cd9f4f17216d243 Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Sun, 5 Apr 2026 08:19:41 +0300 Subject: [PATCH 05/18] feat: implement WebSocket auto-reconnect on mock update and delete --- CHANGELOG.md | 1 + docs/src/content/docs/reference/admin-api.md | 6 + pkg/engine/handler_protocol.go | 7 + pkg/engine/mock_manager.go | 5 + pkg/websocket/manager.go | 37 +++- pkg/websocket/manager_test.go | 49 +++++ tests/integration/websocket_test.go | 217 ++++++++++++++++++- 7 files changed, 319 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ee264c9..cbaed45f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - **WebSocket connection management API** — `GET /websocket/connections`, `GET /websocket/connections/{id}`, `DELETE /websocket/connections/{id}`, `POST /websocket/connections/{id}/send`, `GET /websocket/stats` added to the Admin API for real-time visibility, control, and server-initiated messaging of active WebSocket connections +- **WebSocket auto-reconnect on mock update** — updating or deleting a WebSocket mock now automatically closes all active connections with close code 1012 (Service Restart) so clients reconnect and pick up the new configuration immediately - **Workspace-scoped stateful resources** — stateful resources, custom operations, and request logs are now isolated per workspace - **`--workspace` persistent CLI flag** — scope any CLI command to a specific workspace without switching context - **`?workspaceId=` API parameter** — all admin API endpoints now accept workspace filtering diff --git a/docs/src/content/docs/reference/admin-api.md b/docs/src/content/docs/reference/admin-api.md index 994dc36a..b5ff9a5b 100644 --- a/docs/src/content/docs/reference/admin-api.md +++ b/docs/src/content/docs/reference/admin-api.md @@ -251,6 +251,12 @@ The response includes ports for: ### Mock Management +:::note[WebSocket: active clients reconnect on mock changes] +When a WebSocket mock is updated or deleted, all clients currently connected to that endpoint receive a **close frame with code 1012 (Service Restart)**. Most WebSocket client libraries treat code 1012 as a signal to reconnect automatically. On reconnect, the client establishes a fresh connection that uses the new mock configuration. + +This applies to `PUT /mocks/{id}`, `DELETE /mocks/{id}`, `POST /mocks/{id}/toggle` (when disabling), bulk `DELETE /mocks` (delete all), and `POST /config` with `replace: true`. +::: + #### GET /mocks List all configured mocks. diff --git a/pkg/engine/handler_protocol.go b/pkg/engine/handler_protocol.go index f2fde048..c5b25861 100644 --- a/pkg/engine/handler_protocol.go +++ b/pkg/engine/handler_protocol.go @@ -150,6 +150,13 @@ func (h *Handler) UnregisterWebSocketEndpoint(path string) { h.wsManager.UnregisterEndpoint(path) } +// DisconnectWebSocketEndpoint closes all active connections on a WebSocket endpoint. +// Uses RFC 6455 close code 1012 (Service Restart) so clients reconnect automatically. +// Must be called before UnregisterWebSocketEndpoint while byEndpoint still tracks the path. +func (h *Handler) DisconnectWebSocketEndpoint(path string) { + h.wsManager.DisconnectByEndpoint(path, websocket.CloseServiceRestart, "mock updated") +} + // ListSOAPHandlerPaths returns all registered SOAP handler paths. func (h *Handler) ListSOAPHandlerPaths() []string { h.soapMu.RLock() diff --git a/pkg/engine/mock_manager.go b/pkg/engine/mock_manager.go index 34e27cbc..76111cea 100644 --- a/pkg/engine/mock_manager.go +++ b/pkg/engine/mock_manager.go @@ -197,6 +197,11 @@ func (mm *MockManager) unregisterHandlerLocked(cfg *config.MockConfiguration) { case mock.TypeWebSocket: if mm.handler != nil && cfg.WebSocket != nil { + // Disconnect active clients first so they receive close code 1012 + // (Service Restart) and reconnect with the new configuration. + // Must happen before UnregisterWebSocketEndpoint while byEndpoint + // still tracks the path. + mm.handler.DisconnectWebSocketEndpoint(cfg.WebSocket.Path) mm.handler.UnregisterWebSocketEndpoint(cfg.WebSocket.Path) } case mock.TypeGraphQL: diff --git a/pkg/websocket/manager.go b/pkg/websocket/manager.go index 2dffe284..0a39ffc7 100644 --- a/pkg/websocket/manager.go +++ b/pkg/websocket/manager.go @@ -102,12 +102,15 @@ func (m *ConnectionManager) RegisterEndpoint(e *Endpoint) { } // UnregisterEndpoint removes an endpoint from the manager. +// Callers that want active clients to reconnect with new config should call +// DisconnectByEndpoint before this method, while byEndpoint still tracks +// the path. The engine's unregisterHandlerLocked does this automatically +// for every mock update, delete, and clear operation. func (m *ConnectionManager) UnregisterEndpoint(path string) { m.mu.Lock() defer m.mu.Unlock() delete(m.endpoints, path) - // Note: connections remain until they close } // GetEndpoint returns an endpoint by path. @@ -255,6 +258,38 @@ func (m *ConnectionManager) CountByEndpoint(path string) int { return 0 } +// DisconnectByEndpoint closes all active connections on an endpoint with the given code and reason. +// Returns the number of connections that were closed. +// +// Must be called before UnregisterEndpoint so that byEndpoint still contains the connections. +// Uses the read-lock/copy pattern (same as BroadcastToEndpoint) to avoid holding the lock +// while calling conn.Close, which acquires its own lock. +func (m *ConnectionManager) DisconnectByEndpoint(path string, code CloseCode, reason string) int { + m.mu.RLock() + var ids []string + if eps, ok := m.byEndpoint[path]; ok { + ids = make([]string, 0, len(eps)) + for id := range eps { + ids = append(ids, id) + } + } + m.mu.RUnlock() + + closed := 0 + for _, id := range ids { + m.mu.RLock() + conn := m.connections[id] + m.mu.RUnlock() + + if conn != nil && !conn.IsClosed() { + if err := conn.Close(code, reason); err == nil { + closed++ + } + } + } + return closed +} + // BroadcastToEndpoint sends a message to all connections on an endpoint. func (m *ConnectionManager) BroadcastToEndpoint(path string, msgType MessageType, data []byte) int { m.mu.RLock() diff --git a/pkg/websocket/manager_test.go b/pkg/websocket/manager_test.go index 3298942e..c592e793 100644 --- a/pkg/websocket/manager_test.go +++ b/pkg/websocket/manager_test.go @@ -230,6 +230,55 @@ func TestConnectionManager_ConcurrentJoinFromBothSides(t *testing.T) { } } +func TestConnectionManager_DisconnectByEndpoint_UnknownPath(t *testing.T) { + manager := NewConnectionManager() + + // Endpoint that was never registered — should return 0 without panicking + count := manager.DisconnectByEndpoint("/ws/unknown", CloseServiceRestart, "mock updated") + if count != 0 { + t.Errorf("expected 0 for unknown path, got %d", count) + } +} + +func TestConnectionManager_DisconnectByEndpoint_SkipsAlreadyClosed(t *testing.T) { + manager := NewConnectionManager() + + // Pre-closed connection — Close must not be called again (nil conn would panic). + conn := &Connection{ + id: "conn-already-closed", + endpointPath: "/ws/test", + groups: make(map[string]struct{}), + } + conn.closed.Store(true) + manager.Add(conn) + + // All connections already closed → count must be 0. + count := manager.DisconnectByEndpoint("/ws/test", CloseServiceRestart, "mock updated") + if count != 0 { + t.Errorf("expected 0 (connection pre-closed), got %d", count) + } +} + +func TestConnectionManager_DisconnectByEndpoint_IsolatesEndpoint(t *testing.T) { + manager := NewConnectionManager() + + // Two endpoints: only /ws/target should be affected. + target := &Connection{id: "t1", endpointPath: "/ws/target", groups: make(map[string]struct{})} + target.closed.Store(true) // pre-close to avoid nil conn panic + other := &Connection{id: "o1", endpointPath: "/ws/other", groups: make(map[string]struct{})} + other.closed.Store(true) + + manager.Add(target) + manager.Add(other) + + manager.DisconnectByEndpoint("/ws/target", CloseServiceRestart, "mock updated") + + // The connection on /ws/other must still be tracked by the manager. + if manager.Get("o1") == nil { + t.Error("connection on /ws/other should not be removed by DisconnectByEndpoint on /ws/target") + } +} + func TestConnectionManager_BroadcastToGroup(t *testing.T) { manager := NewConnectionManager() diff --git a/tests/integration/websocket_test.go b/tests/integration/websocket_test.go index 0e6c910b..92c6a079 100644 --- a/tests/integration/websocket_test.go +++ b/tests/integration/websocket_test.go @@ -2,6 +2,7 @@ package integration import ( "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -9,12 +10,13 @@ import ( "time" ws "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/engine" "github.com/getmockd/mockd/pkg/mock" "github.com/getmockd/mockd/pkg/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // ============================================================================ @@ -799,6 +801,217 @@ func TestWS_FullServer_WithMiddleware(t *testing.T) { assert.Equal(t, testMsg, string(data)) } +// ============================================================================ +// User Story: Mock Update / Delete Closes Active Connections +// ============================================================================ + +// assertCloseCode1012 asserts that readErr is a WebSocket close error with code 1012 +// (Service Restart). Expects readErr to be non-nil (connection already closed by server). +func assertCloseCode1012(t *testing.T, readErr error) { + t.Helper() + require.Error(t, readErr, "expected connection to be closed by server") + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// setupWSMockServer creates a test server that has a WebSocket mock registered +// via the engine's control adapter (not the legacy WebSocketEndpoints collection +// field). Returns the httptest.Server and a ControlAPIAdapter for mock management. +func setupWSMockServer(t *testing.T) (*httptest.Server, *engine.ControlAPIAdapter) { + t.Helper() + cfg := config.DefaultServerConfiguration() + srv := engine.NewServer(cfg) + adapter := engine.NewControlAPIAdapter(srv) + + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(func() { ts.Close() }) + return ts, adapter +} + +// TestWS_MockUpdate_ClosesActiveConnections verifies that updating a WebSocket mock +// sends close code 1012 (Service Restart) to all connected clients so they reconnect +// and receive the new configuration. +func TestWS_MockUpdate_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + // Create initial WebSocket mock. + initialResponse := &mock.WSMessageResponse{Type: "text", Value: "v1"} + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/update-test", + DefaultResponse: initialResponse, + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect a WebSocket client. + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/update-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + wsm := getWSManager(t, ts) + assert.Equal(t, 1, wsm.CountByEndpoint("/ws/update-test"), "expected 1 active connection before update") + + // Update the mock — this should trigger disconnect of all active clients. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + updated := *mocks[0] + updated.WebSocket = &mock.WebSocketSpec{ + Path: "/ws/update-test", + DefaultResponse: &mock.WSMessageResponse{Type: "text", Value: "v2"}, + } + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // The client must receive a close frame with code 1012 Service Restart. + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + require.Error(t, readErr, "expected connection to be closed by server") + + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// TestWS_MockDelete_ClosesActiveConnections verifies that deleting a WebSocket mock +// sends close code 1012 (Service Restart) to all connected clients. +func TestWS_MockDelete_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/delete-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/delete-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Delete the mock — active clients should receive close code 1012. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + require.Error(t, readErr, "expected connection to be closed by server") + + var closeErr ws.CloseError + if errors.As(readErr, &closeErr) { + assert.Equal(t, ws.StatusCode(websocket.CloseServiceRestart), closeErr.Code, + "expected close code 1012 Service Restart") + } +} + +// TestWS_MockToggleDisable_ClosesActiveConnections verifies that disabling a WebSocket +// mock (toggle to enabled=false) sends close code 1012 to all connected clients. +// This covers the path: POST /mocks/{id}/toggle → UpdateMock → unregisterHandlerLocked. +func TestWS_MockToggleDisable_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/toggle-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/toggle-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Disable the mock — active clients should receive close code 1012. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + disabled := *mocks[0] + enabled := false + disabled.Enabled = &enabled + require.NoError(t, adapter.UpdateMock(disabled.ID, &disabled)) + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + assertCloseCode1012(t, readErr) +} + +// TestWS_ClearMocks_ClosesActiveConnections verifies that clearing all mocks +// sends close code 1012 to all connected WebSocket clients. +// This covers the path used by bulk DELETE /mocks and POST /config?replace=true. +func TestWS_ClearMocks_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupWSMockServer(t) + + mockCfg := &mock.Mock{ + Type: mock.TypeWebSocket, + WebSocket: &mock.WebSocketSpec{ + Path: "/ws/clear-test", + }, + } + require.NoError(t, adapter.AddMock(mockCfg)) + + wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/ws/clear-test" + dialCtx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer dialCancel() + + conn, resp, err := ws.Dial(dialCtx, wsURL, nil) + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + require.NoError(t, err, "WebSocket dial failed") + + // Wait for connection to be tracked. + time.Sleep(50 * time.Millisecond) + + // Clear all mocks — the WebSocket client must receive close code 1012. + adapter.ClearMocks() + + readCtx, readCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer readCancel() + + _, _, readErr := conn.Read(readCtx) + assertCloseCode1012(t, readErr) +} + // itoa converts int to string without importing strconv. func itoa(i int) string { if i == 0 { From 9657a8c8a039e1ed55e1520058718aed1cbb2757 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Sun, 5 Apr 2026 18:46:27 -0500 Subject: [PATCH 06/18] feat: add MQTT connection management API with per-client control Add full MQTT connection management across all API layers: - Broker: ListClientInfos, GetClientInfo, DisconnectClient, GetConnectionStats - Engine API: GET/DELETE /mqtt/connections/{id}, GET /mqtt/stats - Control API adapter: bridges broker methods to engine API - Engine client: HTTP client methods for MQTT endpoints - Admin API: handlers with stats provider pattern, routes at /mqtt-connections/ - Auto-disconnect: cleanly disconnect all MQTT clients before stopping broker on mock update/delete, giving clients a chance to reconnect Co-Authored-By: Claude Opus 4.6 --- pkg/admin/engineclient/client.go | 160 ++++++++++++++++ pkg/admin/engineclient/types.go | 4 + pkg/admin/mqtt_handlers.go | 144 ++++++++++++++ pkg/admin/mqtt_handlers_test.go | 251 +++++++++++++++++++++++++ pkg/admin/routes.go | 14 ++ pkg/admin/stat_helper.go | 67 +++++-- pkg/api/types/responses.go | 77 +++++++- pkg/engine/api/handlers.go | 89 +++++++++ pkg/engine/api/handlers_test.go | 191 ++++++++++++++++++- pkg/engine/api/server.go | 24 +++ pkg/engine/control_api.go | 214 +++++++++++++++++++++ pkg/engine/mock_manager.go | 31 ++- pkg/mqtt/broker.go | 119 ++++++++++++ pkg/mqtt/connection_management_test.go | 104 ++++++++++ 14 files changed, 1458 insertions(+), 31 deletions(-) create mode 100644 pkg/admin/mqtt_handlers.go create mode 100644 pkg/admin/mqtt_handlers_test.go create mode 100644 pkg/mqtt/connection_management_test.go diff --git a/pkg/admin/engineclient/client.go b/pkg/admin/engineclient/client.go index 0fd448ac..7695245b 100644 --- a/pkg/admin/engineclient/client.go +++ b/pkg/admin/engineclient/client.go @@ -1181,6 +1181,166 @@ func (c *Client) GetWebSocketStats(ctx context.Context) (*WebSocketStats, error) return &stats, nil } +// ListMQTTConnections returns all active MQTT client connections. +func (c *Client) ListMQTTConnections(ctx context.Context) ([]*MQTTConnection, error) { + resp, err := c.get(ctx, "/mqtt/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Connections []*MQTTConnection `json:"connections"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode MQTT connections: %w", err) + } + return result.Connections, nil +} + +// GetMQTTConnection returns a specific MQTT client connection. +func (c *Client) GetMQTTConnection(ctx context.Context, id string) (*MQTTConnection, error) { + resp, err := c.get(ctx, "/mqtt/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var conn MQTTConnection + if err := json.NewDecoder(resp.Body).Decode(&conn); err != nil { + return nil, fmt.Errorf("failed to decode MQTT connection: %w", err) + } + return &conn, nil +} + +// CloseMQTTConnection disconnects a specific MQTT client. +func (c *Client) CloseMQTTConnection(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/mqtt/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetMQTTStats returns MQTT broker statistics. +func (c *Client) GetMQTTStats(ctx context.Context) (*MQTTStats, error) { + resp, err := c.get(ctx, "/mqtt/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats MQTTStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode MQTT stats: %w", err) + } + return &stats, nil +} + +// ListGRPCStreams returns all active gRPC streaming connections. +func (c *Client) ListGRPCStreams(ctx context.Context) ([]*GRPCStream, error) { + resp, err := c.get(ctx, "/grpc/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var result struct { + Streams []*GRPCStream `json:"streams"` + Count int `json:"count"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode gRPC streams: %w", err) + } + return result.Streams, nil +} + +// GetGRPCStream returns a specific gRPC stream by ID. +func (c *Client) GetGRPCStream(ctx context.Context, id string) (*GRPCStream, error) { + resp, err := c.get(ctx, "/grpc/connections/"+url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return nil, ErrNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stream GRPCStream + if err := json.NewDecoder(resp.Body).Decode(&stream); err != nil { + return nil, fmt.Errorf("failed to decode gRPC stream: %w", err) + } + return &stream, nil +} + +// CancelGRPCStream cancels (terminates) a specific gRPC stream. +func (c *Client) CancelGRPCStream(ctx context.Context, id string) error { + resp, err := c.delete(ctx, "/grpc/connections/"+url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode == http.StatusNotFound { + return ErrNotFound + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetGRPCStats returns gRPC statistics. +func (c *Client) GetGRPCStats(ctx context.Context) (*GRPCStats, error) { + resp, err := c.get(ctx, "/grpc/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + + var stats GRPCStats + if err := json.NewDecoder(resp.Body).Decode(&stats); err != nil { + return nil, fmt.Errorf("failed to decode gRPC stats: %w", err) + } + return &stats, nil +} + // HTTP helpers func (c *Client) get(ctx context.Context, path string) (*http.Response, error) { diff --git a/pkg/admin/engineclient/types.go b/pkg/admin/engineclient/types.go index be7698ee..031df268 100644 --- a/pkg/admin/engineclient/types.go +++ b/pkg/admin/engineclient/types.go @@ -49,6 +49,10 @@ type ( SSEStats = types.SSEStats WebSocketConnection = types.WebSocketConnection WebSocketStats = types.WebSocketStats + MQTTConnection = types.MQTTConnection + MQTTStats = types.MQTTStats + GRPCStream = types.GRPCStream + GRPCStats = types.GRPCStats CustomOperationInfo = types.CustomOperationInfo CustomOperationDetail = types.CustomOperationDetail CustomOperationStep = types.CustomOperationStep diff --git a/pkg/admin/mqtt_handlers.go b/pkg/admin/mqtt_handlers.go new file mode 100644 index 00000000..ed46a1f5 --- /dev/null +++ b/pkg/admin/mqtt_handlers.go @@ -0,0 +1,144 @@ +package admin + +import ( + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// MQTTConnectionListResponse represents a list of MQTT connections with stats. +type MQTTConnectionListResponse struct { + Connections []*engineclient.MQTTConnection `json:"connections"` + Stats engineclient.MQTTStats `json:"stats"` +} + +// handleListMQTTConnections handles GET /mqtt/connections. +func (a *API) handleListMQTTConnections(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: []*engineclient.MQTTConnection{}, + Stats: engineclient.MQTTStats{SubscriptionsByClient: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetMQTTStats(ctx) + if err != nil { + a.logger().Error("failed to get MQTT stats", "error", err) + status, code, msg := mapMQTTEngineError(err, a.logger(), "get MQTT stats") + writeError(w, status, code, msg) + return + } + + connections, err := engine.ListMQTTConnections(ctx) + if err != nil { + a.logger().Error("failed to list MQTT connections", "error", err) + status, code, msg := mapMQTTEngineError(err, a.logger(), "list MQTT connections") + writeError(w, status, code, msg) + return + } + + if connections == nil { + connections = []*engineclient.MQTTConnection{} + } + + subsByClient := stats.SubscriptionsByClient + if subsByClient == nil { + subsByClient = make(map[string]int) + } + + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: connections, + Stats: engineclient.MQTTStats{ + ConnectedClients: stats.ConnectedClients, + TotalSubscriptions: stats.TotalSubscriptions, + TopicCount: stats.TopicCount, + Port: stats.Port, + TLSEnabled: stats.TLSEnabled, + AuthEnabled: stats.AuthEnabled, + SubscriptionsByClient: subsByClient, + }, + }) +} + +// handleGetMQTTConnection handles GET /mqtt/connections/{id}. +func (a *API) handleGetMQTTConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Client ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + conn, err := engine.GetMQTTConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to get MQTT connection", "error", err, "clientID", id) + status, code, msg := mapMQTTEngineError(err, a.logger(), "get MQTT connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, conn) +} + +// handleCloseMQTTConnection handles DELETE /mqtt/connections/{id}. +func (a *API) handleCloseMQTTConnection(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Client ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + + err := engine.CloseMQTTConnection(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Connection not found") + return + } + a.logger().Error("failed to close MQTT connection", "error", err, "clientID", id) + status, code, msg := mapMQTTEngineError(err, a.logger(), "close MQTT connection") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connection closed", + "connection": id, + }) +} + +// handleGetMQTTStats handles GET /mqtt/stats. +func (a *API) handleGetMQTTStats(w http.ResponseWriter, r *http.Request) { + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.MQTTStats{SubscriptionsByClient: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newMQTTStatsProvider(engine)) +} + +func mapMQTTEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} diff --git a/pkg/admin/mqtt_handlers_test.go b/pkg/admin/mqtt_handlers_test.go new file mode 100644 index 00000000..792d8403 --- /dev/null +++ b/pkg/admin/mqtt_handlers_test.go @@ -0,0 +1,251 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// ============================================================================ +// handleListMQTTConnections +// ============================================================================ + +func TestHandleListMQTTConnections_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp MQTTConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.NotNil(t, resp.Stats.SubscriptionsByClient) +} + +func TestHandleListMQTTConnections_WithEngine_ReturnsConnectionsAndStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/mqtt/stats": + _, _ = w.Write([]byte(`{"connectedClients":2,"totalSubscriptions":3,"topicCount":5,"port":1883,"tlsEnabled":false,"authEnabled":true,"subscriptionsByClient":{"client-1":2,"client-2":1}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/mqtt/connections": + _, _ = w.Write([]byte(`{"connections":[{"id":"client-1","brokerId":"broker-1","connectedAt":"2026-04-05T00:00:00Z","subscriptions":["sensors/#"],"protocolVersion":4,"status":"connected"}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp MQTTConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 1) + assert.Equal(t, "client-1", resp.Connections[0].ID) + assert.Equal(t, "broker-1", resp.Connections[0].BrokerID) + assert.Equal(t, []string{"sensors/#"}, resp.Connections[0].Subscriptions) + assert.Equal(t, 2, resp.Stats.ConnectedClients) + assert.Equal(t, 3, resp.Stats.TotalSubscriptions) + assert.True(t, resp.Stats.AuthEnabled) +} + +func TestHandleListMQTTConnections_EngineUnavailable_Returns503(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + api.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) +} + +// ============================================================================ +// handleGetMQTTConnection +// ============================================================================ + +func TestHandleGetMQTTConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetMQTTConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetMQTTConnection_Found_ReturnsConnection(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"client-1","brokerId":"broker-1","connectedAt":"2026-04-05T00:00:00Z","subscriptions":["topic/a"],"protocolVersion":5,"status":"connected"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn engineclient.MQTTConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "client-1", conn.ID) + assert.Equal(t, byte(5), conn.ProtocolVersion) +} + +func TestHandleGetMQTTConnection_NotFound_Returns404(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not_found","message":"MQTT connection not found"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + api.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +// ============================================================================ +// handleCloseMQTTConnection +// ============================================================================ + +func TestHandleCloseMQTTConnection_MissingID_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/", nil) + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCloseMQTTConnection_NoEngine_Returns404(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCloseMQTTConnection_Success(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"MQTT connection closed","id":"client-1"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + api.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// ============================================================================ +// handleGetMQTTStats +// ============================================================================ + +func TestHandleGetMQTTStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.MQTTStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.SubscriptionsByClient) +} + +func TestHandleGetMQTTStats_WithEngine_ReturnsStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"connectedClients":3,"totalSubscriptions":5,"topicCount":2,"port":1883,"subscriptionsByClient":{"c1":2,"c2":3}}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.MQTTStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, 3, stats.ConnectedClients) + assert.Equal(t, 5, stats.TotalSubscriptions) + assert.Equal(t, 2, stats.SubscriptionsByClient["c1"]) +} diff --git a/pkg/admin/routes.go b/pkg/admin/routes.go index 579e3e3e..492c955c 100644 --- a/pkg/admin/routes.go +++ b/pkg/admin/routes.go @@ -130,6 +130,20 @@ func (a *API) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /websocket/connections/{id}/send", a.handleSendToWebSocketConnection) mux.HandleFunc("GET /websocket/stats", a.handleGetWebSocketStats) + // MQTT connection management + // Routes use /mqtt-connections/ prefix to avoid ambiguity with existing + // /mqtt/{id}/status recording routes. + mux.HandleFunc("GET /mqtt-connections", a.handleListMQTTConnections) + mux.HandleFunc("GET /mqtt-connections/{id}", a.handleGetMQTTConnection) + mux.HandleFunc("DELETE /mqtt-connections/{id}", a.handleCloseMQTTConnection) + mux.HandleFunc("GET /mqtt-connections/stats", a.handleGetMQTTStats) + + // gRPC stream management + mux.HandleFunc("GET /grpc/connections", a.handleListGRPCStreams) + mux.HandleFunc("GET /grpc/connections/{id}", a.handleGetGRPCStream) + mux.HandleFunc("DELETE /grpc/connections/{id}", a.handleCancelGRPCStream) + mux.HandleFunc("GET /grpc/stats", a.handleGetGRPCStats) + // Mock-specific SSE endpoints mux.HandleFunc("GET /mocks/{id}/sse/connections", a.requireEngine(a.handleListMockSSEConnections)) mux.HandleFunc("DELETE /mocks/{id}/sse/connections", a.requireEngine(a.handleCloseMockSSEConnections)) diff --git a/pkg/admin/stat_helper.go b/pkg/admin/stat_helper.go index 17ddd6c0..ed732ecf 100644 --- a/pkg/admin/stat_helper.go +++ b/pkg/admin/stat_helper.go @@ -13,10 +13,10 @@ import ( type statsProvider interface { // GetStats retrieves statistics from the engine. GetStats(ctx context.Context) (interface{}, error) - // GetEmptyStats returns empty statistics for when engine is unavailable. - GetEmptyStats() interface{} // MapError converts an error to HTTP status, code, and message. MapError(err error, log *slog.Logger, operation string) (int, string, string) + // ProtocolName returns the protocol name for error logging. + ProtocolName() string } // handleGetStats is a generic handler for retrieving statistics. @@ -27,7 +27,7 @@ func (a *API) handleGetStats(w http.ResponseWriter, r *http.Request, provider st stats, err := provider.GetStats(ctx) if err != nil { - a.logger().Error("failed to get stats", "error", err) + a.logger().Error("failed to get "+provider.ProtocolName()+" stats", "error", err) status, code, msg := provider.MapError(err, a.logger(), "get stats") writeError(w, status, code, msg) return @@ -67,18 +67,14 @@ func (p *sseStatsProvider) GetStats(ctx context.Context) (interface{}, error) { }, nil } -// GetEmptyStats returns empty SSE statistics. -func (p *sseStatsProvider) GetEmptyStats() interface{} { - return sse.ConnectionStats{ - ConnectionsByMock: make(map[string]int), - } -} - // MapError converts an SSE engine error to HTTP status, code, and message. func (p *sseStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { return mapSSEEngineError(err, log, operation) } +// ProtocolName returns the protocol name for error logging. +func (p *sseStatsProvider) ProtocolName() string { return "SSE" } + // wsStatsProvider implements statsProvider for WebSocket statistics. type wsStatsProvider struct { engine *engineclient.Client @@ -110,14 +106,51 @@ func (p *wsStatsProvider) GetStats(ctx context.Context) (interface{}, error) { }, nil } -// GetEmptyStats returns empty WebSocket statistics. -func (p *wsStatsProvider) GetEmptyStats() interface{} { - return engineclient.WebSocketStats{ - ConnectionsByMock: make(map[string]int), - } -} - // MapError converts a WebSocket engine error to HTTP status, code, and message. func (p *wsStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { return mapWebSocketEngineError(err, log, operation) } + +// ProtocolName returns the protocol name for error logging. +func (p *wsStatsProvider) ProtocolName() string { return "WebSocket" } + +// mqttStatsProvider implements statsProvider for MQTT statistics. +type mqttStatsProvider struct { + engine *engineclient.Client +} + +// newMQTTStatsProvider creates a new MQTT statistics provider. +func newMQTTStatsProvider(engine *engineclient.Client) *mqttStatsProvider { + return &mqttStatsProvider{engine: engine} +} + +// GetStats retrieves MQTT statistics from the engine. +func (p *mqttStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetMQTTStats(ctx) + if err != nil { + return nil, err + } + + subsByClient := stats.SubscriptionsByClient + if subsByClient == nil { + subsByClient = make(map[string]int) + } + + return engineclient.MQTTStats{ + ConnectedClients: stats.ConnectedClients, + TotalSubscriptions: stats.TotalSubscriptions, + TopicCount: stats.TopicCount, + Port: stats.Port, + TLSEnabled: stats.TLSEnabled, + AuthEnabled: stats.AuthEnabled, + SubscriptionsByClient: subsByClient, + }, nil +} + +// MapError converts an MQTT engine error to HTTP status, code, and message. +func (p *mqttStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapMQTTEngineError(err, log, operation) +} + +// ProtocolName returns the protocol name for error logging. +func (p *mqttStatsProvider) ProtocolName() string { return "MQTT" } diff --git a/pkg/api/types/responses.go b/pkg/api/types/responses.go index a11e04d9..bb8c4202 100644 --- a/pkg/api/types/responses.go +++ b/pkg/api/types/responses.go @@ -471,20 +471,48 @@ type SSEStats struct { ConnectionsByMock map[string]int `json:"connectionsByMock"` } +// --- MQTT --- + +// MQTTConnection represents an active MQTT client connection. +type MQTTConnection struct { + ID string `json:"id"` + BrokerID string `json:"brokerId"` + ConnectedAt time.Time `json:"connectedAt"` + Subscriptions []string `json:"subscriptions"` + ProtocolVersion byte `json:"protocolVersion"` + Username string `json:"username,omitempty"` + RemoteAddr string `json:"remoteAddr,omitempty"` + Status string `json:"status"` +} + +// MQTTConnectionListResponse lists MQTT connections. +type MQTTConnectionListResponse struct { + Connections []*MQTTConnection `json:"connections"` + Count int `json:"count"` +} + +// MQTTStats represents MQTT broker statistics. +type MQTTStats struct { + ConnectedClients int `json:"connectedClients"` + TotalSubscriptions int `json:"totalSubscriptions"` + TopicCount int `json:"topicCount"` + Port int `json:"port"` + TLSEnabled bool `json:"tlsEnabled"` + AuthEnabled bool `json:"authEnabled"` + SubscriptionsByClient map[string]int `json:"subscriptionsByClient"` +} + // --- WebSocket --- // WebSocketConnection represents an active WebSocket connection. type WebSocketConnection struct { - ID string `json:"id"` - MockID string `json:"mockId"` - Path string `json:"path"` - ClientIP string `json:"clientIp"` - ConnectedAt time.Time `json:"connectedAt"` - MessagesSent int64 `json:"messagesSent"` - MessagesRecv int64 `json:"messagesRecv"` - BytesSent int64 `json:"bytesSent"` - BytesReceived int64 `json:"bytesReceived"` - Status string `json:"status"` + ID string `json:"id"` + MockID string `json:"mockId,omitempty"` + Path string `json:"path"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` + Status string `json:"status"` } // WebSocketConnectionListResponse lists WebSocket connections. @@ -501,3 +529,32 @@ type WebSocketStats struct { TotalMessagesRecv int64 `json:"totalMessagesRecv"` ConnectionsByMock map[string]int `json:"connectionsByMock"` } + +// --- gRPC --- + +// GRPCStream represents an active gRPC streaming RPC. +type GRPCStream struct { + ID string `json:"id"` + Method string `json:"method"` + StreamType string `json:"streamType"` + ClientAddr string `json:"clientAddr,omitempty"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` +} + +// GRPCStreamListResponse lists gRPC streams. +type GRPCStreamListResponse struct { + Streams []*GRPCStream `json:"streams"` + Count int `json:"count"` +} + +// GRPCStats represents gRPC statistics. +type GRPCStats struct { + ActiveStreams int `json:"activeStreams"` + TotalStreams int64 `json:"totalStreams"` + TotalRPCs int64 `json:"totalRPCs"` + TotalMessagesSent int64 `json:"totalMessagesSent"` + TotalMessagesRecv int64 `json:"totalMessagesRecv"` + StreamsByMethod map[string]int `json:"streamsByMethod"` +} diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index f220d54e..dc5573d3 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -843,6 +843,10 @@ func (s *Server) handleSendToWebSocketConnection(w http.ResponseWriter, r *http. if req.Type == "" { req.Type = "text" } + if req.Type != "text" && req.Type != "binary" { + writeError(w, http.StatusBadRequest, "invalid_type", `Type must be "text" or "binary"`) + return + } var payload []byte if req.Type == "binary" { @@ -1064,6 +1068,91 @@ func writeJSON(w http.ResponseWriter, status int, v any) { httputil.WriteJSON(w, status, v) } +// MQTT handlers + +func (s *Server) handleListMQTTConnections(w http.ResponseWriter, r *http.Request) { + connections := s.engine.ListMQTTConnections() + writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ + Connections: connections, + Count: len(connections), + }) +} + +func (s *Server) handleGetMQTTConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + conn := s.engine.GetMQTTConnection(id) + if conn == nil { + writeError(w, http.StatusNotFound, "not_found", "MQTT connection not found") + return + } + writeJSON(w, http.StatusOK, conn) +} + +func (s *Server) handleCloseMQTTConnection(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.CloseMQTTConnection(id); err != nil { + writeError(w, http.StatusNotFound, "not_found", "MQTT connection not found") + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "message": "MQTT connection closed", + "id": id, + }) +} + +func (s *Server) handleGetMQTTStats(w http.ResponseWriter, r *http.Request) { + stats := s.engine.GetMQTTStats() + if stats == nil { + writeJSON(w, http.StatusOK, MQTTStats{SubscriptionsByClient: make(map[string]int)}) + return + } + writeJSON(w, http.StatusOK, stats) +} + +// gRPC stream handlers + +func (s *Server) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { + streams := s.engine.ListGRPCStreams() + if streams == nil { + streams = []*GRPCStream{} + } + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: streams, + Count: len(streams), + }) +} + +func (s *Server) handleGetGRPCStream(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + stream := s.engine.GetGRPCStream(id) + if stream == nil { + writeError(w, http.StatusNotFound, "not_found", "gRPC stream not found") + return + } + writeJSON(w, http.StatusOK, stream) +} + +func (s *Server) handleCancelGRPCStream(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.CancelGRPCStream(id); err != nil { + writeError(w, http.StatusNotFound, "not_found", "gRPC stream not found") + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "message": "gRPC stream cancelled", + "id": id, + }) +} + +func (s *Server) handleGetGRPCStats(w http.ResponseWriter, r *http.Request) { + stats := s.engine.GetGRPCStats() + if stats == nil { + writeJSON(w, http.StatusOK, GRPCStats{StreamsByMethod: make(map[string]int)}) + return + } + writeJSON(w, http.StatusOK, stats) +} + func writeError(w http.ResponseWriter, status int, code, message string) { httputil.WriteJSON(w, status, ErrorResponse{ Error: code, diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index d8723515..e02619c1 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -30,9 +30,11 @@ type mockEngine struct { stateOverview *StateOverview handlers []*ProtocolHandler sseConnections []*SSEConnection - wsConnections []*WebSocketConnection - sseStats *SSEStats - wsStats *WebSocketStats + wsConnections []*WebSocketConnection + mqttConnections []*MQTTConnection + sseStats *SSEStats + wsStats *WebSocketStats + mqttStats *MQTTStats configResp *ConfigResponse protocols map[string]ProtocolStatusInfo @@ -376,6 +378,38 @@ func (m *mockEngine) SendToWebSocketConnection(id string, msgType string, data [ return websocket.ErrConnectionNotFound } +func (m *mockEngine) ListMQTTConnections() []*MQTTConnection { + return m.mqttConnections +} + +func (m *mockEngine) GetMQTTConnection(id string) *MQTTConnection { + for _, c := range m.mqttConnections { + if c.ID == id { + return c + } + } + return nil +} + +func (m *mockEngine) CloseMQTTConnection(id string) error { + for i, c := range m.mqttConnections { + if c.ID == id { + m.mqttConnections = append(m.mqttConnections[:i], m.mqttConnections[i+1:]...) + return nil + } + } + return errors.New("connection not found") +} + +func (m *mockEngine) GetMQTTStats() *MQTTStats { + return m.mqttStats +} + +func (m *mockEngine) ListGRPCStreams() []*GRPCStream { return nil } +func (m *mockEngine) GetGRPCStream(id string) *GRPCStream { return nil } +func (m *mockEngine) CancelGRPCStream(id string) error { return errors.New("not found") } +func (m *mockEngine) GetGRPCStats() *GRPCStats { return nil } + func (m *mockEngine) GetConfig() *ConfigResponse { return m.configResp } @@ -2577,3 +2611,154 @@ func TestHandleSendToWebSocketConnection(t *testing.T) { assert.Equal(t, "send_error", resp["error"]) }) } + +// ============================================================================ +// MQTT handlers +// ============================================================================ + +func TestHandleListMQTTConnections(t *testing.T) { + t.Run("empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + server.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp MQTTConnectionListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, 0, resp.Count) + }) + + t.Run("with connections", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", Subscriptions: []string{"sensors/#"}, Status: "connected"}, + {ID: "client-2", BrokerID: "broker-1", Subscriptions: []string{"devices/+"}, Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections", nil) + rec := httptest.NewRecorder() + + server.handleListMQTTConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp MQTTConnectionListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, 2, resp.Count) + assert.Equal(t, "client-1", resp.Connections[0].ID) + }) +} + +func TestHandleGetMQTTConnection(t *testing.T) { + t.Run("found", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", ProtocolVersion: 5, Subscriptions: []string{"topic/a"}, Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + server.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var conn MQTTConnection + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &conn)) + assert.Equal(t, "client-1", conn.ID) + assert.Equal(t, byte(5), conn.ProtocolVersion) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleGetMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleCloseMQTTConnection(t *testing.T) { + t.Run("success", func(t *testing.T) { + engine := newMockEngine() + engine.mqttConnections = []*MQTTConnection{ + {ID: "client-1", BrokerID: "broker-1", Status: "connected"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/client-1", nil) + req.SetPathValue("id", "client-1") + rec := httptest.NewRecorder() + + server.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, "MQTT connection closed", resp["message"]) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/mqtt/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleCloseMQTTConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleGetMQTTStats(t *testing.T) { + t.Run("nil stats returns empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats MQTTStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.NotNil(t, stats.SubscriptionsByClient) + }) + + t.Run("with stats", func(t *testing.T) { + engine := newMockEngine() + engine.mqttStats = &MQTTStats{ + ConnectedClients: 3, + TotalSubscriptions: 7, + TopicCount: 4, + Port: 1883, + SubscriptionsByClient: map[string]int{"c1": 3, "c2": 4}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/mqtt/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetMQTTStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats MQTTStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, 3, stats.ConnectedClients) + assert.Equal(t, 7, stats.TotalSubscriptions) + assert.Equal(t, 3, stats.SubscriptionsByClient["c1"]) + }) +} diff --git a/pkg/engine/api/server.go b/pkg/engine/api/server.go index 0650dab4..baba2c94 100644 --- a/pkg/engine/api/server.go +++ b/pkg/engine/api/server.go @@ -92,6 +92,18 @@ type EngineController interface { SendToWebSocketConnection(id string, msgType string, data []byte) error GetWebSocketStats() *WebSocketStats + // MQTT connections + ListMQTTConnections() []*MQTTConnection + GetMQTTConnection(id string) *MQTTConnection + CloseMQTTConnection(id string) error + GetMQTTStats() *MQTTStats + + // gRPC streams + ListGRPCStreams() []*GRPCStream + GetGRPCStream(id string) *GRPCStream + CancelGRPCStream(id string) error + GetGRPCStats() *GRPCStats + // Config GetConfig() *ConfigResponse } @@ -227,6 +239,18 @@ func (s *Server) registerRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /websocket/stats", s.handleGetWebSocketStats) mux.HandleFunc("POST /websocket/connections/{id}/send", s.handleSendToWebSocketConnection) + // MQTT connections + mux.HandleFunc("GET /mqtt/connections", s.handleListMQTTConnections) + mux.HandleFunc("GET /mqtt/connections/{id}", s.handleGetMQTTConnection) + mux.HandleFunc("DELETE /mqtt/connections/{id}", s.handleCloseMQTTConnection) + mux.HandleFunc("GET /mqtt/stats", s.handleGetMQTTStats) + + // gRPC streams + mux.HandleFunc("GET /grpc/connections", s.handleListGRPCStreams) + mux.HandleFunc("GET /grpc/connections/{id}", s.handleGetGRPCStream) + mux.HandleFunc("DELETE /grpc/connections/{id}", s.handleCancelGRPCStream) + mux.HandleFunc("GET /grpc/stats", s.handleGetGRPCStats) + // Config mux.HandleFunc("GET /config", s.handleGetConfig) mux.HandleFunc("POST /config", s.handleImportConfig) diff --git a/pkg/engine/control_api.go b/pkg/engine/control_api.go index b85b2851..4ab948ca 100644 --- a/pkg/engine/control_api.go +++ b/pkg/engine/control_api.go @@ -791,6 +791,220 @@ func (a *ControlAPIAdapter) GetWebSocketStats() *api.WebSocketStats { } } +// ErrMQTTBrokerNotInitialized is returned when no MQTT broker is available. +var ErrMQTTBrokerNotInitialized = errors.New("MQTT broker not initialized") + +// ListMQTTConnections implements api.EngineController. +func (a *ControlAPIAdapter) ListMQTTConnections() []*api.MQTTConnection { + brokers := a.server.GetMQTTBrokers() + var result []*api.MQTTConnection + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + infos := broker.ListClientInfos() + for _, info := range infos { + if info.Closed { + continue + } + conn := &api.MQTTConnection{ + ID: info.ID, + BrokerID: info.BrokerID, + ConnectedAt: info.ConnectedAt, + Subscriptions: info.Subscriptions, + ProtocolVersion: info.ProtocolVersion, + Username: info.Username, + RemoteAddr: info.RemoteAddr, + Status: "connected", + } + if conn.Subscriptions == nil { + conn.Subscriptions = []string{} + } + result = append(result, conn) + } + } + + return result +} + +// GetMQTTConnection implements api.EngineController. +func (a *ControlAPIAdapter) GetMQTTConnection(id string) *api.MQTTConnection { + brokers := a.server.GetMQTTBrokers() + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + info := broker.GetClientInfo(id) + if info == nil { + continue + } + subs := info.Subscriptions + if subs == nil { + subs = []string{} + } + return &api.MQTTConnection{ + ID: info.ID, + BrokerID: info.BrokerID, + ConnectedAt: info.ConnectedAt, + Subscriptions: subs, + ProtocolVersion: info.ProtocolVersion, + Username: info.Username, + RemoteAddr: info.RemoteAddr, + Status: "connected", + } + } + + return nil +} + +// CloseMQTTConnection implements api.EngineController. +func (a *ControlAPIAdapter) CloseMQTTConnection(id string) error { + brokers := a.server.GetMQTTBrokers() + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + if err := broker.DisconnectClient(id); err == nil { + return nil + } + } + + return fmt.Errorf("MQTT client %q not found", id) +} + +// GetMQTTStats implements api.EngineController. +func (a *ControlAPIAdapter) GetMQTTStats() *api.MQTTStats { + brokers := a.server.GetMQTTBrokers() + if len(brokers) == 0 { + return nil + } + + stats := &api.MQTTStats{ + SubscriptionsByClient: make(map[string]int), + } + + for _, broker := range brokers { + if broker == nil || !broker.IsRunning() { + continue + } + connected, totalSubs, subsByClient := broker.GetConnectionStats() + stats.ConnectedClients += connected + stats.TotalSubscriptions += totalSubs + for k, v := range subsByClient { + stats.SubscriptionsByClient[k] = v + } + cfg := broker.Config() + if cfg != nil { + stats.TopicCount += len(cfg.Topics) + stats.Port = cfg.Port + stats.TLSEnabled = stats.TLSEnabled || (cfg.TLS != nil && cfg.TLS.Enabled) + stats.AuthEnabled = stats.AuthEnabled || (cfg.Auth != nil && cfg.Auth.Enabled) + } + } + + return stats +} + +// ErrGRPCServerNotInitialized is returned when no gRPC server is available. +var ErrGRPCServerNotInitialized = errors.New("gRPC server not initialized") + +// ListGRPCStreams implements api.EngineController. +func (a *ControlAPIAdapter) ListGRPCStreams() []*api.GRPCStream { + servers := a.server.GRPCServers() + var result []*api.GRPCStream + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + for _, info := range srv.StreamTracker().List() { + result = append(result, &api.GRPCStream{ + ID: info.ID, + Method: info.Method, + StreamType: string(info.StreamType), + ClientAddr: info.ClientAddr, + ConnectedAt: info.ConnectedAt, + MessagesSent: info.MessagesSent, + MessagesRecv: info.MessagesRecv, + }) + } + } + + return result +} + +// GetGRPCStream implements api.EngineController. +func (a *ControlAPIAdapter) GetGRPCStream(id string) *api.GRPCStream { + servers := a.server.GRPCServers() + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + if info := srv.StreamTracker().Get(id); info != nil { + return &api.GRPCStream{ + ID: info.ID, + Method: info.Method, + StreamType: string(info.StreamType), + ClientAddr: info.ClientAddr, + ConnectedAt: info.ConnectedAt, + MessagesSent: info.MessagesSent, + MessagesRecv: info.MessagesRecv, + } + } + } + + return nil +} + +// CancelGRPCStream implements api.EngineController. +func (a *ControlAPIAdapter) CancelGRPCStream(id string) error { + servers := a.server.GRPCServers() + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + if err := srv.StreamTracker().Cancel(id); err == nil { + return nil + } + } + + return fmt.Errorf("gRPC stream %s not found", id) +} + +// GetGRPCStats implements api.EngineController. +func (a *ControlAPIAdapter) GetGRPCStats() *api.GRPCStats { + servers := a.server.GRPCServers() + if len(servers) == 0 { + return nil + } + + stats := &api.GRPCStats{ + StreamsByMethod: make(map[string]int), + } + + for _, srv := range servers { + if srv == nil || !srv.IsRunning() { + continue + } + s := srv.StreamTracker().Stats() + stats.ActiveStreams += s.ActiveStreams + stats.TotalStreams += s.TotalStreams + stats.TotalRPCs += s.TotalRPCs + stats.TotalMessagesSent += s.TotalMsgSent + stats.TotalMessagesRecv += s.TotalMsgRecv + for method, count := range s.StreamsByMethod { + stats.StreamsByMethod[method] += count + } + } + + return stats +} + // GetConfig implements api.EngineController. func (a *ControlAPIAdapter) GetConfig() *api.ConfigResponse { cfg := a.server.Config() diff --git a/pkg/engine/mock_manager.go b/pkg/engine/mock_manager.go index 76111cea..0765a66f 100644 --- a/pkg/engine/mock_manager.go +++ b/pkg/engine/mock_manager.go @@ -181,15 +181,44 @@ func (mm *MockManager) unregisterHandlerLocked(cfg *config.MockConfiguration) { return } - switch cfg.Type { //nolint:exhaustive // HTTP mocks don't need cleanup on removal + switch cfg.Type { //nolint:exhaustive // plain HTTP mocks don't need cleanup on removal + case mock.TypeHTTP: + // HTTP mocks with SSE configuration have active streaming connections + // that must be disconnected before the mock is removed or updated. + if mm.handler != nil && cfg.HTTP != nil && cfg.HTTP.SSE != nil { + mm.handler.DisconnectSSEByMock(cfg.ID) + } case mock.TypeGRPC: if mm.protocolManager != nil { + // Cancel active streams first so clients receive codes.Unavailable + // and can reconnect with retry policies (gRPC equivalent of WS 1012). + if srv := mm.protocolManager.GetGRPCServer(cfg.ID); srv != nil { + n := srv.StreamTracker().CancelAll() + if n > 0 { + mm.log.Info("cancelled active gRPC streams", "id", cfg.ID, "count", n) + } + } if err := mm.protocolManager.StopGRPCServer(cfg.ID); err != nil { mm.log.Warn("failed to stop gRPC server", "id", cfg.ID, "error", err) } } case mock.TypeMQTT: if mm.protocolManager != nil { + // Disconnect all connected clients cleanly before stopping the broker. + // This gives clients a chance to detect the disconnection and reconnect + // when the broker restarts with the updated configuration. + broker := mm.protocolManager.GetMQTTBroker(cfg.ID) + if broker != nil && broker.IsRunning() { + clients := broker.GetClients() + for _, clientID := range clients { + if err := broker.DisconnectClient(clientID); err != nil { + mm.log.Debug("failed to disconnect MQTT client during mock update", + "clientID", clientID, "mockID", cfg.ID, "error", err) + } + } + mm.log.Info("disconnected MQTT clients before mock update", + "mockID", cfg.ID, "clientCount", len(clients)) + } if err := mm.protocolManager.StopMQTTBroker(cfg.ID); err != nil { mm.log.Warn("failed to stop MQTT broker", "id", cfg.ID, "error", err) } diff --git a/pkg/mqtt/broker.go b/pkg/mqtt/broker.go index 17b14e31..86dcb330 100644 --- a/pkg/mqtt/broker.go +++ b/pkg/mqtt/broker.go @@ -654,6 +654,125 @@ func (b *Broker) notifyTestPanelPublish(topic string, payload []byte, qos int, r b.sessionManager.NotifyMessage(b.config.ID, msg) } +// MQTTClientInfo represents information about a connected MQTT client. +type MQTTClientInfo struct { + ID string + BrokerID string + ConnectedAt time.Time + Subscriptions []string + ProtocolVersion byte + Username string + RemoteAddr string + Closed bool +} + +// ListClientInfos returns information about all connected MQTT clients. +func (b *Broker) ListClientInfos() []*MQTTClientInfo { + b.mu.RLock() + brokerID := "" + if b.config != nil { + brokerID = b.config.ID + } + subs := make(map[string][]string, len(b.clientSubscriptions)) + for k, v := range b.clientSubscriptions { + subs[k] = append([]string{}, v...) + } + b.mu.RUnlock() + + if b.server == nil { + return nil + } + + clients := b.server.Clients.GetAll() + var result []*MQTTClientInfo + + for id, cl := range clients { + if cl.Net.Inline { + continue // skip the built-in inline client + } + info := &MQTTClientInfo{ + ID: id, + BrokerID: brokerID, + ProtocolVersion: cl.Properties.ProtocolVersion, + Username: string(cl.Properties.Username), + RemoteAddr: cl.Net.Remote, + Closed: cl.Closed(), + } + if subList, ok := subs[id]; ok { + info.Subscriptions = subList + } + result = append(result, info) + } + + return result +} + +// GetClientInfo returns information about a specific MQTT client. +func (b *Broker) GetClientInfo(clientID string) *MQTTClientInfo { + if b.server == nil { + return nil + } + + cl, ok := b.server.Clients.Get(clientID) + if !ok { + return nil + } + if cl.Net.Inline { + return nil + } + + b.mu.RLock() + brokerID := "" + if b.config != nil { + brokerID = b.config.ID + } + var subsCopy []string + if s, ok := b.clientSubscriptions[clientID]; ok { + subsCopy = append([]string{}, s...) + } + b.mu.RUnlock() + + return &MQTTClientInfo{ + ID: clientID, + BrokerID: brokerID, + ProtocolVersion: cl.Properties.ProtocolVersion, + Username: string(cl.Properties.Username), + RemoteAddr: cl.Net.Remote, + Closed: cl.Closed(), + Subscriptions: subsCopy, + } +} + +// DisconnectClient disconnects a specific MQTT client by client ID. +// Returns an error if the client is not found. +func (b *Broker) DisconnectClient(clientID string) error { + if b.server == nil { + return errors.New("broker is not running") + } + + cl, ok := b.server.Clients.Get(clientID) + if !ok || cl.Net.Inline { + return fmt.Errorf("client %q not found", clientID) + } + + cl.Stop(errors.New("disconnected by admin API")) + return nil +} + +// GetConnectionStats returns aggregate connection statistics. +func (b *Broker) GetConnectionStats() (connectedClients int, totalSubscriptions int, subsByClient map[string]int) { + b.mu.RLock() + defer b.mu.RUnlock() + + subsByClient = make(map[string]int, len(b.clientSubscriptions)) + for clientID, subs := range b.clientSubscriptions { + subsByClient[clientID] = len(subs) + totalSubscriptions += len(subs) + } + connectedClients = len(b.getClientsLocked()) + return +} + // detectPayloadFormat detects the format of a payload func detectPayloadFormat(payload []byte) string { // Try to parse as JSON diff --git a/pkg/mqtt/connection_management_test.go b/pkg/mqtt/connection_management_test.go new file mode 100644 index 00000000..74a70dda --- /dev/null +++ b/pkg/mqtt/connection_management_test.go @@ -0,0 +1,104 @@ +package mqtt + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBroker_ListClientInfos_Empty(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + infos := broker.ListClientInfos() + // The inline client may or may not appear, but we filter it out + for _, info := range infos { + assert.False(t, info.Closed, "active clients should not be closed") + } +} + +func TestBroker_GetClientInfo_NotFound(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + info := broker.GetClientInfo("nonexistent-client") + assert.Nil(t, info) +} + +func TestBroker_DisconnectClient_NotFound(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + err = broker.DisconnectClient("nonexistent-client") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestBroker_DisconnectClient_BrokerNotRunning(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + // Don't start the broker + err = broker.DisconnectClient("some-client") + assert.Error(t, err) +} + +func TestBroker_GetConnectionStats_Empty(t *testing.T) { + t.Parallel() + broker, err := NewBroker(&MQTTConfig{Port: 0}) + require.NoError(t, err) + + require.NoError(t, broker.Start(context.Background())) + defer broker.Stop(context.Background(), 5*time.Second) + + time.Sleep(100 * time.Millisecond) + + connected, totalSubs, subsByClient := broker.GetConnectionStats() + assert.GreaterOrEqual(t, connected, 0) + assert.Equal(t, 0, totalSubs) + assert.NotNil(t, subsByClient) +} + +func TestBroker_ListClientInfos_NilServer(t *testing.T) { + t.Parallel() + broker := &Broker{ + config: &MQTTConfig{ID: "test"}, + clientSubscriptions: make(map[string][]string), + } + + infos := broker.ListClientInfos() + assert.Nil(t, infos) +} + +func TestBroker_GetClientInfo_NilServer(t *testing.T) { + t.Parallel() + broker := &Broker{ + config: &MQTTConfig{ID: "test"}, + clientSubscriptions: make(map[string][]string), + } + + info := broker.GetClientInfo("any-client") + assert.Nil(t, info) +} From dde6e05e23a74f205f3fc648c1b665c97406d665 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Sun, 5 Apr 2026 18:50:08 -0500 Subject: [PATCH 07/18] feat: implement gRPC connection tracking for streaming RPCs Add StreamTracker to the gRPC server that tracks active streaming RPCs (server-stream, client-stream, bidirectional). Unary RPCs are counted but not tracked as active connections since they are stateless. Key changes: - StreamTracker in pkg/grpc/stream_tracker.go with Register/Unregister, message counting, Cancel/CancelAll, and aggregate stats - Hooked into all three streaming handlers in server.go, replacing raw metrics.ActiveConnections calls with tracker-based management - Admin API endpoints: GET/DELETE /grpc/connections/{id}, GET /grpc/stats - Engine Control API routes and ControlAPIAdapter methods - Engine client methods for the admin server to call through - Auto-disconnect on mock update: unregisterHandlerLocked cancels active gRPC streams with codes.Unavailable before stopping the server, so clients with retry policies reconnect automatically - Full test coverage: stream tracker unit tests, engine API handler tests, admin handler tests Co-Authored-By: Claude Opus 4.6 --- pkg/admin/grpc_handlers.go | 180 +++++++++++++++++++++ pkg/admin/grpc_handlers_test.go | 218 ++++++++++++++++++++++++++ pkg/engine/api/handlers_test.go | 180 ++++++++++++++++++++- pkg/engine/api/types.go | 6 + pkg/grpc/server.go | 103 ++++++++---- pkg/grpc/stream_tracker.go | 267 ++++++++++++++++++++++++++++++++ pkg/grpc/stream_tracker_test.go | 193 +++++++++++++++++++++++ 7 files changed, 1111 insertions(+), 36 deletions(-) create mode 100644 pkg/admin/grpc_handlers.go create mode 100644 pkg/admin/grpc_handlers_test.go create mode 100644 pkg/grpc/stream_tracker.go create mode 100644 pkg/grpc/stream_tracker_test.go diff --git a/pkg/admin/grpc_handlers.go b/pkg/admin/grpc_handlers.go new file mode 100644 index 00000000..a18df475 --- /dev/null +++ b/pkg/admin/grpc_handlers.go @@ -0,0 +1,180 @@ +package admin + +import ( + "context" + "errors" + "log/slog" + "net/http" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// GRPCStreamListResponse represents a list of gRPC streams with stats. +type GRPCStreamListResponse struct { + Streams []*engineclient.GRPCStream `json:"streams"` + Stats engineclient.GRPCStats `json:"stats"` +} + +// handleListGRPCStreams handles GET /grpc/connections. +func (a *API) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: []*engineclient.GRPCStream{}, + Stats: engineclient.GRPCStats{StreamsByMethod: make(map[string]int)}, + }) + return + } + + stats, err := engine.GetGRPCStats(ctx) + if err != nil { + a.logger().Error("failed to get gRPC stats", "error", err) + status, code, msg := mapGRPCEngineError(err, a.logger(), "get gRPC stats") + writeError(w, status, code, msg) + return + } + + streams, err := engine.ListGRPCStreams(ctx) + if err != nil { + a.logger().Error("failed to list gRPC streams", "error", err) + status, code, msg := mapGRPCEngineError(err, a.logger(), "list gRPC streams") + writeError(w, status, code, msg) + return + } + + if streams == nil { + streams = []*engineclient.GRPCStream{} + } + + byMethod := stats.StreamsByMethod + if byMethod == nil { + byMethod = make(map[string]int) + } + + writeJSON(w, http.StatusOK, GRPCStreamListResponse{ + Streams: streams, + Stats: engineclient.GRPCStats{ + ActiveStreams: stats.ActiveStreams, + TotalStreams: stats.TotalStreams, + TotalRPCs: stats.TotalRPCs, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + StreamsByMethod: byMethod, + }, + }) +} + +// handleGetGRPCStream handles GET /grpc/connections/{id}. +func (a *API) handleGetGRPCStream(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Stream ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + + stream, err := engine.GetGRPCStream(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + a.logger().Error("failed to get gRPC stream", "error", err, "streamID", id) + status, code, msg := mapGRPCEngineError(err, a.logger(), "get gRPC stream") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, stream) +} + +// handleCancelGRPCStream handles DELETE /grpc/connections/{id}. +func (a *API) handleCancelGRPCStream(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + id := r.PathValue("id") + if id == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Stream ID is required") + return + } + + engine := a.localEngine.Load() + if engine == nil { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + + err := engine.CancelGRPCStream(ctx, id) + if err != nil { + if errors.Is(err, engineclient.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Stream not found") + return + } + a.logger().Error("failed to cancel gRPC stream", "error", err, "streamID", id) + status, code, msg := mapGRPCEngineError(err, a.logger(), "cancel gRPC stream") + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Stream cancelled", + "stream": id, + }) +} + +// handleGetGRPCStats handles GET /grpc/stats. +func (a *API) handleGetGRPCStats(w http.ResponseWriter, r *http.Request) { + engine := a.localEngine.Load() + if engine == nil { + writeJSON(w, http.StatusOK, engineclient.GRPCStats{StreamsByMethod: make(map[string]int)}) + return + } + a.handleGetStats(w, r, newGRPCStatsProvider(engine)) +} + +func mapGRPCEngineError(err error, log *slog.Logger, operation string) (int, string, string) { + return http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, log, operation) +} + +// grpcStatsProvider implements statsProvider for gRPC statistics. +type grpcStatsProvider struct { + engine *engineclient.Client +} + +func newGRPCStatsProvider(engine *engineclient.Client) *grpcStatsProvider { + return &grpcStatsProvider{engine: engine} +} + +func (p *grpcStatsProvider) GetStats(ctx context.Context) (interface{}, error) { + stats, err := p.engine.GetGRPCStats(ctx) + if err != nil { + return nil, err + } + + byMethod := stats.StreamsByMethod + if byMethod == nil { + byMethod = make(map[string]int) + } + + return engineclient.GRPCStats{ + ActiveStreams: stats.ActiveStreams, + TotalStreams: stats.TotalStreams, + TotalRPCs: stats.TotalRPCs, + TotalMessagesSent: stats.TotalMessagesSent, + TotalMessagesRecv: stats.TotalMessagesRecv, + StreamsByMethod: byMethod, + }, nil +} + +func (p *grpcStatsProvider) MapError(err error, log *slog.Logger, operation string) (int, string, string) { + return mapGRPCEngineError(err, log, operation) +} + +func (p *grpcStatsProvider) ProtocolName() string { return "gRPC" } diff --git a/pkg/admin/grpc_handlers_test.go b/pkg/admin/grpc_handlers_test.go new file mode 100644 index 00000000..e5941bc0 --- /dev/null +++ b/pkg/admin/grpc_handlers_test.go @@ -0,0 +1,218 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/admin/engineclient" +) + +// ============================================================================ +// handleListGRPCStreams +// ============================================================================ + +func TestHandleListGRPCStreams_NoEngine_ReturnsEmptyList(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + api.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp GRPCStreamListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Streams) + assert.NotNil(t, resp.Stats.StreamsByMethod) +} + +func TestHandleListGRPCStreams_WithEngine_ReturnsStreamsAndStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/grpc/stats": + _, _ = w.Write([]byte(`{"activeStreams":1,"totalStreams":5,"totalRPCs":100,"totalMessagesSent":50,"totalMessagesRecv":30,"streamsByMethod":{"/pkg.Svc/Stream":1}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/grpc/connections": + _, _ = w.Write([]byte(`{"streams":[{"id":"grpc-stream-1","method":"/pkg.Svc/Stream","streamType":"server_stream","clientAddr":"127.0.0.1:5000","connectedAt":"2026-04-05T00:00:00Z","messagesSent":10,"messagesRecv":1}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + api.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp GRPCStreamListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Streams, 1) + assert.Equal(t, "grpc-stream-1", resp.Streams[0].ID) + assert.Equal(t, "/pkg.Svc/Stream", resp.Streams[0].Method) + assert.Equal(t, 1, resp.Stats.ActiveStreams) + assert.Equal(t, int64(100), resp.Stats.TotalRPCs) +} + +// ============================================================================ +// handleGetGRPCStream +// ============================================================================ + +func TestHandleGetGRPCStream_MissingID_ReturnsBadRequest(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleGetGRPCStream_NoEngine_ReturnsNotFound(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleGetGRPCStream_Found_ReturnsStream(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"stream-1","method":"/pkg.Svc/Chat","streamType":"bidi","connectedAt":"2026-04-05T00:00:00Z","messagesSent":5,"messagesRecv":3}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stream engineclient.GRPCStream + err := json.Unmarshal(rec.Body.Bytes(), &stream) + require.NoError(t, err) + assert.Equal(t, "stream-1", stream.ID) + assert.Equal(t, "bidi", stream.StreamType) +} + +// ============================================================================ +// handleCancelGRPCStream +// ============================================================================ + +func TestHandleCancelGRPCStream_MissingID_ReturnsBadRequest(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/", nil) + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestHandleCancelGRPCStream_NoEngine_ReturnsNotFound(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleCancelGRPCStream_Success(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"gRPC stream cancelled","id":"stream-1"}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + api.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} + +// ============================================================================ +// handleGetGRPCStats +// ============================================================================ + +func TestHandleGetGRPCStats_NoEngine_ReturnsEmptyStats(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir())) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.GRPCStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.StreamsByMethod) +} + +func TestHandleGetGRPCStats_WithEngine_ReturnsStats(t *testing.T) { + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"activeStreams":3,"totalStreams":10,"totalRPCs":500,"totalMessagesSent":200,"totalMessagesRecv":150,"streamsByMethod":{"/pkg.Svc/A":2,"/pkg.Svc/B":1}}`)) + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + api.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats engineclient.GRPCStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.Equal(t, 3, stats.ActiveStreams) + assert.Equal(t, int64(500), stats.TotalRPCs) + assert.Len(t, stats.StreamsByMethod, 2) +} diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index e02619c1..81e7d3a6 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -32,9 +32,11 @@ type mockEngine struct { sseConnections []*SSEConnection wsConnections []*WebSocketConnection mqttConnections []*MQTTConnection + grpcStreams []*GRPCStream sseStats *SSEStats wsStats *WebSocketStats mqttStats *MQTTStats + grpcStats *GRPCStats configResp *ConfigResponse protocols map[string]ProtocolStatusInfo @@ -405,10 +407,32 @@ func (m *mockEngine) GetMQTTStats() *MQTTStats { return m.mqttStats } -func (m *mockEngine) ListGRPCStreams() []*GRPCStream { return nil } -func (m *mockEngine) GetGRPCStream(id string) *GRPCStream { return nil } -func (m *mockEngine) CancelGRPCStream(id string) error { return errors.New("not found") } -func (m *mockEngine) GetGRPCStats() *GRPCStats { return nil } +func (m *mockEngine) ListGRPCStreams() []*GRPCStream { + return m.grpcStreams +} + +func (m *mockEngine) GetGRPCStream(id string) *GRPCStream { + for _, s := range m.grpcStreams { + if s.ID == id { + return s + } + } + return nil +} + +func (m *mockEngine) CancelGRPCStream(id string) error { + for i, s := range m.grpcStreams { + if s.ID == id { + m.grpcStreams = append(m.grpcStreams[:i], m.grpcStreams[i+1:]...) + return nil + } + } + return errors.New("stream not found") +} + +func (m *mockEngine) GetGRPCStats() *GRPCStats { + return m.grpcStats +} func (m *mockEngine) GetConfig() *ConfigResponse { return m.configResp @@ -2762,3 +2786,151 @@ func TestHandleGetMQTTStats(t *testing.T) { assert.Equal(t, 3, stats.SubscriptionsByClient["c1"]) }) } + +// ============================================================================ +// gRPC stream handlers +// ============================================================================ + +func TestHandleListGRPCStreams(t *testing.T) { + t.Run("empty", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + server.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp GRPCStreamListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, 0, resp.Count) + }) + + t.Run("with streams", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi", MessagesSent: 5, MessagesRecv: 3}, + {ID: "stream-2", Method: "/pkg.Svc/Watch", StreamType: "server_stream", MessagesSent: 10}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections", nil) + rec := httptest.NewRecorder() + + server.handleListGRPCStreams(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var resp GRPCStreamListResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + assert.Equal(t, 2, resp.Count) + assert.Equal(t, "stream-1", resp.Streams[0].ID) + }) +} + +func TestHandleGetGRPCStream(t *testing.T) { + t.Run("found", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + server.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stream GRPCStream + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stream)) + assert.Equal(t, "stream-1", stream.ID) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleGetGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleCancelGRPCStream(t *testing.T) { + t.Run("success", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStreams = []*GRPCStream{ + {ID: "stream-1", Method: "/pkg.Svc/Chat", StreamType: "bidi"}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/stream-1", nil) + req.SetPathValue("id", "stream-1") + rec := httptest.NewRecorder() + + server.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("not found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/grpc/connections/nonexistent", nil) + req.SetPathValue("id", "nonexistent") + rec := httptest.NewRecorder() + + server.handleCancelGRPCStream(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +func TestHandleGetGRPCStats(t *testing.T) { + t.Run("nil stats", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats GRPCStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.NotNil(t, stats.StreamsByMethod) + }) + + t.Run("with stats", func(t *testing.T) { + engine := newMockEngine() + engine.grpcStats = &GRPCStats{ + ActiveStreams: 2, + TotalStreams: 10, + TotalRPCs: 100, + TotalMessagesSent: 50, + TotalMessagesRecv: 30, + StreamsByMethod: map[string]int{"/pkg.Svc/Chat": 2}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/grpc/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetGRPCStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats GRPCStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, 2, stats.ActiveStreams) + assert.Equal(t, int64(100), stats.TotalRPCs) + assert.Equal(t, 2, stats.StreamsByMethod["/pkg.Svc/Chat"]) + }) +} diff --git a/pkg/engine/api/types.go b/pkg/engine/api/types.go index e069ecfb..f55fbc85 100644 --- a/pkg/engine/api/types.go +++ b/pkg/engine/api/types.go @@ -41,6 +41,12 @@ type ( WebSocketConnection = types.WebSocketConnection WebSocketConnectionListResponse = types.WebSocketConnectionListResponse WebSocketStats = types.WebSocketStats + MQTTConnection = types.MQTTConnection + MQTTConnectionListResponse = types.MQTTConnectionListResponse + MQTTStats = types.MQTTStats + GRPCStream = types.GRPCStream + GRPCStreamListResponse = types.GRPCStreamListResponse + GRPCStats = types.GRPCStats ConfigResponse = types.ConfigResponse CustomOperationInfo = types.CustomOperationInfo CustomOperationDetail = types.CustomOperationDetail diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 66be6118..09c6eb64 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" + grpcpeer "google.golang.org/grpc/peer" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection" @@ -86,6 +87,7 @@ type Server struct { startedAt time.Time log *slog.Logger templateEngine *template.Engine + streamTracker *StreamTracker // Request logging support requestLoggerMu sync.RWMutex @@ -106,6 +108,7 @@ func NewServer(config *GRPCConfig, schema *ProtoSchema) (*Server, error) { schema: schema, log: logging.Nop(), templateEngine: template.New(), + streamTracker: NewStreamTracker(), }, nil } @@ -196,6 +199,10 @@ func (s *Server) Stop(ctx context.Context, timeout time.Duration) error { s.listener = nil s.mu.Unlock() + // Cancel all tracked streams before stopping the gRPC server so that + // stream handlers return codes.Unavailable promptly. + s.streamTracker.CancelAll() + if grpcSrv != nil { // Create a channel to signal graceful stop completion done := make(chan struct{}) @@ -630,6 +637,7 @@ func (s *Server) makeStreamHandler(serviceName, methodName string) func(srv inte // handleUnary handles unary RPC calls. func (s *Server) handleUnary(_ interface{}, ctx context.Context, dec func(interface{}) error, _ grpc.UnaryServerInterceptor, serviceName, methodName string) (interface{}, error) { + s.streamTracker.RecordUnaryRPC() startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) @@ -770,13 +778,10 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker (replaces raw metrics.ActiveConnections) + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamServerStream, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) // Read single request from client inputDesc := method.GetInputDescriptor() @@ -792,6 +797,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, nil, nil, grpcErr) return grpcErr } + s.streamTracker.RecordRecv(streamID) reqMap := dynamicMessageToMap(reqMsg) @@ -803,11 +809,10 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD return err } - // Apply initial delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := status.Error(codes.Unavailable, "stream cancelled") s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, nil, grpcErr) return grpcErr } @@ -836,6 +841,13 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD var collectedResponses []interface{} for i, respData := range responses { + // Check for tracker cancellation between messages + if ctx.Err() != nil { + grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) + return grpcErr + } + resp, err := s.buildResponse(method, respData, templateCtx) if err != nil { grpcErr := status.Errorf(codes.Internal, "failed to build response: %v", err) @@ -849,6 +861,8 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD return grpcErr } + s.streamTracker.RecordSent(streamID) + // Collect for logging collectedResponses = append(collectedResponses, dynamicMessageToMap(resp)) @@ -856,7 +870,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD if i < len(responses)-1 { s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during stream delay") + grpcErr := status.Error(codes.Unavailable, "stream cancelled") s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) return grpcErr } @@ -874,13 +888,10 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamClientStream, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) inputDesc := method.GetInputDescriptor() if inputDesc == nil { @@ -893,6 +904,13 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD var allRequests []interface{} var lastReqMap map[string]interface{} for { + // Check for tracker cancellation + if ctx.Err() != nil { + grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) + return grpcErr + } + reqMsg := dynamicpb.NewMessage(inputDesc) if err := stream.RecvMsg(reqMsg); err != nil { if errors.Is(err, io.EOF) { @@ -902,6 +920,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } + s.streamTracker.RecordRecv(streamID) lastReqMap = dynamicMessageToMap(reqMsg) allRequests = append(allRequests, lastReqMap) } @@ -914,11 +933,10 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD return err } - // Apply delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := status.Error(codes.Unavailable, "stream cancelled") s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } @@ -949,6 +967,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD return err } + s.streamTracker.RecordSent(streamID) respMap := dynamicMessageToMap(resp) // Log the successful call @@ -962,13 +981,10 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * startTime := time.Now() fullPath := fmt.Sprintf("/%s/%s", serviceName, methodName) - // Track active stream connection - if metrics.ActiveConnections != nil { - if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { - vec.Inc() - defer vec.Dec() - } - } + // Register stream with tracker + streamID, ctx, cancel := s.streamTracker.Register(stream.Context(), fullPath, streamBidi, peerAddr(stream.Context())) + defer cancel() + defer s.streamTracker.Unregister(streamID) inputDesc := method.GetInputDescriptor() if inputDesc == nil { @@ -985,11 +1001,10 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return err } - // Apply initial delay (context-aware for client cancellation) - ctx := stream.Context() + // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during configured delay") + grpcErr := status.Error(codes.Unavailable, "stream cancelled") s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, nil, nil, grpcErr) return grpcErr } @@ -1018,6 +1033,13 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * // Echo pattern: for each received message, send a response for { + // Check for tracker cancellation + if ctx.Err() != nil { + grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) + return grpcErr + } + reqMsg := dynamicpb.NewMessage(inputDesc) if err := stream.RecvMsg(reqMsg); err != nil { if errors.Is(err, io.EOF) { @@ -1031,6 +1053,8 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return grpcErr } + s.streamTracker.RecordRecv(streamID) + // Collect request for logging reqMap := dynamicMessageToMap(reqMsg) allRequests = append(allRequests, reqMap) @@ -1064,13 +1088,15 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * return grpcErr } + s.streamTracker.RecordSent(streamID) + // Collect response for logging allResponses = append(allResponses, dynamicMessageToMap(resp)) respIndex++ s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Canceled, "client cancelled during stream delay") + grpcErr := status.Error(codes.Unavailable, "stream cancelled") s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) return grpcErr } @@ -1652,6 +1678,11 @@ func (s *Server) Config() *GRPCConfig { return s.config } +// StreamTracker returns the server's stream tracker. +func (s *Server) StreamTracker() *StreamTracker { + return s.streamTracker +} + // Schema returns the proto schema. func (s *Server) Schema() *ProtoSchema { return s.schema @@ -1931,3 +1962,11 @@ func (s *Server) recordGRPCMetrics(fullPath string, grpcErr error, duration time } } } + +// peerAddr extracts the client address from a gRPC context. +func peerAddr(ctx context.Context) string { + if p, ok := grpcpeer.FromContext(ctx); ok && p.Addr != nil { + return p.Addr.String() + } + return "" +} diff --git a/pkg/grpc/stream_tracker.go b/pkg/grpc/stream_tracker.go new file mode 100644 index 00000000..b72bd920 --- /dev/null +++ b/pkg/grpc/stream_tracker.go @@ -0,0 +1,267 @@ +package grpc + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/getmockd/mockd/pkg/metrics" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// StreamInfo holds metadata about an active streaming RPC. +type StreamInfo struct { + ID string `json:"id"` + Method string `json:"method"` + StreamType streamType `json:"streamType"` + ClientAddr string `json:"clientAddr,omitempty"` + ConnectedAt time.Time `json:"connectedAt"` + MessagesSent int64 `json:"messagesSent"` + MessagesRecv int64 `json:"messagesRecv"` + + cancel context.CancelFunc +} + +// StreamStats holds aggregate statistics for gRPC streams. +type StreamStats struct { + ActiveStreams int `json:"activeStreams"` + TotalStreams int64 `json:"totalStreams"` + TotalRPCs int64 `json:"totalRPCs"` + TotalMsgSent int64 `json:"totalMessagesSent"` + TotalMsgRecv int64 `json:"totalMessagesRecv"` + StreamsByMethod map[string]int `json:"streamsByMethod"` +} + +// StreamTracker tracks active gRPC streaming RPCs for a single Server. +type StreamTracker struct { + streams map[string]*trackedStream // ID -> tracked stream + mu sync.RWMutex + + // Lifetime counters (include completed streams). + totalStreams atomic.Int64 + totalRPCs atomic.Int64 + totalMsgSent atomic.Int64 + totalMsgRecv atomic.Int64 +} + +// trackedStream extends StreamInfo with internal bookkeeping. +type trackedStream struct { + info StreamInfo + msgSent atomic.Int64 + msgRecv atomic.Int64 + cancel context.CancelFunc +} + +// NewStreamTracker creates a new StreamTracker. +func NewStreamTracker() *StreamTracker { + return &StreamTracker{ + streams: make(map[string]*trackedStream), + } +} + +// nextStreamID generates a short unique stream ID. +var streamIDCounter atomic.Int64 + +func nextStreamID() string { + return fmt.Sprintf("grpc-stream-%d", streamIDCounter.Add(1)) +} + +// Register adds a new streaming RPC and returns its ID and a cancel-aware +// context. The caller should defer Unregister. +func (t *StreamTracker) Register(ctx context.Context, method string, st streamType, clientAddr string) (string, context.Context, context.CancelFunc) { + id := nextStreamID() + ctx, cancel := context.WithCancel(ctx) + + ts := &trackedStream{ + info: StreamInfo{ + ID: id, + Method: method, + StreamType: st, + ClientAddr: clientAddr, + ConnectedAt: time.Now(), + }, + cancel: cancel, + } + + t.mu.Lock() + t.streams[id] = ts + t.mu.Unlock() + + t.totalStreams.Add(1) + + if metrics.ActiveConnections != nil { + if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { + vec.Inc() + } + } + + return id, ctx, cancel +} + +// Unregister removes a stream and accumulates its message counters. +func (t *StreamTracker) Unregister(id string) { + t.mu.Lock() + ts, ok := t.streams[id] + if ok { + delete(t.streams, id) + } + t.mu.Unlock() + + if !ok { + return + } + + t.totalMsgSent.Add(ts.msgSent.Load()) + t.totalMsgRecv.Add(ts.msgRecv.Load()) + + if metrics.ActiveConnections != nil { + if vec, err := metrics.ActiveConnections.WithLabels("grpc"); err == nil { + vec.Dec() + } + } +} + +// RecordSent increments the sent counter for a stream. +func (t *StreamTracker) RecordSent(id string) { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts != nil { + ts.msgSent.Add(1) + } +} + +// RecordRecv increments the received counter for a stream. +func (t *StreamTracker) RecordRecv(id string) { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts != nil { + ts.msgRecv.Add(1) + } +} + +// RecordUnaryRPC increments the total RPC counter for unary calls +// (which are not tracked as active streams). +func (t *StreamTracker) RecordUnaryRPC() { + t.totalRPCs.Add(1) +} + +// Get returns info about a specific stream, or nil. +func (t *StreamTracker) Get(id string) *StreamInfo { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts == nil { + return nil + } + return t.toInfo(ts) +} + +// List returns info about all active streams. +func (t *StreamTracker) List() []*StreamInfo { + t.mu.RLock() + defer t.mu.RUnlock() + + result := make([]*StreamInfo, 0, len(t.streams)) + for _, ts := range t.streams { + result = append(result, t.toInfo(ts)) + } + return result +} + +// Count returns the number of active streams. +func (t *StreamTracker) Count() int { + t.mu.RLock() + defer t.mu.RUnlock() + return len(t.streams) +} + +// Cancel cancels a specific stream's context, causing the RPC handler to +// return codes.Unavailable to the client. +func (t *StreamTracker) Cancel(id string) error { + t.mu.RLock() + ts := t.streams[id] + t.mu.RUnlock() + if ts == nil { + return fmt.Errorf("stream %s not found", id) + } + ts.cancel() + return nil +} + +// CancelAll cancels all active streams with the given gRPC status. +// Returns the number of streams cancelled. +func (t *StreamTracker) CancelAll() int { + t.mu.RLock() + ids := make([]string, 0, len(t.streams)) + for id := range t.streams { + ids = append(ids, id) + } + t.mu.RUnlock() + + count := 0 + for _, id := range ids { + if t.Cancel(id) == nil { + count++ + } + } + return count +} + +// CancelAllWithStatus cancels all active streams. The gRPC handler +// should check ctx.Err() and return the appropriate status to the client. +// Returns the number of streams cancelled. +func (t *StreamTracker) CancelAllWithStatus() int { + return t.CancelAll() +} + +// Stats returns aggregate statistics. +func (t *StreamTracker) Stats() *StreamStats { + t.mu.RLock() + defer t.mu.RUnlock() + + var liveSent, liveRecv int64 + byMethod := make(map[string]int) + for _, ts := range t.streams { + liveSent += ts.msgSent.Load() + liveRecv += ts.msgRecv.Load() + byMethod[ts.info.Method]++ + } + + return &StreamStats{ + ActiveStreams: len(t.streams), + TotalStreams: t.totalStreams.Load(), + TotalRPCs: t.totalRPCs.Load() + t.totalStreams.Load(), + TotalMsgSent: t.totalMsgSent.Load() + liveSent, + TotalMsgRecv: t.totalMsgRecv.Load() + liveRecv, + StreamsByMethod: byMethod, + } +} + +// Close cancels all streams (used during server shutdown). +func (t *StreamTracker) Close() { + t.CancelAll() +} + +func (t *StreamTracker) toInfo(ts *trackedStream) *StreamInfo { + return &StreamInfo{ + ID: ts.info.ID, + Method: ts.info.Method, + StreamType: ts.info.StreamType, + ClientAddr: ts.info.ClientAddr, + ConnectedAt: ts.info.ConnectedAt, + MessagesSent: ts.msgSent.Load(), + MessagesRecv: ts.msgRecv.Load(), + } +} + +// UnavailableError returns a gRPC Unavailable status error with the given message. +// This is the gRPC equivalent of WebSocket close code 1012 — clients with retry +// policies will reconnect automatically. +func UnavailableError(msg string) error { + return status.Error(codes.Unavailable, msg) +} diff --git a/pkg/grpc/stream_tracker_test.go b/pkg/grpc/stream_tracker_test.go new file mode 100644 index 00000000..6228d42f --- /dev/null +++ b/pkg/grpc/stream_tracker_test.go @@ -0,0 +1,193 @@ +package grpc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamTracker_RegisterUnregister(t *testing.T) { + tracker := NewStreamTracker() + + id, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamServerStream, "127.0.0.1:1234") + defer cancel() + + assert.NotEmpty(t, id) + assert.Equal(t, 1, tracker.Count()) + + info := tracker.Get(id) + require.NotNil(t, info) + assert.Equal(t, "/pkg.Svc/Method", info.Method) + assert.Equal(t, streamServerStream, info.StreamType) + assert.Equal(t, "127.0.0.1:1234", info.ClientAddr) + assert.False(t, info.ConnectedAt.IsZero()) + + // Context should not be cancelled yet + assert.NoError(t, ctx.Err()) + + tracker.Unregister(id) + assert.Equal(t, 0, tracker.Count()) + assert.Nil(t, tracker.Get(id)) +} + +func TestStreamTracker_MessageCounting(t *testing.T) { + tracker := NewStreamTracker() + + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Stream", streamBidi, "") + defer cancel() + + tracker.RecordSent(id) + tracker.RecordSent(id) + tracker.RecordRecv(id) + + info := tracker.Get(id) + require.NotNil(t, info) + assert.Equal(t, int64(2), info.MessagesSent) + assert.Equal(t, int64(1), info.MessagesRecv) + + // Verify stats include live counters + stats := tracker.Stats() + assert.Equal(t, int64(2), stats.TotalMsgSent) + assert.Equal(t, int64(1), stats.TotalMsgRecv) + assert.Equal(t, 1, stats.ActiveStreams) + + // After unregister, counters roll into lifetime totals + tracker.Unregister(id) + stats = tracker.Stats() + assert.Equal(t, 0, stats.ActiveStreams) + assert.Equal(t, int64(2), stats.TotalMsgSent) + assert.Equal(t, int64(1), stats.TotalMsgRecv) +} + +func TestStreamTracker_Cancel(t *testing.T) { + tracker := NewStreamTracker() + + id, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamServerStream, "") + defer cancel() + + // Cancel the stream + err := tracker.Cancel(id) + assert.NoError(t, err) + + // Context should be cancelled + assert.Error(t, ctx.Err()) + + // Cancel non-existent stream + err = tracker.Cancel("non-existent") + assert.Error(t, err) +} + +func TestStreamTracker_CancelAll(t *testing.T) { + tracker := NewStreamTracker() + + _, ctx1, cancel1 := tracker.Register(context.Background(), "/pkg.Svc/A", streamServerStream, "") + defer cancel1() + _, ctx2, cancel2 := tracker.Register(context.Background(), "/pkg.Svc/B", streamBidi, "") + defer cancel2() + + assert.Equal(t, 2, tracker.Count()) + + n := tracker.CancelAll() + assert.Equal(t, 2, n) + + assert.Error(t, ctx1.Err()) + assert.Error(t, ctx2.Err()) +} + +func TestStreamTracker_List(t *testing.T) { + tracker := NewStreamTracker() + + id1, _, cancel1 := tracker.Register(context.Background(), "/pkg.Svc/A", streamServerStream, "1.2.3.4:100") + defer cancel1() + id2, _, cancel2 := tracker.Register(context.Background(), "/pkg.Svc/B", streamClientStream, "5.6.7.8:200") + defer cancel2() + + list := tracker.List() + assert.Len(t, list, 2) + + ids := map[string]bool{} + for _, info := range list { + ids[info.ID] = true + } + assert.True(t, ids[id1]) + assert.True(t, ids[id2]) +} + +func TestStreamTracker_Stats(t *testing.T) { + tracker := NewStreamTracker() + + // Record some unary RPCs + tracker.RecordUnaryRPC() + tracker.RecordUnaryRPC() + + // Register a stream + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Stream", streamServerStream, "") + defer cancel() + tracker.RecordSent(id) + + stats := tracker.Stats() + assert.Equal(t, 1, stats.ActiveStreams) + assert.Equal(t, int64(1), stats.TotalStreams) + // TotalRPCs = unary (2) + streams (1) = 3 + assert.Equal(t, int64(3), stats.TotalRPCs) + assert.Equal(t, int64(1), stats.TotalMsgSent) + assert.Equal(t, map[string]int{"/pkg.Svc/Stream": 1}, stats.StreamsByMethod) +} + +func TestStreamTracker_RecordOnNonExistent(t *testing.T) { + tracker := NewStreamTracker() + + // Should not panic + tracker.RecordSent("non-existent") + tracker.RecordRecv("non-existent") +} + +func TestStreamTracker_UnregisterIdempotent(t *testing.T) { + tracker := NewStreamTracker() + + id, _, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamBidi, "") + defer cancel() + + tracker.Unregister(id) + tracker.Unregister(id) // Should not panic + assert.Equal(t, 0, tracker.Count()) +} + +func TestStreamTracker_Close(t *testing.T) { + tracker := NewStreamTracker() + + _, ctx, cancel := tracker.Register(context.Background(), "/pkg.Svc/Method", streamBidi, "") + defer cancel() + + tracker.Close() + assert.Error(t, ctx.Err()) +} + +func TestStreamTracker_ConcurrentAccess(t *testing.T) { + tracker := NewStreamTracker() + ctx := context.Background() + + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 100; i++ { + id, _, cancel := tracker.Register(ctx, "/pkg.Svc/Method", streamBidi, "") + tracker.RecordSent(id) + tracker.RecordRecv(id) + time.Sleep(time.Microsecond) + tracker.Unregister(id) + cancel() + } + }() + + for i := 0; i < 100; i++ { + _ = tracker.List() + _ = tracker.Stats() + _ = tracker.Count() + } + + <-done +} From 84435302b755b6c296f47a85c26ff5bb9ea03a85 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Sun, 5 Apr 2026 18:53:44 -0500 Subject: [PATCH 08/18] feat: implement SSE auto-disconnect on mock update and delete When SSE mocks are updated, deleted, toggled, or cleared, active SSE connections now get disconnected automatically. This prevents clients from streaming stale data after configuration changes. Changes: - Add DisconnectSSEByMock() to Handler (handler_protocol.go) which delegates to SSEConnectionManager.CloseByMock() - Add case mock.TypeHTTP in unregisterHandlerLocked (mock_manager.go) that disconnects SSE connections when the HTTP mock has SSE config - Add 6 integration tests covering update, delete, toggle, clear, multi-connection, and cross-mock isolation scenarios The SSE close mechanism uses context cancellation: cancelling each connection's context causes the event loop to exit, the response writer to close, and the client to receive EOF. EventSource clients will auto-reconnect and pick up the new configuration. Co-Authored-By: Claude Opus 4.6 --- pkg/engine/handler_protocol.go | 15 ++ tests/integration/sse_test.go | 302 +++++++++++++++++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 tests/integration/sse_test.go diff --git a/pkg/engine/handler_protocol.go b/pkg/engine/handler_protocol.go index c5b25861..1663667e 100644 --- a/pkg/engine/handler_protocol.go +++ b/pkg/engine/handler_protocol.go @@ -157,6 +157,21 @@ func (h *Handler) DisconnectWebSocketEndpoint(path string) { h.wsManager.DisconnectByEndpoint(path, websocket.CloseServiceRestart, "mock updated") } +// DisconnectSSEByMock closes all active SSE connections for the given mock ID. +// Cancelling each connection's context causes the event loop to exit, the response +// writer to close, and the client to receive EOF. EventSource clients will +// auto-reconnect and pick up the new configuration. +func (h *Handler) DisconnectSSEByMock(mockID string) { + if h.sseHandler == nil { + return + } + mgr := h.sseHandler.GetManager() + if mgr == nil { + return + } + mgr.CloseByMock(mockID) +} + // ListSOAPHandlerPaths returns all registered SOAP handler paths. func (h *Handler) ListSOAPHandlerPaths() []string { h.soapMu.RLock() diff --git a/tests/integration/sse_test.go b/tests/integration/sse_test.go new file mode 100644 index 00000000..b0895c17 --- /dev/null +++ b/tests/integration/sse_test.go @@ -0,0 +1,302 @@ +package integration + +import ( + "bufio" + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/getmockd/mockd/pkg/config" + "github.com/getmockd/mockd/pkg/engine" + "github.com/getmockd/mockd/pkg/mock" + "github.com/getmockd/mockd/pkg/sse" +) + +// ============================================================================ +// SSE Test Helpers +// ============================================================================ + +// setupSSEMockServer creates a test server and adapter for SSE mock management. +func setupSSEMockServer(t *testing.T) (*httptest.Server, *engine.ControlAPIAdapter) { + t.Helper() + cfg := config.DefaultServerConfiguration() + srv := engine.NewServer(cfg) + adapter := engine.NewControlAPIAdapter(srv) + + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(func() { ts.Close() }) + return ts, adapter +} + +// getSSEManager extracts the SSE ConnectionManager from the test server. +func getSSEManager(t *testing.T, ts *httptest.Server) *sse.SSEConnectionManager { + t.Helper() + handler := ts.Config.Handler + require.NotNil(t, handler, "handler is nil") + + engineHandler, ok := handler.(*engine.Handler) + require.True(t, ok, "handler is not *engine.Handler") + + sseHandler := engineHandler.SSEHandler() + require.NotNil(t, sseHandler, "SSE handler is nil") + + mgr := sseHandler.GetManager() + require.NotNil(t, mgr, "SSE connection manager is nil") + return mgr +} + +// newSSEMock creates a minimal SSE mock configuration that streams events +// with a keepalive to hold the connection open. +func newSSEMock(path string) *mock.Mock { + return &mock.Mock{ + Type: mock.TypeHTTP, + HTTP: &mock.HTTPSpec{ + Matcher: &mock.HTTPMatcher{ + Method: "GET", + Path: path, + }, + SSE: &mock.SSEConfig{ + Generator: &mock.SSEEventGenerator{ + Type: "sequence", + Count: 0, // infinite + Sequence: &mock.SSESequenceGenerator{ + Start: 1, + Increment: 1, + }, + }, + Lifecycle: mock.SSELifecycleConfig{ + KeepaliveInterval: 300, // 5min keepalive keeps connection alive + }, + }, + }, + } +} + +// connectSSE opens an SSE connection and waits until the first event is received, +// confirming the server has registered the connection. Returns a cancel func +// that closes the connection. +func connectSSE(t *testing.T, url string) context.CancelFunc { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Read until we get at least one SSE event (data: line), confirming connection is tracked. + ready := make(chan struct{}) + go func() { + scanner := bufio.NewScanner(resp.Body) + signalled := false + for scanner.Scan() { + line := scanner.Text() + if !signalled && len(line) > 0 { + close(ready) + signalled = true + } + } + // Body closed by context cancellation + }() + + select { + case <-ready: + // Connection is active and streaming + case <-time.After(5 * time.Second): + cancel() + t.Fatal("timeout waiting for SSE event") + } + + return cancel +} + +// waitForSSECount polls until the SSE manager reports the expected connection count +// or the timeout expires. +func waitForSSECount(t *testing.T, mgr *sse.SSEConnectionManager, expected int, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if mgr.Count() == expected { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("expected %d SSE connections, got %d", expected, mgr.Count()) +} + +// ============================================================================ +// SSE Auto-Disconnect Tests +// ============================================================================ + +// TestSSE_MockUpdate_ClosesActiveConnections verifies that updating an SSE mock +// disconnects all active SSE clients so they reconnect with the new configuration. +func TestSSE_MockUpdate_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + // Create an SSE mock. + mockCfg := newSSEMock("/sse/update-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect an SSE client. + cancelSSE := connectSSE(t, ts.URL+"/sse/update-test") + defer cancelSSE() + + // Wait for the connection to be registered. + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Update the mock — this should disconnect all active SSE clients. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + updated := *mocks[0] + updated.HTTP.SSE.Generator.Sequence.Start = 100 + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // Connections should be closed. + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MockDelete_ClosesActiveConnections verifies that deleting an SSE mock +// disconnects all active SSE clients. +func TestSSE_MockDelete_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/delete-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/delete-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Delete the mock — SSE clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MockToggleDisable_ClosesActiveConnections verifies that disabling an SSE +// mock (toggle to enabled=false) disconnects all active SSE clients. +func TestSSE_MockToggleDisable_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/toggle-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/toggle-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Disable the mock — SSE clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + disabled := *mocks[0] + enabled := false + disabled.Enabled = &enabled + require.NoError(t, adapter.UpdateMock(disabled.ID, &disabled)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_ClearMocks_ClosesActiveConnections verifies that clearing all mocks +// disconnects all active SSE clients. +func TestSSE_ClearMocks_ClosesActiveConnections(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/clear-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + cancelSSE := connectSSE(t, ts.URL+"/sse/clear-test") + defer cancelSSE() + + waitForSSECount(t, mgr, 1, 2*time.Second) + + // Clear all mocks — SSE clients must be disconnected. + adapter.ClearMocks() + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_MultipleConnections_AllDisconnected verifies that multiple SSE clients +// connected to the same mock are all disconnected when the mock is deleted. +func TestSSE_MultipleConnections_AllDisconnected(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + mockCfg := newSSEMock("/sse/multi-test") + require.NoError(t, adapter.AddMock(mockCfg)) + + // Connect 3 SSE clients. + for i := 0; i < 3; i++ { + cancel := connectSSE(t, ts.URL+"/sse/multi-test") + defer cancel() + } + + waitForSSECount(t, mgr, 3, 2*time.Second) + + // Delete the mock — all 3 clients must be disconnected. + mocks := adapter.ListMocks() + require.Len(t, mocks, 1) + require.NoError(t, adapter.DeleteMock(mocks[0].ID)) + + waitForSSECount(t, mgr, 0, 2*time.Second) +} + +// TestSSE_UpdateDoesNotAffectOtherMocks verifies that updating one SSE mock +// only disconnects clients on that mock, not clients on other SSE mocks. +func TestSSE_UpdateDoesNotAffectOtherMocks(t *testing.T) { + ts, adapter := setupSSEMockServer(t) + mgr := getSSEManager(t, ts) + + // Create two SSE mocks. + mock1 := newSSEMock("/sse/mock1") + require.NoError(t, adapter.AddMock(mock1)) + mock2 := newSSEMock("/sse/mock2") + require.NoError(t, adapter.AddMock(mock2)) + + // Connect a client to each. + cancel1 := connectSSE(t, ts.URL+"/sse/mock1") + defer cancel1() + cancel2 := connectSSE(t, ts.URL+"/sse/mock2") + defer cancel2() + + waitForSSECount(t, mgr, 2, 2*time.Second) + + // Update mock1 — only mock1 connections should drop. + mocks := adapter.ListMocks() + require.Len(t, mocks, 2) + + var target *mock.Mock + for _, m := range mocks { + if m.HTTP != nil && m.HTTP.Matcher != nil && m.HTTP.Matcher.Path == "/sse/mock1" { + target = m + break + } + } + require.NotNil(t, target, "could not find /sse/mock1 mock") + + updated := *target + updated.HTTP.SSE.Generator.Sequence.Start = 100 + require.NoError(t, adapter.UpdateMock(updated.ID, &updated)) + + // mock1 connections should be closed, mock2 should remain. + // Wait for mock1 connections to drop. + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 1, mgr.Count(), "expected 1 remaining SSE connection (mock2)") +} From b6c508b7a557540849ded456ecf40e0ad58812e5 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Sun, 5 Apr 2026 18:53:54 -0500 Subject: [PATCH 09/18] feat: align SSE response types and add engine tests SSE API responses now return full connection details (path, clientIp, userAgent, connectedAt, eventsSent, bytesSent, status) instead of the slimmed SSEStreamInfo (ID, MockID only). This aligns SSE with the WebSocket API which already returns rich connection info. Changes: - SSE list, get-single, and mock-specific list endpoints now return full SSEConnection structs from engineclient (admin/sse_handlers.go) - Add mock-specific WebSocket connection endpoints for parity with SSE: GET /mocks/{id}/websocket/connections and DELETE /mocks/{id}/websocket/connections (admin/websocket_handlers.go) - Register the new WebSocket routes (admin/routes.go) - Add engine-level SSE handler tests covering list, get, close, and stats handlers (engine/api/handlers_test.go) Co-Authored-By: Claude Opus 4.6 --- pkg/admin/routes.go | 4 + pkg/admin/sse_handlers.go | 50 +++----- pkg/admin/websocket_handlers.go | 96 ++++++++++++++- pkg/admin/websocket_handlers_test.go | 56 +++++++++ pkg/engine/api/handlers_test.go | 178 +++++++++++++++++++++++++++ 5 files changed, 353 insertions(+), 31 deletions(-) diff --git a/pkg/admin/routes.go b/pkg/admin/routes.go index 492c955c..eb84396a 100644 --- a/pkg/admin/routes.go +++ b/pkg/admin/routes.go @@ -147,6 +147,10 @@ func (a *API) registerRoutes(mux *http.ServeMux) { // Mock-specific SSE endpoints mux.HandleFunc("GET /mocks/{id}/sse/connections", a.requireEngine(a.handleListMockSSEConnections)) mux.HandleFunc("DELETE /mocks/{id}/sse/connections", a.requireEngine(a.handleCloseMockSSEConnections)) + + // Mock-specific WebSocket endpoints + mux.HandleFunc("GET /mocks/{id}/websocket/connections", a.requireEngine(a.handleListMockWebSocketConnections)) + mux.HandleFunc("DELETE /mocks/{id}/websocket/connections", a.requireEngine(a.handleCloseMockWebSocketConnections)) mux.HandleFunc("GET /mocks/{id}/sse/buffer", a.handleGetMockSSEBuffer) mux.HandleFunc("DELETE /mocks/{id}/sse/buffer", a.handleClearMockSSEBuffer) diff --git a/pkg/admin/sse_handlers.go b/pkg/admin/sse_handlers.go index 609f38c3..07640c76 100644 --- a/pkg/admin/sse_handlers.go +++ b/pkg/admin/sse_handlers.go @@ -6,14 +6,13 @@ import ( "net/http" "github.com/getmockd/mockd/pkg/admin/engineclient" - "github.com/getmockd/mockd/pkg/sse" "github.com/getmockd/mockd/pkg/store" ) // SSEConnectionListResponse represents a list of SSE connections. type SSEConnectionListResponse struct { - Connections []sse.SSEStreamInfo `json:"connections"` - Stats sse.ConnectionStats `json:"stats"` + Connections []*engineclient.SSEConnection `json:"connections"` + Stats engineclient.SSEStats `json:"stats"` } // handleListSSEConnections handles GET /sse/connections. @@ -23,8 +22,8 @@ func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { engine := a.localEngine.Load() if engine == nil { writeJSON(w, http.StatusOK, SSEConnectionListResponse{ - Connections: []sse.SSEStreamInfo{}, - Stats: sse.ConnectionStats{ConnectionsByMock: make(map[string]int)}, + Connections: []*engineclient.SSEConnection{}, + Stats: engineclient.SSEStats{ConnectionsByMock: make(map[string]int)}, }) return } @@ -45,14 +44,8 @@ func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { return } - // Convert engine client connections to SSEStreamInfo - info := make([]sse.SSEStreamInfo, 0, len(connections)) - for _, conn := range connections { - info = append(info, sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - ClientIP: conn.ClientIP, - }) + if connections == nil { + connections = []*engineclient.SSEConnection{} } connsByMock := stats.ConnectionsByMock @@ -61,12 +54,13 @@ func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, SSEConnectionListResponse{ - Connections: info, - Stats: sse.ConnectionStats{ - ActiveConnections: stats.ActiveConnections, + Connections: connections, + Stats: engineclient.SSEStats{ TotalConnections: stats.TotalConnections, + ActiveConnections: stats.ActiveConnections, TotalEventsSent: stats.TotalEventsSent, TotalBytesSent: stats.TotalBytesSent, + ConnectionErrors: stats.ConnectionErrors, ConnectionsByMock: connsByMock, }, }) @@ -99,12 +93,7 @@ func (a *API) handleGetSSEConnection(w http.ResponseWriter, r *http.Request) { return } - info := sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - } - - writeJSON(w, http.StatusOK, info) + writeJSON(w, http.StatusOK, conn) } // handleCloseSSEConnection handles DELETE /sse/connections/{id}. @@ -151,7 +140,7 @@ func (a *API) handleCloseSSEConnection(w http.ResponseWriter, r *http.Request) { func (a *API) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { engine := a.localEngine.Load() if engine == nil { - writeJSON(w, http.StatusOK, sse.ConnectionStats{ConnectionsByMock: make(map[string]int)}) + writeJSON(w, http.StatusOK, engineclient.SSEStats{ConnectionsByMock: make(map[string]int)}) return } a.handleGetStats(w, r, newSSEStatsProvider(engine)) @@ -185,19 +174,20 @@ func (a *API) handleListMockSSEConnections(w http.ResponseWriter, r *http.Reques return } - info := make([]sse.SSEStreamInfo, 0) + var filtered []*engineclient.SSEConnection for _, conn := range connections { if conn.MockID == mockID { - info = append(info, sse.SSEStreamInfo{ - ID: conn.ID, - MockID: conn.MockID, - }) + filtered = append(filtered, conn) } } + if filtered == nil { + filtered = []*engineclient.SSEConnection{} + } + writeJSON(w, http.StatusOK, map[string]interface{}{ - "connections": info, - "count": len(info), + "connections": filtered, + "count": len(filtered), "mockId": mockID, }) } diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go index 3d44eb2e..8bc05436 100644 --- a/pkg/admin/websocket_handlers.go +++ b/pkg/admin/websocket_handlers.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/getmockd/mockd/pkg/admin/engineclient" + "github.com/getmockd/mockd/pkg/store" ) // WebSocketConnectionListResponse represents a list of WebSocket connections with stats. @@ -157,12 +158,16 @@ func (a *API) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Req var req WebSocketSendRequest if err := decodeOptionalJSONBody(r, &req); err != nil { - writeError(w, http.StatusBadRequest, "invalid_body", "Invalid JSON in request body") + writeJSONDecodeError(w, err, a.logger()) return } if req.Type == "" { req.Type = "text" } + if req.Type != "text" && req.Type != "binary" { + writeError(w, http.StatusBadRequest, "invalid_type", `Type must be "text" or "binary"`) + return + } engine := a.localEngine.Load() if engine == nil { @@ -188,3 +193,92 @@ func (a *API) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Req "type": req.Type, }) } + +// handleListMockWebSocketConnections handles GET /mocks/{id}/websocket/connections. +func (a *API) handleListMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { + ctx := r.Context() + mockID := r.PathValue("id") + if mockID == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Mock ID is required") + return + } + + // Verify mock exists in the admin store (single source of truth). + if mockStore := a.getMockStore(); mockStore != nil { + if _, err := mockStore.Get(ctx, mockID); err != nil { + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Mock not found") + return + } + writeError(w, http.StatusInternalServerError, "store_error", ErrMsgInternalError) + return + } + } + + // Get all WebSocket connections and filter by mock + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, a.logger(), "list WebSocket connections for mock")) + return + } + + var filtered []*engineclient.WebSocketConnection + for _, conn := range connections { + if conn.MockID == mockID { + filtered = append(filtered, conn) + } + } + if filtered == nil { + filtered = []*engineclient.WebSocketConnection{} + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "connections": filtered, + "count": len(filtered), + "mockId": mockID, + }) +} + +// handleCloseMockWebSocketConnections handles DELETE /mocks/{id}/websocket/connections. +func (a *API) handleCloseMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { + ctx := r.Context() + mockID := r.PathValue("id") + if mockID == "" { + writeError(w, http.StatusBadRequest, "missing_id", "Mock ID is required") + return + } + + // Verify mock exists in the admin store (single source of truth). + if mockStore := a.getMockStore(); mockStore != nil { + if _, err := mockStore.Get(ctx, mockID); err != nil { + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "not_found", "Mock not found") + return + } + writeError(w, http.StatusInternalServerError, "store_error", ErrMsgInternalError) + return + } + } + + // Get all WebSocket connections and close those for this mock + connections, err := engine.ListWebSocketConnections(ctx) + if err != nil { + writeError(w, http.StatusServiceUnavailable, "engine_error", sanitizeEngineError(err, a.logger(), "list WebSocket connections for close")) + return + } + + closed := 0 + for _, conn := range connections { + if conn.MockID == mockID { + if err := engine.CloseWebSocketConnection(ctx, conn.ID); err == nil { + closed++ + } + } + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Connections closed", + "closed": closed, + "mockId": mockID, + }) +} diff --git a/pkg/admin/websocket_handlers_test.go b/pkg/admin/websocket_handlers_test.go index 51089c3d..d611a5c8 100644 --- a/pkg/admin/websocket_handlers_test.go +++ b/pkg/admin/websocket_handlers_test.go @@ -35,6 +35,43 @@ func TestHandleListWebSocketConnections_NoEngine_ReturnsEmptyList(t *testing.T) assert.NotNil(t, resp.Stats.ConnectionsByMock) } +func TestHandleListWebSocketConnections_WithEngine_ReturnsConnectionsAndStats(t *testing.T) { + // Spin up a mock engine that returns connections and stats. + mockEngine := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/websocket/stats": + _, _ = w.Write([]byte(`{"totalConnections":5,"activeConnections":2,"totalMessagesSent":100,"totalMessagesRecv":50,"connectionsByMock":{"mock-1":2}}`)) + case r.Method == http.MethodGet && r.URL.Path == "/websocket/connections": + _, _ = w.Write([]byte(`{"connections":[{"id":"conn-1","path":"/ws","connectedAt":"2026-04-05T00:00:00Z","messagesSent":10,"messagesRecv":5,"status":"connected"}],"count":1}`)) + default: + http.NotFound(w, r) + } + })) + defer mockEngine.Close() + + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New(mockEngine.URL))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodGet, "/websocket/connections", nil) + rec := httptest.NewRecorder() + + api.handleListWebSocketConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp WebSocketConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 1) + assert.Equal(t, "conn-1", resp.Connections[0].ID) + assert.Equal(t, "/ws", resp.Connections[0].Path) + assert.Equal(t, int64(10), resp.Connections[0].MessagesSent) + assert.Equal(t, 2, resp.Stats.ActiveConnections) + assert.Equal(t, int64(5), resp.Stats.TotalConnections) + assert.Equal(t, 2, resp.Stats.ConnectionsByMock["mock-1"]) +} + func TestHandleListWebSocketConnections_EngineUnavailable_Returns503(t *testing.T) { api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) defer func() { _ = api.Stop() }() @@ -196,6 +233,25 @@ func TestHandleSendToWebSocketConnection_InvalidJSON_Returns400(t *testing.T) { assert.Equal(t, http.StatusBadRequest, rec.Code) } +func TestHandleSendToWebSocketConnection_InvalidType_Returns400(t *testing.T) { + api := NewAPI(0, WithDataDir(t.TempDir()), WithLocalEngineClient(engineclient.New("http://127.0.0.1:1"))) + defer func() { _ = api.Stop() }() + + req := httptest.NewRequest(http.MethodPost, "/websocket/connections/conn-1/send", + strings.NewReader(`{"type":"invalid","data":"hello"}`)) + req.SetPathValue("id", "conn-1") + rec := httptest.NewRecorder() + + api.handleSendToWebSocketConnection(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "invalid_type", resp["error"]) +} + func TestHandleSendToWebSocketConnection_NoEngine_Returns404(t *testing.T) { api := NewAPI(0, WithDataDir(t.TempDir())) defer func() { _ = api.Stop() }() diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index 81e7d3a6..59e90269 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -2934,3 +2934,181 @@ func TestHandleGetGRPCStats(t *testing.T) { assert.Equal(t, 2, stats.StreamsByMethod["/pkg.Svc/Chat"]) }) } + +// ============================================================================ +// SSE Handler Tests +// ============================================================================ + +// TestHandleListSSEConnections tests the GET /sse/connections handler. +func TestHandleListSSEConnections(t *testing.T) { + t.Run("returns empty list when no connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/connections", nil) + rec := httptest.NewRecorder() + + server.handleListSSEConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp SSEConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Empty(t, resp.Connections) + assert.Equal(t, 0, resp.Count) + }) + + t.Run("returns all connections", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", Status: "active"}, + {ID: "sse-2", MockID: "mock-1", Path: "/events", Status: "active"}, + } + + req := httptest.NewRequest(http.MethodGet, "/sse/connections", nil) + rec := httptest.NewRecorder() + + server.handleListSSEConnections(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp SSEConnectionListResponse + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Len(t, resp.Connections, 2) + assert.Equal(t, 2, resp.Count) + }) +} + +// TestHandleGetSSEConnection tests the GET /sse/connections/{id} handler. +func TestHandleGetSSEConnection(t *testing.T) { + t.Run("returns connection when found", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", ClientIP: "127.0.0.1", Status: "active", EventsSent: 42, BytesSent: 1024}, + } + + req := httptest.NewRequest(http.MethodGet, "/sse/connections/sse-1", nil) + req.SetPathValue("id", "sse-1") + rec := httptest.NewRecorder() + + server.handleGetSSEConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var conn SSEConnection + err := json.Unmarshal(rec.Body.Bytes(), &conn) + require.NoError(t, err) + assert.Equal(t, "sse-1", conn.ID) + assert.Equal(t, "mock-1", conn.MockID) + assert.Equal(t, "/events", conn.Path) + assert.Equal(t, "127.0.0.1", conn.ClientIP) + assert.Equal(t, int64(42), conn.EventsSent) + assert.Equal(t, int64(1024), conn.BytesSent) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleGetSSEConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleCloseSSEConnection tests the DELETE /sse/connections/{id} handler. +func TestHandleCloseSSEConnection(t *testing.T) { + t.Run("closes existing connection and returns 200", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + engine.sseConnections = []*SSEConnection{ + {ID: "sse-1", MockID: "mock-1", Path: "/events", Status: "active"}, + } + + req := httptest.NewRequest(http.MethodDelete, "/sse/connections/sse-1", nil) + req.SetPathValue("id", "sse-1") + rec := httptest.NewRecorder() + + server.handleCloseSSEConnection(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]string + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "SSE connection closed", resp["message"]) + assert.Equal(t, "sse-1", resp["id"]) + + // Connection must be removed from the engine. + assert.Empty(t, engine.sseConnections) + }) + + t.Run("returns 404 for unknown connection", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodDelete, "/sse/connections/unknown", nil) + req.SetPathValue("id", "unknown") + rec := httptest.NewRecorder() + + server.handleCloseSSEConnection(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// TestHandleGetSSEStats tests the GET /sse/stats handler. +func TestHandleGetSSEStats(t *testing.T) { + t.Run("returns empty stats with non-nil map when sseStats is nil", func(t *testing.T) { + engine := newMockEngine() + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetSSEStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var stats SSEStats + err := json.Unmarshal(rec.Body.Bytes(), &stats) + require.NoError(t, err) + assert.NotNil(t, stats.ConnectionsByMock) + }) + + t.Run("returns populated stats", func(t *testing.T) { + engine := newMockEngine() + engine.sseStats = &SSEStats{ + TotalConnections: 10, + ActiveConnections: 2, + TotalEventsSent: 500, + TotalBytesSent: 50000, + ConnectionErrors: 1, + ConnectionsByMock: map[string]int{"mock-1": 2}, + } + server := newTestServer(engine) + + req := httptest.NewRequest(http.MethodGet, "/sse/stats", nil) + rec := httptest.NewRecorder() + + server.handleGetSSEStats(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var stats SSEStats + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &stats)) + assert.Equal(t, int64(10), stats.TotalConnections) + assert.Equal(t, 2, stats.ActiveConnections) + assert.Equal(t, int64(500), stats.TotalEventsSent) + assert.Equal(t, int64(50000), stats.TotalBytesSent) + assert.Equal(t, int64(1), stats.ConnectionErrors) + assert.Equal(t, 2, stats.ConnectionsByMock["mock-1"]) + }) +} From 26ec035d5c7ca84fc719a80e5a0b5a58c352faf1 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Sun, 5 Apr 2026 19:00:03 -0500 Subject: [PATCH 10/18] feat: add CLI connection management commands for all protocols Add `connections` subcommands (list, get, close) and `stats` commands for WebSocket, SSE, MQTT, and gRPC protocols. WebSocket also gets a `connections send` command with --binary flag support. New files: ws_connections.go, sse.go, mqtt_connections.go, grpc_connections.go Extended: client.go (16 new AdminClient interface methods + implementations), print.go (formatStringSlice helper) Co-Authored-By: Claude Opus 4.6 --- pkg/cli/client.go | 348 ++++++++++++++++++++++++++++++++++++ pkg/cli/grpc_connections.go | 133 ++++++++++++++ pkg/cli/mqtt_connections.go | 142 +++++++++++++++ pkg/cli/print.go | 14 +- pkg/cli/sse.go | 182 +++++++++++++++++++ pkg/cli/ws_connections.go | 172 ++++++++++++++++++ 6 files changed, 990 insertions(+), 1 deletion(-) create mode 100644 pkg/cli/grpc_connections.go create mode 100644 pkg/cli/mqtt_connections.go create mode 100644 pkg/cli/sse.go create mode 100644 pkg/cli/ws_connections.go diff --git a/pkg/cli/client.go b/pkg/cli/client.go index 7cfb25ec..b835dfc5 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -143,6 +143,45 @@ type AdminClient interface { AddEngineWorkspace(engineID, workspaceID, workspaceName string) error // BulkCreateMocks creates multiple mocks in a single request. BulkCreateMocks(mocks []*mock.Mock, workspaceID string) (*BulkCreateResult, error) + + // Connection management + // ListWebSocketConnections returns active WebSocket connections. + ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) + // GetWebSocketConnection returns a specific WebSocket connection. + GetWebSocketConnection(id string) (*apitypes.WebSocketConnection, error) + // CloseWebSocketConnection closes a WebSocket connection. + CloseWebSocketConnection(id string) error + // SendWebSocketMessage sends a message to a WebSocket connection. + SendWebSocketMessage(id string, message string, binary bool) error + // GetWebSocketStats returns WebSocket statistics. + GetWebSocketStats() (*apitypes.WebSocketStats, error) + + // ListSSEConnections returns active SSE connections. + ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) + // GetSSEConnection returns a specific SSE connection. + GetSSEConnection(id string) (*apitypes.SSEConnection, error) + // CloseSSEConnection closes an SSE connection. + CloseSSEConnection(id string) error + // GetSSEStats returns SSE statistics. + GetSSEStats() (*apitypes.SSEStats, error) + + // ListMQTTConnections returns active MQTT client connections. + ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) + // GetMQTTConnection returns a specific MQTT client connection. + GetMQTTConnection(id string) (*apitypes.MQTTConnection, error) + // CloseMQTTConnection disconnects an MQTT client. + CloseMQTTConnection(id string) error + // GetMQTTStats returns MQTT broker statistics. + GetMQTTStats() (*apitypes.MQTTStats, error) + + // ListGRPCStreams returns active gRPC streaming connections. + ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) + // GetGRPCStream returns a specific gRPC stream. + GetGRPCStream(id string) (*apitypes.GRPCStream, error) + // CloseGRPCStream cancels a gRPC stream. + CloseGRPCStream(id string) error + // GetGRPCStats returns gRPC statistics. + GetGRPCStats() (*apitypes.GRPCStats, error) } // LogFilter specifies filtering criteria for request logs. @@ -1658,6 +1697,315 @@ func (c *adminClient) BulkCreateMocks(mocks []*mock.Mock, workspaceID string) (* return &result, nil } +// ─── Connection management ────────────────────────────────────────────────── + +// ListWebSocketConnections returns active WebSocket connections. +func (c *adminClient) ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) { + resp, err := c.get("/websocket/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetWebSocketConnection returns a specific WebSocket connection. +func (c *adminClient) GetWebSocketConnection(id string) (*apitypes.WebSocketConnection, error) { + resp, err := c.get("/websocket/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseWebSocketConnection closes a WebSocket connection. +func (c *adminClient) CloseWebSocketConnection(id string) error { + resp, err := c.delete("/websocket/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// SendWebSocketMessage sends a message to a WebSocket connection. +func (c *adminClient) SendWebSocketMessage(id string, message string, binary bool) error { + msgType := "text" + if binary { + msgType = "binary" + } + body, err := json.Marshal(map[string]interface{}{ + "message": message, + "type": msgType, + }) + if err != nil { + return fmt.Errorf("failed to encode request: %w", err) + } + resp, err := c.post("/websocket/connections/"+url.PathEscape(id)+"/send", body) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "WebSocket connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return c.parseError(resp) + } + return nil +} + +// GetWebSocketStats returns WebSocket statistics. +func (c *adminClient) GetWebSocketStats() (*apitypes.WebSocketStats, error) { + resp, err := c.get("/websocket/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.WebSocketStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListSSEConnections returns active SSE connections. +func (c *adminClient) ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) { + resp, err := c.get("/sse/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetSSEConnection returns a specific SSE connection. +func (c *adminClient) GetSSEConnection(id string) (*apitypes.SSEConnection, error) { + resp, err := c.get("/sse/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "SSE connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseSSEConnection closes an SSE connection. +func (c *adminClient) CloseSSEConnection(id string) error { + resp, err := c.delete("/sse/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "SSE connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetSSEStats returns SSE statistics. +func (c *adminClient) GetSSEStats() (*apitypes.SSEStats, error) { + resp, err := c.get("/sse/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.SSEStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListMQTTConnections returns active MQTT client connections. +func (c *adminClient) ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) { + resp, err := c.get("/mqtt-connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTConnectionListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetMQTTConnection returns a specific MQTT client connection. +func (c *adminClient) GetMQTTConnection(id string) (*apitypes.MQTTConnection, error) { + resp, err := c.get("/mqtt-connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "MQTT connection not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTConnection + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseMQTTConnection disconnects an MQTT client. +func (c *adminClient) CloseMQTTConnection(id string) error { + resp, err := c.delete("/mqtt-connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "MQTT connection not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetMQTTStats returns MQTT broker statistics. +func (c *adminClient) GetMQTTStats() (*apitypes.MQTTStats, error) { + resp, err := c.get("/mqtt-connections/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.MQTTStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// ListGRPCStreams returns active gRPC streaming connections. +func (c *adminClient) ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) { + resp, err := c.get("/grpc/connections") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStreamListResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// GetGRPCStream returns a specific gRPC stream. +func (c *adminClient) GetGRPCStream(id string) (*apitypes.GRPCStream, error) { + resp, err := c.get("/grpc/connections/" + url.PathEscape(id)) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return nil, &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "gRPC stream not found: " + id} + } + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStream + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + +// CloseGRPCStream cancels a gRPC stream. +func (c *adminClient) CloseGRPCStream(id string) error { + resp, err := c.delete("/grpc/connections/" + url.PathEscape(id)) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return &APIError{StatusCode: resp.StatusCode, ErrorCode: "not_found", Message: "gRPC stream not found: " + id} + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + return c.parseError(resp) + } + return nil +} + +// GetGRPCStats returns gRPC statistics. +func (c *adminClient) GetGRPCStats() (*apitypes.GRPCStats, error) { + resp, err := c.get("/grpc/stats") + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, c.parseError(resp) + } + var result apitypes.GRPCStats + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + return &result, nil +} + // get performs an HTTP GET request. func (c *adminClient) get(path string) (*http.Response, error) { return c.doRequest(http.MethodGet, path, nil) diff --git a/pkg/cli/grpc_connections.go b/pkg/cli/grpc_connections.go new file mode 100644 index 00000000..57bc607c --- /dev/null +++ b/pkg/cli/grpc_connections.go @@ -0,0 +1,133 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── gRPC connection management ───────────────────────────────────────────── + +var grpcConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active gRPC streaming connections", + Long: `List, inspect, or cancel active gRPC streaming RPC connections.`, + RunE: runGRPCConnectionsList, +} + +var grpcConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active gRPC streams", + Example: ` mockd grpc connections list + mockd grpc connections list --json`, + Args: cobra.NoArgs, + RunE: runGRPCConnectionsList, +} + +func runGRPCConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListGRPCStreams() + if err != nil { + return fmt.Errorf("failed to list gRPC streams: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Streams) == 0 { + fmt.Println("No active gRPC streams") + return + } + tw := output.Table() + fmt.Fprintf(tw, "ID\tMETHOD\tSTREAM TYPE\tCLIENT\tCONNECTED\tMSG SENT\tMSG RECV\n") + for _, s := range result.Streams { + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%d\n", + s.ID, s.Method, s.StreamType, s.ClientAddr, + formatDuration(time.Since(s.ConnectedAt)), + s.MessagesSent, s.MessagesRecv) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d stream(s)\n", result.Count) + }) + return nil +} + +var grpcConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of a gRPC stream", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stream, err := client.GetGRPCStream(args[0]) + if err != nil { + return fmt.Errorf("failed to get gRPC stream: %s", FormatConnectionError(err)) + } + printResult(stream, func() { + fmt.Printf("gRPC Stream: %s\n", stream.ID) + fmt.Printf(" Method: %s\n", stream.Method) + fmt.Printf(" Stream Type: %s\n", stream.StreamType) + fmt.Printf(" Client Addr: %s\n", stream.ClientAddr) + fmt.Printf(" Connected: %s (%s ago)\n", stream.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(stream.ConnectedAt))) + fmt.Printf(" Messages Sent: %d\n", stream.MessagesSent) + fmt.Printf(" Messages Recv: %d\n", stream.MessagesRecv) + }) + return nil + }, +} + +var grpcConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Cancel a gRPC stream", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseGRPCStream(args[0]); err != nil { + return fmt.Errorf("failed to cancel gRPC stream: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "cancelled": true}, func() { + fmt.Printf("Cancelled gRPC stream: %s\n", args[0]) + }) + return nil + }, +} + +var grpcStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show gRPC statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetGRPCStats() + if err != nil { + return fmt.Errorf("failed to get gRPC stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("gRPC Statistics") + fmt.Printf(" Active Streams: %d\n", stats.ActiveStreams) + fmt.Printf(" Total Streams: %d\n", stats.TotalStreams) + fmt.Printf(" Total RPCs: %d\n", stats.TotalRPCs) + fmt.Printf(" Total Messages Sent: %d\n", stats.TotalMessagesSent) + fmt.Printf(" Total Messages Recv: %d\n", stats.TotalMessagesRecv) + if len(stats.StreamsByMethod) > 0 { + fmt.Println(" Streams by Method:") + for method, count := range stats.StreamsByMethod { + fmt.Printf(" %s: %d\n", method, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + grpcConnectionsCmd.AddCommand(grpcConnectionsListCmd) + grpcConnectionsCmd.AddCommand(grpcConnectionsGetCmd) + grpcConnectionsCmd.AddCommand(grpcConnectionsCloseCmd) + grpcCmd.AddCommand(grpcConnectionsCmd) + + // stats + grpcCmd.AddCommand(grpcStatsCmd) +} diff --git a/pkg/cli/mqtt_connections.go b/pkg/cli/mqtt_connections.go new file mode 100644 index 00000000..3b0de737 --- /dev/null +++ b/pkg/cli/mqtt_connections.go @@ -0,0 +1,142 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── MQTT connection management ───────────────────────────────────────────── + +var mqttConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active MQTT client connections", + Long: `List, inspect, or disconnect active MQTT client connections.`, + RunE: runMQTTConnectionsList, +} + +var mqttConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active MQTT client connections", + Example: ` mockd mqtt connections list + mockd mqtt connections list --json`, + Args: cobra.NoArgs, + RunE: runMQTTConnectionsList, +} + +func runMQTTConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListMQTTConnections() + if err != nil { + return fmt.Errorf("failed to list MQTT connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active MQTT connections") + return + } + tw := output.Table() + fmt.Fprintf(tw, "ID\tREMOTE ADDR\tUSERNAME\tSUBSCRIPTIONS\tCONNECTED\tSTATUS\n") + for _, c := range result.Connections { + subs := fmt.Sprintf("%d topic(s)", len(c.Subscriptions)) + username := c.Username + if username == "" { + username = "-" + } + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", + c.ID, c.RemoteAddr, username, subs, + formatDuration(time.Since(c.ConnectedAt)), + c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + }) + return nil +} + +var mqttConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of an MQTT client connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetMQTTConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get MQTT connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("MQTT Connection: %s\n", conn.ID) + fmt.Printf(" Broker ID: %s\n", conn.BrokerID) + fmt.Printf(" Remote Addr: %s\n", conn.RemoteAddr) + if conn.Username != "" { + fmt.Printf(" Username: %s\n", conn.Username) + } + fmt.Printf(" Protocol Version: %d\n", conn.ProtocolVersion) + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Subscriptions: %s\n", formatStringSlice(conn.Subscriptions)) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var mqttConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Disconnect an MQTT client", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseMQTTConnection(args[0]); err != nil { + return fmt.Errorf("failed to close MQTT connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Disconnected MQTT client: %s\n", args[0]) + }) + return nil + }, +} + +var mqttStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show MQTT broker statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetMQTTStats() + if err != nil { + return fmt.Errorf("failed to get MQTT stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("MQTT Broker Statistics") + fmt.Printf(" Connected Clients: %d\n", stats.ConnectedClients) + fmt.Printf(" Total Subscriptions: %d\n", stats.TotalSubscriptions) + fmt.Printf(" Topic Count: %d\n", stats.TopicCount) + fmt.Printf(" Port: %d\n", stats.Port) + fmt.Printf(" TLS Enabled: %v\n", stats.TLSEnabled) + fmt.Printf(" Auth Enabled: %v\n", stats.AuthEnabled) + if len(stats.SubscriptionsByClient) > 0 { + fmt.Println(" Subscriptions by Client:") + for clientID, count := range stats.SubscriptionsByClient { + fmt.Printf(" %s: %d\n", clientID, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + mqttConnectionsCmd.AddCommand(mqttConnectionsListCmd) + mqttConnectionsCmd.AddCommand(mqttConnectionsGetCmd) + mqttConnectionsCmd.AddCommand(mqttConnectionsCloseCmd) + mqttCmd.AddCommand(mqttConnectionsCmd) + + // stats + mqttCmd.AddCommand(mqttStatsCmd) +} diff --git a/pkg/cli/print.go b/pkg/cli/print.go index e8e55108..c58b95d6 100644 --- a/pkg/cli/print.go +++ b/pkg/cli/print.go @@ -1,6 +1,10 @@ package cli -import "github.com/getmockd/mockd/pkg/cli/internal/output" +import ( + "strings" + + "github.com/getmockd/mockd/pkg/cli/internal/output" +) // printResult outputs a single operation result. // @@ -15,6 +19,14 @@ func printResult(data any, textFn func()) { textFn() } +// formatStringSlice joins strings for display. +func formatStringSlice(items []string) string { + if len(items) == 0 { + return "(none)" + } + return strings.Join(items, ", ") +} + // printList outputs a collection of items. // // Same contract as printResult. textFn typically uses output.Table() for diff --git a/pkg/cli/sse.go b/pkg/cli/sse.go new file mode 100644 index 00000000..74565754 --- /dev/null +++ b/pkg/cli/sse.go @@ -0,0 +1,182 @@ +package cli + +import ( + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── SSE parent command ───────────────────────────────────────────────────── + +var sseCmd = &cobra.Command{ + Use: "sse", + Short: "Manage SSE (Server-Sent Events) connections", +} + +var sseAddCmd = &cobra.Command{ + Use: "add", + Short: "Add a new SSE mock endpoint", + RunE: func(cmd *cobra.Command, args []string) error { + addMockType = "sse" + return runAdd(cmd, args) + }, +} + +// ─── SSE connection management ────────────────────────────────────────────── + +var sseConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active SSE connections", + Long: `List, inspect, or close active SSE connections.`, + RunE: runSSEConnectionsList, +} + +var sseConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active SSE connections", + Example: ` mockd sse connections list + mockd sse connections list --json`, + Args: cobra.NoArgs, + RunE: runSSEConnectionsList, +} + +func runSSEConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListSSEConnections() + if err != nil { + return fmt.Errorf("failed to list SSE connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active SSE connections") + return + } + tw := output.Table() + fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCLIENT IP\tCONNECTED\tEVENTS\tSTATUS\n") + for _, c := range result.Connections { + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%s\n", + c.ID, c.Path, c.MockID, c.ClientIP, + formatDuration(time.Since(c.ConnectedAt)), + c.EventsSent, c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + }) + return nil +} + +var sseConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of an SSE connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetSSEConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get SSE connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("SSE Connection: %s\n", conn.ID) + fmt.Printf(" Path: %s\n", conn.Path) + fmt.Printf(" Mock ID: %s\n", conn.MockID) + fmt.Printf(" Client IP: %s\n", conn.ClientIP) + if conn.UserAgent != "" { + fmt.Printf(" User Agent: %s\n", conn.UserAgent) + } + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Events Sent: %d\n", conn.EventsSent) + fmt.Printf(" Bytes Sent: %d\n", conn.BytesSent) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var sseConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Close an SSE connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseSSEConnection(args[0]); err != nil { + return fmt.Errorf("failed to close SSE connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Closed SSE connection: %s\n", args[0]) + }) + return nil + }, +} + +var sseStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show SSE statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetSSEStats() + if err != nil { + return fmt.Errorf("failed to get SSE stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("SSE Statistics") + fmt.Printf(" Active Connections: %d\n", stats.ActiveConnections) + fmt.Printf(" Total Connections: %d\n", stats.TotalConnections) + fmt.Printf(" Total Events Sent: %d\n", stats.TotalEventsSent) + fmt.Printf(" Total Bytes Sent: %d\n", stats.TotalBytesSent) + fmt.Printf(" Connection Errors: %d\n", stats.ConnectionErrors) + if len(stats.ConnectionsByMock) > 0 { + fmt.Println(" Connections by Mock:") + for mockID, count := range stats.ConnectionsByMock { + fmt.Printf(" %s: %d\n", mockID, count) + } + } + }) + return nil + }, +} + +// ─── init ─────────────────────────────────────────────────────────────────── + +func init() { + rootCmd.AddCommand(sseCmd) + + // sse add + sseAddCmd.Flags().StringVar(&addPath, "path", "", "SSE endpoint path (e.g., /events)") + sseAddCmd.Flags().StringVar(&addName, "name", "", "Mock display name") + sseCmd.AddCommand(sseAddCmd) + + // list/get/delete generic aliases + sseCmd.AddCommand(&cobra.Command{ + Use: "list", + Short: "List SSE mocks", + RunE: func(cmd *cobra.Command, args []string) error { + listMockType = "sse" + return runList(cmd, args) + }, + }) + sseCmd.AddCommand(&cobra.Command{ + Use: "get", + Short: "Get details of an SSE mock", + RunE: runGet, + }) + sseCmd.AddCommand(&cobra.Command{ + Use: "delete", + Short: "Delete an SSE mock", + RunE: runDelete, + }) + + // connections subgroup + sseConnectionsCmd.AddCommand(sseConnectionsListCmd) + sseConnectionsCmd.AddCommand(sseConnectionsGetCmd) + sseConnectionsCmd.AddCommand(sseConnectionsCloseCmd) + sseCmd.AddCommand(sseConnectionsCmd) + + // stats + sseCmd.AddCommand(sseStatsCmd) +} diff --git a/pkg/cli/ws_connections.go b/pkg/cli/ws_connections.go new file mode 100644 index 00000000..a6449b1a --- /dev/null +++ b/pkg/cli/ws_connections.go @@ -0,0 +1,172 @@ +package cli + +import ( + "encoding/base64" + "fmt" + "time" + + "github.com/getmockd/mockd/pkg/cli/internal/output" + "github.com/spf13/cobra" +) + +// ─── websocket connection management ──────────────────────────────────────── + +var wsConnectionsCmd = &cobra.Command{ + Use: "connections", + Short: "Manage active WebSocket connections", + Long: `List, inspect, close, or send messages to active WebSocket connections.`, + RunE: runWSConnectionsList, +} + +var wsConnectionsListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List active WebSocket connections", + Example: ` mockd websocket connections list + mockd websocket connections list --json`, + Args: cobra.NoArgs, + RunE: runWSConnectionsList, +} + +func runWSConnectionsList(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + result, err := client.ListWebSocketConnections() + if err != nil { + return fmt.Errorf("failed to list WebSocket connections: %s", FormatConnectionError(err)) + } + + printList(result, func() { + if len(result.Connections) == 0 { + fmt.Println("No active WebSocket connections") + return + } + tw := output.Table() + fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCONNECTED\tMSG SENT\tMSG RECV\tSTATUS\n") + for _, c := range result.Connections { + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%d\t%s\n", + c.ID, c.Path, c.MockID, + formatDuration(time.Since(c.ConnectedAt)), + c.MessagesSent, c.MessagesRecv, c.Status) + } + _ = tw.Flush() + fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + }) + return nil +} + +var wsConnectionsGetCmd = &cobra.Command{ + Use: "get ", + Short: "Get details of a WebSocket connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + conn, err := client.GetWebSocketConnection(args[0]) + if err != nil { + return fmt.Errorf("failed to get WebSocket connection: %s", FormatConnectionError(err)) + } + printResult(conn, func() { + fmt.Printf("WebSocket Connection: %s\n", conn.ID) + fmt.Printf(" Path: %s\n", conn.Path) + fmt.Printf(" Mock ID: %s\n", conn.MockID) + fmt.Printf(" Connected: %s (%s ago)\n", conn.ConnectedAt.Format(time.RFC3339), formatDuration(time.Since(conn.ConnectedAt))) + fmt.Printf(" Messages Sent: %d\n", conn.MessagesSent) + fmt.Printf(" Messages Recv: %d\n", conn.MessagesRecv) + fmt.Printf(" Status: %s\n", conn.Status) + }) + return nil + }, +} + +var wsConnectionsCloseCmd = &cobra.Command{ + Use: "close ", + Short: "Close a WebSocket connection", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + if err := client.CloseWebSocketConnection(args[0]); err != nil { + return fmt.Errorf("failed to close WebSocket connection: %s", FormatConnectionError(err)) + } + printResult(map[string]interface{}{"id": args[0], "closed": true}, func() { + fmt.Printf("Closed WebSocket connection: %s\n", args[0]) + }) + return nil + }, +} + +var wsConnectionsSendBinary bool + +var wsConnectionsSendCmd = &cobra.Command{ + Use: "send ", + Short: "Send a message to a WebSocket connection", + Long: `Send a text or binary message to an active WebSocket connection. +Use --binary to send base64-encoded binary data.`, + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + id := args[0] + message := args[1] + + if wsConnectionsSendBinary { + // Validate base64 + if _, err := base64.StdEncoding.DecodeString(message); err != nil { + return fmt.Errorf("invalid base64 data: %w", err) + } + } + + client := NewAdminClientWithAuth(adminURL) + if err := client.SendWebSocketMessage(id, message, wsConnectionsSendBinary); err != nil { + return fmt.Errorf("failed to send message: %s", FormatConnectionError(err)) + } + + msgType := "text" + if wsConnectionsSendBinary { + msgType = "binary" + } + printResult(map[string]interface{}{"id": id, "sent": true, "type": msgType}, func() { + fmt.Printf("Sent %s message to connection %s\n", msgType, id) + }) + return nil + }, +} + +var wsStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show WebSocket statistics", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + client := NewAdminClientWithAuth(adminURL) + stats, err := client.GetWebSocketStats() + if err != nil { + return fmt.Errorf("failed to get WebSocket stats: %s", FormatConnectionError(err)) + } + + printResult(stats, func() { + fmt.Println("WebSocket Statistics") + fmt.Printf(" Active Connections: %d\n", stats.ActiveConnections) + fmt.Printf(" Total Connections: %d\n", stats.TotalConnections) + fmt.Printf(" Total Messages Sent: %d\n", stats.TotalMessagesSent) + fmt.Printf(" Total Messages Recv: %d\n", stats.TotalMessagesRecv) + if len(stats.ConnectionsByMock) > 0 { + fmt.Println(" Connections by Mock:") + for mockID, count := range stats.ConnectionsByMock { + fmt.Printf(" %s: %d\n", mockID, count) + } + } + }) + return nil + }, +} + +func init() { + // connections subgroup + wsConnectionsCmd.AddCommand(wsConnectionsListCmd) + wsConnectionsCmd.AddCommand(wsConnectionsGetCmd) + wsConnectionsCmd.AddCommand(wsConnectionsCloseCmd) + + wsConnectionsSendCmd.Flags().BoolVar(&wsConnectionsSendBinary, "binary", false, "Send base64-encoded binary message") + wsConnectionsCmd.AddCommand(wsConnectionsSendCmd) + + websocketCmd.AddCommand(wsConnectionsCmd) + + // stats as top-level websocket subcommand + websocketCmd.AddCommand(wsStatsCmd) +} From d001a2392737c176a4a363ea327d10fba1437a9e Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Tue, 7 Apr 2026 00:26:58 -0500 Subject: [PATCH 11/18] fix: resolve critical and high-priority review findings for unified connection management Critical fixes: - C3: MQTT ConnectedAt now populated via OnSessionEstablished/OnDisconnect hooks with clientConnectedAt tracking map on the Broker struct - C4: Canonical API types (WebSocketConnectionListResponse, SSEConnectionListResponse, MQTTConnectionListResponse, GRPCStreamListResponse) now use Stats field matching what admin handlers actually return, fixing CLI count:0 bug - C5: CHANGELOG expanded with SSE, MQTT, gRPC, CLI, and UI connection management entries - C6: API docs updated with correct WebSocket paths, full SSE/MQTT/gRPC connection management sections, mock-scoped WebSocket endpoints, and response examples High-priority fixes: - H1: gRPC stream handlers now return codes.Canceled for client-initiated cancellations and codes.Unavailable only for admin/mock-update cancellations, using per-stream adminCancelled tracking in StreamTracker - H2: Removed dead code CancelAllWithStatus() and UnavailableError() from stream_tracker.go Co-Authored-By: Claude Opus 4.6 (1M context) --- CHANGELOG.md | 6 + docs/src/content/docs/reference/admin-api.md | 180 +++++++++++++++++-- pkg/api/types/responses.go | 8 +- pkg/cli/client.go | 2 +- pkg/cli/grpc_connections.go | 2 +- pkg/cli/mqtt_connections.go | 2 +- pkg/cli/sse.go | 2 +- pkg/cli/ws_connections.go | 2 +- pkg/engine/api/handlers.go | 28 ++- pkg/engine/api/handlers_test.go | 14 +- pkg/grpc/server.go | 28 ++- pkg/grpc/stream_tracker.go | 40 ++--- pkg/mcp/tool_handlers_test.go | 35 ++++ pkg/mqtt/broker.go | 33 ++++ pkg/mqtt/hooks.go | 20 +++ 15 files changed, 336 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbaed45f..62a767cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - **WebSocket connection management API** — `GET /websocket/connections`, `GET /websocket/connections/{id}`, `DELETE /websocket/connections/{id}`, `POST /websocket/connections/{id}/send`, `GET /websocket/stats` added to the Admin API for real-time visibility, control, and server-initiated messaging of active WebSocket connections +- **SSE connection management API** — `GET /sse/connections`, `GET /sse/connections/{id}`, `DELETE /sse/connections/{id}`, `GET /sse/stats` for managing active SSE connections with auto-disconnect on mock update +- **SSE response type improvements** — SSE connection list responses now include full stats alongside connections +- **MQTT connection management API** — `GET /mqtt-connections`, `GET /mqtt-connections/{id}`, `DELETE /mqtt-connections/{id}`, `GET /mqtt-connections/stats` for managing active MQTT client connections with connection time tracking +- **gRPC stream tracking and management** — `GET /grpc/connections`, `GET /grpc/connections/{id}`, `DELETE /grpc/connections/{id}`, `GET /grpc/stats` for managing active gRPC streaming RPCs with proper Canceled vs Unavailable status codes +- **CLI connection management commands** — `mockd connections list`, `mockd connections get`, `mockd connections close` for all protocols (WebSocket, SSE, MQTT, gRPC) +- **Connection management UI** — unified connections view in the web dashboard showing all active connections across protocols with real-time stats, search, and filtering - **WebSocket auto-reconnect on mock update** — updating or deleting a WebSocket mock now automatically closes all active connections with close code 1012 (Service Restart) so clients reconnect and pick up the new configuration immediately - **Workspace-scoped stateful resources** — stateful resources, custom operations, and request logs are now isolated per workspace - **`--workspace` persistent CLI flag** — scope any CLI command to a specific workspace without switching context diff --git a/docs/src/content/docs/reference/admin-api.md b/docs/src/content/docs/reference/admin-api.md index b5ff9a5b..699d6f97 100644 --- a/docs/src/content/docs/reference/admin-api.md +++ b/docs/src/content/docs/reference/admin-api.md @@ -935,6 +935,34 @@ Export mocks as Insomnia v4 collection (JSON format, legacy). List active SSE connections. +**Response:** + +```json +{ + "connections": [ + { + "id": "sse-abc123", + "mockId": "mock-1", + "path": "/events", + "clientIp": "127.0.0.1", + "userAgent": "Mozilla/5.0", + "connectedAt": "2024-01-15T10:30:00Z", + "eventsSent": 42, + "bytesSent": 1024, + "status": "active" + } + ], + "stats": { + "totalConnections": 10, + "activeConnections": 1, + "totalEventsSent": 500, + "totalBytesSent": 51200, + "connectionErrors": 0, + "connectionsByMock": {"mock-1": 1} + } +} +``` + #### GET /sse/connections/{id} Get SSE connection details. @@ -943,6 +971,10 @@ Get SSE connection details. Close an SSE connection. +:::note[SSE auto-disconnect on mock update] +When an SSE mock is updated or deleted, all active SSE connections to that endpoint are automatically closed. Clients should reconnect to pick up the new configuration. +::: + #### GET /sse/stats Get SSE statistics. @@ -951,32 +983,42 @@ Get SSE statistics. ### WebSocket Management -#### GET /admin/ws/connections +#### GET /websocket/connections List active WebSocket connections. -#### GET /admin/ws/connections/{id} - -Get connection details. - -#### DELETE /admin/ws/connections/{id} - -Close a WebSocket connection. - -#### POST /admin/ws/connections/{id}/send - -Send a message to a specific connection. - -**Request:** +**Response:** ```json { - "type": "text", - "data": "Hello from server" + "connections": [ + { + "id": "ws-abc123", + "mockId": "mock-1", + "path": "/ws/chat", + "connectedAt": "2024-01-15T10:30:00Z", + "messagesSent": 15, + "messagesRecv": 10, + "status": "connected" + } + ], + "stats": { + "totalConnections": 50, + "activeConnections": 1, + "totalMessagesSent": 500, + "totalMessagesRecv": 300, + "connectionsByMock": {"mock-1": 1} + } } ``` -#### POST /admin/ws/broadcast +#### GET /websocket/connections/{id} + +Get connection details. + +#### DELETE /websocket/connections/{id} + +Close a WebSocket connection. #### POST /websocket/connections/{id}/send @@ -1008,6 +1050,16 @@ Send a text or binary message to a specific active WebSocket connection. Returns `404` if the connection is not found. +#### GET /mocks/{id}/websocket/connections + +List active WebSocket connections for a specific mock. + +**Response:** Same format as `GET /websocket/connections`, filtered to the given mock ID. + +#### DELETE /mocks/{id}/websocket/connections + +Close all WebSocket connections for a specific mock. + #### GET /websocket/stats Get WebSocket statistics. @@ -1054,6 +1106,100 @@ Stop a replay session. --- +### MQTT Connection Management + +#### GET /mqtt-connections + +List active MQTT client connections. + +**Response:** + +```json +{ + "connections": [ + { + "id": "client-abc123", + "brokerId": "mqtt-broker-1", + "connectedAt": "2024-01-15T10:30:00Z", + "subscriptions": ["sensors/#", "devices/+"], + "protocolVersion": 5, + "username": "device-1", + "remoteAddr": "192.168.1.10:54321", + "status": "connected" + } + ], + "stats": { + "connectedClients": 1, + "totalSubscriptions": 2, + "topicCount": 5, + "port": 1883, + "tlsEnabled": false, + "authEnabled": false, + "subscriptionsByClient": {"client-abc123": 2} + } +} +``` + +#### GET /mqtt-connections/{id} + +Get details of a specific MQTT client connection. + +#### DELETE /mqtt-connections/{id} + +Disconnect an MQTT client. + +#### GET /mqtt-connections/stats + +Get MQTT connection statistics. + +--- + +### gRPC Stream Management + +#### GET /grpc/connections + +List active gRPC streaming RPCs. + +**Response:** + +```json +{ + "streams": [ + { + "id": "grpc-stream-1", + "method": "/myapp.ChatService/StreamMessages", + "streamType": "bidi", + "clientAddr": "127.0.0.1:54321", + "connectedAt": "2024-01-15T10:30:00Z", + "messagesSent": 15, + "messagesRecv": 10 + } + ], + "stats": { + "activeStreams": 1, + "totalStreams": 50, + "totalRPCs": 200, + "totalMessagesSent": 1000, + "totalMessagesRecv": 800, + "streamsByMethod": {"/myapp.ChatService/StreamMessages": 1} + } +} +``` + +#### GET /grpc/connections/{id} + +Get details of a specific gRPC stream. + +#### DELETE /grpc/connections/{id} + +Cancel a gRPC stream. The client receives a `codes.Unavailable` status, signaling it should reconnect. + +#### GET /grpc/stats + +Get gRPC stream statistics. + +--- + ### gRPC Management #### GET /grpc diff --git a/pkg/api/types/responses.go b/pkg/api/types/responses.go index bb8c4202..4c3b4c13 100644 --- a/pkg/api/types/responses.go +++ b/pkg/api/types/responses.go @@ -458,7 +458,7 @@ type SSEConnection struct { // SSEConnectionListResponse lists SSE connections. type SSEConnectionListResponse struct { Connections []*SSEConnection `json:"connections"` - Count int `json:"count"` + Stats SSEStats `json:"stats"` } // SSEStats represents SSE statistics. @@ -488,7 +488,7 @@ type MQTTConnection struct { // MQTTConnectionListResponse lists MQTT connections. type MQTTConnectionListResponse struct { Connections []*MQTTConnection `json:"connections"` - Count int `json:"count"` + Stats MQTTStats `json:"stats"` } // MQTTStats represents MQTT broker statistics. @@ -518,7 +518,7 @@ type WebSocketConnection struct { // WebSocketConnectionListResponse lists WebSocket connections. type WebSocketConnectionListResponse struct { Connections []*WebSocketConnection `json:"connections"` - Count int `json:"count"` + Stats WebSocketStats `json:"stats"` } // WebSocketStats represents WebSocket statistics. @@ -546,7 +546,7 @@ type GRPCStream struct { // GRPCStreamListResponse lists gRPC streams. type GRPCStreamListResponse struct { Streams []*GRPCStream `json:"streams"` - Count int `json:"count"` + Stats GRPCStats `json:"stats"` } // GRPCStats represents gRPC statistics. diff --git a/pkg/cli/client.go b/pkg/cli/client.go index b835dfc5..86b62505 100644 --- a/pkg/cli/client.go +++ b/pkg/cli/client.go @@ -1759,7 +1759,7 @@ func (c *adminClient) SendWebSocketMessage(id string, message string, binary boo msgType = "binary" } body, err := json.Marshal(map[string]interface{}{ - "message": message, + "data": message, "type": msgType, }) if err != nil { diff --git a/pkg/cli/grpc_connections.go b/pkg/cli/grpc_connections.go index 57bc607c..c8c54213 100644 --- a/pkg/cli/grpc_connections.go +++ b/pkg/cli/grpc_connections.go @@ -48,7 +48,7 @@ func runGRPCConnectionsList(cmd *cobra.Command, args []string) error { s.MessagesSent, s.MessagesRecv) } _ = tw.Flush() - fmt.Printf("\nTotal: %d stream(s)\n", result.Count) + fmt.Printf("\nTotal: %d stream(s)\n", len(result.Streams)) }) return nil } diff --git a/pkg/cli/mqtt_connections.go b/pkg/cli/mqtt_connections.go index 3b0de737..6bec0291 100644 --- a/pkg/cli/mqtt_connections.go +++ b/pkg/cli/mqtt_connections.go @@ -53,7 +53,7 @@ func runMQTTConnectionsList(cmd *cobra.Command, args []string) error { c.Status) } _ = tw.Flush() - fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) }) return nil } diff --git a/pkg/cli/sse.go b/pkg/cli/sse.go index 74565754..35e6fd6d 100644 --- a/pkg/cli/sse.go +++ b/pkg/cli/sse.go @@ -64,7 +64,7 @@ func runSSEConnectionsList(cmd *cobra.Command, args []string) error { c.EventsSent, c.Status) } _ = tw.Flush() - fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) }) return nil } diff --git a/pkg/cli/ws_connections.go b/pkg/cli/ws_connections.go index a6449b1a..b00b8736 100644 --- a/pkg/cli/ws_connections.go +++ b/pkg/cli/ws_connections.go @@ -49,7 +49,7 @@ func runWSConnectionsList(cmd *cobra.Command, args []string) error { c.MessagesSent, c.MessagesRecv, c.Status) } _ = tw.Flush() - fmt.Printf("\nTotal: %d connection(s)\n", result.Count) + fmt.Printf("\nTotal: %d connection(s)\n", len(result.Connections)) }) return nil } diff --git a/pkg/engine/api/handlers.go b/pkg/engine/api/handlers.go index dc5573d3..69ebb7fc 100644 --- a/pkg/engine/api/handlers.go +++ b/pkg/engine/api/handlers.go @@ -740,9 +740,14 @@ func (s *Server) handleGetHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { connections := s.engine.ListSSEConnections() + stats := s.engine.GetSSEStats() + respStats := SSEStats{ConnectionsByMock: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, SSEConnectionListResponse{ Connections: connections, - Count: len(connections), + Stats: respStats, }) } @@ -781,9 +786,14 @@ func (s *Server) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { func (s *Server) handleListWebSocketConnections(w http.ResponseWriter, r *http.Request) { connections := s.engine.ListWebSocketConnections() + stats := s.engine.GetWebSocketStats() + respStats := WebSocketStats{ConnectionsByMock: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, WebSocketConnectionListResponse{ Connections: connections, - Count: len(connections), + Stats: respStats, }) } @@ -1072,9 +1082,14 @@ func writeJSON(w http.ResponseWriter, status int, v any) { func (s *Server) handleListMQTTConnections(w http.ResponseWriter, r *http.Request) { connections := s.engine.ListMQTTConnections() + stats := s.engine.GetMQTTStats() + respStats := MQTTStats{SubscriptionsByClient: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, MQTTConnectionListResponse{ Connections: connections, - Count: len(connections), + Stats: respStats, }) } @@ -1116,9 +1131,14 @@ func (s *Server) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { if streams == nil { streams = []*GRPCStream{} } + stats := s.engine.GetGRPCStats() + respStats := GRPCStats{StreamsByMethod: make(map[string]int)} + if stats != nil { + respStats = *stats + } writeJSON(w, http.StatusOK, GRPCStreamListResponse{ Streams: streams, - Count: len(streams), + Stats: respStats, }) } diff --git a/pkg/engine/api/handlers_test.go b/pkg/engine/api/handlers_test.go index 59e90269..ab9a3911 100644 --- a/pkg/engine/api/handlers_test.go +++ b/pkg/engine/api/handlers_test.go @@ -2317,7 +2317,7 @@ func TestHandleListWebSocketConnections(t *testing.T) { err := json.Unmarshal(rec.Body.Bytes(), &resp) require.NoError(t, err) assert.Empty(t, resp.Connections) - assert.Equal(t, 0, resp.Count) + assert.Equal(t, 0, resp.Stats.ActiveConnections) }) t.Run("returns all connections", func(t *testing.T) { @@ -2339,7 +2339,6 @@ func TestHandleListWebSocketConnections(t *testing.T) { err := json.Unmarshal(rec.Body.Bytes(), &resp) require.NoError(t, err) assert.Len(t, resp.Connections, 2) - assert.Equal(t, 2, resp.Count) }) } @@ -2653,7 +2652,7 @@ func TestHandleListMQTTConnections(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) var resp MQTTConnectionListResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - assert.Equal(t, 0, resp.Count) + assert.Empty(t, resp.Connections) }) t.Run("with connections", func(t *testing.T) { @@ -2672,7 +2671,7 @@ func TestHandleListMQTTConnections(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) var resp MQTTConnectionListResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - assert.Equal(t, 2, resp.Count) + assert.Len(t, resp.Connections, 2) assert.Equal(t, "client-1", resp.Connections[0].ID) }) } @@ -2804,7 +2803,7 @@ func TestHandleListGRPCStreams(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) var resp GRPCStreamListResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - assert.Equal(t, 0, resp.Count) + assert.Empty(t, resp.Streams) }) t.Run("with streams", func(t *testing.T) { @@ -2823,7 +2822,7 @@ func TestHandleListGRPCStreams(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) var resp GRPCStreamListResponse require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - assert.Equal(t, 2, resp.Count) + assert.Len(t, resp.Streams, 2) assert.Equal(t, "stream-1", resp.Streams[0].ID) }) } @@ -2956,7 +2955,7 @@ func TestHandleListSSEConnections(t *testing.T) { err := json.Unmarshal(rec.Body.Bytes(), &resp) require.NoError(t, err) assert.Empty(t, resp.Connections) - assert.Equal(t, 0, resp.Count) + assert.Equal(t, 0, resp.Stats.ActiveConnections) }) t.Run("returns all connections", func(t *testing.T) { @@ -2978,7 +2977,6 @@ func TestHandleListSSEConnections(t *testing.T) { err := json.Unmarshal(rec.Body.Bytes(), &resp) require.NoError(t, err) assert.Len(t, resp.Connections, 2) - assert.Equal(t, 2, resp.Count) }) } diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 09c6eb64..9443332d 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -812,7 +812,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "stream cancelled") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, nil, grpcErr) return grpcErr } @@ -843,7 +843,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD for i, respData := range responses { // Check for tracker cancellation between messages if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) return grpcErr } @@ -870,7 +870,7 @@ func (s *Server) handleServerStreaming(stream grpc.ServerStream, method *MethodD if i < len(responses)-1 { s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "stream cancelled") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamServerStream, md, reqMap, collectedResponses, grpcErr) return grpcErr } @@ -906,7 +906,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD for { // Check for tracker cancellation if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } @@ -936,7 +936,7 @@ func (s *Server) handleClientStreaming(stream grpc.ServerStream, method *MethodD // Apply delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "stream cancelled") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamClientStream, md, allRequests, nil, grpcErr) return grpcErr } @@ -1004,7 +1004,7 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * // Apply initial delay (context-aware for client/tracker cancellation) s.applyDelayWithContext(ctx, methodCfg.Delay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "stream cancelled") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, nil, nil, grpcErr) return grpcErr } @@ -1035,7 +1035,7 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * for { // Check for tracker cancellation if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "mock configuration updated") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) return grpcErr } @@ -1096,7 +1096,7 @@ func (s *Server) handleBidirectionalStreaming(stream grpc.ServerStream, method * respIndex++ s.applyDelayWithContext(ctx, methodCfg.StreamDelay) if ctx.Err() != nil { - grpcErr := status.Error(codes.Unavailable, "stream cancelled") + grpcErr := s.contextCancelError(streamID) s.logGRPCCall(startTime, fullPath, serviceName, methodName, streamBidi, md, allRequests, allResponses, grpcErr) return grpcErr } @@ -1276,6 +1276,18 @@ func (s *Server) createTemplateContext(md metadata.MD, reqMap map[string]interfa return template.NewContextFromMap(reqMap, headers) } +// contextCancelError returns the appropriate gRPC status error for a +// cancelled context. If the stream was cancelled by admin action (mock update, +// explicit admin cancel), it returns codes.Unavailable so clients with retry +// policies reconnect. If cancelled by the client (e.g., timeout), it returns +// codes.Canceled. +func (s *Server) contextCancelError(streamID string) error { + if s.streamTracker.WasAdminCancelled(streamID) { + return status.Error(codes.Unavailable, "mock configuration updated") + } + return status.Error(codes.Canceled, "stream cancelled by client") +} + // applyDelay applies configured delay, respecting the context deadline. // Accepts Go duration strings ("100ms", "2s") or bare numbers treated as milliseconds ("100"). func (s *Server) applyDelay(delay string) { diff --git a/pkg/grpc/stream_tracker.go b/pkg/grpc/stream_tracker.go index b72bd920..1f55915a 100644 --- a/pkg/grpc/stream_tracker.go +++ b/pkg/grpc/stream_tracker.go @@ -8,8 +8,6 @@ import ( "time" "github.com/getmockd/mockd/pkg/metrics" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) // StreamInfo holds metadata about an active streaming RPC. @@ -37,8 +35,9 @@ type StreamStats struct { // StreamTracker tracks active gRPC streaming RPCs for a single Server. type StreamTracker struct { - streams map[string]*trackedStream // ID -> tracked stream - mu sync.RWMutex + streams map[string]*trackedStream // ID -> tracked stream + adminCancelled map[string]bool // IDs cancelled by admin/mock-update + mu sync.RWMutex // Lifetime counters (include completed streams). totalStreams atomic.Int64 @@ -58,7 +57,8 @@ type trackedStream struct { // NewStreamTracker creates a new StreamTracker. func NewStreamTracker() *StreamTracker { return &StreamTracker{ - streams: make(map[string]*trackedStream), + streams: make(map[string]*trackedStream), + adminCancelled: make(map[string]bool), } } @@ -108,6 +108,7 @@ func (t *StreamTracker) Unregister(id string) { if ok { delete(t.streams, id) } + delete(t.adminCancelled, id) t.mu.Unlock() if !ok { @@ -183,9 +184,12 @@ func (t *StreamTracker) Count() int { // Cancel cancels a specific stream's context, causing the RPC handler to // return codes.Unavailable to the client. func (t *StreamTracker) Cancel(id string) error { - t.mu.RLock() + t.mu.Lock() ts := t.streams[id] - t.mu.RUnlock() + if ts != nil { + t.adminCancelled[id] = true + } + t.mu.Unlock() if ts == nil { return fmt.Errorf("stream %s not found", id) } @@ -193,6 +197,15 @@ func (t *StreamTracker) Cancel(id string) error { return nil } +// WasAdminCancelled returns true if the stream was cancelled by an admin +// action (CancelAll for mock update, or explicit Cancel via admin API) +// rather than by the client. +func (t *StreamTracker) WasAdminCancelled(id string) bool { + t.mu.RLock() + defer t.mu.RUnlock() + return t.adminCancelled[id] +} + // CancelAll cancels all active streams with the given gRPC status. // Returns the number of streams cancelled. func (t *StreamTracker) CancelAll() int { @@ -212,13 +225,6 @@ func (t *StreamTracker) CancelAll() int { return count } -// CancelAllWithStatus cancels all active streams. The gRPC handler -// should check ctx.Err() and return the appropriate status to the client. -// Returns the number of streams cancelled. -func (t *StreamTracker) CancelAllWithStatus() int { - return t.CancelAll() -} - // Stats returns aggregate statistics. func (t *StreamTracker) Stats() *StreamStats { t.mu.RLock() @@ -259,9 +265,3 @@ func (t *StreamTracker) toInfo(ts *trackedStream) *StreamInfo { } } -// UnavailableError returns a gRPC Unavailable status error with the given message. -// This is the gRPC equivalent of WebSocket close code 1012 — clients with retry -// policies will reconnect automatically. -func UnavailableError(msg string) error { - return status.Error(codes.Unavailable, msg) -} diff --git a/pkg/mcp/tool_handlers_test.go b/pkg/mcp/tool_handlers_test.go index 2c5e4f19..5f468a4a 100644 --- a/pkg/mcp/tool_handlers_test.go +++ b/pkg/mcp/tool_handlers_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + apitypes "github.com/getmockd/mockd/pkg/api/types" "github.com/getmockd/mockd/pkg/cli" "github.com/getmockd/mockd/pkg/config" "github.com/getmockd/mockd/pkg/mock" @@ -429,6 +430,40 @@ func (m *mockAdminClient) BulkCreateMocks(_ []*mock.Mock, _ string) (*cli.BulkCr return nil, nil } +// --- Connection Management (stubs) --- + +func (m *mockAdminClient) ListWebSocketConnections() (*apitypes.WebSocketConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetWebSocketConnection(_ string) (*apitypes.WebSocketConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseWebSocketConnection(_ string) error { return nil } +func (m *mockAdminClient) SendWebSocketMessage(_ string, _ string, _ bool) error { return nil } +func (m *mockAdminClient) GetWebSocketStats() (*apitypes.WebSocketStats, error) { return nil, nil } +func (m *mockAdminClient) ListSSEConnections() (*apitypes.SSEConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetSSEConnection(_ string) (*apitypes.SSEConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseSSEConnection(_ string) error { return nil } +func (m *mockAdminClient) GetSSEStats() (*apitypes.SSEStats, error) { return nil, nil } +func (m *mockAdminClient) ListMQTTConnections() (*apitypes.MQTTConnectionListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetMQTTConnection(_ string) (*apitypes.MQTTConnection, error) { + return nil, nil +} +func (m *mockAdminClient) CloseMQTTConnection(_ string) error { return nil } +func (m *mockAdminClient) GetMQTTStats() (*apitypes.MQTTStats, error) { return nil, nil } +func (m *mockAdminClient) ListGRPCStreams() (*apitypes.GRPCStreamListResponse, error) { + return nil, nil +} +func (m *mockAdminClient) GetGRPCStream(_ string) (*apitypes.GRPCStream, error) { return nil, nil } +func (m *mockAdminClient) CloseGRPCStream(_ string) error { return nil } +func (m *mockAdminClient) GetGRPCStats() (*apitypes.GRPCStats, error) { return nil, nil } + // ============================================================================= // Test Helpers // ============================================================================= diff --git a/pkg/mqtt/broker.go b/pkg/mqtt/broker.go index 86dcb330..f6f540d8 100644 --- a/pkg/mqtt/broker.go +++ b/pkg/mqtt/broker.go @@ -59,6 +59,7 @@ type Broker struct { log *slog.Logger internalSubscribers map[string][]SubscriptionHandler clientSubscriptions map[string][]string + clientConnectedAt map[string]time.Time simulator *Simulator recordingEnabled bool recordingStore MQTTRecordingStore @@ -144,6 +145,7 @@ func NewBroker(config *MQTTConfig) (*Broker, error) { log: logging.Nop(), internalSubscribers: make(map[string][]SubscriptionHandler), clientSubscriptions: make(map[string][]string), + clientConnectedAt: make(map[string]time.Time), sessionManager: NewSessionManager(), } broker.responseHandler = NewResponseHandler(broker) @@ -654,6 +656,28 @@ func (b *Broker) notifyTestPanelPublish(topic string, payload []byte, qos int, r b.sessionManager.NotifyMessage(b.config.ID, msg) } +// trackClientConnect records the connection time for a client. +// Called from the MessageHook's OnSessionEstablished callback. +func (b *Broker) trackClientConnect(clientID string) { + if b.stopping.Load() != 0 { + return + } + b.mu.Lock() + b.clientConnectedAt[clientID] = time.Now() + b.mu.Unlock() +} + +// trackClientDisconnect removes the connection time for a client. +// Called from the MessageHook's OnDisconnect callback. +func (b *Broker) trackClientDisconnect(clientID string) { + if b.stopping.Load() != 0 { + return + } + b.mu.Lock() + delete(b.clientConnectedAt, clientID) + b.mu.Unlock() +} + // MQTTClientInfo represents information about a connected MQTT client. type MQTTClientInfo struct { ID string @@ -677,6 +701,10 @@ func (b *Broker) ListClientInfos() []*MQTTClientInfo { for k, v := range b.clientSubscriptions { subs[k] = append([]string{}, v...) } + connTimes := make(map[string]time.Time, len(b.clientConnectedAt)) + for k, v := range b.clientConnectedAt { + connTimes[k] = v + } b.mu.RUnlock() if b.server == nil { @@ -698,6 +726,9 @@ func (b *Broker) ListClientInfos() []*MQTTClientInfo { RemoteAddr: cl.Net.Remote, Closed: cl.Closed(), } + if ct, ok := connTimes[id]; ok { + info.ConnectedAt = ct + } if subList, ok := subs[id]; ok { info.Subscriptions = subList } @@ -730,11 +761,13 @@ func (b *Broker) GetClientInfo(clientID string) *MQTTClientInfo { if s, ok := b.clientSubscriptions[clientID]; ok { subsCopy = append([]string{}, s...) } + connectedAt := b.clientConnectedAt[clientID] b.mu.RUnlock() return &MQTTClientInfo{ ID: clientID, BrokerID: brokerID, + ConnectedAt: connectedAt, ProtocolVersion: cl.Properties.ProtocolVersion, Username: string(cl.Properties.Username), RemoteAddr: cl.Net.Remote, diff --git a/pkg/mqtt/hooks.go b/pkg/mqtt/hooks.go index e423a9ff..67346744 100644 --- a/pkg/mqtt/hooks.go +++ b/pkg/mqtt/hooks.go @@ -219,9 +219,29 @@ func (h *MessageHook) Provides(b byte) bool { mqtt.OnPublish, mqtt.OnSubscribed, mqtt.OnUnsubscribed, + mqtt.OnSessionEstablished, + mqtt.OnDisconnect, }, []byte{b}) } +// OnSessionEstablished is called after a client has connected and its session +// is fully set up. We use it to track the connection time. +func (h *MessageHook) OnSessionEstablished(cl *mqtt.Client, _ packets.Packet) { + if cl.Net.Inline { + return + } + h.broker.trackClientConnect(cl.ID) +} + +// OnDisconnect is called when a client disconnects. We use it to clean up the +// tracked connection time. +func (h *MessageHook) OnDisconnect(cl *mqtt.Client, _ error, _ bool) { + if cl.Net.Inline { + return + } + h.broker.trackClientDisconnect(cl.ID) +} + // OnPublish handles incoming publish messages func (h *MessageHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { // Defense-in-depth: enforce ACL on publish for non-inline clients. From 6ca11265430356e6c1dcf7011fcc61512a135256 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:17:14 -0500 Subject: [PATCH 12/18] fix: resolve golangci-lint failures --- pkg/cli/grpc_connections.go | 4 ++-- pkg/cli/mqtt_connections.go | 4 ++-- pkg/grpc/stream_tracker.go | 2 -- tests/integration/sse_test.go | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pkg/cli/grpc_connections.go b/pkg/cli/grpc_connections.go index c8c54213..03569f6f 100644 --- a/pkg/cli/grpc_connections.go +++ b/pkg/cli/grpc_connections.go @@ -40,9 +40,9 @@ func runGRPCConnectionsList(cmd *cobra.Command, args []string) error { return } tw := output.Table() - fmt.Fprintf(tw, "ID\tMETHOD\tSTREAM TYPE\tCLIENT\tCONNECTED\tMSG SENT\tMSG RECV\n") + _, _ = fmt.Fprintf(tw, "ID\tMETHOD\tSTREAM TYPE\tCLIENT\tCONNECTED\tMSG SENT\tMSG RECV\n") for _, s := range result.Streams { - fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%d\n", + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%d\n", s.ID, s.Method, s.StreamType, s.ClientAddr, formatDuration(time.Since(s.ConnectedAt)), s.MessagesSent, s.MessagesRecv) diff --git a/pkg/cli/mqtt_connections.go b/pkg/cli/mqtt_connections.go index 6bec0291..f8fb929c 100644 --- a/pkg/cli/mqtt_connections.go +++ b/pkg/cli/mqtt_connections.go @@ -40,14 +40,14 @@ func runMQTTConnectionsList(cmd *cobra.Command, args []string) error { return } tw := output.Table() - fmt.Fprintf(tw, "ID\tREMOTE ADDR\tUSERNAME\tSUBSCRIPTIONS\tCONNECTED\tSTATUS\n") + _, _ = fmt.Fprintf(tw, "ID\tREMOTE ADDR\tUSERNAME\tSUBSCRIPTIONS\tCONNECTED\tSTATUS\n") for _, c := range result.Connections { subs := fmt.Sprintf("%d topic(s)", len(c.Subscriptions)) username := c.Username if username == "" { username = "-" } - fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%s\n", c.ID, c.RemoteAddr, username, subs, formatDuration(time.Since(c.ConnectedAt)), c.Status) diff --git a/pkg/grpc/stream_tracker.go b/pkg/grpc/stream_tracker.go index 1f55915a..bfe74f79 100644 --- a/pkg/grpc/stream_tracker.go +++ b/pkg/grpc/stream_tracker.go @@ -19,8 +19,6 @@ type StreamInfo struct { ConnectedAt time.Time `json:"connectedAt"` MessagesSent int64 `json:"messagesSent"` MessagesRecv int64 `json:"messagesRecv"` - - cancel context.CancelFunc } // StreamStats holds aggregate statistics for gRPC streams. diff --git a/tests/integration/sse_test.go b/tests/integration/sse_test.go index b0895c17..8ab9cd53 100644 --- a/tests/integration/sse_test.go +++ b/tests/integration/sse_test.go @@ -93,6 +93,7 @@ func connectSSE(t *testing.T, url string) context.CancelFunc { resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) + t.Cleanup(func() { resp.Body.Close() }) // Read until we get at least one SSE event (data: line), confirming connection is tracked. ready := make(chan struct{}) @@ -106,7 +107,6 @@ func connectSSE(t *testing.T, url string) context.CancelFunc { signalled = true } } - // Body closed by context cancellation }() select { From 7f176bb0131ec68ff9fab68fed5a0a8dba229889 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:23:48 -0500 Subject: [PATCH 13/18] fix: remaining lint issues - unchecked Fprintf returns and gosec false positive --- .mcp.json | 8 ++++++++ pkg/cli/sse.go | 4 ++-- pkg/cli/ws_connections.go | 4 ++-- pkg/grpc/stream_tracker.go | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 .mcp.json diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 00000000..54d6c197 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "mockd": { + "command": "mockd", + "args": ["mcp"] + } + } +} diff --git a/pkg/cli/sse.go b/pkg/cli/sse.go index 35e6fd6d..91fa0948 100644 --- a/pkg/cli/sse.go +++ b/pkg/cli/sse.go @@ -56,9 +56,9 @@ func runSSEConnectionsList(cmd *cobra.Command, args []string) error { return } tw := output.Table() - fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCLIENT IP\tCONNECTED\tEVENTS\tSTATUS\n") + _, _ = fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCLIENT IP\tCONNECTED\tEVENTS\tSTATUS\n") for _, c := range result.Connections { - fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%s\n", + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\t%d\t%s\n", c.ID, c.Path, c.MockID, c.ClientIP, formatDuration(time.Since(c.ConnectedAt)), c.EventsSent, c.Status) diff --git a/pkg/cli/ws_connections.go b/pkg/cli/ws_connections.go index b00b8736..88ef1617 100644 --- a/pkg/cli/ws_connections.go +++ b/pkg/cli/ws_connections.go @@ -41,9 +41,9 @@ func runWSConnectionsList(cmd *cobra.Command, args []string) error { return } tw := output.Table() - fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCONNECTED\tMSG SENT\tMSG RECV\tSTATUS\n") + _, _ = fmt.Fprintf(tw, "ID\tPATH\tMOCK ID\tCONNECTED\tMSG SENT\tMSG RECV\tSTATUS\n") for _, c := range result.Connections { - fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%d\t%s\n", + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%d\t%d\t%s\n", c.ID, c.Path, c.MockID, formatDuration(time.Since(c.ConnectedAt)), c.MessagesSent, c.MessagesRecv, c.Status) diff --git a/pkg/grpc/stream_tracker.go b/pkg/grpc/stream_tracker.go index bfe74f79..96947a30 100644 --- a/pkg/grpc/stream_tracker.go +++ b/pkg/grpc/stream_tracker.go @@ -71,7 +71,7 @@ func nextStreamID() string { // context. The caller should defer Unregister. func (t *StreamTracker) Register(ctx context.Context, method string, st streamType, clientAddr string) (string, context.Context, context.CancelFunc) { id := nextStreamID() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(ctx) //nolint:gosec // cancel is stored in trackedStream.cancel ts := &trackedStream{ info: StreamInfo{ From fb9a4772762fcca5ad1f05beeab5fa41783437a2 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:23:57 -0500 Subject: [PATCH 14/18] fix: remove .mcp.json with credentials, add to gitignore --- .gitignore | 1 + .mcp.json | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) delete mode 100644 .mcp.json diff --git a/.gitignore b/.gitignore index d0faa240..1bd26bb3 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,4 @@ completions/ # Dashboard frontend dist (copied from mockd-desktop at release time) pkg/admin/dashboard/dist/assets/ pkg/admin/dashboard/dist/index.html +.mcp.json diff --git a/.mcp.json b/.mcp.json deleted file mode 100644 index 54d6c197..00000000 --- a/.mcp.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "mcpServers": { - "mockd": { - "command": "mockd", - "args": ["mcp"] - } - } -} From 4c6b20f04f8e5651e56ef1b1cdc046e1d28dd3e5 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:30:00 -0500 Subject: [PATCH 15/18] fix: resolve dupl lint findings in connection handlers --- pkg/admin/grpc_handlers.go | 4 ++++ pkg/admin/mqtt_handlers.go | 2 ++ pkg/admin/sse_handlers.go | 6 ++++++ pkg/admin/websocket_handlers.go | 6 ++++++ pkg/cli/grpc_connections.go | 1 + pkg/cli/sse.go | 1 + 6 files changed, 20 insertions(+) diff --git a/pkg/admin/grpc_handlers.go b/pkg/admin/grpc_handlers.go index a18df475..1090e043 100644 --- a/pkg/admin/grpc_handlers.go +++ b/pkg/admin/grpc_handlers.go @@ -16,6 +16,8 @@ type GRPCStreamListResponse struct { } // handleListGRPCStreams handles GET /grpc/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol list handlers func (a *API) handleListGRPCStreams(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -97,6 +99,8 @@ func (a *API) handleGetGRPCStream(w http.ResponseWriter, r *http.Request) { } // handleCancelGRPCStream handles DELETE /grpc/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers func (a *API) handleCancelGRPCStream(w http.ResponseWriter, r *http.Request) { ctx := r.Context() id := r.PathValue("id") diff --git a/pkg/admin/mqtt_handlers.go b/pkg/admin/mqtt_handlers.go index ed46a1f5..8bbd2b36 100644 --- a/pkg/admin/mqtt_handlers.go +++ b/pkg/admin/mqtt_handlers.go @@ -97,6 +97,8 @@ func (a *API) handleGetMQTTConnection(w http.ResponseWriter, r *http.Request) { } // handleCloseMQTTConnection handles DELETE /mqtt/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers func (a *API) handleCloseMQTTConnection(w http.ResponseWriter, r *http.Request) { ctx := r.Context() id := r.PathValue("id") diff --git a/pkg/admin/sse_handlers.go b/pkg/admin/sse_handlers.go index 07640c76..e6b2705e 100644 --- a/pkg/admin/sse_handlers.go +++ b/pkg/admin/sse_handlers.go @@ -16,6 +16,8 @@ type SSEConnectionListResponse struct { } // handleListSSEConnections handles GET /sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol list handlers func (a *API) handleListSSEConnections(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -147,6 +149,8 @@ func (a *API) handleGetSSEStats(w http.ResponseWriter, r *http.Request) { } // handleListMockSSEConnections handles GET /mocks/{id}/sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleListMockSSEConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") @@ -193,6 +197,8 @@ func (a *API) handleListMockSSEConnections(w http.ResponseWriter, r *http.Reques } // handleCloseMockSSEConnections handles DELETE /mocks/{id}/sse/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleCloseMockSSEConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") diff --git a/pkg/admin/websocket_handlers.go b/pkg/admin/websocket_handlers.go index 8bc05436..339c5167 100644 --- a/pkg/admin/websocket_handlers.go +++ b/pkg/admin/websocket_handlers.go @@ -100,6 +100,8 @@ func (a *API) handleGetWebSocketConnection(w http.ResponseWriter, r *http.Reques } // handleCloseWebSocketConnection handles DELETE /websocket/connections/{id}. +// +//nolint:dupl // intentionally parallel structure with other protocol close handlers func (a *API) handleCloseWebSocketConnection(w http.ResponseWriter, r *http.Request) { ctx := r.Context() id := r.PathValue("id") @@ -195,6 +197,8 @@ func (a *API) handleSendToWebSocketConnection(w http.ResponseWriter, r *http.Req } // handleListMockWebSocketConnections handles GET /mocks/{id}/websocket/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleListMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") @@ -240,6 +244,8 @@ func (a *API) handleListMockWebSocketConnections(w http.ResponseWriter, r *http. } // handleCloseMockWebSocketConnections handles DELETE /mocks/{id}/websocket/connections. +// +//nolint:dupl // intentionally parallel structure with other protocol mock-scoped handlers func (a *API) handleCloseMockWebSocketConnections(w http.ResponseWriter, r *http.Request, engine *engineclient.Client) { ctx := r.Context() mockID := r.PathValue("id") diff --git a/pkg/cli/grpc_connections.go b/pkg/cli/grpc_connections.go index 03569f6f..dcc46d8a 100644 --- a/pkg/cli/grpc_connections.go +++ b/pkg/cli/grpc_connections.go @@ -92,6 +92,7 @@ var grpcConnectionsCloseCmd = &cobra.Command{ }, } +//nolint:dupl // intentionally parallel structure with other protocol stats commands var grpcStatsCmd = &cobra.Command{ Use: "stats", Short: "Show gRPC statistics", diff --git a/pkg/cli/sse.go b/pkg/cli/sse.go index 91fa0948..9671567b 100644 --- a/pkg/cli/sse.go +++ b/pkg/cli/sse.go @@ -112,6 +112,7 @@ var sseConnectionsCloseCmd = &cobra.Command{ }, } +//nolint:dupl // intentionally parallel structure with other protocol stats commands var sseStatsCmd = &cobra.Command{ Use: "stats", Short: "Show SSE statistics", From 4dee62b20c445d7d37c3392d06e0ed13f926cd1b Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:32:45 -0500 Subject: [PATCH 16/18] fix: bump Go to 1.26.2 for stdlib CVE fixes (crypto/tls, crypto/x509, html/template) --- .claude/worktrees/agent-a29e97fb | 1 + .github/workflows/benchmark.yaml | 2 +- .github/workflows/ci.yaml | 12 ++++++------ go.mod | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) create mode 160000 .claude/worktrees/agent-a29e97fb diff --git a/.claude/worktrees/agent-a29e97fb b/.claude/worktrees/agent-a29e97fb new file mode 160000 index 00000000..0e9c0a08 --- /dev/null +++ b/.claude/worktrees/agent-a29e97fb @@ -0,0 +1 @@ +Subproject commit 0e9c0a0866b2858ab31917b886a10b6c26aa1589 diff --git a/.github/workflows/benchmark.yaml b/.github/workflows/benchmark.yaml index 3abb985e..5555ddb2 100644 --- a/.github/workflows/benchmark.yaml +++ b/.github/workflows/benchmark.yaml @@ -35,7 +35,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Install Apache Bench diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 13403a61..798de9d2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,7 +19,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Run golangci-lint @@ -38,7 +38,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Download dependencies @@ -66,7 +66,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Install govulncheck @@ -87,7 +87,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Run E2E tests @@ -105,7 +105,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Build binary @@ -321,7 +321,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: "1.26.1" + go-version: "1.26.2" cache: true - name: Build diff --git a/go.mod b/go.mod index b20001d5..1a8563c8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/getmockd/mockd -go 1.26.1 +go 1.26.2 require ( github.com/beevik/etree v1.6.0 From 8d9acf017b52cf262b8c1e3a0bf34a868facdff5 Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:32:55 -0500 Subject: [PATCH 17/18] fix: remove stale worktree ref, add to gitignore --- .claude/worktrees/agent-a29e97fb | 1 - .gitignore | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 160000 .claude/worktrees/agent-a29e97fb diff --git a/.claude/worktrees/agent-a29e97fb b/.claude/worktrees/agent-a29e97fb deleted file mode 160000 index 0e9c0a08..00000000 --- a/.claude/worktrees/agent-a29e97fb +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0e9c0a0866b2858ab31917b886a10b6c26aa1589 diff --git a/.gitignore b/.gitignore index 1bd26bb3..21d2fc47 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,4 @@ completions/ pkg/admin/dashboard/dist/assets/ pkg/admin/dashboard/dist/index.html .mcp.json +.claude/worktrees/ From bfd88c1525167dd4c489dd66b9f88f1d9b51346e Mon Sep 17 00:00:00 2001 From: Zach Snell Date: Wed, 8 Apr 2026 22:37:14 -0500 Subject: [PATCH 18/18] fix: bump Dockerfile to Go 1.26.2 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 8a5522a1..c6f2364d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # ============================================================================= # Stage 1: Builder # ============================================================================= -FROM golang:1.26.1-alpine AS builder +FROM golang:1.26.2-alpine AS builder # Build arguments for version info ARG VERSION=dev