diff --git a/app/components/player/player.go b/app/components/player/player.go index e1e90a9..8204427 100644 --- a/app/components/player/player.go +++ b/app/components/player/player.go @@ -43,10 +43,35 @@ type PlayerParams struct { Source sources.Source // the source for this playback ViewOffset int // resume position in milliseconds + // Parent / grandparent ratingKeys. For episodes: the season and show. + // Empty for movies. Used by post-playback cache invalidation to clear + // the season's children listing. + ParentRatingKey string + GrandparentRatingKey string + // NextEpisode is the pre-resolved next episode (nil for movies or last episode). NextEpisode *NextEpisodeInfo } +// PlayerParamsForMetadata builds PlayerParams for a metadata item that has at +// least one playable media part. Caller is responsible for verifying that +// meta.Media[0].Part[0] exists. +func PlayerParamsForMetadata(ctx context.Context, meta *sources.Metadata, src sources.Source, win *gtk.Window, nextEp *NextEpisodeInfo) PlayerParams { + return PlayerParams{ + Ctx: ctx, + Title: meta.Title, + PartKey: meta.Media[0].Part[0].Key, + Window: win, + RatingKey: meta.RatingKey, + ParentRatingKey: meta.ParentRatingKey, + GrandparentRatingKey: meta.GrandparentRatingKey, + Media: meta.Media, + Source: src, + ViewOffset: meta.ViewOffset, + NextEpisode: nextEp, + } +} + // NewPlayer creates a video player with overlay controls. // The player reuses the application's main window — there is no separate // fullscreen modal. The user can still toggle fullscreen via the F/F11 key @@ -553,20 +578,25 @@ func NewPlayer(params PlayerParams) { playNextEpisode = func() { // Resolve the next-next episode before closing (context still alive). var nextNext *NextEpisodeInfo + var parent, grandparent string if nextInfo.Metadata != nil { nextNext = ResolveNextEpisode(ctx, src, nextInfo.Metadata) + parent = nextInfo.Metadata.ParentRatingKey + grandparent = nextInfo.Metadata.GrandparentRatingKey } closePlayer() NewPlayer(PlayerParams{ - Ctx: params.Ctx, - Title: nextInfo.Title, - PartKey: nextInfo.PartKey, - Window: params.Window, - RatingKey: nextInfo.RatingKey, - Media: nextInfo.Media, - Source: src, - ViewOffset: nextInfo.ViewOffset, - NextEpisode: nextNext, + Ctx: params.Ctx, + Title: nextInfo.Title, + PartKey: nextInfo.PartKey, + Window: params.Window, + RatingKey: nextInfo.RatingKey, + ParentRatingKey: parent, + GrandparentRatingKey: grandparent, + Media: nextInfo.Media, + Source: src, + ViewOffset: nextInfo.ViewOffset, + NextEpisode: nextNext, }) } @@ -816,18 +846,26 @@ func NewPlayer(params PlayerParams) { pcore.Pause() dur := currentDurationUs() ts := currentTimestampUs() + progressReported := false if dur > 0 { timeMs := int(ts / 1000) durationMs := int(dur / 1000) if err := src.UpdateProgress(ctx, params.RatingKey, sources.StateStopped, timeMs, durationMs); err != nil { slog.Error("failed to send final progress", "error", err) + } else { + progressReported = true } } if dur > 0 && ts > 0 && float64(ts)/float64(dur) > 0.9 { if err := src.Scrobble(ctx, params.RatingKey); err != nil { slog.Error("failed to scrobble", "error", err) + } else { + progressReported = true } } + if progressReported { + src.InvalidateAfterPlayback(params.RatingKey, params.ParentRatingKey, params.GrandparentRatingKey) + } // Detach paintable before pcore.Close — otherwise the picture // may snapshot a sink whose state is being freed (SIGSEGV in // gdk_paintable_snapshot). A typed-nil PaintableBase passes a diff --git a/app/pages/episode.go b/app/pages/episode.go index 19a85a7..a752cd9 100644 --- a/app/pages/episode.go +++ b/app/pages/episode.go @@ -77,17 +77,7 @@ func Episode(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey ConnectClicked(func(b gtk.Button) { if len(meta.Media) > 0 && len(meta.Media[0].Part) > 0 { nextEp := player.ResolveNextEpisode(ctx, src, meta) - player.NewPlayer(player.PlayerParams{ - Ctx: ctx, - Title: meta.Title, - PartKey: meta.Media[0].Part[0].Key, - Window: appCtx.Window, - RatingKey: ratingKey, - Media: meta.Media, - Source: src, - ViewOffset: meta.ViewOffset, - NextEpisode: nextEp, - }) + player.NewPlayer(player.PlayerParamsForMetadata(ctx, meta, src, appCtx.Window, nextEp)) } }), ). @@ -119,6 +109,7 @@ func Episode(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey }) return } + src.InvalidateAfterPlayback(ratingKey, meta.ParentRatingKey, meta.GrandparentRatingKey) schwifty.OnMainThreadOncePure(func() { b.SetSensitive(true) if watched { diff --git a/app/pages/movie.go b/app/pages/movie.go index 271d52f..4928023 100644 --- a/app/pages/movie.go +++ b/app/pages/movie.go @@ -87,16 +87,7 @@ func Movie(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey s WithCSSClass("pill"). ConnectClicked(func(b gtk.Button) { if len(meta.Media) > 0 && len(meta.Media[0].Part) > 0 { - player.NewPlayer(player.PlayerParams{ - Ctx: ctx, - Title: meta.Title, - PartKey: meta.Media[0].Part[0].Key, - Window: appCtx.Window, - RatingKey: ratingKey, - Media: meta.Media, - Source: src, - ViewOffset: meta.ViewOffset, - }) + player.NewPlayer(player.PlayerParamsForMetadata(ctx, meta, src, appCtx.Window, nil)) } }), ). @@ -128,6 +119,7 @@ func Movie(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey s }) return } + src.InvalidateAfterPlayback(ratingKey, meta.ParentRatingKey, meta.GrandparentRatingKey) schwifty.OnMainThreadOncePure(func() { b.SetSensitive(true) if watched { diff --git a/app/pages/season.go b/app/pages/season.go index 047447b..13c19df 100644 --- a/app/pages/season.go +++ b/app/pages/season.go @@ -68,17 +68,7 @@ func Season(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey WithCSSClass("pill"). ConnectClicked(func(b gtk.Button) { nextEp := player.ResolveNextEpisode(ctx, src, ep) - player.NewPlayer(player.PlayerParams{ - Ctx: ctx, - Title: ep.Title, - PartKey: ep.Media[0].Part[0].Key, - Window: appCtx.Window, - RatingKey: ep.RatingKey, - Media: ep.Media, - Source: src, - ViewOffset: ep.ViewOffset, - NextEpisode: nextEp, - }) + player.NewPlayer(player.PlayerParamsForMetadata(ctx, ep, src, appCtx.Window, nextEp)) }), ) } diff --git a/app/pages/show.go b/app/pages/show.go index d24d966..c5d1c75 100644 --- a/app/pages/show.go +++ b/app/pages/show.go @@ -74,17 +74,7 @@ func Show(ctx context.Context, appCtx *appctx.AppContext, serverID, ratingKey st WithCSSClass("pill"). ConnectClicked(func(b gtk.Button) { nextEp := player.ResolveNextEpisode(ctx, src, ep) - player.NewPlayer(player.PlayerParams{ - Ctx: ctx, - Title: ep.Title, - PartKey: ep.Media[0].Part[0].Key, - Window: appCtx.Window, - RatingKey: ep.RatingKey, - Media: ep.Media, - Source: src, - ViewOffset: ep.ViewOffset, - NextEpisode: nextEp, - }) + player.NewPlayer(player.PlayerParamsForMetadata(ctx, ep, src, appCtx.Window, nextEp)) }), ) } diff --git a/app/preference/performance.go b/app/preference/performance.go index afc3938..2c8c41d 100644 --- a/app/preference/performance.go +++ b/app/preference/performance.go @@ -37,10 +37,18 @@ func (p *PerformanceSettings) BindCacheLibraries(target *gobject.Object, propert p.settings.Bind("cache-libraries", target, property, gio.GSettingsBindNoSensitivityValue) } +func (p *PerformanceSettings) ShouldCacheLibraries() bool { + return p.settings.GetBoolean("cache-libraries") +} + func (p *PerformanceSettings) BindCacheMetadata(target *gobject.Object, property string) { p.settings.Bind("cache-metadata", target, property, gio.GSettingsBindNoSensitivityValue) } +func (p *PerformanceSettings) ShouldCacheMetadata() bool { + return p.settings.GetBoolean("cache-metadata") +} + // Navigation func (p *PerformanceSettings) BindMaxRouterHistorySize(target *gobject.Object, property string) { diff --git a/app/sources/manager.go b/app/sources/manager.go index 9672e9c..efc4e6a 100644 --- a/app/sources/manager.go +++ b/app/sources/manager.go @@ -8,7 +8,6 @@ import ( "github.com/0skillallluck/scanline/app/secrets" "github.com/0skillallluck/scanline/internal/signals" - "github.com/0skillallluck/scanline/provider/plex" "github.com/0skillallluck/scanline/provider/plex/auth" "github.com/0skillallluck/scanline/provider/plex/watchlist" "github.com/google/uuid" @@ -55,7 +54,7 @@ func NewManager() *Manager { // Assume reachable until a connection probe says otherwise srv.Reachable = srv.URL != "" if srv.Enabled && srv.URL != "" { - client := plex.NewClient(srv.URL, tokenForServer(token, srv), acct.ClientID) + client := newPlexClient(srv.URL, tokenForServer(token, srv), acct.ClientID) m.sources[srv.ID] = NewPlexSource(srv.ID, srv.Name, client) } } @@ -174,7 +173,7 @@ func (m *Manager) AddPlexAccount(ctx context.Context, token, username, clientID m.accounts = append(m.accounts, acct) for _, srv := range servers { if srv.Enabled && srv.Reachable && srv.URL != "" { - client := plex.NewClient(srv.URL, tokenForServer(token, srv), clientID) + client := newPlexClient(srv.URL, tokenForServer(token, srv), clientID) m.sources[srv.ID] = NewPlexSource(srv.ID, srv.Name, client) } } @@ -240,7 +239,7 @@ func (m *Manager) SetServerEnabled(accountID, serverID string, enabled bool) { slog.Warn("server URL not cached, need to refresh servers", "server", srv.Name) } if srv.URL != "" && token != "" { - client := plex.NewClient(srv.URL, tokenForServer(token, srv), acct.ClientID) + client := newPlexClient(srv.URL, tokenForServer(token, srv), acct.ClientID) m.sources[srv.ID] = NewPlexSource(srv.ID, srv.Name, client) } } else { @@ -355,7 +354,7 @@ func (m *Manager) RefreshServers(ctx context.Context) { srv.Reachable = false } else { srv.URL = url - client := plex.NewClient(url, tokenForServer(token, srv), info.clientID) + client := newPlexClient(url, tokenForServer(token, srv), info.clientID) newSources[srv.ID] = NewPlexSource(srv.ID, srv.Name, client) } } diff --git a/app/sources/plex.go b/app/sources/plex.go index ea1c176..6045684 100644 --- a/app/sources/plex.go +++ b/app/sources/plex.go @@ -4,9 +4,20 @@ import ( "context" "net/url" + "github.com/0skillallluck/scanline/app/preference" "github.com/0skillallluck/scanline/provider/plex" ) +// newPlexClient constructs a Plex client wired to the user's caching +// preferences. The two predicates are evaluated per request, so toggling the +// "Cache Libraries" / "Cache Metadata" switches takes effect immediately. +func newPlexClient(serverURL, token, clientID string) *plex.Client { + return plex.NewClient(serverURL, token, clientID, plex.WithCachePolicy( + preference.Performance().ShouldCacheLibraries, + preference.Performance().ShouldCacheMetadata, + )) +} + // PlexSource adapts a plex.Client to the Source interface. type PlexSource struct { client *plex.Client @@ -107,3 +118,7 @@ func (s *PlexSource) Unscrobble(ctx context.Context, ratingKey string) error { func (s *PlexSource) UpdateProgress(ctx context.Context, ratingKey string, state PlaybackState, timeMs, durationMs int) error { return s.client.Timeline.UpdateProgress(ctx, ratingKey, state, timeMs, durationMs) } + +func (s *PlexSource) InvalidateAfterPlayback(ratingKey, parentRatingKey, grandparentRatingKey string) { + s.client.InvalidateAfterPlayback(ratingKey, parentRatingKey, grandparentRatingKey) +} diff --git a/app/sources/source.go b/app/sources/source.go index da8a560..fd5bdda 100644 --- a/app/sources/source.go +++ b/app/sources/source.go @@ -70,4 +70,11 @@ type Source interface { // UpdateProgress reports playback position to the server. UpdateProgress(ctx context.Context, ratingKey string, state PlaybackState, timeMs, durationMs int) error + + // InvalidateAfterPlayback evicts cached metadata that may have changed + // after a playback action (Scrobble, Unscrobble, or playback stop). + // Pass empty strings for parent / grandparent if not known. Callers should + // invoke this after a successful Scrobble / Unscrobble / stop so that the + // next render sees fresh watch state. + InvalidateAfterPlayback(ratingKey, parentRatingKey, grandparentRatingKey string) } diff --git a/assets/meta/go.mod.yml b/assets/meta/go.mod.yml index 5f165f3..ed268eb 100644 --- a/assets/meta/go.mod.yml +++ b/assets/meta/go.mod.yml @@ -52,6 +52,11 @@ strip-components: 5 type: archive url: https://proxy.golang.org/github.com/yeqown/go-qrcode/writer/standard/@v/v1.3.0.zip +- dest: vendor/golang.org/x/sync + sha256: a35481e5ae73e51ef01cf42bcad09c3b73bb3a4abb67d495d4a575021541ed02 + strip-components: 3 + type: archive + url: https://proxy.golang.org/golang.org/x/sync/@v/v0.12.0.zip - dest: vendor/codeberg.org/puregotk/purego sha256: 0d5fb9e11934b151ac357a9c0660e4dcd287d2f93db0829356415f6e6a8ef84b strip-components: 3 diff --git a/assets/meta/modules.txt b/assets/meta/modules.txt index 4fe8b96..82043b7 100644 --- a/assets/meta/modules.txt +++ b/assets/meta/modules.txt @@ -90,6 +90,9 @@ golang.org/x/image/font golang.org/x/image/font/basicfont golang.org/x/image/math/f64 golang.org/x/image/math/fixed +# golang.org/x/sync v0.12.0 +## explicit; go 1.23.0 +golang.org/x/sync/singleflight # golang.org/x/sys v0.43.0 ## explicit; go 1.25.0 golang.org/x/sys/windows diff --git a/flake.nix b/flake.nix index 73cd560..3ab14c2 100644 --- a/flake.nix +++ b/flake.nix @@ -113,7 +113,7 @@ pname = "scanline"; version = "0.4.0"; src = pkgs.lib.cleanSource ./.; - vendorHash = "sha256-VCILJElIVjGr3mZZvDcgA/HaEYFsDY272mv2+tFnZyY="; + vendorHash = "sha256-zp+DQoo5jwJJrvC6KGdWBAvWtR55+6jslVlAAkfcU1U="; ldflags = [ "-X \"github.com/0skillallluck/scanline/app/dialogs/about.Commit=${ diff --git a/go.mod b/go.mod index 14bfa08..69d93b3 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/leonelquinteros/gotext v1.7.2 github.com/yeqown/go-qrcode/v2 v2.2.5 github.com/yeqown/go-qrcode/writer/standard v1.3.0 + golang.org/x/sync v0.12.0 ) require ( diff --git a/go.sum b/go.sum index c3474cb..95e488e 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWB golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/image v0.38.0 h1:5l+q+Y9JDC7mBOMjo4/aPhMDcxEptsX+Tt3GgRQRPuE= golang.org/x/image v0.38.0/go.mod h1:/3f6vaXC+6CEanU4KJxbcUZyEePbyKbaLoDOe4ehFYY= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/provider/plex/auth/pin.go b/provider/plex/auth/pin.go index 376f4e0..012dfca 100644 --- a/provider/plex/auth/pin.go +++ b/provider/plex/auth/pin.go @@ -2,7 +2,6 @@ package auth import ( "context" - "encoding/json" "errors" "fmt" "log/slog" @@ -10,10 +9,12 @@ import ( "net/url" "strings" "time" + + "github.com/0skillallluck/scanline/utils/httputils/request" ) const ( - plexTVBaseURL = "https://plex.tv" + plexTVBaseURL = "https://plex.tv" authAppBaseURL = "https://app.plex.tv/auth#" ) @@ -45,30 +46,28 @@ type Pin struct { // The clientIdentifier should be a unique identifier for the application instance. func RequestPin(clientIdentifier string) (*Pin, error) { formValues := url.Values{ - "strong": {"true"}, - "X-Plex-Product": {"Scanline"}, - "X-Plex-Client-Identifier": {clientIdentifier}, - } - - req, err := http.NewRequest(http.MethodPost, plexTVBaseURL+"/api/v2/pins", strings.NewReader(formValues.Encode())) - if err != nil { - return nil, fmt.Errorf("creating pin request: %w", err) + "strong": {"true"}, + "X-Plex-Product": {"Scanline"}, + "X-Plex-Client-Identifier": {clientIdentifier}, } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := request.NewRequest(http.MethodPost, plexTVBaseURL+"/api/v2/pins"). + WithHeaders(map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }). + WithBody(strings.NewReader(formValues.Encode())). + Do() if err != nil { return nil, fmt.Errorf("executing pin request: %w", err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated { return nil, fmt.Errorf("failed to request PIN: %s", resp.Status) } var pin Pin - if err := json.NewDecoder(resp.Body).Decode(&pin); err != nil { + if err := resp.JSON(&pin); err != nil { return nil, fmt.Errorf("decoding pin response: %w", err) } slog.Debug("plex: PIN requested", "pin_id", pin.ID, "expires_in", pin.ExpiresIn) @@ -78,29 +77,23 @@ func RequestPin(clientIdentifier string) (*Pin, error) { // CheckPin checks the status of a PIN and returns the updated PIN information. // If the user has authorized the PIN, the returned Pin will have AuthToken populated. func CheckPin(pinID int, pinCode, clientIdentifier string) (*Pin, error) { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/api/v2/pins/%d", plexTVBaseURL, pinID), nil) - if err != nil { - return nil, fmt.Errorf("creating check request: %w", err) - } - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Plex-Client-Identifier", clientIdentifier) - - params := req.URL.Query() - params.Set("code", pinCode) - req.URL.RawQuery = params.Encode() - - resp, err := http.DefaultClient.Do(req) + resp, err := request.NewRequest(http.MethodGet, fmt.Sprintf("%s/api/v2/pins/%d", plexTVBaseURL, pinID)). + WithHeaders(map[string]string{ + "Accept": "application/json", + "X-Plex-Client-Identifier": clientIdentifier, + }). + WithQuery(map[string]string{"code": pinCode}). + Do() if err != nil { return nil, fmt.Errorf("executing check request: %w", err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to check PIN: %s", resp.Status) } var pin Pin - if err := json.NewDecoder(resp.Body).Decode(&pin); err != nil { + if err := resp.JSON(&pin); err != nil { return nil, fmt.Errorf("decoding check response: %w", err) } return &pin, nil @@ -134,8 +127,8 @@ func PollPin(ctx context.Context, pinID int, pinCode, clientIdentifier string, i // AuthAppURL returns the URL where the user should be directed to authorize the PIN. func AuthAppURL(clientIdentifier, code, product string) string { params := url.Values{ - "clientID": {clientIdentifier}, - "code": {code}, + "clientID": {clientIdentifier}, + "code": {code}, "context[device][product]": {product}, } return authAppBaseURL + "?" + params.Encode() diff --git a/provider/plex/auth/plextv.go b/provider/plex/auth/plextv.go index 457b359..e850913 100644 --- a/provider/plex/auth/plextv.go +++ b/provider/plex/auth/plextv.go @@ -2,10 +2,11 @@ package auth import ( "context" - "encoding/json" "fmt" "log/slog" "net/http" + + "github.com/0skillallluck/scanline/utils/httputils/request" ) // Resource represents a Plex resource (server, player, etc.) discovered from plex.tv. @@ -89,14 +90,12 @@ type ResourceConnection struct { // PlexTV provides access to plex.tv API endpoints for server discovery. type PlexTV struct { clientIdentifier string - httpClient *http.Client } // NewPlexTV creates a new PlexTV client with the given client identifier. func NewPlexTV(clientIdentifier string) *PlexTV { return &PlexTV{ clientIdentifier: clientIdentifier, - httpClient: http.DefaultClient, } } @@ -104,32 +103,29 @@ func NewPlexTV(clientIdentifier string) *PlexTV { // The token parameter is the authentication token obtained from PIN authentication. func (p *PlexTV) DiscoverServers(ctx context.Context, token string) ([]Resource, error) { slog.Debug("plex: discovering servers") - req, err := http.NewRequestWithContext(ctx, http.MethodGet, plexTVBaseURL+"/api/v2/resources", nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Plex-Token", token) - req.Header.Set("X-Plex-Client-Identifier", p.clientIdentifier) - - params := req.URL.Query() - params.Set("includeHttps", "1") - params.Set("includeRelay", "1") - req.URL.RawQuery = params.Encode() - - resp, err := p.httpClient.Do(req) + resp, err := request.NewRequest(http.MethodGet, plexTVBaseURL+"/api/v2/resources"). + WithContext(ctx). + WithHeaders(map[string]string{ + "Accept": "application/json", + "X-Plex-Token": token, + "X-Plex-Client-Identifier": p.clientIdentifier, + }). + WithLogging("X-Plex-Token"). + WithQuery(map[string]string{ + "includeHttps": "1", + "includeRelay": "1", + }). + Do() if err != nil { return nil, fmt.Errorf("executing request: %w", err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to discover servers: %s", resp.Status) } var resources []Resource - if err := json.NewDecoder(resp.Body).Decode(&resources); err != nil { + if err := resp.JSON(&resources); err != nil { return nil, fmt.Errorf("decoding response: %w", err) } slog.Debug("plex: servers discovered", "resource_count", len(resources)) @@ -151,27 +147,27 @@ type User struct { // GetUser retrieves the Plex account information for the authenticated user. func (p *PlexTV) GetUser(ctx context.Context, token string) (*User, error) { slog.Debug("plex: fetching user info") - req, err := http.NewRequestWithContext(ctx, http.MethodGet, plexTVBaseURL+"/api/v2/user", nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Plex-Token", token) - req.Header.Set("X-Plex-Client-Identifier", p.clientIdentifier) - - resp, err := p.httpClient.Do(req) + resp, err := request.NewRequest(http.MethodGet, plexTVBaseURL+"/api/v2/user"). + WithContext(ctx). + WithHeaders(map[string]string{ + "Accept": "application/json", + "X-Plex-Token": token, + "X-Plex-Client-Identifier": p.clientIdentifier, + }). + WithLogging("X-Plex-Token"). + WithInMemoryCaching(60 * 60). + WithCacheKey(token). + Do() if err != nil { return nil, fmt.Errorf("executing request: %w", err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to get user: %s", resp.Status) } var user User - if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { + if err := resp.JSON(&user); err != nil { return nil, fmt.Errorf("decoding response: %w", err) } slog.Debug("plex: user fetched", "username", user.Username) diff --git a/provider/plex/base/base.go b/provider/plex/base/base.go index 12c3eb8..a4eed7d 100644 --- a/provider/plex/base/base.go +++ b/provider/plex/base/base.go @@ -8,11 +8,26 @@ import ( "github.com/0skillallluck/scanline/utils/httputils/request" ) +// CachePolicy controls whether the Plex API helpers apply HTTP response +// caching. The two predicates are evaluated per request, so toggling the +// underlying user preference takes effect on the next call. Either field may +// be nil, which disables caching for the associated category. +type CachePolicy struct { + // Libraries gates caching of library-listing-shaped endpoints + // (sections, content, hubs, search, etc.). + Libraries func() bool + + // Metadata gates caching of per-item metadata endpoints + // (item details, children, markers, server info). + Metadata func() bool +} + // Base contains the shared configuration for all Plex API services. type Base struct { BaseURL string Token string ClientID string + Cache CachePolicy } // Request builds a new request with Plex headers pre-configured. @@ -38,6 +53,96 @@ func (b *Base) GetWithQuery(ctx context.Context, path string, query map[string]s WithQuery(query) } +// GetCachedLibraries is Get with disk + memory caching applied iff the +// Libraries cache policy is enabled. The token is folded into the cache key so +// responses don't cross between accounts pointing at the same server URL. +func (b *Base) GetCachedLibraries(ctx context.Context, path string, ttlSeconds int) *request.Request { + req := b.Get(ctx, path) + if b.cacheLibrariesEnabled() { + req = req.WithCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetCachedLibrariesWithQuery is GetWithQuery with disk + memory caching gated +// by the Libraries cache policy. +func (b *Base) GetCachedLibrariesWithQuery(ctx context.Context, path string, query map[string]string, ttlSeconds int) *request.Request { + req := b.GetWithQuery(ctx, path, query) + if b.cacheLibrariesEnabled() { + req = req.WithCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetCachedMetadata is Get with disk + memory caching gated by the Metadata +// cache policy. +func (b *Base) GetCachedMetadata(ctx context.Context, path string, ttlSeconds int) *request.Request { + req := b.Get(ctx, path) + if b.cacheMetadataEnabled() { + req = req.WithCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetCachedMetadataWithQuery is GetWithQuery with disk + memory caching gated +// by the Metadata cache policy. +func (b *Base) GetCachedMetadataWithQuery(ctx context.Context, path string, query map[string]string, ttlSeconds int) *request.Request { + req := b.GetWithQuery(ctx, path, query) + if b.cacheMetadataEnabled() { + req = req.WithCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetMemCachedLibraries is Get with in-memory-only caching gated by the +// Libraries cache policy. Use for short-lived data that shouldn't survive +// app restarts. +func (b *Base) GetMemCachedLibraries(ctx context.Context, path string, ttlSeconds int) *request.Request { + req := b.Get(ctx, path) + if b.cacheLibrariesEnabled() { + req = req.WithInMemoryCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetMemCachedLibrariesWithQuery is GetWithQuery with in-memory-only caching +// gated by the Libraries cache policy. +func (b *Base) GetMemCachedLibrariesWithQuery(ctx context.Context, path string, query map[string]string, ttlSeconds int) *request.Request { + req := b.GetWithQuery(ctx, path, query) + if b.cacheLibrariesEnabled() { + req = req.WithInMemoryCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetMemCachedMetadata is Get with in-memory-only caching gated by the +// Metadata cache policy. +func (b *Base) GetMemCachedMetadata(ctx context.Context, path string, ttlSeconds int) *request.Request { + req := b.Get(ctx, path) + if b.cacheMetadataEnabled() { + req = req.WithInMemoryCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +// GetMemCachedMetadataWithQuery is GetWithQuery with in-memory-only caching +// gated by the Metadata cache policy. +func (b *Base) GetMemCachedMetadataWithQuery(ctx context.Context, path string, query map[string]string, ttlSeconds int) *request.Request { + req := b.GetWithQuery(ctx, path, query) + if b.cacheMetadataEnabled() { + req = req.WithInMemoryCaching(ttlSeconds).WithCacheKey(b.Token) + } + return req +} + +func (b *Base) cacheLibrariesEnabled() bool { + return b.Cache.Libraries != nil && b.Cache.Libraries() +} + +func (b *Base) cacheMetadataEnabled() bool { + return b.Cache.Metadata != nil && b.Cache.Metadata() +} + // Post returns a POST request. func (b *Base) Post(ctx context.Context, path string) *request.Request { return b.Request(http.MethodPost, path).WithContext(ctx) diff --git a/provider/plex/client.go b/provider/plex/client.go index 4a9646c..ffb1957 100644 --- a/provider/plex/client.go +++ b/provider/plex/client.go @@ -14,6 +14,7 @@ import ( "github.com/0skillallluck/scanline/provider/plex/search" "github.com/0skillallluck/scanline/provider/plex/server" "github.com/0skillallluck/scanline/provider/plex/timeline" + "github.com/0skillallluck/scanline/utils/cacheutils" "github.com/0skillallluck/scanline/utils/httputils/request" ) @@ -49,19 +50,35 @@ type Client struct { Timeline *timeline.Timeline } +// Option configures a Client at construction time. +type Option func(*Client) + +// WithCachePolicy configures HTTP response caching for the cacheable Plex +// endpoints. The two predicates are evaluated per request, so flipping the +// underlying user preference takes effect on the next call. Pass nil for +// either argument to disable caching for that category. +func WithCachePolicy(libraries, metadata func() bool) Option { + return func(c *Client) { + c.base.Cache = base.CachePolicy{ + Libraries: libraries, + Metadata: metadata, + } + } +} + // NewClient creates a new Plex Media Server client. // // The serverURL should be the base URL of the Plex server (e.g., "http://localhost:32400"). // The token is the authentication token for the server. // The clientID should be a unique identifier for the application instance (typically a UUID). -func NewClient(serverURL, token, clientID string) *Client { +func NewClient(serverURL, token, clientID string, opts ...Option) *Client { b := &base.Base{ BaseURL: serverURL, Token: token, ClientID: clientID, } - return &Client{ + c := &Client{ base: b, Server: server.New(b), Library: library.New(b), @@ -70,6 +87,32 @@ func NewClient(serverURL, token, clientID string) *Client { Playlists: playlists.New(b), Timeline: timeline.New(b), } + for _, opt := range opts { + opt(c) + } + return c +} + +// InvalidateAfterPlayback evicts cached metadata for the given item and the +// hubs whose contents may have shifted as a result of a playback action +// (Scrobble, Unscrobble, or playback stop). parentRatingKey and +// grandparentRatingKey may be empty when not known; they extend invalidation +// up the episode → season → show chain. +func (c *Client) InvalidateAfterPlayback(ratingKey, parentRatingKey, grandparentRatingKey string) { + if ratingKey != "" { + // Catches both Library.Metadata (no query) and Library.Markers + // (?includeMarkers=1) since they share this URL prefix. + cacheutils.DeleteByPrefix(c.base.BaseURL + "/library/metadata/" + ratingKey) + } + if parentRatingKey != "" { + cacheutils.DeleteByPrefix(c.base.BaseURL + "/library/metadata/" + parentRatingKey + "/children") + } + if grandparentRatingKey != "" { + cacheutils.DeleteByPrefix(c.base.BaseURL + "/library/metadata/" + grandparentRatingKey + "/children") + } + // Hubs that surface watched / in-progress items. + cacheutils.DeleteByPrefix(c.base.BaseURL + "/hubs/promoted") + cacheutils.DeleteByPrefix(c.base.BaseURL + "/hubs/continueWatching") } // ServerURL returns the base URL of the Plex server. diff --git a/provider/plex/hubs/continue_watching.go b/provider/plex/hubs/continue_watching.go index 2880ecd..29a71e3 100644 --- a/provider/plex/hubs/continue_watching.go +++ b/provider/plex/hubs/continue_watching.go @@ -5,10 +5,9 @@ import "context" // ContinueWatching returns the continue watching hub. // // This contains items the user has started but not finished watching. -// No caching is applied as this data changes frequently. func (h *Hubs) ContinueWatching(ctx context.Context) ([]Hub, error) { var resp mediaContainerResponse[hubsContainer] - err := h.Get(ctx, "/hubs/continueWatching"). + err := h.GetMemCachedLibraries(ctx, "/hubs/continueWatching", 60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/hubs/home.go b/provider/plex/hubs/home.go index c9469a7..8e31b5d 100644 --- a/provider/plex/hubs/home.go +++ b/provider/plex/hubs/home.go @@ -17,7 +17,7 @@ func (h *Hubs) Home(ctx context.Context, count int) ([]Hub, error) { query["count"] = strconv.Itoa(count) } - err := h.GetWithQuery(ctx, "/hubs/promoted", query). + err := h.GetCachedLibrariesWithQuery(ctx, "/hubs/promoted", query, 15*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/hubs/related.go b/provider/plex/hubs/related.go index f44a011..248388b 100644 --- a/provider/plex/hubs/related.go +++ b/provider/plex/hubs/related.go @@ -7,7 +7,7 @@ import "context" // The id parameter is the rating key of the item. func (h *Hubs) Related(ctx context.Context, id string) ([]Hub, error) { var resp mediaContainerResponse[hubsContainer] - err := h.Get(ctx, "/hubs/metadata/"+id+"/related"). + err := h.GetCachedMetadata(ctx, "/hubs/metadata/"+id+"/related", 60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/hubs/section.go b/provider/plex/hubs/section.go index 36916b5..6e2d102 100644 --- a/provider/plex/hubs/section.go +++ b/provider/plex/hubs/section.go @@ -5,7 +5,7 @@ import "context" // Section returns the hubs for a specific library section. func (h *Hubs) Section(ctx context.Context, sectionID string) ([]Hub, error) { var resp mediaContainerResponse[hubsContainer] - err := h.Get(ctx, "/hubs/sections/"+sectionID). + err := h.GetCachedLibraries(ctx, "/hubs/sections/"+sectionID, 15*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/invalidate_test.go b/provider/plex/invalidate_test.go new file mode 100644 index 0000000..7c4bf5f --- /dev/null +++ b/provider/plex/invalidate_test.go @@ -0,0 +1,80 @@ +package plex + +import ( + "testing" + + "github.com/0skillallluck/scanline/utils/cacheutils" +) + +// TestInvalidateAfterPlayback_RemovesItemAndHubs verifies that the prefix +// invalidation drops cached entries for the affected item, the parent's +// children listing, the home hub, and the continue-watching hub — and leaves +// unrelated entries intact. +func TestInvalidateAfterPlayback_RemovesItemAndHubs(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + const baseURL = "https://server" + + // Seed a variety of cache entries. + entries := map[string]struct { + strategy cacheutils.Strategy + }{ + baseURL + "/library/metadata/42": {cacheutils.Layered}, + baseURL + "/library/metadata/42?includeMarkers=1": {cacheutils.Layered}, + baseURL + "/library/metadata/parent/children": {cacheutils.Layered}, + baseURL + "/library/metadata/grandparent/children": {cacheutils.Layered}, + baseURL + "/library/metadata/99": {cacheutils.Layered}, // unrelated + baseURL + "/library/sections/all": {cacheutils.Layered}, // unrelated + baseURL + "/hubs/promoted": {cacheutils.Layered}, + baseURL + "/hubs/continueWatching": {cacheutils.MemoryOnly}, + } + for k, v := range entries { + _ = cacheutils.Store(k, []byte("v:"+k), v.strategy, 600) + } + + c := NewClient(baseURL, "tok", "cid") + c.InvalidateAfterPlayback("42", "parent", "grandparent") + + expectGone := []string{ + baseURL + "/library/metadata/42", + baseURL + "/library/metadata/42?includeMarkers=1", + baseURL + "/library/metadata/parent/children", + baseURL + "/library/metadata/grandparent/children", + baseURL + "/hubs/promoted", + baseURL + "/hubs/continueWatching", + } + for _, k := range expectGone { + if _, ok := cacheutils.Get(k, entries[k].strategy, 600); ok { + t.Errorf("expected %q to be invalidated", k) + } + } + + expectKept := []string{ + baseURL + "/library/metadata/99", + baseURL + "/library/sections/all", + } + for _, k := range expectKept { + if _, ok := cacheutils.Get(k, entries[k].strategy, 600); !ok { + t.Errorf("expected %q to remain cached", k) + } + } +} + +// TestInvalidateAfterPlayback_EmptyParentsAreSkipped checks that empty parent +// IDs don't accidentally invalidate the entire children prefix. +func TestInvalidateAfterPlayback_EmptyParentsAreSkipped(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + const baseURL = "https://server" + _ = cacheutils.Store(baseURL+"/library/metadata/parent/children", []byte("v"), cacheutils.Layered, 600) + + c := NewClient(baseURL, "tok", "cid") + c.InvalidateAfterPlayback("42", "", "") + + // The unrelated /children entry must survive empty-parent invalidation. + if _, ok := cacheutils.Get(baseURL+"/library/metadata/parent/children", cacheutils.Layered, 600); !ok { + t.Error("empty parent ratingKey should not blow away unrelated children entries") + } +} diff --git a/provider/plex/library/cache_test.go b/provider/plex/library/cache_test.go new file mode 100644 index 0000000..cc6e0ae --- /dev/null +++ b/provider/plex/library/cache_test.go @@ -0,0 +1,132 @@ +package library + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/0skillallluck/scanline/provider/plex/base" + "github.com/0skillallluck/scanline/utils/cacheutils" +) + +// TestSections_CachingHonorsLibrariesPolicy verifies the libraries policy +// gates caching for Library.Sections end-to-end. +func TestSections_CachingHonorsLibrariesPolicy(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/library/sections/all") { + t.Errorf("unexpected request path: %s", r.URL.Path) + } + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"MediaContainer":{"Directory":[]}}`)) + })) + defer server.Close() + + cachingOn := true + lib := New(&base.Base{ + BaseURL: server.URL, + Token: "tok", + ClientID: "cid", + Cache: base.CachePolicy{ + Libraries: func() bool { return cachingOn }, + Metadata: func() bool { return false }, + }, + }) + + ctx := context.Background() + + if _, err := lib.Sections(ctx); err != nil { + t.Fatalf("first call: %v", err) + } + if _, err := lib.Sections(ctx); err != nil { + t.Fatalf("second call: %v", err) + } + if callCount != 1 { + t.Errorf("expected second call to hit cache: callCount=%d", callCount) + } + + // Toggle caching off → next call must hit the server. + cachingOn = false + if _, err := lib.Sections(ctx); err != nil { + t.Fatalf("third call: %v", err) + } + if callCount != 2 { + t.Errorf("expected cache bypass after toggling off: callCount=%d", callCount) + } +} + +// TestMetadata_CachingScopedByToken verifies that two clients with different +// tokens against the same server URL do NOT share cached responses. +func TestMetadata_CachingScopedByToken(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"MediaContainer":{"Metadata":[{"ratingKey":"42","title":"Item"}]}}`)) + })) + defer server.Close() + + always := func() bool { return true } + libA := New(&base.Base{BaseURL: server.URL, Token: "tokA", ClientID: "cid", Cache: base.CachePolicy{Metadata: always}}) + libB := New(&base.Base{BaseURL: server.URL, Token: "tokB", ClientID: "cid", Cache: base.CachePolicy{Metadata: always}}) + + ctx := context.Background() + + if _, err := libA.Metadata(ctx, "42"); err != nil { + t.Fatalf("libA call: %v", err) + } + if _, err := libB.Metadata(ctx, "42"); err != nil { + t.Fatalf("libB call: %v", err) + } + if callCount != 2 { + t.Errorf("expected separate cache entries per token: callCount=%d", callCount) + } + + // Repeat libA — should hit cache. + if _, err := libA.Metadata(ctx, "42"); err != nil { + t.Fatalf("libA repeat: %v", err) + } + if callCount != 2 { + t.Errorf("expected libA repeat to be cached: callCount=%d", callCount) + } +} + +// TestMetadata_NoCachingWhenPolicyOff verifies that with the metadata policy +// disabled, every call hits the server. +func TestMetadata_NoCachingWhenPolicyOff(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"MediaContainer":{"Metadata":[{"ratingKey":"42"}]}}`)) + })) + defer server.Close() + + lib := New(&base.Base{ + BaseURL: server.URL, + Token: "tok", + ClientID: "cid", + }) + + ctx := context.Background() + for i := 0; i < 3; i++ { + if _, err := lib.Metadata(ctx, "42"); err != nil { + t.Fatalf("call %d: %v", i, err) + } + } + if callCount != 3 { + t.Errorf("expected no caching when policy is nil: callCount=%d", callCount) + } +} diff --git a/provider/plex/library/children.go b/provider/plex/library/children.go index 920dabb..89e43a9 100644 --- a/provider/plex/library/children.go +++ b/provider/plex/library/children.go @@ -7,7 +7,7 @@ import "context" // The id parameter is the rating key of the parent (e.g., season ID to get episodes). func (l *Library) Children(ctx context.Context, id string) ([]Metadata, error) { var resp mediaContainerResponse[metadataContainer] - err := l.Get(ctx, "/library/metadata/"+id+"/children"). + err := l.GetCachedMetadata(ctx, "/library/metadata/"+id+"/children", 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/collection_items.go b/provider/plex/library/collection_items.go index dce0132..9b9633d 100644 --- a/provider/plex/library/collection_items.go +++ b/provider/plex/library/collection_items.go @@ -5,7 +5,7 @@ import "context" // CollectionItems returns the items in a collection. func (l *Library) CollectionItems(ctx context.Context, collectionID string) ([]Metadata, error) { var resp mediaContainerResponse[metadataContainer] - err := l.Get(ctx, "/library/collections/"+collectionID+"/items"). + err := l.GetCachedLibraries(ctx, "/library/collections/"+collectionID+"/items", 60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/collections.go b/provider/plex/library/collections.go index c5be348..19c0fdf 100644 --- a/provider/plex/library/collections.go +++ b/provider/plex/library/collections.go @@ -6,7 +6,7 @@ import "context" func (l *Library) Collections(ctx context.Context, sectionID string) ([]Metadata, error) { var resp mediaContainerResponse[metadataContainer] query := map[string]string{"sectionID": sectionID} - err := l.GetWithQuery(ctx, "/library/collections", query). + err := l.GetCachedLibrariesWithQuery(ctx, "/library/collections", query, 60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/content.go b/provider/plex/library/content.go index 72ea10a..80c56d5 100644 --- a/provider/plex/library/content.go +++ b/provider/plex/library/content.go @@ -41,7 +41,7 @@ func (l *Library) Content(ctx context.Context, sectionID string, opts *ContentOp } var resp mediaContainerResponse[metadataContainer] - err := l.GetWithQuery(ctx, "/library/sections/"+sectionID+"/all", query). + err := l.GetCachedLibrariesWithQuery(ctx, "/library/sections/"+sectionID+"/all", query, 60*60). DoAndDecode(&resp) if err != nil { return nil, 0, err diff --git a/provider/plex/library/markers.go b/provider/plex/library/markers.go index e2c9996..27b542b 100644 --- a/provider/plex/library/markers.go +++ b/provider/plex/library/markers.go @@ -9,9 +9,9 @@ import "context" // returns 404 on most Plex Media Server versions. func (l *Library) Markers(ctx context.Context, id string) ([]Marker, error) { var resp mediaContainerResponse[metadataContainer] - err := l.GetWithQuery(ctx, "/library/metadata/"+id, map[string]string{ + err := l.GetCachedMetadataWithQuery(ctx, "/library/metadata/"+id, map[string]string{ "includeMarkers": "1", - }).DoAndDecode(&resp) + }, 24*60*60).DoAndDecode(&resp) if err != nil { return nil, err } diff --git a/provider/plex/library/metadata.go b/provider/plex/library/metadata.go index 1b9f1ee..61c8ea2 100644 --- a/provider/plex/library/metadata.go +++ b/provider/plex/library/metadata.go @@ -24,7 +24,7 @@ func (e *EmptyResultError) Error() string { // Returns EmptyResultError if the item is not found. func (l *Library) Metadata(ctx context.Context, id string) (*Metadata, error) { var resp mediaContainerResponse[metadataContainer] - err := l.Get(ctx, "/library/metadata/"+id). + err := l.GetCachedMetadata(ctx, "/library/metadata/"+id, 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/seasons.go b/provider/plex/library/seasons.go index 3355952..8f77407 100644 --- a/provider/plex/library/seasons.go +++ b/provider/plex/library/seasons.go @@ -7,7 +7,7 @@ import "context" // The id parameter is the rating key of the show. func (l *Library) Seasons(ctx context.Context, id string) ([]Metadata, error) { var resp mediaContainerResponse[metadataContainer] - err := l.Get(ctx, "/library/metadata/"+id+"/children"). + err := l.GetCachedMetadata(ctx, "/library/metadata/"+id+"/children", 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/section.go b/provider/plex/library/section.go index 3c08c0c..bbb2a73 100644 --- a/provider/plex/library/section.go +++ b/provider/plex/library/section.go @@ -5,7 +5,7 @@ import "context" // Section returns information about a specific library section. func (l *Library) Section(ctx context.Context, sectionID string) (*LibrarySection, error) { var resp mediaContainerResponse[LibrarySection] - err := l.Get(ctx, "/library/sections/"+sectionID). + err := l.GetCachedLibraries(ctx, "/library/sections/"+sectionID, 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/library/sections.go b/provider/plex/library/sections.go index d0f64c2..fc71d83 100644 --- a/provider/plex/library/sections.go +++ b/provider/plex/library/sections.go @@ -5,7 +5,7 @@ import "context" // Sections returns all library sections. func (l *Library) Sections(ctx context.Context) ([]LibrarySection, error) { var resp mediaContainerResponse[librarySectionsContainer] - err := l.Get(ctx, "/library/sections/all"). + err := l.GetCachedLibraries(ctx, "/library/sections/all", 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/playlists/items.go b/provider/plex/playlists/items.go index 0885f04..13e0131 100644 --- a/provider/plex/playlists/items.go +++ b/provider/plex/playlists/items.go @@ -9,7 +9,7 @@ import ( // Items returns the items in a playlist. func (p *Playlists) Items(ctx context.Context, playlistID string) ([]library.Metadata, error) { var resp mediaContainerResponse[metadataContainer] - err := p.Get(ctx, "/playlists/"+playlistID+"/items"). + err := p.GetMemCachedLibraries(ctx, "/playlists/"+playlistID+"/items", 60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/playlists/list.go b/provider/plex/playlists/list.go index 8d38b18..81f2d23 100644 --- a/provider/plex/playlists/list.go +++ b/provider/plex/playlists/list.go @@ -3,9 +3,14 @@ package playlists import "context" // List returns all playlists. +// +// Note: Mutations (Create/Update/Delete) bypass the in-memory cache, but +// their effect on subsequent List calls only becomes visible after the +// short TTL — there's no automatic invalidation hook on the playlist +// mutation endpoints today. func (p *Playlists) List(ctx context.Context) ([]Playlist, error) { var resp mediaContainerResponse[playlistsContainer] - err := p.Get(ctx, "/playlists"). + err := p.GetMemCachedLibraries(ctx, "/playlists", 60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/search/query.go b/provider/plex/search/query.go index a6cdb7c..d90fa31 100644 --- a/provider/plex/search/query.go +++ b/provider/plex/search/query.go @@ -17,7 +17,7 @@ func (s *Search) Query(ctx context.Context, query string, limit int) ([]Hub, err } var resp mediaContainerResponse[hubsContainer] - err := s.GetWithQuery(ctx, "/hubs/search", params). + err := s.GetMemCachedLibrariesWithQuery(ctx, "/hubs/search", params, 5*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/server/info.go b/provider/plex/server/info.go index bb2cc30..9e17987 100644 --- a/provider/plex/server/info.go +++ b/provider/plex/server/info.go @@ -7,7 +7,7 @@ import "context" // This includes the server name, version, platform, and configuration details. func (s *Server) Info(ctx context.Context) (*ServerInfo, error) { var resp serverInfoContainer - err := s.Get(ctx, "/"). + err := s.GetCachedMetadata(ctx, "/", 24*60*60). DoAndDecode(&resp) if err != nil { return nil, err diff --git a/provider/plex/watchlist/watchlist.go b/provider/plex/watchlist/watchlist.go index ebed701..a3f4b91 100644 --- a/provider/plex/watchlist/watchlist.go +++ b/provider/plex/watchlist/watchlist.go @@ -2,9 +2,10 @@ package watchlist import ( "context" - "encoding/json" "fmt" "net/http" + + "github.com/0skillallluck/scanline/utils/httputils/request" ) const discoverBaseURL = "https://discover.provider.plex.tv" @@ -19,14 +20,13 @@ type Item struct { Thumb string `json:"thumb,omitempty"` Art string `json:"art,omitempty"` Summary string `json:"summary,omitempty"` - GUID string `json:"guid,omitempty"` // e.g. "plex://movie/..." + GUID string `json:"guid,omitempty"` // e.g. "plex://movie/..." } // Client fetches watchlist data from the Plex Discover API. type Client struct { token string clientID string - http *http.Client } // NewClient creates a new watchlist client with account-level credentials. @@ -34,7 +34,6 @@ func NewClient(token, clientID string) *Client { return &Client{ token: token, clientID: clientID, - http: http.DefaultClient, } } @@ -51,33 +50,31 @@ func (c *Client) List(ctx context.Context, filter string) ([]Item, error) { filter = "all" } - url := discoverBaseURL + "/library/sections/watchlist/" + filter - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Plex-Token", c.token) - req.Header.Set("X-Plex-Client-Identifier", c.clientID) - - q := req.URL.Query() - q.Set("includeCollections", "1") - q.Set("includeExternalMedia", "1") - req.URL.RawQuery = q.Encode() - - resp, err := c.http.Do(req) + resp, err := request.NewRequest(http.MethodGet, discoverBaseURL+"/library/sections/watchlist/"+filter). + WithContext(ctx). + WithHeaders(map[string]string{ + "Accept": "application/json", + "X-Plex-Token": c.token, + "X-Plex-Client-Identifier": c.clientID, + }). + WithLogging("X-Plex-Token"). + WithQuery(map[string]string{ + "includeCollections": "1", + "includeExternalMedia": "1", + }). + WithInMemoryCaching(5 * 60). + WithCacheKey(c.token). + Do() if err != nil { return nil, fmt.Errorf("executing request: %w", err) } - defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("watchlist request failed: %s", resp.Status) } var result watchlistResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := resp.JSON(&result); err != nil { return nil, fmt.Errorf("decoding response: %w", err) } diff --git a/utils/cacheutils/cache.go b/utils/cacheutils/cache.go index b33f148..67fd6c0 100644 --- a/utils/cacheutils/cache.go +++ b/utils/cacheutils/cache.go @@ -3,6 +3,8 @@ package cacheutils import ( "crypto/sha256" "encoding/base64" + "strings" + "sync" ) // Strategy specifies the caching strategy to use. @@ -14,6 +16,16 @@ const ( Layered // Memory L1 + File L2 with promotion ) +// rawKeyIndex maps raw cache keys to their hashed counterparts (and back), +// supporting DeleteByPrefix since hashed keys are not prefix-comparable. The +// index is in-process only; on cold start prefix invalidation falls back to +// TTL expiry until entries are re-cached. +var ( + rawKeyMu sync.RWMutex + rawKeyIndex = make(map[string]string) // raw → hashed + hashedToRawKey = make(map[string]string) // hashed → raw +) + // Get retrieves cached data using the given strategy. // For Layered strategy, checks memory first, then file (promoting hits to memory). func Get(key string, strategy Strategy, ttl int) ([]byte, bool) { @@ -39,8 +51,10 @@ func Get(key string, strategy Strategy, ttl int) ([]byte, bool) { return nil, false } - // Promote file hit to memory - storeInMemory(hashedKey, data) + // Re-index on promotion: rawKeyIndex is in-process only, so file hits + // across restarts need re-adding for DeleteByPrefix to find them. + storeInMemory(hashedKey, data, ttl, true) + rememberRawKey(key, hashedKey) return data, true } @@ -53,9 +67,10 @@ func Store(key string, data []byte, strategy Strategy, ttl int) error { } hashedKey := hashKey(key) + rememberRawKey(key, hashedKey) // Always store in memory - storeInMemory(hashedKey, data) + storeInMemory(hashedKey, data, ttl, strategy == Layered) // For memory-only, we're done if strategy == MemoryOnly { @@ -71,11 +86,43 @@ func Delete(key string) { hashedKey := hashKey(key) deleteFromMemory(hashedKey) deleteFromFile(hashedKey) + forgetRawKey(key) +} + +// DeleteByPrefix removes every cache entry whose raw key starts with prefix. +// Entries Stored between the scan and the delete phase aren't affected — +// they're newer than the invalidation request. +func DeleteByPrefix(prefix string) { + rawKeyMu.RLock() + matches := make([]struct{ raw, hashed string }, 0) + for raw, hashed := range rawKeyIndex { + if strings.HasPrefix(raw, prefix) { + matches = append(matches, struct{ raw, hashed string }{raw, hashed}) + } + } + rawKeyMu.RUnlock() + + if len(matches) == 0 { + return + } + + rawKeyMu.Lock() + for _, m := range matches { + delete(rawKeyIndex, m.raw) + delete(hashedToRawKey, m.hashed) + } + rawKeyMu.Unlock() + + for _, m := range matches { + deleteFromMemory(m.hashed) + deleteFromFile(m.hashed) + } } // Clear removes all cache entries from all layers. func Clear() error { clearMemory() + clearRawKeyIndex() return clearFileDir() } @@ -83,3 +130,37 @@ func hashKey(key string) string { hash := sha256.Sum256([]byte(key)) return base64.URLEncoding.EncodeToString(hash[:]) } + +func rememberRawKey(raw, hashed string) { + rawKeyMu.Lock() + rawKeyIndex[raw] = hashed + hashedToRawKey[hashed] = raw + rawKeyMu.Unlock() +} + +func forgetRawKey(raw string) { + rawKeyMu.Lock() + if hashed, ok := rawKeyIndex[raw]; ok { + delete(rawKeyIndex, raw) + delete(hashedToRawKey, hashed) + } + rawKeyMu.Unlock() +} + +// forgetHashedKey drops the index entry for a given hashed key. Called from +// LRU eviction so the index doesn't outlive the data. +func forgetHashedKey(hashed string) { + rawKeyMu.Lock() + if raw, ok := hashedToRawKey[hashed]; ok { + delete(hashedToRawKey, hashed) + delete(rawKeyIndex, raw) + } + rawKeyMu.Unlock() +} + +func clearRawKeyIndex() { + rawKeyMu.Lock() + rawKeyIndex = make(map[string]string) + hashedToRawKey = make(map[string]string) + rawKeyMu.Unlock() +} diff --git a/utils/cacheutils/cache_test.go b/utils/cacheutils/cache_test.go index 4aebe19..269a744 100644 --- a/utils/cacheutils/cache_test.go +++ b/utils/cacheutils/cache_test.go @@ -128,7 +128,7 @@ func TestLayered_MemoryHitSkipsFile(t *testing.T) { hashedKey := hashKey(key) // Store only in memory - storeInMemory(hashedKey, data) + storeInMemory(hashedKey, data, 0, false) // Get should return from memory without needing file got, found := Get(key, Layered, 60) diff --git a/utils/cacheutils/memory.go b/utils/cacheutils/memory.go index d6af60a..dd83848 100644 --- a/utils/cacheutils/memory.go +++ b/utils/cacheutils/memory.go @@ -1,43 +1,125 @@ package cacheutils -import "sync" +import ( + "container/list" + "sync" + "time" +) + +// MemoryCacheMaxBytes is the upper bound on bytes held in the in-memory cache. +// When exceeded, least-recently-used entries are evicted until the total fits. +const MemoryCacheMaxBytes = 64 * 1024 * 1024 // 64 MiB + +type memoryEntry struct { + key string + data []byte + expiresAt time.Time // zero value = no expiration + // layered: a file copy exists, so eviction must keep the rawKeyIndex + // entry alive for DeleteByPrefix to reach it. + layered bool +} var ( - memoryCache = make(map[string][]byte) - memoryCacheMu sync.RWMutex + memoryMu sync.Mutex + memoryList = list.New() + memoryIndex = make(map[string]*list.Element) + memoryBytes int64 ) +// getFromMemory returns cached bytes if present and not expired, promoting the +// entry to the front of the LRU list. func getFromMemory(hashedKey string) ([]byte, bool) { - memoryCacheMu.RLock() - defer memoryCacheMu.RUnlock() + memoryMu.Lock() + defer memoryMu.Unlock() - data, ok := memoryCache[hashedKey] + el, ok := memoryIndex[hashedKey] if !ok { return nil, false } - return copyBytes(data), true + entry := el.Value.(*memoryEntry) + if !entry.expiresAt.IsZero() && time.Now().After(entry.expiresAt) { + removeMemoryElement(el) + return nil, false + } + + memoryList.MoveToFront(el) + return copyBytes(entry.data), true } -func storeInMemory(hashedKey string, data []byte) { - memoryCacheMu.Lock() - defer memoryCacheMu.Unlock() +// storeInMemory stores data under the given hashed key. ttl == 0 means no +// expiration (still subject to LRU). The byte cap is soft for the just- +// inserted entry — an oversized entry survives at least until the next Store, +// keeping the cache useful when a single response exceeds the cap. +func storeInMemory(hashedKey string, data []byte, ttl int, layered bool) { + memoryMu.Lock() + defer memoryMu.Unlock() + + var expiresAt time.Time + if ttl > 0 { + expiresAt = time.Now().Add(time.Duration(ttl) * time.Second) + } + + var inserted *list.Element + if el, ok := memoryIndex[hashedKey]; ok { + entry := el.Value.(*memoryEntry) + memoryBytes -= int64(len(entry.data)) + entry.data = copyBytes(data) + entry.expiresAt = expiresAt + entry.layered = layered + memoryBytes += int64(len(entry.data)) + memoryList.MoveToFront(el) + inserted = el + } else { + entry := &memoryEntry{ + key: hashedKey, + data: copyBytes(data), + expiresAt: expiresAt, + layered: layered, + } + inserted = memoryList.PushFront(entry) + memoryIndex[hashedKey] = inserted + memoryBytes += int64(len(entry.data)) + } - memoryCache[hashedKey] = copyBytes(data) + for memoryBytes > MemoryCacheMaxBytes { + oldest := memoryList.Back() + if oldest == nil || oldest == inserted { + break + } + removeMemoryElement(oldest) + } } func deleteFromMemory(hashedKey string) { - memoryCacheMu.Lock() - defer memoryCacheMu.Unlock() + memoryMu.Lock() + defer memoryMu.Unlock() - delete(memoryCache, hashedKey) + if el, ok := memoryIndex[hashedKey]; ok { + removeMemoryElement(el) + } } func clearMemory() { - memoryCacheMu.Lock() - defer memoryCacheMu.Unlock() + memoryMu.Lock() + defer memoryMu.Unlock() + + memoryList.Init() + memoryIndex = make(map[string]*list.Element) + memoryBytes = 0 +} - memoryCache = make(map[string][]byte) +// removeMemoryElement evicts an LRU entry. Memory-only entries also drop the +// rawKeyIndex; layered entries keep theirs so DeleteByPrefix can still reach +// the file copy. Caller must hold memoryMu. +func removeMemoryElement(el *list.Element) { + entry := el.Value.(*memoryEntry) + memoryList.Remove(el) + delete(memoryIndex, entry.key) + memoryBytes -= int64(len(entry.data)) + if !entry.layered { + forgetHashedKey(entry.key) + } } func copyBytes(b []byte) []byte { diff --git a/utils/cacheutils/memory_test.go b/utils/cacheutils/memory_test.go new file mode 100644 index 0000000..0132368 --- /dev/null +++ b/utils/cacheutils/memory_test.go @@ -0,0 +1,266 @@ +package cacheutils + +import ( + "strings" + "testing" + "time" +) + +func TestMemoryEntry_RespectsTTL(t *testing.T) { + clearMemory() + + storeInMemory("ttl-key", []byte("data"), 1, false) + + if _, ok := getFromMemory("ttl-key"); !ok { + t.Fatal("entry should be present immediately after storing") + } + + time.Sleep(1100 * time.Millisecond) + + if _, ok := getFromMemory("ttl-key"); ok { + t.Error("entry should have expired") + } +} + +func TestMemoryEntry_TTLZeroNeverExpires(t *testing.T) { + clearMemory() + + storeInMemory("indef-key", []byte("forever"), 0, false) + + time.Sleep(50 * time.Millisecond) + + if _, ok := getFromMemory("indef-key"); !ok { + t.Error("TTL=0 entry should not expire") + } +} + +func TestMemoryEntry_LRUEvictionRespectsByteCap(t *testing.T) { + clearMemory() + t.Cleanup(clearMemory) + + // Insert 65 entries of 1 MiB each — total 65 MiB > 64 MiB cap. The + // oldest entries should be evicted. + const oneMiB = 1024 * 1024 + payload := make([]byte, oneMiB) + for i := byte(0); i < 65; i++ { + payload[0] = i // make each payload distinguishable + storeInMemory(string([]byte{'k', i}), payload, 0, false) + } + + if memoryBytes > MemoryCacheMaxBytes { + t.Errorf("memory bytes %d exceeds cap %d", memoryBytes, MemoryCacheMaxBytes) + } + + // First-inserted key should be gone. + if _, ok := getFromMemory(string([]byte{'k', 0})); ok { + t.Error("oldest entry should have been evicted") + } + + // Most recent key should still be present. + if _, ok := getFromMemory(string([]byte{'k', 64})); !ok { + t.Error("most recent entry should still be present") + } +} + +func TestMemoryEntry_GetMovesToFront(t *testing.T) { + clearMemory() + t.Cleanup(clearMemory) + + // Fill the cache to capacity with 64 entries of 1 MiB each. + const oneMiB = 1024 * 1024 + payload := make([]byte, oneMiB) + for i := byte(0); i < 64; i++ { + storeInMemory(string([]byte{'k', i}), payload, 0, false) + } + + // Touch entry 0 — it becomes most recently used; entry 1 takes its + // place as the least recently used. + if _, ok := getFromMemory(string([]byte{'k', 0})); !ok { + t.Fatal("entry 0 should be present") + } + + // Insert one more entry to push us over the cap by 1 MiB. The LRU + // (entry 1, since 0 was promoted) should be evicted. + storeInMemory("new", payload, 0, false) + + if _, ok := getFromMemory(string([]byte{'k', 0})); !ok { + t.Error("entry 0 should have survived because Get promoted it") + } + if _, ok := getFromMemory(string([]byte{'k', 1})); ok { + t.Error("entry 1 should have been evicted as the LRU") + } +} + +func TestDeleteByPrefix_RemovesMatchingEntries(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + keys := []string{ + "https://server/library/metadata/123", + "https://server/library/metadata/123?includeMarkers=1", + "https://server/library/metadata/456", + "https://server/hubs/promoted", + } + for _, k := range keys { + _ = Store(k, []byte("v:"+k), Layered, 60) + } + + DeleteByPrefix("https://server/library/metadata/123") + + for _, k := range keys { + _, ok := Get(k, Layered, 60) + shouldExist := !strings.HasPrefix(k, "https://server/library/metadata/123") + if shouldExist && !ok { + t.Errorf("non-matching key %q should still be present", k) + } + if !shouldExist && ok { + t.Errorf("matching key %q should have been deleted", k) + } + } +} + +func TestDeleteByPrefix_RemovesFromBothLayers(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + key := "https://server/foo/bar" + _ = Store(key, []byte("v"), Layered, 60) + + // Verify present in both layers + hashed := hashKey(key) + if _, ok := getFromMemory(hashed); !ok { + t.Fatal("expected entry in memory before DeleteByPrefix") + } + if _, ok := getFromFile(hashed, 60); !ok { + t.Fatal("expected entry in file cache before DeleteByPrefix") + } + + DeleteByPrefix("https://server/foo/") + + if _, ok := getFromMemory(hashed); ok { + t.Error("entry should be gone from memory after DeleteByPrefix") + } + if _, ok := getFromFile(hashed, 60); ok { + t.Error("entry should be gone from file cache after DeleteByPrefix") + } +} + +func TestLRUEviction_CleansRawKeyIndex(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + const oneMiB = 1024 * 1024 + payload := make([]byte, oneMiB) + + // Fill exactly to capacity then push one over so a single eviction occurs. + for i := byte(0); i < 64; i++ { + _ = Store("key-"+string([]byte{i}), payload, MemoryOnly, 0) + } + _ = Store("key-evictor", payload, MemoryOnly, 0) + + rawKeyMu.RLock() + rawCount := len(rawKeyIndex) + hashedCount := len(hashedToRawKey) + rawKeyMu.RUnlock() + + if rawCount > 64 { + t.Errorf("rawKeyIndex should not retain entries past LRU eviction: got %d", rawCount) + } + if rawCount != hashedCount { + t.Errorf("rawKeyIndex (%d) and hashedToRawKey (%d) should stay in sync", rawCount, hashedCount) + } +} + +// TestFileHit_ReindexesRawKey simulates a process restart (clearing the +// in-process raw-key index and the memory cache while leaving the file cache +// intact). After a Get re-indexes the entry, DeleteByPrefix must be able to +// find and delete the file copy. +func TestFileHit_ReindexesRawKey(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + key := "https://server/library/metadata/42" + if err := Store(key, []byte("v"), Layered, 600); err != nil { + t.Fatalf("Store: %v", err) + } + + // Simulate a process restart: in-memory state goes away, file cache survives. + clearMemory() + clearRawKeyIndex() + + // First read after "restart" — should hit the file and re-index. + if _, ok := Get(key, Layered, 600); !ok { + t.Fatal("expected file-cache hit after simulated restart") + } + + rawKeyMu.RLock() + hashed, indexed := rawKeyIndex[key] + rawKeyMu.RUnlock() + if !indexed { + t.Fatal("file-hit Get should re-add the raw key to the index") + } + if hashed != hashKey(key) { + t.Errorf("re-indexed mapping mismatch: got %q want %q", hashed, hashKey(key)) + } + + // Invalidation should now find and delete the file copy. + DeleteByPrefix("https://server/library/metadata/") + if _, ok := getFromFile(hashKey(key), 600); ok { + t.Error("DeleteByPrefix should have removed the file-cache entry") + } +} + +// TestLayeredEviction_PreservesInvalidation verifies that LRU eviction of a +// Layered entry's memory copy keeps the rawKeyIndex entry intact so +// DeleteByPrefix can still reach the file copy. +func TestLayeredEviction_PreservesInvalidation(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + const oneMiB = 1024 * 1024 + smallPayload := []byte("v") + bigPayload := make([]byte, oneMiB) + + target := "https://server/library/metadata/42" + if err := Store(target, smallPayload, Layered, 600); err != nil { + t.Fatalf("Store target: %v", err) + } + + // Push enough big entries through to evict the target from memory. + for i := byte(0); i < 65; i++ { + _ = Store("filler-"+string([]byte{i}), bigPayload, MemoryOnly, 0) + } + + if _, inMem := getFromMemory(hashKey(target)); inMem { + t.Fatal("target should have been LRU-evicted from memory") + } + if _, onDisk := getFromFile(hashKey(target), 600); !onDisk { + t.Fatal("target's file copy should have survived memory eviction") + } + + // Index entry must still exist so DeleteByPrefix can find the file copy. + rawKeyMu.RLock() + _, indexed := rawKeyIndex[target] + rawKeyMu.RUnlock() + if !indexed { + t.Fatal("layered entry's raw key should NOT be dropped on memory eviction") + } + + DeleteByPrefix("https://server/library/metadata/") + if _, onDisk := getFromFile(hashKey(target), 600); onDisk { + t.Error("DeleteByPrefix should have deleted the file copy") + } +} + +func TestDeleteByPrefix_NoMatchIsNoop(t *testing.T) { + SetFileCacheDir(t.TempDir()) + Clear() //nolint:errcheck + + _ = Store("keep-me", []byte("v"), Layered, 60) + + DeleteByPrefix("nonexistent-prefix") + + if _, ok := Get("keep-me", Layered, 60); !ok { + t.Error("DeleteByPrefix with non-matching prefix should leave entries intact") + } +} diff --git a/utils/httputils/request/request.go b/utils/httputils/request/request.go index a6a05d3..b4d86a5 100644 --- a/utils/httputils/request/request.go +++ b/utils/httputils/request/request.go @@ -6,34 +6,40 @@ import ( "log/slog" "net/http" "net/url" + "strings" "time" "github.com/0skillallluck/scanline/utils/cacheutils" "github.com/0skillallluck/scanline/utils/httputils/response" + "golang.org/x/sync/singleflight" ) // Request represents a chainable HTTP request builder. type Request struct { - method string - url string - headers http.Header - query url.Values - body io.Reader - ctx context.Context - cancel context.CancelFunc - timeout time.Duration - timeoutSet bool - client *http.Client - cacheStrategy cacheutils.Strategy - cacheTTL int - logging bool - redactHeaders []string - err error + method string + url string + headers http.Header + query url.Values + body io.Reader + ctx context.Context + cancel context.CancelFunc + timeout time.Duration + timeoutSet bool + client *http.Client + cacheStrategy cacheutils.Strategy + cacheTTL int + cacheKeyExtras []string + logging bool + redactHeaders []string + err error } // DefaultTimeout is the default request timeout. const DefaultTimeout = 60 * time.Second +// requestGroup deduplicates concurrent in-flight requests sharing a cache key. +var requestGroup singleflight.Group + // NewRequest creates a new Request with the given method and URL. // By default, requests have a 60-second timeout. Use WithTimeout to override. func NewRequest(method, rawURL string) *Request { @@ -58,24 +64,97 @@ func (r *Request) Do() (*response.Response, error) { return nil, r.err } - // Check cache for GET requests if r.method == http.MethodGet && r.cacheStrategy != cacheutils.None { - cacheKey := r.buildCacheKey() + return r.doCached() + } + return r.doFetch() +} + +// doCached handles the cached path: cache check, then singleflight-wrapped +// fetch + store. Each caller unmarshals its own response from the shared +// bytes so callers can safely mutate the returned response. +func (r *Request) doCached() (*response.Response, error) { + cacheKey := r.buildCacheKey() + + if resp, ok := r.cachedResponse(cacheKey); ok { + return resp, nil + } + + ch := requestGroup.DoChan(cacheKey, func() (any, error) { + // Re-check now that we hold the slot — and validate, since corrupt + // bytes here must fall through to a fresh fetch. if data, found := cacheutils.Get(cacheKey, r.cacheStrategy, r.cacheTTL); found { - if resp, err := unmarshalResponse(data); err == nil { - if r.logging { - slog.Debug("HTTP cache hit", - "method", r.method, - "url", r.url, - "cache_key", cacheKey, - ) - } - return resp, nil + if _, err := unmarshalResponse(data); err == nil { + return data, nil } + cacheutils.Delete(cacheKey) + } + resp, err := r.doFetch() + if err != nil { + return nil, err + } + data, err := marshalResponse(resp) + if err != nil { + return nil, err } + if resp.IsSuccess() { + if storeErr := cacheutils.Store(cacheKey, data, r.cacheStrategy, r.cacheTTL); storeErr != nil { + slog.Debug("Failed to cache response", + "error", storeErr, + "cache_key", cacheKey, + ) + } + } + return data, nil + }) + + // Followers honor their own context: a slow leader's fetch must not + // pin a canceled or timed-out follower. The leader runs on its own ctx. + ctx := r.ctx + if ctx == nil { + ctx = context.Background() } + ctx, cancel := context.WithTimeout(ctx, r.timeout) + defer cancel() - // Apply timeout to the context right before execution + select { + case res := <-ch: + if res.Err != nil { + return nil, res.Err + } + if res.Shared && r.logging { + slog.Debug("HTTP request shared via singleflight", "cache_key", cacheKey) + } + return unmarshalResponse(res.Val.([]byte)) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// cachedResponse returns a fresh response from cache if present and valid. +// Corrupt entries are deleted so subsequent paths can refetch. +func (r *Request) cachedResponse(cacheKey string) (*response.Response, bool) { + data, found := cacheutils.Get(cacheKey, r.cacheStrategy, r.cacheTTL) + if !found { + return nil, false + } + resp, err := unmarshalResponse(data) + if err != nil { + cacheutils.Delete(cacheKey) + return nil, false + } + if r.logging { + slog.Debug("HTTP cache hit", + "method", r.method, + "url", r.url, + "cache_key", cacheKey, + ) + } + return resp, true +} + +// doFetch executes the HTTP request directly without consulting the cache. +func (r *Request) doFetch() (*response.Response, error) { ctx := r.ctx if ctx == nil { ctx = context.Background() @@ -83,47 +162,38 @@ func (r *Request) Do() (*response.Response, error) { ctx, cancel := context.WithTimeout(ctx, r.timeout) defer cancel() - // Build the URL with query parameters reqURL, err := r.buildURL() if err != nil { return nil, err } - // Create the HTTP request req, err := http.NewRequestWithContext(ctx, r.method, reqURL, r.body) if err != nil { return nil, err } - - // Set headers req.Header = r.headers - // Log request if enabled start := time.Now() if r.logging { r.logRequest(req) } - // Get the client client := r.client if client == nil { client = DefaultClient() } - // Execute the request resp, err := client.Do(req) if err != nil { return nil, err } defer func() { _ = resp.Body.Close() }() - // Read the body body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - // Build the response result := &response.Response{ StatusCode: resp.StatusCode, Status: resp.Status, @@ -131,24 +201,10 @@ func (r *Request) Do() (*response.Response, error) { Body: body, } - // Log response if enabled if r.logging { r.logResponse(result, time.Since(start)) } - // Cache successful GET responses - if r.method == http.MethodGet && r.cacheStrategy != cacheutils.None && result.IsSuccess() { - cacheKey := r.buildCacheKey() - if data, err := marshalResponse(result); err == nil { - if err := cacheutils.Store(cacheKey, data, r.cacheStrategy, r.cacheTTL); err != nil { - slog.Debug("Failed to cache response", - "error", err, - "cache_key", cacheKey, - ) - } - } - } - return result, nil } @@ -186,10 +242,18 @@ func (r *Request) buildURL() (string, error) { return parsed.String(), nil } -// buildCacheKey generates a cache key from URL and query parameters. +// buildCacheKey generates a cache key from URL, query parameters, and any +// extras added via WithCacheKey. The whole string is later SHA256-hashed by +// cacheutils, so the format only needs to be deterministic. func (r *Request) buildCacheKey() string { reqURL, _ := r.buildURL() - return reqURL + if len(r.cacheKeyExtras) == 0 { + return reqURL + } + parts := make([]string, 0, len(r.cacheKeyExtras)+1) + parts = append(parts, reqURL) + parts = append(parts, r.cacheKeyExtras...) + return strings.Join(parts, "|") } // logRequest logs the outgoing request details. diff --git a/utils/httputils/request/stampede_test.go b/utils/httputils/request/stampede_test.go new file mode 100644 index 0000000..5e6f360 --- /dev/null +++ b/utils/httputils/request/stampede_test.go @@ -0,0 +1,196 @@ +package request + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/0skillallluck/scanline/utils/cacheutils" +) + +// TestSingleflight_DedupesConcurrentCacheMisses verifies that N concurrent +// callers requesting the same uncached URL produce a single backend fetch. +func TestSingleflight_DedupesConcurrentCacheMisses(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + var serverHits int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&serverHits, 1) + // Hold long enough for concurrent callers to pile up on the + // singleflight slot. + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"k":"v"}`)) + })) + defer server.Close() + + const concurrent = 20 + var wg sync.WaitGroup + wg.Add(concurrent) + + bodies := make([][]byte, concurrent) + for i := 0; i < concurrent; i++ { + go func(i int) { + defer wg.Done() + resp, err := NewRequest(http.MethodGet, server.URL). + WithInMemoryCaching(60). + Do() + if err != nil { + t.Errorf("goroutine %d: %v", i, err) + return + } + bodies[i] = resp.Body + }(i) + } + wg.Wait() + + if got := atomic.LoadInt64(&serverHits); got != 1 { + t.Errorf("expected 1 server hit, got %d", got) + } + + for i, b := range bodies { + if string(b) != `{"k":"v"}` { + t.Errorf("goroutine %d body = %q, want %q", i, b, `{"k":"v"}`) + } + } +} + +// TestSingleflight_FollowerHonorsContextCancellation verifies that a follower +// waiting on a slow leader's fetch can abandon its wait when its own context +// is canceled, instead of blocking until the leader returns. +func TestSingleflight_FollowerHonorsContextCancellation(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + leaderStarted := make(chan struct{}) + releaseLeader := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case leaderStarted <- struct{}{}: + default: + } + <-releaseLeader + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + defer close(releaseLeader) + + // Leader: long timeout, will block until releaseLeader fires. + leaderDone := make(chan error, 1) + go func() { + _, err := NewRequest(http.MethodGet, server.URL).WithInMemoryCaching(60).Do() + leaderDone <- err + }() + + // Wait for the leader's request to actually hit the server before + // dispatching the follower, so the follower lands on the in-flight slot. + <-leaderStarted + + // Follower: short context, should abandon long before the leader returns. + followerCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := NewRequest(http.MethodGet, server.URL). + WithContext(followerCtx). + WithInMemoryCaching(60). + Do() + elapsed := time.Since(start) + + if err == nil { + t.Fatal("follower should have errored out via ctx cancellation") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("follower err = %v, want context.DeadlineExceeded", err) + } + if elapsed > 500*time.Millisecond { + t.Errorf("follower waited %v — should have abandoned at the 100ms deadline", elapsed) + } +} + +// TestDoCached_CorruptCacheTriggersRefetch verifies that bytes that cannot be +// unmarshaled into a response (e.g. residue from a previous schema or disk +// corruption) trigger a fresh fetch instead of bubbling up the unmarshal error. +func TestDoCached_CorruptCacheTriggersRefetch(t *testing.T) { + cacheutils.SetFileCacheDir(t.TempDir()) + cacheutils.Clear() //nolint:errcheck + + var serverHits int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&serverHits, 1) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + // Compute the cache key the request layer will use, then plant garbage. + r := NewRequest(http.MethodGet, server.URL).WithCaching(60) + cacheKey := r.buildCacheKey() + if err := cacheutils.Store(cacheKey, []byte("not a valid response"), cacheutils.Layered, 60); err != nil { + t.Fatalf("Store garbage: %v", err) + } + + resp, err := NewRequest(http.MethodGet, server.URL).WithCaching(60).Do() + if err != nil { + t.Fatalf("Do: %v", err) + } + if got := atomic.LoadInt64(&serverHits); got != 1 { + t.Errorf("expected fallback to network: serverHits=%d", got) + } + if string(resp.Body) != `{"ok":true}` { + t.Errorf("body = %q, want fresh server response", resp.Body) + } + + // And the corrupt entry must be gone — a follow-up call should hit the cache. + if _, err := NewRequest(http.MethodGet, server.URL).WithCaching(60).Do(); err != nil { + t.Fatalf("second Do: %v", err) + } + if got := atomic.LoadInt64(&serverHits); got != 1 { + t.Errorf("expected second call to hit cache: serverHits=%d", got) + } +} + +// TestSingleflight_DoesNotShareResponseBuffers verifies that mutating one +// caller's response body doesn't affect siblings — i.e. each caller gets an +// independent unmarshaled response. +func TestSingleflight_DoesNotShareResponseBuffers(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + })) + defer server.Close() + + const concurrent = 5 + var wg sync.WaitGroup + wg.Add(concurrent) + results := make([]*[]byte, concurrent) + + for i := 0; i < concurrent; i++ { + go func(i int) { + defer wg.Done() + resp, err := NewRequest(http.MethodGet, server.URL).WithInMemoryCaching(60).Do() + if err != nil { + t.Errorf("goroutine %d: %v", i, err) + return + } + results[i] = &resp.Body + }(i) + } + wg.Wait() + + // Mutate one caller's body and ensure others are unaffected. + (*results[0])[0] = 'X' + for i := 1; i < concurrent; i++ { + if (*results[i])[0] != 'h' { + t.Errorf("goroutine %d body was mutated by sibling: %q", i, *results[i]) + } + } +} diff --git a/utils/httputils/request/with_cache_key.go b/utils/httputils/request/with_cache_key.go new file mode 100644 index 0000000..859d833 --- /dev/null +++ b/utils/httputils/request/with_cache_key.go @@ -0,0 +1,12 @@ +package request + +// WithCacheKey appends extra strings to the cache key for this request. +// Useful for discriminating cached responses by something not present in the +// URL (e.g. an auth token), so that two callers with the same URL but +// different identities don't share cache entries. +// The extras are joined into the cache key with a separator and the whole +// string is SHA256-hashed by cacheutils, so any value is safe to pass. +func (r *Request) WithCacheKey(extras ...string) *Request { + r.cacheKeyExtras = append(r.cacheKeyExtras, extras...) + return r +} diff --git a/utils/httputils/request/with_cache_key_test.go b/utils/httputils/request/with_cache_key_test.go new file mode 100644 index 0000000..bdeb1b9 --- /dev/null +++ b/utils/httputils/request/with_cache_key_test.go @@ -0,0 +1,102 @@ +package request + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/0skillallluck/scanline/utils/cacheutils" +) + +func TestWithCacheKey_DifferentExtrasProduceDifferentEntries(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + _, err := NewRequest(http.MethodGet, server.URL). + WithInMemoryCaching(0). + WithCacheKey("user-A"). + Do() + if err != nil { + t.Fatalf("first request: %v", err) + } + + _, err = NewRequest(http.MethodGet, server.URL). + WithInMemoryCaching(0). + WithCacheKey("user-B"). + Do() + if err != nil { + t.Fatalf("second request: %v", err) + } + + if callCount != 2 { + t.Errorf("expected 2 server calls (one per cache key extra), got %d", callCount) + } +} + +func TestWithCacheKey_SameExtraReusesCacheEntry(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + for i := 0; i < 3; i++ { + _, err := NewRequest(http.MethodGet, server.URL). + WithInMemoryCaching(0). + WithCacheKey("same-token"). + Do() + if err != nil { + t.Fatalf("request %d: %v", i, err) + } + } + + if callCount != 1 { + t.Errorf("expected 1 server call (subsequent requests cached), got %d", callCount) + } +} + +func TestWithInMemoryCaching_TTLExpires(t *testing.T) { + cacheutils.Clear() //nolint:errcheck + + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + _, err := NewRequest(http.MethodGet, server.URL).WithInMemoryCaching(1).Do() + if err != nil { + t.Fatalf("first request: %v", err) + } + + if _, err := NewRequest(http.MethodGet, server.URL).WithInMemoryCaching(1).Do(); err != nil { + t.Fatalf("second request: %v", err) + } + if callCount != 1 { + t.Errorf("expected cache hit on second request: callCount=%d", callCount) + } + + // Wait past the TTL. + time.Sleep(1100 * time.Millisecond) + + if _, err := NewRequest(http.MethodGet, server.URL).WithInMemoryCaching(1).Do(); err != nil { + t.Fatalf("third request: %v", err) + } + if callCount != 2 { + t.Errorf("expected cache miss after TTL expiry: callCount=%d", callCount) + } +} diff --git a/utils/httputils/request/with_in_memory_caching.go b/utils/httputils/request/with_in_memory_caching.go index 9b389fe..391bd36 100644 --- a/utils/httputils/request/with_in_memory_caching.go +++ b/utils/httputils/request/with_in_memory_caching.go @@ -3,10 +3,14 @@ package request import "github.com/0skillallluck/scanline/utils/cacheutils" // WithInMemoryCaching enables in-memory caching for the request. -// Cached responses persist for the lifetime of the process. -// Cache key is auto-calculated from URL and query parameters. +// If ttlSeconds is 0, the entry has no explicit expiration and persists for the +// lifetime of the process (still subject to the memory cache's LRU eviction). +// If ttlSeconds > 0, the entry expires after the given number of seconds. +// Cache key is auto-calculated from URL and query parameters; use WithCacheKey +// to add additional discriminators. // Caching only works for GET requests. -func (r *Request) WithInMemoryCaching() *Request { +func (r *Request) WithInMemoryCaching(ttlSeconds int) *Request { + r.cacheTTL = ttlSeconds r.cacheStrategy = cacheutils.MemoryOnly return r } diff --git a/utils/httputils/request/with_in_memory_caching_test.go b/utils/httputils/request/with_in_memory_caching_test.go index 11db8f5..0641491 100644 --- a/utils/httputils/request/with_in_memory_caching_test.go +++ b/utils/httputils/request/with_in_memory_caching_test.go @@ -20,14 +20,14 @@ func TestWithInMemoryCaching_GetRequest(t *testing.T) { defer server.Close() resp1, err := NewRequest(http.MethodGet, server.URL). - WithInMemoryCaching(). + WithInMemoryCaching(0). Do() if err != nil { t.Fatalf("First Do() error = %v", err) } resp2, err := NewRequest(http.MethodGet, server.URL). - WithInMemoryCaching(). + WithInMemoryCaching(0). Do() if err != nil { t.Fatalf("Second Do() error = %v", err) @@ -55,7 +55,7 @@ func TestWithInMemoryCaching_PersistsForSessionLifetime(t *testing.T) { // First request _, err := NewRequest(http.MethodGet, server.URL). - WithInMemoryCaching(). + WithInMemoryCaching(0). Do() if err != nil { t.Fatalf("First Do() error = %v", err) @@ -64,7 +64,7 @@ func TestWithInMemoryCaching_PersistsForSessionLifetime(t *testing.T) { // Multiple subsequent requests should all be cached for i := range 5 { _, err := NewRequest(http.MethodGet, server.URL). - WithInMemoryCaching(). + WithInMemoryCaching(0). Do() if err != nil { t.Fatalf("Request %d Do() error = %v", i+2, err) @@ -86,8 +86,8 @@ func TestWithInMemoryCaching_NonGetRequestIgnored(t *testing.T) { })) defer server.Close() - _, _ = NewRequest(http.MethodPost, server.URL).WithInMemoryCaching().Do() - _, _ = NewRequest(http.MethodPost, server.URL).WithInMemoryCaching().Do() + _, _ = NewRequest(http.MethodPost, server.URL).WithInMemoryCaching(0).Do() + _, _ = NewRequest(http.MethodPost, server.URL).WithInMemoryCaching(0).Do() if callCount != 2 { t.Errorf("POST requests should not be cached, got %d calls instead of 2", callCount) @@ -105,10 +105,10 @@ func TestWithInMemoryCaching_DifferentURLsDifferentCache(t *testing.T) { })) defer server.Close() - _, _ = NewRequest(http.MethodGet, server.URL+"/path1").WithInMemoryCaching().Do() - _, _ = NewRequest(http.MethodGet, server.URL+"/path2").WithInMemoryCaching().Do() - _, _ = NewRequest(http.MethodGet, server.URL+"/path1").WithInMemoryCaching().Do() - _, _ = NewRequest(http.MethodGet, server.URL+"/path2").WithInMemoryCaching().Do() + _, _ = NewRequest(http.MethodGet, server.URL+"/path1").WithInMemoryCaching(0).Do() + _, _ = NewRequest(http.MethodGet, server.URL+"/path2").WithInMemoryCaching(0).Do() + _, _ = NewRequest(http.MethodGet, server.URL+"/path1").WithInMemoryCaching(0).Do() + _, _ = NewRequest(http.MethodGet, server.URL+"/path2").WithInMemoryCaching(0).Do() if callCount != 2 { t.Errorf("Server was called %d times, want 2 (one per unique path)", callCount)