diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 1f13824..8f44b31 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,117 @@ import Foundation import Tokenizers import Hub + /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). + private final class CachedContext: NSObject, @unchecked Sendable { + let context: ModelContext + init(_ context: ModelContext) { self.context = context } + } + + /// Coordinates a bounded in-memory cache with structured, coalesced loading. + private final class ModelContextCache { + private let cache: NSCache + private let lock = NSLock() + private var inFlight: [String: Task] = [:] + + /// Creates a cache with a count-based eviction limit. + init(countLimit: Int) { + let cache = NSCache() + cache.countLimit = countLimit + self.cache = cache + } + + /// Returns a cached context or loads it exactly once per key. + func context( + for key: String, + loader: @escaping @Sendable () async throws -> ModelContext + ) async throws -> ModelContext { + let cacheKey = key as NSString + if let cached = cache.object(forKey: cacheKey) { + return cached.context + } + + if let task = inFlightTask(for: key) { + return try await task.value.context + } + + let task = Task { try await CachedContext(loader()) } + setInFlight(task, for: key) + + do { + let cached = try await task.value + cache.setObject(cached, forKey: cacheKey) + clearInFlight(for: key) + return cached.context + } catch { + clearInFlight(for: key) + throw error + } + } + + /// Removes a cached context for the key. + func remove(for key: String) { + cache.removeObject(forKey: key as NSString) + } + + /// Clears all cached contexts. + func removeAll() { + cache.removeAllObjects() + } + + /// Cancels in-flight work and removes cached data for the key. + func removeAndCancel(for key: String) async { + let task = removeInFlight(for: key) + task?.cancel() + cache.removeObject(forKey: key as NSString) + } + + /// Cancels all in-flight work and clears cached data. + func removeAllAndCancel() async { + let tasks = removeAllInFlight() + tasks.forEach { $0.cancel() } + cache.removeAllObjects() + } + + private func inFlightTask(for key: String) -> Task? { + lock.lock() + defer { lock.unlock() } + return inFlight[key] + } + + private func setInFlight(_ task: Task, for key: String) { + lock.lock() + inFlight[key] = task + lock.unlock() + } + + private func clearInFlight(for key: String) { + lock.lock() + inFlight[key] = nil + lock.unlock() + } + + private func removeInFlight(for key: String) -> Task? { + lock.lock() + defer { lock.unlock() } + let task = inFlight[key] + inFlight[key] = nil + return task + } + + private func removeAllInFlight() -> [Task] { + lock.lock() + defer { lock.unlock() } + let tasks = Array(inFlight.values) + inFlight.removeAll() + return tasks + } + } + + /// Shared cache across MLXLanguageModel instances. + private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3) + + // MARK: - MLXLanguageModel + /// A language model that runs locally using MLX. /// /// Use this model to run language models on Apple silicon using the MLX framework. @@ -51,6 +162,33 @@ import Foundation self.directory = directory } + /// Removes this model from the shared cache and cancels any in-flight load. + /// + /// Call this to free memory when the model is no longer needed. + /// The model will be reloaded automatically on the next request. + public func removeFromCache() async { + let key = directory?.absoluteString ?? modelId + await modelCache.removeAndCancel(for: key) + } + + /// Removes all MLX models from the shared cache and cancels in-flight loads. + public static func removeAllFromCache() async { + await modelCache.removeAllAndCancel() + } + + /// Get or load model context with caching + private func loadContext(modelId: String, hub: HubApi?, directory: URL?) async throws -> ModelContext { + let key = directory?.absoluteString ?? modelId + + return try await modelCache.context(for: key) { + if let directory { + return try await loadModel(directory: directory) + } + + return try await loadModel(hub: hub ?? HubApi(), id: modelId) + } + } + public func respond( within session: LanguageModelSession, to prompt: Prompt, @@ -63,14 +201,8 @@ import Foundation fatalError("MLXLanguageModel only supports generating String content") } - let context: ModelContext - if let directory { - context = try await loadModel(directory: directory) - } else if let hub { - context = try await loadModel(hub: hub, id: modelId) - } else { - context = try await loadModel(id: modelId) - } + // Get cached or load fresh ModelContext + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) // Convert session tools to MLX ToolSpec format let toolSpecs: [ToolSpec]? = @@ -179,18 +311,11 @@ import Foundation continuation in let task = Task { @Sendable in do { - let context: ModelContext - if let directory { - context = try await loadModel(directory: directory) - } else if let hub { - context = try await loadModel(hub: hub, id: modelId) - } else { - context = try await loadModel(id: modelId) - } + // Get cached or load fresh ModelContext + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - - // Build chat history from full transcript let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let userInput = MLXLMCommon.UserInput( @@ -462,4 +587,5 @@ import Foundation } return textParts.joined(separator: "\n") } + #endif // MLX