Skip to content

Commit 643373b

Browse files
authored
fix: more client optimizations (#8)
1 parent d2e1fde commit 643373b

2 files changed

Lines changed: 211 additions & 6 deletions

File tree

client/client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package client
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"net"
78
"sync"
89
"time"
@@ -171,18 +172,17 @@ func (c *Client) GetBytes(ctx context.Context, hashBytes []byte) ([]byte, error)
171172
return nil, fmt.Errorf("failed to write hash: %w", err)
172173
}
173174

174-
// Read response
175+
// Read response - exactly 192 bytes
175176
response := make([]byte, needle.NeedleLength)
176-
n, err := conn.Read(response)
177+
n, err := io.ReadFull(conn, response)
177178
if err != nil {
178179
c.connPool.MarkBad(conn)
180+
if err == io.ErrUnexpectedEOF {
181+
return nil, fmt.Errorf("incomplete response: expected %d bytes, got %d", needle.NeedleLength, n)
182+
}
179183
return nil, fmt.Errorf("failed to read response: %w", err)
180184
}
181185

182-
if n != needle.NeedleLength {
183-
return nil, fmt.Errorf("invalid response length: expected %d bytes, got %d", needle.NeedleLength, n)
184-
}
185-
186186
return response, nil
187187
}
188188

cmd/simple-test/main.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"flag"
7+
"fmt"
8+
"sync"
9+
"sync/atomic"
10+
"time"
11+
12+
"github.com/nomasters/haystack/client"
13+
"github.com/nomasters/haystack/needle"
14+
)
15+
16+
func main() {
17+
var (
18+
endpoint = flag.String("endpoint", "haystack-example-trunk.fly.dev:1337", "Haystack server endpoint")
19+
messages = flag.Int("messages", 100, "Number of messages to test")
20+
writers = flag.Int("writers", 20, "Number of concurrent writers")
21+
readers = flag.Int("readers", 20, "Number of concurrent readers")
22+
getOps = flag.Int("get-ops", 0, "Total number of GET operations (0 = same as messages)")
23+
)
24+
flag.Parse()
25+
26+
// Determine total GET operations
27+
totalGetOps := *getOps
28+
if totalGetOps == 0 {
29+
totalGetOps = *messages
30+
}
31+
32+
fmt.Printf("🧪 Haystack Simple Test\n")
33+
fmt.Printf("=======================\n")
34+
fmt.Printf("Endpoint: %s\n", *endpoint)
35+
fmt.Printf("Messages: %d\n", *messages)
36+
fmt.Printf("Writers: %d\n", *writers)
37+
fmt.Printf("Readers: %d\n", *readers)
38+
fmt.Printf("GET Ops: %d\n", totalGetOps)
39+
fmt.Printf("\n")
40+
41+
// Create client with reasonable pool size
42+
cfg := &client.Config{
43+
Address: *endpoint,
44+
MaxConnections: 300,
45+
ReadTimeout: 10 * time.Second,
46+
WriteTimeout: 10 * time.Second,
47+
}
48+
49+
c, err := client.New(cfg)
50+
if err != nil {
51+
panic(fmt.Sprintf("Failed to create client: %v", err))
52+
}
53+
defer c.Close()
54+
55+
ctx := context.Background()
56+
57+
// Generate test needles
58+
fmt.Printf("Generating %d test needles... ", *messages)
59+
needles := make([]*needle.Needle, *messages)
60+
for i := 0; i < *messages; i++ {
61+
data := make([]byte, 160)
62+
// Put message number in first bytes for verification
63+
data[0] = byte(i >> 8)
64+
data[1] = byte(i)
65+
// Fill rest with random data
66+
rand.Read(data[2:])
67+
68+
n, err := needle.New(data)
69+
if err != nil {
70+
panic(fmt.Sprintf("Failed to create needle: %v", err))
71+
}
72+
needles[i] = n
73+
}
74+
fmt.Println("✅")
75+
76+
// Phase 1: Write all needles
77+
fmt.Printf("\n📝 Phase 1: Writing %d needles with %d writers\n", *messages, *writers)
78+
fmt.Println("----------------------------------------")
79+
80+
var writeWg sync.WaitGroup
81+
var writeSuccess atomic.Int32
82+
var writeErrors atomic.Int32
83+
84+
messagesPerWriter := *messages / *writers
85+
extraMessages := *messages % *writers
86+
87+
startWrite := time.Now()
88+
89+
for w := 0; w < *writers; w++ {
90+
writeWg.Add(1)
91+
start := w * messagesPerWriter
92+
end := start + messagesPerWriter
93+
94+
// Last writer handles extra messages
95+
if w == *writers-1 {
96+
end += extraMessages
97+
}
98+
99+
go func(workerID int, startIdx, endIdx int) {
100+
defer writeWg.Done()
101+
102+
for i := startIdx; i < endIdx; i++ {
103+
if err := c.Set(ctx, needles[i]); err != nil {
104+
writeErrors.Add(1)
105+
fmt.Printf(" Writer %d: Failed needle %d: %v\n", workerID, i, err)
106+
} else {
107+
writeSuccess.Add(1)
108+
}
109+
}
110+
}(w, start, end)
111+
}
112+
113+
writeWg.Wait()
114+
writeDuration := time.Since(startWrite)
115+
116+
fmt.Printf("\nWrite Results:\n")
117+
fmt.Printf(" ✅ Success: %d\n", writeSuccess.Load())
118+
fmt.Printf(" ❌ Errors: %d\n", writeErrors.Load())
119+
fmt.Printf(" ⏱️ Duration: %v\n", writeDuration)
120+
fmt.Printf(" 📊 Throughput: %.2f ops/sec\n", float64(*messages)/writeDuration.Seconds())
121+
122+
// Wait a bit for data to settle
123+
fmt.Println("\nWaiting 2 second for data to settle...")
124+
time.Sleep(2 * time.Second)
125+
126+
// Phase 2: Read needles (sustained test)
127+
fmt.Printf("\n📖 Phase 2: Reading %d operations across %d needles with %d readers\n", totalGetOps, *messages, *readers)
128+
fmt.Println("----------------------------------------")
129+
130+
var readWg sync.WaitGroup
131+
var readSuccess atomic.Int32
132+
var readErrors atomic.Int32
133+
var dataMatches atomic.Int32
134+
135+
startRead := time.Now()
136+
137+
// Distribute total GET operations across readers
138+
opsPerReader := totalGetOps / *readers
139+
extraOps := totalGetOps % *readers
140+
141+
for r := 0; r < *readers; r++ {
142+
readWg.Add(1)
143+
readerOps := opsPerReader
144+
if r == *readers-1 {
145+
readerOps += extraOps
146+
}
147+
148+
go func(readerID int, numOps int) {
149+
defer readWg.Done()
150+
151+
for op := 0; op < numOps; op++ {
152+
// Round-robin through available needles
153+
needleIdx := op % *messages
154+
hash := needles[needleIdx].Hash()
155+
gotNeedle, err := c.Get(ctx, hash)
156+
157+
if err != nil {
158+
readErrors.Add(1)
159+
// Only log first few errors
160+
if readErrors.Load() <= 5 {
161+
fmt.Printf(" Reader %d: Failed op %d, needle %d (hash %x): %v\n", readerID, op, needleIdx, hash[:8], err)
162+
}
163+
} else {
164+
readSuccess.Add(1)
165+
166+
// Verify data matches
167+
originalPayload := needles[needleIdx].Payload()
168+
gotPayload := gotNeedle.Payload()
169+
170+
match := true
171+
for j := 0; j < len(originalPayload); j++ {
172+
if originalPayload[j] != gotPayload[j] {
173+
match = false
174+
break
175+
}
176+
}
177+
178+
if match {
179+
dataMatches.Add(1)
180+
} else {
181+
fmt.Printf(" Reader %d: Data mismatch for needle %d!\n", readerID, needleIdx)
182+
}
183+
}
184+
}
185+
}(r, readerOps)
186+
}
187+
188+
readWg.Wait()
189+
readDuration := time.Since(startRead)
190+
191+
fmt.Printf("\nRead Results:\n")
192+
fmt.Printf(" ✅ Success: %d/%d\n", readSuccess.Load(), totalGetOps)
193+
fmt.Printf(" ✅ Data matches: %d\n", dataMatches.Load())
194+
fmt.Printf(" ❌ Errors: %d\n", readErrors.Load())
195+
fmt.Printf(" ⏱️ Duration: %v\n", readDuration)
196+
fmt.Printf(" 📊 Throughput: %.2f ops/sec\n", float64(totalGetOps)/readDuration.Seconds())
197+
198+
// Summary
199+
fmt.Printf("\n========================================\n")
200+
fmt.Printf("📊 SUMMARY\n")
201+
fmt.Printf("========================================\n")
202+
fmt.Printf("Total write success rate: %.1f%%\n", 100.0*float64(writeSuccess.Load())/float64(*messages))
203+
fmt.Printf("Total read success rate: %.1f%%\n", 100.0*float64(readSuccess.Load())/float64(totalGetOps))
204+
fmt.Printf("Data integrity rate: %.1f%%\n", 100.0*float64(dataMatches.Load())/float64(totalGetOps))
205+
}

0 commit comments

Comments
 (0)