From 650e3a0bbd81eaf71cb71b29608c7f9badb9faf8 Mon Sep 17 00:00:00 2001 From: noorbhatia Date: Tue, 13 Jan 2026 20:53:59 +0530 Subject: [PATCH 1/2] Implement NSCache to cache MLX ModelContext --- .../Models/MLXLanguageModel.swift | 75 ++++++++++++++----- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 1f13824..824f6af 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,21 @@ import Foundation import Tokenizers import Hub + + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass) + private final class CachedContext: NSObject { + let context: ModelContext + init(_ context: ModelContext) { self.context = context } + } + + private nonisolated(unsafe) let modelCache: NSCache = { + let cache = NSCache() + cache.countLimit = 3 + return cache + }() + + // MARK: - MLXLanguageModel + /// A language model that runs locally using MLX. /// /// Use this model to run language models on Apple silicon using the MLX framework. @@ -51,6 +66,25 @@ import Foundation self.directory = directory } + /// Get or load model context with caching + private func loadContext(modelId: String, hub: HubApi?, directory: URL?) async throws -> ModelContext { + let key = (directory?.absoluteString ?? modelId) as NSString + + if let cached = modelCache.object(forKey: key) { + return cached.context + } + + let context: ModelContext + if let directory { + context = try await loadModel(directory: directory) + } else { + context = try await loadModel(hub: hub ?? HubApi(), id: modelId) + } + + modelCache.setObject(CachedContext(context), forKey: key) + return context + } + public func respond( within session: LanguageModelSession, to prompt: Prompt, @@ -63,14 +97,8 @@ import Foundation fatalError("MLXLanguageModel only supports generating String content") } - let context: ModelContext - if let directory { - context = try await loadModel(directory: directory) - } else if let hub { - context = try await loadModel(hub: hub, id: modelId) - } else { - context = try await loadModel(id: modelId) - } + // Get cached or load fresh ModelContext + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) // Convert session tools to MLX ToolSpec format let toolSpecs: [ToolSpec]? = @@ -179,18 +207,11 @@ import Foundation continuation in let task = Task { @Sendable in do { - let context: ModelContext - if let directory { - context = try await loadModel(directory: directory) - } else if let hub { - context = try await loadModel(hub: hub, id: modelId) - } else { - context = try await loadModel(id: modelId) - } + // Get cached or load fresh ModelContext + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - - // Build chat history from full transcript let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let userInput = MLXLMCommon.UserInput( @@ -462,4 +483,22 @@ import Foundation } return textParts.joined(separator: "\n") } + + // MARK: - Cache Management + + extension MLXLanguageModel { + /// Unloads this model from the shared cache. + /// + /// Call this to free memory when the model is no longer needed. + /// The model will be reloaded automatically on the next request. + public func unload() { + let key = (directory?.absoluteString ?? modelId) as NSString + modelCache.removeObject(forKey: key) + } + + /// Unloads all MLX models from the shared cache. + public static func unloadAll() { + modelCache.removeAllObjects() + } + } #endif // MLX From ee7e4dc6efbe469a72259f0c7ab2a01bc825ce80 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 15 Jan 2026 04:13:47 -0800 Subject: [PATCH 2/2] Implement actor-coalesced MLX model cache backed by NSCache --- .../Models/MLXLanguageModel.swift | 161 ++++++++++++++---- 1 file changed, 124 insertions(+), 37 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 824f6af..8f44b31 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,18 +17,114 @@ import Foundation import Tokenizers import Hub - - /// Wrapper to store ModelContext in NSCache (requires NSObject subclass) - private final class CachedContext: NSObject { + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). + private final class CachedContext: NSObject, @unchecked Sendable { let context: ModelContext init(_ context: ModelContext) { self.context = context } } - private nonisolated(unsafe) let modelCache: NSCache = { - let cache = NSCache() - cache.countLimit = 3 - return cache - }() + /// Coordinates a bounded in-memory cache with structured, coalesced loading. + private final class ModelContextCache { + private let cache: NSCache + private let lock = NSLock() + private var inFlight: [String: Task] = [:] + + /// Creates a cache with a count-based eviction limit. + init(countLimit: Int) { + let cache = NSCache() + cache.countLimit = countLimit + self.cache = cache + } + + /// Returns a cached context or loads it exactly once per key. + func context( + for key: String, + loader: @escaping @Sendable () async throws -> ModelContext + ) async throws -> ModelContext { + let cacheKey = key as NSString + if let cached = cache.object(forKey: cacheKey) { + return cached.context + } + + if let task = inFlightTask(for: key) { + return try await task.value.context + } + + let task = Task { try await CachedContext(loader()) } + setInFlight(task, for: key) + + do { + let cached = try await task.value + cache.setObject(cached, forKey: cacheKey) + clearInFlight(for: key) + return cached.context + } catch { + clearInFlight(for: key) + throw error + } + } + + /// Removes a cached context for the key. + func remove(for key: String) { + cache.removeObject(forKey: key as NSString) + } + + /// Clears all cached contexts. + func removeAll() { + cache.removeAllObjects() + } + + /// Cancels in-flight work and removes cached data for the key. + func removeAndCancel(for key: String) async { + let task = removeInFlight(for: key) + task?.cancel() + cache.removeObject(forKey: key as NSString) + } + + /// Cancels all in-flight work and clears cached data. + func removeAllAndCancel() async { + let tasks = removeAllInFlight() + tasks.forEach { $0.cancel() } + cache.removeAllObjects() + } + + private func inFlightTask(for key: String) -> Task? { + lock.lock() + defer { lock.unlock() } + return inFlight[key] + } + + private func setInFlight(_ task: Task, for key: String) { + lock.lock() + inFlight[key] = task + lock.unlock() + } + + private func clearInFlight(for key: String) { + lock.lock() + inFlight[key] = nil + lock.unlock() + } + + private func removeInFlight(for key: String) -> Task? { + lock.lock() + defer { lock.unlock() } + let task = inFlight[key] + inFlight[key] = nil + return task + } + + private func removeAllInFlight() -> [Task] { + lock.lock() + defer { lock.unlock() } + let tasks = Array(inFlight.values) + inFlight.removeAll() + return tasks + } + } + + /// Shared cache across MLXLanguageModel instances. + private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) // MARK: - MLXLanguageModel @@ -66,23 +162,31 @@ import Foundation self.directory = directory } + /// Removes this model from the shared cache and cancels any in-flight load. + /// + /// Call this to free memory when the model is no longer needed. + /// The model will be reloaded automatically on the next request. + public func removeFromCache() async { + let key = directory?.absoluteString ?? modelId + await modelCache.removeAndCancel(for: key) + } + + /// Removes all MLX models from the shared cache and cancels in-flight loads. + public static func removeAllFromCache() async { + await modelCache.removeAllAndCancel() + } + /// Get or load model context with caching private func loadContext(modelId: String, hub: HubApi?, directory: URL?) async throws -> ModelContext { - let key = (directory?.absoluteString ?? modelId) as NSString + let key = directory?.absoluteString ?? modelId - if let cached = modelCache.object(forKey: key) { - return cached.context - } + return try await modelCache.context(for: key) { + if let directory { + return try await loadModel(directory: directory) + } - let context: ModelContext - if let directory { - context = try await loadModel(directory: directory) - } else { - context = try await loadModel(hub: hub ?? HubApi(), id: modelId) + return try await loadModel(hub: hub ?? HubApi(), id: modelId) } - - modelCache.setObject(CachedContext(context), forKey: key) - return context } public func respond( @@ -484,21 +588,4 @@ import Foundation return textParts.joined(separator: "\n") } - // MARK: - Cache Management - - extension MLXLanguageModel { - /// Unloads this model from the shared cache. - /// - /// Call this to free memory when the model is no longer needed. - /// The model will be reloaded automatically on the next request. - public func unload() { - let key = (directory?.absoluteString ?? modelId) as NSString - modelCache.removeObject(forKey: key) - } - - /// Unloads all MLX models from the shared cache. - public static func unloadAll() { - modelCache.removeAllObjects() - } - } #endif // MLX