Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/qdrant/go-client v1.15.2
github.com/spf13/cobra v1.10.2
golang.org/x/oauth2 v0.27.0
golang.org/x/sync v0.18.0
google.golang.org/api v0.186.0
google.golang.org/grpc v1.71.0-dev
gopkg.in/yaml.v3 v3.0.1
Expand Down Expand Up @@ -41,7 +42,6 @@ require (
go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/time v0.14.0 // indirect
Expand Down
27 changes: 21 additions & 6 deletions internal/integrations/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ import (
"time"

"github.com/google/generative-ai-go/genai"
"golang.org/x/sync/errgroup"
"google.golang.org/api/option"
)

// maxBatchConcurrency limits the number of concurrent embedding requests in EmbedBatch.
const maxBatchConcurrency = 10

// Embedder generates embeddings using Gemini or OpenAI.
type Embedder struct {
provider Provider
Expand Down Expand Up @@ -166,19 +170,30 @@ func (e *Embedder) embedOpenAI(ctx context.Context, text string) ([]float32, err
})
}

// EmbedBatch generates embeddings for multiple texts.
// EmbedBatch generates embeddings for multiple texts in parallel.
// It uses up to maxBatchConcurrency concurrent requests to the embedding API.
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, fmt.Errorf("texts cannot be empty")
}

embeddings := make([][]float32, len(texts))
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(maxBatchConcurrency)

for i, text := range texts {
embedding, err := e.Embed(ctx, text)
if err != nil {
return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
g.Go(func() error {
embedding, err := e.Embed(ctx, text)
Comment on lines +185 to +186
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semaphore acquisition (sem <- struct{}{}) does not respect context cancellation. If ctx is cancelled while this send is blocked (because the semaphore is full), the goroutine will remain stuck waiting for capacity, delaying cancellation and potentially causing EmbedBatch to hang until other calls return. Use a select on ctx.Done() when acquiring the semaphore (or rely on errgroup.SetLimit, which avoids this pattern).

Copilot uses AI. Check for mistakes.
if err != nil {
return fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
return nil
})
}
Comment on lines 180 to +193
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semaphore is acquired inside the goroutine, so EmbedBatch still spawns one goroutine per input text. For large batches this can create thousands of blocked goroutines and unnecessary memory/GC pressure; the concurrency cap only limits in-flight API calls, not goroutine count. Consider using errgroup.Group.SetLimit(maxBatchConcurrency) (or acquiring the semaphore before starting each goroutine / using a worker pool) so goroutine creation is also bounded.

Copilot uses AI. Check for mistakes.

if err := g.Wait(); err != nil {
return nil, err
Comment on lines +181 to +196
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n internal/integrations/ai/embedder.go | head -220 | tail -60

Repository: similigh/simili-bot

Length of output: 2061


🏁 Script executed:

cat -n internal/integrations/ai/embedder.go | head -210 | tail -80

Repository: similigh/simili-bot

Length of output: 2724


🏁 Script executed:

wc -l internal/integrations/ai/embedder.go

Repository: similigh/simili-bot

Length of output: 103


🏁 Script executed:

# Check the Go version requirement and imports
head -20 internal/integrations/ai/embedder.go

Repository: similigh/simili-bot

Length of output: 441


🏁 Script executed:

# Search for other uses of errgroup in the codebase to see patterns
rg "errgroup\." --type go -B 2 -A 5

Repository: similigh/simili-bot

Length of output: 567


🏁 Script executed:

# Check if there are any tests for EmbedBatch to understand expected behavior
find . -name "*test.go" -type f | xargs grep -l "EmbedBatch" 2>/dev/null

Repository: similigh/simili-bot

Length of output: 106


🏁 Script executed:

cat -n internal/integrations/ai/embedder_test.go

Repository: similigh/simili-bot

Length of output: 8732


🏁 Script executed:

rg "maxBatchConcurrency" internal/integrations/ai/embedder.go -B 2 -A 2

Repository: similigh/simili-bot

Length of output: 654


🏁 Script executed:

# Check for Go version constraints in go.mod
head -5 go.mod

Repository: similigh/simili-bot

Length of output: 122


🏁 Script executed:

# Look at errgroup usage patterns to understand the expected behavior better
# Also check if there's documentation or comments about context handling
rg "Wait\(\)" internal/integrations/ai/embedder.go -B 5 -A 5

Repository: similigh/simili-bot

Length of output: 196


🏁 Script executed:

# Look for any race detector runs or tests
find . -name "*test.go" -type f -exec grep -l "race" {} \;

Repository: similigh/simili-bot

Length of output: 216


Stop queueing new embeds once the errgroup context is canceled.

g.Go() blocks when SetLimit is full and does not watch the derived context. After the first failure or an upstream cancellation, this loop continues walking the rest of texts and enqueuing canceled work, which makes large failing batches return much more slowly than necessary.

Suggested fix
-	g, ctx := errgroup.WithContext(ctx)
+	g, groupCtx := errgroup.WithContext(ctx)
 	g.SetLimit(maxBatchConcurrency)
+	stoppedEarly := false
 
 	for i, text := range texts {
+		if groupCtx.Err() != nil {
+			stoppedEarly = true
+			break
+		}
 		g.Go(func() error {
-			embedding, err := e.Embed(ctx, text)
+			embedding, err := e.Embed(groupCtx, text)
 			if err != nil {
 				return fmt.Errorf("failed to embed text %d: %w", i, err)
 			}
 			embeddings[i] = embedding
 			return nil
 		})
 	}
 
 	if err := g.Wait(); err != nil {
 		return nil, err
 	}
+	if stoppedEarly && ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(maxBatchConcurrency)
for i, text := range texts {
embedding, err := e.Embed(ctx, text)
if err != nil {
return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
g.Go(func() error {
embedding, err := e.Embed(ctx, text)
if err != nil {
return fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
g, groupCtx := errgroup.WithContext(ctx)
g.SetLimit(maxBatchConcurrency)
stoppedEarly := false
for i, text := range texts {
if groupCtx.Err() != nil {
stoppedEarly = true
break
}
g.Go(func() error {
embedding, err := e.Embed(groupCtx, text)
if err != nil {
return fmt.Errorf("failed to embed text %d: %w", i, err)
}
embeddings[i] = embedding
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
if stoppedEarly && ctx.Err() != nil {
return nil, ctx.Err()
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@internal/integrations/ai/embedder.go` around lines 181 - 196, Stop enqueuing
new goroutines once the errgroup context is canceled by checking the derived
context before calling g.Go in the loop: after obtaining g, ctx :=
errgroup.WithContext(ctx) and g.SetLimit(maxBatchConcurrency), inside the for i,
text := range texts loop test if ctx.Err() != nil (or select on ctx.Done()) and
break/return early to avoid queuing canceled work; also capture loop variables
(use i, text := i, text inside the loop) so the closure passed to g.Go
references the correct values when invoking e.Embed(ctx, text) and writing
embeddings[i].

}

return embeddings, nil
Expand Down
143 changes: 143 additions & 0 deletions internal/integrations/ai/embedder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
package ai

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)

func TestInferEmbeddingDimensions(t *testing.T) {
Expand Down Expand Up @@ -89,6 +96,142 @@ func TestIsLikelyGeminiEmbeddingModel(t *testing.T) {
}
}

func TestEmbedBatch_EmptyInput(t *testing.T) {
srv, _ := statusServer([]int{200}, func(_ int) []byte { return embeddingOKBody() })
defer srv.Close()

e := newTestEmbedder(srv.URL)
_, err := e.EmbedBatch(context.Background(), []string{})
if err == nil {
t.Fatal("expected error for empty input, got nil")
}
}

func TestEmbedBatch_ResultsPreserveOrder(t *testing.T) {
// The handler parses the input text ("text-N") and encodes N+1 as the first
// embedding value. This lets us deterministically verify that results[i]
// corresponds to texts[i] regardless of goroutine scheduling order.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req struct {
Input string `json:"input"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
var idx int
if _, err := fmt.Sscanf(req.Input, "text-%d", &idx); err != nil {
http.Error(w, "invalid input format", http.StatusBadRequest)
return
}

type embItem struct {
Embedding []float64 `json:"embedding"`
}
type embResp struct {
Data []embItem `json:"data"`
}
b, _ := json.Marshal(embResp{Data: []embItem{{Embedding: []float64{float64(idx + 1), 0.0, 0.0}}}})
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write(b)
}))
defer srv.Close()

e := newTestEmbedder(srv.URL)
e.retryConfig = fastRetry

texts := make([]string, 20)
for i := range texts {
texts[i] = fmt.Sprintf("text-%d", i)
}

results, err := e.EmbedBatch(context.Background(), texts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != len(texts) {
t.Fatalf("expected %d results, got %d", len(texts), len(results))
}
for i, emb := range results {
if len(emb) == 0 {
t.Errorf("result[%d] is empty", i)
continue
}
// texts[i] = "text-i", so the handler encodes i+1 as emb[0].
if got, want := emb[0], float32(i+1); got != want {
t.Errorf("result[%d][0] = %v, want %v (order not preserved)", i, got, want)
}
}
}
Comment on lines +110 to +166
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestEmbedBatch_ResultsPreserveOrder doesn't actually assert order preservation. The handler returns embeddings based on request arrival order (a shared counter), but the test only checks that each result is non-empty, so it would pass even if EmbedBatch returned embeddings in the wrong order. To verify ordering, make the response embedding depend deterministically on the request's input text (e.g., parse the JSON body and encode the input index into the embedding) and assert results[i] matches texts[i].

Copilot uses AI. Check for mistakes.

func TestEmbedBatch_PropagatesError(t *testing.T) {
// First request succeeds, subsequent ones fail with a non-retryable 400.
var counter atomic.Int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := counter.Add(1)
w.Header().Set("Content-Type", "application/json")
if n == 1 {
w.WriteHeader(200)
_, _ = w.Write(embeddingOKBody())
} else {
w.WriteHeader(400)
_, _ = w.Write([]byte(`{"error":{"message":"bad request"}}`))
}
}))
defer srv.Close()

e := newTestEmbedder(srv.URL)
e.retryConfig = fastRetry

_, err := e.EmbedBatch(context.Background(), []string{"a", "b", "c"})
if err == nil {
t.Fatal("expected error from failed embedding, got nil")
}
}

func TestEmbedBatch_ConcurrencyLimit(t *testing.T) {
// Verify that EmbedBatch honours maxBatchConcurrency: concurrent in-flight
// requests should never exceed the limit.
var (
inFlight atomic.Int32
maxObserved atomic.Int32
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cur := inFlight.Add(1)
defer inFlight.Add(-1)
// Update max observed — simple CAS loop.
for {
old := maxObserved.Load()
if cur <= old || maxObserved.CompareAndSwap(old, cur) {
break
}
}
// Hold the request briefly so multiple goroutines overlap in-flight,
// making the concurrency cap observable.
time.Sleep(5 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
_, _ = w.Write(embeddingOKBody())
}))
defer srv.Close()

e := newTestEmbedder(srv.URL)
e.retryConfig = fastRetry

texts := make([]string, maxBatchConcurrency*3)
for i := range texts {
texts[i] = fmt.Sprintf("text-%d", i)
}

if _, err := e.EmbedBatch(context.Background(), texts); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := maxObserved.Load(); got > int32(maxBatchConcurrency) {
t.Errorf("max concurrent requests = %d, want <= %d", got, maxBatchConcurrency)
}
}
Comment on lines +193 to +233
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TestEmbedBatch_ConcurrencyLimit may be ineffective/flaky because the handler responds immediately, so requests may not overlap enough to observe concurrency >1 (and the test only asserts maxObserved <= maxBatchConcurrency). Consider adding a small synchronized delay/barrier in the handler (e.g., block until N requests are in-flight or sleep for a short duration) so the test reliably exercises the concurrency limit and would fail if the limit were removed.

Copilot uses AI. Check for mistakes.

func TestNewEmbedderRejectsLegacyGeminiModels(t *testing.T) {
// Fake a Gemini API key so provider resolution picks Gemini.
t.Setenv("GEMINI_API_KEY", "fake-key-for-test")
Expand Down
Loading