Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 144 additions & 18 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,117 @@ import Foundation
import Tokenizers
import Hub

/// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
private final class CachedContext: NSObject, @unchecked Sendable {
let context: ModelContext
init(_ context: ModelContext) { self.context = context }
}

/// Coordinates a bounded in-memory cache with structured, coalesced loading.
private final class ModelContextCache {
private let cache: NSCache<NSString, CachedContext>
private let lock = NSLock()
private var inFlight: [String: Task<CachedContext, Error>] = [:]

/// Creates a cache with a count-based eviction limit.
init(countLimit: Int) {
let cache = NSCache<NSString, CachedContext>()
cache.countLimit = countLimit
self.cache = cache
}

/// Returns a cached context or loads it exactly once per key.
func context(
for key: String,
loader: @escaping @Sendable () async throws -> ModelContext
) async throws -> ModelContext {
let cacheKey = key as NSString
if let cached = cache.object(forKey: cacheKey) {
return cached.context
}

if let task = inFlightTask(for: key) {
return try await task.value.context
}

let task = Task { try await CachedContext(loader()) }
setInFlight(task, for: key)

do {
let cached = try await task.value
cache.setObject(cached, forKey: cacheKey)
clearInFlight(for: key)
return cached.context
} catch {
clearInFlight(for: key)
throw error
}
}

/// Removes a cached context for the key.
func remove(for key: String) {
cache.removeObject(forKey: key as NSString)
}

/// Clears all cached contexts.
func removeAll() {
cache.removeAllObjects()
}

/// Cancels in-flight work and removes cached data for the key.
func removeAndCancel(for key: String) async {
let task = removeInFlight(for: key)
task?.cancel()
cache.removeObject(forKey: key as NSString)
}

/// Cancels all in-flight work and clears cached data.
func removeAllAndCancel() async {
let tasks = removeAllInFlight()
tasks.forEach { $0.cancel() }
cache.removeAllObjects()
}

private func inFlightTask(for key: String) -> Task<CachedContext, Error>? {
lock.lock()
defer { lock.unlock() }
return inFlight[key]
}

private func setInFlight(_ task: Task<CachedContext, Error>, for key: String) {
lock.lock()
inFlight[key] = task
lock.unlock()
}

private func clearInFlight(for key: String) {
lock.lock()
inFlight[key] = nil
lock.unlock()
}

private func removeInFlight(for key: String) -> Task<CachedContext, Error>? {
lock.lock()
defer { lock.unlock() }
let task = inFlight[key]
inFlight[key] = nil
return task
}

private func removeAllInFlight() -> [Task<CachedContext, Error>] {
lock.lock()
defer { lock.unlock() }
let tasks = Array(inFlight.values)
inFlight.removeAll()
return tasks
}
}

/// Shared cache across MLXLanguageModel instances.
private nonisolated(unsafe) let modelCache = ModelContextCache(countLimit: 3)

// MARK: - MLXLanguageModel

/// A language model that runs locally using MLX.
///
/// Use this model to run language models on Apple silicon using the MLX framework.
Expand Down Expand Up @@ -51,6 +162,33 @@ import Foundation
self.directory = directory
}

/// Removes this model from the shared cache and cancels any in-flight load.
///
/// Call this to free memory when the model is no longer needed.
/// The model will be reloaded automatically on the next request.
public func removeFromCache() async {
let key = directory?.absoluteString ?? modelId
await modelCache.removeAndCancel(for: key)
}

/// Removes all MLX models from the shared cache and cancels in-flight loads.
public static func removeAllFromCache() async {
await modelCache.removeAllAndCancel()
}

/// Get or load model context with caching
private func loadContext(modelId: String, hub: HubApi?, directory: URL?) async throws -> ModelContext {
let key = directory?.absoluteString ?? modelId

return try await modelCache.context(for: key) {
if let directory {
return try await loadModel(directory: directory)
}

return try await loadModel(hub: hub ?? HubApi(), id: modelId)
}
}

public func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
Expand All @@ -63,14 +201,8 @@ import Foundation
fatalError("MLXLanguageModel only supports generating String content")
}

let context: ModelContext
if let directory {
context = try await loadModel(directory: directory)
} else if let hub {
context = try await loadModel(hub: hub, id: modelId)
} else {
context = try await loadModel(id: modelId)
}
// Get cached or load fresh ModelContext
let context = try await loadContext(modelId: modelId, hub: hub, directory: directory)

// Convert session tools to MLX ToolSpec format
let toolSpecs: [ToolSpec]? =
Expand Down Expand Up @@ -179,18 +311,11 @@ import Foundation
continuation in
let task = Task { @Sendable in
do {
let context: ModelContext
if let directory {
context = try await loadModel(directory: directory)
} else if let hub {
context = try await loadModel(hub: hub, id: modelId)
} else {
context = try await loadModel(id: modelId)
}
// Get cached or load fresh ModelContext
let context = try await loadContext(modelId: modelId, hub: hub, directory: directory)

// Build chat inside task to avoid Sendable issues
let generateParameters = toGenerateParameters(options)

// Build chat history from full transcript
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

let userInput = MLXLMCommon.UserInput(
Expand Down Expand Up @@ -462,4 +587,5 @@ import Foundation
}
return textParts.joined(separator: "\n")
}

#endif // MLX