Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 123 additions & 1 deletion packages/store/src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@ import { __flush } from './scheduler'
import { isUpdaterFunction } from './types'
import type { AnyUpdater, Listener, Updater } from './types'

export interface Storage {
removeItem: (key: string) => void
getItem: (key: string) => string | null
setItem: (key: string, value: string) => void
}

export interface PersistOptions<TState> {
/**
* Storage key to use when persisting state
*/
key: string
/**
* Storage to use for persistence. Defaults to localStorage
*/
storage?: Storage
/**
* Custom serializer. Defaults to JSON.stringify
*/
serialize?: (state: TState) => string
/**
* Custom deserializer. Defaults to JSON.parse
*/
deserialize?: (serializedState: string) => TState
}

export interface StoreOptions<
TState,
TUpdater extends AnyUpdater = (cb: TState) => TState,
Expand All @@ -23,6 +48,10 @@ export interface StoreOptions<
* Called after the state has been updated, used to derive other state.
*/
onUpdate?: () => void
/**
* Options for state persistence
*/
persist?: PersistOptions<TState>
}

export class Store<
Expand All @@ -35,9 +64,99 @@ export class Store<
options?: StoreOptions<TState, TUpdater>

constructor(initialState: TState, options?: StoreOptions<TState, TUpdater>) {
this.options = options

// Try to load persisted state if persistence is enabled
if (options?.persist) {
const persistedState = this.loadPersistedState()
if (persistedState !== null) {
this.prevState = persistedState
this.state = persistedState
return
}
}

this.prevState = initialState
this.state = initialState
this.options = options
}

private getDefaultStorage(): Storage {
if (typeof window === 'undefined') {
return {
getItem: () => null,
setItem: () => undefined,
removeItem: () => undefined,
}
}
return window.localStorage
}

private loadPersistedState(): TState | null {
const { persist } = this.options || {}
if (!persist) return null

const deserialize = persist.deserialize || JSON.parse
const storage = persist.storage || this.getDefaultStorage()

try {
const persistedState = storage.getItem(persist.key)
if (persistedState === null) return null
return deserialize(persistedState)
} catch (error) {
console.error('Failed to load persisted state:', error)
return null
}
}

private persistState(): void {
const { persist } = this.options || {}
if (!persist) return

const serialize = persist.serialize || JSON.stringify
const storage = persist.storage || this.getDefaultStorage()

try {
const serializedState = serialize(this.state)
storage.setItem(persist.key, serializedState)
} catch (error) {
console.error('Failed to persist state:', error)
}
}

/**
* Manually persist the current state
*/
persist(): void {
this.persistState()
}

/**
* Clear the persisted state
*/
clearPersistedState(): void {
const { persist } = this.options || {}
if (!persist) return

const storage = persist.storage || this.getDefaultStorage()
try {
storage.removeItem(persist.key)
} catch (error) {
console.error('Failed to clear persisted state:', error)
}
}

/**
* Reload the state from persistence
* @returns true if state was successfully reloaded, false otherwise
*/
rehydrate(): boolean {
const persistedState = this.loadPersistedState()
if (persistedState === null) return false

this.prevState = this.state
this.state = persistedState
__flush(this as never)
return true
}

subscribe = (listener: Listener<TState>) => {
Expand Down Expand Up @@ -71,6 +190,9 @@ export class Store<
// Always run onUpdate, regardless of batching
this.options?.onUpdate?.()

// Persist state if enabled
this.persistState()

// Attempt to flush
__flush(this as never)
}
Expand Down
212 changes: 212 additions & 0 deletions packages/store/tests/persistence.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { Store } from '../src'

interface TestState {
count: number
text: string
}

describe('Store Persistence', () => {
let mockStorage: {
storage: Record<string, string>
getItem: (key: string) => string | null
setItem: (key: string, value: string) => void
removeItem: (key: string) => void
}

beforeEach(() => {
mockStorage = {
storage: {},
getItem: vi.fn((key: string) => mockStorage.storage[key] || null),
setItem: vi.fn((key: string, value: string) => {
mockStorage.storage[key] = value
}),
removeItem: vi.fn((key: string) => {
delete mockStorage.storage[key]
}),
}
})

it('should initialize with initial state when no persisted state exists', () => {
const initialState: TestState = { count: 0, text: 'initial' }
const store = new Store(initialState, {
persist: {
key: 'test-store',
storage: mockStorage,
},
})

expect(store.state).toEqual(initialState)
expect(mockStorage.getItem).toHaveBeenCalledWith('test-store')
})

it('should persist state updates', () => {
const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
},
},
)

store.setState({ count: 1, text: 'updated' })

expect(mockStorage.setItem).toHaveBeenCalledWith(
'test-store',
JSON.stringify({ count: 1, text: 'updated' }),
)
})

it('should use custom serializer and deserializer', () => {
const customSerializer = vi.fn(
(state: TestState) => `count:${state.count};text:${state.text}`,
)
const customDeserializer = vi.fn((str: string) => {
const [countPart, textPart] = str.split(';')
if (!countPart || !textPart) {
throw new Error('Invalid format')
}
const count = Number(countPart.split(':')[1])
const text = textPart.split(':')[1]
if (typeof text !== 'string') {
throw new Error('Invalid format')
}
return {
count,
text,
}
})

mockStorage.storage['test-store'] = 'count:42;text:persisted'

const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
serialize: customSerializer,
deserialize: customDeserializer,
},
},
)

expect(store.state).toEqual({ count: 42, text: 'persisted' })
expect(customDeserializer).toHaveBeenCalledWith('count:42;text:persisted')

store.setState({ count: 100, text: 'serialized' })
expect(customSerializer).toHaveBeenCalledWith({
count: 100,
text: 'serialized',
})
})

it('should handle storage errors gracefully', () => {
const errorStorage = {
getItem: vi.fn(() => {
throw new Error('Storage error')
}),
setItem: vi.fn(() => {
throw new Error('Storage error')
}),
removeItem: vi.fn(() => {
throw new Error('Storage error')
}),
}

const initialState = { count: 0, text: 'initial' }
const store = new Store(initialState, {
persist: {
key: 'test-store',
storage: errorStorage,
},
})

expect(store.state).toEqual(initialState)
expect(() => store.setState({ count: 1, text: 'updated' })).not.toThrow()
})

it('should clear persisted state', () => {
const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
},
},
)

store.clearPersistedState()
expect(mockStorage.removeItem).toHaveBeenCalledWith('test-store')
})

it('should manually persist state', () => {
const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
},
},
)

vi.clearAllMocks()

store.persist()
expect(mockStorage.setItem).toHaveBeenCalledWith(
'test-store',
JSON.stringify({ count: 0, text: 'initial' }),
)
})

it('should rehydrate state from storage', () => {
const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
},
},
)

mockStorage.storage['test-store'] = JSON.stringify({
count: 42,
text: 'rehydrated',
})

const success = store.rehydrate()
expect(success).toBe(true)
expect(store.state).toEqual({ count: 42, text: 'rehydrated' })
})

it('should return false when rehydration fails', () => {
const store = new Store(
{ count: 0, text: 'initial' },
{
persist: {
key: 'test-store',
storage: mockStorage,
},
},
)

const success = store.rehydrate()
expect(success).toBe(false)
expect(store.state).toEqual({ count: 0, text: 'initial' })
})

it('should work without persistence options', () => {
const store = new Store({ count: 0, text: 'initial' })

store.persist()
store.rehydrate()
store.clearPersistedState()

expect(store.state).toEqual({ count: 0, text: 'initial' })
})
})