Skip to content

Commit df19229

Browse files
committed
Add additionalContext support to MLXLanguageModel
1 parent 7b311b1 commit df19229

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ import Foundation
183183
/// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit")
184184
/// ```
185185
public struct MLXLanguageModel: LanguageModel {
186+
/// Custom generation options for MLX models.
187+
public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions {
188+
/// Additional key-value pairs injected into the chat template rendering context.
189+
public var additionalContext: [String: MLXLMCommon.JSONValue]?
190+
191+
public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) {
192+
self.additionalContext = additionalContext
193+
}
194+
}
195+
186196
/// The reason the model is unavailable.
187197
public enum UnavailableReason: Sendable, Equatable, Hashable {
188198
/// The model has not been loaded into memory yet.
@@ -292,6 +302,11 @@ import Foundation
292302
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
293303
let generateParameters = toGenerateParameters(options)
294304

305+
// Extract additional context from custom options
306+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
307+
.flatMap { $0.additionalContext }
308+
.map { $0.mapValues { $0.toSendable() } }
309+
295310
// Build chat history from full transcript
296311
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
297312

@@ -305,6 +320,7 @@ import Foundation
305320
chat: chat,
306321
processing: .init(resize: .init(width: 512, height: 512)),
307322
tools: toolSpecs,
323+
additionalContext: additionalContext,
308324
)
309325
let lmInput = try await context.processor.prepare(input: userInput)
310326

@@ -407,10 +423,15 @@ import Foundation
407423
let generateParameters = toGenerateParameters(options)
408424
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
409425

426+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
427+
.flatMap { $0.additionalContext }
428+
.map { $0.mapValues { $0.toSendable() } }
429+
410430
let userInput = MLXLMCommon.UserInput(
411431
chat: chat,
412432
processing: .init(resize: .init(width: 512, height: 512)),
413-
tools: nil
433+
tools: nil,
434+
additionalContext: additionalContext
414435
)
415436
let lmInput = try await context.processor.prepare(input: userInput)
416437

@@ -876,10 +897,16 @@ import Foundation
876897
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
877898
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
878899
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)
900+
901+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
902+
.flatMap { $0.additionalContext }
903+
.map { $0.mapValues { $0.toSendable() } }
904+
879905
let userInput = MLXLMCommon.UserInput(
880906
chat: chat,
881907
processing: .init(resize: .init(width: 512, height: 512)),
882-
tools: nil
908+
tools: nil,
909+
additionalContext: additionalContext,
883910
)
884911
let lmInput = try await context.processor.prepare(input: userInput)
885912

@@ -1120,4 +1147,18 @@ import Foundation
11201147
return sampledToken.item(Int.self)
11211148
}
11221149
}
1150+
extension MLXLMCommon.JSONValue {
1151+
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1152+
func toSendable() -> any Sendable {
1153+
switch self {
1154+
case .string(let s): return s
1155+
case .int(let i): return i
1156+
case .double(let d): return d
1157+
case .bool(let b): return b
1158+
case .null: return NSNull()
1159+
case .array(let arr): return arr.map { $0.toSendable() }
1160+
case .object(let obj): return obj.mapValues { $0.toSendable() }
1161+
}
1162+
}
1163+
}
11231164
#endif // MLX

Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,28 @@ import Testing
220220
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
221221
}
222222

223+
@Test func withAdditionalContext() async throws {
224+
let session = LanguageModelSession(model: model)
225+
226+
var options = GenerationOptions(
227+
temperature: 0.7,
228+
maximumResponseTokens: 32
229+
)
230+
options[custom: MLXLanguageModel.self] = .init(
231+
additionalContext: [
232+
"user_name": .string("Alice"),
233+
"turn_count": .int(3),
234+
"verbose": .bool(true),
235+
]
236+
)
237+
238+
let response = try await session.respond(
239+
to: "Say hello",
240+
options: options
241+
)
242+
#expect(!response.content.isEmpty)
243+
}
244+
223245
@Test func unavailableForNonexistentModel() async {
224246
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
225247
await model.removeFromCache()

0 commit comments

Comments
 (0)