diff --git a/CHANGELOG.md b/CHANGELOG.md index 502d4ce..7c28483 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Changes +- Added attachment media caching with `discrawl attachments`, `attachments fetch`, `sync --with-media`, and Git snapshot backup/restore for cached non-DM media files. - Docker: add a local image with `/data` persistence and CI smoke coverage. ### Fixes diff --git a/README.md b/README.md index 9ae500c..bc3b3c6 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Wiretap DMs stay local and are never exported to the Git-backed snapshot mirror. - maintains FTS5 search indexes for fast local text search - builds an offline member directory from archived profile payloads - extracts small text-like attachments into the local search index +- downloads and backs up cached attachment media when requested - records structured user and role mentions for direct querying - tails Gateway events for live updates, with periodic repair syncs - imports classifiable Discord Desktop cache messages with `wiretap`, including proven DMs under `@me` @@ -268,6 +269,7 @@ Bot sync modes: `--all` ignores `default_guild_id` and fans out across every discovered guild the bot can access. `--skip-members` refreshes guild/channel/message data without crawling the full member list, which is useful for frequent Git snapshot publishers that only need latest messages. `--latest-only` is still accepted for explicit latest-only runs; it is now the default for untargeted `sync`. Use `--all-channels` to opt out of the fast default without doing a full historical crawl. +`--with-media` downloads missing attachment media into `cache_dir/media` after the message sync/import phase. When `--channels` includes a forum channel id, `discrawl` expands that forum's threads and syncs their messages as part of the targeted run. `--since` limits initial history/bootstrap and full-history backfill to messages at or after the given RFC3339 timestamp. It does not mark older history as complete, so a later `sync --full` without `--since` can continue the backfill. Long runs now emit periodic progress logs to stderr so large backfills and Git snapshot imports do not look hung. @@ -372,6 +374,19 @@ Notes: - at least one filter is required - `--dm` is shorthand for `--guild @me`, so DM searches and message slices do not need raw SQL +### `attachments` + +Lists attachment metadata and downloads media into the local cache when requested. + +```bash +discrawl attachments --channel general --days 7 +discrawl attachments --filename crash --type image --all +discrawl attachments fetch --channel general --days 7 +discrawl attachments fetch --missing --max-bytes 104857600 +``` + +Media bytes are stored under `cache_dir/media`, not in SQLite. SQLite stores attachment metadata, content hash, relative media path, fetch status, and fetch error. Cached non-DM media is included in Git snapshots by default; `publish --no-media` omits it. + ### `dms` Lists local wiretap DM conversations or reads one DM thread. @@ -491,6 +506,7 @@ Publisher: discrawl publish --remote https://github.com/example/discord-archive.git --push discrawl publish --readme path/to/discord-backup/README.md --push discrawl publish --public-only --include-channels 1458141495701012561 --push +discrawl publish --no-media --push ``` Subscriber: @@ -514,8 +530,8 @@ Once `share.remote` is configured, read commands auto-fetch and import when the Hybrid mode is supported too: keep normal Discord credentials configured and set `share.remote`. `discrawl sync --update=auto` and `discrawl messages --sync` import the Git snapshot first, usually as a changed-shard delta, then use live Discord for latest-message deltas. Use `sync --all-channels` or `sync --full` when you intentionally want a broader live repair/backfill pass. -Git snapshots publish non-DM archive tables by default. DMs, desktop wiretap -rows, and local secrets are never exported. +Git snapshots publish non-DM archive tables and cached non-DM attachment media by default. DMs, desktop wiretap rows, DM media, and local secrets are never exported. Use `publish --no-media` to omit cached media files. +Subscribers can use `subscribe --no-media` or `update --no-media` to import only SQLite rows and skip restoring cached files. Publish filters narrow only the Git snapshot. The publisher can still sync and keep a richer local SQLite archive, then publish a smaller view for Git-only @@ -666,6 +682,8 @@ concurrency = 16 repair_every = "6h" full_history = true attachment_text = true +attachment_media = false +max_attachment_bytes = 104857600 [desktop] path = "~/.config/discord" # macOS default: "~/Library/Application Support/discord" @@ -688,6 +706,7 @@ repo_path = "~/.local/share/discrawl/share" # macOS: "~/Library/Application Supp branch = "main" auto_update = true stale_after = "15m" +media = true ``` The value above is an example. `init` writes an auto-sized default based on the host: `min(32, max(8, GOMAXPROCS*2))`. @@ -757,7 +776,7 @@ Proven DMs use `@me` as their guild id. Unclassifiable desktop-cache payloads ar SQLite schema migrations are versioned with `PRAGMA user_version`. Startup now fails fast when a local DB schema is newer than the supported binary. -Attachment binaries are not stored in SQLite. +Attachment binaries are not stored in SQLite. `attachments fetch` and `sync --with-media` write media bytes under `cache_dir/media`; SQLite stores the relative media path, content hash, size, and fetch status. Set `sync.attachment_text = false` if you want to keep attachment metadata and filenames but disable attachment body fetches for text indexing. diff --git a/docs/commands/attachments.md b/docs/commands/attachments.md new file mode 100644 index 0000000..4155511 --- /dev/null +++ b/docs/commands/attachments.md @@ -0,0 +1,47 @@ +# `attachments` + +Lists attachment metadata and optionally downloads attachment media into the local cache. + +## Usage + +```bash +discrawl attachments --channel general --days 7 +discrawl attachments --filename crash --type image --all +discrawl attachments --message 1456744319972282449 +discrawl attachments fetch --channel general --days 7 +discrawl attachments fetch --missing --max-bytes 104857600 +discrawl --json attachments --missing --all +``` + +## Flags + +- `--channel ` - id, exact name, `#name`, or partial name match +- `--guild ` / `--guilds ` / `--dm` - restrict the guild scope (`--dm` is shorthand for `--guild @me`) +- `--author ` - restrict to one author +- `--message ` - restrict to one message +- `--filename ` - filename substring match +- `--type ` - content-type substring match, such as `image` or `application/pdf` +- `--hours ` - shorthand for "since now minus N hours" +- `--days ` - shorthand for "since now minus N days" +- `--since ` / `--before ` - explicit time window +- `--limit ` - safety limit (default 200; `--all` removes it) +- `--all` - removes the safety limit +- `--missing` - only attachments whose cached media file is absent + +`attachments fetch` also accepts: + +- `--force` - re-download already cached attachments +- `--max-bytes ` - per-attachment download cap (defaults to `[sync].max_attachment_bytes`) + +## Notes + +- media bytes are stored under `cache_dir/media`, not in SQLite +- SQLite stores attachment metadata, content hash, cached media path, fetch status, and errors +- `publish` backs up cached non-DM media files by default; use `publish --no-media` to omit them +- `@me` DM media is local-only and is not published to Git snapshots + +## See also + +- [`messages`](messages.html) +- [`publish`](publish.html) +- [`sync`](sync.html) diff --git a/docs/commands/publish.md b/docs/commands/publish.md index fa7e220..2a99f17 100644 --- a/docs/commands/publish.md +++ b/docs/commands/publish.md @@ -9,6 +9,7 @@ discrawl publish --remote https://github.com/example/discord-archive.git --push discrawl publish --readme path/to/discord-backup/README.md --push discrawl publish --message "sync: discord archive" --push discrawl publish --with-embeddings --push +discrawl publish --no-media --push discrawl publish --public-only --include-channels 1458141495701012561 --push ``` @@ -25,6 +26,7 @@ discrawl publish --public-only --include-channels 1458141495701012561 --push - `--include-channels ` - comma-separated channel ids to export; forum parents include their allowed public threads - `--exclude-channels ` - comma-separated channel ids to omit; exclusions win over includes - `--with-embeddings` - also export stored `message_embeddings` rows +- `--no-media` - omit cached attachment media files from the snapshot Filters narrow only the published snapshot. The local SQLite archive can still be synced from a richer bot-visible dataset. Git-only readers see the filtered @@ -58,6 +60,7 @@ README files without Discrawl report markers are left alone. ## What is published - non-DM archive tables (DM `@me` rows are always excluded) +- cached non-DM attachment media files under `media/` unless `--no-media` is used - when filters are enabled: only matching guilds, channels, messages, events, attachments, mentions, channel-scoped sync-state rows, member rows referenced by matching messages, and matching embedding rows @@ -68,6 +71,7 @@ README files without Discrawl report markers are left alone. ## What is not published - `@me` DM guilds, channels, messages, events, attachments, mentions, wiretap sync state +- `@me` DM media files - when filters are enabled: share manifest state and guild-level member freshness markers, because they describe the full archive - `embedding_jobs` diff --git a/docs/commands/subscribe.md b/docs/commands/subscribe.md index 017bf97..160b0ff 100644 --- a/docs/commands/subscribe.md +++ b/docs/commands/subscribe.md @@ -15,6 +15,7 @@ discrawl subscribe --stale-after 15m https://github.com/example/discord-archive. discrawl subscribe --no-auto-update https://github.com/example/discord-archive.git discrawl subscribe --no-import https://github.com/example/discord-archive.git discrawl subscribe --with-embeddings https://github.com/example/discord-archive.git +discrawl subscribe --no-media https://github.com/example/discord-archive.git ``` ## What it does @@ -31,6 +32,7 @@ discrawl subscribe --with-embeddings https://github.com/example/discord-archive. - `--no-auto-update` - disable auto-refresh (use [`update`](update.html) manually) - `--no-import` - write config only; skip the initial pull/import - `--with-embeddings` - import vectors that match your local `[search.embeddings]` identity +- `--no-media` - skip restoring cached attachment media files into `cache_dir/media` ## Disabled in this mode diff --git a/docs/commands/sync.md b/docs/commands/sync.md index 594286e..36f5d6e 100644 --- a/docs/commands/sync.md +++ b/docs/commands/sync.md @@ -28,6 +28,7 @@ discrawl sync --source wiretap # desktop cache only; aliases: desktop, cache discrawl sync --guild 123456789012345678 --all-channels discrawl sync --channels 111,222 --since 2026-03-01T00:00:00Z discrawl sync --with-embeddings +discrawl sync --with-media ``` ## Sources @@ -61,6 +62,7 @@ discrawl sync --with-embeddings - `--concurrency ` - override worker count (default auto-sized: floor 8, cap 32) - `--skip-members` - refresh guild/channel/message data without crawling members - `--with-embeddings` - also enqueue changed messages into `embedding_jobs` +- `--with-media` - after sync, download missing attachment media into `cache_dir/media` ## Notes diff --git a/docs/commands/update.md b/docs/commands/update.md index 9f3a88f..9a41e0b 100644 --- a/docs/commands/update.md +++ b/docs/commands/update.md @@ -12,6 +12,7 @@ discrawl update \ --repo ~/.local/share/discrawl/share \ --remote https://github.com/example/discord-archive.git discrawl update --with-embeddings +discrawl update --no-media ``` ## Flags @@ -20,6 +21,7 @@ discrawl update --with-embeddings - `--remote ` - target Git remote (defaults to `[share].remote`) - `--branch ` - snapshot branch (defaults to `[share].branch`) - `--with-embeddings` - also import vectors that match your local `[search.embeddings]` identity +- `--no-media` - skip restoring cached attachment media files into `cache_dir/media` ## When to use it diff --git a/docs/configuration.md b/docs/configuration.md index e5391b5..555ddd2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -62,6 +62,8 @@ concurrency = 16 repair_every = "6h" full_history = true attachment_text = true +attachment_media = false +max_attachment_bytes = 104857600 [desktop] path = "~/.config/discord" # macOS default: "~/Library/Application Support/discord" @@ -84,6 +86,7 @@ repo_path = "~/.local/share/discrawl/share" # macOS: "~/Library/Application Supp branch = "main" auto_update = true stale_after = "15m" +media = true [share.filter] public_only = false @@ -118,6 +121,8 @@ Set `discord.token_source = "keyring"` if you want to require keyring lookup and - `guild_ids` is reserved for explicit multi-guild fan-out; usually you do not set this directly - changing `[search.embeddings]` provider/model/input version retargets pending jobs and resets prior attempts; existing vectors for another identity remain in SQLite but are not used for semantic search - changing `db_path` does not migrate existing data; copy the file yourself if you want to keep history +- `sync.attachment_media = true` makes `sync` behave like `sync --with-media`; media bytes are cached under `cache_dir/media` +- `share.media = false` makes publish/update/auto-update omit or skip restoring cached media; `subscribe --no-media` writes this for Git-only readers - `[share.filter]` narrows only `publish` output; sync can still keep a richer local archive - `share.filter.public_only` exports only channels visible to the guild `@everyone` role after category/channel permission overwrites; private diff --git a/docs/guides/data-storage.md b/docs/guides/data-storage.md index e9c603c..3d9058c 100644 --- a/docs/guides/data-storage.md +++ b/docs/guides/data-storage.md @@ -28,7 +28,9 @@ Proven DMs use the synthetic guild id `@me`. Unclassifiable desktop-cache payloa ## Attachments -Attachment binaries are not stored in SQLite. Only attachment metadata, filenames, and (optionally) extracted text. +Attachment binaries are not stored in SQLite. SQLite stores attachment metadata, filenames, optional extracted text, and media cache bookkeeping. + +`discrawl attachments fetch` and `discrawl sync --with-media` download media into `cache_dir/media` and record the relative media path, SHA-256, byte size, fetch time, and fetch status on the attachment row. Set `sync.attachment_text = false` if you want to keep attachment metadata and filenames but disable attachment body fetches for text indexing. diff --git a/docs/security.md b/docs/security.md index 00cac70..8d60841 100644 --- a/docs/security.md +++ b/docs/security.md @@ -38,10 +38,12 @@ CI runs secret scanning with `gitleaks` against git history and the working tree - FTS index rows - optional local embedding queue metadata and vectors -Attachment binaries are not stored in SQLite. Only attachment metadata and (optionally) extracted text. +Attachment binaries are not stored in SQLite. Only attachment metadata, optional extracted text, and media cache bookkeeping are stored there. Cached files live under `cache_dir/media`. Set `sync.attachment_text = false` if you want to keep attachment metadata and filenames but disable attachment body fetches for text indexing. +Git snapshots include cached non-DM media files by default. Use `publish --no-media` to omit them. DM media under `@me` stays local-only. + ## What is sent over the wire With remote embedding providers, message text is sent during `discrawl embed`, and search query text is sent when using `--mode semantic` or `--mode hybrid`. Stored message text is not sent during local vector scoring. diff --git a/internal/cli/admin_commands.go b/internal/cli/admin_commands.go index 8b31975..1d72c8a 100644 --- a/internal/cli/admin_commands.go +++ b/internal/cli/admin_commands.go @@ -17,6 +17,7 @@ import ( "github.com/openclaw/discrawl/internal/config" "github.com/openclaw/discrawl/internal/discord" "github.com/openclaw/discrawl/internal/discorddesktop" + "github.com/openclaw/discrawl/internal/media" "github.com/openclaw/discrawl/internal/share" "github.com/openclaw/discrawl/internal/store" "github.com/openclaw/discrawl/internal/syncer" @@ -32,6 +33,7 @@ type syncRunStats struct { Source string `json:"source"` Discord *syncer.SyncStats `json:"discord,omitempty"` Wiretap *discorddesktop.Stats `json:"wiretap,omitempty"` + Media *media.FetchStats `json:"media,omitempty"` } func (r *runtime) runInit(args []string) error { @@ -110,6 +112,7 @@ func (r *runtime) runSync(args []string) error { concurrency := fs.Int("concurrency", r.cfg.Sync.Concurrency, "") source := fs.String("source", r.cfg.Sync.Source, "") withEmbeddings := fs.Bool("with-embeddings", false, "") + withMedia := fs.Bool("with-media", r.cfg.AttachmentMediaEnabled(), "") skipMembers := fs.Bool("skip-members", false, "") latestOnly := fs.Bool("latest-only", false, "") guildsFlag := fs.String("guilds", "", "") @@ -155,11 +158,11 @@ func (r *runtime) runSync(args []string) error { LatestOnly: syncLatestOnly(*latestOnly, defaultLatest), } return r.withSyncLock(func() error { - return r.runSyncLocked(sources, opts) + return r.runSyncLocked(sources, opts, *withMedia) }) } -func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) error { +func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions, withMedia bool) error { var apiStats *syncer.SyncStats if sources.discord { r.setSyncLockPhase("discord sync") @@ -190,13 +193,81 @@ func (r *runtime) runSyncLocked(sources syncSources, opts syncer.SyncOptions) er } wiretapStats = &stats } - if sources.discord && !sources.wiretap { + var mediaStats *media.FetchStats + if withMedia { + r.setSyncLockPhase("attachment media fetch") + cacheDir, err := config.ExpandPath(r.cfg.CacheDir) + if err != nil { + return configErr(err) + } + channelIDs := opts.ChannelIDs + if len(channelIDs) > 0 { + channelIDs, err = r.store.ExpandAttachmentChannelIDs(r.ctx, channelIDs) + if err != nil { + return err + } + } + stats, err := r.fetchSyncMedia(sources, opts, cacheDir, channelIDs) + if err != nil { + return err + } + mediaStats = stats + } + if sources.discord && !sources.wiretap && mediaStats == nil { return r.print(*apiStats) } - if sources.wiretap && !sources.discord { + if sources.wiretap && !sources.discord && mediaStats == nil { return r.print(*wiretapStats) } - return r.print(syncRunStats{Source: sources.name, Discord: apiStats, Wiretap: wiretapStats}) + return r.print(syncRunStats{Source: sources.name, Discord: apiStats, Wiretap: wiretapStats, Media: mediaStats}) +} + +func (r *runtime) fetchSyncMedia(sources syncSources, opts syncer.SyncOptions, cacheDir string, channelIDs []string) (*media.FetchStats, error) { + total := media.FetchStats{} + if sources.discord { + stats, err := media.Fetch(r.ctx, r.store, media.FetchOptions{ + CacheDir: cacheDir, + MaxBytes: r.cfg.Sync.MaxAttachmentBytes, + List: store.AttachmentListOptions{ + GuildIDs: opts.GuildIDs, + ExcludeGuildIDs: []string{store.DirectMessageGuildID}, + ChannelIDs: channelIDs, + Since: opts.Since, + }, + StatusUpdate: true, + Now: r.now, + }) + if err != nil { + return nil, err + } + total = addFetchStats(total, stats) + } + if sources.wiretap { + stats, err := media.Fetch(r.ctx, r.store, media.FetchOptions{ + CacheDir: cacheDir, + MaxBytes: r.cfg.Sync.MaxAttachmentBytes, + List: store.AttachmentListOptions{ + GuildIDs: []string{store.DirectMessageGuildID}, + }, + StatusUpdate: true, + Now: r.now, + }) + if err != nil { + return nil, err + } + total = addFetchStats(total, stats) + } + return &total, nil +} + +func addFetchStats(a, b media.FetchStats) media.FetchStats { + a.Attachments += b.Attachments + a.Fetched += b.Fetched + a.Reused += b.Reused + a.Skipped += b.Skipped + a.Failed += b.Failed + a.Bytes += b.Bytes + return a } func defaultLatestSyncMode(full bool, allChannels bool, since string, channels string) bool { diff --git a/internal/cli/attachments.go b/internal/cli/attachments.go new file mode 100644 index 0000000..946e15c --- /dev/null +++ b/internal/cli/attachments.go @@ -0,0 +1,226 @@ +package cli + +import ( + "errors" + "flag" + "fmt" + "io" + "os" + "strings" + "time" + + "github.com/openclaw/discrawl/internal/config" + "github.com/openclaw/discrawl/internal/media" + "github.com/openclaw/discrawl/internal/store" +) + +func (r *runtime) runAttachments(args []string) error { + if len(args) > 0 && args[0] == "fetch" { + return r.runAttachmentsFetch(args[1:]) + } + opts, limit, err := r.parseAttachmentListFlags("attachments", args, defaultMessageLimit) + if err != nil { + return err + } + missingOnly := opts.MissingOnly + if missingOnly { + opts.MissingOnly = false + opts.Limit = 0 + } else { + opts.Limit = limit + } + rows, err := r.store.ListAttachments(r.ctx, opts) + if err != nil { + return err + } + if missingOnly { + cacheDir, err := config.ExpandPath(r.cfg.CacheDir) + if err != nil { + return configErr(err) + } + rows = filterMissingAttachmentMedia(cacheDir, rows) + if limit > 0 && len(rows) > limit { + rows = rows[:limit] + } + } + return r.print(rows) +} + +func (r *runtime) runAttachmentsFetch(args []string) error { + listArgs := stripFlags(args, map[string]struct{}{"force": {}, "max-bytes": {}}) + opts, limit, err := r.parseAttachmentListFlags("attachments fetch", listArgs, defaultMessageLimit) + if err != nil { + return err + } + fs := flag.NewFlagSet("attachments fetch", flag.ContinueOnError) + fs.SetOutput(io.Discard) + force := fs.Bool("force", false, "") + maxBytes := fs.Int64("max-bytes", r.cfg.Sync.MaxAttachmentBytes, "") + if err := parseKnown(fs, args, attachmentListFlagNames()); err != nil { + return usageErr(err) + } + if *maxBytes <= 0 { + return usageErr(errors.New("--max-bytes must be positive")) + } + cacheDir, err := config.ExpandPath(r.cfg.CacheDir) + if err != nil { + return configErr(err) + } + opts.Limit = limit + stats, err := media.Fetch(r.ctx, r.store, media.FetchOptions{ + CacheDir: cacheDir, + List: opts, + MaxBytes: *maxBytes, + Force: *force, + StatusUpdate: true, + Now: r.now, + }) + if err != nil { + return err + } + return r.print(stats) +} + +func filterMissingAttachmentMedia(cacheDir string, rows []store.AttachmentRow) []store.AttachmentRow { + out := rows[:0] + for _, row := range rows { + if row.MediaPath == "" { + out = append(out, row) + continue + } + path, err := media.LocalPath(cacheDir, row.MediaPath) + if err != nil { + out = append(out, row) + continue + } + if _, err := os.Stat(path); err != nil { + out = append(out, row) + } + } + return out +} + +func (r *runtime) parseAttachmentListFlags(name string, args []string, defaultLimit int) (store.AttachmentListOptions, int, error) { + fs := flag.NewFlagSet(name, flag.ContinueOnError) + fs.SetOutput(io.Discard) + channel := fs.String("channel", "", "") + author := fs.String("author", "", "") + messageID := fs.String("message", "", "") + filename := fs.String("filename", "", "") + contentType := fs.String("type", "", "") + hours := fs.Int("hours", 0, "") + days := fs.Int("days", 0, "") + since := fs.String("since", "", "") + before := fs.String("before", "", "") + limit := fs.Int("limit", defaultLimit, "") + all := fs.Bool("all", false, "") + missing := fs.Bool("missing", false, "") + dm := fs.Bool("dm", false, "") + guildsFlag := fs.String("guilds", "", "") + guildFlag := fs.String("guild", "", "") + if err := fs.Parse(args); err != nil { + return store.AttachmentListOptions{}, 0, usageErr(err) + } + if fs.NArg() != 0 { + return store.AttachmentListOptions{}, 0, usageErr(errors.New(name + " takes flags only")) + } + if *hours < 0 || *days < 0 || *limit < 0 { + return store.AttachmentListOptions{}, 0, usageErr(errors.New("--hours, --days, and --limit must be >= 0")) + } + if countNonZero(*hours > 0, *days > 0, strings.TrimSpace(*since) != "") > 1 { + return store.AttachmentListOptions{}, 0, usageErr(errors.New("use only one of --hours, --days, or --since")) + } + var sinceTime time.Time + var beforeTime time.Time + var err error + if *hours > 0 { + sinceTime = r.nowUTC().Add(-time.Duration(*hours) * time.Hour) + } + if *days > 0 { + sinceTime = r.nowUTC().Add(-time.Duration(*days) * 24 * time.Hour) + } + if strings.TrimSpace(*since) != "" { + sinceTime, err = time.Parse(time.RFC3339, *since) + if err != nil { + return store.AttachmentListOptions{}, 0, usageErr(fmt.Errorf("invalid --since: %w", err)) + } + } + if strings.TrimSpace(*before) != "" { + beforeTime, err = time.Parse(time.RFC3339, *before) + if err != nil { + return store.AttachmentListOptions{}, 0, usageErr(fmt.Errorf("invalid --before: %w", err)) + } + } + guildIDs, err := directMessageGuildScope(*dm, *guildFlag, *guildsFlag) + if err != nil { + return store.AttachmentListOptions{}, 0, usageErr(err) + } + if *all { + *limit = 0 + } + return store.AttachmentListOptions{ + GuildIDs: guildIDs, + Channel: *channel, + Author: *author, + MessageID: *messageID, + Filename: *filename, + ContentType: *contentType, + Since: sinceTime, + Before: beforeTime, + MissingOnly: *missing, + }, *limit, nil +} + +func attachmentListFlagNames() map[string]struct{} { + return map[string]struct{}{ + "channel": {}, "author": {}, "message": {}, "filename": {}, "type": {}, + "hours": {}, "days": {}, "since": {}, "before": {}, "limit": {}, + "all": {}, "missing": {}, "dm": {}, "guilds": {}, "guild": {}, + } +} + +func parseKnown(fs *flag.FlagSet, args []string, known map[string]struct{}) error { + filtered := make([]string, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + if !strings.HasPrefix(arg, "-") { + filtered = append(filtered, arg) + continue + } + name := strings.TrimLeft(arg, "-") + if key, _, ok := strings.Cut(name, "="); ok { + name = key + } + if _, skip := known[name]; skip { + if !strings.Contains(arg, "=") && i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { + i++ + } + continue + } + filtered = append(filtered, arg) + } + return fs.Parse(filtered) +} + +func stripFlags(args []string, strip map[string]struct{}) []string { + out := make([]string, 0, len(args)) + for i := 0; i < len(args); i++ { + arg := args[i] + if !strings.HasPrefix(arg, "-") { + out = append(out, arg) + continue + } + name := strings.TrimLeft(arg, "-") + if key, _, ok := strings.Cut(name, "="); ok { + name = key + } + if _, ok := strip[name]; ok { + if !strings.Contains(arg, "=") && name == "max-bytes" && i+1 < len(args) { + i++ + } + continue + } + out = append(out, arg) + } + return out +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 8ccd3a5..a7c20f2 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -238,6 +238,15 @@ func (r *runtime) dispatch(rest []string) error { return r.withLocalStoreRead(false, func() error { return r.runDirectMessages(rest[1:]) }) case "mentions": return r.withLocalStoreRead(true, func() error { return r.runMentions(rest[1:]) }) + case "attachments": + if hasHelpArg(rest[1:]) { + return printCommandUsage(r.stdout, []string{"attachments"}) + } + autoShareUpdate := !hasBoolFlag(rest[1:], "--dm") + if len(rest) > 1 && rest[1] == "fetch" { + return r.withLocalStoreLocked(autoShareUpdate, func() error { return r.runAttachments(rest[1:]) }) + } + return r.withLocalStoreRead(autoShareUpdate, func() error { return r.runAttachments(rest[1:]) }) case "embed": return r.withLocalStoreLocked(true, func() error { return r.runEmbed(rest[1:]) }) case "sql": @@ -413,6 +422,29 @@ func (r *runtime) withLocalStoreReadOnly(fn func() error) error { r.store = nil return fn() } + if errors.Is(openErr, store.ErrSchemaVersionMismatch) { + if err := r.withSyncLock(func() error { + storeFactory := r.openStore + if storeFactory == nil { + storeFactory = store.Open + } + var migrateErr error + r.store, migrateErr = storeFactory(r.ctx, dbPath) + if migrateErr != nil { + return dbErr(migrateErr) + } + closeErr := r.store.Close() + r.store = nil + return closeErr + }); err != nil { + return err + } + r.store, openErr = store.OpenReadOnly(r.ctx, dbPath) + if openErr == nil { + defer func() { _ = r.store.Close() }() + return fn() + } + } return dbErr(openErr) } defer func() { _ = r.store.Close() }() @@ -581,10 +613,16 @@ func (r *runtime) shareOptions() (share.Options, error) { if err != nil { return share.Options{}, configErr(err) } + cacheDir, err := config.ExpandPath(r.cfg.CacheDir) + if err != nil { + return share.Options{}, configErr(err) + } return share.Options{ - RepoPath: repoPath, - Remote: r.cfg.Share.Remote, - Branch: r.cfg.Share.Branch, - Progress: r.shareProgress, + RepoPath: repoPath, + CacheDir: cacheDir, + Remote: r.cfg.Share.Remote, + Branch: r.cfg.Share.Branch, + IncludeMedia: r.cfg.ShareMediaEnabled(), + Progress: r.shareProgress, }, nil } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 628b6b4..745ca6f 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "flag" "io" "log/slog" "net/http" @@ -23,6 +24,7 @@ import ( "github.com/openclaw/discrawl/internal/config" discordclient "github.com/openclaw/discrawl/internal/discord" "github.com/openclaw/discrawl/internal/discorddesktop" + "github.com/openclaw/discrawl/internal/media" "github.com/openclaw/discrawl/internal/report" "github.com/openclaw/discrawl/internal/share" "github.com/openclaw/discrawl/internal/store" @@ -629,6 +631,132 @@ func TestParseSyncSources(t *testing.T) { require.Error(t, err) } +func TestFetchSyncMediaScopesWiretapToDMs(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + s := seedCLIStore(t, filepath.Join(dir, "discrawl.db")) + defer func() { _ = s.Close() }() + require.NoError(t, addCLIAttachment(ctx, s, "")) + require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ID: store.DirectMessageGuildID, Name: "Discord Direct Messages", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ID: "dm-c1", GuildID: store.DirectMessageGuildID, Kind: "dm", Name: "Alice", RawJSON: `{}`})) + now := time.Now().UTC().Format(time.RFC3339Nano) + require.NoError(t, s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "dm1", + GuildID: store.DirectMessageGuildID, + ChannelID: "dm-c1", + ChannelName: "Alice", + AuthorID: "u2", + AuthorName: "Alice", + MessageType: 0, + CreatedAt: now, + Content: "dm attachment", + NormalizedContent: "dm attachment private.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a-dm", + MessageID: "dm1", + GuildID: store.DirectMessageGuildID, + ChannelID: "dm-c1", + AuthorID: "u2", + Filename: "private.png", + ContentType: "image/png", + }}, + }})) + cfg := config.Default() + rt := &runtime{ctx: ctx, cfg: cfg, store: s, now: time.Now} + + stats, err := rt.fetchSyncMedia(syncSources{name: "wiretap", wiretap: true}, syncer.SyncOptions{GuildIDs: []string{"g1"}}, filepath.Join(dir, "cache"), nil) + require.NoError(t, err) + require.Equal(t, &media.FetchStats{Attachments: 1, Skipped: 1}, stats) + + guildRows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m100"}) + require.NoError(t, err) + require.Len(t, guildRows, 1) + require.Empty(t, guildRows[0].FetchStatus) + dmRows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "dm1"}) + require.NoError(t, err) + require.Len(t, dmRows, 1) + require.Equal(t, "no_url", dmRows[0].FetchStatus) +} + +func TestFetchSyncMediaHonorsSince(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + s := seedCLIStore(t, filepath.Join(dir, "discrawl.db")) + defer func() { _ = s.Close() }() + cutoff := time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC) + require.NoError(t, s.UpsertMessages(ctx, []store.MessageMutation{ + { + Record: store.MessageRecord{ + ID: "m-old", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: cutoff.Add(-time.Hour).Format(time.RFC3339Nano), + Content: "old attachment", + NormalizedContent: "old attachment old.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a-old", + MessageID: "m-old", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "old.png", + ContentType: "image/png", + }}, + }, + { + Record: store.MessageRecord{ + ID: "m-new", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: cutoff.Add(time.Hour).Format(time.RFC3339Nano), + Content: "new attachment", + NormalizedContent: "new attachment new.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a-new", + MessageID: "m-new", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "new.png", + ContentType: "image/png", + }}, + }, + })) + cfg := config.Default() + rt := &runtime{ctx: ctx, cfg: cfg, store: s, now: time.Now} + + stats, err := rt.fetchSyncMedia(syncSources{name: "discord", discord: true}, syncer.SyncOptions{GuildIDs: []string{"g1"}, Since: cutoff}, filepath.Join(dir, "cache"), nil) + require.NoError(t, err) + require.Equal(t, &media.FetchStats{Attachments: 1, Skipped: 1}, stats) + + oldRows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m-old"}) + require.NoError(t, err) + require.Len(t, oldRows, 1) + require.Empty(t, oldRows[0].FetchStatus) + newRows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m-new"}) + require.NoError(t, err) + require.Len(t, newRows, 1) + require.Equal(t, "no_url", newRows[0].FetchStatus) +} + func TestReadCommandsAutoImportStaleShare(t *testing.T) { ctx := context.Background() dir := t.TempDir() @@ -671,6 +799,172 @@ func TestReadCommandsAutoImportStaleShare(t *testing.T) { require.NotEmpty(t, lastImport) } +func TestAttachmentsCommandListsAndFetchesMedia(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + body := []byte("png-ish") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/file.png" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "image/png") + _, _ = w.Write(body) + })) + defer server.Close() + + dbPath := filepath.Join(dir, "discrawl.db") + s := seedCLIStore(t, dbPath) + require.NoError(t, addCLIAttachment(ctx, s, server.URL+"/file.png")) + require.NoError(t, s.Close()) + + cfgPath := filepath.Join(dir, "config.toml") + cfg := config.Default() + cfg.DBPath = dbPath + cfg.CacheDir = filepath.Join(dir, "cache") + cfg.Share.Remote = "" + require.NoError(t, config.Write(cfgPath, cfg)) + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "file.png") + require.Contains(t, out.String(), "image/png") + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "--author", "Peter", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "file.png") + + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "fetch", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "fetched=1") + + check, err := store.Open(ctx, dbPath) + require.NoError(t, err) + defer func() { _ = check.Close() }() + rows, err := check.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m100"}) + require.NoError(t, err) + require.Len(t, rows, 1) + path, err := media.LocalPath(cfg.CacheDir, rows[0].MediaPath) + require.NoError(t, err) + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, body, got) + + require.NoError(t, os.Remove(path)) + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "--missing", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "file.png") + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "fetch", "--missing", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "fetched=1") + got, err = os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, body, got) +} + +func TestAttachmentsDMCommandsSkipShareAutoUpdate(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + dbPath := filepath.Join(dir, "discrawl.db") + s := seedCLIStore(t, dbPath) + require.NoError(t, addCLIDMAttachment(ctx, s)) + require.NoError(t, s.Close()) + + cfgPath := filepath.Join(dir, "config.toml") + cfg := config.Default() + cfg.DBPath = dbPath + cfg.CacheDir = filepath.Join(dir, "cache") + cfg.Share.Remote = filepath.Join(dir, "missing-remote.git") + cfg.Share.RepoPath = filepath.Join(dir, "share") + cfg.Share.AutoUpdate = true + require.NoError(t, config.Write(cfgPath, cfg)) + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "--dm", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "private.png") + out.Reset() + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "attachments", "fetch", "--dm", "--all"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "skipped=1") +} + +func TestAttachmentFlagParsing(t *testing.T) { + t.Parallel() + + rt := &runtime{ + now: func() time.Time { return time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC) }, + } + opts, limit, err := rt.parseAttachmentListFlags("attachments", []string{ + "--channel", "general", + "--author=Peter", + "--message", "m1", + "--filename", "file", + "--type", "image", + "--hours", "2", + "--before", "2026-05-15T13:00:00Z", + "--missing", + "--guilds", "g1,g2", + "--all", + }, 20) + require.NoError(t, err) + require.Zero(t, limit) + require.Equal(t, []string{"g1", "g2"}, opts.GuildIDs) + require.Equal(t, "general", opts.Channel) + require.Equal(t, "Peter", opts.Author) + require.Equal(t, "m1", opts.MessageID) + require.Equal(t, "file", opts.Filename) + require.Equal(t, "image", opts.ContentType) + require.True(t, opts.MissingOnly) + require.Equal(t, time.Date(2026, 5, 15, 10, 0, 0, 0, time.UTC), opts.Since) + + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"--hours", "1", "--since", "2026-05-15T10:00:00Z"}, 20) + require.Error(t, err) + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"positional"}, 20) + require.Error(t, err) + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"--limit", "-1"}, 20) + require.Error(t, err) + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"--since", "bad"}, 20) + require.Error(t, err) + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"--before", "bad"}, 20) + require.Error(t, err) + _, _, err = rt.parseAttachmentListFlags("attachments", []string{"--dm", "--guild", "g1"}, 20) + require.Error(t, err) + + require.Equal(t, []string{"--channel", "general", "tail"}, stripFlags([]string{"--force", "--max-bytes", "10", "--channel", "general", "tail"}, map[string]struct{}{"force": {}, "max-bytes": {}})) + fs := flag.NewFlagSet("attachments fetch", flag.ContinueOnError) + fs.SetOutput(io.Discard) + force := fs.Bool("force", false, "") + maxBytes := fs.Int64("max-bytes", 0, "") + require.NoError(t, parseKnown(fs, []string{"--force", "--max-bytes=10", "--channel", "general"}, attachmentListFlagNames())) + require.True(t, *force) + require.Equal(t, int64(10), *maxBytes) +} + +func TestFilterMissingAttachmentMediaKeepsInvalidAndMissingRows(t *testing.T) { + t.Parallel() + + cacheDir := t.TempDir() + existingPath := "attachments/aa/file.png" + fullPath, err := media.LocalPath(cacheDir, existingPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(fullPath), 0o755)) + require.NoError(t, os.WriteFile(fullPath, []byte("cached"), 0o600)) + + rows := filterMissingAttachmentMedia(cacheDir, []store.AttachmentRow{ + {AttachmentID: "empty"}, + {AttachmentID: "invalid", MediaPath: "../bad"}, + {AttachmentID: "missing", MediaPath: "attachments/bb/missing.png"}, + {AttachmentID: "existing", MediaPath: existingPath}, + }) + require.Equal(t, []string{"empty", "invalid", "missing"}, attachmentIDs(rows)) +} + +func attachmentIDs(rows []store.AttachmentRow) []string { + out := make([]string, 0, len(rows)) + for _, row := range rows { + out = append(out, row.AttachmentID) + } + return out +} + func TestReadCommandsCanDisableAutoImportWithEnv(t *testing.T) { ctx := context.Background() dir := t.TempDir() @@ -714,6 +1008,19 @@ func TestReadCommandsCanDisableAutoImportWithEnv(t *testing.T) { require.Empty(t, lastImport) } +func TestSubscribeNoMediaPersistsShareMediaOptOut(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "subscribe", "--no-import", "--no-media", "https://github.com/example/archive.git"}, &out, &bytes.Buffer{})) + + cfg, err := config.Load(cfgPath) + require.NoError(t, err) + require.False(t, cfg.ShareMediaEnabled()) +} + func TestShareCommandsPublishSubscribeAndUpdate(t *testing.T) { ctx := context.Background() dir := t.TempDir() @@ -1168,7 +1475,32 @@ func TestReadCommandsMigrateOlderLocalStore(t *testing.T) { defer func() { _ = reader.Close() }() var version int require.NoError(t, reader.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) - require.Equal(t, 2, version) + require.Equal(t, 3, version) +} + +func TestReadOnlyCommandsMigrateOlderLocalStore(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + cfg := config.Default() + cfg.DBPath = filepath.Join(dir, "discrawl.db") + cfgPath := filepath.Join(dir, "config.toml") + require.NoError(t, config.Write(cfgPath, cfg)) + + s := seedCLIStore(t, cfg.DBPath) + _, err := s.DB().ExecContext(ctx, `pragma user_version = 1`) + require.NoError(t, err) + require.NoError(t, s.Close()) + + var out bytes.Buffer + require.NoError(t, Run(ctx, []string{"--config", cfgPath, "status"}, &out, &bytes.Buffer{})) + require.Contains(t, out.String(), "messages=1") + + reader, err := store.OpenReadOnly(ctx, cfg.DBPath) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + var version int + require.NoError(t, reader.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) + require.Equal(t, 3, version) } func seedCLIStore(t *testing.T, path string) *store.Store { @@ -1195,6 +1527,72 @@ func seedCLIStore(t *testing.T, path string) *store.Store { return s } +func addCLIAttachment(ctx context.Context, s *store.Store, url string) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + return s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "m100", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: now, + Content: "automatic updates work", + NormalizedContent: "automatic updates work file.png", + HasAttachments: true, + RawJSON: `{"author":{"username":"Peter","global_name":"Peter"}}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a-cli", + MessageID: "m100", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "file.png", + ContentType: "image/png", + Size: 7, + URL: url, + }}, + }}) +} + +func addCLIDMAttachment(ctx context.Context, s *store.Store) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + if err := s.UpsertGuild(ctx, store.GuildRecord{ID: store.DirectMessageGuildID, Name: "Discord Direct Messages", RawJSON: `{}`}); err != nil { + return err + } + if err := s.UpsertChannel(ctx, store.ChannelRecord{ID: "dm-c1", GuildID: store.DirectMessageGuildID, Kind: "dm", Name: "Alice", RawJSON: `{}`}); err != nil { + return err + } + return s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "dm1", + GuildID: store.DirectMessageGuildID, + ChannelID: "dm-c1", + ChannelName: "Alice", + AuthorID: "u2", + AuthorName: "Alice", + MessageType: 0, + CreatedAt: now, + Content: "dm attachment", + NormalizedContent: "dm attachment private.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a-dm", + MessageID: "dm1", + GuildID: store.DirectMessageGuildID, + ChannelID: "dm-c1", + AuthorID: "u2", + Filename: "private.png", + ContentType: "image/png", + }}, + }}) +} + func publishSnapshot(t *testing.T, ctx context.Context, s *store.Store, opts share.Options, message string) { t.Helper() _, err := share.Export(ctx, s, opts) diff --git a/internal/cli/output.go b/internal/cli/output.go index fc81115..2ccd9b4 100644 --- a/internal/cli/output.go +++ b/internal/cli/output.go @@ -12,6 +12,7 @@ import ( "time" "github.com/openclaw/discrawl/internal/discorddesktop" + "github.com/openclaw/discrawl/internal/media" "github.com/openclaw/discrawl/internal/report" "github.com/openclaw/discrawl/internal/store" "github.com/openclaw/discrawl/internal/syncer" @@ -115,6 +116,7 @@ Commands: analytics dms mentions + attachments embed sql members @@ -150,6 +152,31 @@ Flags: --dm Search local desktop DM cache. --guild ID Restrict to one guild id. --guilds ID,ID Restrict to guild ids. +`, + "attachments": `Usage: + discrawl attachments [flags] + discrawl attachments fetch [flags] + +Flags: + --channel ID_OR_NAME Filter by channel id or name. + --author ID_OR_NAME Filter by author id or name. + --message ID Filter by message id. + --filename TEXT Filter by filename. + --type TEXT Filter by content type. + --hours N Attachments from the last N hours. + --days N Attachments from the last N days. + --since RFC3339 Attachments at or after timestamp. + --before RFC3339 Attachments before timestamp. + --limit N Maximum rows. Default: 200. + --all Return all matching rows. + --missing Only attachments without cached media. + --dm Read local desktop DM cache. + --guild ID Restrict to one guild id. + --guilds ID,ID Restrict to guild ids. + +Fetch flags: + --force Re-download already cached attachments. + --max-bytes N Per-attachment download cap. `, "messages": `Usage: discrawl messages [flags] @@ -209,6 +236,12 @@ func printHuman(w io.Writer, value any) error { return err } } + if v.Media != nil { + if _, err := fmt.Fprintf(w, "media_attachments=%d\nmedia_fetched=%d\nmedia_reused=%d\nmedia_skipped=%d\nmedia_failed=%d\nmedia_bytes=%d\n", + v.Media.Attachments, v.Media.Fetched, v.Media.Reused, v.Media.Skipped, v.Media.Failed, v.Media.Bytes); err != nil { + return err + } + } return nil case syncer.SyncStats: _, err := fmt.Fprintf(w, "guilds=%d channels=%d threads=%d members=%d messages=%d\n", v.Guilds, v.Channels, v.Threads, v.Members, v.Messages) @@ -251,6 +284,19 @@ func printHuman(w io.Writer, value any) error { } } return nil + case []store.AttachmentRow: + for _, row := range v { + if _, err := fmt.Fprintf(w, "[%s/%s] %s %s\n%s type=%s size=%d status=%s media=%s url=%s\n\n", + row.GuildID, row.ChannelName, row.AuthorName, formatTime(row.CreatedAt), row.Filename, + row.ContentType, row.Size, firstNonEmpty(row.FetchStatus, "metadata"), row.MediaPath, firstNonEmpty(row.URL, row.ProxyURL)); err != nil { + return err + } + } + return nil + case media.FetchStats: + _, err := fmt.Fprintf(w, "attachments=%d\nfetched=%d\nreused=%d\nskipped=%d\nfailed=%d\nbytes=%d\n", + v.Attachments, v.Fetched, v.Reused, v.Skipped, v.Failed, v.Bytes) + return err case []store.DirectMessageConversationRow: tw := tabwriter.NewWriter(w, 2, 4, 2, ' ', 0) _, _ = fmt.Fprintln(tw, "CHANNEL\tNAME\tMESSAGES\tAUTHORS\tFIRST\tLAST") diff --git a/internal/cli/share_commands.go b/internal/cli/share_commands.go index dc2dd89..250f4f2 100644 --- a/internal/cli/share_commands.go +++ b/internal/cli/share_commands.go @@ -25,6 +25,7 @@ func (r *runtime) runPublish(args []string) error { noCommit := fs.Bool("no-commit", false, "") push := fs.Bool("push", false, "") withEmbeddings := fs.Bool("with-embeddings", false, "") + noMedia := fs.Bool("no-media", !r.cfg.ShareMediaEnabled(), "") publicOnly := fs.Bool("public-only", r.cfg.Share.Filter.PublicOnly, "") includeChannels := fs.String("include-channels", strings.Join(r.cfg.Share.Filter.IncludeChannelIDs, ","), "") excludeChannels := fs.String("exclude-channels", strings.Join(r.cfg.Share.Filter.ExcludeChannelIDs, ","), "") @@ -49,6 +50,9 @@ func (r *runtime) runPublish(args []string) error { if *withEmbeddings { applyEmbeddingShareOptions(&opts, r.cfg) } + if err := applyMediaShareOptions(&opts, r.cfg, !*noMedia); err != nil { + return err + } manifest, err := share.Export(r.ctx, r.store, opts) if err != nil { return err @@ -95,6 +99,7 @@ func (r *runtime) runPublish(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "media": manifest.Media, "embeddings": manifest.Embeddings, "readme": *readmePath, "committed": committed, @@ -127,6 +132,7 @@ func (r *runtime) runSubscribe(args []string) error { noAutoUpdate := fs.Bool("no-auto-update", false, "") noImport := fs.Bool("no-import", false, "") withEmbeddings := fs.Bool("with-embeddings", false, "") + noMedia := fs.Bool("no-media", false, "") if err := fs.Parse(args); err != nil { return usageErr(err) } @@ -145,6 +151,9 @@ func (r *runtime) runSubscribe(args []string) error { cfg.Share.Branch = *branch cfg.Share.AutoUpdate = !*noAutoUpdate cfg.Share.StaleAfter = *staleAfter + if *noMedia { + cfg.Share.Media = new(false) + } cfg.Discord.TokenSource = "none" if err := config.Write(r.configPath, cfg); err != nil { return configErr(err) @@ -171,6 +180,9 @@ func (r *runtime) runSubscribe(args []string) error { return configErr(err) } opts := share.Options{RepoPath: expandedRepo, Remote: cfg.Share.Remote, Branch: cfg.Share.Branch, Progress: r.shareProgress} + if err := applyMediaShareOptions(&opts, cfg, cfg.ShareMediaEnabled() && !*noMedia); err != nil { + return err + } if *withEmbeddings { applyEmbeddingShareOptions(&opts, cfg) } @@ -189,6 +201,7 @@ func (r *runtime) runSubscribe(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "media": manifest.Media, "embeddings": manifest.Embeddings, "imported": imported, }) @@ -202,6 +215,7 @@ func (r *runtime) runUpdate(args []string) error { remote := fs.String("remote", r.cfg.Share.Remote, "") branch := fs.String("branch", r.cfg.Share.Branch, "") withEmbeddings := fs.Bool("with-embeddings", false, "") + noMedia := fs.Bool("no-media", !r.cfg.ShareMediaEnabled(), "") if err := fs.Parse(args); err != nil { return usageErr(err) } @@ -216,6 +230,9 @@ func (r *runtime) runUpdate(args []string) error { if *withEmbeddings { applyEmbeddingShareOptions(&opts, r.cfg) } + if err := applyMediaShareOptions(&opts, r.cfg, !*noMedia); err != nil { + return err + } r.setSyncLockPhase("share pull") if err := share.Pull(r.ctx, opts); err != nil { return err @@ -230,6 +247,7 @@ func (r *runtime) runUpdate(args []string) error { "remote": opts.Remote, "generated_at": manifest.GeneratedAt, "tables": manifest.Tables, + "media": manifest.Media, "embeddings": manifest.Embeddings, "imported": imported, }) @@ -256,6 +274,16 @@ func applyEmbeddingShareOptions(opts *share.Options, cfg config.Config) { opts.EmbeddingInputVersion = store.EmbeddingInputVersion } +func applyMediaShareOptions(opts *share.Options, cfg config.Config, enabled bool) error { + cacheDir, err := config.ExpandPath(cfg.CacheDir) + if err != nil { + return configErr(err) + } + opts.CacheDir = cacheDir + opts.IncludeMedia = enabled + return nil +} + func loadConfigOrDefault(path string) (config.Config, error) { cfg, err := config.Load(path) if err == nil { diff --git a/internal/config/config.go b/internal/config/config.go index 9d33c62..df81c88 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,11 +48,13 @@ type DesktopConfig struct { } type SyncConfig struct { - Source string `toml:"source"` - Concurrency int `toml:"concurrency"` - RepairEvery string `toml:"repair_every"` - FullHistory bool `toml:"full_history"` - AttachmentText *bool `toml:"attachment_text"` + Source string `toml:"source"` + Concurrency int `toml:"concurrency"` + RepairEvery string `toml:"repair_every"` + FullHistory bool `toml:"full_history"` + AttachmentText *bool `toml:"attachment_text"` + AttachmentMedia *bool `toml:"attachment_media"` + MaxAttachmentBytes int64 `toml:"max_attachment_bytes"` } type SearchConfig struct { @@ -66,6 +68,7 @@ type ShareConfig struct { Branch string `toml:"branch,omitempty"` AutoUpdate bool `toml:"auto_update"` StaleAfter string `toml:"stale_after"` + Media *bool `toml:"media"` Filter ShareFilterConfig `toml:"filter"` } @@ -128,11 +131,13 @@ func Default() Config { MaxFileBytes: 64 << 20, }, Sync: SyncConfig{ - Source: "both", - Concurrency: defaultSyncConcurrency(), - RepairEvery: "6h", - FullHistory: true, - AttachmentText: new(true), + Source: "both", + Concurrency: defaultSyncConcurrency(), + RepairEvery: "6h", + FullHistory: true, + AttachmentText: new(true), + AttachmentMedia: new(false), + MaxAttachmentBytes: 100 << 20, }, Search: SearchConfig{ DefaultMode: "fts", @@ -151,6 +156,7 @@ func Default() Config { Branch: "main", AutoUpdate: true, StaleAfter: "15m", + Media: new(true), }, } } @@ -256,6 +262,12 @@ func (c *Config) Normalize() error { if c.Sync.AttachmentText == nil { c.Sync.AttachmentText = new(true) } + if c.Sync.AttachmentMedia == nil { + c.Sync.AttachmentMedia = new(false) + } + if c.Sync.MaxAttachmentBytes <= 0 { + c.Sync.MaxAttachmentBytes = 100 << 20 + } if c.Search.DefaultMode == "" { c.Search.DefaultMode = "fts" } @@ -293,6 +305,9 @@ func (c *Config) Normalize() error { if c.Share.StaleAfter == "" { c.Share.StaleAfter = "15m" } + if c.Share.Media == nil { + c.Share.Media = new(true) + } c.Share.Filter.IncludeChannelIDs = uniqueStrings(c.Share.Filter.IncludeChannelIDs) c.Share.Filter.ExcludeChannelIDs = uniqueStrings(c.Share.Filter.ExcludeChannelIDs) if c.Search.Embeddings.MaxInputChars <= 0 { @@ -350,6 +365,14 @@ func (c Config) AttachmentTextEnabled() bool { return c.Sync.AttachmentText == nil || *c.Sync.AttachmentText } +func (c Config) AttachmentMediaEnabled() bool { + return c.Sync.AttachmentMedia != nil && *c.Sync.AttachmentMedia +} + +func (c Config) ShareMediaEnabled() bool { + return c.Share.Media == nil || *c.Share.Media +} + func (c Config) ShareEnabled() bool { return strings.TrimSpace(c.Share.Remote) != "" } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 319501e..a40a157 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -423,6 +423,20 @@ func TestAttachmentTextExplicitFalseSurvivesNormalize(t *testing.T) { require.False(t, cfg.AttachmentTextEnabled()) } +func TestMediaBooleansExplicitFalseSurviveNormalize(t *testing.T) { + t.Parallel() + + cfg := Default() + require.False(t, cfg.AttachmentMediaEnabled()) + require.True(t, cfg.ShareMediaEnabled()) + + cfg.Sync.AttachmentMedia = new(true) + cfg.Share.Media = new(false) + require.NoError(t, cfg.Normalize()) + require.True(t, cfg.AttachmentMediaEnabled()) + require.False(t, cfg.ShareMediaEnabled()) +} + func TestExpandPath(t *testing.T) { t.Parallel() diff --git a/internal/media/cache.go b/internal/media/cache.go new file mode 100644 index 0000000..67043cf --- /dev/null +++ b/internal/media/cache.go @@ -0,0 +1,360 @@ +package media + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "strings" + "time" + "unicode" + + "github.com/openclaw/discrawl/internal/store" +) + +const ( + DefaultMaxBytes int64 = 100 << 20 + cacheSubdir = "media" +) + +type FetchOptions struct { + CacheDir string + List store.AttachmentListOptions + MaxBytes int64 + Force bool + HTTPClient *http.Client + Now func() time.Time + StatusUpdate bool +} + +type FetchStats struct { + Attachments int `json:"attachments"` + Fetched int `json:"fetched"` + Reused int `json:"reused"` + Skipped int `json:"skipped"` + Failed int `json:"failed"` + Bytes int64 `json:"bytes"` +} + +func Fetch(ctx context.Context, s *store.Store, opts FetchOptions) (FetchStats, error) { + if strings.TrimSpace(opts.CacheDir) == "" { + return FetchStats{}, errors.New("cache dir is required") + } + if opts.MaxBytes <= 0 { + opts.MaxBytes = DefaultMaxBytes + } + if opts.HTTPClient == nil { + opts.HTTPClient = &http.Client{Timeout: 30 * time.Second} + } + if opts.Now == nil { + opts.Now = time.Now + } + list := opts.List + limit := list.Limit + missingOnly := list.MissingOnly + // A stored media_path is only metadata; the cache file may have been deleted + // or intentionally omitted from a Git snapshot import. Fetch all matching + // rows and skip only after checking the filesystem. + if missingOnly || !opts.Force { + list.MissingOnly = false + list.Limit = 0 + } + attachments, err := s.ListAttachments(ctx, list) + if err != nil { + return FetchStats{}, err + } + stats := FetchStats{} + for _, attachment := range attachments { + if err := ctx.Err(); err != nil { + return stats, err + } + if attachment.MediaPath != "" && (missingOnly || !opts.Force) { + if mediaFileReusable(opts.CacheDir, attachment) { + stats.Reused++ + continue + } + } + if limit > 0 && stats.Attachments >= limit { + break + } + stats.Attachments++ + result, err := fetchOne(ctx, opts, attachment) + switch { + case err != nil: + stats.Failed++ + if opts.StatusUpdate { + if err := s.UpdateAttachmentFetchStatus(ctx, attachment.AttachmentID, opts.Now().UTC().Format(time.RFC3339Nano), "failed", clampError(err.Error())); err != nil { + return stats, err + } + } + case result.status == "skipped": + stats.Skipped++ + if opts.StatusUpdate { + if err := s.UpdateAttachmentFetchStatus(ctx, attachment.AttachmentID, opts.Now().UTC().Format(time.RFC3339Nano), result.reason, ""); err != nil { + return stats, err + } + } + default: + stats.Fetched++ + stats.Bytes += result.size + if err := s.UpdateAttachmentMedia(ctx, store.AttachmentMediaUpdate{ + AttachmentID: attachment.AttachmentID, + MediaPath: result.mediaPath, + ContentSHA256: result.sha256, + ContentSize: result.size, + FetchedAt: opts.Now().UTC().Format(time.RFC3339Nano), + FetchStatus: "fetched", + }); err != nil { + return stats, err + } + } + } + return stats, nil +} + +func mediaFileReusable(cacheDir string, attachment store.AttachmentRow) bool { + path, err := LocalPath(cacheDir, attachment.MediaPath) + if err != nil { + return false + } + info, err := os.Lstat(path) + if err != nil || !info.Mode().IsRegular() { + return false + } + if attachment.ContentSHA256 == "" { + return true + } + current, err := fileSHA256(path) + return err == nil && current == attachment.ContentSHA256 +} + +type fetchResult struct { + status string + reason string + mediaPath string + sha256 string + size int64 +} + +func fetchOne(ctx context.Context, opts FetchOptions, attachment store.AttachmentRow) (fetchResult, error) { + urls := candidateURLs(attachment) + if len(urls) == 0 { + return fetchResult{status: "skipped", reason: "no_url"}, nil + } + var lastErr error + for _, url := range urls { + result, err := fetchURL(ctx, opts, attachment, url) + if err == nil || result.status == "skipped" { + return result, err + } + lastErr = err + } + return fetchResult{}, lastErr +} + +func candidateURLs(attachment store.AttachmentRow) []string { + seen := map[string]struct{}{} + out := []string{} + for _, raw := range []string{attachment.URL, attachment.ProxyURL} { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if _, ok := seen[raw]; ok { + continue + } + seen[raw] = struct{}{} + out = append(out, raw) + } + return out +} + +func fetchURL(ctx context.Context, opts FetchOptions, attachment store.AttachmentRow, url string) (fetchResult, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fetchResult{}, err + } + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return fetchResult{}, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fetchResult{}, fmt.Errorf("GET %s returned HTTP %d", url, resp.StatusCode) + } + if resp.ContentLength > opts.MaxBytes { + return fetchResult{status: "skipped", reason: "too_large"}, nil + } + body, err := io.ReadAll(io.LimitReader(resp.Body, opts.MaxBytes+1)) + if err != nil { + return fetchResult{}, err + } + if int64(len(body)) > opts.MaxBytes { + return fetchResult{status: "skipped", reason: "too_large"}, nil + } + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := attachmentMediaPath(hash, attachment.Filename, attachment.ContentType) + target, err := LocalPath(opts.CacheDir, mediaPath) + if err != nil { + return fetchResult{}, err + } + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return fetchResult{}, err + } + needsWrite, err := mediaTargetNeedsWrite(target, hash) + if err != nil { + return fetchResult{}, err + } + if needsWrite || opts.Force { + tmp, err := os.CreateTemp(filepath.Dir(target), ".download-*") + if err != nil { + return fetchResult{}, err + } + tmpPath := tmp.Name() + if _, err := tmp.Write(body); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return fetchResult{}, err + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return fetchResult{}, err + } + if info, err := os.Lstat(target); err == nil && !info.Mode().IsRegular() { + _ = os.Remove(target) + } + if err := os.Rename(tmpPath, target); err != nil { + _ = os.Remove(tmpPath) + return fetchResult{}, err + } + } + return fetchResult{mediaPath: mediaPath, sha256: hash, size: int64(len(body))}, nil +} + +func mediaTargetNeedsWrite(target, hash string) (bool, error) { + info, err := os.Lstat(target) + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + if err != nil { + return false, err + } + if !info.Mode().IsRegular() { + return true, nil + } + current, err := fileSHA256(target) + if err != nil { + return false, err + } + return current != hash, nil +} + +func attachmentMediaPath(hash, filename, contentType string) string { + name := safeFilename(filename) + if name == "" { + name = "attachment" + extensionForContentType(contentType) + } + name = truncateFilename(name, 190) + return filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-"+name)) +} + +func truncateFilename(name string, maxLen int) string { + if len(name) <= maxLen { + return name + } + ext := filepath.Ext(name) + if len(ext) >= maxLen { + return name[:maxLen] + } + baseLen := maxLen - len(ext) + if baseLen <= 0 { + return name[:maxLen] + } + base := strings.TrimRight(name[:baseLen], ".-") + if base == "" { + return strings.TrimLeft(ext, ".") + } + return base + ext +} + +func extensionForContentType(contentType string) string { + mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType)) + if err != nil || mediaType == "" { + return "" + } + exts, err := mime.ExtensionsByType(mediaType) + if err != nil || len(exts) == 0 { + return "" + } + return exts[0] +} + +func safeFilename(raw string) string { + raw = filepath.Base(strings.TrimSpace(raw)) + var b strings.Builder + for _, r := range raw { + switch { + case r == '.' || r == '-' || r == '_': + b.WriteRune(r) + case r <= unicode.MaxASCII && (unicode.IsLetter(r) || unicode.IsDigit(r)): + b.WriteRune(r) + case unicode.IsSpace(r): + b.WriteByte('-') + } + } + return strings.Trim(strings.TrimSpace(b.String()), ".-") +} + +func fileSHA256(path string) (string, error) { + file, err := os.Open(path) // #nosec G304 -- callers pass confined cache paths. + if err != nil { + return "", err + } + defer func() { _ = file.Close() }() + hasher := sha256.New() + if _, err := io.Copy(hasher, file); err != nil { + return "", err + } + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func LocalPath(cacheDir, mediaPath string) (string, error) { + root := filepath.Clean(filepath.Join(cacheDir, cacheSubdir)) + clean := filepath.Clean(filepath.FromSlash(strings.TrimSpace(mediaPath))) + if clean == "." || filepath.IsAbs(clean) || clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("invalid media path %q", mediaPath) + } + full := filepath.Clean(filepath.Join(root, clean)) + if full != root && !strings.HasPrefix(full, root+string(filepath.Separator)) { + return "", fmt.Errorf("media path escapes cache: %q", mediaPath) + } + return full, nil +} + +func RepoPath(repoPath, mediaPath string) (string, error) { + root := filepath.Clean(filepath.Join(repoPath, cacheSubdir)) + clean := filepath.Clean(filepath.FromSlash(strings.TrimSpace(mediaPath))) + if clean == "." || filepath.IsAbs(clean) || clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("invalid media path %q", mediaPath) + } + full := filepath.Clean(filepath.Join(root, clean)) + if full != root && !strings.HasPrefix(full, root+string(filepath.Separator)) { + return "", fmt.Errorf("media path escapes repo: %q", mediaPath) + } + return full, nil +} + +func clampError(message string) string { + message = strings.TrimSpace(message) + if len(message) <= 512 { + return message + } + return message[:512] +} diff --git a/internal/media/cache_test.go b/internal/media/cache_test.go new file mode 100644 index 0000000..b7e546c --- /dev/null +++ b/internal/media/cache_test.go @@ -0,0 +1,478 @@ +package media + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/openclaw/discrawl/internal/store" + "github.com/stretchr/testify/require" +) + +func TestFetchCachesAttachmentMedia(t *testing.T) { + t.Parallel() + + ctx := context.Background() + body := []byte("image-bytes") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/file.png" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "image/png") + _, _ = w.Write(body) + })) + defer server.Close() + + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachment(ctx, s, server.URL+"/file.png")) + + cacheDir := t.TempDir() + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + StatusUpdate: true, + Now: func() time.Time { return time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC) }, + }) + require.NoError(t, err) + require.Equal(t, FetchStats{Attachments: 1, Fetched: 1, Bytes: int64(len(body))}, stats) + + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "fetched", rows[0].FetchStatus) + sum := sha256.Sum256(body) + require.Equal(t, hex.EncodeToString(sum[:]), rows[0].ContentSHA256) + path, err := LocalPath(cacheDir, rows[0].MediaPath) + require.NoError(t, err) + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, body, got) + + stats, err = Fetch(ctx, s, FetchOptions{CacheDir: cacheDir}) + require.NoError(t, err) + require.Equal(t, 1, stats.Reused) +} + +func TestFetchLimitAppliesAfterExistingCacheCheck(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m1", "a1", "https://example.test/one.png")) + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m2", "a2", "https://example.test/two.png")) + + cacheDir := t.TempDir() + _, err = Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient([]byte("one")), + List: store.AttachmentListOptions{MessageID: "m1"}, + }) + require.NoError(t, err) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient([]byte("two")), + List: store.AttachmentListOptions{Limit: 1, MissingOnly: true}, + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Reused) + require.Equal(t, 1, stats.Attachments) + require.Equal(t, 1, stats.Fetched) +} + +func TestFetchForceRewritesCachedMedia(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + body := []byte("canonical") + require.NoError(t, seedAttachment(ctx, s, "https://example.test/file.png")) + + cacheDir := t.TempDir() + _, err = Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient(body), + }) + require.NoError(t, err) + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + path, err := LocalPath(cacheDir, rows[0].MediaPath) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, []byte("corrupt"), 0o600)) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient(body), + Force: true, + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Fetched) + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, body, got) +} + +func TestFetchForceMissingSkipsReusableCachedMedia(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m1", "a1", "https://example.test/one.png")) + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m2", "a2", "https://example.test/two.png")) + + cacheDir := t.TempDir() + _, err = Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient([]byte("one")), + List: store.AttachmentListOptions{MessageID: "m1"}, + }) + require.NoError(t, err) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient([]byte("two")), + Force: true, + List: store.AttachmentListOptions{Limit: 1, MissingOnly: true}, + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Reused) + require.Equal(t, 1, stats.Attachments) + require.Equal(t, 1, stats.Fetched) + + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + path, err := LocalPath(cacheDir, rows[0].MediaPath) + require.NoError(t, err) + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, []byte("one"), got) +} + +func TestFetchRepairsCorruptCachedMedia(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + body := []byte("canonical") + require.NoError(t, seedAttachment(ctx, s, "https://example.test/file.png")) + + cacheDir := t.TempDir() + _, err = Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient(body), + }) + require.NoError(t, err) + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + path, err := LocalPath(cacheDir, rows[0].MediaPath) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, []byte("corrupt"), 0o600)) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient(body), + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Fetched) + got, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, body, got) +} + +func TestFetchCapsLongCacheFilename(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + longName := strings.Repeat("a", 320) + ".png" + require.NoError(t, seedAttachmentRecord(ctx, s, "m1", "a1", longName, "https://example.test/file.png")) + + cacheDir := t.TempDir() + _, err = Fetch(ctx, s, FetchOptions{ + CacheDir: cacheDir, + MaxBytes: 1024, + HTTPClient: staticHTTPClient([]byte("image")), + }) + require.NoError(t, err) + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.LessOrEqual(t, len(filepath.Base(rows[0].MediaPath)), 255) + require.Truef(t, strings.HasSuffix(filepath.Base(rows[0].MediaPath), ".png"), "media path %q", rows[0].MediaPath) + path, err := LocalPath(cacheDir, rows[0].MediaPath) + require.NoError(t, err) + require.FileExists(t, path) +} + +func TestFetchRecordsSkippedAndFailedStatuses(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m1", "a1", "")) + require.NoError(t, seedAttachmentWithIDs(ctx, s, "m2", "a2", "https://example.test/fail.png")) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: t.TempDir(), + MaxBytes: 1024, + StatusUpdate: true, + HTTPClient: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New(strings.Repeat("x", 600)) + })}, + }) + require.NoError(t, err) + require.Equal(t, 2, stats.Attachments) + require.Equal(t, 1, stats.Skipped) + require.Equal(t, 1, stats.Failed) + + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Equal(t, "no_url", rows[0].FetchStatus) + rows, err = s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m2"}) + require.NoError(t, err) + require.Equal(t, "failed", rows[0].FetchStatus) + require.Len(t, rows[0].FetchError, 512) +} + +func TestFetchSkipsOversizedResponses(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachment(ctx, s, "https://example.test/large.bin")) + + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: t.TempDir(), + MaxBytes: 4, + StatusUpdate: true, + HTTPClient: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + body := []byte("too-large") + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + Request: req, + }, nil + })}, + }) + require.NoError(t, err) + require.Equal(t, 1, stats.Skipped) + rows, err := s.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Equal(t, "too_large", rows[0].FetchStatus) + require.Empty(t, rows[0].MediaPath) +} + +func TestFetchUsesProxyFallbackAndBodyLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + require.NoError(t, seedAttachmentRecord(ctx, s, "m1", "a1", "file.bin", "https://example.test/primary")) + require.NoError(t, s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: time.Now().UTC().Format(time.RFC3339Nano), + Content: "see attached", + NormalizedContent: "see attached", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a1", + MessageID: "m1", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "file.bin", + ContentType: "application/octet-stream", + Size: 8, + URL: "https://example.test/primary", + ProxyURL: "https://example.test/proxy", + }}, + }})) + + calls := []string{} + stats, err := Fetch(ctx, s, FetchOptions{ + CacheDir: t.TempDir(), + MaxBytes: 4, + HTTPClient: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + calls = append(calls, req.URL.Path) + if req.URL.Path == "/primary" { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("nope")), + Request: req, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("12345")), + Request: req, + }, nil + })}, + }) + require.NoError(t, err) + require.Equal(t, []string{"/primary", "/proxy"}, calls) + require.Equal(t, 1, stats.Skipped) +} + +func TestMediaPathHelpers(t *testing.T) { + t.Parallel() + + hash := strings.Repeat("a", 64) + require.Equal(t, "attachments/aa/"+hash+"-attachment.png", attachmentMediaPath(hash, "", "image/png")) + require.Equal(t, "report-final.png", safeFilename(" ../report final!.png ")) + require.Equal(t, "name.png", truncateFilename("name.png", 20)) + require.Equal(t, "aaa.png", truncateFilename("aaaaa.png", 7)) + require.Empty(t, extensionForContentType("not a content type")) + require.Empty(t, clampError(" ")) + require.Equal(t, "short", clampError(" short ")) + + repo := t.TempDir() + path, err := RepoPath(repo, "attachments/aa/file.png") + require.NoError(t, err) + require.Equal(t, filepath.Join(repo, "media", "attachments", "aa", "file.png"), path) + _, err = RepoPath(repo, "../escape") + require.Error(t, err) + _, err = LocalPath(repo, "") + require.Error(t, err) +} + +func TestMediaTargetNeedsWrite(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + target := filepath.Join(dir, "file.bin") + hash := sha256.Sum256([]byte("same")) + hashString := hex.EncodeToString(hash[:]) + + needsWrite, err := mediaTargetNeedsWrite(target, hashString) + require.NoError(t, err) + require.True(t, needsWrite) + require.NoError(t, os.WriteFile(target, []byte("same"), 0o600)) + needsWrite, err = mediaTargetNeedsWrite(target, hashString) + require.NoError(t, err) + require.False(t, needsWrite) + needsWrite, err = mediaTargetNeedsWrite(target, strings.Repeat("0", 64)) + require.NoError(t, err) + require.True(t, needsWrite) + + require.NoError(t, os.Remove(target)) + require.NoError(t, os.Mkdir(target, 0o755)) + needsWrite, err = mediaTargetNeedsWrite(target, hashString) + require.NoError(t, err) + require.True(t, needsWrite) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func staticHTTPClient(body []byte) *http.Client { + return &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + Request: req, + }, nil + })} +} + +func seedAttachment(ctx context.Context, s *store.Store, url string) error { + return seedAttachmentWithIDs(ctx, s, "m1", "a1", url) +} + +func seedAttachmentWithIDs(ctx context.Context, s *store.Store, messageID, attachmentID, url string) error { + return seedAttachmentRecord(ctx, s, messageID, attachmentID, "file.png", url) +} + +func seedAttachmentRecord(ctx context.Context, s *store.Store, messageID, attachmentID, filename, url string) error { + now := time.Date(2026, 5, 15, 12, 0, 0, 0, time.UTC).Format(time.RFC3339Nano) + if err := s.UpsertGuild(ctx, store.GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`}); err != nil { + return err + } + if err := s.UpsertChannel(ctx, store.ChannelRecord{ID: "c1", GuildID: "g1", Kind: "text", Name: "general", RawJSON: `{}`}); err != nil { + return err + } + return s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: messageID, + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: now, + Content: "see attached", + NormalizedContent: "see attached", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: attachmentID, + MessageID: messageID, + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: filename, + ContentType: "image/png", + Size: 11, + URL: url, + }}, + }}) +} diff --git a/internal/share/share.go b/internal/share/share.go index 5edf550..050a585 100644 --- a/internal/share/share.go +++ b/internal/share/share.go @@ -3,8 +3,10 @@ package share import ( "compress/gzip" "context" + "crypto/sha256" "database/sql" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -13,6 +15,7 @@ import ( "os" "os/exec" "path/filepath" + "slices" "sort" "strconv" "strings" @@ -20,6 +23,7 @@ import ( "github.com/openclaw/crawlkit/mirror" "github.com/openclaw/crawlkit/snapshot" + "github.com/openclaw/discrawl/internal/media" "github.com/openclaw/discrawl/internal/store" ) @@ -31,7 +35,10 @@ const ( directMessageGuildID = "@me" ) -var ErrNoManifest = snapshot.ErrNoManifest +var ( + ErrNoManifest = snapshot.ErrNoManifest + errUnsafeMediaPath = errors.New("unsafe media path") +) const shardFlushRows = 1024 @@ -50,9 +57,11 @@ var SnapshotTables = []string{ type Options struct { RepoPath string + CacheDir string Remote string Branch string Filter FilterOptions + IncludeMedia bool IncludeEmbeddings bool EmbeddingProvider string EmbeddingModel string @@ -84,6 +93,7 @@ type Manifest struct { Version int `json:"version"` GeneratedAt time.Time `json:"generated_at"` Tables []TableManifest `json:"tables"` + Media *MediaManifest `json:"media,omitempty"` Embeddings []EmbeddingManifest `json:"embeddings,omitempty"` Files map[string]string `json:"files,omitempty"` } @@ -99,6 +109,12 @@ type EmbeddingManifest struct { Rows int `json:"rows"` } +type MediaManifest struct { + Attachments int `json:"attachments"` + Files []snapshot.FileManifest `json:"files"` + Bytes int64 `json:"bytes"` +} + func EnsureRepo(ctx context.Context, opts Options) error { if strings.TrimSpace(opts.RepoPath) == "" { return errors.New("share repo path is empty") @@ -129,6 +145,9 @@ func Push(ctx context.Context, opts Options) error { } func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error) { + if err := validateMediaRoots(opts); err != nil { + return Manifest{}, err + } if err := EnsureRepo(ctx, opts); err != nil { return Manifest{}, err } @@ -161,6 +180,19 @@ func Export(ctx context.Context, s *store.Store, opts Options) (Manifest, error) } manifest.Embeddings = []EmbeddingManifest{entry} } + if opts.IncludeMedia { + entry, err := exportMedia(ctx, s.DB(), opts, filter) + if err != nil { + return Manifest{}, err + } + if entry != nil { + manifest.Media = entry + } + } else { + if err := os.RemoveAll(filepath.Join(opts.RepoPath, "media")); err != nil { + return Manifest{}, fmt.Errorf("reset media dir: %w", err) + } + } body, err := json.MarshalIndent(manifest, "", " ") if err != nil { return Manifest{}, err @@ -184,6 +216,7 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) return Manifest{}, err } pragmasRestored := false + existingMedia := map[string]attachmentMediaRecord{} defer func() { if !pragmasRestored { _ = restorePragmas(ctx) @@ -208,6 +241,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) return !isDirectMessageSnapshotRow(table, row), nil }, BeforeImport: func(ctx context.Context, tx *sql.Tx) error { + var err error + existingMedia, err = attachmentMediaByID(ctx, tx) + if err != nil { + return err + } for _, table := range []string{"message_fts", "member_fts"} { if _, err := tx.ExecContext(ctx, "drop table if exists "+table); err != nil { return fmt.Errorf("drop %s: %w", table, err) @@ -226,6 +264,9 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) if err := repairImportedGuildIDs(ctx, tx); err != nil { return err } + if err := preserveImportedAttachmentMedia(ctx, tx, existingMedia); err != nil { + return err + } if opts.IncludeEmbeddings { return importEmbeddings(ctx, tx, opts, manifest.Embeddings) } @@ -238,6 +279,11 @@ func Import(ctx context.Context, s *store.Store, opts Options) (Manifest, error) if err := s.RebuildSearchIndexes(ctx); err != nil { return Manifest{}, err } + if opts.IncludeMedia { + if _, err := importMedia(ctx, opts, manifest.Media); err != nil { + return Manifest{}, err + } + } if err := MarkImported(ctx, s, manifest); err != nil { return Manifest{}, err } @@ -289,6 +335,16 @@ func ImportIfChanged(ctx context.Context, s *store.Store, opts Options) (Manifes return Manifest{}, false, err } } + if opts.IncludeMedia { + copied, err := importMedia(ctx, opts, manifest.Media) + if err != nil { + return Manifest{}, false, err + } + if err := MarkImported(ctx, s, manifest); err != nil { + return Manifest{}, false, err + } + return manifest, copied > 0, nil + } if err := MarkImported(ctx, s, manifest); err != nil { return Manifest{}, false, err } @@ -316,10 +372,18 @@ func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previo return Manifest{}, false, errIncrementalUnsupported } if !plan.Changed() { + copied := 0 + if opts.IncludeMedia { + var err error + copied, err = importMedia(ctx, opts, manifest.Media) + if err != nil { + return Manifest{}, false, err + } + } if err := MarkImported(ctx, s, manifest); err != nil { return Manifest{}, false, err } - return manifest, false, nil + return manifest, copied > 0, nil } opts.reportProgress(ImportProgress{Phase: "start", TotalRows: importPlanRowCount(plan)}) restorePragmas, err := applyImportPragmas(ctx, s.DB()) @@ -327,6 +391,7 @@ func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previo return Manifest{}, false, err } pragmasRestored := false + existingMedia := map[string]attachmentMediaRecord{} defer func() { if !pragmasRestored { _ = restorePragmas(ctx) @@ -351,6 +416,11 @@ func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previo Filter: func(table string, row map[string]any) (bool, error) { return !isDirectMessageSnapshotRow(table, row), nil }, + BeforeImport: func(ctx context.Context, tx *sql.Tx) error { + var err error + existingMedia, err = attachmentMediaByID(ctx, tx) + return err + }, DeleteTable: func(ctx context.Context, tx *sql.Tx, table string) error { query, args := snapshotDeleteQuery(table) if _, err := tx.ExecContext(ctx, query, args...); err != nil { @@ -363,6 +433,9 @@ func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previo if err := repairImportedGuildIDs(ctx, tx); err != nil { return err } + if err := preserveImportedAttachmentMedia(ctx, tx, existingMedia); err != nil { + return err + } if opts.IncludeEmbeddings { return importEmbeddings(ctx, tx, opts, manifest.Embeddings) } @@ -384,6 +457,11 @@ func ImportIncremental(ctx context.Context, s *store.Store, opts Options, previo return Manifest{}, false, err } } + if opts.IncludeMedia { + if _, err := importMedia(ctx, opts, manifest.Media); err != nil { + return Manifest{}, false, err + } + } if err := MarkImported(ctx, s, manifest); err != nil { return Manifest{}, false, err } @@ -409,6 +487,9 @@ func manifestRowCount(manifest Manifest) int { for _, embeddings := range manifest.Embeddings { total += embeddings.Rows } + if manifest.Media != nil { + total += manifest.Media.Attachments + } return total } @@ -798,6 +879,290 @@ func exportEmbeddings(ctx context.Context, db *sql.DB, opts Options) (EmbeddingM }, nil } +func exportMedia(ctx context.Context, db *sql.DB, opts Options, filter *snapshotFilter) (*MediaManifest, error) { + if strings.TrimSpace(opts.CacheDir) == "" { + return nil, nil + } + if err := os.RemoveAll(filepath.Join(opts.RepoPath, "media")); err != nil { + return nil, fmt.Errorf("reset media dir: %w", err) + } + rows, err := db.QueryContext(ctx, ` + select attachment_id, message_id, guild_id, channel_id, coalesce(media_path, ''), coalesce(content_sha256, '') + from message_attachments + where guild_id <> ? and coalesce(media_path, '') <> '' + order by media_path, attachment_id + `, directMessageGuildID) + if err != nil { + return nil, fmt.Errorf("query media attachments: %w", err) + } + defer func() { _ = rows.Close() }() + manifest := &MediaManifest{} + seen := map[string]struct{}{} + for rows.Next() { + var attachmentID, messageID, guildID, channelID, mediaPath, expectedHash string + if err := rows.Scan(&attachmentID, &messageID, &guildID, &channelID, &mediaPath, &expectedHash); err != nil { + return nil, err + } + if filter != nil { + ok := filter.allow("message_attachments", map[string]any{ + "attachment_id": attachmentID, + "message_id": messageID, + "guild_id": guildID, + "channel_id": channelID, + }) + if !ok { + continue + } + } + if _, ok := seen[mediaPath]; ok { + manifest.Attachments++ + continue + } + source, err := media.LocalPath(opts.CacheDir, mediaPath) + if err != nil { + return nil, err + } + info, err := regularMediaFile(filepath.Join(opts.CacheDir, "media"), source, mediaPath) + if errors.Is(err, os.ErrNotExist) { + continue + } + if errors.Is(err, errUnsafeMediaPath) { + continue + } + if err != nil { + return nil, fmt.Errorf("stat media %s: %w", mediaPath, err) + } + manifest.Attachments++ + seen[mediaPath] = struct{}{} + rel := filepath.ToSlash(filepath.Join("media", mediaPath)) + target, err := media.RepoPath(opts.RepoPath, mediaPath) + if err != nil { + return nil, err + } + if err := copyFile(target, source); err != nil { + return nil, fmt.Errorf("copy media %s: %w", mediaPath, err) + } + hash, err := fileSHA256(target) + if err != nil { + return nil, err + } + if expectedHash != "" && hash != expectedHash { + return nil, fmt.Errorf("media hash mismatch for %s: got %s want %s", mediaPath, hash, expectedHash) + } + manifest.Files = append(manifest.Files, snapshot.FileManifest{Path: rel, Size: info.Size(), SHA256: hash}) + manifest.Bytes += info.Size() + } + if err := rows.Err(); err != nil { + return nil, err + } + if len(manifest.Files) == 0 && manifest.Attachments == 0 { + return nil, nil + } + return manifest, nil +} + +func validateMediaRoots(opts Options) error { + if strings.TrimSpace(opts.CacheDir) == "" { + return nil + } + repoMedia, err := resolvePathForOverlap(filepath.Join(opts.RepoPath, "media")) + if err != nil { + return fmt.Errorf("resolve repo media dir: %w", err) + } + cacheMedia, err := resolvePathForOverlap(filepath.Join(opts.CacheDir, "media")) + if err != nil { + return fmt.Errorf("resolve cache media dir: %w", err) + } + if pathsOverlap(repoMedia, cacheMedia) { + return fmt.Errorf("share media dir %s overlaps cache media dir %s", repoMedia, cacheMedia) + } + return nil +} + +func resolvePathForOverlap(path string) (string, error) { + current, err := filepath.Abs(path) + if err != nil { + return "", err + } + current = filepath.Clean(current) + missing := []string{} + for { + resolved, err := filepath.EvalSymlinks(current) + if err == nil { + for _, part := range slices.Backward(missing) { + resolved = filepath.Join(resolved, part) + } + return filepath.Clean(resolved), nil + } + if !errors.Is(err, os.ErrNotExist) { + return "", err + } + parent := filepath.Dir(current) + if parent == current { + return filepath.Clean(path), nil + } + missing = append(missing, filepath.Base(current)) + current = parent + } +} + +func pathsOverlap(a, b string) bool { + if a == b { + return true + } + return pathContains(a, b) || pathContains(b, a) +} + +func pathContains(parent, child string) bool { + rel, err := filepath.Rel(parent, child) + if err != nil { + return false + } + return rel != "." && rel != "" && !filepath.IsAbs(rel) && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +func importMedia(ctx context.Context, opts Options, manifest *MediaManifest) (int, error) { + if manifest == nil || strings.TrimSpace(opts.CacheDir) == "" { + return 0, nil + } + copied := 0 + for _, file := range manifest.Files { + if err := ctx.Err(); err != nil { + return copied, err + } + mediaPath, ok := strings.CutPrefix(filepath.ToSlash(file.Path), "media/") + if !ok || strings.TrimSpace(mediaPath) == "" { + return copied, fmt.Errorf("invalid media manifest path %q", file.Path) + } + source, err := media.RepoPath(opts.RepoPath, mediaPath) + if err != nil { + return copied, err + } + if _, err := regularMediaFile(filepath.Join(opts.RepoPath, "media"), source, file.Path); err != nil { + return copied, err + } + info, err := os.Lstat(source) + if err != nil { + return copied, fmt.Errorf("stat media %s: %w", file.Path, err) + } + if !info.Mode().IsRegular() { + return copied, fmt.Errorf("media %s is not a regular file", file.Path) + } + hash, err := fileSHA256(source) + if err != nil { + return copied, fmt.Errorf("hash media %s: %w", file.Path, err) + } + if file.SHA256 != "" && hash != file.SHA256 { + return copied, fmt.Errorf("media hash mismatch for %s: got %s want %s", file.Path, hash, file.SHA256) + } + target, err := media.LocalPath(opts.CacheDir, mediaPath) + if err != nil { + return copied, err + } + if sameFileHash(target, hash) { + continue + } + if err := copyFile(target, source); err != nil { + return copied, fmt.Errorf("restore media %s: %w", file.Path, err) + } + copied++ + } + return copied, nil +} + +func regularMediaFile(root, path, label string) (os.FileInfo, error) { + root = filepath.Clean(root) + path = filepath.Clean(path) + rel, err := filepath.Rel(root, path) + if err != nil || rel == "." || rel == "" || filepath.IsAbs(rel) || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || rel == ".." { + return nil, fmt.Errorf("%w: media %s escapes media root", errUnsafeMediaPath, label) + } + rootInfo, err := os.Lstat(root) + if err != nil { + return nil, err + } + if rootInfo.Mode()&os.ModeSymlink != 0 || !rootInfo.IsDir() { + return nil, fmt.Errorf("%w: media root for %s is not a directory", errUnsafeMediaPath, label) + } + current := root + parts := strings.Split(rel, string(filepath.Separator)) + for i, part := range parts { + if part == "" || part == "." || part == ".." { + return nil, fmt.Errorf("%w: invalid media path %q", errUnsafeMediaPath, label) + } + current = filepath.Join(current, part) + info, err := os.Lstat(current) + if err != nil { + return nil, err + } + if info.Mode()&os.ModeSymlink != 0 { + if i == len(parts)-1 { + return nil, fmt.Errorf("%w: media %s is not a regular file", errUnsafeMediaPath, label) + } + return nil, fmt.Errorf("%w: media %s contains symlinked path component", errUnsafeMediaPath, label) + } + if i < len(parts)-1 { + if !info.IsDir() { + return nil, fmt.Errorf("%w: media %s parent is not a directory", errUnsafeMediaPath, label) + } + continue + } + if !info.Mode().IsRegular() { + return nil, fmt.Errorf("%w: media %s is not a regular file", errUnsafeMediaPath, label) + } + return info, nil + } + return nil, fmt.Errorf("%w: invalid media path %q", errUnsafeMediaPath, label) +} + +func copyFile(target, source string) error { + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return err + } + src, err := os.Open(source) // #nosec G304 -- source is constrained by media path helpers. + if err != nil { + return err + } + defer func() { _ = src.Close() }() + tmp, err := os.CreateTemp(filepath.Dir(target), ".copy-*") + if err != nil { + return err + } + tmpPath := tmp.Name() + if _, err := io.Copy(tmp, src); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return err + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return err + } + if err := os.Rename(tmpPath, target); err != nil { + _ = os.Remove(tmpPath) + return err + } + return nil +} + +func fileSHA256(path string) (string, error) { + file, err := os.Open(path) // #nosec G304 -- callers pass confined repo/cache paths. + if err != nil { + return "", err + } + defer func() { _ = file.Close() }() + hasher := sha256.New() + if _, err := io.Copy(hasher, file); err != nil { + return "", err + } + return hex.EncodeToString(hasher.Sum(nil)), nil +} + +func sameFileHash(path, hash string) bool { + current, err := fileSHA256(path) + return err == nil && current == hash +} + func importTable(ctx context.Context, tx *sql.Tx, opts Options, table TableManifest) error { files := table.Files if len(files) == 0 && strings.TrimSpace(table.File) != "" { @@ -1571,10 +1936,76 @@ func importValue(value any) any { } } +type attachmentMediaRecord struct { + MediaPath string + ContentSHA256 string + ContentSize int64 + FetchedAt string + FetchStatus string + FetchError string +} + +func attachmentMediaByID(ctx context.Context, tx *sql.Tx) (map[string]attachmentMediaRecord, error) { + rows, err := tx.QueryContext(ctx, ` + select attachment_id, coalesce(media_path, ''), coalesce(content_sha256, ''), + content_size, coalesce(fetched_at, ''), coalesce(fetch_status, ''), coalesce(fetch_error, '') + from message_attachments + where coalesce(media_path, '') <> '' + `) + if err != nil { + return nil, fmt.Errorf("query existing attachment media: %w", err) + } + defer func() { _ = rows.Close() }() + out := map[string]attachmentMediaRecord{} + for rows.Next() { + var id string + var record attachmentMediaRecord + if err := rows.Scan(&id, &record.MediaPath, &record.ContentSHA256, &record.ContentSize, &record.FetchedAt, &record.FetchStatus, &record.FetchError); err != nil { + return nil, err + } + out[id] = record + } + return out, rows.Err() +} + +func preserveImportedAttachmentMedia(ctx context.Context, tx *sql.Tx, existing map[string]attachmentMediaRecord) error { + if len(existing) == 0 { + return nil + } + stmt, err := tx.PrepareContext(ctx, ` + update message_attachments + set media_path = ?, content_sha256 = ?, content_size = ?, fetched_at = ?, + fetch_status = ?, fetch_error = ? + where attachment_id = ? and coalesce(media_path, '') = '' + `) + if err != nil { + return err + } + defer func() { _ = stmt.Close() }() + for attachmentID, record := range existing { + if _, err := stmt.ExecContext(ctx, record.MediaPath, nullableString(record.ContentSHA256), record.ContentSize, nullableString(record.FetchedAt), record.FetchStatus, record.FetchError, attachmentID); err != nil { + return fmt.Errorf("preserve attachment media %s: %w", attachmentID, err) + } + } + return nil +} + +func nullableString(v string) any { + if v == "" { + return nil + } + return v +} + func importIncrementalSnapshotRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error { if table == "message_events" || table == "mention_events" { delete(row, "event_id") } + if table == "message_attachments" { + if err := preserveIncrementalAttachmentMedia(ctx, tx, row); err != nil { + return err + } + } if err := insertOrReplaceSnapshotRow(ctx, tx, table, row); err != nil { return err } @@ -1588,6 +2019,36 @@ func importIncrementalSnapshotRow(ctx context.Context, tx *sql.Tx, table string, return upsertMessageFTSRow(ctx, tx, messageID) } +func preserveIncrementalAttachmentMedia(ctx context.Context, tx *sql.Tx, row map[string]any) error { + if stringValue(row["media_path"]) != "" { + return nil + } + attachmentID := stringValue(row["attachment_id"]) + if attachmentID == "" { + return nil + } + var record attachmentMediaRecord + err := tx.QueryRowContext(ctx, ` + select coalesce(media_path, ''), coalesce(content_sha256, ''), content_size, + coalesce(fetched_at, ''), coalesce(fetch_status, ''), coalesce(fetch_error, '') + from message_attachments + where attachment_id = ? and coalesce(media_path, '') <> '' + `, attachmentID).Scan(&record.MediaPath, &record.ContentSHA256, &record.ContentSize, &record.FetchedAt, &record.FetchStatus, &record.FetchError) + if errors.Is(err, sql.ErrNoRows) { + return nil + } + if err != nil { + return fmt.Errorf("query existing attachment media %s: %w", attachmentID, err) + } + row["media_path"] = record.MediaPath + row["content_sha256"] = record.ContentSHA256 + row["content_size"] = record.ContentSize + row["fetched_at"] = record.FetchedAt + row["fetch_status"] = record.FetchStatus + row["fetch_error"] = record.FetchError + return nil +} + func insertOrReplaceSnapshotRow(ctx context.Context, tx *sql.Tx, table string, row map[string]any) error { cols := make([]string, 0, len(row)) for col := range row { diff --git a/internal/share/share_test.go b/internal/share/share_test.go index ffbaefc..c32067d 100644 --- a/internal/share/share_test.go +++ b/internal/share/share_test.go @@ -4,6 +4,8 @@ import ( "bytes" "compress/gzip" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "io" "os" @@ -18,6 +20,7 @@ import ( "github.com/openclaw/crawlkit/snapshot" "github.com/stretchr/testify/require" + "github.com/openclaw/discrawl/internal/media" "github.com/openclaw/discrawl/internal/store" ) @@ -75,6 +78,588 @@ func TestExportImportRoundTrip(t *testing.T) { require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt) } +func TestExportImportRestoresMediaFiles(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + body := []byte("cached-media") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, src, mediaPath, hash, int64(len(body)))) + srcCache := filepath.Join(dir, "src-cache") + srcFile, err := media.LocalPath(srcCache, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(srcFile), 0o755)) + require.NoError(t, os.WriteFile(srcFile, body, 0o600)) + + repo := filepath.Join(dir, "share") + manifest, err := Export(ctx, src, Options{RepoPath: repo, CacheDir: srcCache, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.NotNil(t, manifest.Media) + require.Equal(t, 1, manifest.Media.Attachments) + require.Len(t, manifest.Media.Files, 1) + require.FileExists(t, filepath.Join(repo, filepath.FromSlash(manifest.Media.Files[0].Path))) + + dst, err := store.Open(ctx, filepath.Join(dir, "dst.db")) + require.NoError(t, err) + defer func() { _ = dst.Close() }() + dstCache := filepath.Join(dir, "dst-cache") + imported, changed, err := ImportIfChanged(ctx, dst, Options{RepoPath: repo, CacheDir: dstCache, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.True(t, changed) + require.NotNil(t, imported.Media) + dstFile, err := media.LocalPath(dstCache, mediaPath) + require.NoError(t, err) + got, err := os.ReadFile(dstFile) + require.NoError(t, err) + require.Equal(t, body, got) + rows, err := dst.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, mediaPath, rows[0].MediaPath) + + require.NoError(t, os.Remove(dstFile)) + imported, changed, err = ImportIfChanged(ctx, dst, Options{RepoPath: repo, CacheDir: dstCache, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.True(t, changed) + require.NotNil(t, imported.Media) + got, err = os.ReadFile(dstFile) + require.NoError(t, err) + require.Equal(t, body, got) +} + +func TestExportRejectsOverlappingMediaRoots(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + cacheDir := filepath.Join(dir, "cache") + mediaPath := "attachments/aa/file.png" + cacheFile, err := media.LocalPath(cacheDir, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(cacheFile), 0o755)) + require.NoError(t, os.WriteFile(cacheFile, []byte("cached"), 0o600)) + + _, err = Export(ctx, src, Options{RepoPath: cacheDir, CacheDir: cacheDir, Branch: "main", IncludeMedia: true}) + require.Error(t, err) + require.Contains(t, err.Error(), "overlaps cache media dir") + require.FileExists(t, cacheFile) + + _, err = Export(ctx, src, Options{RepoPath: cacheDir, CacheDir: cacheDir, Branch: "main", IncludeMedia: false}) + require.Error(t, err) + require.Contains(t, err.Error(), "overlaps cache media dir") + require.FileExists(t, cacheFile) +} + +func TestExportRejectsSymlinkedOverlappingMediaRoots(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + cacheDir := filepath.Join(dir, "cache") + cacheFile, err := media.LocalPath(cacheDir, "attachments/aa/file.png") + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(cacheFile), 0o755)) + require.NoError(t, os.WriteFile(cacheFile, []byte("cached"), 0o600)) + repoPath := filepath.Join(dir, "repo-link") + if err := os.Symlink(cacheDir, repoPath); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + _, err = Export(ctx, src, Options{RepoPath: repoPath, CacheDir: cacheDir, Branch: "main", IncludeMedia: false}) + require.Error(t, err) + require.Contains(t, err.Error(), "overlaps cache media dir") + require.FileExists(t, cacheFile) +} + +func TestImportPreservesLocalAttachmentMediaMetadata(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + require.NoError(t, addUncachedAttachment(ctx, src)) + repo := filepath.Join(dir, "share") + _, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"}) + require.NoError(t, err) + + dst := seedStore(t, filepath.Join(dir, "dst.db")) + defer func() { _ = dst.Close() }() + body := []byte("local-cache") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, dst, mediaPath, hash, int64(len(body)))) + + _, err = Import(ctx, dst, Options{RepoPath: repo, Branch: "main"}) + require.NoError(t, err) + rows, err := dst.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, mediaPath, rows[0].MediaPath) + require.Equal(t, hash, rows[0].ContentSHA256) + require.Equal(t, int64(len(body)), rows[0].ContentSize) + require.Equal(t, "fetched", rows[0].FetchStatus) +} + +func TestIncrementalImportPreservesLocalAttachmentMediaMetadata(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + dst := seedStore(t, filepath.Join(dir, "dst.db")) + defer func() { _ = dst.Close() }() + body := []byte("local-cache") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, dst, mediaPath, hash, int64(len(body)))) + + tx, err := dst.DB().BeginTx(ctx, nil) + require.NoError(t, err) + row := map[string]any{ + "attachment_id": "a1", + "message_id": "m1", + "guild_id": "g1", + "channel_id": "c1", + "author_id": "u1", + "filename": "file.png", + "content_type": "image/png", + "size": int64(len(body)), + "url": "https://cdn.example/file.png", + "proxy_url": nil, + "text_content": "", + "media_path": "", + "content_sha256": "", + "content_size": int64(0), + "fetched_at": nil, + "fetch_status": "", + "fetch_error": "", + "updated_at": time.Now().UTC().Format(time.RFC3339Nano), + } + require.NoError(t, importIncrementalSnapshotRow(ctx, tx, "message_attachments", row)) + require.NoError(t, tx.Commit()) + + rows, err := dst.ListAttachments(ctx, store.AttachmentListOptions{MessageID: "m1"}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, mediaPath, rows[0].MediaPath) + require.Equal(t, hash, rows[0].ContentSHA256) + require.Equal(t, int64(len(body)), rows[0].ContentSize) + require.Equal(t, "fetched", rows[0].FetchStatus) +} + +func TestImportMediaRejectsSymlinkedFiles(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + repo := filepath.Join(dir, "share") + cacheDir := filepath.Join(dir, "cache") + mediaPath := "attachments/aa/file.png" + source, err := media.RepoPath(repo, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(source), 0o755)) + target := filepath.Join(dir, "outside.png") + require.NoError(t, os.WriteFile(target, []byte("outside"), 0o600)) + if err := os.Symlink(target, source); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + _, err = importMedia(ctx, Options{RepoPath: repo, CacheDir: cacheDir}, &MediaManifest{ + Files: []snapshot.FileManifest{{Path: "media/" + mediaPath}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not a regular file") + _, err = os.Stat(filepath.Join(cacheDir, "media", filepath.FromSlash(mediaPath))) + require.True(t, os.IsNotExist(err)) +} + +func TestImportMediaRejectsSymlinkedDirectories(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + repo := filepath.Join(dir, "share") + cacheDir := filepath.Join(dir, "cache") + mediaPath := "attachments/aa/file.png" + outside := filepath.Join(dir, "outside") + require.NoError(t, os.MkdirAll(outside, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(outside, "file.png"), []byte("outside"), 0o600)) + linkParent := filepath.Join(repo, "media", "attachments", "aa") + require.NoError(t, os.MkdirAll(filepath.Dir(linkParent), 0o755)) + if err := os.Symlink(outside, linkParent); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + _, err := importMedia(ctx, Options{RepoPath: repo, CacheDir: cacheDir}, &MediaManifest{ + Files: []snapshot.FileManifest{{Path: "media/" + mediaPath}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "symlinked path component") + _, err = os.Stat(filepath.Join(cacheDir, "media", filepath.FromSlash(mediaPath))) + require.True(t, os.IsNotExist(err)) +} + +func TestImportMediaRejectsInvalidAndMismatchedFiles(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + repo := filepath.Join(dir, "share") + cacheDir := filepath.Join(dir, "cache") + + copied, err := importMedia(ctx, Options{RepoPath: repo, CacheDir: cacheDir}, nil) + require.NoError(t, err) + require.Zero(t, copied) + + _, err = importMedia(ctx, Options{RepoPath: repo, CacheDir: cacheDir}, &MediaManifest{ + Files: []snapshot.FileManifest{{Path: "tables/not-media.jsonl.gz"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid media manifest path") + + mediaPath := "attachments/aa/file.png" + source, err := media.RepoPath(repo, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(source), 0o755)) + require.NoError(t, os.WriteFile(source, []byte("body"), 0o600)) + _, err = importMedia(ctx, Options{RepoPath: repo, CacheDir: cacheDir}, &MediaManifest{ + Files: []snapshot.FileManifest{{Path: "media/" + mediaPath, SHA256: strings.Repeat("0", 64)}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "media hash mismatch") +} + +func TestImportIncrementalRestoresMediaWhenTablesUnchanged(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + body := []byte("cached-media") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, src, mediaPath, hash, int64(len(body)))) + srcCache := filepath.Join(dir, "src-cache") + srcFile, err := media.LocalPath(srcCache, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(srcFile), 0o755)) + require.NoError(t, os.WriteFile(srcFile, body, 0o600)) + + repo := filepath.Join(dir, "share") + manifest, err := Export(ctx, src, Options{RepoPath: repo, CacheDir: srcCache, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + dst, err := store.Open(ctx, filepath.Join(dir, "dst.db")) + require.NoError(t, err) + defer func() { _ = dst.Close() }() + _, err = Import(ctx, dst, Options{RepoPath: repo, CacheDir: filepath.Join(dir, "dst-cache"), Branch: "main", IncludeMedia: false}) + require.NoError(t, err) + + dstCache := filepath.Join(dir, "dst-cache") + imported, changed, err := ImportIncremental(ctx, dst, Options{RepoPath: repo, CacheDir: dstCache, Branch: "main", IncludeMedia: true}, manifest, manifest) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, manifest.GeneratedAt, imported.GeneratedAt) + dstFile, err := media.LocalPath(dstCache, mediaPath) + require.NoError(t, err) + got, err := os.ReadFile(dstFile) + require.NoError(t, err) + require.Equal(t, body, got) +} + +func TestShareIncrementalPlanRejectsUnsupportedModes(t *testing.T) { + _, supported := shareIncrementalPlan(snapshot.ImportPlan{Full: true, Reason: "schema changed"}) + require.False(t, supported) + + _, supported = shareIncrementalPlan(snapshot.ImportPlan{Tables: []snapshot.TableImportPlan{{ + Table: snapshot.TableManifest{Name: "custom"}, + Mode: snapshot.TableImportFiles, + }}}) + require.False(t, supported) + + _, supported = shareIncrementalPlan(snapshot.ImportPlan{Tables: []snapshot.TableImportPlan{{ + Table: snapshot.TableManifest{Name: "messages"}, + Mode: snapshot.TableImportReplace, + }}}) + require.False(t, supported) + + plan, supported := shareIncrementalPlan(snapshot.ImportPlan{Tables: []snapshot.TableImportPlan{ + {Table: snapshot.TableManifest{Name: "messages"}, Mode: snapshot.TableImportFiles}, + {Table: snapshot.TableManifest{Name: "guilds"}, Mode: snapshot.TableImportFiles}, + {Table: snapshot.TableManifest{Name: "channels"}, Mode: snapshot.TableImportReplace}, + {Table: snapshot.TableManifest{Name: "sync_state"}, Mode: snapshot.TableImportSkip}, + }}) + require.True(t, supported) + require.Len(t, plan.Tables, 4) + require.Equal(t, snapshot.TableImportReplace, plan.Tables[1].Mode) +} + +func TestMessageFTSHelpers(t *testing.T) { + id, ok := messageFTSRowID("42") + require.True(t, ok) + require.Equal(t, int64(42), id) + id, ok = messageFTSRowID("18446744073709551615") + require.True(t, ok) + require.NotZero(t, id) + _, ok = messageFTSRowID("") + require.False(t, ok) + require.Nil(t, nullIfEmpty("")) + require.Equal(t, "value", nullIfEmpty("value")) +} + +func TestPreviousImportedManifestFallsBackToGitHistory(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + repo := filepath.Join(dir, "share") + require.NoError(t, exec.CommandContext(ctx, "git", "init", repo).Run()) + configureGitUser(t, repo) + manifest, err := Export(ctx, src, Options{RepoPath: repo, Branch: "main"}) + require.NoError(t, err) + require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "add", ".").Run()) + require.NoError(t, exec.CommandContext(ctx, "git", "-C", repo, "commit", "-m", "snapshot").Run()) + + dst, err := store.Open(ctx, filepath.Join(dir, "dst.db")) + require.NoError(t, err) + defer func() { _ = dst.Close() }() + require.NoError(t, dst.SetSyncState(ctx, LastImportManifestSyncScope, manifest.GeneratedAt.Format(time.RFC3339Nano))) + + previous, ok := PreviousImportedManifest(ctx, dst, Options{RepoPath: repo, Branch: "main"}) + require.True(t, ok) + require.Equal(t, manifest.GeneratedAt, previous.GeneratedAt) + require.NotEmpty(t, tableEntry(t, previous, "messages").FileManifests) +} + +func TestMediaPathValidationHelpers(t *testing.T) { + dir := t.TempDir() + root := filepath.Join(dir, "media") + require.NoError(t, os.MkdirAll(root, 0o755)) + file := filepath.Join(root, "attachments", "aa", "file.png") + require.NoError(t, os.MkdirAll(filepath.Dir(file), 0o755)) + require.NoError(t, os.WriteFile(file, []byte("body"), 0o600)) + + info, err := regularMediaFile(root, file, "file.png") + require.NoError(t, err) + require.Equal(t, int64(4), info.Size()) + _, err = regularMediaFile(root, root, "root") + require.Error(t, err) + require.ErrorIs(t, err, errUnsafeMediaPath) + require.True(t, pathsOverlap(root, filepath.Join(root, "attachments"))) + require.False(t, pathsOverlap(root, filepath.Join(dir, "other"))) +} + +func TestMediaCopyHashHelpers(t *testing.T) { + dir := t.TempDir() + source := filepath.Join(dir, "source.bin") + target := filepath.Join(dir, "nested", "target.bin") + body := []byte("copy-body") + require.NoError(t, os.WriteFile(source, body, 0o600)) + + require.NoError(t, copyFile(target, source)) + got, err := os.ReadFile(target) + require.NoError(t, err) + require.Equal(t, body, got) + hash, err := fileSHA256(source) + require.NoError(t, err) + require.True(t, sameFileHash(target, hash)) + require.False(t, sameFileHash(filepath.Join(dir, "missing.bin"), hash)) + require.Error(t, copyFile(filepath.Join(dir, "other.bin"), filepath.Join(dir, "missing.bin"))) +} + +func TestPublicPermissionHelpers(t *testing.T) { + rawGuild := `{"roles":[{"id":"g1","permissions":"1024"}]}` + permissions, ok := everyoneGuildPermissions(rawGuild, "g1") + require.True(t, ok) + require.Equal(t, permissionViewChannel, permissions) + _, ok = everyoneGuildPermissions(`{"roles":[{"id":"g1","permissions":{}}]}`, "g1") + require.False(t, ok) + _, ok = everyoneGuildPermissions(`not-json`, "g1") + require.False(t, ok) + _, ok = everyoneGuildPermissions(`{"roles":[]}`, "g1") + require.False(t, ok) + + require.Equal(t, int64(0), applyEveryoneOverwrite(permissionViewChannel, `{"permission_overwrites":[{"id":"g1","type":"role","deny":"1024"}]}`, "g1")) + require.Equal(t, permissionViewChannel, applyEveryoneOverwrite(0, `{"permission_overwrites":[{"id":"g1","type":0,"allow":1024}]}`, "g1")) + require.Equal(t, permissionViewChannel, applyEveryoneOverwrite(permissionViewChannel, `not-json`, "g1")) + + parsed, ok := parsePermissionBits(json.Number("1024")) + require.True(t, ok) + require.Equal(t, permissionViewChannel, parsed) + parsed, ok = parsePermissionBits(float64(1024)) + require.True(t, ok) + require.Equal(t, permissionViewChannel, parsed) + parsed, ok = parsePermissionBits(nil) + require.True(t, ok) + require.Zero(t, parsed) + _, ok = parsePermissionBits(json.Number("bad")) + require.False(t, ok) + _, ok = parsePermissionBits(struct{}{}) + require.False(t, ok) + + require.True(t, isRoleOverwrite(json.Number("0"))) + require.True(t, isRoleOverwrite(float64(0))) + require.True(t, isRoleOverwrite("role")) + require.False(t, isRoleOverwrite("member")) + require.False(t, isRoleOverwrite(json.Number("bad"))) + require.False(t, isRoleOverwrite(struct{}{})) +} + +func TestPublicSnapshotFilterHonorsCategoryAndThreadPermissions(t *testing.T) { + ctx := context.Background() + s, err := store.Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, s.UpsertGuild(ctx, store.GuildRecord{ + ID: "g1", + Name: "Guild", + RawJSON: `{"roles":[{"id":"g1","permissions":"1024"}]}`, + })) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ + ID: "cat-deny", + GuildID: "g1", + Kind: "category", + Name: "Private", + RawJSON: `{"permission_overwrites":[{"id":"g1","type":"role","deny":"1024"}]}`, + })) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ + ID: "child-denied", + GuildID: "g1", + ParentID: "cat-deny", + Kind: "text", + Name: "denied", + RawJSON: `{}`, + })) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ + ID: "public", + GuildID: "g1", + Kind: "text", + Name: "public", + RawJSON: `{}`, + })) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ + ID: "thread", + GuildID: "g1", + ThreadParentID: "public", + Kind: "thread_public", + Name: "thread", + RawJSON: `{}`, + })) + require.NoError(t, s.UpsertChannel(ctx, store.ChannelRecord{ + ID: "private-thread", + GuildID: "g1", + ThreadParentID: "public", + Kind: "thread_private", + Name: "private-thread", + IsPrivateThread: true, + RawJSON: `{}`, + })) + + filter, err := newSnapshotFilter(ctx, s.DB(), FilterOptions{PublicOnly: true}) + require.NoError(t, err) + require.False(t, filter.publicChannel("child-denied")) + require.True(t, filter.publicChannel("public")) + require.True(t, filter.publicChannel("thread")) + require.False(t, filter.publicChannel("private-thread")) + require.False(t, filter.publicChannel("missing")) +} + +func TestExportSkipsMissingMediaFiles(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + body := []byte("missing-media") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-missing.png")) + require.NoError(t, addCachedAttachment(ctx, src, mediaPath, hash, int64(len(body)))) + + repo := filepath.Join(dir, "share") + manifest, err := Export(ctx, src, Options{RepoPath: repo, CacheDir: filepath.Join(dir, "cache"), Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.Nil(t, manifest.Media) +} + +func TestExportMediaRejectsInvalidPathAndHashMismatch(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + require.NoError(t, addCachedAttachment(ctx, src, "../bad.png", strings.Repeat("a", 64), 4)) + _, err := Export(ctx, src, Options{RepoPath: filepath.Join(dir, "share-bad-path"), CacheDir: filepath.Join(dir, "cache"), Branch: "main", IncludeMedia: true}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid media path") + + src2 := seedStore(t, filepath.Join(dir, "src2.db")) + defer func() { _ = src2.Close() }() + body := []byte("actual") + actual := sha256.Sum256(body) + actualHash := hex.EncodeToString(actual[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", actualHash[:2], actualHash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, src2, mediaPath, strings.Repeat("b", 64), int64(len(body)))) + cacheDir := filepath.Join(dir, "cache2") + cacheFile, err := media.LocalPath(cacheDir, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(cacheFile), 0o755)) + require.NoError(t, os.WriteFile(cacheFile, body, 0o600)) + + _, err = Export(ctx, src2, Options{RepoPath: filepath.Join(dir, "share-bad-hash"), CacheDir: cacheDir, Branch: "main", IncludeMedia: true}) + require.Error(t, err) + require.Contains(t, err.Error(), "media hash mismatch") +} + +func TestExportSkipsSymlinkedMediaFiles(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + body := []byte("outside-media") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, src, mediaPath, hash, int64(len(body)))) + cacheDir := filepath.Join(dir, "cache") + cacheFile, err := media.LocalPath(cacheDir, mediaPath) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(cacheFile), 0o755)) + target := filepath.Join(dir, "outside.png") + require.NoError(t, os.WriteFile(target, body, 0o600)) + if err := os.Symlink(target, cacheFile); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + repo := filepath.Join(dir, "share") + manifest, err := Export(ctx, src, Options{RepoPath: repo, CacheDir: cacheDir, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.Nil(t, manifest.Media) + _, err = os.Lstat(filepath.Join(repo, "media", filepath.FromSlash(mediaPath))) + require.True(t, os.IsNotExist(err)) +} + +func TestExportSkipsSymlinkedMediaDirectories(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + src := seedStore(t, filepath.Join(dir, "src.db")) + defer func() { _ = src.Close() }() + body := []byte("outside-media") + sum := sha256.Sum256(body) + hash := hex.EncodeToString(sum[:]) + mediaPath := filepath.ToSlash(filepath.Join("attachments", hash[:2], hash+"-file.png")) + require.NoError(t, addCachedAttachment(ctx, src, mediaPath, hash, int64(len(body)))) + cacheDir := filepath.Join(dir, "cache") + outside := filepath.Join(dir, "outside") + require.NoError(t, os.MkdirAll(outside, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(outside, hash+"-file.png"), body, 0o600)) + linkParent := filepath.Join(cacheDir, "media", "attachments", hash[:2]) + require.NoError(t, os.MkdirAll(filepath.Dir(linkParent), 0o755)) + if err := os.Symlink(outside, linkParent); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + repo := filepath.Join(dir, "share") + manifest, err := Export(ctx, src, Options{RepoPath: repo, CacheDir: cacheDir, Branch: "main", IncludeMedia: true}) + require.NoError(t, err) + require.Nil(t, manifest.Media) + _, err = os.Lstat(filepath.Join(repo, "media", filepath.FromSlash(mediaPath))) + require.True(t, os.IsNotExist(err)) +} + func TestImportIfChangedUsesIncrementalTailImport(t *testing.T) { ctx := context.Background() src := seedStore(t, filepath.Join(t.TempDir(), "src.db")) @@ -1286,6 +1871,76 @@ func seedStore(t *testing.T, path string) *store.Store { return s } +func addCachedAttachment(ctx context.Context, s *store.Store, mediaPath, hash string, size int64) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + if err := s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: now, + Content: "launch checklist ready", + NormalizedContent: "launch checklist ready file.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a1", + MessageID: "m1", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "file.png", + ContentType: "image/png", + Size: size, + URL: "https://cdn.example/file.png", + MediaPath: mediaPath, + ContentSHA256: hash, + ContentSize: size, + FetchedAt: now, + FetchStatus: "fetched", + }}, + }}); err != nil { + return err + } + return nil +} + +func addUncachedAttachment(ctx context.Context, s *store.Store) error { + now := time.Now().UTC().Format(time.RFC3339Nano) + return s.UpsertMessages(ctx, []store.MessageMutation{{ + Record: store.MessageRecord{ + ID: "m1", + GuildID: "g1", + ChannelID: "c1", + ChannelName: "general", + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: now, + Content: "launch checklist ready", + NormalizedContent: "launch checklist ready file.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []store.AttachmentRecord{{ + AttachmentID: "a1", + MessageID: "m1", + GuildID: "g1", + ChannelID: "c1", + AuthorID: "u1", + Filename: "file.png", + ContentType: "image/png", + Size: 11, + URL: "https://cdn.example/file.png", + }}, + }}) +} + func upsertSnapshotFilterChannel(t *testing.T, ctx context.Context, s *store.Store, channel store.ChannelRecord) { t.Helper() require.NoError(t, s.UpsertChannel(ctx, channel)) diff --git a/internal/store/attachments.go b/internal/store/attachments.go new file mode 100644 index 0000000..a16c59f --- /dev/null +++ b/internal/store/attachments.go @@ -0,0 +1,269 @@ +package store + +import ( + "context" + "strings" + "time" +) + +type AttachmentListOptions struct { + GuildIDs []string + ExcludeGuildIDs []string + ChannelIDs []string + Channel string + Author string + MessageID string + Filename string + ContentType string + Since time.Time + Before time.Time + Limit int + MissingOnly bool +} + +type AttachmentRow struct { + AttachmentID string `json:"attachment_id"` + MessageID string `json:"message_id"` + GuildID string `json:"guild_id"` + GuildName string `json:"guild_name,omitempty"` + ChannelID string `json:"channel_id"` + ChannelName string `json:"channel_name"` + AuthorID string `json:"author_id"` + AuthorName string `json:"author_name"` + Filename string `json:"filename"` + ContentType string `json:"content_type,omitempty"` + Size int64 `json:"size"` + URL string `json:"url,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + TextContent string `json:"text_content,omitempty"` + MediaPath string `json:"media_path,omitempty"` + ContentSHA256 string `json:"content_sha256,omitempty"` + ContentSize int64 `json:"content_size,omitempty"` + FetchedAt time.Time `json:"fetched_at,omitzero"` + FetchStatus string `json:"fetch_status,omitempty"` + FetchError string `json:"fetch_error,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type AttachmentMediaUpdate struct { + AttachmentID string + MediaPath string + ContentSHA256 string + ContentSize int64 + FetchedAt string + FetchStatus string + FetchError string +} + +func (s *Store) ListAttachments(ctx context.Context, opts AttachmentListOptions) ([]AttachmentRow, error) { + args := []any{} + clauses := []string{"1=1"} + if len(opts.GuildIDs) > 0 { + clauses = append(clauses, "a.guild_id in ("+placeholders(len(opts.GuildIDs))+")") + for _, guildID := range opts.GuildIDs { + args = append(args, guildID) + } + } + if len(opts.ExcludeGuildIDs) > 0 { + clauses = append(clauses, "a.guild_id not in ("+placeholders(len(opts.ExcludeGuildIDs))+")") + for _, guildID := range opts.ExcludeGuildIDs { + args = append(args, guildID) + } + } + if len(opts.ChannelIDs) > 0 { + clauses = append(clauses, "a.channel_id in ("+placeholders(len(opts.ChannelIDs))+")") + for _, channelID := range opts.ChannelIDs { + args = append(args, channelID) + } + } + if channel := normalizeChannelFilter(opts.Channel); channel != "" { + clauses = append(clauses, "(a.channel_id = ? or c.name = ? or c.name like ?)") + args = append(args, channel, channel, "%"+channel+"%") + } + if author := strings.TrimSpace(opts.Author); author != "" { + clauses = append(clauses, `(a.author_id = ? or coalesce(mem.username, '') = ? or coalesce(mem.display_name, '') = ? or coalesce(json_extract(m.raw_json, '$.author.username'), '') = ? or coalesce(json_extract(m.raw_json, '$.author.global_name'), '') = ? or coalesce(mem.username, '') like ? or coalesce(mem.display_name, '') like ? or coalesce(json_extract(m.raw_json, '$.author.username'), '') like ? or coalesce(json_extract(m.raw_json, '$.author.global_name'), '') like ?)`) + args = append(args, author, author, author, author, author, "%"+author+"%", "%"+author+"%", "%"+author+"%", "%"+author+"%") + } + if messageID := strings.TrimSpace(opts.MessageID); messageID != "" { + clauses = append(clauses, "a.message_id = ?") + args = append(args, messageID) + } + if filename := strings.TrimSpace(opts.Filename); filename != "" { + clauses = append(clauses, "a.filename like ?") + args = append(args, "%"+filename+"%") + } + if contentType := strings.TrimSpace(opts.ContentType); contentType != "" { + clauses = append(clauses, "coalesce(a.content_type, '') like ?") + args = append(args, "%"+contentType+"%") + } + if !opts.Since.IsZero() { + clauses = append(clauses, "m.created_at >= ?") + args = append(args, opts.Since.UTC().Format(timeLayout)) + } + if !opts.Before.IsZero() { + clauses = append(clauses, "m.created_at < ?") + args = append(args, opts.Before.UTC().Format(timeLayout)) + } + if opts.MissingOnly { + clauses = append(clauses, "coalesce(a.media_path, '') = ''") + } + query := ` + select + a.attachment_id, + a.message_id, + a.guild_id, + coalesce(g.name, ''), + a.channel_id, + coalesce(c.name, ''), + coalesce(a.author_id, ''), + coalesce( + nullif(mem.display_name, ''), + nullif(mem.nick, ''), + nullif(mem.global_name, ''), + nullif(mem.username, ''), + nullif(json_extract(m.raw_json, '$.author.global_name'), ''), + nullif(json_extract(m.raw_json, '$.author.username'), ''), + '' + ), + a.filename, + coalesce(a.content_type, ''), + a.size, + coalesce(a.url, ''), + coalesce(a.proxy_url, ''), + a.text_content, + coalesce(a.media_path, ''), + coalesce(a.content_sha256, ''), + a.content_size, + coalesce(a.fetched_at, ''), + coalesce(a.fetch_status, ''), + coalesce(a.fetch_error, ''), + m.created_at + from message_attachments a + left join messages m on m.id = a.message_id + left join guilds g on g.id = a.guild_id + left join channels c on c.id = a.channel_id + left join members mem on mem.guild_id = a.guild_id and mem.user_id = a.author_id + where ` + strings.Join(clauses, " and ") + ` + order by m.created_at asc, a.attachment_id asc + ` + if opts.Limit > 0 { + query += ` limit ?` + args = append(args, opts.Limit) + } + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := []AttachmentRow{} + for rows.Next() { + var row AttachmentRow + var created string + var fetched string + if err := rows.Scan( + &row.AttachmentID, + &row.MessageID, + &row.GuildID, + &row.GuildName, + &row.ChannelID, + &row.ChannelName, + &row.AuthorID, + &row.AuthorName, + &row.Filename, + &row.ContentType, + &row.Size, + &row.URL, + &row.ProxyURL, + &row.TextContent, + &row.MediaPath, + &row.ContentSHA256, + &row.ContentSize, + &fetched, + &row.FetchStatus, + &row.FetchError, + &created, + ); err != nil { + return nil, err + } + row.CreatedAt = parseTime(created) + row.FetchedAt = parseTime(fetched) + out = append(out, row) + } + return out, rows.Err() +} + +func (s *Store) ExpandAttachmentChannelIDs(ctx context.Context, channelIDs []string) ([]string, error) { + if len(channelIDs) == 0 { + return nil, nil + } + seen := map[string]struct{}{} + out := make([]string, 0, len(channelIDs)) + args := make([]any, 0, len(channelIDs)*2) + for _, channelID := range channelIDs { + channelID = strings.TrimSpace(channelID) + if channelID == "" { + continue + } + if _, ok := seen[channelID]; ok { + continue + } + seen[channelID] = struct{}{} + out = append(out, channelID) + args = append(args, channelID) + } + if len(out) == 0 { + return nil, nil + } + args = append(args, args...) + rows, err := s.db.QueryContext(ctx, ` + select id + from channels + where id in (`+placeholders(len(out))+`) + or thread_parent_id in (`+placeholders(len(out))+`) + order by id + `, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var channelID string + if err := rows.Scan(&channelID); err != nil { + return nil, err + } + if _, ok := seen[channelID]; ok { + continue + } + seen[channelID] = struct{}{} + out = append(out, channelID) + } + return out, rows.Err() +} + +func (s *Store) UpdateAttachmentMedia(ctx context.Context, update AttachmentMediaUpdate) error { + _, err := s.db.ExecContext(ctx, ` + update message_attachments + set media_path = ?, + content_sha256 = ?, + content_size = ?, + fetched_at = ?, + fetch_status = ?, + fetch_error = ?, + updated_at = ? + where attachment_id = ? + `, nullable(update.MediaPath), nullable(update.ContentSHA256), update.ContentSize, nullable(update.FetchedAt), + update.FetchStatus, update.FetchError, time.Now().UTC().Format(timeLayout), update.AttachmentID) + return err +} + +func (s *Store) UpdateAttachmentFetchStatus(ctx context.Context, attachmentID, fetchedAt, status, message string) error { + _, err := s.db.ExecContext(ctx, ` + update message_attachments + set fetched_at = ?, + fetch_status = ?, + fetch_error = ?, + updated_at = ? + where attachment_id = ? + `, nullable(fetchedAt), status, message, time.Now().UTC().Format(timeLayout), attachmentID) + return err +} diff --git a/internal/store/attachments_test.go b/internal/store/attachments_test.go new file mode 100644 index 0000000..6183cdf --- /dev/null +++ b/internal/store/attachments_test.go @@ -0,0 +1,129 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestExpandAttachmentChannelIDsIncludesForumThreads(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, s.UpsertGuild(ctx, GuildRecord{ID: "g1", Name: "Guild", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "forum", GuildID: "g1", Kind: "forum", Name: "ideas", RawJSON: `{}`})) + require.NoError(t, s.UpsertChannel(ctx, ChannelRecord{ID: "thread", GuildID: "g1", Kind: "thread_public", Name: "launch", ThreadParentID: "forum", RawJSON: `{}`})) + + ids, err := s.ExpandAttachmentChannelIDs(ctx, []string{"forum"}) + require.NoError(t, err) + require.ElementsMatch(t, []string{"forum", "thread"}, ids) +} + +func TestListAttachmentsCanExcludeGuilds(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, seedAttachmentForGuild(ctx, s, "g1", "c1", "m1", "a1")) + require.NoError(t, seedAttachmentForGuild(ctx, s, DirectMessageGuildID, "dm1", "m2", "a2")) + + rows, err := s.ListAttachments(ctx, AttachmentListOptions{ExcludeGuildIDs: []string{DirectMessageGuildID}}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "a1", rows[0].AttachmentID) +} + +func TestAttachmentMediaUpdatesAndFilters(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx, filepath.Join(t.TempDir(), "discrawl.db")) + require.NoError(t, err) + defer func() { _ = s.Close() }() + + require.NoError(t, seedAttachmentForGuild(ctx, s, "g1", "c1", "m1", "a1")) + require.NoError(t, seedAttachmentForGuild(ctx, s, "g1", "c1", "m2", "a2")) + require.NoError(t, s.UpdateAttachmentMedia(ctx, AttachmentMediaUpdate{ + AttachmentID: "a1", + MediaPath: "attachments/aa/hash-file.png", + ContentSHA256: "hash", + ContentSize: 4, + FetchedAt: "2026-05-15T12:05:00Z", + FetchStatus: "fetched", + })) + require.NoError(t, s.UpdateAttachmentFetchStatus(ctx, "a2", "2026-05-15T12:06:00Z", "failed", "boom")) + + rows, err := s.ListAttachments(ctx, AttachmentListOptions{MissingOnly: true}) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "a2", rows[0].AttachmentID) + require.Equal(t, "failed", rows[0].FetchStatus) + require.Equal(t, "boom", rows[0].FetchError) + + rows, err = s.ListAttachments(ctx, AttachmentListOptions{ + GuildIDs: []string{"g1"}, + ChannelIDs: []string{"c1"}, + Channel: "#c1", + Author: "Peter", + Filename: "file", + ContentType: "image/", + Since: time.Date(2026, 5, 15, 11, 0, 0, 0, time.UTC), + Before: time.Date(2026, 5, 15, 13, 0, 0, 0, time.UTC), + Limit: 1, + }) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "a1", rows[0].AttachmentID) + require.Equal(t, "attachments/aa/hash-file.png", rows[0].MediaPath) + require.Equal(t, "hash", rows[0].ContentSHA256) + require.Equal(t, int64(4), rows[0].ContentSize) + require.Equal(t, "fetched", rows[0].FetchStatus) + require.False(t, rows[0].FetchedAt.IsZero()) + + rows, err = s.ListAttachments(ctx, AttachmentListOptions{MessageID: "missing"}) + require.NoError(t, err) + require.Empty(t, rows) +} + +func seedAttachmentForGuild(ctx context.Context, s *Store, guildID, channelID, messageID, attachmentID string) error { + if err := s.UpsertGuild(ctx, GuildRecord{ID: guildID, Name: guildID, RawJSON: `{}`}); err != nil { + return err + } + if err := s.UpsertChannel(ctx, ChannelRecord{ID: channelID, GuildID: guildID, Kind: "text", Name: channelID, RawJSON: `{}`}); err != nil { + return err + } + if err := s.UpsertMember(ctx, MemberRecord{GuildID: guildID, UserID: "u1", Username: "peter", DisplayName: "Peter", RoleIDsJSON: `[]`, RawJSON: `{}`}); err != nil { + return err + } + return s.UpsertMessages(ctx, []MessageMutation{{ + Record: MessageRecord{ + ID: messageID, + GuildID: guildID, + ChannelID: channelID, + ChannelName: channelID, + AuthorID: "u1", + AuthorName: "Peter", + MessageType: 0, + CreatedAt: "2026-05-15T12:00:00Z", + Content: "attached", + NormalizedContent: "attached file.png", + HasAttachments: true, + RawJSON: `{}`, + }, + Attachments: []AttachmentRecord{{ + AttachmentID: attachmentID, + MessageID: messageID, + GuildID: guildID, + ChannelID: channelID, + AuthorID: "u1", + Filename: "file.png", + ContentType: "image/png", + Size: 7, + URL: "https://example.test/file.png", + }}, + }}) +} diff --git a/internal/store/store.go b/internal/store/store.go index 07c42bf..ddd0559 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -16,7 +16,7 @@ const ( timeLayout = time.RFC3339Nano messageFTSVersion = "2" memberFTSVersion = "1" - storeSchemaVersion = 2 + storeSchemaVersion = 3 ) var ErrSchemaVersionMismatch = errors.New("database schema version mismatch") @@ -175,7 +175,16 @@ func (s *Store) migrate(ctx context.Context) error { if err := s.applyQueryIndexMigration(ctx); err != nil { return err } - if err := s.setSchemaVersion(ctx, storeSchemaVersion); err != nil { + if err := s.setSchemaVersion(ctx, 2); err != nil { + return err + } + currentVersion = 2 + } + if currentVersion < 3 { + if err := s.applyAttachmentMediaMigration(ctx); err != nil { + return err + } + if err := s.setSchemaVersion(ctx, 3); err != nil { return err } } @@ -187,6 +196,9 @@ func (s *Store) migrate(ctx context.Context) error { if err := s.applyQueryIndexMigration(ctx); err != nil { return err } + if err := s.applyAttachmentMediaMigration(ctx); err != nil { + return err + } if err := s.ensureFTSRowIDs(ctx); err != nil { return err } @@ -360,6 +372,12 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { url text, proxy_url text, text_content text not null default '', + media_path text, + content_sha256 text, + content_size integer not null default 0, + fetched_at text, + fetch_status text not null default '', + fetch_error text not null default '', updated_at text not null );`, `create table if not exists mention_events ( @@ -444,6 +462,39 @@ func (s *Store) applyBaselineSchema(ctx context.Context) error { return tx.Commit() } +func (s *Store) applyAttachmentMediaMigration(ctx context.Context) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer rollback(tx) + for _, column := range []struct { + name string + sql string + }{ + {"media_path", `alter table message_attachments add column media_path text`}, + {"content_sha256", `alter table message_attachments add column content_sha256 text`}, + {"content_size", `alter table message_attachments add column content_size integer not null default 0`}, + {"fetched_at", `alter table message_attachments add column fetched_at text`}, + {"fetch_status", `alter table message_attachments add column fetch_status text not null default ''`}, + {"fetch_error", `alter table message_attachments add column fetch_error text not null default ''`}, + } { + ok, err := columnExists(ctx, tx, "message_attachments", column.name) + if err != nil { + return err + } + if !ok { + if _, err := tx.ExecContext(ctx, column.sql); err != nil { + return fmt.Errorf("add message_attachments.%s: %w", column.name, err) + } + } + } + if _, err := tx.ExecContext(ctx, `create index if not exists idx_attachments_sha256 on message_attachments(content_sha256)`); err != nil { + return fmt.Errorf("ensure attachment media index: %w", err) + } + return tx.Commit() +} + func (s *Store) applyQueryIndexMigration(ctx context.Context) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 77672fd..f8dc7c1 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -1140,7 +1140,7 @@ func TestOpenMigratesSchemaV1ToV2(t *testing.T) { var version int require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) - require.Equal(t, 2, version) + require.Equal(t, storeSchemaVersion, version) _, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'") require.NoError(t, err) @@ -1182,7 +1182,7 @@ func TestOpenMigratesUnversionedV1SchemaToV2(t *testing.T) { var version int require.NoError(t, s.DB().QueryRowContext(ctx, `pragma user_version`).Scan(&version)) - require.Equal(t, 2, version) + require.Equal(t, storeSchemaVersion, version) _, rows, err := s.ReadOnlyQuery(ctx, "select provider, model, input_version, last_error, locked_at from embedding_jobs where message_id = 'm1'") require.NoError(t, err) diff --git a/internal/store/write.go b/internal/store/write.go index 8eae067..9766a04 100644 --- a/internal/store/write.go +++ b/internal/store/write.go @@ -66,17 +66,23 @@ type MessageRecord struct { } type AttachmentRecord struct { - AttachmentID string - MessageID string - GuildID string - ChannelID string - AuthorID string - Filename string - ContentType string - Size int64 - URL string - ProxyURL string - TextContent string + AttachmentID string + MessageID string + GuildID string + ChannelID string + AuthorID string + Filename string + ContentType string + Size int64 + URL string + ProxyURL string + TextContent string + MediaPath string + ContentSHA256 string + ContentSize int64 + FetchedAt string + FetchStatus string + FetchError string } type MentionEventRecord struct { @@ -452,6 +458,10 @@ func appendEventTx(ctx context.Context, tx *sql.Tx, guildID, channelID, messageI } func replaceAttachmentsTx(ctx context.Context, tx *sql.Tx, messageID string, attachments []AttachmentRecord) error { + existing, err := existingAttachmentMediaTx(ctx, tx, messageID) + if err != nil { + return err + } if _, err := tx.ExecContext(ctx, `delete from message_attachments where message_id = ?`, messageID); err != nil { return err } @@ -462,14 +472,23 @@ func replaceAttachmentsTx(ctx context.Context, tx *sql.Tx, messageID string, att stmt, err := tx.PrepareContext(ctx, ` insert into message_attachments( attachment_id, message_id, guild_id, channel_id, author_id, filename, - content_type, size, url, proxy_url, text_content, updated_at - ) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + content_type, size, url, proxy_url, text_content, media_path, content_sha256, + content_size, fetched_at, fetch_status, fetch_error, updated_at + ) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `) if err != nil { return err } defer func() { _ = stmt.Close() }() for _, attachment := range attachments { + if media, ok := existing[attachment.AttachmentID]; ok && attachment.MediaPath == "" { + attachment.MediaPath = media.MediaPath + attachment.ContentSHA256 = media.ContentSHA256 + attachment.ContentSize = media.ContentSize + attachment.FetchedAt = media.FetchedAt + attachment.FetchStatus = media.FetchStatus + attachment.FetchError = media.FetchError + } if _, err := stmt.ExecContext( ctx, attachment.AttachmentID, @@ -483,6 +502,12 @@ func replaceAttachmentsTx(ctx context.Context, tx *sql.Tx, messageID string, att nullable(attachment.URL), nullable(attachment.ProxyURL), attachment.TextContent, + nullable(attachment.MediaPath), + nullable(attachment.ContentSHA256), + attachment.ContentSize, + nullable(attachment.FetchedAt), + attachment.FetchStatus, + attachment.FetchError, now, ); err != nil { return err @@ -491,6 +516,36 @@ func replaceAttachmentsTx(ctx context.Context, tx *sql.Tx, messageID string, att return nil } +func existingAttachmentMediaTx(ctx context.Context, tx *sql.Tx, messageID string) (map[string]AttachmentRecord, error) { + rows, err := tx.QueryContext(ctx, ` + select attachment_id, coalesce(media_path, ''), coalesce(content_sha256, ''), + content_size, coalesce(fetched_at, ''), coalesce(fetch_status, ''), coalesce(fetch_error, '') + from message_attachments + where message_id = ? + `, messageID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[string]AttachmentRecord{} + for rows.Next() { + var record AttachmentRecord + if err := rows.Scan( + &record.AttachmentID, + &record.MediaPath, + &record.ContentSHA256, + &record.ContentSize, + &record.FetchedAt, + &record.FetchStatus, + &record.FetchError, + ); err != nil { + return nil, err + } + out[record.AttachmentID] = record + } + return out, rows.Err() +} + func replaceMentionEventsTx(ctx context.Context, tx *sql.Tx, messageID string, mentions []MentionEventRecord) error { if _, err := tx.ExecContext(ctx, `delete from mention_events where message_id = ?`, messageID); err != nil { return err