Skip to content

Commit 3cf7548

Browse files
committed
Incorporate feedback from review
1 parent 3a2b9ea commit 3cf7548

File tree

1 file changed

+53
-23
lines changed

1 file changed

+53
-23
lines changed

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -497,16 +497,17 @@ import Foundation
497497
return limits.max() ?? effectiveIdleLimit()
498498
}
499499

500+
private func idlePolicyConfiguration() -> GPUMemoryConfiguration {
501+
knownConfigs.max(by: { $0.idleCacheLimit < $1.idleCacheLimit })
502+
?? GPUMemoryConfiguration.automatic
503+
}
504+
500505
private func effectiveIdleLimit() -> Int {
501-
let limits = knownConfigs.map(\.idleCacheLimit)
502-
return limits.min() ?? GPUMemoryConfiguration.automatic.idleCacheLimit
506+
idlePolicyConfiguration().idleCacheLimit
503507
}
504508

505509
private func shouldClearOnEviction() -> Bool {
506-
if knownConfigs.isEmpty {
507-
return GPUMemoryConfiguration.automatic.clearCacheOnEviction
508-
}
509-
return knownConfigs.contains { $0.clearCacheOnEviction }
510+
idlePolicyConfiguration().clearCacheOnEviction
510511
}
511512
}
512513

@@ -619,7 +620,25 @@ import Foundation
619620
.concurrentRequests(
620621
.init(
621622
debugDescription:
622-
"Concurrent requests on the same LanguageModelSession are not supported while MLX KV cache reuse is enabled."
623+
"Concurrent requests on the same LanguageModelSession are not supported for MLX due to cache and memory management constraints."
624+
)
625+
)
626+
}
627+
628+
private static func maxToolIterationsExceededError(limit: Int) -> LanguageModelSession.GenerationError {
629+
.decodingFailure(
630+
.init(
631+
debugDescription:
632+
"Exceeded maximum tool iterations (\(limit)) while processing MLX tool calls."
633+
)
634+
)
635+
}
636+
637+
private static func repeatedToolCallLoopError() -> LanguageModelSession.GenerationError {
638+
.decodingFailure(
639+
.init(
640+
debugDescription:
641+
"Detected repeated MLX tool-call signature and aborted to avoid an infinite tool loop."
623642
)
624643
)
625644
}
@@ -663,7 +682,7 @@ import Foundation
663682
guard entry.prefixTokens.count == entry.prefillTokenCount else {
664683
return false
665684
}
666-
return Array(currentTokens.prefix(entry.prefillTokenCount)) == entry.prefixTokens
685+
return currentTokens.starts(with: entry.prefixTokens)
667686
}
668687

669688
private func resolveCache(
@@ -846,15 +865,19 @@ import Foundation
846865
if !collectedToolCalls.isEmpty {
847866
toolIteration += 1
848867
if toolIteration > maxToolIterations {
849-
break
868+
let unresolvedCalls = try makeTranscriptToolCalls(from: collectedToolCalls)
869+
allEntries.append(Transcript.Entry.toolCalls(Transcript.ToolCalls(unresolvedCalls)))
870+
throw Self.maxToolIterationsExceededError(limit: maxToolIterations)
850871
}
851872

852873
let signature =
853874
collectedToolCalls
854875
.map { "\($0.function.name):\($0.function.arguments)" }
855876
.joined(separator: "|")
856877
if signature == previousToolCallSignature {
857-
break
878+
let unresolvedCalls = try makeTranscriptToolCalls(from: collectedToolCalls)
879+
allEntries.append(Transcript.Entry.toolCalls(Transcript.ToolCalls(unresolvedCalls)))
880+
throw Self.repeatedToolCallLoopError()
858881
}
859882
previousToolCallSignature = signature
860883

@@ -1294,19 +1317,9 @@ import Foundation
12941317
case invocations([ToolInvocationResult])
12951318
}
12961319

1297-
private func resolveToolCalls(
1298-
_ toolCalls: [MLXLMCommon.ToolCall],
1299-
session: LanguageModelSession
1300-
) async throws -> ToolResolutionOutcome {
1301-
if toolCalls.isEmpty { return .invocations([]) }
1302-
1303-
var toolsByName: [String: any Tool] = [:]
1304-
for tool in session.tools {
1305-
if toolsByName[tool.name] == nil {
1306-
toolsByName[tool.name] = tool
1307-
}
1308-
}
1309-
1320+
private func makeTranscriptToolCalls(
1321+
from toolCalls: [MLXLMCommon.ToolCall]
1322+
) throws -> [Transcript.ToolCall] {
13101323
var transcriptCalls: [Transcript.ToolCall] = []
13111324
transcriptCalls.reserveCapacity(toolCalls.count)
13121325
for call in toolCalls {
@@ -1320,6 +1333,23 @@ import Foundation
13201333
)
13211334
)
13221335
}
1336+
return transcriptCalls
1337+
}
1338+
1339+
private func resolveToolCalls(
1340+
_ toolCalls: [MLXLMCommon.ToolCall],
1341+
session: LanguageModelSession
1342+
) async throws -> ToolResolutionOutcome {
1343+
if toolCalls.isEmpty { return .invocations([]) }
1344+
1345+
var toolsByName: [String: any Tool] = [:]
1346+
for tool in session.tools {
1347+
if toolsByName[tool.name] == nil {
1348+
toolsByName[tool.name] = tool
1349+
}
1350+
}
1351+
1352+
let transcriptCalls = try makeTranscriptToolCalls(from: toolCalls)
13231353

13241354
if let delegate = session.toolExecutionDelegate {
13251355
await delegate.didGenerateToolCalls(transcriptCalls, in: session)

0 commit comments

Comments
 (0)