diff --git a/.github/workflows/codeChecks.yml b/.github/workflows/codeChecks.yml new file mode 100644 index 0000000..100f27b --- /dev/null +++ b/.github/workflows/codeChecks.yml @@ -0,0 +1,42 @@ +--- +name: code checks +on: + push: + paths: + - "cmd/**" + - "internal/**" + - "pkg/**" + - "*.go" + - "go.*" +jobs: + code_check_job: + + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.24', '1.25'] + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go-version }} + + - name: Install dependencies + run: go get . + + - name: Build + run: go build -v ./... + + - name: Test with the Go CLI + run: go test -v ./... + + - name: Check for vulnerabilities + uses: golang/govulncheck-action@v1 + with: + go-version-input: ${{ matrix.go-version }} + go-package: ./... + work-dir: . diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 136409d..579aa8e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,7 +13,7 @@ jobs: DOCKER_CLI_EXPERIMENTAL: "enabled" steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 - uses: cachix/install-nix-action@v31 @@ -28,9 +28,9 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GH_GORELEASER_TOKEN }} - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: 1.24 + go-version: '1.24.9' - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 with: diff --git a/.github/workflows/vulncheck.yml b/.github/workflows/vulncheck.yml deleted file mode 100644 index d990cbf..0000000 --- a/.github/workflows/vulncheck.yml +++ /dev/null @@ -1,22 +0,0 @@ ---- -on: - push: - paths: - - "cmd/**" - - "internal/**" - - "pkg/**" - - "*.go" - - "go.*" -jobs: - govulncheck_job: - runs-on: ubuntu-latest - name: Run govulncheck - steps: - - name: Checkout - uses: actions/checkout@v4 - - id: govulncheck - uses: golang/govulncheck-action@v1 - with: - go-version-input: 1.24 - go-package: ./... - work-dir: . diff --git a/go.mod b/go.mod index 5bde940..c744827 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/xenos76/https-wrench -go 1.24.4 +go 1.24.9 require ( github.com/alecthomas/assert/v2 v2.11.0 diff --git a/internal/requests/requests.go b/internal/requests/requests.go index 2de2c6a..23821df 100644 --- a/internal/requests/requests.go +++ b/internal/requests/requests.go @@ -7,9 +7,11 @@ import ( "crypto/x509" "errors" "fmt" + "io" "net" "net/http" "net/http/httputil" + "os" "strings" "time" @@ -181,11 +183,13 @@ func (r *RequestsMetaConfig) SetRequests(requests []RequestConfig) *RequestsMeta return r } -func (r *RequestsMetaConfig) PrintCmd() { +func (r *RequestsMetaConfig) PrintCmd(w io.Writer) { if r.RequestVerbose { - fmt.Println() - fmt.Println(style.LgSprintf(style.Cmd, "Requests")) - fmt.Println() + fmt.Fprintf( + w, + "\n%s\n", + style.LgSprintf(style.Cmd, "Requests"), + ) } } @@ -201,50 +205,62 @@ func (r *RequestConfig) PrintTitle(isVerbose bool) { } } -func (r *RequestConfig) PrintRequestDebug(req *http.Request) { +func (r *RequestConfig) PrintRequestDebug(w io.Writer, req *http.Request) error { + if req == nil { + return errors.New("nil pointer to http.Request") + } + if r.RequestDebug { reqDump, err := httputil.DumpRequestOut(req, true) if err != nil { - fmt.Printf("Warning: failed to dump request: %v\n", err) - return + fmt.Fprintf(w, "Warning: failed to dump request: %v\n", err) + return err } - fmt.Printf("Requesting url: %s\n", req.URL) - fmt.Printf("Request dump:\n%s\n", string(reqDump)) + _, err = fmt.Fprintf(w, "Requesting url: %s\nRequest dump:\n%s\n", req.URL, string(reqDump)) + + return err } + + return nil } -func (r *RequestConfig) PrintResponseDebug(resp *http.Response) { +func (r *RequestConfig) PrintResponseDebug(w io.Writer, resp *http.Response) { + // TODO: return an error + if resp == nil { + return + } + if r.ResponseDebug { respDump, err := httputil.DumpResponse(resp, true) if err != nil { - fmt.Printf("Warning: failed to dump response: %v\n", err) + fmt.Fprintf(w, "Warning: failed to dump response: %v\n", err) return } - fmt.Printf("Requested url: %s\n", resp.Request.URL) - fmt.Printf("Response dump:\n%s\n", string(respDump)) + fmt.Fprintf(w, "Requested url: %s\n", resp.Request.URL) + fmt.Fprintf(w, "Response dump:\n%s\n", string(respDump)) if resp.TLS != nil { - fmt.Println("TLS:") - fmt.Printf("Version: %v\n", TLSVersionName(resp.TLS.Version)) - fmt.Printf("CipherSuite: %v\n", cipherSuiteName(resp.TLS.CipherSuite)) + fmt.Fprintln(w, "TLS:") + fmt.Fprintf(w, "Version: %v\n", TLSVersionName(resp.TLS.Version)) + fmt.Fprintf(w, "CipherSuite: %v\n", cipherSuiteName(resp.TLS.CipherSuite)) for i, cert := range resp.TLS.PeerCertificates { - fmt.Printf("Certificate %d:\n", i) + fmt.Fprintf(w, "Certificate %d:\n", i) certinfo.PrintCertInfo(cert, 1) } for i, chain := range resp.TLS.VerifiedChains { - fmt.Printf("Verified Chain %d:\n", i) + fmt.Fprintf(w, "Verified Chain %d:\n", i) for j, cert := range chain { - fmt.Printf(" Cert %d:\n", j) + fmt.Fprintf(w, " Cert %d:\n", j) certinfo.PrintCertInfo(cert, 2) } } } else { - fmt.Println("TLS: Not available (non-TLS connection)") + fmt.Fprintln(w, "TLS: Not available (non-TLS connection)") } } } @@ -275,6 +291,14 @@ func (rc *RequestHTTPClient) SetServerName(serverName string) (*RequestHTTPClien "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize") } + if serverName == emptyString { + return nil, errors.New("serverName cannot be empty") + } + + if strings.Contains(serverName, "://") { + return nil, fmt.Errorf("serverName should be a hostname, not a URL: %s", serverName) + } + transport, ok := rc.client.Transport.(*http.Transport) if !ok { return nil, fmt.Errorf("expected *http.Transport, got %T", rc.client.Transport) @@ -406,9 +430,15 @@ func (rc *RequestHTTPClient) SetTransportOverride(transportURL string) (*Request return rc, nil } -func (rc *RequestHTTPClient) SetProxyProtocolV2(header proxyproto.Header) (*RequestHTTPClient, error) { +func (rc *RequestHTTPClient) SetProxyProtocolV2(enable bool) *RequestHTTPClient { + rc.enableProxyProtoV2 = enable + + return rc +} + +func (rc *RequestHTTPClient) SetProxyProtocolHeader(header proxyproto.Header) (*RequestHTTPClient, error) { if rc.transportAddress == emptyString { - return nil, errors.New("SetProxyProtocolV2 failed: transportOverrideURL not set") + return nil, errors.New("SetProxyProtocolHeader failed: transportOverrideURL not set") } if rc.client == nil { @@ -457,6 +487,10 @@ func (rc *RequestHTTPClient) SetClientTimeout(timeout int) (*RequestHTTPClient, "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize") } + if timeout < 0 { + return nil, fmt.Errorf("timeout value must be positive: %v provided", timeout) + } + t := time.Duration(timeout) * time.Second rc.client.Timeout = t @@ -464,11 +498,6 @@ func (rc *RequestHTTPClient) SetClientTimeout(timeout int) (*RequestHTTPClient, } func NewHTTPClientFromRequestConfig(r RequestConfig, serverName string, caPool *x509.CertPool) (*RequestHTTPClient, error) { - if r.EnableProxyProtocolV2 && r.TransportOverrideURL == emptyString { - return nil, errors.New( - "if EnableProxyProtocolV2 is true, a TransportOverrideURL must be set") - } - reqClient := NewRequestHTTPClient() _, err := reqClient.SetCACertsPool(caPool) @@ -501,15 +530,22 @@ func NewHTTPClientFromRequestConfig(r RequestConfig, serverName string, caPool * return nil, fmt.Errorf("SetTransportOverride error: %w", err) } + reqClient.SetProxyProtocolV2(r.EnableProxyProtocolV2) + + if r.EnableProxyProtocolV2 && r.TransportOverrideURL == emptyString { + return nil, errors.New( + "if EnableProxyProtocolV2 is true, a TransportOverrideURL must be set") + } + if r.EnableProxyProtocolV2 && reqClient.transportAddress != emptyString { header, err := proxyProtoHeaderFromRequest(r, serverName) if err != nil { return nil, fmt.Errorf("error creating proxyproto Header: %w", err) } - _, err = reqClient.SetProxyProtocolV2(header) + _, err = reqClient.SetProxyProtocolHeader(header) if err != nil { - return nil, fmt.Errorf("SetProxyProtocolV2 error: %w", err) + return nil, fmt.Errorf("SetProxyProtocolHeader error: %w", err) } } @@ -563,7 +599,9 @@ func processHTTPRequestsByHost(r RequestConfig, caPool *x509.CertPool, isVerbose req.Header.Add(header.Key, header.Value) } - r.PrintRequestDebug(req) + if err := r.PrintRequestDebug(os.Stdout, req); err != nil { + fmt.Fprintf(os.Stderr, "Warning: PrintRequestDebug failed: %v\n", err) + } resp, err := reqClient.client.Do(req) if err != nil { @@ -580,7 +618,7 @@ func processHTTPRequestsByHost(r RequestConfig, caPool *x509.CertPool, isVerbose continue } - r.PrintResponseDebug(resp) + r.PrintResponseDebug(os.Stdout, resp) responseData.Response = resp diff --git a/internal/requests/requests_handlers.go b/internal/requests/requests_handlers.go index 776c8d3..8f88eb3 100644 --- a/internal/requests/requests_handlers.go +++ b/internal/requests/requests_handlers.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "os" "regexp" "slices" "strconv" @@ -205,7 +206,7 @@ func proxyProtoHeaderFromRequest(r RequestConfig, serverName string) (proxyproto func HandleRequests(cfg *RequestsMetaConfig) (map[string][]ResponseData, error) { responseDataMap := make(map[string][]ResponseData) - cfg.PrintCmd() + cfg.PrintCmd(os.Stdout) for _, r := range cfg.Requests { responseDataList, err := processHTTPRequestsByHost( diff --git a/internal/requests/requests_test.go b/internal/requests/requests_test.go index 29a7ac2..95630ad 100644 --- a/internal/requests/requests_test.go +++ b/internal/requests/requests_test.go @@ -1,14 +1,18 @@ package requests import ( + "bytes" "crypto/tls" "crypto/x509" "fmt" "net/http" "net/http/httptest" + "net/url" "testing" + "time" "github.com/google/go-cmp/cmp" + "github.com/pires/go-proxyproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -49,7 +53,8 @@ func TestNewRequestsMetaConfig(t *testing.T) { func TestRequestsMetaConfig_SetVerbose(t *testing.T) { tests := []bool{true, false} - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("SetVerbose(%v)", tt) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -64,7 +69,8 @@ func TestRequestsMetaConfig_SetVerbose(t *testing.T) { func TestRequestsMetaConfig_SetDebug(t *testing.T) { tests := []bool{true, false} - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("SetDebug(%v)", tt) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -89,7 +95,8 @@ func TestRequestsMetaConfig_SetCaPoolFromYAML(t *testing.T) { }, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("SetCaPoolFromYAML(%v)", tt.desc) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -134,7 +141,8 @@ func TestRequestsMetaConfig_SetCaPoolFromFile(t *testing.T) { }, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("SetCaPoolFromFile(%v)", tt.desc) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -155,6 +163,45 @@ func TestRequestsMetaConfig_SetCaPoolFromFile(t *testing.T) { } } +func TestRequestsMetaConfig_SetRequests(t *testing.T) { + hosts := []Host{ + {Name: "www.example.com"}, + {Name: "example.com", URIList: []URI{"/test", "test2"}}, + } + + requestConfigs := []RequestConfig{ + { + Name: "first request", + Insecure: true, + TransportOverrideURL: "localhost:443", + }, + { + Name: "second request", + PrintResponseBody: true, + TransportOverrideURL: "localhost:443", + Hosts: hosts, + }, + } + + tests := [][]RequestConfig{requestConfigs} + + for _, tc := range tests { + tt := tc // use a local copy of tc should be safer when using t.Parallel() + testname := fmt.Sprintf("SetRequests(%v)", tt[0].Name) + t.Run(testname, func(t *testing.T) { + t.Parallel() + + rmc, err := NewRequestsMetaConfig() + require.NoError(t, err) + rmc.SetRequests(tt) + + if diff := cmp.Diff(tt, rmc.Requests); diff != "" { + t.Errorf("Requests mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestNewRequestHTTPClient(t *testing.T) { t.Run("NewRequestHTTPClient", func(t *testing.T) { t.Parallel() @@ -197,18 +244,185 @@ func TestNewRequestHTTPClient(t *testing.T) { }) } +func TestNewHTTPClientFromRequestConfig_Error(t *testing.T) { + tests := []struct { + desc string + reqConf RequestConfig + serverName string + errMsg string + }{ + { + desc: "EnableProxyProtocolV2", + reqConf: RequestConfig{ + EnableProxyProtocolV2: true, + }, + serverName: "localhost", + errMsg: "if EnableProxyProtocolV2 is true, a TransportOverrideURL must be set", + }, + { + desc: "EnableProxyProtoNoServerName", + reqConf: RequestConfig{ + TransportOverrideURL: "https://localhost:8443", + EnableProxyProtocolV2: true, + }, + serverName: emptyString, + errMsg: "SetServerName error: serverName cannot be empty", + }, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + _, err := NewHTTPClientFromRequestConfig( + tt.reqConf, + tt.serverName, + nil, + ) + require.Error(t, err) + assert.Equal(t, + tt.errMsg, + err.Error(), + ) + }) + } +} + +func TestNewHTTPClientFromRequestConfig(t *testing.T) { + tests := []struct { + desc string + reqConf RequestConfig + serverName string + pool *x509.CertPool + transportAddress string + }{ + { + desc: "transpAddr", + reqConf: RequestConfig{ + ClientTimeout: 3, + UserAgent: "test1-ua", + TransportOverrideURL: "https://localhost:45555", + Insecure: true, + RequestMethod: http.MethodGet, + }, + serverName: "localhost", + transportAddress: "localhost:45555", + }, + { + desc: "proxyProto", + reqConf: RequestConfig{ + TransportOverrideURL: "https://localhost:8443", + RequestMethod: http.MethodHead, + EnableProxyProtocolV2: true, + }, + serverName: "localhost", + transportAddress: "localhost:8443", + }, + { + desc: "caPool", + reqConf: RequestConfig{ + RequestMethod: http.MethodPut, + }, + serverName: "localhost", + pool: caCertPool, + }, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + rcClient, err := NewHTTPClientFromRequestConfig( + tt.reqConf, + tt.serverName, + tt.pool, + ) + if err != nil { + t.Fatal(err) + } + + var i any = rcClient.client + + client, ok := i.(*http.Client) + if !ok { + t.Errorf("expecting *http.Client, got %T", client) + } + + assert.Equal(t, + time.Duration(tt.reqConf.ClientTimeout)*time.Second, + client.Timeout, + "check client Timeout", + ) + + assert.Equal(t, + tt.reqConf.RequestMethod, + rcClient.method, + "check client Method", + ) + + assert.Equal(t, + tt.reqConf.EnableProxyProtocolV2, + rcClient.enableProxyProtoV2, + "check proxy proto enabled", + ) + + if tt.transportAddress != emptyString { + assert.Equal(t, + tt.transportAddress, + rcClient.transportAddress, + "check transportAddress", + ) + } + + var ti any = rcClient.client.Transport + + transport, ok := ti.(*http.Transport) + if !ok { + t.Errorf("expecting *http.Transport, got %T", transport) + } + + assert.Equal(t, + tt.reqConf.Insecure, + transport.TLSClientConfig.InsecureSkipVerify, + "check Insecure", + ) + + currPool := systemCertPool + + if tt.pool != nil { + currPool = caCertPool + } + + if diff := cmp.Diff(currPool, transport.TLSClientConfig.RootCAs); diff != "" { + t.Errorf("Client CA Pool mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestNewRequestHTTPClient_SetServerName(t *testing.T) { tests := []string{ - "", "localhost", "127.0.0.1", "[::1]", "example.com", " a silly string ", + "[::1]", + "localhost", + "127.0.0.1", + "example.com", + " a silly string ", } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt) t.Run(testname, func(t *testing.T) { t.Parallel() c := NewRequestHTTPClient() - c.SetServerName(tt) + + _, err := c.SetServerName(tt) + if err != nil { + t.Fatal(err) + } // Extract the transport via type assertion transport, ok := c.client.Transport.(*http.Transport) @@ -216,18 +430,136 @@ func TestNewRequestHTTPClient_SetServerName(t *testing.T) { t.Fatalf("expected *http.Transport, got %T", c.client.Transport) } - if transport.TLSClientConfig == nil { - t.Fatal("TLSClientConfig is nil") - } + assert.NotNil(t, + transport.TLSClientConfig, + "check TLSClientConfig not nil", + ) - if transport.TLSClientConfig.ServerName != tt { - t.Errorf("expected ServerName to be %s, got %s", - tt, transport.TLSClientConfig.ServerName) + assert.Equal(t, + tt, + transport.TLSClientConfig.ServerName, + "check ServerName in TLSClientConfig", + ) + }) + } +} + +func TestNewRequestHTTPClient_SetServerName_clientError(t *testing.T) { + t.Run("malformedClient", func(t *testing.T) { + t.Parallel() + + noTransportClient := http.Client{} + c := NewRequestHTTPClient() + + c.client = &noTransportClient + + _, err := c.SetServerName("localhost") + + require.Error(t, err) + }) +} + +func TestNewRequestHTTPClient_SetServerName_Error(t *testing.T) { + testsError := []struct { + desc string + serverName string + errMsg string + }{ + { + "empty serverName", + emptyString, + "serverName cannot be empty", + }, + { + "url as serverName", + "https://localhost", + "serverName should be a hostname, not a URL: https://localhost", + }, + } + + for _, tc := range testsError { + tt := tc // safer when using t.Parallel() + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + c := NewRequestHTTPClient() + _, err := c.SetServerName(tt.serverName) + require.Error(t, err) + assert.Equal(t, tt.errMsg, err.Error()) + }) + } + + t.Run("Error: Nil HTTP client", func(t *testing.T) { + t.Parallel() + + var c RequestHTTPClient + + _, err := c.SetServerName("localhost") + require.Error(t, err) + assert.Equal(t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) +} + +func TestNewRequestHTTPClient_SetClientTimeout(t *testing.T) { + tests := []int{ + 3, 0, 50, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + testname := fmt.Sprintf("%v", tt) + t.Run(testname, func(t *testing.T) { + t.Parallel() + + c := NewRequestHTTPClient() + + _, err := c.SetClientTimeout(tt) + require.NoError(t, err) + + var i any = c.client.Timeout + + duration, ok := i.(time.Duration) + if !ok { + t.Fatalf("expected time.Duration, got %T", c.client.Timeout) } + + assert.Equal(t, time.Duration(tt)*time.Second, duration) }) } } +func TestNewRequestHTTPClient_SetClientTimeout_Error(t *testing.T) { + t.Run("Negative Timeout", func(t *testing.T) { + t.Parallel() + + c := NewRequestHTTPClient() + timeout := -1 + + _, err := c.SetClientTimeout(timeout) + require.Error(t, err) + assert.Equal(t, "timeout value must be positive: -1 provided", err.Error()) + }) + + t.Run("Nil Timeout", func(t *testing.T) { + t.Parallel() + + var c RequestHTTPClient + + timeout := 10 + + _, err := c.SetClientTimeout(timeout) + require.Error(t, err) + assert.Equal( + t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) +} + func TestNewRequestHTTPClient_SetCaCertsPool(t *testing.T) { var emptyPool *x509.CertPool @@ -245,7 +577,8 @@ func TestNewRequestHTTPClient_SetCaCertsPool(t *testing.T) { {"System Cert Pool", defaultCertPool, defaultCertPool}, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() t.Run(tt.testname, func(t *testing.T) { t.Parallel() @@ -269,9 +602,44 @@ func TestNewRequestHTTPClient_SetCaCertsPool(t *testing.T) { } } +func TestNewRequestHTTPClient_SetCaCertsPool_Error(t *testing.T) { + t.Run("nilClient", func(t *testing.T) { + t.Parallel() + + var c RequestHTTPClient + + _, err := c.SetCACertsPool(caCertPool) + + require.Error(t, err) + assert.Equal(t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) + + t.Run("malformedClient", func(t *testing.T) { + t.Parallel() + + incompleteClient := http.Client{} + c := NewRequestHTTPClient() + + c.client = &incompleteClient + + _, err := c.SetCACertsPool(caCertPool) + + require.Error(t, err) + + assert.Equal(t, + "expected *http.Transport, got ", + err.Error(), + ) + }) +} + func TestNewRequestHTTPClient_SetInsecureSkipVerify_struct(t *testing.T) { tests := []bool{true, false} - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt) t.Run(testname, func(t *testing.T) { @@ -300,7 +668,8 @@ func TestNewRequestHTTPClient_SetInsecureSkipVerify_struct(t *testing.T) { func TestNewRequestHTTPClient_SetInsecureSkipVerify_tlsServer(t *testing.T) { tests := []bool{true, false} - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt) t.Run(testname, func(t *testing.T) { @@ -322,13 +691,47 @@ func TestNewRequestHTTPClient_SetInsecureSkipVerify_tlsServer(t *testing.T) { } if tt { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) } }) } } +func TestNewRequestHTTPClient_SetInsecureSkipVerify_Error(t *testing.T) { + t.Run("nilClient", func(t *testing.T) { + t.Parallel() + + var c RequestHTTPClient + + _, err := c.SetInsecureSkipVerify(true) + + require.Error(t, err) + assert.Equal(t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) + + t.Run("malformedClient", func(t *testing.T) { + t.Parallel() + + incompleteClient := http.Client{} + c := NewRequestHTTPClient() + + c.client = &incompleteClient + + _, err := c.SetInsecureSkipVerify(true) + + require.Error(t, err) + + assert.Equal(t, + "expected *http.Transport, got ", + err.Error(), + ) + }) +} + func TestNewRequestHTTPClient_SetMethod(t *testing.T) { tests := []struct { got string @@ -348,7 +751,8 @@ func TestNewRequestHTTPClient_SetMethod(t *testing.T) { {"post", http.MethodPost}, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt.got) t.Run(testname, func(t *testing.T) { @@ -374,7 +778,8 @@ func TestRequestHTTPClient_SetTransportOverride_transportAddress_struc(t *testin {"https://example.com", "example.com:443"}, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt.got) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -389,6 +794,39 @@ func TestRequestHTTPClient_SetTransportOverride_transportAddress_struc(t *testin } } +func TestRequestHTTPClient_SetTransportOverride_Error(t *testing.T) { + t.Run("nilClient", func(t *testing.T) { + t.Parallel() + + var c RequestHTTPClient + + _, err := c.SetTransportOverride("http://localhost") + require.Error(t, err) + assert.Equal(t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) + + t.Run("malformedClient", func(t *testing.T) { + t.Parallel() + + incompleteClient := http.Client{} + c := NewRequestHTTPClient() + + c.client = &incompleteClient + + _, err := c.SetTransportOverride("http://localhost") + + require.Error(t, err) + + assert.Equal(t, + "expected *http.Transport, got ", + err.Error(), + ) + }) +} + // Test SetTransportOverride method for RequestHTTPClient. // Use case: we want our client to redirect HTTPS request meant for https://hostname to a // third party proxy. This in order to test the settings of that proxy before pointing to it the DNS @@ -411,7 +849,8 @@ func TestRequestHTTPClient_SetTransportOverride_transportAddress_server(t *testi }, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() testname := fmt.Sprintf("%v", tt.trasportURL) t.Run(testname, func(t *testing.T) { t.Parallel() @@ -507,7 +946,8 @@ func TestRequestHTTPClient_SetProxyProtocolV2_server(t *testing.T) { }, } - for _, tt := range tests { + for _, tc := range tests { + tt := tc // safer when using t.Parallel() t.Run(tt.testname, func(t *testing.T) { t.Parallel() @@ -544,7 +984,7 @@ func TestRequestHTTPClient_SetProxyProtocolV2_server(t *testing.T) { c := NewRequestHTTPClient() c.SetTransportOverride(transportURL) - c.SetProxyProtocolV2(header) + c.SetProxyProtocolHeader(header) // Extract the transport via type assertion transport, ok := c.client.Transport.(*http.Transport) @@ -589,3 +1029,459 @@ func TestRequestHTTPClient_SetProxyProtocolV2_server(t *testing.T) { }) } } + +func TestRequestHTTPClient_SetProxyProtocolHeader_Error(t *testing.T) { + t.Run("nilClient", func(t *testing.T) { + t.Parallel() + + c := RequestHTTPClient{transportAddress: "127.0.0.1:443"} + header := proxyproto.Header{} + _, err := c.SetProxyProtocolHeader(header) + + require.Error(t, err) + assert.Equal(t, + "*RequestHTTPClient.client is nil. Use NewRequestHTTPClient to initialize", + err.Error(), + ) + }) + + t.Run("malformedClient", func(t *testing.T) { + t.Parallel() + + incompleteClient := http.Client{} + c := NewRequestHTTPClient() + c.transportAddress = "127.0.0.1:443" + c.client = &incompleteClient + + header := proxyproto.Header{} + _, err := c.SetProxyProtocolHeader(header) + require.Error(t, err) + + assert.Equal(t, + "expected *http.Transport, got ", + err.Error(), + ) + }) + + t.Run("noTransportAddress", func(t *testing.T) { + t.Parallel() + + incompleteClient := http.Client{} + c := NewRequestHTTPClient() + c.client = &incompleteClient + + header := proxyproto.Header{} + _, err := c.SetProxyProtocolHeader(header) + require.Error(t, err) + + assert.Equal(t, + "SetProxyProtocolHeader failed: transportOverrideURL not set", + err.Error(), + ) + }) +} + +func TestPrintCmd(t *testing.T) { + tests := []bool{true, false} + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + testname := fmt.Sprintf("%v", tt) + t.Run(testname, func(t *testing.T) { + t.Parallel() + + buffer := bytes.Buffer{} + r := RequestsMetaConfig{RequestVerbose: tt} + r.PrintCmd(&buffer) + + got := buffer.String() + if tt { + assert.Contains(t, got, + "Requests", + "check PrintCmd when verbose", + ) + } else { + assert.Empty(t, + got, + "check empty outputs from PrintCmd when not verbose", + ) + } + }) + } +} + +func TestPrintResponseDebug(t *testing.T) { + tests := []struct { + desc string + srvAddr string + verbose bool + outputs []string + }{ + { + desc: "verboseTrue", + srvAddr: "localhost:46010", + verbose: true, + outputs: []string{ + "Requested url:", + "Response dump:", + "DemoHTTPSServer Handler - client output", + "TLS:", + "CipherSuite:", + }, + }, + { + desc: "verboseFalse", + srvAddr: "localhost:46011", + verbose: false, + outputs: []string{emptyString}, + }, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + httpSrvData := demoHttpServerData{ + serverAddr: tt.srvAddr, + proxyprotoEnabled: false, + serverName: "localhost", + } + + ts, err := NewHTTPSTestServer(httpSrvData) + if err != nil { + t.Fatal(err) + } + defer ts.Close() + + tr := &http.Transport{TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }} + + client := &http.Client{Transport: tr} + + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + rc := RequestConfig{ResponseDebug: tt.verbose} + buffer := bytes.Buffer{} + rc.PrintResponseDebug(&buffer, res) + + got := buffer.String() + fmt.Printf("got:\n%s\n", got) + + if !tt.verbose && len(got) == 0 { + assert.Empty(t, + buffer.Bytes(), + "check PrintResponseDebug with verbose False", + ) + } + + if tt.verbose { + for _, output := range tt.outputs { + assert.True(t, + bytes.Contains(buffer.Bytes(), []byte(output)), + "check PrintResponseDebug contains: %s", output, + ) + } + } + }) + } +} + +func TestPrintResponseDebug_Error(t *testing.T) { + t.Run("NilResponse", func(t *testing.T) { + rc := RequestConfig{ResponseDebug: true} + buffer := bytes.Buffer{} + rc.PrintResponseDebug(&buffer, nil) + + got := buffer.String() + assert.Empty(t, + got, + "output should be empty when Response is nil", + ) + }) + + // TODO: trigger error on malformed Response + // + // t.Run("MalformedResponse", func(t *testing.T) { + // httpTestHeader := http.Header{} + // httpTestHeader.Add("user-agent", "go-test") + // + // httpTestRequest := http.Request{ + // // Method: http.MethodGet, + // // URL: &url.URL{Scheme: "https", Host: "localhost"}, + // ContentLength: 300, + // Body: nil, + // } + // + // response := http.Response{ + // // Status: "200 OK", + // // StatusCode: 200, + // // Header: httpTestHeader, + // Request: &httpTestRequest, + // } + // rc := RequestConfig{ResponseDebug: true} + // buffer := bytes.Buffer{} + // rc.PrintResponseDebug(&buffer, &response) + // + // got := buffer.String() + // fmt.Printf("got:\n%s\n", got) + // assert.Contains(t, + // got, + // "Warning: failed to dump response:", + // "check PrintResponseDebug: MalformedResponse", + // ) + // }) +} + +func TestPrintResponseDebug_nonTLS(t *testing.T) { + t.Run("non-TLS", func(t *testing.T) { + respURL := url.URL{Scheme: "http", Host: "localhost"} + req := http.Request{URL: &respURL} + resp := http.Response{ + StatusCode: 200, + Request: &req, + } + rc := RequestConfig{ResponseDebug: true} + buffer := bytes.Buffer{} + rc.PrintResponseDebug(&buffer, &resp) + + assert.True(t, + bytes.Contains(buffer.Bytes(), []byte("TLS: Not available")), + "check non-TLS connection", + ) + }) +} + +func TestPrintRequestDebug(t *testing.T) { + httpTestHeader := http.Header{} + httpTestHeader.Add("user-agent", "go-test") + requestTest := http.Request{ + Method: http.MethodGet, + URL: &url.URL{Scheme: "https", Host: "localhost"}, + Header: httpTestHeader, + } + + requestTestIncomplete := http.Request{ + Method: http.MethodGet, + URL: &url.URL{Scheme: "https", Host: "localhost"}, + } + + var requestTestNilPointer *http.Request + + expectedOutput := "Requesting url: https://localhost\nRequest dump:\nGET / " + expectedOutput += "HTTP/1.1\r\nHost: localhost\r\nUser-Agent: go-test\r\nAccept-Encoding: " + expectedOutput += "gzip\r\n\r\n\n" + + expectedOutputIncomplete := "Warning: failed to dump request: http: nil Request.Header\n" + + tests := []struct { + desc string + verbose bool + request *http.Request + output string + expectErr bool + }{ + { + desc: "verboseTrue", + verbose: true, + request: &requestTest, + output: expectedOutput, + }, + + { + desc: "verboseFalse", + verbose: false, + request: &requestTest, + output: emptyString, + }, + { + desc: "nilRequestError", + verbose: true, + request: requestTestNilPointer, + output: emptyString, + expectErr: true, + }, + { + desc: "incompleteRequestError", + verbose: true, + request: &requestTestIncomplete, + output: expectedOutputIncomplete, + expectErr: true, + }, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + buffer := bytes.Buffer{} + r := RequestConfig{RequestDebug: tt.verbose} + + err := r.PrintRequestDebug(&buffer, tt.request) + + if tt.expectErr { + require.Error(t, err, "PrintRequestDebug should return an error") + } else { + require.NoError(t, err, "PrintRequestDebug should not return an error") + } + + got := buffer.String() + want := tt.output + + assert.Equal(t, want, got, "check PrintRequestDebug") + }) + } +} + +func TestProcessHTTPRequestsByHost(t *testing.T) { + tests := []struct { + srvAddr string + reqConf RequestConfig + pool *x509.CertPool + verbose bool + respStatusCode int + errMsg string + }{ + { + srvAddr: "localhost:46001", + reqConf: RequestConfig{ + Name: "StatusOK", + TransportOverrideURL: "https://localhost:46001", + UserAgent: "test-ua", + RequestHeaders: []RequestHeader{ + {Key: "testKey", Value: "testValue"}, + {Key: "testKey2", Value: "testValue2"}, + }, + Hosts: []Host{ + {Name: "example.com"}, + }, + }, + pool: caCertPool, + verbose: false, + respStatusCode: http.StatusOK, + }, + + { + srvAddr: "localhost:46002", + reqConf: RequestConfig{ + Name: "invalidServerName", + TransportOverrideURL: "https://localhost:46002", + Hosts: []Host{ + {Name: "localhost"}, + }, + }, + pool: caCertPool, + verbose: false, + respStatusCode: 0, + errMsg: "Get \"https://localhost\": tls: failed to verify certificate: x509: certificate is valid for example.com, example.net, example.de, not localhost", + }, + + { + srvAddr: "localhost:46003", + reqConf: RequestConfig{ + Name: "bodyRex", + ResponseBodyMatchRegexp: "DemoHTTPSServer Handler - client output", + PrintResponseBody: true, + TransportOverrideURL: "https://localhost:46003", + Hosts: []Host{ + {Name: "example.com"}, + }, + }, + pool: caCertPool, + verbose: true, + respStatusCode: http.StatusOK, + }, + } + + for _, tc := range tests { + tt := tc // safer when using t.Parallel() + t.Run(tt.reqConf.Name, func(t *testing.T) { + t.Parallel() + + httpSrvData := demoHttpServerData{ + serverAddr: tt.srvAddr, + proxyprotoEnabled: false, + serverName: "localhost", + } + + ts, err := NewHTTPSTestServer(httpSrvData) + if err != nil { + t.Fatal(err) + } + defer ts.Close() + + respList, err := processHTTPRequestsByHost( + tt.reqConf, + tt.pool, + tt.verbose, + ) + if err != nil { + t.Error(err) + } + + for _, r := range respList { + fmt.Printf("resp type: %T\n", r) + + assert.Equal(t, + tt.srvAddr, + r.TransportAddress, + "check TransportAddress", + ) + + if tt.respStatusCode == 0 { + assert.Equal(t, + tt.errMsg, + r.Error.Error(), + "check Response Error", + ) + } + + // if expecting and error from the request do not + // check values from the response + if tt.respStatusCode != 0 { + require.NoError(t, + r.Error, + "check NoError in ResponseData", + ) + + ua := httpUserAgent + + if tt.reqConf.UserAgent != emptyString { + ua = tt.reqConf.UserAgent + } + + assert.Equal(t, + ua, + r.Response.Request.Header.Get("user-agent"), + "check UserAgent", + ) + + assert.Equal(t, + len(tt.reqConf.ResponseBodyMatchRegexp) > 0, + r.ResponseBodyRegexpMatched, + "check body rex match", + ) + + assert.Equal(t, + tt.respStatusCode, + r.Response.StatusCode, + "check StatusCode", + ) + + for _, headers := range tt.reqConf.RequestHeaders { + assert.Equal(t, + headers.Value, + r.Response.Request.Header.Get(headers.Key), + "check RequestHeaders Key", + ) + } + } + } + }) + } +}