Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 101 additions & 41 deletions memorystore/memorystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ package memorystore

import (
"context"
"hash/fnv"
"sync"
"sync/atomic"
"time"

"github.com/goccy/go-json"
)

const numShards = 256

// item represents a single cache entry with its value and expiration time.
type item struct {
value []byte // Raw data stored as a byte slice
Expand All @@ -27,11 +30,18 @@ type StoreMetrics struct {
Evictions int64 // Total number of items evicted (expired)
}

type shard struct {
mu sync.RWMutex
store map[string]item
}

// MemoryStore implements an in-memory cache with automatic cleanup of expired items.
// It is safe for concurrent use by multiple goroutines.
type MemoryStore struct {
mu sync.RWMutex // Protects access to the store map
store map[string]item // Internal storage for cache items
// lifecycleMu protects the lifecycle state (cancelFunc)
lifecycleMu sync.RWMutex

shards []*shard // Sharded storage
ps *pubSubManager // PubSub manager for cache events
ctx context.Context // Context for controlling the cleanup worker
cancelFunc context.CancelFunc // Function to stop the cleanup worker
Expand All @@ -49,15 +59,29 @@ type MemoryStore struct {
func NewMemoryStore() *MemoryStore {
ctx, cancel := context.WithCancel(context.Background())
ms := &MemoryStore{
store: make(map[string]item),
shards: make([]*shard, numShards),
ctx: ctx,
cancelFunc: cancel,
}

for i := 0; i < numShards; i++ {
ms.shards[i] = &shard{
store: make(map[string]item),
}
}

ms.initPubSub()
ms.startCleanupWorker()
return ms
}

// getShard returns the shard responsible for the given key.
func (m *MemoryStore) getShard(key string) *shard {
h := fnv.New64a()
h.Write([]byte(key))
return m.shards[h.Sum64()%numShards]
}

// Stop gracefully shuts down the MemoryStore by stopping the cleanup goroutine
// and releasing associated resources. After calling Stop, the store cannot be used.
// Multiple calls to Stop will not cause a panic and return nil.
Expand All @@ -67,8 +91,8 @@ func NewMemoryStore() *MemoryStore {
// store := NewMemoryStore()
// defer store.Stop()
func (m *MemoryStore) Stop() error {
m.mu.Lock()
defer m.mu.Unlock()
m.lifecycleMu.Lock()
defer m.lifecycleMu.Unlock()

if m.cancelFunc == nil {
return nil
Expand All @@ -83,7 +107,11 @@ func (m *MemoryStore) Stop() error {
m.wg.Wait()

// Clear the store to free up memory
m.store = nil
for _, s := range m.shards {
s.mu.Lock()
s.store = nil
s.mu.Unlock()
}

return nil
}
Expand All @@ -98,8 +126,8 @@ func (m *MemoryStore) Stop() error {
// return
// }
func (m *MemoryStore) IsStopped() bool {
m.mu.RLock()
defer m.mu.RUnlock()
m.lifecycleMu.RLock()
defer m.lifecycleMu.RUnlock()
return m.cancelFunc == nil
}

Expand All @@ -124,26 +152,31 @@ func (m *MemoryStore) startCleanupWorker() {
}

// cleanupExpiredItems removes all expired items from the cache.
// This method acquires a write lock on the store while performing the cleanup.
// It iterates over shards and cleans them one by one to avoid global locking.
func (m *MemoryStore) cleanupExpiredItems() {
m.mu.Lock()
defer m.mu.Unlock()
for key, item := range m.store {
if time.Now().After(item.expiresAt) {
delete(m.store, key)
atomic.AddInt64(&m.evictions, 1)
now := time.Now()
for _, s := range m.shards {
// Lock only the current shard
s.mu.Lock()
for key, item := range s.store {
if now.After(item.expiresAt) {
delete(s.store, key)
atomic.AddInt64(&m.evictions, 1)
}
}
s.mu.Unlock()
}
}

// Set stores a raw byte slice in the cache with the specified key and duration.
// The item will automatically expire after the specified duration.
// If an error occurs, it will be returned to the caller.
func (m *MemoryStore) Set(key string, value []byte, duration time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
s := m.getShard(key)
s.mu.Lock()
defer s.mu.Unlock()

m.store[key] = item{
s.store[key] = item{
value: value,
expiresAt: time.Now().Add(duration),
}
Expand Down Expand Up @@ -175,10 +208,11 @@ func (m *MemoryStore) SetJSON(key string, value interface{}, duration time.Durat
// Returns the value and a boolean indicating whether the key was found.
// If the item has expired, returns (nil, false).
func (m *MemoryStore) Get(key string) ([]byte, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
s := m.getShard(key)
s.mu.RLock()
defer s.mu.RUnlock()

it, exists := m.store[key]
it, exists := s.store[key]
if !exists || time.Now().After(it.expiresAt) {
atomic.AddInt64(&m.misses, 1)
return nil, false
Expand Down Expand Up @@ -217,45 +251,68 @@ func (m *MemoryStore) GetJSON(key string, dest interface{}) (bool, error) {
// Delete removes an item from the cache.
// If the key doesn't exist, the operation is a no-op.
func (m *MemoryStore) Delete(key string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.store, key)
s := m.getShard(key)
s.mu.Lock()
defer s.mu.Unlock()
delete(s.store, key)
}

// SetMulti stores multiple key-value pairs in the cache.
// This is more efficient than calling Set multiple times as it acquires the lock only once.
// This is more efficient than calling Set multiple times as it groups keys by shard.
// All items will have the same expiration duration.
func (m *MemoryStore) SetMulti(items map[string][]byte, duration time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()

// Group items by shard
shardItems := make(map[*shard]map[string]item)
expiresAt := time.Now().Add(duration)

for key, value := range items {
m.store[key] = item{
s := m.getShard(key)
if _, ok := shardItems[s]; !ok {
shardItems[s] = make(map[string]item)
}
shardItems[s][key] = item{
value: value,
expiresAt: expiresAt,
}
}

// Apply updates per shard
for s, items := range shardItems {
s.mu.Lock()
for k, v := range items {
s.store[k] = v
}
s.mu.Unlock()
}
return nil
}

// GetMulti retrieves multiple values from the cache.
// It returns a map of found items. Keys that don't exist or are expired are omitted.
func (m *MemoryStore) GetMulti(keys []string) map[string][]byte {
m.mu.RLock()
defer m.mu.RUnlock()

result := make(map[string][]byte)
now := time.Now()

// Group keys by shard
shardKeys := make(map[*shard][]string)
for _, key := range keys {
it, exists := m.store[key]
if exists && !now.After(it.expiresAt) {
result[key] = it.value
atomic.AddInt64(&m.hits, 1)
} else {
atomic.AddInt64(&m.misses, 1)
s := m.getShard(key)
shardKeys[s] = append(shardKeys[s], key)
}

// Retrieve from each shard
for s, keys := range shardKeys {
s.mu.RLock()
for _, key := range keys {
it, exists := s.store[key]
if exists && !now.After(it.expiresAt) {
result[key] = it.value
atomic.AddInt64(&m.hits, 1)
} else {
atomic.AddInt64(&m.misses, 1)
}
}
s.mu.RUnlock()
}

return result
Expand All @@ -264,9 +321,12 @@ func (m *MemoryStore) GetMulti(keys []string) map[string][]byte {
// GetMetrics returns the current statistics of the MemoryStore.
// It returns a copy of the metrics to ensure thread safety.
func (m *MemoryStore) GetMetrics() StoreMetrics {
m.mu.RLock()
itemCount := len(m.store)
m.mu.RUnlock()
itemCount := 0
for _, s := range m.shards {
s.mu.RLock()
itemCount += len(s.store)
s.mu.RUnlock()
}

return StoreMetrics{
Items: itemCount,
Expand Down