Skip to content

Commit 273ebe3

Browse files
fix: add error handling for file writes and HTTP response body cleanup
1 parent cdce892 commit 273ebe3

10 files changed

Lines changed: 158 additions & 83 deletions

File tree

cmd/protocol.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,16 @@ func saveProtocolResults(findings []types.Finding, output string) error {
499499
}
500500
}()
501501

502-
// Create report
503-
fmt.Fprintf(file, "Protocol Security Scan Report\n")
504-
fmt.Fprintf(file, "Generated: %s\n", time.Now().Format(time.RFC3339))
505-
fmt.Fprintf(file, "=====================================\n\n")
502+
// Create report - check all write errors
503+
if _, err := fmt.Fprintf(file, "Protocol Security Scan Report\n"); err != nil {
504+
return fmt.Errorf("failed to write report header: %w", err)
505+
}
506+
if _, err := fmt.Fprintf(file, "Generated: %s\n", time.Now().Format(time.RFC3339)); err != nil {
507+
return fmt.Errorf("failed to write report timestamp: %w", err)
508+
}
509+
if _, err := fmt.Fprintf(file, "=====================================\n\n"); err != nil {
510+
return fmt.Errorf("failed to write report separator: %w", err)
511+
}
506512

507513
// Group by target
508514
byTarget := make(map[string][]types.Finding)
@@ -517,28 +523,48 @@ func saveProtocolResults(findings []types.Finding, output string) error {
517523
}
518524

519525
for target, targetFindings := range byTarget {
520-
fmt.Fprintf(file, "Target: %s\n", target)
521-
fmt.Fprintf(file, "-------------------\n")
526+
if _, err := fmt.Fprintf(file, "Target: %s\n", target); err != nil {
527+
return fmt.Errorf("failed to write target header: %w", err)
528+
}
529+
if _, err := fmt.Fprintf(file, "-------------------\n"); err != nil {
530+
return fmt.Errorf("failed to write target separator: %w", err)
531+
}
522532

523533
for _, finding := range targetFindings {
524-
fmt.Fprintf(file, "\nTitle: %s\n", finding.Title)
525-
fmt.Fprintf(file, "Type: %s\n", finding.Type)
526-
fmt.Fprintf(file, "Severity: %s\n", finding.Severity)
527-
fmt.Fprintf(file, "Description: %s\n", finding.Description)
534+
if _, err := fmt.Fprintf(file, "\nTitle: %s\n", finding.Title); err != nil {
535+
return fmt.Errorf("failed to write finding title: %w", err)
536+
}
537+
if _, err := fmt.Fprintf(file, "Type: %s\n", finding.Type); err != nil {
538+
return fmt.Errorf("failed to write finding type: %w", err)
539+
}
540+
if _, err := fmt.Fprintf(file, "Severity: %s\n", finding.Severity); err != nil {
541+
return fmt.Errorf("failed to write finding severity: %w", err)
542+
}
543+
if _, err := fmt.Fprintf(file, "Description: %s\n", finding.Description); err != nil {
544+
return fmt.Errorf("failed to write finding description: %w", err)
545+
}
528546

529547
if finding.Solution != "" {
530-
fmt.Fprintf(file, "Remediation: %s\n", finding.Solution)
548+
if _, err := fmt.Fprintf(file, "Remediation: %s\n", finding.Solution); err != nil {
549+
return fmt.Errorf("failed to write finding remediation: %w", err)
550+
}
531551
}
532552

533553
if len(finding.References) > 0 {
534-
fmt.Fprintf(file, "References:\n")
554+
if _, err := fmt.Fprintf(file, "References:\n"); err != nil {
555+
return fmt.Errorf("failed to write references header: %w", err)
556+
}
535557
for _, ref := range finding.References {
536-
fmt.Fprintf(file, " - %s\n", ref)
558+
if _, err := fmt.Fprintf(file, " - %s\n", ref); err != nil {
559+
return fmt.Errorf("failed to write reference: %w", err)
560+
}
537561
}
538562
}
539563
}
540564

541-
fmt.Fprintf(file, "\n=====================================\n\n")
565+
if _, err := fmt.Fprintf(file, "\n=====================================\n\n"); err != nil {
566+
return fmt.Errorf("failed to write section separator: %w", err)
567+
}
542568
}
543569

544570
return nil

cmd/root.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2252,7 +2252,10 @@ func runMainDiscovery(cmd *cobra.Command, args []string, log *logger.Logger, db
22522252
startTime := time.Now()
22532253

22542254
// Set bug bounty mode and reduce log noise
2255-
os.Setenv("SHELLS_BUG_BOUNTY_MODE", "true")
2255+
if err := os.Setenv("SHELLS_BUG_BOUNTY_MODE", "true"); err != nil {
2256+
fmt.Fprintf(os.Stderr, "Warning: failed to set bug bounty mode: %v\n", err)
2257+
fmt.Fprintf(os.Stderr, "Impact: Some features may not operate in optimized mode\n")
2258+
}
22562259

22572260
// Force clean console output for bug bounty mode
22582261
viper.Set("log.format", "console")

cmd/scan.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"time"
1212

13+
"github.com/CodeMonkeyCybersecurity/shells/internal/httpclient"
1314
"github.com/CodeMonkeyCybersecurity/shells/pkg/security"
1415
"github.com/CodeMonkeyCybersecurity/shells/pkg/types"
1516
"github.com/google/uuid"
@@ -622,7 +623,7 @@ func executeDirScan(target string, options map[string]string, scanID string) ([]
622623
if err != nil {
623624
continue
624625
}
625-
resp.Body.Close()
626+
httpclient.CloseBody(resp)
626627

627628
if resp.StatusCode < 400 {
628629
log.Infow("Directory found", "path", path, "status", resp.StatusCode)
@@ -645,7 +646,7 @@ func executeSCIMScan(target string, options map[string]string, scanID string) ([
645646
if err != nil {
646647
continue
647648
}
648-
resp.Body.Close()
649+
httpclient.CloseBody(resp)
649650

650651
if resp.StatusCode < 400 {
651652
log.Infow("SCIM endpoint found", "path", path, "status", resp.StatusCode)
@@ -684,7 +685,7 @@ func executeOAuth2Scan(target string, options map[string]string, scanID string)
684685
if err != nil {
685686
continue
686687
}
687-
resp.Body.Close()
688+
httpclient.CloseBody(resp)
688689

689690
if resp.StatusCode < 400 {
690691
log.Infow("OAuth2 endpoint found", "path", path, "status", resp.StatusCode)
@@ -731,7 +732,7 @@ func executeJSScan(target string, options map[string]string, scanID string) ([]t
731732
if err != nil {
732733
continue
733734
}
734-
resp.Body.Close()
735+
httpclient.CloseBody(resp)
735736

736737
contentType := resp.Header.Get("Content-Type")
737738
if resp.StatusCode < 400 && strings.Contains(contentType, "javascript") {
@@ -754,7 +755,7 @@ func executeAPISecan(target string, options map[string]string, scanID string) ([
754755
if err != nil {
755756
continue
756757
}
757-
resp.Body.Close()
758+
httpclient.CloseBody(resp)
758759

759760
if resp.StatusCode < 400 {
760761
log.Infow("API endpoint found", "path", path, "status", resp.StatusCode)
@@ -921,7 +922,7 @@ func legacyDirScan(target string, options map[string]string) error {
921922
if err != nil {
922923
continue
923924
}
924-
resp.Body.Close()
925+
httpclient.CloseBody(resp)
925926

926927
if resp.StatusCode < 400 {
927928
log.Infow("Directory found", "path", path, "status", resp.StatusCode)
@@ -944,7 +945,7 @@ func legacySCIMScan(target string, options map[string]string) error {
944945
if err != nil {
945946
continue
946947
}
947-
resp.Body.Close()
948+
httpclient.CloseBody(resp)
948949

949950
if resp.StatusCode < 400 {
950951
log.Infow("SCIM endpoint found", "path", path, "status", resp.StatusCode)

cmd/scan_aws.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ func runAWSValidate(cmd *cobra.Command, args []string) error {
274274
fmt.Printf("Checking AWS credentials (profile: %s)... ", awsProfile)
275275

276276
// Set AWS profile environment variable
277-
os.Setenv("AWS_PROFILE", awsProfile)
277+
if err := os.Setenv("AWS_PROFILE", awsProfile); err != nil {
278+
return fmt.Errorf("failed to set AWS_PROFILE environment variable: %w", err)
279+
}
278280

279281
// Test AWS credentials by calling STS GetCallerIdentity
280282
cmd := exec.Command("aws", "sts", "get-caller-identity", "--profile", awsProfile)

internal/httpclient/factory.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ package httpclient
44
import (
55
"context"
66
"fmt"
7+
"io"
78
"net"
89
"net/http"
10+
"os"
911
"time"
1012
)
1113

@@ -238,3 +240,38 @@ func DoWithContext(ctx context.Context, client *http.Client, req *http.Request)
238240

239241
return resp, nil
240242
}
243+
244+
// CloseBody safely closes an HTTP response body and logs any errors.
245+
// This is critical for connection pool health - unclosed bodies leak HTTP connections.
246+
//
247+
// Usage:
248+
// defer httpclient.CloseBody(resp)
249+
//
250+
// Philosophy alignment: Transparent error handling (human-centric principle)
251+
func CloseBody(resp *http.Response) {
252+
if resp == nil || resp.Body == nil {
253+
return
254+
}
255+
256+
// Drain body before closing to enable connection reuse
257+
// HTTP/1.1 connections can only be reused if body is fully read
258+
_, _ = io.Copy(io.Discard, resp.Body)
259+
260+
if err := resp.Body.Close(); err != nil {
261+
fmt.Fprintf(os.Stderr, "Warning: failed to close HTTP response body: %v\n", err)
262+
fmt.Fprintf(os.Stderr, "Impact: HTTP connection may leak (pool exhaustion possible)\n")
263+
}
264+
}
265+
266+
// MustCloseBody is like CloseBody but panics on error (use only in tests)
267+
func MustCloseBody(resp *http.Response) {
268+
if resp == nil || resp.Body == nil {
269+
return
270+
}
271+
272+
_, _ = io.Copy(io.Discard, resp.Body)
273+
274+
if err := resp.Body.Close(); err != nil {
275+
panic(fmt.Sprintf("failed to close response body: %v", err))
276+
}
277+
}

internal/plugins/api/graphql.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/CodeMonkeyCybersecurity/shells/internal/core"
14+
"github.com/CodeMonkeyCybersecurity/shells/internal/httpclient"
1415
"github.com/CodeMonkeyCybersecurity/shells/pkg/types"
1516
)
1617

@@ -311,7 +312,7 @@ func (s *graphQLScanner) isGraphQLEndpoint(ctx context.Context, url string) bool
311312
if err != nil {
312313
return false
313314
}
314-
defer resp.Body.Close()
315+
defer httpclient.CloseBody(resp)
315316

316317
body, err := io.ReadAll(resp.Body)
317318
if err != nil {
@@ -474,7 +475,7 @@ func (s *graphQLScanner) testIntrospection(ctx context.Context, endpoint string,
474475
if err != nil {
475476
return findings
476477
}
477-
defer resp.Body.Close()
478+
defer httpclient.CloseBody(resp)
478479

479480
body, err := io.ReadAll(resp.Body)
480481
if err != nil {
@@ -833,7 +834,7 @@ func (s *graphQLScanner) testBatchingAttacks(ctx context.Context, endpoint strin
833834
if err != nil {
834835
return findings
835836
}
836-
defer resp.Body.Close()
837+
defer httpclient.CloseBody(resp)
837838

838839
body, err := io.ReadAll(resp.Body)
839840
if err != nil {
@@ -937,7 +938,7 @@ func (s *graphQLScanner) testQueryComplexity(ctx context.Context, endpoint strin
937938
if err != nil {
938939
continue
939940
}
940-
resp.Body.Close()
941+
httpclient.CloseBody(resp)
941942

942943
// If complex queries execute without limits, it's a finding
943944
if resp.StatusCode == 200 && duration > 3*time.Second {
@@ -1018,7 +1019,7 @@ func (s *graphQLScanner) testDepthLimits(ctx context.Context, endpoint string, o
10181019
if err != nil {
10191020
return findings
10201021
}
1021-
defer resp.Body.Close()
1022+
defer httpclient.CloseBody(resp)
10221023

10231024
if resp.StatusCode == 200 {
10241025
severity := types.SeverityMedium
@@ -1098,7 +1099,7 @@ func (s *graphQLScanner) testFieldSuggestion(ctx context.Context, endpoint strin
10981099
if err != nil {
10991100
continue
11001101
}
1101-
defer resp.Body.Close()
1102+
defer httpclient.CloseBody(resp)
11021103

11031104
body, err := io.ReadAll(resp.Body)
11041105
if err != nil {
@@ -1177,7 +1178,7 @@ func (s *graphQLScanner) testAuthorizationBypass(ctx context.Context, endpoint s
11771178
if err != nil {
11781179
continue
11791180
}
1180-
defer resp.Body.Close()
1181+
defer httpclient.CloseBody(resp)
11811182

11821183
body, err := io.ReadAll(resp.Body)
11831184
if err != nil {
@@ -1329,7 +1330,7 @@ func (s *graphQLScanner) testInjectionQuery(ctx context.Context, endpoint, query
13291330
if err != nil {
13301331
return false
13311332
}
1332-
defer resp.Body.Close()
1333+
defer httpclient.CloseBody(resp)
13331334

13341335
body, err := io.ReadAll(resp.Body)
13351336
if err != nil {
@@ -1412,7 +1413,7 @@ func (s *graphQLScanner) testInformationDisclosure(ctx context.Context, endpoint
14121413
if err != nil {
14131414
continue
14141415
}
1415-
defer resp.Body.Close()
1416+
defer httpclient.CloseBody(resp)
14161417

14171418
body, err := io.ReadAll(resp.Body)
14181419
if err != nil {
@@ -1494,7 +1495,7 @@ func (s *graphQLScanner) testCSRFProtection(ctx context.Context, endpoint string
14941495
if err != nil {
14951496
continue
14961497
}
1497-
defer resp.Body.Close()
1498+
defer httpclient.CloseBody(resp)
14981499

14991500
if resp.StatusCode == http.StatusOK {
15001501
var result GraphQLResponse
@@ -1569,7 +1570,7 @@ func (s *graphQLScanner) testRateLimiting(ctx context.Context, endpoint string,
15691570
if resp.StatusCode == http.StatusOK {
15701571
successCount++
15711572
}
1572-
resp.Body.Close()
1573+
httpclient.CloseBody(resp)
15731574

15741575
// Don't delay between requests to test rate limiting
15751576
}

0 commit comments

Comments
 (0)