Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions pkg/audit/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"

"github.com/stacklok/toolhive/pkg/transport/types"
)

// Config represents the audit logging configuration.
Expand Down Expand Up @@ -104,28 +101,6 @@ func (c *Config) ShouldAuditEvent(eventType string) bool {
return true
}

// CreateMiddlewareWithTransport creates an HTTP middleware from the audit configuration with transport information.
func (c *Config) CreateMiddlewareWithTransport(transportType string) (types.MiddlewareFunction, error) {
auditor, err := NewAuditorWithTransport(c, transportType)
if err != nil {
return nil, fmt.Errorf("failed to create auditor: %w", err)
}
return auditor.Middleware, nil
}

// GetMiddlewareFromFile loads the audit configuration from a file and creates an HTTP middleware.
// Note: This function requires a transport type to be provided separately.
func GetMiddlewareFromFile(path string, transportType string) (func(http.Handler) http.Handler, error) {
// Load the configuration
config, err := LoadFromFile(path)
if err != nil {
return nil, fmt.Errorf("failed to load audit config: %w", err)
}

// Create the middleware with transport information
return config.CreateMiddlewareWithTransport(transportType)
}

// Validate validates the audit configuration.
func (c *Config) Validate() error {
if c.MaxDataSize < 0 {
Expand Down
17 changes: 0 additions & 17 deletions pkg/audit/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ func TestShouldAuditEventExcludeTakesPrecedence(t *testing.T) {
assert.False(t, config.ShouldAuditEvent("mcp_resource_read")) // Not in EventTypes
}

func TestCreateMiddleware(t *testing.T) {
t.Parallel()
config := &Config{}

middleware, err := config.CreateMiddlewareWithTransport("sse")
assert.NoError(t, err)
assert.NotNil(t, middleware)
}

func TestValidateValidConfig(t *testing.T) {
t.Parallel()
config := &Config{
Expand Down Expand Up @@ -238,14 +229,6 @@ func TestConfigMinimalJSON(t *testing.T) {
assert.Equal(t, 0, config.MaxDataSize) // Default zero value
}

func TestGetMiddlewareFromFileError(t *testing.T) {
t.Parallel()
// Test with non-existent file
_, err := GetMiddlewareFromFile("/non/existent/file.json", "sse")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to load audit config")
}

func TestLoadFromFilePathCleaning(t *testing.T) {
t.Parallel()
// Test that filepath.Clean is used (this is more of a smoke test)
Expand Down
16 changes: 11 additions & 5 deletions pkg/audit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type MiddlewareParams struct {
// Middleware wraps audit middleware functionality
type Middleware struct {
middleware types.MiddlewareFunction
auditor *Auditor
}

// Handler returns the middleware function used by the proxy.
Expand All @@ -32,8 +33,10 @@ func (m *Middleware) Handler() types.MiddlewareFunction {
}

// Close cleans up any resources used by the middleware.
func (*Middleware) Close() error {
// Audit middleware doesn't need cleanup
func (m *Middleware) Close() error {
if m.auditor != nil {
return m.auditor.Close()
}
return nil
}

Expand Down Expand Up @@ -67,13 +70,16 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
auditConfig.Component = params.Component
}

// Always use the transport-aware constructor
middleware, err := auditConfig.CreateMiddlewareWithTransport(params.TransportType)
// Create the auditor directly so we can store a reference for cleanup
auditor, err := NewAuditorWithTransport(auditConfig, params.TransportType)
if err != nil {
return fmt.Errorf("failed to create audit middleware: %w", err)
}

auditMw := &Middleware{middleware: middleware}
auditMw := &Middleware{
middleware: auditor.Middleware,
auditor: auditor,
}
runner.AddMiddleware(config.Type, auditMw)
return nil
}
Loading