feat(cluster): TP decode loop (rank 0 + rank 1)#196
Conversation
Add tensor-parallel inference engine and server: - ClusterControlMessage: add promptTokens (0x08), stepToken (0x09), sessionStop (0x0A) message types with Codable payload structs - ClusterSession/ClusterPeer: split jacclBootstrap into its own bootstrapHandler parameter, separate from inferenceHandler - TensorParallelInference: implement TensorParallelEngine (rank 0, greedy decode loop with AsyncStream) and TensorParallelServer (rank 1 frame handler); add ClusterSessionSendable protocol for testability; expose UncheckedSendableLLMModel for Swift 6 safety - ClusterModelLoader: load LlamaModelTP from a model directory using the jaccl DistributedGroup from the process environment - ClusterDiscovery: wire LlamaModelTP construction into the bootstrap completion path; expose currentEngine()/currentServer() accessors - TensorParallelDecodeTests: 17 tests covering message types, payload round-trips, engine/server construction, generate() semantics (maxTokens, EOS, determinism, frame sequence), and handleFrame routing
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
Benchmark ResultsRunner: 1-provider-streaming1 providers, 1 users, 30 requests, concurrency=5, streaming=true
Latency Decomposition
Assertion Report: FAIL
1-provider-non-streaming1 providers, 1 users, 20 requests, concurrency=5, streaming=false
Latency Decomposition
Assertion Report: FAIL
7-provider-multi-model7 providers, 5 users, 50 requests, concurrency=10, streaming=true
Latency Decomposition
Assertion Report: FAIL
3-provider-high-concurrency3 providers, 10 users, 60 requests, concurrency=20, streaming=true
Latency Decomposition
Assertion Report: FAIL
1-provider-queue-saturation1 providers, 10 users, 40 requests, concurrency=15, streaming=true
Latency Decomposition
Assertion Report: FAIL
3-provider-20-users3 providers, 20 users, 60 requests, concurrency=10, streaming=true
Latency Decomposition
Assertion Report: FAIL
1-provider-scaling1 providers, 5 users, 30 requests, concurrency=10, streaming=true
Latency Decomposition
Assertion Report: FAIL
3-provider-scaling3 providers, 5 users, 30 requests, concurrency=10, streaming=true
Latency Decomposition
Assertion Report: FAIL
5-provider-scaling5 providers, 5 users, 30 requests, concurrency=10, streaming=true
Latency Decomposition
Assertion Report: FAIL
3-provider-heavy-100conc-10kb3 providers, 20 users, 100 requests, concurrency=100, streaming=true
Latency Decomposition
Assertion Report: FAIL
|
What
Implements PR 4b of the tensor-parallel inference stack: the rank-0 decode loop and rank-1 serve loop for
LlamaModelTP.Changes
Protocol
ClusterControlMessage: three new message types —promptTokens(0x08),stepToken(0x09),sessionStop(0x0A) — each with aCodable, Sendablepayload struct carrying a requestuidClusterPeer.serve:jacclBootstrapframes now route to a dedicatedbootstrapHandlerparameter; inference frames go toinferenceHandler. Clean separation that was a stopgap in PR 4a.Engine (rank 0) —
TensorParallelEngineLlamaModelTPvianonisolated(unsafe)for Swift 6 compatibilitygenerate(promptTokens:maxTokens:eosTokenIDs:) -> AsyncStream<Int>: sendspromptTokensto rank 1, prefills, greedily samples, loopsstepTokenper token, sendssessionStopat endServer (rank 1) —
TensorParallelServerhandleFrame(_ data: Data)dispatches onClusterMsgTypepromptTokens→ reset KV cache + prefill (discards logits)stepToken→ decode step (discards logits; rank 0 samples)sessionStop→ clear cacheModel loading —
ClusterModelLoaderconfig.json, decodesLlamaConfiguration, callsMLX.DistributedGroup()(reads jaccl env vars set during bootstrap), constructsLlamaModelTP, loads weights viaMLXLMCommon.loadWeightsLLMModelFactorybecause the factory doesn't threadDistributedGroupthrough its pipelineWiring —
ClusterDiscoverymodelDirectoryis setsetModelDirectory(_:)public setter;currentEngine()/currentServer()accessors for the provider serve loop (PR 4d)Swift 6 Sendable
UncheckedSendableLLMModel(public struct,@unchecked Sendable) wrapsany LLMModelfor safe single-owner transfers across actor boundaries withoutsendingUncheckedSendableLLMModel; internal storage usesnonisolated(unsafe) letTests (
TensorParallelDecodeTests.swift— 17 tests, all passing)ClusterMsgTypecasesTensorParallelEngineandTensorParallelServerconstruct without error on singleton groupgenerate()produces ≤ maxTokens tokens and completes cleanlygenerate()is deterministic (same weights + same prompt → same output)generate()stops at EOS token after exactly 1 tokenpromptTokens → stepToken × N → sessionStoppromptTokensframe content (uid, tokens, maxTokens)handleFramerouting for all three server-side frame typesClusterPeer.servesignature compilation check (bootstrapHandler present)Out of scope
LlamaModelTPQ(quantized variant)Need help on this PR? Tag
@codesmithwith what you need. Autofix is disabled.