Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions internal/cli/admin_commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,61 @@ func (r *runtime) runStatus(args []string) error {
return r.print(status)
}

func (r *runtime) runEmbed(args []string) error {
fs := flag.NewFlagSet("embed", flag.ContinueOnError)
fs.SetOutput(io.Discard)
limit := fs.Int("limit", store.DefaultEmbedLimit(), "")
batchSize := fs.Int("batch-size", r.cfg.Search.Embeddings.BatchSize, "")
rebuild := fs.Bool("rebuild", false, "")
if err := fs.Parse(args); err != nil {
return usageErr(err)
}
if fs.NArg() != 0 {
return usageErr(fmt.Errorf("embed takes no positional arguments"))
}
if *limit <= 0 {
return usageErr(fmt.Errorf("--limit must be positive"))
}
if *batchSize <= 0 {
return usageErr(fmt.Errorf("--batch-size must be positive"))
}
if !r.cfg.Search.Embeddings.Enabled {
return usageErr(fmt.Errorf("embeddings are disabled in config"))
}
providerFactory := r.newEmbed
if providerFactory == nil {
providerFactory = func(cfg config.EmbeddingsConfig) (embed.Provider, error) {
return embed.NewProvider(cfg)
}
}
provider, err := providerFactory(r.cfg.Search.Embeddings)
if err != nil {
return configErr(err)
}
opts := store.EmbeddingDrainOptions{
Provider: r.cfg.Search.Embeddings.Provider,
Model: r.cfg.Search.Embeddings.Model,
InputVersion: store.EmbeddingInputVersion,
Limit: *limit,
BatchSize: *batchSize,
MaxInputChars: r.cfg.Search.Embeddings.MaxInputChars,
Now: r.now,
}
requeued := 0
if *rebuild {
requeued, err = r.store.RequeueAllEmbeddingJobs(r.ctx, opts)
if err != nil {
return err
}
}
stats, err := r.store.DrainEmbeddingJobs(r.ctx, provider, opts)
if err != nil {
return err
}
stats.Requeued = requeued
return r.print(stats)
}

func (r *runtime) runDoctor(args []string) error {
if len(args) != 0 {
return usageErr(fmt.Errorf("doctor takes no arguments"))
Expand Down
4 changes: 4 additions & 0 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/steipete/discrawl/internal/config"
"github.com/steipete/discrawl/internal/discord"
"github.com/steipete/discrawl/internal/embed"
"github.com/steipete/discrawl/internal/store"
"github.com/steipete/discrawl/internal/syncer"
)
Expand Down Expand Up @@ -94,6 +95,7 @@ type runtime struct {
openStore func(context.Context, string) (*store.Store, error)
newDiscord func(config.Config) (discordClient, error)
newSyncer func(syncer.Client, *store.Store, *slog.Logger) syncService
newEmbed func(config.EmbeddingsConfig) (embed.Provider, error)
now func() time.Time
}

Expand Down Expand Up @@ -128,6 +130,8 @@ func (r *runtime) dispatch(rest []string) error {
return r.withServices(hasBoolFlag(rest[1:], "--sync"), func() error { return r.runMessages(rest[1:]) })
case "mentions":
return r.withServices(false, func() error { return r.runMentions(rest[1:]) })
case "embed":
return r.withServices(false, func() error { return r.runEmbed(rest[1:]) })
case "sql":
return r.withServices(false, func() error { return r.runSQL(rest[1:]) })
case "members":
Expand Down
65 changes: 65 additions & 0 deletions internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -133,6 +134,70 @@ func TestStatusSearchSQLAndListings(t *testing.T) {
}
}

func TestEmbedCommandDrainsBoundedBacklog(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "discrawl.db")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/embeddings", r.URL.Path)
var req struct {
Input []string `json:"input"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
require.Len(t, req.Input, 1)
_, _ = w.Write([]byte(`{"data":[{"index":0,"embedding":[1,2]}]}`))
}))
defer server.Close()

cfg := config.Default()
cfg.DBPath = dbPath
cfg.Search.Embeddings.Enabled = true
cfg.Search.Embeddings.Provider = "openai_compatible"
cfg.Search.Embeddings.Model = "local-model"
cfg.Search.Embeddings.BaseURL = server.URL
cfg.Search.Embeddings.APIKeyEnv = ""
require.NoError(t, config.Write(cfgPath, cfg))

s, err := store.Open(ctx, dbPath)
require.NoError(t, err)
for _, id := range []string{"m1", "m2"} {
require.NoError(t, s.UpsertMessageWithOptions(ctx, store.MessageRecord{
ID: id,
GuildID: "g1",
ChannelID: "c1",
MessageType: 0,
CreatedAt: time.Now().UTC().Format(time.RFC3339Nano),
Content: "hello",
NormalizedContent: "hello",
RawJSON: `{}`,
}, store.WriteOptions{EnqueueEmbedding: true}))
}
require.NoError(t, s.Close())

var out bytes.Buffer
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--limit", "1"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "processed=1")
require.Contains(t, out.String(), "succeeded=1")
require.Contains(t, out.String(), "remaining_backlog=1")
require.Contains(t, out.String(), "provider=openai_compatible")

s, err = store.Open(ctx, dbPath)
require.NoError(t, err)
_, rows, err := s.ReadOnlyQuery(ctx, "select count(*) from message_embeddings")
require.NoError(t, err)
require.Equal(t, "1", rows[0][0])
require.NoError(t, s.Close())

out.Reset()
require.NoError(t, Run(ctx, []string{"--config", cfgPath, "embed", "--rebuild", "--limit", "1"}, &out, &bytes.Buffer{}))
require.Contains(t, out.String(), "processed=1")
require.Contains(t, out.String(), "succeeded=1")
require.Contains(t, out.String(), "remaining_backlog=1")
require.Contains(t, out.String(), "requeued=2")
}

type fakeDiscordClient struct {
guilds []*discordgo.UserGuild
self *discordgo.User
Expand Down
16 changes: 16 additions & 0 deletions internal/cli/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Commands:
search
messages
mentions
embed
sql
members
channels
Expand Down Expand Up @@ -107,6 +108,21 @@ func printHuman(w io.Writer, value any) error {
v.DBPath, v.GuildCount, v.ChannelCount, v.ThreadCount, v.MessageCount, v.MemberCount, v.EmbeddingBacklog,
formatTime(v.LastSyncAt), formatTime(v.LastTailEventAt))
return err
case store.EmbeddingDrainStats:
_, err := fmt.Fprintf(w, "processed=%d\nsucceeded=%d\nfailed=%d\nskipped=%d\nremaining_backlog=%d\nprovider=%s\nmodel=%s\ninput_version=%s\n",
v.Processed, v.Succeeded, v.Failed, v.Skipped, v.RemainingBacklog, v.Provider, v.Model, v.InputVersion)
if err != nil {
return err
}
if v.Requeued > 0 {
if _, err := fmt.Fprintf(w, "requeued=%d\n", v.Requeued); err != nil {
return err
}
}
if v.RateLimited {
_, err = fmt.Fprintln(w, "rate_limited=true")
}
return err
case []store.SearchResult:
for _, row := range v {
if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s\n\n", row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Content); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/embed/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func postJSON(ctx context.Context, client *http.Client, endpoint, apiKey string,
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("embedding request failed with HTTP %d: %s", resp.StatusCode, string(msg))
return &HTTPError{StatusCode: resp.StatusCode, Body: string(msg)}
}
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
return fmt.Errorf("decode embedding response: %w", err)
Expand Down
14 changes: 14 additions & 0 deletions internal/embed/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ type EmbeddingBatch struct {
Vectors [][]float32
}

type HTTPError struct {
StatusCode int
Body string
}

func (e *HTTPError) Error() string {
return fmt.Sprintf("embedding request failed with HTTP %d: %s", e.StatusCode, e.Body)
}

func IsRateLimitError(err error) bool {
var httpErr *HTTPError
return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusTooManyRequests
}

type CheckResult struct {
Provider string
Model string
Expand Down
21 changes: 21 additions & 0 deletions internal/embed/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,27 @@ func TestCheckProviderWarnsOnLocalProbeFailure(t *testing.T) {
require.False(t, result.Probed)
}

func TestProviderExposesRateLimitErrors(t *testing.T) {
t.Parallel()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "rate limited", http.StatusTooManyRequests)
}))
defer server.Close()

provider, err := NewProvider(config.EmbeddingsConfig{
Provider: ProviderOpenAICompatible,
Model: "local-model",
BaseURL: server.URL,
RequestTimeout: "5s",
})
require.NoError(t, err)

_, err = provider.Embed(context.Background(), []string{"one"})
require.ErrorContains(t, err, "HTTP 429")
require.True(t, IsRateLimitError(err))
}

func TestProviderRejectsInvalidResponses(t *testing.T) {
t.Parallel()

Expand Down
Loading