-
Notifications
You must be signed in to change notification settings - Fork 16
perf: parallelize EmbedBatch using goroutines and errgroup #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
df59ed4
a957c34
b6c2825
9b2f394
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,9 +15,13 @@ import ( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "time" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "github.com/google/generative-ai-go/genai" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "golang.org/x/sync/errgroup" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "google.golang.org/api/option" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // maxBatchConcurrency limits the number of concurrent embedding requests in EmbedBatch. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const maxBatchConcurrency = 10 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Embedder generates embeddings using Gemini or OpenAI. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type Embedder struct { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| provider Provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -166,19 +170,30 @@ func (e *Embedder) embedOpenAI(ctx context.Context, text string) ([]float32, err | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // EmbedBatch generates embeddings for multiple texts. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // EmbedBatch generates embeddings for multiple texts in parallel. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // It uses up to maxBatchConcurrency concurrent requests to the embedding API. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(texts) == 0 { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return nil, fmt.Errorf("texts cannot be empty") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings := make([][]float32, len(texts)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| g, ctx := errgroup.WithContext(ctx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| g.SetLimit(maxBatchConcurrency) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i, text := range texts { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| embedding, err := e.Embed(ctx, text) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return nil, fmt.Errorf("failed to embed text %d: %w", i, err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings[i] = embedding | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| g.Go(func() error { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| embedding, err := e.Embed(ctx, text) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return fmt.Errorf("failed to embed text %d: %w", i, err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| embeddings[i] = embedding | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return nil | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
180
to
+193
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if err := g.Wait(); err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return nil, err | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+181
to
+196
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: cat -n internal/integrations/ai/embedder.go | head -220 | tail -60Repository: similigh/simili-bot Length of output: 2061 🏁 Script executed: cat -n internal/integrations/ai/embedder.go | head -210 | tail -80Repository: similigh/simili-bot Length of output: 2724 🏁 Script executed: wc -l internal/integrations/ai/embedder.goRepository: similigh/simili-bot Length of output: 103 🏁 Script executed: # Check the Go version requirement and imports
head -20 internal/integrations/ai/embedder.goRepository: similigh/simili-bot Length of output: 441 🏁 Script executed: # Search for other uses of errgroup in the codebase to see patterns
rg "errgroup\." --type go -B 2 -A 5Repository: similigh/simili-bot Length of output: 567 🏁 Script executed: # Check if there are any tests for EmbedBatch to understand expected behavior
find . -name "*test.go" -type f | xargs grep -l "EmbedBatch" 2>/dev/nullRepository: similigh/simili-bot Length of output: 106 🏁 Script executed: cat -n internal/integrations/ai/embedder_test.goRepository: similigh/simili-bot Length of output: 8732 🏁 Script executed: rg "maxBatchConcurrency" internal/integrations/ai/embedder.go -B 2 -A 2Repository: similigh/simili-bot Length of output: 654 🏁 Script executed: # Check for Go version constraints in go.mod
head -5 go.modRepository: similigh/simili-bot Length of output: 122 🏁 Script executed: # Look at errgroup usage patterns to understand the expected behavior better
# Also check if there's documentation or comments about context handling
rg "Wait\(\)" internal/integrations/ai/embedder.go -B 5 -A 5Repository: similigh/simili-bot Length of output: 196 🏁 Script executed: # Look for any race detector runs or tests
find . -name "*test.go" -type f -exec grep -l "race" {} \;Repository: similigh/simili-bot Length of output: 216 Stop queueing new embeds once the errgroup context is canceled.
Suggested fix- g, ctx := errgroup.WithContext(ctx)
+ g, groupCtx := errgroup.WithContext(ctx)
g.SetLimit(maxBatchConcurrency)
+ stoppedEarly := false
for i, text := range texts {
+ if groupCtx.Err() != nil {
+ stoppedEarly = true
+ break
+ }
g.Go(func() error {
- embedding, err := e.Embed(ctx, text)
+ embedding, err := e.Embed(groupCtx, text)
if err != nil {
return fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
+ if stoppedEarly && ctx.Err() != nil {
+ return nil, ctx.Err()
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return embeddings, nil | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,15 @@ | |
| package ai | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "strings" | ||
| "sync/atomic" | ||
| "testing" | ||
| "time" | ||
| ) | ||
|
|
||
| func TestInferEmbeddingDimensions(t *testing.T) { | ||
|
|
@@ -89,6 +96,142 @@ func TestIsLikelyGeminiEmbeddingModel(t *testing.T) { | |
| } | ||
| } | ||
|
|
||
| func TestEmbedBatch_EmptyInput(t *testing.T) { | ||
| srv, _ := statusServer([]int{200}, func(_ int) []byte { return embeddingOKBody() }) | ||
| defer srv.Close() | ||
|
|
||
| e := newTestEmbedder(srv.URL) | ||
| _, err := e.EmbedBatch(context.Background(), []string{}) | ||
| if err == nil { | ||
| t.Fatal("expected error for empty input, got nil") | ||
| } | ||
| } | ||
|
|
||
| func TestEmbedBatch_ResultsPreserveOrder(t *testing.T) { | ||
| // The handler parses the input text ("text-N") and encodes N+1 as the first | ||
| // embedding value. This lets us deterministically verify that results[i] | ||
| // corresponds to texts[i] regardless of goroutine scheduling order. | ||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| var req struct { | ||
| Input string `json:"input"` | ||
| } | ||
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
| http.Error(w, "bad request", http.StatusBadRequest) | ||
| return | ||
| } | ||
| var idx int | ||
| if _, err := fmt.Sscanf(req.Input, "text-%d", &idx); err != nil { | ||
| http.Error(w, "invalid input format", http.StatusBadRequest) | ||
| return | ||
| } | ||
|
|
||
| type embItem struct { | ||
| Embedding []float64 `json:"embedding"` | ||
| } | ||
| type embResp struct { | ||
| Data []embItem `json:"data"` | ||
| } | ||
| b, _ := json.Marshal(embResp{Data: []embItem{{Embedding: []float64{float64(idx + 1), 0.0, 0.0}}}}) | ||
| w.Header().Set("Content-Type", "application/json") | ||
| w.WriteHeader(200) | ||
| _, _ = w.Write(b) | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| e := newTestEmbedder(srv.URL) | ||
| e.retryConfig = fastRetry | ||
|
|
||
| texts := make([]string, 20) | ||
| for i := range texts { | ||
| texts[i] = fmt.Sprintf("text-%d", i) | ||
| } | ||
|
|
||
| results, err := e.EmbedBatch(context.Background(), texts) | ||
| if err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
| if len(results) != len(texts) { | ||
| t.Fatalf("expected %d results, got %d", len(texts), len(results)) | ||
| } | ||
| for i, emb := range results { | ||
| if len(emb) == 0 { | ||
| t.Errorf("result[%d] is empty", i) | ||
| continue | ||
| } | ||
| // texts[i] = "text-i", so the handler encodes i+1 as emb[0]. | ||
| if got, want := emb[0], float32(i+1); got != want { | ||
| t.Errorf("result[%d][0] = %v, want %v (order not preserved)", i, got, want) | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+110
to
+166
|
||
|
|
||
| func TestEmbedBatch_PropagatesError(t *testing.T) { | ||
| // First request succeeds, subsequent ones fail with a non-retryable 400. | ||
| var counter atomic.Int32 | ||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| n := counter.Add(1) | ||
| w.Header().Set("Content-Type", "application/json") | ||
| if n == 1 { | ||
| w.WriteHeader(200) | ||
| _, _ = w.Write(embeddingOKBody()) | ||
| } else { | ||
| w.WriteHeader(400) | ||
| _, _ = w.Write([]byte(`{"error":{"message":"bad request"}}`)) | ||
| } | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| e := newTestEmbedder(srv.URL) | ||
| e.retryConfig = fastRetry | ||
|
|
||
| _, err := e.EmbedBatch(context.Background(), []string{"a", "b", "c"}) | ||
| if err == nil { | ||
| t.Fatal("expected error from failed embedding, got nil") | ||
| } | ||
| } | ||
|
|
||
| func TestEmbedBatch_ConcurrencyLimit(t *testing.T) { | ||
| // Verify that EmbedBatch honours maxBatchConcurrency: concurrent in-flight | ||
| // requests should never exceed the limit. | ||
| var ( | ||
| inFlight atomic.Int32 | ||
| maxObserved atomic.Int32 | ||
| ) | ||
| srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| cur := inFlight.Add(1) | ||
| defer inFlight.Add(-1) | ||
| // Update max observed — simple CAS loop. | ||
| for { | ||
| old := maxObserved.Load() | ||
| if cur <= old || maxObserved.CompareAndSwap(old, cur) { | ||
| break | ||
| } | ||
| } | ||
| // Hold the request briefly so multiple goroutines overlap in-flight, | ||
| // making the concurrency cap observable. | ||
| time.Sleep(5 * time.Millisecond) | ||
| w.Header().Set("Content-Type", "application/json") | ||
| w.WriteHeader(200) | ||
| _, _ = w.Write(embeddingOKBody()) | ||
| })) | ||
| defer srv.Close() | ||
|
|
||
| e := newTestEmbedder(srv.URL) | ||
| e.retryConfig = fastRetry | ||
|
|
||
| texts := make([]string, maxBatchConcurrency*3) | ||
| for i := range texts { | ||
| texts[i] = fmt.Sprintf("text-%d", i) | ||
| } | ||
|
|
||
| if _, err := e.EmbedBatch(context.Background(), texts); err != nil { | ||
| t.Fatalf("unexpected error: %v", err) | ||
| } | ||
| if got := maxObserved.Load(); got > int32(maxBatchConcurrency) { | ||
| t.Errorf("max concurrent requests = %d, want <= %d", got, maxBatchConcurrency) | ||
| } | ||
| } | ||
|
Comment on lines
+193
to
+233
|
||
|
|
||
| func TestNewEmbedderRejectsLegacyGeminiModels(t *testing.T) { | ||
| // Fake a Gemini API key so provider resolution picks Gemini. | ||
| t.Setenv("GEMINI_API_KEY", "fake-key-for-test") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Semaphore acquisition (
sem <- struct{}{}) does not respect context cancellation. Ifctxis cancelled while this send is blocked (because the semaphore is full), the goroutine will remain stuck waiting for capacity, delaying cancellation and potentially causingEmbedBatchto hang until other calls return. Use aselectonctx.Done()when acquiring the semaphore (or rely onerrgroup.SetLimit, which avoids this pattern).