@@ -183,6 +183,16 @@ import Foundation
183183 /// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit")
184184 /// ```
185185 public struct MLXLanguageModel : LanguageModel {
186+ /// Custom generation options for MLX models.
187+ public struct CustomGenerationOptions : AnyLanguageModel . CustomGenerationOptions {
188+ /// Additional key-value pairs injected into the chat template rendering context.
189+ public var additionalContext : [ String : MLXLMCommon . JSONValue ] ?
190+
191+ public init ( additionalContext: [ String : MLXLMCommon . JSONValue ] ? = nil ) {
192+ self . additionalContext = additionalContext
193+ }
194+ }
195+
186196 /// The reason the model is unavailable.
187197 public enum UnavailableReason : Sendable , Equatable , Hashable {
188198 /// The model has not been loaded into memory yet.
@@ -292,6 +302,11 @@ import Foundation
292302 // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
293303 let generateParameters = toGenerateParameters ( options)
294304
305+ // Extract additional context from custom options
306+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
307+ . flatMap { $0. additionalContext }
308+ . map { $0. mapValues { $0. toSendable ( ) } }
309+
295310 // Build chat history from full transcript
296311 var chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
297312
@@ -305,6 +320,7 @@ import Foundation
305320 chat: chat,
306321 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
307322 tools: toolSpecs,
323+ additionalContext: additionalContext,
308324 )
309325 let lmInput = try await context. processor. prepare ( input: userInput)
310326
@@ -407,10 +423,15 @@ import Foundation
407423 let generateParameters = toGenerateParameters ( options)
408424 let chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
409425
426+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
427+ . flatMap { $0. additionalContext }
428+ . map { $0. mapValues { $0. toSendable ( ) } }
429+
410430 let userInput = MLXLMCommon . UserInput (
411431 chat: chat,
412432 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
413- tools: nil
433+ tools: nil ,
434+ additionalContext: additionalContext
414435 )
415436 let lmInput = try await context. processor. prepare ( input: userInput)
416437
@@ -876,10 +897,16 @@ import Foundation
876897 let baseChat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
877898 let schemaPrompt = includeSchemaInPrompt ? schemaPrompt ( for: schema) : nil
878899 let chat = normalizeChatForStructuredGeneration ( baseChat, schemaPrompt: schemaPrompt)
900+
901+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
902+ . flatMap { $0. additionalContext }
903+ . map { $0. mapValues { $0. toSendable ( ) } }
904+
879905 let userInput = MLXLMCommon . UserInput (
880906 chat: chat,
881907 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
882- tools: nil
908+ tools: nil ,
909+ additionalContext: additionalContext,
883910 )
884911 let lmInput = try await context. processor. prepare ( input: userInput)
885912
@@ -1120,4 +1147,18 @@ import Foundation
11201147 return sampledToken. item ( Int . self)
11211148 }
11221149 }
1150+ extension MLXLMCommon . JSONValue {
1151+ /// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1152+ func toSendable( ) -> any Sendable {
1153+ switch self {
1154+ case . string( let s) : return s
1155+ case . int( let i) : return i
1156+ case . double( let d) : return d
1157+ case . bool( let b) : return b
1158+ case . null: return NSNull ( )
1159+ case . array( let arr) : return arr. map { $0. toSendable ( ) }
1160+ case . object( let obj) : return obj. mapValues { $0. toSendable ( ) }
1161+ }
1162+ }
1163+ }
11231164#endif // MLX
0 commit comments