From 81b704a85511626b1c5673d1d776ea65beaf3044 Mon Sep 17 00:00:00 2001 From: MrBrain <176294248+GaosCode@users.noreply.github.com> Date: Sun, 12 Apr 2026 17:27:08 +0800 Subject: [PATCH] feat(embed): add embedding job drain --- internal/cli/admin_commands.go | 55 ++++ internal/cli/cli.go | 4 + internal/cli/cli_test.go | 65 ++++ internal/cli/output.go | 16 + internal/embed/ollama.go | 2 +- internal/embed/provider.go | 14 + internal/embed/provider_test.go | 21 ++ internal/store/embeddings.go | 487 +++++++++++++++++++++++++++++ internal/store/store.go | 99 +++++- internal/store/store_test.go | 218 +++++++++++++ internal/store/store_write_test.go | 310 ++++++++++++++++++ internal/store/write.go | 34 +- 12 files changed, 1321 insertions(+), 4 deletions(-) create mode 100644 internal/store/embeddings.go diff --git a/internal/cli/admin_commands.go b/internal/cli/admin_commands.go index 2841df8..beadb91 100644 --- a/internal/cli/admin_commands.go +++ b/internal/cli/admin_commands.go @@ -164,6 +164,61 @@ func (r *runtime) runStatus(args []string) error { return r.print(status) } +func (r *runtime) runEmbed(args []string) error { + fs := flag.NewFlagSet("embed", flag.ContinueOnError) + fs.SetOutput(io.Discard) + limit := fs.Int("limit", store.DefaultEmbedLimit(), "") + batchSize := fs.Int("batch-size", r.cfg.Search.Embeddings.BatchSize, "") + rebuild := fs.Bool("rebuild", false, "") + if err := fs.Parse(args); err != nil { + return usageErr(err) + } + if fs.NArg() != 0 { + return usageErr(fmt.Errorf("embed takes no positional arguments")) + } + if *limit <= 0 { + return usageErr(fmt.Errorf("--limit must be positive")) + } + if *batchSize <= 0 { + return usageErr(fmt.Errorf("--batch-size must be positive")) + } + if !r.cfg.Search.Embeddings.Enabled { + return usageErr(fmt.Errorf("embeddings are disabled in config")) + } + providerFactory := r.newEmbed + if providerFactory == nil { + providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) { + return embed.NewProvider(cfg) + } + } + provider, err := providerFactory(r.cfg.Search.Embeddings) + if err != nil { + return configErr(err) + } + opts := store.EmbeddingDrainOptions{ + Provider: r.cfg.Search.Embeddings.Provider, + Model: r.cfg.Search.Embeddings.Model, + InputVersion: store.EmbeddingInputVersion, + Limit: *limit, + BatchSize: *batchSize, + MaxInputChars: r.cfg.Search.Embeddings.MaxInputChars, + Now: r.now, + } + requeued := 0 + if *rebuild { + requeued, err = r.store.RequeueAllEmbeddingJobs(r.ctx, opts) + if err != nil { + return err + } + } + stats, err := r.store.DrainEmbeddingJobs(r.ctx, provider, opts) + if err != nil { + return err + } + stats.Requeued = requeued + return r.print(stats) +} + func (r *runtime) runDoctor(args []string) error { if len(args) != 0 { return usageErr(fmt.Errorf("doctor takes no arguments")) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index a49c711..5ed8b0c 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -12,6 +12,7 @@ import ( "github.com/bwmarrin/discordgo" "github.com/steipete/discrawl/internal/config" "github.com/steipete/discrawl/internal/discord" + "github.com/steipete/discrawl/internal/embed" "github.com/steipete/discrawl/internal/store" "github.com/steipete/discrawl/internal/syncer" ) @@ -94,6 +95,7 @@ type runtime struct { openStore func(context.Context, string) (*store.Store, error) newDiscord func(config.Config) (discordClient, error) newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService + newEmbed func(config.EmbeddingsConfig) (embed.Provider, error) now func() time.Time } @@ -128,6 +130,8 @@ func (r *runtime) dispatch(rest []string) error { return r.withServices(hasBoolFlag(rest[1:], "--sync"), func() error { return r.runMessages(rest[1:]) }) case "mentions": return r.withServices(false, func() error { return r.runMentions(rest[1:]) }) + case "embed": + return r.withServices(false, func() error { return r.runEmbed(rest[1:]) }) case "sql": return r.withServices(false, func() error { return r.runSQL(rest[1:]) }) case "members": diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index b305994..871fce9 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "context" + "encoding/json" "log/slog" "net/http" "net/http/httptest" @@ -133,6 +134,70 @@ func TestStatusSearchSQLAndListings(t *testing.T) { } } +func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + dbPath := filepath.Join(dir, "discrawl.db") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/embeddings", r.URL.Path) + var req struct { + Input []string `json:"input"` + } + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Len(t, req.Input, 1) + _, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1,2]}]}`)) + })) + defer server.Close() + + cfg := config.Default() + cfg.DBPath = dbPath + cfg.Search.Embeddings.Enabled = true + cfg.Search.Embeddings.Provider = "openai_compatible" + cfg.Search.Embeddings.Model = "local-model" + cfg.Search.Embeddings.BaseURL = server.URL + cfg.Search.Embeddings.APIKeyEnv = "" + require.NoError(t, config.Write(cfgPath, cfg)) + + s, err := store.Open(ctx, dbPath) + require.NoError(t, err) + for _, id := range []string{"m1", "m2"} { + require.NoError(t, s.UpsertMessageWithOptions(ctx, store.MessageRecord{ + ID: id, + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + }, store.WriteOptions{EnqueueEmbedding: true})) + } + require.NoError(t, s.Close()) + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--limit", "1"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "processed=1") + require.Contains(t, out.String(), "succeeded=1") + require.Contains(t, out.String(), "remaining_backlog=1") + require.Contains(t, out.String(), "provider=openai_compatible") + + s, err = store.Open(ctx, dbPath) + require.NoError(t, err) + _, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings") + require.NoError(t, err) + require.Equal(t, "1", rows[0][0]) + require.NoError(t, s.Close()) + + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--rebuild", "--limit", "1"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "processed=1") + require.Contains(t, out.String(), "succeeded=1") + require.Contains(t, out.String(), "remaining_backlog=1") + require.Contains(t, out.String(), "requeued=2") +} + type fakeDiscordClient struct { guilds []*discordgo.UserGuild self *discordgo.User diff --git a/internal/cli/output.go b/internal/cli/output.go index 8c9c7ec..f10af3e 100644 --- a/internal/cli/output.go +++ b/internal/cli/output.go @@ -80,6 +80,7 @@ Commands: search messages mentions + embed sql members channels @@ -107,6 +108,21 @@ func printHuman(w io.Writer, value any) error { v.DBPath, v.GuildCount, v.ChannelCount, v.ThreadCount, v.MessageCount, v.MemberCount, v.EmbeddingBacklog, formatTime(v.LastSyncAt), formatTime(v.LastTailEventAt)) return err + case store.EmbeddingDrainStats: + _, err := fmt.Fprintf(w, "processed=%d\nsucceeded=%d\nfailed=%d\nskipped=%d\nremaining_backlog=%d\nprovider=%s\nmodel=%s\ninput_version=%s\n", + v.Processed, v.Succeeded, v.Failed, v.Skipped, v.RemainingBacklog, v.Provider, v.Model, v.InputVersion) + if err != nil { + return err + } + if v.Requeued > 0 { + if _, err := fmt.Fprintf(w, "requeued=%d\n", v.Requeued); err != nil { + return err + } + } + if v.RateLimited { + _, err = fmt.Fprintln(w, "rate_limited=true") + } + return err case []store.SearchResult: for _, row := range v { if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s\n\n", row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Content); err != nil { diff --git a/internal/embed/ollama.go b/internal/embed/ollama.go index 2a33930..b5daa15 100644 --- a/internal/embed/ollama.go +++ b/internal/embed/ollama.go @@ -82,7 +82,7 @@ func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string, defer func() { _ = resp.Body.Close() }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg)) + return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)} } if err := json.NewDecoder(resp.Body).Decode(target); err != nil { return fmt.Errorf("decode embedding response: %w", err) diff --git a/internal/embed/provider.go b/internal/embed/provider.go index 8f3f964..e65edb5 100644 --- a/internal/embed/provider.go +++ b/internal/embed/provider.go @@ -41,6 +41,20 @@ type EmbeddingBatch struct { Vectors [][]float32 } +type HTTPError struct { + StatusCode int + Body string +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body) +} + +func IsRateLimitError(err error) bool { + var httpErr *HTTPError + return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests +} + type CheckResult struct { Provider string Model string diff --git a/internal/embed/provider_test.go b/internal/embed/provider_test.go index b37c9a0..8d604b6 100644 --- a/internal/embed/provider_test.go +++ b/internal/embed/provider_test.go @@ -172,6 +172,27 @@ func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) { require.False(t, result.Probed) } +func TestProviderExposesRateLimitErrors(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "rate limited", http.StatusTooManyRequests) + })) + defer server.Close() + + provider, err := NewProvider(config.EmbeddingsConfig{ + Provider: ProviderOpenAICompatible, + Model: "local-model", + BaseURL: server.URL, + RequestTimeout: "5s", + }) + require.NoError(t, err) + + _, err = provider.Embed(context.Background(), []string{"one"}) + require.ErrorContains(t, err, "HTTP 429") + require.True(t, IsRateLimitError(err)) +} + func TestProviderRejectsInvalidResponses(t *testing.T) { t.Parallel() diff --git a/internal/store/embeddings.go b/internal/store/embeddings.go new file mode 100644 index 0000000..bff99ce --- /dev/null +++ b/internal/store/embeddings.go @@ -0,0 +1,487 @@ +package store + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "strings" + "time" + + "github.com/steipete/discrawl/internal/embed" +) + +const ( + EmbeddingInputVersion = "message_normalized_v1" + defaultEmbedLimit = 1000 + maxEmbeddingAttempts = 3 + maxStoredErrorChars = 500 +) + +type EmbeddingDrainOptions struct { + Provider string + Model string + InputVersion string + Limit int + BatchSize int + MaxInputChars int + Now func() time.Time +} + +type EmbeddingDrainStats struct { + Processed int `json:"processed"` + Succeeded int `json:"succeeded"` + Failed int `json:"failed"` + Skipped int `json:"skipped"` + Requeued int `json:"requeued,omitempty"` + RemainingBacklog int `json:"remaining_backlog"` + Provider string `json:"provider"` + Model string `json:"model"` + InputVersion string `json:"input_version"` + RateLimited bool `json:"rate_limited,omitempty"` +} + +type embeddingJob struct { + MessageID string + NormalizedContent string + Attempts int + Provider string + Model string + InputVersion string +} + +func DefaultEmbedLimit() int { + return defaultEmbedLimit +} + +func (s *Store) DrainEmbeddingJobs(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions) (EmbeddingDrainStats, error) { + opts = normalizeEmbeddingDrainOptions(opts) + stats := EmbeddingDrainStats{ + Provider: opts.Provider, + Model: opts.Model, + InputVersion: opts.InputVersion, + } + if provider == nil { + return stats, errors.New("embedding provider is nil") + } + jobs, err := s.pendingEmbeddingJobs(ctx, opts.Limit) + if err != nil { + return stats, err + } + var batch []embeddingJob + flush := func() error { + if len(batch) == 0 { + return nil + } + rateLimited, err := s.processEmbeddingBatch(ctx, provider, opts, batch, &stats) + batch = batch[:0] + if err != nil { + return err + } + if rateLimited { + stats.RateLimited = true + } + return nil + } + for _, job := range jobs { + if !sameEmbeddingIdentity(job, opts) { + resetAttempts := !emptyEmbeddingIdentity(job) + if err := s.resetEmbeddingJobIdentity(ctx, job.MessageID, opts, resetAttempts); err != nil { + return stats, err + } + job.Provider = opts.Provider + job.Model = opts.Model + job.InputVersion = opts.InputVersion + if resetAttempts { + job.Attempts = 0 + } + } + if strings.TrimSpace(job.NormalizedContent) == "" { + if err := s.markEmbeddingJobsDone(ctx, opts, []embeddingJob{job}); err != nil { + return stats, err + } + stats.Processed++ + stats.Skipped++ + continue + } + batch = append(batch, job) + if len(batch) >= opts.BatchSize { + if err := flush(); err != nil { + return stats, err + } + if stats.RateLimited { + break + } + } + } + if !stats.RateLimited { + if err := flush(); err != nil { + return stats, err + } + } + stats.RemainingBacklog, err = s.EmbeddingBacklog(ctx) + if err != nil { + return stats, err + } + return stats, nil +} + +func normalizeEmbeddingDrainOptions(opts EmbeddingDrainOptions) EmbeddingDrainOptions { + opts.Provider = strings.ToLower(strings.TrimSpace(opts.Provider)) + opts.Model = strings.TrimSpace(opts.Model) + opts.InputVersion = strings.TrimSpace(opts.InputVersion) + if opts.InputVersion == "" { + opts.InputVersion = EmbeddingInputVersion + } + if opts.Limit <= 0 { + opts.Limit = defaultEmbedLimit + } + if opts.BatchSize <= 0 { + opts.BatchSize = embed.DefaultBatchSize + } + if opts.BatchSize > opts.Limit { + opts.BatchSize = opts.Limit + } + if opts.MaxInputChars <= 0 { + opts.MaxInputChars = embed.DefaultMaxInputChars + } + if opts.Now == nil { + opts.Now = func() time.Time { return time.Now().UTC() } + } + return opts +} + +func sameEmbeddingIdentity(job embeddingJob, opts EmbeddingDrainOptions) bool { + return job.Provider == opts.Provider && job.Model == opts.Model && job.InputVersion == opts.InputVersion +} + +func emptyEmbeddingIdentity(job embeddingJob) bool { + return job.Provider == "" && job.Model == "" && job.InputVersion == "" +} + +func (s *Store) pendingEmbeddingJobs(ctx context.Context, limit int) ([]embeddingJob, error) { + rows, err := s.db.QueryContext(ctx, ` + select + j.message_id, + m.normalized_content, + j.attempts, + j.provider, + j.model, + j.input_version + from embedding_jobs j + join messages m on m.id = j.message_id + where j.state = 'pending' + order by j.updated_at, j.message_id + limit ? + `, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var jobs []embeddingJob + for rows.Next() { + var job embeddingJob + if err := rows.Scan(&job.MessageID, &job.NormalizedContent, &job.Attempts, &job.Provider, &job.Model, &job.InputVersion); err != nil { + return nil, err + } + jobs = append(jobs, job) + } + return jobs, rows.Err() +} + +func (s *Store) resetEmbeddingJobIdentity(ctx context.Context, messageID string, opts EmbeddingDrainOptions, resetAttempts bool) error { + if resetAttempts { + _, err := s.db.ExecContext(ctx, ` + update embedding_jobs + set provider = ?, + model = ?, + input_version = ?, + attempts = 0, + last_error = '', + locked_at = null, + updated_at = ? + where message_id = ? + `, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID) + return err + } + _, err := s.db.ExecContext(ctx, ` + update embedding_jobs + set provider = ?, + model = ?, + input_version = ?, + last_error = '', + locked_at = null, + updated_at = ? + where message_id = ? + `, opts.Provider, opts.Model, opts.InputVersion, opts.Now().Format(timeLayout), messageID) + return err +} + +func (s *Store) processEmbeddingBatch(ctx context.Context, provider embed.Provider, opts EmbeddingDrainOptions, jobs []embeddingJob, stats *EmbeddingDrainStats) (bool, error) { + if err := s.lockEmbeddingJobs(ctx, jobs, opts.Now().Format(timeLayout)); err != nil { + return false, err + } + inputs := make([]string, 0, len(jobs)) + for _, job := range jobs { + inputs = append(inputs, capRunes(job.NormalizedContent, opts.MaxInputChars)) + } + batch, err := provider.Embed(ctx, inputs) + if err != nil { + if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil { + return false, markErr + } + stats.Processed += len(jobs) + stats.Failed += len(jobs) + return embed.IsRateLimitError(err), nil + } + dimensions, err := validateEmbeddingBatch(batch, len(jobs)) + if err != nil { + if markErr := s.markEmbeddingJobsFailed(ctx, opts, jobs, err); markErr != nil { + return false, markErr + } + stats.Processed += len(jobs) + stats.Failed += len(jobs) + return false, nil + } + if err := s.storeEmbeddingBatch(ctx, opts, jobs, batch.Vectors, dimensions); err != nil { + return false, err + } + stats.Processed += len(jobs) + stats.Succeeded += len(jobs) + return false, nil +} + +func (s *Store) lockEmbeddingJobs(ctx context.Context, jobs []embeddingJob, lockedAt string) error { + if len(jobs) == 0 { + return nil + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + for _, job := range jobs { + if _, err := tx.ExecContext(ctx, ` + update embedding_jobs + set locked_at = ?, updated_at = ? + where message_id = ? + `, lockedAt, lockedAt, job.MessageID); err != nil { + return err + } + } + return tx.Commit() +} + +func validateEmbeddingBatch(batch embed.EmbeddingBatch, expected int) (int, error) { + if len(batch.Vectors) != expected { + return 0, fmt.Errorf("embedding provider returned %d vectors for %d inputs", len(batch.Vectors), expected) + } + dimensions := batch.Dimensions + for _, vector := range batch.Vectors { + if len(vector) == 0 { + return 0, errors.New("embedding provider returned an empty vector") + } + if dimensions == 0 { + dimensions = len(vector) + continue + } + if len(vector) != dimensions { + return 0, fmt.Errorf("embedding provider dimensions mismatch: got %d want %d", len(vector), dimensions) + } + } + return dimensions, nil +} + +func (s *Store) storeEmbeddingBatch(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, vectors [][]float32, dimensions int) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + embeddedAt := opts.Now().Format(timeLayout) + for i, job := range jobs { + blob, err := EncodeEmbeddingVector(vectors[i]) + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, ` + insert into message_embeddings( + message_id, provider, model, input_version, dimensions, embedding_blob, embedded_at + ) values(?, ?, ?, ?, ?, ?, ?) + on conflict(message_id, provider, model, input_version) do update set + dimensions = excluded.dimensions, + embedding_blob = excluded.embedding_blob, + embedded_at = excluded.embedded_at + `, job.MessageID, opts.Provider, opts.Model, opts.InputVersion, dimensions, blob, embeddedAt); err != nil { + return err + } + if _, err := tx.ExecContext(ctx, ` + update embedding_jobs + set state = 'done', + attempts = 0, + provider = ?, + model = ?, + input_version = ?, + last_error = '', + locked_at = null, + updated_at = ? + where message_id = ? + `, opts.Provider, opts.Model, opts.InputVersion, embeddedAt, job.MessageID); err != nil { + return err + } + } + return tx.Commit() +} + +func (s *Store) markEmbeddingJobsDone(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + now := opts.Now().Format(timeLayout) + for _, job := range jobs { + if _, err := tx.ExecContext(ctx, ` + update embedding_jobs + set state = 'done', + provider = ?, + model = ?, + input_version = ?, + last_error = '', + locked_at = null, + updated_at = ? + where message_id = ? + `, opts.Provider, opts.Model, opts.InputVersion, now, job.MessageID); err != nil { + return err + } + } + return tx.Commit() +} + +func (s *Store) markEmbeddingJobsFailed(ctx context.Context, opts EmbeddingDrainOptions, jobs []embeddingJob, cause error) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + now := opts.Now().Format(timeLayout) + lastError := trimStoredError(cause) + for _, job := range jobs { + attempts := job.Attempts + 1 + state := "pending" + if attempts >= maxEmbeddingAttempts { + state = "failed" + } + if _, err := tx.ExecContext(ctx, ` + update embedding_jobs + set state = ?, + attempts = ?, + provider = ?, + model = ?, + input_version = ?, + last_error = ?, + locked_at = null, + updated_at = ? + where message_id = ? + `, state, attempts, opts.Provider, opts.Model, opts.InputVersion, lastError, now, job.MessageID); err != nil { + return err + } + } + return tx.Commit() +} + +func trimStoredError(err error) string { + if err == nil { + return "" + } + msg := strings.TrimSpace(err.Error()) + runes := []rune(msg) + if len(runes) > maxStoredErrorChars { + msg = string(runes[:maxStoredErrorChars]) + } + return msg +} + +func capRunes(value string, maxChars int) string { + if maxChars <= 0 { + return value + } + runes := []rune(value) + if len(runes) <= maxChars { + return value + } + return string(runes[:maxChars]) +} + +func EncodeEmbeddingVector(vector []float32) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, len(vector)*4)) + for _, value := range vector { + if err := binary.Write(buf, binary.LittleEndian, value); err != nil { + return nil, fmt.Errorf("encode embedding vector: %w", err) + } + } + return buf.Bytes(), nil +} + +func DecodeEmbeddingVector(blob []byte) ([]float32, error) { + if len(blob)%4 != 0 { + return nil, fmt.Errorf("embedding blob length %d is not a float32 multiple", len(blob)) + } + out := make([]float32, len(blob)/4) + reader := bytes.NewReader(blob) + for i := range out { + if err := binary.Read(reader, binary.LittleEndian, &out[i]); err != nil { + return nil, fmt.Errorf("decode embedding vector: %w", err) + } + } + return out, nil +} + +func (s *Store) EmbeddingBacklog(ctx context.Context) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, `select count(*) from embedding_jobs where state = 'pending'`).Scan(&count) + return count, err +} + +func (s *Store) RequeueAllEmbeddingJobs(ctx context.Context, opts EmbeddingDrainOptions) (int, error) { + opts = normalizeEmbeddingDrainOptions(opts) + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + defer rollback(tx) + now := opts.Now().Format(timeLayout) + if _, err := tx.ExecContext(ctx, ` + insert or ignore into embedding_jobs( + message_id, state, attempts, provider, model, input_version, last_error, locked_at, updated_at + ) + select id, 'pending', 0, ?, ?, ?, '', null, ? + from messages + `, opts.Provider, opts.Model, opts.InputVersion, now); err != nil { + return 0, err + } + result, err := tx.ExecContext(ctx, ` + update embedding_jobs + set state = 'pending', + attempts = 0, + provider = ?, + model = ?, + input_version = ?, + last_error = '', + locked_at = null, + updated_at = ? + where message_id in (select id from messages) + `, opts.Provider, opts.Model, opts.InputVersion, now) + if err != nil { + return 0, err + } + if err := tx.Commit(); err != nil { + return 0, err + } + affected, err := result.RowsAffected() + if err != nil { + return 0, err + } + return int(affected), nil +} diff --git a/internal/store/store.go b/internal/store/store.go index 6d2ec62..d8961c7 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -18,7 +18,7 @@ const ( timeLayout = time.RFC3339Nano messageFTSVersion = "2" memberFTSVersion = "1" - storeSchemaVersion = 1 + storeSchemaVersion = 2 ) type Store struct { @@ -199,9 +199,21 @@ func (s *Store) migrate(ctx context.Context) error { if err := s.applyBaselineSchema(ctx); err != nil { return err } + if err := s.applySchemaV2(ctx); err != nil { + return err + } if err := s.setSchemaVersion(ctx, storeSchemaVersion); err != nil { return err } + currentVersion = storeSchemaVersion + } + if currentVersion == 1 { + if err := s.applySchemaV2(ctx); err != nil { + return err + } + if err := s.setSchemaVersion(ctx, 2); err != nil { + return err + } } if version, err := s.schemaVersion(ctx); err != nil { return err @@ -339,8 +351,23 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { message_id text primary key, state text not null, attempts integer not null default 0, + provider text not null default '', + model text not null default '', + input_version text not null default '', + last_error text not null default '', + locked_at text, updated_at text not null );`, + `create table if not exists message_embeddings ( + message_id text not null, + provider text not null, + model text not null, + input_version text not null, + dimensions integer not null, + embedding_blob blob not null, + embedded_at text not null, + primary key (message_id, provider, model, input_version) + );`, `create virtual table if not exists message_fts using fts5( message_id unindexed, guild_id unindexed, @@ -368,6 +395,7 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { `create index if not exists idx_mentions_message_id on mention_events(message_id);`, `create index if not exists idx_mentions_target on mention_events(target_type, target_id, event_at);`, `create index if not exists idx_mentions_author on mention_events(author_id, event_at);`, + `create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`, } for _, stmt := range stmts { if _, err := tx.ExecContext(ctx, stmt); err != nil { @@ -377,6 +405,75 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { return tx.Commit() } +func (s *Store) applySchemaV2(ctx context.Context) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + for _, column := range []struct { + name string + sql string + }{ + {"provider", `alter table embedding_jobs add column provider text not null default ''`}, + {"model", `alter table embedding_jobs add column model text not null default ''`}, + {"input_version", `alter table embedding_jobs add column input_version text not null default ''`}, + {"last_error", `alter table embedding_jobs add column last_error text not null default ''`}, + {"locked_at", `alter table embedding_jobs add column locked_at text`}, + } { + ok, err := columnExists(ctx, tx, "embedding_jobs", column.name) + if err != nil { + return err + } + if !ok { + if _, err := tx.ExecContext(ctx, column.sql); err != nil { + return fmt.Errorf("add embedding_jobs.%s: %w", column.name, err) + } + } + } + stmts := []string{ + `create table if not exists message_embeddings ( + message_id text not null, + provider text not null, + model text not null, + input_version text not null, + dimensions integer not null, + embedding_blob blob not null, + embedded_at text not null, + primary key (message_id, provider, model, input_version) + );`, + `create index if not exists idx_embedding_jobs_state_updated on embedding_jobs(state, updated_at);`, + } + for _, stmt := range stmts { + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("migrate schema v2: %w", err) + } + } + return tx.Commit() +} + +func columnExists(ctx context.Context, tx *sql.Tx, table, column string) (bool, error) { + rows, err := tx.QueryContext(ctx, `pragma table_info(`+table+`)`) + if err != nil { + return false, fmt.Errorf("inspect %s columns: %w", table, err) + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var cid int + var name, typ string + var notNull int + var defaultValue sql.NullString + var pk int + if err := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil { + return false, err + } + if name == column { + return true, nil + } + } + return false, rows.Err() +} + func (s *Store) ensureFTSRowIDs(ctx context.Context) error { var version sql.NullString err := s.db.QueryRowContext(ctx, ` diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 70607a7..33ada11 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -273,6 +273,88 @@ func TestOpenBackfillsMissingSchemaVersion(t *testing.T) { require.Equal(t, storeSchemaVersion, version) } +func TestOpenMigratesSchemaV1ToV2(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPath := filepath.Join(t.TempDir(), "discrawl.db") + require.NoError(t, createV1Schema(ctx, dbPath)) + + db, err := sql.Open("sqlite", dbPath) + require.NoError(t, err) + _, err = db.ExecContext(ctx, ` + insert into messages( + id, guild_id, channel_id, message_type, created_at, content, + normalized_content, raw_json, updated_at + ) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z') + `) + require.NoError(t, err) + _, err = db.ExecContext(ctx, ` + insert into embedding_jobs(message_id, state, attempts, updated_at) + values('m1', 'pending', 1, '2026-01-01T00:00:00Z') + `) + require.NoError(t, err) + require.NoError(t, db.Close()) + + s, err := Open(ctx, dbPath) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + var version int + require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) + require.Equal(t, 2, version) + + _, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"", "", "", "", ""}}, rows) + + _, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings") + require.NoError(t, err) + require.Equal(t, "0", rows[0][0]) +} + +func TestOpenMigratesUnversionedV1SchemaToV2(t *testing.T) { + t.Parallel() + + ctx := context.Background() + dbPath := filepath.Join(t.TempDir(), "discrawl.db") + require.NoError(t, createV1Schema(ctx, dbPath)) + + db, err := sql.Open("sqlite", dbPath) + require.NoError(t, err) + _, err = db.ExecContext(ctx, ` + insert into messages( + id, guild_id, channel_id, message_type, created_at, content, + normalized_content, raw_json, updated_at + ) values('m1', 'g1', 'c1', 0, '2026-01-01T00:00:00Z', 'hello', 'hello', '{}', '2026-01-01T00:00:00Z') + `) + require.NoError(t, err) + _, err = db.ExecContext(ctx, ` + insert into embedding_jobs(message_id, state, attempts, updated_at) + values('m1', 'pending', 1, '2026-01-01T00:00:00Z') + `) + require.NoError(t, err) + _, err = db.ExecContext(ctx, `pragma user_version = 0`) + require.NoError(t, err) + require.NoError(t, db.Close()) + + s, err := Open(ctx, dbPath) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + var version int + require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) + require.Equal(t, 2, version) + + _, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"", "", "", "", ""}}, rows) + + _, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings") + require.NoError(t, err) + require.Equal(t, "0", rows[0][0]) +} + func TestReadOnlyQueryGuards(t *testing.T) { t.Parallel() @@ -290,6 +372,142 @@ func TestReadOnlyQueryGuards(t *testing.T) { require.Error(t, err) } +func createV1Schema(ctx context.Context, path string) error { + db, err := sql.Open("sqlite", path) + if err != nil { + return err + } + defer func() { _ = db.Close() }() + stmts := []string{ + `create table guilds ( + id text primary key, + name text not null, + icon text, + raw_json text not null, + updated_at text not null + );`, + `create table channels ( + id text primary key, + guild_id text not null, + parent_id text, + kind text not null, + name text not null, + topic text, + position integer, + is_nsfw integer not null default 0, + is_archived integer not null default 0, + is_locked integer not null default 0, + is_private_thread integer not null default 0, + thread_parent_id text, + archive_timestamp text, + raw_json text not null, + updated_at text not null + );`, + `create table members ( + guild_id text not null, + user_id text not null, + username text not null, + global_name text, + display_name text, + nick text, + discriminator text, + avatar text, + bot integer not null default 0, + joined_at text, + role_ids_json text not null, + raw_json text not null, + updated_at text not null, + primary key (guild_id, user_id) + );`, + `create table messages ( + id text primary key, + guild_id text not null, + channel_id text not null, + author_id text, + message_type integer not null, + created_at text not null, + edited_at text, + deleted_at text, + content text not null, + normalized_content text not null, + reply_to_message_id text, + pinned integer not null default 0, + has_attachments integer not null default 0, + raw_json text not null, + updated_at text not null + );`, + `create table message_events ( + event_id integer primary key autoincrement, + guild_id text not null, + channel_id text not null, + message_id text not null, + event_type text not null, + event_at text not null, + payload_json text not null + );`, + `create table message_attachments ( + attachment_id text primary key, + message_id text not null, + guild_id text not null, + channel_id text not null, + author_id text, + filename text not null, + content_type text, + size integer not null default 0, + url text, + proxy_url text, + text_content text not null default '', + updated_at text not null + );`, + `create table mention_events ( + event_id integer primary key autoincrement, + message_id text not null, + guild_id text not null, + channel_id text not null, + author_id text, + target_type text not null, + target_id text not null, + target_name text not null default '', + event_at text not null + );`, + `create table sync_state ( + scope text primary key, + cursor text, + updated_at text not null + );`, + `create table embedding_jobs ( + message_id text primary key, + state text not null, + attempts integer not null default 0, + updated_at text not null + );`, + `create virtual table message_fts using fts5( + message_id unindexed, + guild_id unindexed, + channel_id unindexed, + author_id unindexed, + author_name, + channel_name, + content + );`, + `create virtual table member_fts using fts5( + member_key unindexed, + guild_id unindexed, + user_id unindexed, + username, + display_name, + profile_text + );`, + `pragma user_version = 1;`, + } + for _, stmt := range stmts { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} + func TestQueryAndExec(t *testing.T) { t.Parallel() diff --git a/internal/store/store_write_test.go b/internal/store/store_write_test.go index 8ad0799..207ca91 100644 --- a/internal/store/store_write_test.go +++ b/internal/store/store_write_test.go @@ -3,12 +3,15 @@ package store import ( "context" "database/sql" + "errors" "path/filepath" "sync" "testing" "time" "github.com/stretchr/testify/require" + + "github.com/steipete/discrawl/internal/embed" ) func TestUpsertMessagesBatch(t *testing.T) { @@ -123,6 +126,293 @@ func TestUpsertMessageWithEmbeddingsQueuesJob(t *testing.T) { require.Equal(t, "1", rows[0][0]) } +func TestUpsertMessageWithEmbeddingsQueuesExistingMessageWithoutJob(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + record := MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + } + require.NoError(t, s.UpsertMessage(ctx, record)) + require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true})) + + _, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"pending", "0"}}, rows) +} + +func TestDrainEmbeddingJobsStoresVectorsAndSkipsEmptyInput(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + now := time.Now().UTC().Format(time.RFC3339Nano) + require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: now, + Content: "abcdef世界", + NormalizedContent: "abcdef世界", + RawJSON: `{}`, + }, WriteOptions{EnqueueEmbedding: true})) + require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{ + ID: "m2", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: now, + Content: "", + NormalizedContent: " ", + RawJSON: `{}`, + }, WriteOptions{EnqueueEmbedding: true})) + + provider := &fakeEmbeddingProvider{ + batches: []embed.EmbeddingBatch{{ + Vectors: [][]float32{{1.25, 2.5}}, + }}, + } + stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{ + Provider: "ollama", + Model: "nomic-embed-text", + Limit: 10, + BatchSize: 2, + MaxInputChars: 7, + }) + require.NoError(t, err) + require.Equal(t, 2, stats.Processed) + require.Equal(t, 1, stats.Succeeded) + require.Equal(t, 1, stats.Skipped) + require.Equal(t, 0, stats.RemainingBacklog) + require.Equal(t, [][]string{{"abcdef世"}}, provider.inputs) + + _, rows, err := s.ReadOnlyQuery(ctx, "select message_id, provider, model, input_version, dimensions from message_embeddings") + require.NoError(t, err) + require.Equal(t, [][]string{{"m1", "ollama", "nomic-embed-text", EmbeddingInputVersion, "2"}}, rows) + + var blob []byte + require.NoError(t, s.DB().QueryRowContext(ctx, `select embedding_blob from message_embeddings where message_id = 'm1'`).Scan(&blob)) + vector, err := DecodeEmbeddingVector(blob) + require.NoError(t, err) + require.Equal(t, []float32{1.25, 2.5}, vector) + + _, rows, err = s.ReadOnlyQuery(ctx, "select message_id, state, provider, model, input_version from embedding_jobs order by message_id") + require.NoError(t, err) + require.Equal(t, [][]string{ + {"m1", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion}, + {"m2", "done", "ollama", "nomic-embed-text", EmbeddingInputVersion}, + }, rows) +} + +func TestUpsertMessageWithEmbeddingsDoesNotRequeueUnchangedDoneJob(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + record := MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + } + require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true})) + + stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{ + batches: []embed.EmbeddingBatch{{Vectors: [][]float32{{1, 2}}}}, + }, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1}) + require.NoError(t, err) + require.Equal(t, 1, stats.Succeeded) + + require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true})) + _, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"done", "0", ""}}, rows) + + backlog, err := s.EmbeddingBacklog(ctx) + require.NoError(t, err) + require.Equal(t, 0, backlog) + + record.NormalizedContent = "hello updated" + record.Content = "hello updated" + require.NoError(t, s.UpsertMessageWithOptions(ctx, record, WriteOptions{EnqueueEmbedding: true})) + _, rows, err = s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"pending", "0", ""}}, rows) +} + +func TestDrainEmbeddingJobsFailsWholeBatchOnDimensionMismatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + }, WriteOptions{EnqueueEmbedding: true})) + + stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{ + batches: []embed.EmbeddingBatch{{ + Dimensions: 3, + Vectors: [][]float32{{1, 2}}, + }}, + }, EmbeddingDrainOptions{Provider: "ollama", Model: "nomic-embed-text", Limit: 10, BatchSize: 1}) + require.NoError(t, err) + require.Equal(t, 1, stats.Failed) + + _, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, "pending", rows[0][0]) + require.Equal(t, "1", rows[0][1]) + require.Contains(t, rows[0][2], "dimensions mismatch") + + _, rows, err = s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings") + require.NoError(t, err) + require.Equal(t, "0", rows[0][0]) +} + +func TestDrainEmbeddingJobsMarksFailedAfterMaxAttempts(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + }, WriteOptions{EnqueueEmbedding: true})) + _, err = s.DB().ExecContext(ctx, `update embedding_jobs set attempts = 2 where message_id = 'm1'`) + require.NoError(t, err) + + stats, err := s.DrainEmbeddingJobs(ctx, &fakeEmbeddingProvider{err: errors.New("provider down")}, EmbeddingDrainOptions{ + Provider: "ollama", + Model: "nomic-embed-text", + Limit: 10, + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Failed) + + _, rows, err := s.ReadOnlyQuery(ctx, "select state, attempts, last_error from embedding_jobs where message_id = 'm1'") + require.NoError(t, err) + require.Equal(t, [][]string{{"failed", "3", "provider down"}}, rows) +} + +func TestDrainEmbeddingJobsStopsOnRateLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + for _, id := range []string{"m1", "m2"} { + require.NoError(t, s.UpsertMessageWithOptions(ctx, MessageRecord{ + ID: id, + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + }, WriteOptions{EnqueueEmbedding: true})) + } + + provider := &fakeEmbeddingProvider{err: &embed.HTTPError{StatusCode: 429, Body: "slow down"}} + stats, err := s.DrainEmbeddingJobs(ctx, provider, EmbeddingDrainOptions{ + Provider: "ollama", + Model: "nomic-embed-text", + Limit: 10, + BatchSize: 1, + }) + require.NoError(t, err) + require.True(t, stats.RateLimited) + require.Equal(t, 1, stats.Processed) + require.Equal(t, 1, stats.Failed) + require.Equal(t, 2, stats.RemainingBacklog) + require.Len(t, provider.inputs, 1) +} + +func TestRequeueAllEmbeddingJobsUsesCurrentIdentity(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + for _, id := range []string{"m1", "m2"} { + require.NoError(t, s.UpsertMessage(ctx, MessageRecord{ + ID: id, + GuildID: "g1", + ChannelID: "c1", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "hello", + NormalizedContent: "hello", + RawJSON: `{}`, + })) + } + _, err = s.DB().ExecContext(ctx, ` + insert into embedding_jobs(message_id, state, attempts, provider, model, input_version, last_error, updated_at) + values('m1', 'failed', 3, 'old', 'old-model', 'old-input', 'old error', ?) + `, time.Now().UTC().Format(timeLayout)) + require.NoError(t, err) + + requeued, err := s.RequeueAllEmbeddingJobs(ctx, EmbeddingDrainOptions{ + Provider: "ollama", + Model: "nomic-embed-text", + InputVersion: EmbeddingInputVersion, + }) + require.NoError(t, err) + require.Equal(t, 2, requeued) + + _, rows, err := s.ReadOnlyQuery(ctx, "select message_id, state, attempts, provider, model, input_version, last_error from embedding_jobs order by message_id") + require.NoError(t, err) + require.Equal(t, [][]string{ + {"m1", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""}, + {"m2", "pending", "0", "ollama", "nomic-embed-text", EmbeddingInputVersion, ""}, + }, rows) +} + func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) { t.Parallel() @@ -160,6 +450,26 @@ func TestConcurrentMessageUpsertsShareSingleWriter(t *testing.T) { require.Equal(t, "8", rows[0][0]) } +type fakeEmbeddingProvider struct { + batches []embed.EmbeddingBatch + err error + inputs [][]string +} + +func (f *fakeEmbeddingProvider) Embed(_ context.Context, inputs []string) (embed.EmbeddingBatch, error) { + copied := append([]string(nil), inputs...) + f.inputs = append(f.inputs, copied) + if f.err != nil { + return embed.EmbeddingBatch{}, f.err + } + if len(f.batches) == 0 { + return embed.EmbeddingBatch{}, nil + } + batch := f.batches[0] + f.batches = f.batches[1:] + return batch, nil +} + func TestMessageFTSUsesSnowflakeRowID(t *testing.T) { t.Parallel() diff --git a/internal/store/write.go b/internal/store/write.go index a23f133..e885b15 100644 --- a/internal/store/write.go +++ b/internal/store/write.go @@ -287,6 +287,30 @@ func (s *Store) UpsertMessages(ctx context.Context, messages []MessageMutation) func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opts WriteOptions) error { now := time.Now().UTC().Format(timeLayout) + var previousNormalized sql.NullString + previousErr := sql.ErrNoRows + jobExists := false + if opts.EnqueueEmbedding { + previousErr = tx.QueryRowContext(ctx, ` + select normalized_content + from messages + where id = ? + `, message.ID).Scan(&previousNormalized) + if previousErr != nil && previousErr != sql.ErrNoRows { + return previousErr + } + if previousErr == nil { + var existingJobs int + if err := tx.QueryRowContext(ctx, ` + select count(*) + from embedding_jobs + where message_id = ? + `, message.ID).Scan(&existingJobs); err != nil { + return err + } + jobExists = existingJobs > 0 + } + } if _, err := tx.ExecContext(ctx, ` insert into messages( id, guild_id, channel_id, author_id, message_type, created_at, edited_at, deleted_at, @@ -323,11 +347,17 @@ func upsertMessageTx(ctx context.Context, tx *sql.Tx, message MessageRecord, opt return err } } - if opts.EnqueueEmbedding { + queueEmbedding := opts.EnqueueEmbedding && (previousErr == sql.ErrNoRows || previousNormalized.String != message.NormalizedContent || !jobExists) + if queueEmbedding { if _, err := tx.ExecContext(ctx, ` insert into embedding_jobs(message_id, state, attempts, updated_at) values(?, 'pending', 0, ?) - on conflict(message_id) do nothing + on conflict(message_id) do update set + state = 'pending', + attempts = 0, + last_error = '', + locked_at = null, + updated_at = excluded.updated_at `, message.ID, now); err != nil { return err }