From 83b8551d5aed1071965534a143141b6e375cdc30 Mon Sep 17 00:00:00 2001 From: eastriver Date: Sat, 20 Dec 2025 12:16:01 +0900 Subject: [PATCH 1/7] Fix SystemLanguageModel to pass schema for structured generation --- .../AnyLanguageModel/GenerationSchema.swift | 19 +- .../Models/SystemLanguageModel.swift | 129 ++++-- .../GenerableMacroTests.swift | 64 +++ .../StructuredGenerationTests.swift | 397 ++++++++++++++++++ 4 files changed, 573 insertions(+), 36 deletions(-) create mode 100644 Tests/AnyLanguageModelTests/StructuredGenerationTests.swift diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index a8065d96..c5a59126 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -160,6 +160,16 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible ) } } + + var nodeDescription: String? { + switch self { + case .object(let node): node.description + case .array(let node): node.description + case .string(let node): node.description + case .number(let node): node.description + case .boolean, .anyOf, .ref: nil + } + } } struct ObjectNode: Sendable, Codable { @@ -204,7 +214,7 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible } let root: Node - private var defs: [String: Node] + var defs: [String: Node] /// A string representation of the debug description. /// @@ -703,6 +713,13 @@ extension GenerationSchema { } else { // Complex type - use its schema let schema = Value.generationSchema + + // Arrays should be inlined, not referenced + if case .array(var arrayNode) = schema.root { + arrayNode.description = description + return (.array(arrayNode), schema.defs) + } + let typeName = String(reflecting: Value.self) var deps = schema.defs diff --git a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift index df0b9f54..0360e998 100644 --- a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift @@ -81,17 +81,19 @@ transcript: session.transcript.toFoundationModels(instructions: session.instructions) ) - let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) - let generatedContent = GeneratedContent(fmResponse.content) - if type == String.self { + let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) + let generatedContent = GeneratedContent(fmResponse.content) return LanguageModelSession.Response( content: fmResponse.content as! Content, rawContent: generatedContent, transcriptEntries: [] ) } else { - // For non-String types, try to create an instance from the generated content + // For non-String types, use schema-based structured generation + let schema = FoundationModels.GenerationSchema(type.generationSchema) + let fmResponse = try await fmSession.respond(to: fmPrompt, schema: schema, options: fmOptions) + let generatedContent = try AnyLanguageModel.GeneratedContent(fmResponse.content) let content = try type.init(generatedContent) return LanguageModelSession.Response( @@ -321,25 +323,15 @@ extension FoundationModels.GenerationSchema { internal init(_ content: AnyLanguageModel.GenerationSchema) { let resolvedSchema = content.withResolvedRoot() ?? content - - let rawParameters = try? JSONValue(resolvedSchema) - var schema: FoundationModels.GenerationSchema? = nil - if rawParameters?.objectValue is [String: JSONValue] { - if let data = try? JSONEncoder().encode(rawParameters) { - if let jsonSchema = try? JSONDecoder().decode(JSONSchema.self, from: data) { - let dynamicSchema = convertToDynamicSchema(jsonSchema) - schema = try? FoundationModels.GenerationSchema(root: dynamicSchema, dependencies: []) - } - } + let dynamicSchema = convertToDynamicSchema(resolvedSchema.root) + let dependencies = resolvedSchema.defs.map { name, node in + convertToDynamicSchema(node, name: name) } - if let schema = schema { + + if let schema = try? FoundationModels.GenerationSchema(root: dynamicSchema, dependencies: dependencies) { self = schema } else { - self = FoundationModels.GenerationSchema( - type: String.self, - properties: [] - ) - + self = FoundationModels.GenerationSchema(type: String.self, properties: []) } } } @@ -368,6 +360,78 @@ } } + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + func convertToDynamicSchema( + _ node: GenerationSchema.Node, + name: String? = nil + ) -> FoundationModels.DynamicGenerationSchema { + switch node { + case .object(let objectNode): + return .init( + name: name ?? "", + description: objectNode.description, + properties: objectNode.properties.map { key, value in + .init( + name: key, + description: value.nodeDescription, + schema: convertToDynamicSchema(value), + isOptional: !objectNode.required.contains(key) + ) + } + ) + + case .string(let stringNode): + if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { + return .init( + name: name ?? "", + description: stringNode.description, + anyOf: enumChoices.map { .init(type: String.self, guides: [.constant($0)]) } + ) + } + if let pattern = stringNode.pattern, let regex = try? Regex(pattern) { + return .init(type: String.self, guides: [.pattern(regex)]) + } + return .init(type: String.self) + + case .number(let numberNode): + return numberNode.integerOnly + ? .init(type: Int.self, guides: intGuides(numberNode)) + : .init(type: Double.self, guides: doubleGuides(numberNode)) + + case .boolean: + return .init(type: Bool.self) + + case .array(let arrayNode): + return .init( + arrayOf: convertToDynamicSchema(arrayNode.items), + minimumElements: arrayNode.minItems, + maximumElements: arrayNode.maxItems + ) + + case .anyOf(let nodes): + return .init(name: "", anyOf: nodes.map { convertToDynamicSchema($0) }) + + case .ref(let refName): + return .init(referenceTo: refName) + } + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + private func intGuides(_ numberNode: GenerationSchema.NumberNode) -> [FoundationModels.GenerationGuide] { + var guides: [FoundationModels.GenerationGuide] = [] + if let minimum = numberNode.minimum { guides.append(.minimum(Int(minimum))) } + if let maximum = numberNode.maximum { guides.append(.maximum(Int(maximum))) } + return guides + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + private func doubleGuides(_ numberNode: GenerationSchema.NumberNode) -> [FoundationModels.GenerationGuide] { + var guides: [FoundationModels.GenerationGuide] = [] + if let minimum = numberNode.minimum { guides.append(.minimum(minimum)) } + if let maximum = numberNode.maximum { guides.append(.maximum(maximum)) } + return guides + } + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) func convertToDynamicSchema(_ jsonSchema: JSONSchema) -> FoundationModels.DynamicGenerationSchema { switch jsonSchema { @@ -378,10 +442,13 @@ return .init(name: "", description: jsonSchema.description, properties: schemaProperties) case .string(_, _, _, _, _, _, _, _, pattern: let pattern, _): - var guides: [FoundationModels.GenerationGuide] = [] - if let values = jsonSchema.enum?.compactMap(\.stringValue), !values.isEmpty { - guides.append(.anyOf(values)) + if let enumValues = jsonSchema.enum?.compactMap(\.stringValue), !enumValues.isEmpty { + let enumSchemas = enumValues.map { + FoundationModels.DynamicGenerationSchema(type: String.self, guides: [.constant($0)]) + } + return .init(name: "", description: jsonSchema.description, anyOf: enumSchemas) } + var guides: [FoundationModels.GenerationGuide] = [] if let value = jsonSchema.const?.stringValue { guides.append(.constant(value)) } @@ -397,12 +464,8 @@ } var guides: [FoundationModels.GenerationGuide] = [] - if let min = minimum { - guides.append(.minimum(min)) - } - if let max = maximum { - guides.append(.maximum(max)) - } + if let minimum { guides.append(.minimum(minimum)) } + if let maximum { guides.append(.maximum(maximum)) } if let value = jsonSchema.const?.intValue { guides.append(.range(value ... value)) } @@ -415,12 +478,8 @@ } var guides: [FoundationModels.GenerationGuide] = [] - if let min = minimum { - guides.append(.minimum(min)) - } - if let max = maximum { - guides.append(.maximum(max)) - } + if let minimum { guides.append(.minimum(minimum)) } + if let maximum { guides.append(.maximum(maximum)) } if let value = jsonSchema.const?.doubleValue { guides.append(.range(value ... value)) } diff --git a/Tests/AnyLanguageModelTests/GenerableMacroTests.swift b/Tests/AnyLanguageModelTests/GenerableMacroTests.swift index 1e6c12e7..5c6c242b 100644 --- a/Tests/AnyLanguageModelTests/GenerableMacroTests.swift +++ b/Tests/AnyLanguageModelTests/GenerableMacroTests.swift @@ -34,6 +34,38 @@ struct TestArguments { var age: Int } +@Generable +private enum TestEnum: Equatable { + case optionA + case optionB + case optionC +} + +@Generable +private struct TestNestedInner: Equatable { + var value: String + var count: Int +} + +@Generable +private struct TestNestedOuter: Equatable { + var name: String + var inner: TestNestedInner +} + +@Generable +private struct TestStructWithEnum: Equatable { + var label: String + var choice: TestEnum +} + +@Generable +private struct TestStructWithArray: Equatable { + var title: String + @Guide(.count(3)) + var items: [String] +} + @Suite("Generable Macro") struct GenerableMacroTests { @Test("@Guide description with multiline string") @@ -135,4 +167,36 @@ struct GenerableMacroTests { #expect(args.name == "Bob") #expect(args.age == 25) } + + @Test("Enum round-trip conversion") + func enumRoundTrip() throws { + for choice in [TestEnum.optionA, TestEnum.optionB, TestEnum.optionC] { + let restored = try TestEnum(choice.generatedContent) + #expect(choice == restored) + } + } + + @Test("Nested struct round-trip conversion") + func nestedStructRoundTrip() throws { + let original = TestNestedOuter( + name: "outer", + inner: TestNestedInner(value: "inner", count: 42) + ) + let restored = try TestNestedOuter(original.generatedContent) + #expect(original == restored) + } + + @Test("Struct with enum round-trip conversion") + func structWithEnumRoundTrip() throws { + let original = TestStructWithEnum(label: "test", choice: .optionB) + let restored = try TestStructWithEnum(original.generatedContent) + #expect(original == restored) + } + + @Test("Struct with array round-trip conversion") + func structWithArrayRoundTrip() throws { + let original = TestStructWithArray(title: "list", items: ["a", "b", "c"]) + let restored = try TestStructWithArray(original.generatedContent) + #expect(original == restored) + } } diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift new file mode 100644 index 00000000..7ed1dc9a --- /dev/null +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -0,0 +1,397 @@ +import Foundation +import Testing + +@testable import AnyLanguageModel + +@Generable +enum Priority: Equatable { + case low + case medium + case high +} + +@Generable +struct SimpleString: Equatable { + @Guide(description: "A greeting message") + var message: String +} + +@Generable +struct SimpleInt: Equatable { + @Guide(description: "A count value", .minimum(0)) + var count: Int +} + +@Generable +struct SimpleBool: Equatable { + @Guide(description: "A boolean flag") + var value: Bool +} + +@Generable +struct SimpleDouble: Equatable { + @Guide(description: "A temperature value") + var temperature: Double +} + +@Generable +struct OptionalFields: Equatable { + @Guide(description: "A required name") + var name: String + + @Guide(description: "An optional nickname") + var nickname: String? +} + +@Generable +struct BasicStruct: Equatable { + @Guide(description: "Person's name") + var name: String + + @Guide(description: "Person's age", .minimum(0)) + var age: Int + + @Guide(description: "Is the person active") + var isActive: Bool + + @Guide(description: "Score value") + var score: Double +} + +@Generable +struct Address: Equatable { + @Guide(description: "Street name") + var street: String + + @Guide(description: "City name") + var city: String + + @Guide(description: "Postal code") + var postalCode: String +} + +@Generable +struct Person: Equatable { + @Guide(description: "Person's name") + var name: String + + @Guide(description: "Person's age") + var age: Int + + var address: Address +} + +@Generable +struct TaskItem: Equatable { + @Guide(description: "Task title") + var title: String + + var priority: Priority + + @Guide(description: "Is completed") + var isCompleted: Bool +} + +@Generable +struct SimpleArray: Equatable { + @Guide(description: "A list of color names") + var colors: [String] +} + +@Generable +struct MultiChoiceQuestion: Equatable { + @Guide(description: "The quiz question") + var text: String + + @Guide(.count(4)) + var choices: [String] + + var answer: String + + @Guide(description: "A brief explanation of why the answer is correct") + var explanation: String +} + +private struct SupportedModel: Sendable { + let name: String + let model: any LanguageModel + + static var all: [SupportedModel] { + var models: [SupportedModel] = [] + + #if canImport(FoundationModels) + if #available(macOS 26.0, *) { + if SystemLanguageModel.default.isAvailable { + models.append(SupportedModel(name: "SystemLanguageModel", model: SystemLanguageModel.default)) + } + } + #endif + + #if Llama + if let modelPath = ProcessInfo.processInfo.environment["LLAMA_MODEL_PATH"] { + models.append(SupportedModel(name: "LlamaLanguageModel", model: LlamaLanguageModel(modelPath: modelPath))) + } + #endif + + #if MLX + let shouldRunMLX = ProcessInfo.processInfo.environment["ENABLE_MLX_TESTS"] != nil + || (ProcessInfo.processInfo.environment["CI"] == nil + && ProcessInfo.processInfo.environment["HF_TOKEN"] != nil + && ProcessInfo.processInfo.environment["XCTestConfigurationFilePath"] != nil) + if shouldRunMLX { + models.append( + SupportedModel( + name: "MLXLanguageModel", + model: MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit") + ) + ) + } + #endif + + return models + } +} + +private let supportedModels = SupportedModel.all + +private func isGenerationTestsEnabled() -> Bool { + !supportedModels.isEmpty +} + +private func testAllModels(_ test: (SupportedModel) async throws -> Void) async { + var failures: [(name: String, error: any Error)] = [] + + for model in supportedModels { + print("Testing: \(model.name)") + do { + try await test(model) + print(" ✓ \(model.name) passed") + } catch { + print(" ✗ \(model.name) failed: \(error)") + failures.append((model.name, error)) + } + } + + for failure in failures { + Issue.record("[\(failure.name)] \(failure.error)") + } +} + +@Suite("Structured Generation", .serialized, .enabled(if: isGenerationTestsEnabled())) +struct StructuredGenerationTests { + @Test("Generate SimpleString with all supported models") + func generateSimpleString() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a greeting message that says hello", + generating: SimpleString.self + ) + + #expect(!response.content.message.isEmpty, "[\(model.name)] message should not be empty") + } + } + + @Test("Generate SimpleInt with all supported models") + func generateSimpleInt() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a count value of 42", + generating: SimpleInt.self + ) + + #expect(response.content.count >= 0, "[\(model.name)] count should be non-negative") + } + } + + @Test("Generate SimpleDouble with all supported models") + func generateSimpleDouble() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a temperature value of 72.5 degrees", + generating: SimpleDouble.self + ) + + #expect(!response.content.temperature.isNaN, "[\(model.name)] temperature should be a valid number") + } + } + + @Test("Generate SimpleBool with all supported models") + func generateSimpleBool() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a boolean value: true", + generating: SimpleBool.self + ) + + let jsonData = response.rawContent.jsonString.data(using: .utf8) + #expect(jsonData != nil, "[\(model.name)] rawContent should be valid UTF-8 JSON") + if let jsonData { + let json = try JSONSerialization.jsonObject(with: jsonData) + let dictionary = json as? [String: Any] + let boolValue = dictionary?["value"] as? Bool + #expect(boolValue != nil, "[\(model.name)] value should be encoded as a JSON boolean") + } + } + } + + @Test("Generate OptionalFields with all supported models") + func generateOptionalFields() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person named Alex with nickname 'Lex'. Nickname may be omitted if unsure.", + generating: OptionalFields.self + ) + + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + if let nickname = response.content.nickname { + #expect(!nickname.isEmpty, "[\(model.name)] nickname should not be empty when present") + } + } + } + + @Test("Generate Priority enum with all supported models") + func generatePriorityEnum() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a high priority value", + generating: Priority.self + ) + + #expect( + [Priority.low, Priority.medium, Priority.high].contains(response.content), + "[\(model.name)] should generate valid priority" + ) + } + } + + @Test("Generate BasicStruct with all supported models") + func generateBasicStruct() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person with name Alice, age 30, active status true, and score 95.5", + generating: BasicStruct.self + ) + + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") + } + } + + @Test("Generate nested struct (Person with Address) with all supported models") + func generateNestedStruct() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a person named John, age 25, living at 123 Main St, Springfield, 12345", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") + #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") + #expect(!response.content.address.street.isEmpty, "[\(model.name)] street should not be empty") + #expect(!response.content.address.city.isEmpty, "[\(model.name)] city should not be empty") + } + } + + @Test("Generate struct with enum (TaskItem) with all supported models") + func generateStructWithEnum() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a task titled 'Complete project' with high priority, not completed", + generating: TaskItem.self + ) + + #expect(!response.content.title.isEmpty, "[\(model.name)] title should not be empty") + #expect( + [Priority.low, Priority.medium, Priority.high].contains(response.content.priority), + "[\(model.name)] should have valid priority" + ) + } + } + + @Test("Generate simple array with all supported models") + func generateSimpleArray() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: "Generate a list of 3 color names: red, green, blue", + generating: SimpleArray.self + ) + + #expect(!response.content.colors.isEmpty, "[\(model.name)] colors should not be empty") + } + } + + @Test("Generate struct with array (MultiChoiceQuestion) with all supported models") + func generateStructWithArray() async { + await testAllModels { model in + let session = LanguageModelSession( + model: model.model, + instructions: "You are a helpful assistant that generates structured data." + ) + + let response = try await session.respond( + to: """ + Generate a quiz question: + - Question: What is the capital of France? + - Choices: London, Paris, Berlin, Madrid + - Answer: Paris + - Explanation: Paris is the capital city of France + """, + generating: MultiChoiceQuestion.self + ) + + #expect(!response.content.text.isEmpty, "[\(model.name)] question text should not be empty") + #expect(response.content.choices.count == 4, "[\(model.name)] should have exactly 4 choices") + #expect(!response.content.answer.isEmpty, "[\(model.name)] answer should not be empty") + } + } +} From a6fb7130fcf185294cf3452ac2b1dc8359961a11 Mon Sep 17 00:00:00 2001 From: eastriver Date: Sat, 20 Dec 2025 16:04:43 +0900 Subject: [PATCH 2/7] Implement logit-constrained structured generation for LlamaLanguageModel --- .../Models/LlamaLanguageModel.swift | 403 +++++++++++++++++- .../StructuredGenerationTests.swift | 3 - 2 files changed, 383 insertions(+), 23 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index a57da336..811f98ef 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -478,14 +478,8 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("LlamaLanguageModel only supports generating String content") - } - // Validate that no image segments are present try validateNoImageSegments(in: session) - try await ensureModelLoaded() let runtimeOptions = resolvedOptions(from: options) @@ -495,7 +489,6 @@ import Foundation guard let context = llama_init_from_model(model!, contextParams) else { throw LlamaLanguageModelError.contextInitializationFailed } - defer { llama_free(context) } // Check if this is an embedding model (no KV cache). @@ -510,22 +503,44 @@ import Foundation llama_set_warmup(context, false) llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads) - let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 let fullPrompt = try formatPrompt(for: session) - let text = try await generateText( - context: context, - model: model!, - prompt: fullPrompt, - maxTokens: maxTokens, - options: runtimeOptions - ) + if type == String.self { + let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 + let text = try await generateText( + context: context, + model: model!, + prompt: fullPrompt, + maxTokens: maxTokens, + options: runtimeOptions + ) - return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), - transcriptEntries: ArraySlice([]) - ) + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice([]) + ) + } else { + let maxTokens = runtimeOptions.maximumResponseTokens ?? 512 + let schema = type.generationSchema + let jsonString = try await generateStructuredJSON( + context: context, + model: model!, + prompt: fullPrompt, + schema: schema, + maxTokens: maxTokens, + options: runtimeOptions + ) + + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } } public func streamResponse( @@ -840,6 +855,354 @@ import Foundation return generatedText } + // MARK: - Structured JSON Generation (logit constrained) + + private func generateStructuredJSON( + context: OpaquePointer, + model: OpaquePointer, + prompt: String, + schema: GenerationSchema, + maxTokens: Int, + options: ResolvedGenerationOptions + ) async throws -> String { + guard let vocab = llama_model_get_vocab(model) else { + throw LlamaLanguageModelError.contextInitializationFailed + } + + let promptTokens = try tokenizeText(vocab: vocab, text: prompt) + guard !promptTokens.isEmpty else { + throw LlamaLanguageModelError.tokenizationFailed + } + + var batch = llama_batch_init(Int32(options.batchSize), 0, 1) + defer { llama_batch_free(batch) } + + let hasEncoder = try prepareInitialBatch( + batch: &batch, + promptTokens: promptTokens, + model: model, + vocab: vocab, + context: context, + batchSize: options.batchSize + ) + + guard let sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()) else { + throw LlamaLanguageModelError.decodingFailed + } + defer { llama_sampler_free(sampler) } + let samplerPointer = UnsafeMutablePointer(sampler) + + applySampling(sampler: samplerPointer, effectiveTemperature: options.temperature, options: options) + + let vocabSize = Int(llama_vocab_n_tokens(vocab)) + let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens + + return try withUnsafeMutablePointer(to: &batch) { batchPointer in + let generator = try StructuredJSONGenerator( + context: context, + vocab: vocab, + vocabSize: vocabSize, + sampler: samplerPointer, + tokenToText: { token in self.tokenToText(vocab: vocab, token: token) }, + batch: batchPointer, + initialPosition: initialPosition, + maximumTokens: maxTokens, + schema: schema + ) + return try generator.generate() + } + } + + private struct StructuredJSONGenerator { + let context: OpaquePointer + let vocab: OpaquePointer + let vocabSize: Int + let sampler: UnsafeMutablePointer + let tokenToText: (llama_token) -> String? + let batch: UnsafeMutablePointer + let schema: GenerationSchema + + var position: Int32 + var remainingTokens: Int + + let quoteToken: llama_token + let digitOnlyTokens: Set + let validStringTokens: Set + let validStringTokensOrQuote: Set + + init( + context: OpaquePointer, + vocab: OpaquePointer, + vocabSize: Int, + sampler: UnsafeMutablePointer, + tokenToText: @escaping (llama_token) -> String?, + batch: UnsafeMutablePointer, + initialPosition: Int32, + maximumTokens: Int, + schema: GenerationSchema + ) throws { + self.context = context + self.vocab = vocab + self.vocabSize = vocabSize + self.sampler = sampler + self.tokenToText = tokenToText + self.batch = batch + self.position = initialPosition + self.remainingTokens = maximumTokens + self.schema = schema + + guard let quoteToken = try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, "\"").first else { + throw LlamaLanguageModelError.tokenizationFailed + } + self.quoteToken = quoteToken + self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens(vocabSize: vocabSize, tokenToText: tokenToText) + self.validStringTokens = StructuredJSONGenerator.buildValidJSONStringContentTokens( + vocabSize: vocabSize, + tokenToText: tokenToText + ) + var tokensOrQuote = self.validStringTokens + tokensOrQuote.insert(quoteToken) + self.validStringTokensOrQuote = tokensOrQuote + } + + func generate() throws -> String { + var generator = self + return try generator.generateNode(schema.root) + } + + private static func buildDigitOnlyTokens( + vocabSize: Int, + tokenToText: (llama_token) -> String? + ) -> Set { + Set((0 ..< vocabSize).compactMap { tokenIndex in + let token = llama_token(tokenIndex) + guard let text = tokenToText(token), !text.isEmpty else { return nil } + return text.allSatisfy({ $0.isNumber }) ? token : nil + }) + } + + private static func buildValidJSONStringContentTokens( + vocabSize: Int, + tokenToText: (llama_token) -> String? + ) -> Set { + var allowed = Set() + allowed.reserveCapacity(vocabSize / 4) + + for tokenIndex in 0 ..< vocabSize { + let token = llama_token(tokenIndex) + guard let text = tokenToText(token), !text.isEmpty else { continue } + if text.contains("\"") { continue } + if text.contains("\\") { continue } + if text.unicodeScalars.contains(where: { $0.value < 0x20 }) { continue } + allowed.insert(token) + } + return allowed + } + + private static func tokenizeFragment(vocab: OpaquePointer, _ text: String) throws -> [llama_token] { + let utf8Count = text.utf8.count + let capacity = Int32(max(utf8Count * 2, 8)) + let tokens = UnsafeMutablePointer.allocate(capacity: Int(capacity)) + defer { tokens.deallocate() } + + let tokenCount = llama_tokenize( + vocab, + text, + Int32(utf8Count), + tokens, + capacity, + false, + false + ) + guard tokenCount > 0 else { return [] } + return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount))) + } + + private mutating func decodeToken(_ token: llama_token) throws { + batch.pointee.n_tokens = 1 + batch.pointee.token[0] = token + batch.pointee.pos[0] = position + batch.pointee.n_seq_id[0] = 1 + if let seqIds = batch.pointee.seq_id, let seqId = seqIds[0] { + seqId[0] = 0 + } + batch.pointee.logits[0] = 1 + + position += 1 + remainingTokens -= 1 + + let decodeResult = llama_decode(context, batch.pointee) + guard decodeResult == 0 else { + throw LlamaLanguageModelError.decodingFailed + } + } + + private mutating func emitLiteral(_ text: String) throws -> String { + for token in try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, text) { + guard remainingTokens > 0 else { throw LlamaLanguageModelError.decodingFailed } + try decodeToken(token) + } + return text + } + + private mutating func sampleToken(allowedTokens: Set) -> llama_token { + guard let logits = llama_get_logits(context) else { + return llama_vocab_eos(vocab) + } + + for tokenIndex in 0 ..< vocabSize { + let token = llama_token(tokenIndex) + if !allowedTokens.contains(token) { + logits[tokenIndex] = -Float.infinity + } + } + + let tokenIndex = batch.pointee.n_tokens - 1 + let sampled = llama_sampler_sample(sampler, context, tokenIndex) + llama_sampler_accept(sampler, sampled) + return sampled + } + + private mutating func generateFreeString(maxTokens: Int) throws -> String { + var result = "" + var generatedTokens = 0 + + while remainingTokens > 0, generatedTokens < maxTokens { + let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrQuote + let token = sampleToken(allowedTokens: allowedTokens) + if token == quoteToken { break } + + result += tokenToText(token) ?? "" + generatedTokens += 1 + try decodeToken(token) + } + + return result + } + + private mutating func generateLiteralChoice(_ candidates: [String]) throws -> String { + let tokenizedCandidates = try candidates.map { try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, $0) } + .filter { !$0.isEmpty } + guard !tokenizedCandidates.isEmpty else { throw LlamaLanguageModelError.tokenizationFailed } + + var prefixes = tokenizedCandidates + var emitted = "" + var tokenPosition = 0 + + while remainingTokens > 0 { + if prefixes.contains(where: { $0.count == tokenPosition }) { break } + + let allowed = Set(prefixes.compactMap { tokens -> llama_token? in + guard tokenPosition < tokens.count else { return nil } + return tokens[tokenPosition] + }) + + let nextToken = sampleToken(allowedTokens: allowed) + emitted += tokenToText(nextToken) ?? "" + try decodeToken(nextToken) + + prefixes = prefixes.filter { tokens in + tokenPosition < tokens.count && tokens[tokenPosition] == nextToken + } + tokenPosition += 1 + if prefixes.isEmpty { break } + } + + return emitted + } + + private mutating func generateNumber(isInteger: Bool) throws -> String { + let maxTokens = isInteger ? 3 : 4 + var generated = "" + + for _ in 0 ..< maxTokens { + guard remainingTokens > 0 else { break } + let token = sampleToken(allowedTokens: digitOnlyTokens) + generated += tokenToText(token) ?? "" + try decodeToken(token) + if !generated.isEmpty { break } + } + + return generated.isEmpty ? "0" : generated + } + + private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { + let elementCount = arrayNode.minItems ?? arrayNode.maxItems ?? 4 + var output = try emitLiteral("[") + + for index in 0 ..< elementCount { + output += try generateNode(arrayNode.items) + if index < elementCount - 1 { + output += try emitLiteral(",") + } + } + + output += try emitLiteral("]") + return output + } + + private mutating func generateObject(_ objectNode: GenerationSchema.ObjectNode) throws -> String { + let keys = objectNode.properties.keys.sorted() + var output = try emitLiteral("{") + + for (index, key) in keys.enumerated() { + output += try emitLiteral("\"") + output += try emitLiteral(key) + output += try emitLiteral("\":") + + if let propertyNode = objectNode.properties[key] { + output += try generateNode(propertyNode) + } else { + output += try emitLiteral("null") + } + + if index < keys.count - 1 { + output += try emitLiteral(",") + } + } + + output += try emitLiteral("}") + return output + } + + private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { + guard remainingTokens > 0 else { throw LlamaLanguageModelError.decodingFailed } + + switch node { + case .string(let stringNode): + var output = try emitLiteral("\"") + if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { + output += try generateLiteralChoice(enumChoices) + } else { + output += try generateFreeString(maxTokens: 12) + } + output += try emitLiteral("\"") + return output + + case .number(let numberNode): + return try generateNumber(isInteger: numberNode.integerOnly) + + case .boolean: + return try generateLiteralChoice(["true", "false"]) + + case .array(let arrayNode): + return try generateArray(arrayNode) + + case .object(let objectNode): + return try generateObject(objectNode) + + case .anyOf(let nodes): + // Pick the first option deterministically (schema-driven). + guard let first = nodes.first else { throw LlamaLanguageModelError.decodingFailed } + return try generateNode(first) + + case .ref(let refName): + guard let referenced = schema.defs[refName] else { throw LlamaLanguageModelError.decodingFailed } + return try generateNode(referenced) + } + } + } + private func generateTextStream( context: OpaquePointer, model: OpaquePointer, diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift index 7ed1dc9a..df70c031 100644 --- a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -162,12 +162,9 @@ private func testAllModels(_ test: (SupportedModel) async throws -> Void) async var failures: [(name: String, error: any Error)] = [] for model in supportedModels { - print("Testing: \(model.name)") do { try await test(model) - print(" ✓ \(model.name) passed") } catch { - print(" ✗ \(model.name) failed: \(error)") failures.append((model.name, error)) } } From 9b1d2db9e2b507ef25abfe3326be8873876db793 Mon Sep 17 00:00:00 2001 From: eastriver Date: Sat, 20 Dec 2025 17:24:06 +0900 Subject: [PATCH 3/7] Implement logit-constrained structured generation for MLXLanguageModel --- .../Models/LlamaLanguageModel.swift | 9 +- .../Models/MLXLanguageModel.swift | 477 +++++++++++++++++- 2 files changed, 468 insertions(+), 18 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 811f98ef..af68c551 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -924,6 +924,7 @@ import Foundation var position: Int32 var remainingTokens: Int + let totalTokenBudget: Int let quoteToken: llama_token let digitOnlyTokens: Set @@ -949,6 +950,7 @@ import Foundation self.batch = batch self.position = initialPosition self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens self.schema = schema guard let quoteToken = try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, "\"").first else { @@ -1063,6 +1065,11 @@ import Foundation return sampled } + private func maxTokenCountForFreeString() -> Int { + let perStringLimit = max(32, totalTokenBudget / 4) + return min(remainingTokens, perStringLimit) + } + private mutating func generateFreeString(maxTokens: Int) throws -> String { var result = "" var generatedTokens = 0 @@ -1174,7 +1181,7 @@ import Foundation if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { output += try generateLiteralChoice(enumChoices) } else { - output += try generateFreeString(maxTokens: 12) + output += try generateFreeString(maxTokens: maxTokenCountForFreeString()) } output += try emitLiteral("\"") return output diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 77accfb5..aabf0b61 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -58,11 +58,6 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("MLXLanguageModel only supports generating String content") - } - let context: ModelContext if let directory { context = try await loadModel(directory: directory) @@ -72,6 +67,25 @@ import Foundation context = try await loadModel(id: modelId) } + if type != String.self { + let jsonString = try await generateStructuredJSON( + session: session, + prompt: prompt, + context: context, + options: options, + schema: type.generationSchema + ) + + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } + // Convert session tools to MLX ToolSpec format let toolSpecs: [ToolSpec]? = session.tools.isEmpty @@ -168,7 +182,7 @@ import Foundation options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { guard type == String.self else { - fatalError("MLXLanguageModel only supports generating String content") + fatalError("MLXLanguageModel streaming only supports String content") } let modelId = self.modelId @@ -189,7 +203,6 @@ import Foundation let generateParameters = toGenerateParameters(options) - // Build chat history from full transcript let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let userInput = MLXLMCommon.UserInput( @@ -358,30 +371,74 @@ import Foundation private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec { // Convert AnyLanguageModel's GenerationSchema to JSON-compatible dictionary - let parametersDict: [String: Any] + let parametersDict: [String: any Sendable] do { let resolvedSchema = tool.parameters.withResolvedRoot() ?? tool.parameters let encoder = JSONEncoder() let data = try encoder.encode(resolvedSchema) if let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] { - parametersDict = json + parametersDict = try convertToSendableJSONObject(json) } else { - parametersDict = ["type": "object", "properties": [:], "required": []] + parametersDict = makeEmptyJSONSchemaObject() } } catch { - parametersDict = ["type": "object", "properties": [:], "required": []] + parametersDict = makeEmptyJSONSchemaObject() } - return [ + let functionSpec: [String: any Sendable] = [ + "name": tool.name, + "description": tool.description, + "parameters": parametersDict, + ] + + let toolSpec: ToolSpec = [ "type": "function", - "function": [ - "name": tool.name, - "description": tool.description, - "parameters": parametersDict, - ], + "function": functionSpec, + ] + + return toolSpec + } + + private func makeEmptyJSONSchemaObject() -> [String: any Sendable] { + [ + "type": "object", + "properties": [String: any Sendable](), + "required": [String](), ] } + private func convertToSendableJSONObject(_ object: [String: Any]) throws -> [String: any Sendable] { + var converted: [String: any Sendable] = [:] + converted.reserveCapacity(object.count) + + for (key, value) in object { + converted[key] = try convertToSendableJSONValue(value) + } + return converted + } + + private func convertToSendableJSONValue(_ value: Any) throws -> any Sendable { + if value is NSNull { return MLXLMCommon.JSONValue.null } + if let stringValue = value as? String { return stringValue } + if let boolValue = value as? Bool { return boolValue } + if let intValue = value as? Int { return intValue } + if let doubleValue = value as? Double { return doubleValue } + if let numberValue = value as? NSNumber { + if CFGetTypeID(numberValue) == CFBooleanGetTypeID() { + return numberValue.boolValue + } + return numberValue.doubleValue + } + if let arrayValue = value as? [Any] { + return try arrayValue.map { try convertToSendableJSONValue($0) } + } + if let dictionaryValue = value as? [String: Any] { + return try convertToSendableJSONObject(dictionaryValue) + } + + throw StructuredGenerationError.invalidTokenization + } + // MARK: - Tool Invocation Handling private struct ToolInvocationResult { @@ -464,4 +521,390 @@ import Foundation } return textParts.joined(separator: "\n") } + + // MARK: - Structured JSON Generation (logit constrained) + + private enum StructuredGenerationError: Error { + case missingTokenizer + case emptyPrompt + case invalidQuoteToken + case invalidTokenization + case tokenBudgetExceeded + case invalidVocabSize + } + + private func generateStructuredJSON( + session: LanguageModelSession, + prompt: Prompt, + context: ModelContext, + options: GenerationOptions, + schema: GenerationSchema + ) async throws -> String { + let structuredMaxTokens = options.maximumResponseTokens ?? 512 + let generateParameters = toGenerateParameters(options) + + let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil + ) + let lmInput = try await context.processor.prepare(input: userInput) + + var decoder = try MLXTokenDecoder( + context: context, + input: lmInput, + parameters: generateParameters, + maximumTokens: structuredMaxTokens + ) + + let vocabSize = decoder.vocabSize + var generator = try StructuredJSONGenerator( + schema: schema, + tokenizeFragment: { fragment in + context.tokenizer.encode(text: fragment, addSpecialTokens: false) + }, + tokenText: { token in + context.tokenizer.decode(tokens: [token], skipSpecialTokens: false) + }, + decodeToken: { token in + try decoder.decodeToken(token) + }, + sampleToken: { allowedTokens in + try decoder.sampleToken(allowedTokens: allowedTokens) + }, + maximumTokens: structuredMaxTokens, + vocabSize: vocabSize + ) + + let json = try generator.generate() + Stream().synchronize() + return json + } + + private struct MLXTokenDecoder { + let model: any MLXLMCommon.LanguageModel + var state: MLXLMCommon.LMOutput.State? + var cache: [MLXLMCommon.KVCache] + var processor: MLXLMCommon.LogitProcessor? + let sampler: MLXLMCommon.LogitSampler + + var currentLogits: MLXArray + let vocabSize: Int + + init( + context: ModelContext, + input: MLXLMCommon.LMInput, + parameters: MLXLMCommon.GenerateParameters, + maximumTokens: Int + ) throws { + self.model = context.model + self.state = nil + self.cache = context.model.newCache(parameters: parameters) + self.processor = parameters.processor() + self.sampler = parameters.sampler() + + processor?.prompt(input.text.tokens) + + let prepareResult = try context.model.prepare( + input, + cache: cache, + windowSize: parameters.prefillStepSize + ) + + let output: MLXLMCommon.LMOutput + switch prepareResult { + case .tokens(let tokensToProcess): + output = context.model( + tokensToProcess[text: .newAxis], + cache: cache, + state: state + ) + case .logits(let logitsOutput): + output = logitsOutput + } + + self.state = output.state + self.currentLogits = output.logits + + guard output.logits.shape.count >= 1 else { + throw StructuredGenerationError.invalidVocabSize + } + self.vocabSize = output.logits.shape.last ?? 0 + guard self.vocabSize > 0 else { + throw StructuredGenerationError.invalidVocabSize + } + } + + mutating func decodeToken(_ token: Int) throws { + let tokenArray = MLXArray(token) + processor?.didSample(token: tokenArray) + + let inputText = MLXLMCommon.LMInput.Text(tokens: tokenArray) + let output = model( + inputText[text: .newAxis], + cache: cache.isEmpty ? nil : cache, + state: state + ) + state = output.state + currentLogits = output.logits + } + + mutating func sampleToken(allowedTokens: Set) throws -> Int { + guard !allowedTokens.isEmpty else { throw StructuredGenerationError.invalidTokenization } + + var logits = currentLogits[0..., -1, 0...] + logits = processor?.process(logits: logits) ?? logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + let allowedIndices = MLXArray(allowedTokens.map { UInt32($0) }) + let maskedLogits = full(logits.shape, values: -Float.infinity) + maskedLogits[0..., allowedIndices] = logits[0..., allowedIndices] + + let sampledToken = sampler.sample(logits: maskedLogits) + processor?.didSample(token: sampledToken) + return sampledToken.item(Int.self) + } + } + + private struct StructuredJSONGenerator { + let schema: GenerationSchema + let tokenizeFragment: (String) throws -> [Int] + let tokenText: (Int) -> String + let decodeToken: (Int) throws -> Void + let sampleToken: (Set) throws -> Int + + var remainingTokens: Int + let totalTokenBudget: Int + + let quoteToken: Int + let digitOnlyTokens: Set + let validStringTokens: Set + let validStringTokensOrQuote: Set + + init( + schema: GenerationSchema, + tokenizeFragment: @escaping (String) throws -> [Int], + tokenText: @escaping (Int) -> String, + decodeToken: @escaping (Int) throws -> Void, + sampleToken: @escaping (Set) throws -> Int, + maximumTokens: Int, + vocabSize: Int + ) throws { + self.schema = schema + self.tokenizeFragment = tokenizeFragment + self.tokenText = tokenText + self.decodeToken = decodeToken + self.sampleToken = sampleToken + self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens + + let quoteTokens = try tokenizeFragment("\"") + guard quoteTokens.count == 1, let quoteToken = quoteTokens.first else { + throw StructuredGenerationError.invalidQuoteToken + } + self.quoteToken = quoteToken + + self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens( + vocabSize: vocabSize, + tokenText: tokenText + ) + self.validStringTokens = StructuredJSONGenerator.buildValidJSONStringContentTokens( + vocabSize: vocabSize, + tokenText: tokenText + ) + var tokensOrQuote = self.validStringTokens + tokensOrQuote.insert(quoteToken) + self.validStringTokensOrQuote = tokensOrQuote + } + + mutating func generate() throws -> String { + try generateNode(schema.root) + } + + private func maxTokenCountForFreeString() -> Int { + let perStringLimit = max(32, totalTokenBudget / 4) + return min(remainingTokens, perStringLimit) + } + + private static func buildDigitOnlyTokens( + vocabSize: Int, + tokenText: (Int) -> String + ) -> Set { + Set((0 ..< vocabSize).filter { tokenId in + let text = tokenText(tokenId) + guard !text.isEmpty else { return false } + return text.allSatisfy({ $0.isNumber }) + }) + } + + private static func buildValidJSONStringContentTokens( + vocabSize: Int, + tokenText: (Int) -> String + ) -> Set { + var allowed = Set() + allowed.reserveCapacity(vocabSize / 4) + + for tokenId in 0 ..< vocabSize { + let text = tokenText(tokenId) + guard !text.isEmpty else { continue } + if text.contains("\"") { continue } + if text.contains("\\") { continue } + if text.unicodeScalars.contains(where: { $0.value < 0x20 }) { continue } + allowed.insert(tokenId) + } + return allowed + } + + private mutating func emitLiteral(_ text: String) throws -> String { + for token in try tokenizeFragment(text) { + guard remainingTokens > 0 else { throw StructuredGenerationError.tokenBudgetExceeded } + try decodeToken(token) + remainingTokens -= 1 + } + return text + } + + private mutating func generateFreeString(maxTokens: Int) throws -> String { + var result = "" + var generatedTokens = 0 + + while remainingTokens > 0, generatedTokens < maxTokens { + let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrQuote + let token = try sampleToken(allowedTokens) + if token == quoteToken { break } + + result += tokenText(token) + generatedTokens += 1 + try decodeToken(token) + remainingTokens -= 1 + } + + return result + } + + private mutating func generateLiteralChoice(_ candidates: [String]) throws -> String { + let tokenizedCandidates = try candidates.map { try tokenizeFragment($0) }.filter { !$0.isEmpty } + guard !tokenizedCandidates.isEmpty else { throw StructuredGenerationError.invalidTokenization } + + var prefixes = tokenizedCandidates + var emitted = "" + var tokenPosition = 0 + + while remainingTokens > 0 { + if prefixes.contains(where: { $0.count == tokenPosition }) { break } + + let allowed = Set(prefixes.compactMap { tokens -> Int? in + guard tokenPosition < tokens.count else { return nil } + return tokens[tokenPosition] + }) + + let nextToken = try sampleToken(allowed) + emitted += tokenText(nextToken) + try decodeToken(nextToken) + remainingTokens -= 1 + + prefixes = prefixes.filter { tokens in + tokenPosition < tokens.count && tokens[tokenPosition] == nextToken + } + tokenPosition += 1 + if prefixes.isEmpty { break } + } + + return emitted + } + + private mutating func generateNumber(isInteger: Bool) throws -> String { + let maxTokens = isInteger ? 3 : 4 + var generated = "" + + for _ in 0 ..< maxTokens { + guard remainingTokens > 0 else { break } + let token = try sampleToken(digitOnlyTokens) + generated += tokenText(token) + try decodeToken(token) + remainingTokens -= 1 + if !generated.isEmpty { break } + } + + return generated.isEmpty ? "0" : generated + } + + private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { + let elementCount = arrayNode.minItems ?? arrayNode.maxItems ?? 4 + var output = try emitLiteral("[") + + for index in 0 ..< elementCount { + output += try generateNode(arrayNode.items) + if index < elementCount - 1 { + output += try emitLiteral(",") + } + } + + output += try emitLiteral("]") + return output + } + + private mutating func generateObject(_ objectNode: GenerationSchema.ObjectNode) throws -> String { + let keys = objectNode.properties.keys.sorted() + var output = try emitLiteral("{") + + for (index, key) in keys.enumerated() { + output += try emitLiteral("\"") + output += try emitLiteral(key) + output += try emitLiteral("\":") + + if let propertyNode = objectNode.properties[key] { + output += try generateNode(propertyNode) + } else { + output += try emitLiteral("null") + } + + if index < keys.count - 1 { + output += try emitLiteral(",") + } + } + + output += try emitLiteral("}") + return output + } + + private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { + guard remainingTokens > 0 else { throw StructuredGenerationError.tokenBudgetExceeded } + + switch node { + case .string(let stringNode): + var output = try emitLiteral("\"") + if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { + output += try generateLiteralChoice(enumChoices) + } else { + output += try generateFreeString(maxTokens: maxTokenCountForFreeString()) + } + output += try emitLiteral("\"") + return output + + case .number(let numberNode): + return try generateNumber(isInteger: numberNode.integerOnly) + + case .boolean: + return try generateLiteralChoice(["true", "false"]) + + case .array(let arrayNode): + return try generateArray(arrayNode) + + case .object(let objectNode): + return try generateObject(objectNode) + + case .anyOf(let nodes): + guard let first = nodes.first else { throw StructuredGenerationError.invalidTokenization } + return try generateNode(first) + + case .ref(let refName): + guard let referenced = schema.defs[refName] else { throw StructuredGenerationError.invalidTokenization } + return try generateNode(referenced) + } + } + } #endif // MLX From afcbd9f4918362ea701ca4b105dab547ca304709 Mon Sep 17 00:00:00 2001 From: eastriver Date: Sun, 21 Dec 2025 01:09:20 +0900 Subject: [PATCH 4/7] Fix duplicate type crash in schema generation --- .../AnyLanguageModel/GenerationSchema.swift | 43 ++++++++++++- .../Models/LlamaLanguageModel.swift | 55 ++++++++++++---- .../Models/MLXLanguageModel.swift | 62 +++++++++++++------ .../StructuredGenerationTests.swift | 44 +++++++++++++ 4 files changed, 172 insertions(+), 32 deletions(-) diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index c5a59126..6e7b771b 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -514,12 +514,32 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible } private static func nodesEqual(_ a: Node, _ b: Node) -> Bool { - // Simple structural equality - could be enhanced switch (a, b) { case (.boolean, .boolean): return true case (.ref(let aName), .ref(let bName)): return aName == bName + case (.string(let aString), .string(let bString)): + return aString.pattern == bString.pattern + && aString.enumChoices == bString.enumChoices + case (.number(let aNumber), .number(let bNumber)): + return aNumber.integerOnly == bNumber.integerOnly + && aNumber.minimum == bNumber.minimum + && aNumber.maximum == bNumber.maximum + case (.array(let aArray), .array(let bArray)): + return aArray.minItems == bArray.minItems + && aArray.maxItems == bArray.maxItems + && nodesEqual(aArray.items, bArray.items) + case (.object(let aObject), .object(let bObject)): + return aObject.required == bObject.required + && aObject.properties.keys == bObject.properties.keys + && aObject.properties.allSatisfy { key, aNode in + guard let bNode = bObject.properties[key] else { return false } + return nodesEqual(aNode, bNode) + } + case (.anyOf(let aNodes), .anyOf(let bNodes)): + return aNodes.count == bNodes.count + && zip(aNodes, bNodes).allSatisfy(nodesEqual) default: return false } @@ -817,4 +837,25 @@ extension GenerationSchema { /// let data = try encoder.encode(schema) /// ``` static let omitAdditionalPropertiesKey = CodingUserInfoKey(rawValue: "GenerationSchema.omitAdditionalProperties")! + + package func schemaPrompt() -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + guard let data = try? encoder.encode(self), + let schemaJSON = String(data: data, encoding: .utf8) else { + return "Respond with valid JSON only." + } + return "Respond with valid JSON matching this schema:\n\(schemaJSON)" + } +} + +extension Character { + package static let jsonQuoteScalars: Set = [0x22, 0x201C, 0x201D, 0x2018, 0x2019] + + package var isValidJSONStringCharacter: Bool { + guard self != "\\" else { return false } + guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false } + guard !Self.jsonQuoteScalars.contains(scalar.value) else { return false } + return true + } } diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index af68c551..181ad69a 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -503,7 +503,15 @@ import Foundation llama_set_warmup(context, false) llama_set_n_threads(context, runtimeOptions.threads, runtimeOptions.threads) - let fullPrompt = try formatPrompt(for: session) + let fullPrompt: String + if includeSchemaInPrompt, type != String.self { + fullPrompt = try formatPrompt( + for: session, + extraSystemMessage: type.generationSchema.schemaPrompt() + ) + } else { + fullPrompt = try formatPrompt(for: session) + } if type == String.self { let maxTokens = runtimeOptions.maximumResponseTokens ?? 100 @@ -892,7 +900,19 @@ import Foundation defer { llama_sampler_free(sampler) } let samplerPointer = UnsafeMutablePointer(sampler) - applySampling(sampler: samplerPointer, effectiveTemperature: options.temperature, options: options) + let effectiveTemperature: Float = 0.2 + if options.repeatPenalty != 1.0 || options.frequencyPenalty != 0.0 || options.presencePenalty != 0.0 { + llama_sampler_chain_add( + samplerPointer, + llama_sampler_init_penalties( + options.repeatLastN, + options.repeatPenalty, + options.frequencyPenalty, + options.presencePenalty + ) + ) + } + applySampling(sampler: samplerPointer, effectiveTemperature: effectiveTemperature, options: options) let vocabSize = Int(llama_vocab_n_tokens(vocab)) let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens @@ -927,6 +947,7 @@ import Foundation let totalTokenBudget: Int let quoteToken: llama_token + let eosToken: llama_token let digitOnlyTokens: Set let validStringTokens: Set let validStringTokensOrQuote: Set @@ -957,13 +978,16 @@ import Foundation throw LlamaLanguageModelError.tokenizationFailed } self.quoteToken = quoteToken + self.eosToken = llama_vocab_eos(vocab) + self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens(vocabSize: vocabSize, tokenToText: tokenToText) self.validStringTokens = StructuredJSONGenerator.buildValidJSONStringContentTokens( vocabSize: vocabSize, tokenToText: tokenToText ) - var tokensOrQuote = self.validStringTokens + var tokensOrQuote = validStringTokens tokensOrQuote.insert(quoteToken) + tokensOrQuote.insert(eosToken) self.validStringTokensOrQuote = tokensOrQuote } @@ -993,10 +1017,9 @@ import Foundation for tokenIndex in 0 ..< vocabSize { let token = llama_token(tokenIndex) guard let text = tokenToText(token), !text.isEmpty else { continue } - if text.contains("\"") { continue } - if text.contains("\\") { continue } - if text.unicodeScalars.contains(where: { $0.value < 0x20 }) { continue } - allowed.insert(token) + if text.allSatisfy({ $0.isValidJSONStringCharacter }) { + allowed.insert(token) + } } return allowed } @@ -1037,6 +1060,9 @@ import Foundation guard decodeResult == 0 else { throw LlamaLanguageModelError.decodingFailed } + + // Keep sampler state aligned with the actual context tokens. + llama_sampler_accept(sampler, token) } private mutating func emitLiteral(_ text: String) throws -> String { @@ -1060,9 +1086,7 @@ import Foundation } let tokenIndex = batch.pointee.n_tokens - 1 - let sampled = llama_sampler_sample(sampler, context, tokenIndex) - llama_sampler_accept(sampler, sampled) - return sampled + return llama_sampler_sample(sampler, context, tokenIndex) } private func maxTokenCountForFreeString() -> Int { @@ -1077,7 +1101,7 @@ import Foundation while remainingTokens > 0, generatedTokens < maxTokens { let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrQuote let token = sampleToken(allowedTokens: allowedTokens) - if token == quoteToken { break } + if token == quoteToken || token == eosToken { break } result += tokenToText(token) ?? "" generatedTokens += 1 @@ -1452,7 +1476,10 @@ import Foundation return hasEncoder } - private func formatPrompt(for session: LanguageModelSession) throws -> String { + private func formatPrompt( + for session: LanguageModelSession, + extraSystemMessage: String? = nil + ) throws -> String { guard let model = self.model else { throw LlamaLanguageModelError.modelLoadFailed } @@ -1484,6 +1511,10 @@ import Foundation } } + if let extraSystemMessage, !extraSystemMessage.isEmpty { + messages.append(("system", extraSystemMessage)) + } + // Keep C strings alive while using them let cRoles = messages.map { strdup($0.role) } let cContents = messages.map { strdup($0.content) } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index aabf0b61..fd2738f9 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -68,12 +68,14 @@ import Foundation } if type != String.self { + let schema = type.generationSchema let jsonString = try await generateStructuredJSON( session: session, prompt: prompt, context: context, options: options, - schema: type.generationSchema + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt ) let generatedContent = try GeneratedContent(json: jsonString) @@ -261,6 +263,20 @@ import Foundation ) } + private func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { + MLXLMCommon.GenerateParameters( + maxTokens: options.maximumResponseTokens, + maxKVSize: nil, + kvBits: nil, + kvGroupSize: 64, + quantizedKVStart: 0, + temperature: 0.2, + topP: 0.95, + repetitionPenalty: 1.1, + repetitionContextSize: 64 + ) + } + // MARK: - Transcript Conversion private func convertTranscriptToMLXChat( @@ -522,7 +538,7 @@ import Foundation return textParts.joined(separator: "\n") } - // MARK: - Structured JSON Generation (logit constrained) + // MARK: - Structured JSON Generation private enum StructuredGenerationError: Error { case missingTokenizer @@ -538,12 +554,16 @@ import Foundation prompt: Prompt, context: ModelContext, options: GenerationOptions, - schema: GenerationSchema + schema: GenerationSchema, + includeSchemaInPrompt: Bool ) async throws -> String { let structuredMaxTokens = options.maximumResponseTokens ?? 512 - let generateParameters = toGenerateParameters(options) + let generateParameters = toStructuredGenerateParameters(options) - let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + if includeSchemaInPrompt { + chat.insert(.init(role: .system, content: schema.schemaPrompt()), at: 0) + } let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), @@ -559,6 +579,7 @@ import Foundation ) let vocabSize = decoder.vocabSize + let eosToken = context.tokenizer.eosTokenId ?? -1 var generator = try StructuredJSONGenerator( schema: schema, tokenizeFragment: { fragment in @@ -574,7 +595,8 @@ import Foundation try decoder.sampleToken(allowedTokens: allowedTokens) }, maximumTokens: structuredMaxTokens, - vocabSize: vocabSize + vocabSize: vocabSize, + eosToken: eosToken ) let json = try generator.generate() @@ -679,10 +701,10 @@ import Foundation var remainingTokens: Int let totalTokenBudget: Int - let quoteToken: Int let digitOnlyTokens: Set let validStringTokens: Set - let validStringTokensOrQuote: Set + let stringTerminators: Set + let validStringTokensOrTerminators: Set init( schema: GenerationSchema, @@ -691,7 +713,8 @@ import Foundation decodeToken: @escaping (Int) throws -> Void, sampleToken: @escaping (Set) throws -> Int, maximumTokens: Int, - vocabSize: Int + vocabSize: Int, + eosToken: Int ) throws { self.schema = schema self.tokenizeFragment = tokenizeFragment @@ -701,11 +724,15 @@ import Foundation self.remainingTokens = maximumTokens self.totalTokenBudget = maximumTokens + // String terminators: EOS + dumb quote (") + var terminators = Set() + if eosToken >= 0 { terminators.insert(eosToken) } let quoteTokens = try tokenizeFragment("\"") guard quoteTokens.count == 1, let quoteToken = quoteTokens.first else { throw StructuredGenerationError.invalidQuoteToken } - self.quoteToken = quoteToken + terminators.insert(quoteToken) + self.stringTerminators = terminators self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens( vocabSize: vocabSize, @@ -715,9 +742,7 @@ import Foundation vocabSize: vocabSize, tokenText: tokenText ) - var tokensOrQuote = self.validStringTokens - tokensOrQuote.insert(quoteToken) - self.validStringTokensOrQuote = tokensOrQuote + self.validStringTokensOrTerminators = validStringTokens.union(stringTerminators) } mutating func generate() throws -> String { @@ -750,10 +775,9 @@ import Foundation for tokenId in 0 ..< vocabSize { let text = tokenText(tokenId) guard !text.isEmpty else { continue } - if text.contains("\"") { continue } - if text.contains("\\") { continue } - if text.unicodeScalars.contains(where: { $0.value < 0x20 }) { continue } - allowed.insert(tokenId) + if text.allSatisfy({ $0.isValidJSONStringCharacter }) { + allowed.insert(tokenId) + } } return allowed } @@ -772,9 +796,9 @@ import Foundation var generatedTokens = 0 while remainingTokens > 0, generatedTokens < maxTokens { - let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrQuote + let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrTerminators let token = try sampleToken(allowedTokens) - if token == quoteToken { break } + if stringTerminators.contains(token) { break } result += tokenText(token) generatedTokens += 1 diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift index df70c031..accb4c90 100644 --- a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -70,6 +70,18 @@ struct Address: Equatable { var postalCode: String } +@Generable +struct ReusedNestedStruct: Equatable { + @Guide(description: "Some text") + var text: String +} + +@Generable +struct ContainerWithDuplicateNestedType: Equatable { + var first: ReusedNestedStruct + var second: ReusedNestedStruct +} + @Generable struct Person: Equatable { @Guide(description: "Person's name") @@ -158,6 +170,14 @@ private func isGenerationTestsEnabled() -> Bool { !supportedModels.isEmpty } +@Test("GenerationSchema merges duplicate defs for the same type") +func generationSchemaMergesDuplicateDefsForSameType() { + let schema = ContainerWithDuplicateNestedType.generationSchema + + let nestedTypeName = String(reflecting: ReusedNestedStruct.self) + #expect(schema.defs[nestedTypeName] != nil) +} + private func testAllModels(_ test: (SupportedModel) async throws -> Void) async { var failures: [(name: String, error: any Error)] = [] @@ -174,6 +194,19 @@ private func testAllModels(_ test: (SupportedModel) async throws -> Void) async } } +private func logGenerated(_ content: T, model: String) { + let json = content.generatedContent.jsonString + if let data = json.data(using: .utf8), + let object = try? JSONSerialization.jsonObject(with: data), + let prettyData = try? JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]), + let prettyJSON = String(data: prettyData, encoding: .utf8) + { + print("\n[\(model)]\n\(prettyJSON)\n") + } else { + print("\n[\(model)]\n\(json)\n") + } +} + @Suite("Structured Generation", .serialized, .enabled(if: isGenerationTestsEnabled())) struct StructuredGenerationTests { @Test("Generate SimpleString with all supported models") @@ -189,6 +222,7 @@ struct StructuredGenerationTests { generating: SimpleString.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.message.isEmpty, "[\(model.name)] message should not be empty") } } @@ -206,6 +240,7 @@ struct StructuredGenerationTests { generating: SimpleInt.self ) + logGenerated(response.content, model: model.name) #expect(response.content.count >= 0, "[\(model.name)] count should be non-negative") } } @@ -223,6 +258,7 @@ struct StructuredGenerationTests { generating: SimpleDouble.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.temperature.isNaN, "[\(model.name)] temperature should be a valid number") } } @@ -240,6 +276,7 @@ struct StructuredGenerationTests { generating: SimpleBool.self ) + logGenerated(response.content, model: model.name) let jsonData = response.rawContent.jsonString.data(using: .utf8) #expect(jsonData != nil, "[\(model.name)] rawContent should be valid UTF-8 JSON") if let jsonData { @@ -264,6 +301,7 @@ struct StructuredGenerationTests { generating: OptionalFields.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") if let nickname = response.content.nickname { #expect(!nickname.isEmpty, "[\(model.name)] nickname should not be empty when present") @@ -284,6 +322,7 @@ struct StructuredGenerationTests { generating: Priority.self ) + logGenerated(response.content, model: model.name) #expect( [Priority.low, Priority.medium, Priority.high].contains(response.content), "[\(model.name)] should generate valid priority" @@ -304,6 +343,7 @@ struct StructuredGenerationTests { generating: BasicStruct.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") } @@ -322,6 +362,7 @@ struct StructuredGenerationTests { generating: Person.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.name.isEmpty, "[\(model.name)] name should not be empty") #expect(response.content.age >= 0, "[\(model.name)] age should be non-negative") #expect(!response.content.address.street.isEmpty, "[\(model.name)] street should not be empty") @@ -342,6 +383,7 @@ struct StructuredGenerationTests { generating: TaskItem.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.title.isEmpty, "[\(model.name)] title should not be empty") #expect( [Priority.low, Priority.medium, Priority.high].contains(response.content.priority), @@ -363,6 +405,7 @@ struct StructuredGenerationTests { generating: SimpleArray.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.colors.isEmpty, "[\(model.name)] colors should not be empty") } } @@ -386,6 +429,7 @@ struct StructuredGenerationTests { generating: MultiChoiceQuestion.self ) + logGenerated(response.content, model: model.name) #expect(!response.content.text.isEmpty, "[\(model.name)] question text should not be empty") #expect(response.content.choices.count == 4, "[\(model.name)] should have exactly 4 choices") #expect(!response.content.answer.isEmpty, "[\(model.name)] answer should not be empty") From 1d4293011cb9303ac4e901ba6b15a131719127a3 Mon Sep 17 00:00:00 2001 From: eastriver Date: Sun, 21 Dec 2025 12:17:19 +0900 Subject: [PATCH 5/7] Enforce count + numeric range guides --- .../AnyLanguageModel/GenerationGuide.swift | 51 ++-- .../AnyLanguageModel/GenerationSchema.swift | 20 +- .../Models/LlamaLanguageModel.swift | 47 +++- .../Models/MLXLanguageModel.swift | 48 ++-- .../GenerableMacro.swift | 219 +++++++++++++----- 5 files changed, 287 insertions(+), 98 deletions(-) diff --git a/Sources/AnyLanguageModel/GenerationGuide.swift b/Sources/AnyLanguageModel/GenerationGuide.swift index 35b4c49a..c9ab16f9 100644 --- a/Sources/AnyLanguageModel/GenerationGuide.swift +++ b/Sources/AnyLanguageModel/GenerationGuide.swift @@ -2,7 +2,24 @@ import struct Foundation.Decimal import class Foundation.NSDecimalNumber /// Guides that control how values are generated. -public struct GenerationGuide {} +public struct GenerationGuide: Sendable { + package var minimumCount: Int? + package var maximumCount: Int? + package var minimum: Double? + package var maximum: Double? + + public init() {} + + package init(minimumCount: Int?, maximumCount: Int?) { + self.minimumCount = minimumCount + self.maximumCount = maximumCount + } + + package init(minimum: Double?, maximum: Double?) { + self.minimum = minimum + self.maximum = maximum + } +} // MARK: - String Guides @@ -45,7 +62,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func minimum(_ value: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: Double(value), maximum: nil) } /// Enforces a maximum value. @@ -65,7 +82,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func maximum(_ value: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: nil, maximum: Double(value)) } /// Enforces values fall within a range. @@ -85,7 +102,7 @@ extension GenerationGuide where Value == Int { /// } /// ``` public static func range(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: Double(range.lowerBound), maximum: Double(range.upperBound)) } } @@ -144,18 +161,18 @@ extension GenerationGuide where Value == Double { /// Enforces a minimum value. /// The bounds are inclusive. public static func minimum(_ value: Double) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: value, maximum: nil) } /// Enforces a maximum value. /// The bounds are inclusive. public static func maximum(_ value: Double) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: nil, maximum: value) } /// Enforces values fall within a range. public static func range(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimum: range.lowerBound, maximum: range.upperBound) } } @@ -168,7 +185,7 @@ extension GenerationGuide { /// The bounds are inclusive. public static func minimumCount(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: count, maximumCount: nil) } /// Enforces a maximum number of elements in the array. @@ -176,25 +193,23 @@ extension GenerationGuide { /// The bounds are inclusive. public static func maximumCount(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: nil, maximumCount: count) } /// Enforces that the number of elements in the array fall within a closed range. public static func count(_ range: ClosedRange) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: range.lowerBound, maximumCount: range.upperBound) } /// Enforces that the array has exactly a certain number elements. public static func count(_ count: Int) -> GenerationGuide<[Element]> where Value == [Element] { - GenerationGuide<[Element]>() + GenerationGuide<[Element]>(minimumCount: count, maximumCount: count) } /// Enforces a guide on the elements within the array. - public static func element(_ guide: GenerationGuide) -> GenerationGuide< - [Element] - > + public static func element(_ guide: GenerationGuide) -> GenerationGuide<[Element]> where Value == [Element] { GenerationGuide<[Element]>() } @@ -210,7 +225,7 @@ extension GenerationGuide where Value == [Never] { /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.minimumCount(_:)` on your own. public static func minimumCount(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: count, maximumCount: nil) } /// Enforces a maximum number of elements in the array. @@ -219,20 +234,20 @@ extension GenerationGuide where Value == [Never] { /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.maximumCount(_:)` on your own. public static func maximumCount(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: nil, maximumCount: count) } /// Enforces that the number of elements in the array fall within a closed range. /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own. public static func count(_ range: ClosedRange) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: range.lowerBound, maximumCount: range.upperBound) } /// Enforces that the array has exactly a certain number elements. /// /// - Warning: This overload is only used for macro expansion. Don't call `GenerationGuide<[Never]>.count(_:)` on your own. public static func count(_ count: Int) -> GenerationGuide { - GenerationGuide() + GenerationGuide(minimumCount: count, maximumCount: count) } } diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index 6e7b771b..5a226c77 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -723,12 +723,24 @@ extension GenerationSchema { } else if type == String.self { return (.string(StringNode(description: description, pattern: nil, enumChoices: nil)), [:]) } else if type == Int.self { + var minimum: Double? + var maximum: Double? + for guide in guides { + if let min = guide.minimum { minimum = min } + if let max = guide.maximum { maximum = max } + } return ( - .number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: true)), [:] + .number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: true)), [:] ) } else if type == Float.self || type == Double.self || type == Decimal.self { + var minimum: Double? + var maximum: Double? + for guide in guides { + if let min = guide.minimum { minimum = min } + if let max = guide.maximum { maximum = max } + } return ( - .number(NumberNode(description: description, minimum: nil, maximum: nil, integerOnly: false)), [:] + .number(NumberNode(description: description, minimum: minimum, maximum: maximum, integerOnly: false)), [:] ) } else { // Complex type - use its schema @@ -737,6 +749,10 @@ extension GenerationSchema { // Arrays should be inlined, not referenced if case .array(var arrayNode) = schema.root { arrayNode.description = description + for guide in guides { + if let min = guide.minimumCount { arrayNode.minItems = min } + if let max = guide.maximumCount { arrayNode.maxItems = max } + } return (.array(arrayNode), schema.defs) } diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 181ad69a..c42230f7 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -1142,19 +1142,42 @@ import Foundation return emitted } - private mutating func generateNumber(isInteger: Bool) throws -> String { - let maxTokens = isInteger ? 3 : 4 - var generated = "" - - for _ in 0 ..< maxTokens { - guard remainingTokens > 0 else { break } - let token = sampleToken(allowedTokens: digitOnlyTokens) - generated += tokenToText(token) ?? "" - try decodeToken(token) - if !generated.isEmpty { break } + private mutating func generateNumber(_ numberNode: GenerationSchema.NumberNode) throws -> String { + if numberNode.integerOnly { + let clamped = clampInteger(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) + return try emitLiteral(String(clamped)) } - return generated.isEmpty ? "0" : generated + let clamped = clampDouble(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) + return try emitLiteral(formatNumberLiteral(clamped)) + } + + private func clampInteger(defaultValue: Int, minimum: Double?, maximum: Double?) -> Int { + let clampedMinimum = minimum.map { Int(ceil($0)) } + let clampedMaximum = maximum.map { Int(floor($0)) } + let normalized = normalizeBounds(minimum: clampedMinimum, maximum: clampedMaximum) + return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) + } + + private func clampDouble(defaultValue: Double, minimum: Double?, maximum: Double?) -> Double { + let normalized = normalizeBounds(minimum: minimum, maximum: maximum) + return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) + } + + private func normalizeBounds(minimum: T?, maximum: T?) -> (minimum: T?, maximum: T?) { + guard let minimum, let maximum, minimum > maximum else { return (minimum, maximum) } + return (minimum, minimum) + } + + private func clamp(_ value: T, minimum: T?, maximum: T?) -> T { + if let minimum, value < minimum { return minimum } + if let maximum, value > maximum { return maximum } + return value + } + + private func formatNumberLiteral(_ value: Double) -> String { + if value.rounded() == value { return String(Int(value)) } + return String(value) } private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { @@ -1211,7 +1234,7 @@ import Foundation return output case .number(let numberNode): - return try generateNumber(isInteger: numberNode.integerOnly) + return try generateNumber(numberNode) case .boolean: return try generateLiteralChoice(["true", "false"]) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index fd2738f9..a64ce66b 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -840,20 +840,42 @@ import Foundation return emitted } - private mutating func generateNumber(isInteger: Bool) throws -> String { - let maxTokens = isInteger ? 3 : 4 - var generated = "" - - for _ in 0 ..< maxTokens { - guard remainingTokens > 0 else { break } - let token = try sampleToken(digitOnlyTokens) - generated += tokenText(token) - try decodeToken(token) - remainingTokens -= 1 - if !generated.isEmpty { break } + private mutating func generateNumber(_ numberNode: GenerationSchema.NumberNode) throws -> String { + if numberNode.integerOnly { + let clamped = clampInteger(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) + return try emitLiteral(String(clamped)) } - return generated.isEmpty ? "0" : generated + let clamped = clampDouble(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) + return try emitLiteral(formatNumberLiteral(clamped)) + } + + private func clampInteger(defaultValue: Int, minimum: Double?, maximum: Double?) -> Int { + let clampedMinimum = minimum.map { Int(ceil($0)) } + let clampedMaximum = maximum.map { Int(floor($0)) } + let normalized = normalizeBounds(minimum: clampedMinimum, maximum: clampedMaximum) + return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) + } + + private func clampDouble(defaultValue: Double, minimum: Double?, maximum: Double?) -> Double { + let normalized = normalizeBounds(minimum: minimum, maximum: maximum) + return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) + } + + private func normalizeBounds(minimum: T?, maximum: T?) -> (minimum: T?, maximum: T?) { + guard let minimum, let maximum, minimum > maximum else { return (minimum, maximum) } + return (minimum, minimum) + } + + private func clamp(_ value: T, minimum: T?, maximum: T?) -> T { + if let minimum, value < minimum { return minimum } + if let maximum, value > maximum { return maximum } + return value + } + + private func formatNumberLiteral(_ value: Double) -> String { + if value.rounded() == value { return String(Int(value)) } + return String(value) } private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { @@ -910,7 +932,7 @@ import Foundation return output case .number(let numberNode): - return try generateNumber(isInteger: numberNode.integerOnly) + return try generateNumber(numberNode) case .boolean: return try generateLiteralChoice(["true", "false"]) diff --git a/Sources/AnyLanguageModelMacros/GenerableMacro.swift b/Sources/AnyLanguageModelMacros/GenerableMacro.swift index a44707e7..2c1d027f 100644 --- a/Sources/AnyLanguageModelMacros/GenerableMacro.swift +++ b/Sources/AnyLanguageModelMacros/GenerableMacro.swift @@ -106,19 +106,15 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { let binding = varDecl.bindings.first, let identifier = binding.pattern.as(IdentifierPatternSyntax.self) { - let propertyName = identifier.identifier.text let propertyType = binding.typeAnnotation?.type.description ?? "String" - let guideInfo = extractGuideInfo(from: varDecl.attributes) properties.append( PropertyInfo( name: propertyName, type: propertyType, - guideDescription: guideInfo.description, - guides: guideInfo.guides, - pattern: guideInfo.pattern + guide: guideInfo ) ) } @@ -140,32 +136,96 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { in: .init(charactersIn: "\"") ) - var guides: [String] = [] - var pattern: String? = nil + var constraints = Constraints() for arg in Array(arguments.dropFirst()) { - let argText = arg.expression.description + let guideExpression = arg.expression + if let parsedPattern = parsePatternFromExpression(guideExpression) { + constraints.pattern = parsedPattern + continue + } - if argText.contains(".pattern(") { - let patternRegex = #/\.pattern\(\"([^\"]*)\"\)/# - if let match = argText.firstMatch(of: patternRegex) { - pattern = String(match.1) - } - } else if argText.contains("pattern(") { - let patternRegex = #/pattern\(\"([^\"]*)\"\)/# - if let match = argText.firstMatch(of: patternRegex) { - pattern = String(match.1) - } - } else { - guides.append(argText) + if let functionCall = guideExpression.as(FunctionCallExprSyntax.self) { + applyConstraints(from: functionCall, into: &constraints) + } else if let memberAccess = guideExpression.as(MemberAccessExprSyntax.self), + let functionCall = memberAccess.base?.as(FunctionCallExprSyntax.self) { + applyConstraints(from: functionCall, into: &constraints) } } - return GuideInfo(description: description, guides: guides, pattern: pattern) + return GuideInfo(description: description, constraints: constraints) } } } - return GuideInfo(description: nil, guides: [], pattern: nil) + return GuideInfo(description: nil, constraints: Constraints()) + } + + private static func applyConstraints(from call: FunctionCallExprSyntax, into constraints: inout Constraints) { + let functionName: String? + if let memberAccess = call.calledExpression.as(MemberAccessExprSyntax.self) { + functionName = memberAccess.declName.baseName.text + } else if let identifier = call.calledExpression.as(DeclReferenceExprSyntax.self) { + functionName = identifier.baseName.text + } else { + functionName = nil + } + + guard let functionName, let firstArgument = call.arguments.first else { return } + + switch functionName { + case "count": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.minimumCount = value + constraints.maximumCount = value + } else if let rangeExpression = firstArgument.expression.as(SequenceExprSyntax.self) { + let (minimum, maximum) = parseClosedRangeInt(rangeExpression) + constraints.minimumCount = minimum + constraints.maximumCount = maximum + } + case "minimumCount": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.minimumCount = value + } + case "maximumCount": + if let intLiteral = firstArgument.expression.as(IntegerLiteralExprSyntax.self), + let value = Int(intLiteral.literal.text) { + constraints.maximumCount = value + } + case "minimum": + constraints.minimum = parseNumericLiteral(firstArgument.expression) + case "maximum": + constraints.maximum = parseNumericLiteral(firstArgument.expression) + case "range": + if let rangeExpression = firstArgument.expression.as(SequenceExprSyntax.self) { + let (minimum, maximum) = parseClosedRangeDouble(rangeExpression) + constraints.minimum = minimum + constraints.maximum = maximum + } + default: + break + } + } + + private static func parsePatternFromExpression(_ expression: ExprSyntax) -> String? { + if let functionCall = expression.as(FunctionCallExprSyntax.self) { + let functionName: String? + if let memberAccess = functionCall.calledExpression.as(MemberAccessExprSyntax.self) { + functionName = memberAccess.declName.baseName.text + } else if let identifier = functionCall.calledExpression.as(DeclReferenceExprSyntax.self) { + functionName = identifier.baseName.text + } else { + functionName = nil + } + + if functionName == "pattern", + let firstArg = functionCall.arguments.first, + let stringLiteral = firstArg.expression.as(StringLiteralExprSyntax.self) { + return stringLiteral.segments.description.trimmingCharacters(in: .init(charactersIn: "\"")) + } + } + return nil } private static func isDictionaryType(_ type: String) -> Bool { @@ -173,6 +233,78 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { return trimmed.hasPrefix("[") && trimmed.contains(":") && trimmed.hasSuffix("]") } + private static func escapeDescriptionString(_ description: String?) -> String { + guard let description else { return "nil" } + return makeSwiftStringLiteralExpression(description) + } + + /// Escapes text so it can be embedded safely inside generated Swift source as a string literal. + /// + /// Multi-line strings need newlines converted to `\n` escape sequences, and special characters + /// (backslashes and quotes) must be escaped. + private static func makeSwiftStringLiteralExpression(_ value: String) -> String { + let escaped = value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + return "\"\(escaped)\"" + } + + private static func buildGuidesArray(for property: PropertyInfo) -> String { + let baseType = property.type.replacingOccurrences(of: "?", with: "") + + if baseType.hasPrefix("[") && baseType.hasSuffix("]") && !isDictionaryType(baseType) { + if property.guide.constraints.minimumCount != nil || property.guide.constraints.maximumCount != nil { + let minStr = property.guide.constraints.minimumCount.map { String($0) } ?? "nil" + let maxStr = property.guide.constraints.maximumCount.map { String($0) } ?? "nil" + return "[GenerationGuide(minimumCount: \(minStr), maximumCount: \(maxStr))]" + } + return "[]" + } + + if baseType == "Int" || baseType == "Double" || baseType == "Float" { + if property.guide.constraints.minimum != nil || property.guide.constraints.maximum != nil { + let minStr = property.guide.constraints.minimum.map { String($0) } ?? "nil" + let maxStr = property.guide.constraints.maximum.map { String($0) } ?? "nil" + return "[GenerationGuide(minimum: \(minStr), maximum: \(maxStr))]" + } + return "[]" + } + + return "[]" + } + + private static func parseNumericLiteral(_ expression: ExprSyntax) -> Double? { + if let intLiteral = expression.as(IntegerLiteralExprSyntax.self) { + return Double(intLiteral.literal.text) + } else if let floatLiteral = expression.as(FloatLiteralExprSyntax.self) { + return Double(floatLiteral.literal.text) + } else if let prefixExpression = expression.as(PrefixOperatorExprSyntax.self), + prefixExpression.operator.text == "-" { + if let value = parseNumericLiteral(prefixExpression.expression) { + return -value + } + } + return nil + } + + private static func parseClosedRangeInt(_ expression: SequenceExprSyntax) -> (Int?, Int?) { + let elements = Array(expression.elements) + guard elements.count == 3, + let lowerBound = elements[0].as(IntegerLiteralExprSyntax.self), + let upperBound = elements[2].as(IntegerLiteralExprSyntax.self) + else { return (nil, nil) } + return (Int(lowerBound.literal.text), Int(upperBound.literal.text)) + } + + private static func parseClosedRangeDouble(_ expression: SequenceExprSyntax) -> (Double?, Double?) { + let elements = Array(expression.elements) + guard elements.count == 3 else { return (nil, nil) } + let minimum = parseNumericLiteral(elements[0]) + let maximum = parseNumericLiteral(elements[2]) + return (minimum, maximum) + } + private static func extractDictionaryTypes(_ type: String) -> (key: String, value: String)? { let trimmed = type.trimmingCharacters(in: .whitespacesAndNewlines) @@ -601,32 +733,8 @@ public struct GenerableMacro: MemberMacro, ExtensionMacro { properties: [PropertyInfo] ) -> DeclSyntax { let propertySchemas = properties.map { prop in - var guidesArray = "[]" - if !prop.guides.isEmpty || prop.pattern != nil { - var guides: [String] = [] - - if let pattern = prop.pattern { - guides.append(".pattern(\"\(pattern)\")") - } - - guides.append(contentsOf: prop.guides) - guidesArray = "[\(guides.joined(separator: ", "))]" - } - - // Escape the description string so it can be safely embedded in generated code. - // Multi-line strings need newlines converted to \n escape sequences, - // and special characters (backslashes, quotes) must be escaped. - let escapedDescription: String - if let desc = prop.guideDescription { - let escaped = - desc - .replacingOccurrences(of: "\\", with: "\\\\") // Escape backslashes first - .replacingOccurrences(of: "\"", with: "\\\"") // Escape quotes - .replacingOccurrences(of: "\n", with: "\\n") // Convert newlines to escape sequences - escapedDescription = "\"\(escaped)\"" - } else { - escapedDescription = "nil" - } + let escapedDescription = escapeDescriptionString(prop.guide.description) + let guidesArray = buildGuidesArray(for: prop) return """ GenerationSchema.Property( @@ -1204,14 +1312,19 @@ private struct EnumCaseInfo { private struct GuideInfo { let description: String? - let guides: [String] - let pattern: String? + let constraints: Constraints +} + +private struct Constraints { + var minimumCount: Int? + var maximumCount: Int? + var minimum: Double? + var maximum: Double? + var pattern: String? } private struct PropertyInfo { let name: String let type: String - let guideDescription: String? - let guides: [String] - let pattern: String? + let guide: GuideInfo } From dc062257e8f299ddcaf4e2345949919428532c1b Mon Sep 17 00:00:00 2001 From: eastriver Date: Sun, 21 Dec 2025 12:23:01 +0900 Subject: [PATCH 6/7] Respect temperature for structured generation --- Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift | 3 +-- Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index c42230f7..d30f5022 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -900,7 +900,6 @@ import Foundation defer { llama_sampler_free(sampler) } let samplerPointer = UnsafeMutablePointer(sampler) - let effectiveTemperature: Float = 0.2 if options.repeatPenalty != 1.0 || options.frequencyPenalty != 0.0 || options.presencePenalty != 0.0 { llama_sampler_chain_add( samplerPointer, @@ -912,7 +911,7 @@ import Foundation ) ) } - applySampling(sampler: samplerPointer, effectiveTemperature: effectiveTemperature, options: options) + applySampling(sampler: samplerPointer, effectiveTemperature: options.temperature, options: options) let vocabSize = Int(llama_vocab_n_tokens(vocab)) let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index a64ce66b..f0c54327 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -270,7 +270,7 @@ import Foundation kvBits: nil, kvGroupSize: 64, quantizedKVStart: 0, - temperature: 0.2, + temperature: Float(options.temperature ?? 0.2), topP: 0.95, repetitionPenalty: 1.1, repetitionContextSize: 64 From 63bcd1481f93b797ec3854688f6cd90c819389e2 Mon Sep 17 00:00:00 2001 From: eastriver Date: Mon, 22 Dec 2025 01:15:58 +0900 Subject: [PATCH 7/7] Refactor Llama and MLX structured generation to shared constrained generator --- .../AnyLanguageModel/GenerationSchema.swift | 17 +- .../Models/LlamaLanguageModel.swift | 313 +++--------- .../Models/MLXLanguageModel.swift | 475 ++++++------------ .../StructuredGeneration.swift | 314 ++++++++++++ .../StructuredGenerationTests.swift | 18 +- 5 files changed, 551 insertions(+), 586 deletions(-) create mode 100644 Sources/AnyLanguageModel/StructuredGeneration.swift diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index 5a226c77..69fd79d6 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -867,11 +867,26 @@ extension GenerationSchema { extension Character { package static let jsonQuoteScalars: Set = [0x22, 0x201C, 0x201D, 0x2018, 0x2019] + package static let jsonAllowedWhitespaceCharacters: Set = [" ", "\t", "\n"] + + package var containsEmojiScalar: Bool { + unicodeScalars.contains { scalar in + scalar.properties.isEmojiPresentation || scalar.properties.isEmoji + } + } package var isValidJSONStringCharacter: Bool { guard self != "\\" else { return false } guard let scalar = unicodeScalars.first, scalar.value >= 0x20 else { return false } guard !Self.jsonQuoteScalars.contains(scalar.value) else { return false } - return true + + if let ascii = asciiValue { + let char = Character(UnicodeScalar(ascii)) + if Self.jsonAllowedWhitespaceCharacters.contains(char) { return true } + return isLetter || isNumber || (isASCII && (isPunctuation || isSymbol)) + } + + // Allow non-ASCII letters/numbers and emoji, but disallow non-ASCII punctuation (e.g. "】") + return isLetter || isNumber || containsEmojiScalar } } diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index d30f5022..33ddba6c 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -530,19 +530,15 @@ import Foundation ) } else { let maxTokens = runtimeOptions.maximumResponseTokens ?? 512 - let schema = type.generationSchema - let jsonString = try await generateStructuredJSON( + let jsonString = try generateStructuredJSON( context: context, - model: model!, prompt: fullPrompt, - schema: schema, + schema: type.generationSchema, maxTokens: maxTokens, options: runtimeOptions ) - let generatedContent = try GeneratedContent(json: jsonString) let content = try type.init(generatedContent) - return LanguageModelSession.Response( content: content, rawContent: generatedContent, @@ -863,17 +859,16 @@ import Foundation return generatedText } - // MARK: - Structured JSON Generation (logit constrained) + // MARK: - Structured JSON Generation private func generateStructuredJSON( context: OpaquePointer, - model: OpaquePointer, prompt: String, schema: GenerationSchema, maxTokens: Int, options: ResolvedGenerationOptions - ) async throws -> String { - guard let vocab = llama_model_get_vocab(model) else { + ) throws -> String { + guard let vocab = llama_model_get_vocab(model!) else { throw LlamaLanguageModelError.contextInitializationFailed } @@ -888,7 +883,7 @@ import Foundation let hasEncoder = try prepareInitialBatch( batch: &batch, promptTokens: promptTokens, - model: model, + model: model!, vocab: vocab, context: context, batchSize: options.batchSize @@ -917,113 +912,92 @@ import Foundation let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens return try withUnsafeMutablePointer(to: &batch) { batchPointer in - let generator = try StructuredJSONGenerator( + var backend = LlamaTokenBackend( context: context, vocab: vocab, vocabSize: vocabSize, sampler: samplerPointer, - tokenToText: { token in self.tokenToText(vocab: vocab, token: token) }, batch: batchPointer, - initialPosition: initialPosition, + position: initialPosition, maximumTokens: maxTokens, - schema: schema + tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } ) + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) return try generator.generate() } } - private struct StructuredJSONGenerator { + private struct LlamaTokenBackend: TokenBackend { let context: OpaquePointer let vocab: OpaquePointer let vocabSize: Int let sampler: UnsafeMutablePointer - let tokenToText: (llama_token) -> String? let batch: UnsafeMutablePointer - let schema: GenerationSchema + let tokenToTextFn: (Int) -> String? + let tokensExcludedFromRepetitionPenalty: Set + let endTokens: Set var position: Int32 var remainingTokens: Int let totalTokenBudget: Int - - let quoteToken: llama_token - let eosToken: llama_token - let digitOnlyTokens: Set - let validStringTokens: Set - let validStringTokensOrQuote: Set + let eosToken: Int init( context: OpaquePointer, vocab: OpaquePointer, vocabSize: Int, sampler: UnsafeMutablePointer, - tokenToText: @escaping (llama_token) -> String?, batch: UnsafeMutablePointer, - initialPosition: Int32, + position: Int32, maximumTokens: Int, - schema: GenerationSchema - ) throws { + tokenToTextFn: @escaping (Int) -> String? + ) { self.context = context self.vocab = vocab self.vocabSize = vocabSize self.sampler = sampler - self.tokenToText = tokenToText self.batch = batch - self.position = initialPosition + self.position = position self.remainingTokens = maximumTokens self.totalTokenBudget = maximumTokens - self.schema = schema + self.eosToken = Int(llama_vocab_eos(vocab)) - guard let quoteToken = try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, "\"").first else { - throw LlamaLanguageModelError.tokenizationFailed - } - self.quoteToken = quoteToken - self.eosToken = llama_vocab_eos(vocab) + let eotTokenValue = llama_vocab_eot(vocab) + let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken + self.endTokens = [self.eosToken, endOfTurnToken] - self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens(vocabSize: vocabSize, tokenToText: tokenToText) - self.validStringTokens = StructuredJSONGenerator.buildValidJSONStringContentTokens( + self.tokenToTextFn = tokenToTextFn + self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty( vocabSize: vocabSize, - tokenToText: tokenToText + tokenToText: tokenToTextFn ) - var tokensOrQuote = validStringTokens - tokensOrQuote.insert(quoteToken) - tokensOrQuote.insert(eosToken) - self.validStringTokensOrQuote = tokensOrQuote } - func generate() throws -> String { - var generator = self - return try generator.generateNode(schema.root) + func isSpecialToken(_ token: Int) -> Bool { + let attributes = llama_vocab_get_attr(vocab, llama_token(token)) + return (attributes.rawValue & LLAMA_TOKEN_ATTR_CONTROL.rawValue) != 0 } - private static func buildDigitOnlyTokens( + private static func buildTokensExcludedFromRepetitionPenalty( vocabSize: Int, - tokenToText: (llama_token) -> String? - ) -> Set { - Set((0 ..< vocabSize).compactMap { tokenIndex in - let token = llama_token(tokenIndex) - guard let text = tokenToText(token), !text.isEmpty else { return nil } - return text.allSatisfy({ $0.isNumber }) ? token : nil - }) - } - - private static func buildValidJSONStringContentTokens( - vocabSize: Int, - tokenToText: (llama_token) -> String? - ) -> Set { - var allowed = Set() - allowed.reserveCapacity(vocabSize / 4) - - for tokenIndex in 0 ..< vocabSize { - let token = llama_token(tokenIndex) - guard let text = tokenToText(token), !text.isEmpty else { continue } - if text.allSatisfy({ $0.isValidJSONStringCharacter }) { - allowed.insert(token) + tokenToText: (Int) -> String? + ) -> Set { + let excludedTexts: Set = ["{", "}", "[", "]", ",", ":", "\""] + var excluded = Set() + excluded.reserveCapacity(excludedTexts.count * 4) + + for token in 0 ..< vocabSize { + guard let text = tokenToText(token) else { continue } + let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines) + if excludedTexts.contains(trimmed) { + excluded.insert(token) } } - return allowed + + return excluded } - private static func tokenizeFragment(vocab: OpaquePointer, _ text: String) throws -> [llama_token] { + func tokenize(_ text: String) throws -> [Int] { let utf8Count = text.utf8.count let capacity = Int32(max(utf8Count * 2, 8)) let tokens = UnsafeMutablePointer.allocate(capacity: Int(capacity)) @@ -1039,12 +1013,18 @@ import Foundation false ) guard tokenCount > 0 else { return [] } - return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount))) + return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount))).map { Int($0) } + } + + func tokenText(_ token: Int) -> String? { + tokenToTextFn(token) } - private mutating func decodeToken(_ token: llama_token) throws { + mutating func decode(_ token: Int) throws { + let llamaToken = llama_token(token) + batch.pointee.n_tokens = 1 - batch.pointee.token[0] = token + batch.pointee.token[0] = llamaToken batch.pointee.pos[0] = position batch.pointee.n_seq_id[0] = 1 if let seqIds = batch.pointee.seq_id, let seqId = seqIds[0] { @@ -1060,199 +1040,24 @@ import Foundation throw LlamaLanguageModelError.decodingFailed } - // Keep sampler state aligned with the actual context tokens. - llama_sampler_accept(sampler, token) - } - - private mutating func emitLiteral(_ text: String) throws -> String { - for token in try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, text) { - guard remainingTokens > 0 else { throw LlamaLanguageModelError.decodingFailed } - try decodeToken(token) + if !tokensExcludedFromRepetitionPenalty.contains(Int(llamaToken)) { + llama_sampler_accept(sampler, llamaToken) } - return text } - private mutating func sampleToken(allowedTokens: Set) -> llama_token { + mutating func sample(from allowedTokens: Set) throws -> Int { guard let logits = llama_get_logits(context) else { - return llama_vocab_eos(vocab) + return eosToken } for tokenIndex in 0 ..< vocabSize { - let token = llama_token(tokenIndex) - if !allowedTokens.contains(token) { + if !allowedTokens.contains(tokenIndex) { logits[tokenIndex] = -Float.infinity } } let tokenIndex = batch.pointee.n_tokens - 1 - return llama_sampler_sample(sampler, context, tokenIndex) - } - - private func maxTokenCountForFreeString() -> Int { - let perStringLimit = max(32, totalTokenBudget / 4) - return min(remainingTokens, perStringLimit) - } - - private mutating func generateFreeString(maxTokens: Int) throws -> String { - var result = "" - var generatedTokens = 0 - - while remainingTokens > 0, generatedTokens < maxTokens { - let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrQuote - let token = sampleToken(allowedTokens: allowedTokens) - if token == quoteToken || token == eosToken { break } - - result += tokenToText(token) ?? "" - generatedTokens += 1 - try decodeToken(token) - } - - return result - } - - private mutating func generateLiteralChoice(_ candidates: [String]) throws -> String { - let tokenizedCandidates = try candidates.map { try StructuredJSONGenerator.tokenizeFragment(vocab: vocab, $0) } - .filter { !$0.isEmpty } - guard !tokenizedCandidates.isEmpty else { throw LlamaLanguageModelError.tokenizationFailed } - - var prefixes = tokenizedCandidates - var emitted = "" - var tokenPosition = 0 - - while remainingTokens > 0 { - if prefixes.contains(where: { $0.count == tokenPosition }) { break } - - let allowed = Set(prefixes.compactMap { tokens -> llama_token? in - guard tokenPosition < tokens.count else { return nil } - return tokens[tokenPosition] - }) - - let nextToken = sampleToken(allowedTokens: allowed) - emitted += tokenToText(nextToken) ?? "" - try decodeToken(nextToken) - - prefixes = prefixes.filter { tokens in - tokenPosition < tokens.count && tokens[tokenPosition] == nextToken - } - tokenPosition += 1 - if prefixes.isEmpty { break } - } - - return emitted - } - - private mutating func generateNumber(_ numberNode: GenerationSchema.NumberNode) throws -> String { - if numberNode.integerOnly { - let clamped = clampInteger(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) - return try emitLiteral(String(clamped)) - } - - let clamped = clampDouble(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) - return try emitLiteral(formatNumberLiteral(clamped)) - } - - private func clampInteger(defaultValue: Int, minimum: Double?, maximum: Double?) -> Int { - let clampedMinimum = minimum.map { Int(ceil($0)) } - let clampedMaximum = maximum.map { Int(floor($0)) } - let normalized = normalizeBounds(minimum: clampedMinimum, maximum: clampedMaximum) - return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) - } - - private func clampDouble(defaultValue: Double, minimum: Double?, maximum: Double?) -> Double { - let normalized = normalizeBounds(minimum: minimum, maximum: maximum) - return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) - } - - private func normalizeBounds(minimum: T?, maximum: T?) -> (minimum: T?, maximum: T?) { - guard let minimum, let maximum, minimum > maximum else { return (minimum, maximum) } - return (minimum, minimum) - } - - private func clamp(_ value: T, minimum: T?, maximum: T?) -> T { - if let minimum, value < minimum { return minimum } - if let maximum, value > maximum { return maximum } - return value - } - - private func formatNumberLiteral(_ value: Double) -> String { - if value.rounded() == value { return String(Int(value)) } - return String(value) - } - - private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { - let elementCount = arrayNode.minItems ?? arrayNode.maxItems ?? 4 - var output = try emitLiteral("[") - - for index in 0 ..< elementCount { - output += try generateNode(arrayNode.items) - if index < elementCount - 1 { - output += try emitLiteral(",") - } - } - - output += try emitLiteral("]") - return output - } - - private mutating func generateObject(_ objectNode: GenerationSchema.ObjectNode) throws -> String { - let keys = objectNode.properties.keys.sorted() - var output = try emitLiteral("{") - - for (index, key) in keys.enumerated() { - output += try emitLiteral("\"") - output += try emitLiteral(key) - output += try emitLiteral("\":") - - if let propertyNode = objectNode.properties[key] { - output += try generateNode(propertyNode) - } else { - output += try emitLiteral("null") - } - - if index < keys.count - 1 { - output += try emitLiteral(",") - } - } - - output += try emitLiteral("}") - return output - } - - private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { - guard remainingTokens > 0 else { throw LlamaLanguageModelError.decodingFailed } - - switch node { - case .string(let stringNode): - var output = try emitLiteral("\"") - if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { - output += try generateLiteralChoice(enumChoices) - } else { - output += try generateFreeString(maxTokens: maxTokenCountForFreeString()) - } - output += try emitLiteral("\"") - return output - - case .number(let numberNode): - return try generateNumber(numberNode) - - case .boolean: - return try generateLiteralChoice(["true", "false"]) - - case .array(let arrayNode): - return try generateArray(arrayNode) - - case .object(let objectNode): - return try generateObject(objectNode) - - case .anyOf(let nodes): - // Pick the first option deterministically (schema-driven). - guard let first = nodes.first else { throw LlamaLanguageModelError.decodingFailed } - return try generateNode(first) - - case .ref(let refName): - guard let referenced = schema.defs[refName] else { throw LlamaLanguageModelError.decodingFailed } - return try generateNode(referenced) - } + return Int(llama_sampler_sample(sampler, context, tokenIndex)) } } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f0c54327..12a57ae2 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -67,26 +67,22 @@ import Foundation context = try await loadModel(id: modelId) } - if type != String.self { - let schema = type.generationSchema - let jsonString = try await generateStructuredJSON( - session: session, - prompt: prompt, - context: context, - options: options, - schema: schema, - includeSchemaInPrompt: includeSchemaInPrompt - ) - - let generatedContent = try GeneratedContent(json: jsonString) - let content = try type.init(generatedContent) - - return LanguageModelSession.Response( - content: content, - rawContent: generatedContent, - transcriptEntries: ArraySlice([]) - ) - } + if type != String.self { + let jsonString = try await generateStructuredJSON( + context: context, + session: session, + prompt: prompt, + schema: type.generationSchema, + options: options + ) + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } // Convert session tools to MLX ToolSpec format let toolSpecs: [ToolSpec]? = @@ -452,7 +448,7 @@ import Foundation return try convertToSendableJSONObject(dictionaryValue) } - throw StructuredGenerationError.invalidTokenization + throw StructuredGenerationError.unsupportedJSONValueType } // MARK: - Tool Invocation Handling @@ -541,29 +537,22 @@ import Foundation // MARK: - Structured JSON Generation private enum StructuredGenerationError: Error { - case missingTokenizer - case emptyPrompt - case invalidQuoteToken - case invalidTokenization - case tokenBudgetExceeded case invalidVocabSize + case unsupportedJSONValueType } private func generateStructuredJSON( + context: ModelContext, session: LanguageModelSession, prompt: Prompt, - context: ModelContext, - options: GenerationOptions, schema: GenerationSchema, - includeSchemaInPrompt: Bool + options: GenerationOptions ) async throws -> String { - let structuredMaxTokens = options.maximumResponseTokens ?? 512 + let maxTokens = options.maximumResponseTokens ?? 512 let generateParameters = toStructuredGenerateParameters(options) - var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) - if includeSchemaInPrompt { - chat.insert(.init(role: .system, content: schema.schemaPrompt()), at: 0) - } + let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schema.schemaPrompt()) let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), @@ -571,48 +560,78 @@ import Foundation ) let lmInput = try await context.processor.prepare(input: userInput) - var decoder = try MLXTokenDecoder( + let backend = try MLXTokenBackend( context: context, input: lmInput, parameters: generateParameters, - maximumTokens: structuredMaxTokens - ) - - let vocabSize = decoder.vocabSize - let eosToken = context.tokenizer.eosTokenId ?? -1 - var generator = try StructuredJSONGenerator( - schema: schema, - tokenizeFragment: { fragment in - context.tokenizer.encode(text: fragment, addSpecialTokens: false) - }, - tokenText: { token in - context.tokenizer.decode(tokens: [token], skipSpecialTokens: false) - }, - decodeToken: { token in - try decoder.decodeToken(token) - }, - sampleToken: { allowedTokens in - try decoder.sampleToken(allowedTokens: allowedTokens) - }, - maximumTokens: structuredMaxTokens, - vocabSize: vocabSize, - eosToken: eosToken + maximumTokens: maxTokens ) + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) let json = try generator.generate() Stream().synchronize() return json } - private struct MLXTokenDecoder { + private func normalizeChatForStructuredGeneration( + _ chat: [MLXLMCommon.Chat.Message], + schemaPrompt: String + ) -> [MLXLMCommon.Chat.Message] { + var systemMessageParts: [String] = [] + systemMessageParts.append(schemaPrompt) + + var messages: [MLXLMCommon.Chat.Message] = [] + messages.reserveCapacity(chat.count) + + for message in chat { + if message.role == .system { + systemMessageParts.append(message.content) + continue + } + + if let last = messages.last, last.role == message.role { + let merged = MLXLMCommon.Chat.Message(role: last.role, content: "\(last.content)\n\(message.content)") + messages.removeLast() + messages.append(merged) + } else { + messages.append(message) + } + } + + let systemPrefix = systemMessageParts + .map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + .filter { !$0.isEmpty } + .joined(separator: "\n\n") + + guard !systemPrefix.isEmpty else { + return messages + } + + if let firstUserIndex = messages.firstIndex(where: { $0.role == .user }) { + let existing = messages[firstUserIndex].content + messages[firstUserIndex] = MLXLMCommon.Chat.Message(role: .user, content: "\(systemPrefix)\n\n\(existing)") + return messages + } + + messages.insert(.init(role: .user, content: systemPrefix), at: 0) + return messages + } + + private struct MLXTokenBackend: TokenBackend { let model: any MLXLMCommon.LanguageModel + let tokenizer: any Tokenizer var state: MLXLMCommon.LMOutput.State? var cache: [MLXLMCommon.KVCache] var processor: MLXLMCommon.LogitProcessor? let sampler: MLXLMCommon.LogitSampler + let tokensExcludedFromRepetitionPenalty: Set + let endTokens: Set var currentLogits: MLXArray let vocabSize: Int + let eosToken: Int + var remainingTokens: Int + let totalTokenBudget: Int init( context: ModelContext, @@ -621,10 +640,24 @@ import Foundation maximumTokens: Int ) throws { self.model = context.model + self.tokenizer = context.tokenizer self.state = nil self.cache = context.model.newCache(parameters: parameters) self.processor = parameters.processor() self.sampler = parameters.sampler() + self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens + guard let eosTokenId = context.tokenizer.eosTokenId else { + throw StructuredGenerationError.invalidVocabSize + } + self.eosToken = eosTokenId + self.endTokens = Self.buildEndTokens( + eosTokenId: eosTokenId, + tokenizer: context.tokenizer, + configuration: context.configuration + ) + + self.tokensExcludedFromRepetitionPenalty = Self.buildTokensExcludedFromRepetitionPenalty(tokenizer: context.tokenizer) processor?.prompt(input.text.tokens) @@ -658,299 +691,93 @@ import Foundation } } - mutating func decodeToken(_ token: Int) throws { - let tokenArray = MLXArray(token) - processor?.didSample(token: tokenArray) - - let inputText = MLXLMCommon.LMInput.Text(tokens: tokenArray) - let output = model( - inputText[text: .newAxis], - cache: cache.isEmpty ? nil : cache, - state: state - ) - state = output.state - currentLogits = output.logits - } - - mutating func sampleToken(allowedTokens: Set) throws -> Int { - guard !allowedTokens.isEmpty else { throw StructuredGenerationError.invalidTokenization } - - var logits = currentLogits[0..., -1, 0...] - logits = processor?.process(logits: logits) ?? logits - if logits.dtype == .bfloat16 { - logits = logits.asType(.float32) - } - - let allowedIndices = MLXArray(allowedTokens.map { UInt32($0) }) - let maskedLogits = full(logits.shape, values: -Float.infinity) - maskedLogits[0..., allowedIndices] = logits[0..., allowedIndices] - - let sampledToken = sampler.sample(logits: maskedLogits) - processor?.didSample(token: sampledToken) - return sampledToken.item(Int.self) - } - } - - private struct StructuredJSONGenerator { - let schema: GenerationSchema - let tokenizeFragment: (String) throws -> [Int] - let tokenText: (Int) -> String - let decodeToken: (Int) throws -> Void - let sampleToken: (Set) throws -> Int - - var remainingTokens: Int - let totalTokenBudget: Int - - let digitOnlyTokens: Set - let validStringTokens: Set - let stringTerminators: Set - let validStringTokensOrTerminators: Set - - init( - schema: GenerationSchema, - tokenizeFragment: @escaping (String) throws -> [Int], - tokenText: @escaping (Int) -> String, - decodeToken: @escaping (Int) throws -> Void, - sampleToken: @escaping (Set) throws -> Int, - maximumTokens: Int, - vocabSize: Int, - eosToken: Int - ) throws { - self.schema = schema - self.tokenizeFragment = tokenizeFragment - self.tokenText = tokenText - self.decodeToken = decodeToken - self.sampleToken = sampleToken - self.remainingTokens = maximumTokens - self.totalTokenBudget = maximumTokens - - // String terminators: EOS + dumb quote (") - var terminators = Set() - if eosToken >= 0 { terminators.insert(eosToken) } - let quoteTokens = try tokenizeFragment("\"") - guard quoteTokens.count == 1, let quoteToken = quoteTokens.first else { - throw StructuredGenerationError.invalidQuoteToken - } - terminators.insert(quoteToken) - self.stringTerminators = terminators - - self.digitOnlyTokens = StructuredJSONGenerator.buildDigitOnlyTokens( - vocabSize: vocabSize, - tokenText: tokenText - ) - self.validStringTokens = StructuredJSONGenerator.buildValidJSONStringContentTokens( - vocabSize: vocabSize, - tokenText: tokenText - ) - self.validStringTokensOrTerminators = validStringTokens.union(stringTerminators) - } - - mutating func generate() throws -> String { - try generateNode(schema.root) - } - - private func maxTokenCountForFreeString() -> Int { - let perStringLimit = max(32, totalTokenBudget / 4) - return min(remainingTokens, perStringLimit) - } - - private static func buildDigitOnlyTokens( - vocabSize: Int, - tokenText: (Int) -> String + private static func buildEndTokens( + eosTokenId: Int, + tokenizer: any Tokenizer, + configuration: ModelConfiguration ) -> Set { - Set((0 ..< vocabSize).filter { tokenId in - let text = tokenText(tokenId) - guard !text.isEmpty else { return false } - return text.allSatisfy({ $0.isNumber }) - }) - } + var tokens: Set = [eosTokenId] - private static func buildValidJSONStringContentTokens( - vocabSize: Int, - tokenText: (Int) -> String - ) -> Set { - var allowed = Set() - allowed.reserveCapacity(vocabSize / 4) - - for tokenId in 0 ..< vocabSize { - let text = tokenText(tokenId) - guard !text.isEmpty else { continue } - if text.allSatisfy({ $0.isValidJSONStringCharacter }) { - allowed.insert(tokenId) - } + // If the tokenizer declares an EOS token string, prefer treating its ID as an end token too. + // Some chat models use a string EOS marker (e.g. "") whose ID may differ from eosTokenId. + if let eosString = tokenizer.eosToken, let eosStringId = tokenizer.convertTokenToId(eosString) { + tokens.insert(eosStringId) } - return allowed - } - private mutating func emitLiteral(_ text: String) throws -> String { - for token in try tokenizeFragment(text) { - guard remainingTokens > 0 else { throw StructuredGenerationError.tokenBudgetExceeded } - try decodeToken(token) - remainingTokens -= 1 + for tokenString in configuration.extraEOSTokens { + if let id = tokenizer.convertTokenToId(tokenString) { + tokens.insert(id) + } } - return text + return tokens } - private mutating func generateFreeString(maxTokens: Int) throws -> String { - var result = "" - var generatedTokens = 0 - - while remainingTokens > 0, generatedTokens < maxTokens { - let allowedTokens = result.isEmpty ? validStringTokens : validStringTokensOrTerminators - let token = try sampleToken(allowedTokens) - if stringTerminators.contains(token) { break } - - result += tokenText(token) - generatedTokens += 1 - try decodeToken(token) - remainingTokens -= 1 - } - - return result + func isSpecialToken(_ token: Int) -> Bool { + // Use swift-transformers' own special token registry (skipSpecialTokens) instead of guessing. + let raw = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + guard !raw.isEmpty else { return false } + let filtered = tokenizer.decode(tokens: [token], skipSpecialTokens: true) + return filtered.isEmpty } - private mutating func generateLiteralChoice(_ candidates: [String]) throws -> String { - let tokenizedCandidates = try candidates.map { try tokenizeFragment($0) }.filter { !$0.isEmpty } - guard !tokenizedCandidates.isEmpty else { throw StructuredGenerationError.invalidTokenization } - - var prefixes = tokenizedCandidates - var emitted = "" - var tokenPosition = 0 - - while remainingTokens > 0 { - if prefixes.contains(where: { $0.count == tokenPosition }) { break } - - let allowed = Set(prefixes.compactMap { tokens -> Int? in - guard tokenPosition < tokens.count else { return nil } - return tokens[tokenPosition] - }) - - let nextToken = try sampleToken(allowed) - emitted += tokenText(nextToken) - try decodeToken(nextToken) - remainingTokens -= 1 + private static func buildTokensExcludedFromRepetitionPenalty(tokenizer: any Tokenizer) -> Set { + let excludedTexts = ["{", "}", "[", "]", ",", ":", "\""] + var excluded = Set() + excluded.reserveCapacity(excludedTexts.count * 2) - prefixes = prefixes.filter { tokens in - tokenPosition < tokens.count && tokens[tokenPosition] == nextToken + for text in excludedTexts { + let tokens = tokenizer.encode(text: text, addSpecialTokens: false) + for token in tokens { + excluded.insert(token) } - tokenPosition += 1 - if prefixes.isEmpty { break } } - return emitted - } - - private mutating func generateNumber(_ numberNode: GenerationSchema.NumberNode) throws -> String { - if numberNode.integerOnly { - let clamped = clampInteger(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) - return try emitLiteral(String(clamped)) - } - - let clamped = clampDouble(defaultValue: 0, minimum: numberNode.minimum, maximum: numberNode.maximum) - return try emitLiteral(formatNumberLiteral(clamped)) - } - - private func clampInteger(defaultValue: Int, minimum: Double?, maximum: Double?) -> Int { - let clampedMinimum = minimum.map { Int(ceil($0)) } - let clampedMaximum = maximum.map { Int(floor($0)) } - let normalized = normalizeBounds(minimum: clampedMinimum, maximum: clampedMaximum) - return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) - } - - private func clampDouble(defaultValue: Double, minimum: Double?, maximum: Double?) -> Double { - let normalized = normalizeBounds(minimum: minimum, maximum: maximum) - return clamp(defaultValue, minimum: normalized.minimum, maximum: normalized.maximum) + return excluded } - private func normalizeBounds(minimum: T?, maximum: T?) -> (minimum: T?, maximum: T?) { - guard let minimum, let maximum, minimum > maximum else { return (minimum, maximum) } - return (minimum, minimum) + func tokenize(_ text: String) throws -> [Int] { + tokenizer.encode(text: text, addSpecialTokens: false) } - private func clamp(_ value: T, minimum: T?, maximum: T?) -> T { - if let minimum, value < minimum { return minimum } - if let maximum, value > maximum { return maximum } - return value + func tokenText(_ token: Int) -> String? { + let decoded = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + return decoded.isEmpty ? nil : decoded } - private func formatNumberLiteral(_ value: Double) -> String { - if value.rounded() == value { return String(Int(value)) } - return String(value) - } - - private mutating func generateArray(_ arrayNode: GenerationSchema.ArrayNode) throws -> String { - let elementCount = arrayNode.minItems ?? arrayNode.maxItems ?? 4 - var output = try emitLiteral("[") + mutating func decode(_ token: Int) throws { + let inputText = MLXLMCommon.LMInput.Text(tokens: MLXArray([Int32(token)])) + let output = model( + inputText[text: .newAxis], + cache: cache.isEmpty ? nil : cache, + state: state + ) + state = output.state + currentLogits = output.logits + remainingTokens -= 1 - for index in 0 ..< elementCount { - output += try generateNode(arrayNode.items) - if index < elementCount - 1 { - output += try emitLiteral(",") - } + if !tokensExcludedFromRepetitionPenalty.contains(token) { + let tokenArray = MLXArray(Int32(token)) + processor?.didSample(token: tokenArray) } - - output += try emitLiteral("]") - return output } - private mutating func generateObject(_ objectNode: GenerationSchema.ObjectNode) throws -> String { - let keys = objectNode.properties.keys.sorted() - var output = try emitLiteral("{") - - for (index, key) in keys.enumerated() { - output += try emitLiteral("\"") - output += try emitLiteral(key) - output += try emitLiteral("\":") - - if let propertyNode = objectNode.properties[key] { - output += try generateNode(propertyNode) - } else { - output += try emitLiteral("null") - } - - if index < keys.count - 1 { - output += try emitLiteral(",") - } + mutating func sample(from allowedTokens: Set) throws -> Int { + guard !allowedTokens.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed } - output += try emitLiteral("}") - return output - } - - private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { - guard remainingTokens > 0 else { throw StructuredGenerationError.tokenBudgetExceeded } - - switch node { - case .string(let stringNode): - var output = try emitLiteral("\"") - if let enumChoices = stringNode.enumChoices, !enumChoices.isEmpty { - output += try generateLiteralChoice(enumChoices) - } else { - output += try generateFreeString(maxTokens: maxTokenCountForFreeString()) - } - output += try emitLiteral("\"") - return output - - case .number(let numberNode): - return try generateNumber(numberNode) - - case .boolean: - return try generateLiteralChoice(["true", "false"]) - - case .array(let arrayNode): - return try generateArray(arrayNode) - - case .object(let objectNode): - return try generateObject(objectNode) + var logits = currentLogits[0..., -1, 0...] + logits = processor?.process(logits: logits) ?? logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } - case .anyOf(let nodes): - guard let first = nodes.first else { throw StructuredGenerationError.invalidTokenization } - return try generateNode(first) + let allowedIndices = MLXArray(allowedTokens.map { UInt32($0) }) + let maskedLogits = full(logits.shape, values: -Float.infinity) + maskedLogits[0..., allowedIndices] = logits[0..., allowedIndices] - case .ref(let refName): - guard let referenced = schema.defs[refName] else { throw StructuredGenerationError.invalidTokenization } - return try generateNode(referenced) - } + let sampledToken = sampler.sample(logits: maskedLogits) + return sampledToken.item(Int.self) } } #endif // MLX diff --git a/Sources/AnyLanguageModel/StructuredGeneration.swift b/Sources/AnyLanguageModel/StructuredGeneration.swift new file mode 100644 index 00000000..9a8bd0ed --- /dev/null +++ b/Sources/AnyLanguageModel/StructuredGeneration.swift @@ -0,0 +1,314 @@ +import Foundation + +// MARK: - Token Backend + +/// Abstracts token-level operations for structured JSON generation. +package protocol TokenBackend { + func tokenize(_ text: String) throws -> [Int] + func tokenText(_ token: Int) -> String? + func isSpecialToken(_ token: Int) -> Bool + mutating func decode(_ token: Int) throws + mutating func sample(from allowedTokens: Set) throws -> Int + + var eosToken: Int { get } + var endTokens: Set { get } + var vocabSize: Int { get } + var remainingTokens: Int { get set } + var totalTokenBudget: Int { get } +} + +// MARK: - JSON Generator + +/// Generates JSON conforming to a schema using constrained token sampling. +package struct ConstrainedJSONGenerator { + private var backend: Backend + private let schema: GenerationSchema + + private let quoteToken: Int + + private let stringTerminators: Set + private let stringInitialAllowedTokens: Set + private let stringContinuationAllowedTokens: Set + + private let basicTerminators: Set + private let integerTerminators: Set + private let doubleTerminators: Set + + package init(backend: Backend, schema: GenerationSchema) throws { + self.backend = backend + self.schema = schema + + guard let quoteToken = try backend.tokenize("\"").first else { + throw ConstrainedGenerationError.tokenizationFailed + } + self.quoteToken = quoteToken + + self.stringTerminators = backend.endTokens.union([quoteToken]) + + var structuralTerminators = backend.endTokens + for structuralText in [",", "}", "]", ":"] { + if let token = try backend.tokenize(structuralText).first { + structuralTerminators.insert(token) + } + } + self.basicTerminators = structuralTerminators + self.integerTerminators = Self.buildValidIntegerTokens(backend: backend).union(structuralTerminators) + self.doubleTerminators = Self.buildValidDecimalTokens(backend: backend).union(structuralTerminators) + + let stringContentTokens = Self.buildValidStringTokens(backend: backend) + self.stringInitialAllowedTokens = stringContentTokens + self.stringContinuationAllowedTokens = stringContentTokens.union(stringTerminators) + } + + package mutating func generate() throws -> String { + try generateNode(schema.root) + } + + private static func buildValidStringTokens(backend: Backend) -> Set { + let allowedWhitespace: Set = [" ", "\t", "\n"] + var allowed = Set() + allowed.reserveCapacity(backend.vocabSize / 4) + + for token in 0 ..< backend.vocabSize { + if backend.endTokens.contains(token) { continue } + if backend.isSpecialToken(token) { continue } + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + guard text.allSatisfy({ $0.isValidJSONStringCharacter }) else { continue } + + if text.allSatisfy({ $0.isWhitespace }) { + if text.count == 1, let char = text.first, allowedWhitespace.contains(char) { + allowed.insert(token) + } + } else { + allowed.insert(token) + } + } + return allowed + } + + private static func buildValidIntegerTokens(backend: Backend) -> Set { + var allowed = Set() + for token in 0 ..< backend.vocabSize { + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + if text.allSatisfy({ $0.isNumber || $0 == "-" }) { + allowed.insert(token) + } + } + return allowed + } + + private static func buildValidDecimalTokens(backend: Backend) -> Set { + var allowed = Set() + for token in 0 ..< backend.vocabSize { + guard let text = backend.tokenText(token), !text.isEmpty else { continue } + if text.allSatisfy({ $0.isNumber || $0 == "-" || $0 == "." }) { + allowed.insert(token) + } + } + return allowed + } + + private mutating func emit(_ text: String) throws -> String { + for token in try backend.tokenize(text) { + guard backend.remainingTokens > 0 else { + throw ConstrainedGenerationError.tokenBudgetExceeded + } + try backend.decode(token) + } + return text + } + + private func maxFreeStringTokens() -> Int { + let perStringLimit = max(32, backend.totalTokenBudget / 4) + return min(backend.remainingTokens, perStringLimit) + } + + private mutating func generateFreeString(maxTokens: Int) throws -> String { + var result = "" + var generated = 0 + + while backend.remainingTokens > 0, generated < maxTokens { + let allowed = result.isEmpty ? stringInitialAllowedTokens : stringContinuationAllowedTokens + let token = try backend.sample(from: allowed) + if stringTerminators.contains(token) { break } + + var text = backend.tokenText(token) ?? "" + if result.last?.isWhitespace == true && text.first?.isWhitespace == true { + text = String(text.drop(while: { $0.isWhitespace })) + } + result += text + generated += 1 + try backend.decode(token) + } + + return result + } + + private mutating func generateChoice(_ candidates: [String]) throws -> String { + let tokenized = try candidates.map { try backend.tokenize($0) }.filter { !$0.isEmpty } + guard !tokenized.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed + } + + var prefixes = tokenized + var emitted = "" + var position = 0 + + while backend.remainingTokens > 0 { + if prefixes.contains(where: { $0.count == position }) { break } + + let allowed = Set(prefixes.compactMap { tokens -> Int? in + guard position < tokens.count else { return nil } + return tokens[position] + }) + + let token = try backend.sample(from: allowed) + emitted += backend.tokenText(token) ?? "" + try backend.decode(token) + + prefixes = prefixes.filter { $0.count > position && $0[position] == token } + position += 1 + if prefixes.isEmpty { break } + } + + return emitted + } + + private mutating func generateNumber(_ node: GenerationSchema.NumberNode) throws -> String { + let allowedTokens = node.integerOnly ? integerTerminators : doubleTerminators + var result = "" + let maxTokens = 16 + + while backend.remainingTokens > 0, result.count < maxTokens { + let token = try backend.sample(from: allowedTokens) + if basicTerminators.contains(token) { break } + + guard let text = backend.tokenText(token) else { break } + result += text + try backend.decode(token) + } + + return clampNumberString(result.isEmpty ? "0" : result, node: node) + } + + private func clampNumberString(_ text: String, node: GenerationSchema.NumberNode) -> String { + if node.integerOnly { + let value = Int(text) ?? 0 + let clamped = clampInt(value, min: node.minimum, max: node.maximum) + return String(clamped) + } else { + let value = Double(text) ?? 0 + let clamped = clampDouble(value, min: node.minimum, max: node.maximum) + return formatDouble(clamped) + } + } + + private func clampInt(_ value: Int, min: Double?, max: Double?) -> Int { + let lower = min.map { Int(ceil($0)) } + let upper = max.map { Int(floor($0)) } + return clamp(value, min: lower, max: upper) + } + + private func clampDouble(_ value: Double, min: Double?, max: Double?) -> Double { + clamp(value, min: min, max: max) + } + + private func clamp(_ value: T, min: T?, max: T?) -> T { + var result = value + if let min { result = Swift.max(result, min) } + if let max { result = Swift.min(result, max) } + return result + } + + private func formatDouble(_ value: Double) -> String { + if value.truncatingRemainder(dividingBy: 1) == 0 { + return String(Int(value)) + } + let formatted = String(format: "%.6g", value) + return formatted + } + + private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { + guard backend.remainingTokens > 0 else { + throw ConstrainedGenerationError.tokenBudgetExceeded + } + + switch node { + case .object(let objectNode): + return try generateObject(objectNode) + case .array(let arrayNode): + return try generateArray(arrayNode) + case .string(let stringNode): + return try generateString(stringNode) + case .number(let numberNode): + return try generateNumber(numberNode) + case .boolean: + return try generateChoice(["true", "false"]) + case .ref(let typeName): + guard let referenced = schema.defs[typeName] else { + throw ConstrainedGenerationError.missingReference(typeName) + } + return try generateNode(referenced) + case .anyOf(let variants): + guard let first = variants.first else { + throw ConstrainedGenerationError.emptyAnyOf + } + return try generateNode(first) + } + } + + private mutating func generateObject(_ node: GenerationSchema.ObjectNode) throws -> String { + let keys = node.properties.keys.sorted() + var output = try emit("{") + + for (index, key) in keys.enumerated() { + output += try emit("\"\(key)\":") + output += try generateNode(node.properties[key] ?? .string(.init())) + + if index < keys.count - 1 { + output += try emit(",") + } + } + + output += try emit("}") + return output + } + + private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { + let count = node.minItems ?? node.maxItems ?? 4 + var output = try emit("[") + + for index in 0 ..< count { + output += try generateNode(node.items) + if index < count - 1 { + output += try emit(",") + } + } + + output += try emit("]") + return output + } + + private mutating func generateString(_ node: GenerationSchema.StringNode) throws -> String { + var output = try emit("\"") + + if let choices = node.enumChoices, !choices.isEmpty { + output += try generateChoice(choices) + } else { + let content = try generateFreeString(maxTokens: maxFreeStringTokens()) + output += content.trimmingCharacters(in: .whitespaces) + } + + output += try emit("\"") + return output + } +} + +// MARK: - Errors + +package enum ConstrainedGenerationError: Error { + case tokenizationFailed + case tokenBudgetExceeded + case missingReference(String) + case emptyAnyOf +} diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift index accb4c90..787e2930 100644 --- a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -129,6 +129,10 @@ private struct SupportedModel: Sendable { let model: any LanguageModel static var all: [SupportedModel] { + func environmentValue(_ key: String) -> String? { + ProcessInfo.processInfo.environment[key] ?? ProcessInfo.processInfo.environment["TEST_RUNNER_\(key)"] + } + var models: [SupportedModel] = [] #if canImport(FoundationModels) @@ -140,21 +144,21 @@ private struct SupportedModel: Sendable { #endif #if Llama - if let modelPath = ProcessInfo.processInfo.environment["LLAMA_MODEL_PATH"] { + if let modelPath = environmentValue("LLAMA_MODEL_PATH") { models.append(SupportedModel(name: "LlamaLanguageModel", model: LlamaLanguageModel(modelPath: modelPath))) } #endif #if MLX - let shouldRunMLX = ProcessInfo.processInfo.environment["ENABLE_MLX_TESTS"] != nil - || (ProcessInfo.processInfo.environment["CI"] == nil - && ProcessInfo.processInfo.environment["HF_TOKEN"] != nil - && ProcessInfo.processInfo.environment["XCTestConfigurationFilePath"] != nil) - if shouldRunMLX { + let shouldRunMLX = environmentValue("ENABLE_MLX_TESTS") != nil + || (environmentValue("CI") == nil + && environmentValue("HF_TOKEN") != nil + && environmentValue("XCTestConfigurationFilePath") != nil) + if let modelId = environmentValue("MLX_MODEL_ID"), shouldRunMLX { models.append( SupportedModel( name: "MLXLanguageModel", - model: MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit") + model: MLXLanguageModel(modelId: modelId) ) ) }