Skip to content

Commit c1bbf17

Browse files
authored
Merge pull request #708 from projectdiscovery/dwisiswant0/fix/httputil/racy-in-ResponseChain
fix(httputil): racy in `ResponseChain`
2 parents 332bd1c + 4a19d93 commit c1bbf17

3 files changed

Lines changed: 125 additions & 14 deletions

File tree

http/normalization.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"net/http"
88
"strings"
9+
"sync"
910

1011
"github.com/dsnet/compress/brotli"
1112
"github.com/klauspost/compress/gzip"
@@ -18,6 +19,41 @@ import (
1819
stringsutil "github.com/projectdiscovery/utils/strings"
1920
)
2021

22+
const (
23+
// DefaultChunkSize defines the default chunk size for reading response
24+
// bodies.
25+
//
26+
// Use [SetChunkSize] to adjust the size.
27+
DefaultChunkSize = 32 * 1024 // 32KB
28+
)
29+
30+
var (
31+
chunkSize = DefaultChunkSize
32+
chunkPool = sync.Pool{
33+
New: func() any {
34+
b := make([]byte, chunkSize)
35+
return &b
36+
},
37+
}
38+
)
39+
40+
// SetChunkSize sets the chunk size for reading response bodies.
41+
//
42+
// If size is less than or equal to zero, it resets to the default chunk size.
43+
func SetChunkSize(size int) {
44+
if size <= 0 {
45+
size = DefaultChunkSize
46+
}
47+
48+
chunkSize = size
49+
chunkPool = sync.Pool{
50+
New: func() any {
51+
b := make([]byte, chunkSize)
52+
return &b
53+
},
54+
}
55+
}
56+
2157
// limitedBuffer wraps [bytes.Buffer] to prevent capacity growth beyond maxCap.
2258
// This prevents bytes.Buffer.ReadFrom() from over-allocating when it doesn't
2359
// know the final size.
@@ -27,8 +63,9 @@ type limitedBuffer struct {
2763
}
2864

2965
func (lb *limitedBuffer) ReadFrom(r io.Reader) (n int64, err error) {
30-
const chunkSize = 32 * 1024 // 32KB chunks
31-
chunk := make([]byte, chunkSize)
66+
chunkPtr := chunkPool.Get().(*[]byte)
67+
defer chunkPool.Put(chunkPtr)
68+
chunk := *chunkPtr
3269

3370
for {
3471
available := lb.buf.Cap() - lb.buf.Len()

http/respChain.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -305,21 +305,23 @@ func (r *ResponseChain) FullResponse() *bytes.Buffer {
305305

306306
// FullResponseBytes returns the current response (headers+body) as byte slice.
307307
//
308-
// The returned slice is valid only until Close() is called.
309-
// Note: This creates a new buffer internally which is returned to the pool.
308+
// The returned slice is a copy and remains valid even after Close() is called.
310309
func (r *ResponseChain) FullResponseBytes() []byte {
311-
buf := r.FullResponse()
312-
defer putBuffer(buf)
310+
size := r.headers.Len() + r.body.Len()
311+
buf := make([]byte, size)
312+
313+
copy(buf, r.headers.Bytes())
314+
copy(buf[r.headers.Len():], r.body.Bytes())
313315

314-
return buf.Bytes()
316+
return buf
315317
}
316318

317319
// FullResponseString returns the current response as string in the chain.
318320
//
319-
// The returned string is valid only until Close() is called.
320-
// This is a zero-copy operation for performance.
321+
// The returned string is a copy and remains valid even after Close() is called.
322+
// This is a zero-copy operation from the byte slice.
321323
func (r *ResponseChain) FullResponseString() string {
322-
return conversion.String(r.FullResponse().Bytes())
324+
return conversion.String(r.FullResponseBytes())
323325
}
324326

325327
// previous updates response pointer to previous response
@@ -372,10 +374,15 @@ func (r *ResponseChain) Fill() error {
372374

373375
// Close the response chain and releases the buffers.
374376
func (r *ResponseChain) Close() {
375-
putBuffer(r.headers)
376-
putBuffer(r.body)
377-
r.headers = nil
378-
r.body = nil
377+
if r.headers != nil {
378+
putBuffer(r.headers)
379+
r.headers = nil
380+
}
381+
382+
if r.body != nil {
383+
putBuffer(r.body)
384+
r.body = nil
385+
}
379386
}
380387

381388
// Has returns true if the response chain has a response

http/respChain_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,70 @@ func TestResponseChain_BurstWithPoolExhaustion(t *testing.T) {
11151115
require.NoError(t, err)
11161116
}
11171117
}
1118+
1119+
func TestResponseChain_FullResponseBytes_Race(t *testing.T) {
1120+
resp := &http.Response{
1121+
StatusCode: 200,
1122+
Status: "200 OK",
1123+
}
1124+
chain := NewResponseChain(resp, 1024)
1125+
chain.headers.WriteString("Header: Value\r\n")
1126+
chain.body.WriteString("Body Content")
1127+
1128+
data := chain.FullResponseBytes()
1129+
initialContent := string(data)
1130+
1131+
// Trigger buffer reuse
1132+
// We need to get a buffer from the pool.
1133+
1134+
found := false
1135+
for i := 0; i < 100; i++ {
1136+
b := getBuffer()
1137+
b.WriteString("OVERWRITTEN_DATA_XXXXXXXXXXXXXXXX")
1138+
1139+
if string(data) != initialContent {
1140+
found = true
1141+
t.Logf("Iteration %d: Content changed to %q", i, string(data))
1142+
1143+
break
1144+
}
1145+
1146+
putBuffer(b)
1147+
}
1148+
1149+
if found {
1150+
t.Fatalf("Race detected! Content changed from %q", initialContent)
1151+
}
1152+
}
1153+
1154+
func TestResponseChain_Close_Idempotency(t *testing.T) {
1155+
resp := &http.Response{
1156+
StatusCode: 200,
1157+
}
1158+
rc := NewResponseChain(resp, 1024)
1159+
1160+
rc.Close()
1161+
1162+
defer func() {
1163+
if r := recover(); r != nil {
1164+
t.Errorf("Close() panicked on second call: %v", r)
1165+
}
1166+
}()
1167+
1168+
rc.Close()
1169+
}
1170+
1171+
func TestLimitedBuffer_Pool(t *testing.T) {
1172+
buf := getBuffer()
1173+
defer putBuffer(buf)
1174+
1175+
lb := &limitedBuffer{buf: buf, maxCap: 1024 * 1024}
1176+
data := bytes.Repeat([]byte("A"), 100*1024) // 100KB
1177+
r := bytes.NewReader(data)
1178+
1179+
n, err := lb.ReadFrom(r)
1180+
require.NoError(t, err)
1181+
require.Equal(t, int64(len(data)), n)
1182+
require.Equal(t, len(data), buf.Len())
1183+
require.Equal(t, data, buf.Bytes())
1184+
}

0 commit comments

Comments
 (0)