From 877550ac856e7ccf684114e241ded69c5a8b6d17 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 24 Mar 2026 16:50:14 +0200 Subject: [PATCH 01/13] [refactor] Move and rename last FFN layer's TaskGraph ID field and its get/set methods to AbstractFFNLayers for consistency --- .../model/fp16/GraniteFP16LayerPlanner.java | 2 +- .../model/fp16/LlamaFP16LayerPlanner.java | 2 +- .../model/fp16/Phi3FP16LayerPlanner.java | 2 +- .../model/fp16/Qwen2FP16LayerPlanner.java | 2 +- .../model/fp16/Qwen3FP16LayerPlanner.java | 2 +- .../model/q8_0/GraniteQ8_0LayerPlanner.java | 2 +- .../model/q8_0/LlamaQ8_0LayerPlanner.java | 2 +- .../model/q8_0/Phi3Q8_0LayerPlanner.java | 2 +- .../model/q8_0/Qwen2Q8_0LayerPlanner.java | 2 +- .../model/q8_0/Qwen3Q8_0LayerPlanner.java | 2 +- .../tornadovm/layers/AbstractFFNLayers.java | 14 +++++--------- .../gpullama3/tornadovm/layers/AbstractLayer.java | 14 +------------- .../layers/type/fp16/GraniteFP16FFNLayers.java | 2 +- .../layers/type/fp16/LlamaFP16FFNLayers.java | 2 +- .../layers/type/fp16/Phi3FP16FFNLayers.java | 2 +- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 2 +- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 2 +- .../layers/type/q8_0/GraniteQ8_0FFNLayers.java | 2 +- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 2 +- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 2 +- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 2 +- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 2 +- 22 files changed, 26 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index 7cc97d64..4a6c853a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -20,7 +20,7 @@ public GraniteFP16LayerPlanner(GraniteState state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java index 0480d513..7cad2949 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java @@ -21,7 +21,7 @@ public LlamaFP16LayerPlanner(LlamaState state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java index b1f41515..8a2554c9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java @@ -28,7 +28,7 @@ public Phi3FP16LayerPlanner(Phi3State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(),this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java index b87dafd8..f1458ee1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java @@ -28,6 +28,6 @@ public Qwen2FP16LayerPlanner(Qwen2State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java index ef3dcee4..d03a98ee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java @@ -28,7 +28,7 @@ public Qwen3FP16LayerPlanner(Qwen3State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java index ee818080..4be8d900 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java @@ -21,6 +21,6 @@ public GraniteQ8_0LayerPlanner(GraniteState state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java index 2560d8d7..2d25e14b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java @@ -21,7 +21,7 @@ public LlamaQ8_0LayerPlanner(LlamaState state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java index dfa0ec0e..ea4d8f07 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java @@ -29,7 +29,7 @@ public Phi3Q8_0LayerPlanner(Phi3State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java index 34cb1a42..4966a5d9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java @@ -29,7 +29,7 @@ public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java index fb4d4ef3..c5489cb6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java @@ -29,6 +29,6 @@ public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(),this.schedulerType); } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index 3b0620c6..97a2666e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -20,7 +20,7 @@ */ public abstract class AbstractFFNLayers extends AbstractLayer { - protected String lastTaskGraphID; + protected String lastFFNLayerTaskGraphID; protected final SchedulerType schedulerType; @@ -51,15 +51,11 @@ protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, public abstract List getFfnLayerTaskGraphs(); /** - * Get the ID of the last task graph. - * - * Used by LogitsLayer to know where to attach the final logits computation. The last transformer layer's task graph ID is needed to chain the logits computation after all FFN layers complete. - * - * @return Task graph ID of the last FFN layer + * Returns the TaskGraph ID of the last FFN layer. + * Used by the logits layer to chain its consumeFromDevice call. */ - @Override - public String getLastTaskGraphID() { - return lastTaskGraphID; + public String getLastFFNLayerTaskGraphID() { + return lastFFNLayerTaskGraphID; } /** diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index 6578777f..b9cdf3e3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -19,7 +19,7 @@ public abstract class AbstractLayer { /** Common constants used in tasks & worker-grid sizing. */ protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; protected static final int THREAD_SCALE_FOR_LOGITS = 8; - protected static String lastTaskGraphID; + protected final Weights weights; protected final Configuration config; /** Often a small context/config buffer passed into kernels. Use your real type if available. */ @@ -58,16 +58,4 @@ protected static T requireWeightsType(Object weights, Class expectedType, protected TaskGraph configureLayerDataTransfers(TaskGraph tg, int layerIndex) { return tg; } - - public String getLastTaskGraphID() { - return lastTaskGraphID; - } - - public void setupLastID(String taskGraphID) { - if (lastTaskGraphID == null) { - lastTaskGraphID = taskGraphID; - } else if (!lastTaskGraphID.equals(taskGraphID)) { - throw new IllegalStateException("Task graph IDs do not match: " + lastTaskGraphID + " vs " + taskGraphID); - } - } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 0387ae54..4567a460 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -87,7 +87,7 @@ List setupFFNLayered() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } return ffnLayer.snapshot(); }).toList(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 8d105e89..8497ba60 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -86,7 +86,7 @@ List setupFFNLayered() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } return ffnLayer.snapshot(); }).toList(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index a98f9860..11df3b40 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -115,7 +115,7 @@ List setupFFNLayered() { for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); if (layerIndex == phi3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } ffnGraphs.add(ffnLayer.snapshot()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index a6f1c95c..88cfe462 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -140,7 +140,7 @@ List setupFFNLayered() { for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); if (layerIndex == qwen2Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } ffnGraphs.add(ffnLayer.snapshot()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index b99a1ab3..46d400a0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -128,7 +128,7 @@ List setupFFNLayered() { return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); if (i == qwen3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } return ffnLayer.snapshot(); }).toList(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index b7e036d4..a724f34f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -50,7 +50,7 @@ List setupFFNLayered() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } return ffnLayer.snapshot(); }).toList(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index ba1b6a79..39934546 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -45,7 +45,7 @@ List setupFFNLayered() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } return ffnLayer.snapshot(); }).toList(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 46b0737d..bfe748de 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -103,7 +103,7 @@ List setupFFNLayered() { for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); if (layerIndex == phi3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } ffnGraphs.add(ffnLayer.snapshot()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 7aee3b83..1997090c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -145,7 +145,7 @@ List setupFFNLayered() { for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen2Q8_0FFNLayer((Qwen2TornadoWeights) weights, layerIndex); if (layerIndex == qwen2Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } ffnGraphs.add(ffnLayer.snapshot()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 6aea5559..a54fe678 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -137,7 +137,7 @@ List setupFFNLayered() { for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); if (layerIndex == qwen3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); + this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); } ffnGraphs.add(ffnLayer.snapshot()); } From 4ff791994dfd11c6daa208756ecc548f984ed040 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 24 Mar 2026 18:43:53 +0200 Subject: [PATCH 02/13] [refactor] Rename setupFFNLayered() to setupFFNLayerTaskGraphs() and abstractify it for visibility and consistency across all FFN layers --- .../tornadovm/layers/AbstractFFNLayers.java | 7 +++++++ .../layers/type/fp16/GraniteFP16FFNLayers.java | 5 +++-- .../layers/type/fp16/LlamaFP16FFNLayers.java | 5 +++-- .../layers/type/fp16/Phi3FP16FFNLayers.java | 5 +++-- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 5 +++-- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 5 +++-- .../layers/type/q8_0/GraniteQ8_0FFNLayers.java | 18 ++---------------- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 18 ++---------------- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 5 +++-- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 5 +++-- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 5 +++-- 11 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index 97a2666e..5305b7a7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -41,6 +41,13 @@ protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, this.schedulerType = schedulerType; } + /** + * Creates the TornadoVM {@link uk.ac.manchester.tornado.api.TaskGraph} for the FFN layers. + * It creates one TaskGraph per layer and snapshots it to produce an {@link ImmutableTaskGraph} per layer. + * It also stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}. + */ + protected abstract List setupFFNLayerTaskGraphs(); + /** * Returns all task graphs for the FFN layers. * diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 4567a460..d8241b82 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -27,7 +27,7 @@ public class GraniteFP16FFNLayers extends AbstractFFNLayers { public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); + this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -83,7 +83,8 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); if (i == config.numberOfLayers() - 1) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 8497ba60..5fc834cd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -26,7 +26,7 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers { public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); + this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -82,7 +82,8 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); if (i == config.numberOfLayers() - 1) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 11df3b40..f809eadc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -42,7 +42,7 @@ public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh this.phi3State = state; this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -110,7 +110,8 @@ public List getFfnLayerTaskGraphs() { /** * Setup all FFN layers for all transformer layers */ - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { List ffnGraphs = new ArrayList<>(); for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 88cfe462..5de9af9c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -42,7 +42,7 @@ public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -135,7 +135,8 @@ public List getFfnLayerTaskGraphs() { return ffnLayerTaskGraphs; } - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers()); for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 46d400a0..865417a1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -55,7 +55,7 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -124,7 +124,8 @@ public List getFfnLayerTaskGraphs() { /** * Setup all FFN layers for all transformer layers */ - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); if (i == qwen3Config.numberOfLayers() - 1) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index a724f34f..f565cd57 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -28,25 +28,11 @@ public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return null; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - List setupFFNLayered() { + protected List setupFFNLayerTaskGraphs() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); if (i == config.numberOfLayers() - 1) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 39934546..0bbf6b85 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -23,25 +23,11 @@ public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return null; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - List setupFFNLayered() { + protected List setupFFNLayerTaskGraphs() { return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); if (i == config.numberOfLayers() - 1) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index bfe748de..f1bf0836 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -41,7 +41,7 @@ public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh this.phi3State = state; this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -98,7 +98,8 @@ public List getFfnLayerTaskGraphs() { /** * Setup all FFN layers for all transformer layers */ - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { List ffnGraphs = new ArrayList<>(); for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 1997090c..53c9b293 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -46,7 +46,7 @@ public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWe super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -137,7 +137,8 @@ public List getFfnLayerTaskGraphs() { /** * Setup all FFN layers for all transformer layers */ - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { List ffnGraphs = new ArrayList<>(); qwen2State.temp.init(0.0f); qwen2State.tempFFN.init(0.0f); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index a54fe678..3d893360 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -60,7 +60,7 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayered(); + ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); } @Override @@ -127,7 +127,8 @@ public List getFfnLayerTaskGraphs() { /** * Setup all FFN layers for all transformer layers */ - List setupFFNLayered() { + @Override + protected List setupFFNLayerTaskGraphs() { List ffnGraphs = new ArrayList<>(); qwen3State.temp.init(0.0f); qwen3State.tempFFN.init(0.0f); From 16659c76d9b23622af0f0dddf92dcac43603b935 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 13:31:26 +0200 Subject: [PATCH 03/13] [refactor] Remove unused fields and methods across FFN layer implementations --- .../tornadovm/layers/AbstractFFNLayers.java | 10 +++----- .../tornadovm/layers/AbstractLayer.java | 23 +++-------------- .../tornadovm/layers/Activation.java | 7 ------ .../tornadovm/layers/ActivationGranite.java | 25 ------------------- .../type/fp16/GraniteFP16FFNLayers.java | 19 +------------- .../layers/type/fp16/LlamaFP16FFNLayers.java | 20 ++------------- .../layers/type/fp16/Phi3FP16FFNLayers.java | 19 +------------- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 19 +------------- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 19 +------------- .../type/q8_0/GraniteQ8_0FFNLayers.java | 7 +----- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 3 +-- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 19 +------------- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 19 +------------- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 20 +-------------- 14 files changed, 18 insertions(+), 211 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index 5305b7a7..b62b1532 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -11,12 +11,10 @@ /** * Abstract base class for all FFN (Feed-Forward Network) layer implementations. * - * Extends AbstractLayer and adds FFN-specific methods: - getFfnLayerTaskGraphs(): Returns task graphs for all transformer layers - getLastTaskGraphID(): Tracks the ID of the last task graph + * Each subclass builds N ImmutableTaskGraphs (one per FFN layer) via + * {@link #setupFFNLayerTaskGraphs}, covering RMSNorm, Attention, and FFN computations. * - * All model-specific FFN layers extend this: - LlamaFP16FFNLayers, Qwen2FP16FFNLayers, Qwen3FP16FFNLayers, Phi3FP16FFNLayers - LlamaQ8_0FFNLayers, Qwen2Q8_0FFNLayers, Qwen3Q8_0FFNLayers, - * Phi3Q8_0FFNLayers - * - * Used by FP16LayerPlanner and Q8_0LayerPlanner template methods for type-safe polymorphic access to any FFN layer implementation. + * Model-specific subclasses: Llama, Qwen2, Qwen3, Phi3, Granite — each in FP16 and Q8_0 variants. */ public abstract class AbstractFFNLayers extends AbstractLayer { @@ -55,7 +53,7 @@ protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, * * @return List of immutable task graphs (one per transformer layer) */ - public abstract List getFfnLayerTaskGraphs(); + public abstract List getFFNLayerTaskGraphs(); /** * Returns the TaskGraph ID of the last FFN layer. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index b9cdf3e3..3e2a1beb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -4,15 +4,11 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; -import java.util.ArrayList; -import java.util.List; - /** - * Minimal base with common fields/utilities so subclasses compile cleanly. Adjust or remove fields if they already exist in your project. + * Abstract base class for Activations, FFN Layers, and Logits. */ public abstract class AbstractLayer { @@ -22,17 +18,10 @@ public abstract class AbstractLayer { protected final Weights weights; protected final Configuration config; - /** Often a small context/config buffer passed into kernels. Use your real type if available. */ + protected final State state; protected final KernelContext context = new KernelContext(); - /** Collected snapshots for scheduling / debugging. */ - protected final List taskGraphs = new ArrayList<>(); - /** Optional: track the "main" task graph for the layer if one exists. */ - protected TaskGraph taskGraph; - /** Shared runtime objects (exposed because kernels expect them). */ - protected State state; protected AbstractLayer(String taskGraphName, State state, Weights weights, Configuration config) { - this.taskGraph = null; this.state = state; this.weights = weights; this.config = config; @@ -48,13 +37,7 @@ protected static T requireWeightsType(Object weights, Class expectedType, public abstract GridScheduler updateGridScheduler(GridScheduler scheduler); - public abstract GridScheduler getGridScheduler(); - - public abstract TaskGraph getTaskGraph(); - - public abstract ImmutableTaskGraph getImmutableTaskGraph(); - - /** Allow subclasses to override if they need custom transfers. */ + /** Allow subclasses to override if they need custom data transfers. */ protected TaskGraph configureLayerDataTransfers(TaskGraph tg, int layerIndex) { return tg; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index e7822786..ccd49d63 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -50,17 +50,10 @@ public GridScheduler updateGridScheduler(GridScheduler scheduler) { return scheduler; } - @Override - public GridScheduler getGridScheduler() { - return null; - } - - @Override public TaskGraph getTaskGraph() { return activationUpdate; } - @Override public ImmutableTaskGraph getImmutableTaskGraph() { return activationUpdate.snapshot(); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java index 20002dac..1345d0df 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -41,28 +41,3 @@ public ActivationGranite(String taskGraphHandle, State state, Weights weights, G } // @formatter:on } - - @Override - public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid worker = new WorkerGrid1D(config.dim()); - worker.setLocalWork(128, 1, 1); - scheduler.addWorkerGrid("activationUpdate.updateX", worker); - return scheduler; - } - - @Override - public GridScheduler getGridScheduler() { - return null; - } - - @Override - public TaskGraph getTaskGraph() { - return activationUpdate; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return activationUpdate.snapshot(); - } - -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index d8241b82..4a471e2f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -21,8 +21,6 @@ public class GraniteFP16FFNLayers extends AbstractFFNLayers { - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; List ffnLayerTaskGraphs; public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { @@ -64,22 +62,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 5fc834cd..fc1c69c2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -20,9 +20,7 @@ public class LlamaFP16FFNLayers extends AbstractFFNLayers { - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; - List ffnLayerTaskGraphs; + private List ffnLayerTaskGraphs; public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); @@ -64,21 +62,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) } @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index f809eadc..0d1f6b28 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -33,8 +33,6 @@ public class Phi3FP16FFNLayers extends AbstractFFNLayers { private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { @@ -88,22 +86,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 5de9af9c..8732918c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -34,8 +34,6 @@ public class Qwen2FP16FFNLayers extends AbstractFFNLayers { // Typed references to Qwen2-specific state and config private final Qwen2State qwen2State; private final Qwen2Configuration qwen2Config; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { @@ -116,22 +114,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 865417a1..ba991c9f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -38,8 +38,6 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { private final int nEmbdHead; private final int nEmbdGqa; private final int gqa; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { @@ -102,22 +100,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index f565cd57..1f88d600 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -1,11 +1,7 @@ package org.beehive.gpullama3.tornadovm.layers.type.q8_0; import org.beehive.gpullama3.inference.state.GraniteState; -import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.granite.Granite; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -23,7 +19,6 @@ public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { - GridScheduler scheduler; List ffnLayerTaskGraphs; public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { @@ -314,7 +309,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 0bbf6b85..91b60b20 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -18,7 +18,6 @@ public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { - GridScheduler scheduler; List ffnLayerTaskGraphs; public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { @@ -309,7 +308,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index f1bf0836..ee5744f8 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -32,8 +32,6 @@ public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { @@ -76,22 +74,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 53c9b293..e279c2bb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -34,8 +34,6 @@ */ public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; // Typed references to Qwen2-specific state and config @@ -115,22 +113,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 3d893360..77c19af1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -31,9 +31,6 @@ */ public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { - String lastTaskGraphID; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; List ffnLayerTaskGraphs; // Typed references to Qwen3-specific state and config @@ -105,22 +102,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { + public List getFFNLayerTaskGraphs() { return ffnLayerTaskGraphs; } From 5baba0e7bdd9dfda6de2f6229e4abccbed035ec7 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 16:51:47 +0200 Subject: [PATCH 04/13] [refactor] Generalize AbstractFFNLayers and unify task graph setup logic across all subclasses --- .../base/QuantizationPlannerFactory.java | 8 +- .../model/fp16/MistralFP16LayerPlanner.java | 21 ++ .../model/q8_0/MistralQ8_0LayerPlanner.java | 21 ++ .../quantization/FP16LayerPlanner.java | 4 +- .../quantization/Q8_0LayerPlanner.java | 4 +- .../tornadovm/layers/AbstractFFNLayers.java | 73 +++--- .../type/fp16/GraniteFP16FFNLayers.java | 30 +-- .../layers/type/fp16/LlamaFP16FFNLayers.java | 34 +-- .../type/fp16/MistralFP16FFNLayers.java | 208 ++++++++++++++++++ .../layers/type/fp16/Phi3FP16FFNLayers.java | 82 +++---- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 34 +-- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 78 +++---- .../type/q8_0/GraniteQ8_0FFNLayers.java | 28 +-- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 32 +-- .../type/q8_0/MistralQ8_0FFNLayers.java | 189 ++++++++++++++++ .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 83 +++---- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 42 +--- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 76 ++----- 18 files changed, 633 insertions(+), 414 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index ca844e51..9efa6a08 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -11,11 +11,13 @@ import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner; @@ -54,7 +56,8 @@ public static GenericLayerPlanner create(GGMLType quantization, State state, Mod // ============ FP16 Planners ============ private static GenericLayerPlanner createFP16Planner(State state, Model model) { return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model); + case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model); + case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model); case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); @@ -67,7 +70,8 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { // ============ Q8_0 Planners ============ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); + case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); + case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model); case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java new file mode 100644 index 00000000..3c82ebbb --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; + +public class MistralFP16LayerPlanner extends FP16LayerPlanner { + + public MistralFP16LayerPlanner(LlamaState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + buildForwardPlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java new file mode 100644 index 00000000..4e9dd449 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; + +public class MistralQ8_0LayerPlanner extends Q8_0LayerPlanner { + + public MistralQ8_0LayerPlanner(LlamaState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + buildForwardPlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java index 9be5e08b..a8d5e12d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java @@ -25,7 +25,7 @@ public abstract class FP16LayerPlanner extends QuantizedLayerPlanner { protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; + protected AbstractFFNLayers ffnLayers; protected LogitsFP16Layer logitsLayer; protected List immutableTaskGraphs; @@ -56,7 +56,7 @@ protected final void setupTornadoForwardPlan() { activationLayer.updateGridScheduler(masterScheduler); // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); + allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); ffnLayers.updateGridScheduler(masterScheduler); // 3. Logits layer (common to all models) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java index f10f9686..06aefee5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java @@ -26,7 +26,7 @@ public abstract class Q8_0LayerPlanner extends QuantizedLayerPlanner { protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; + protected AbstractFFNLayers ffnLayers; protected LogitsQ8_0Layer logitsLayer; // Cache for task graphs and scheduler (set once, reused) @@ -59,7 +59,7 @@ protected final void setupTornadoForwardPlan() { activationLayer.updateGridScheduler(masterScheduler); // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); + allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); ffnLayers.updateGridScheduler(masterScheduler); // 3. Logits layer (common to all models) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index b62b1532..1a73897e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -5,55 +5,70 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; import java.util.List; +import java.util.stream.IntStream; /** * Abstract base class for all FFN (Feed-Forward Network) layer implementations. - * - * Each subclass builds N ImmutableTaskGraphs (one per FFN layer) via - * {@link #setupFFNLayerTaskGraphs}, covering RMSNorm, Attention, and FFN computations. - * - * Model-specific subclasses: Llama, Qwen2, Qwen3, Phi3, Granite — each in FP16 and Q8_0 variants. + * Extended by model and quantization-specific subclasses that provide specific implementations. */ -public abstract class AbstractFFNLayers extends AbstractLayer { +public abstract class AbstractFFNLayers extends AbstractLayer { + + /** + * List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer. + * Build by {@link #setupFFNLayers()}. + */ + private List ffnLayerITGs; + protected final W weights; + protected final C config; protected String lastFFNLayerTaskGraphID; protected final SchedulerType schedulerType; + protected AbstractFFNLayers(String taskGraphName, State state, W weights, C config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config); + this.weights = weights; + this.config = config; + this.schedulerType = schedulerType; + // the ffnLayerITGs is initialized on subclasses + // due to some model-specific values (i.e. in Qwen3) + } /** - * Constructor for FFN layers. - * - * @param taskGraphName - * Name for the task graph - * @param state - * Runtime state (LlamaState, Qwen2State, etc.) - * @param weights - * Model weights (FP16Weights, Q8_0Weights, etc.) - * @param config - * Model configuration + * Creates the {@link ImmutableTaskGraph} list for each FFN layer. */ - protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config); - this.schedulerType = schedulerType; + protected void setupFFNLayers() { + int numLayers = config.numberOfLayers(); + + this.ffnLayerITGs = IntStream.range(0, numLayers) + .mapToObj(this::setupFFNLayer) + .toList(); } /** - * Creates the TornadoVM {@link uk.ac.manchester.tornado.api.TaskGraph} for the FFN layers. - * It creates one TaskGraph per layer and snapshots it to produce an {@link ImmutableTaskGraph} per layer. - * It also stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}. + * Creates the TaskGraph for a specific FFN layer and produces the {@link ImmutableTaskGraph}. + * In addition, it stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}. */ - protected abstract List setupFFNLayerTaskGraphs(); + private ImmutableTaskGraph setupFFNLayer(int layerIndex) { + TaskGraph tg = createFFNLayerTaskGraph(layerIndex); + + if (layerIndex == config.numberOfLayers() - 1) { + lastFFNLayerTaskGraphID = tg.getTaskGraphName(); + } + + return tg.snapshot(); + } /** - * Returns all task graphs for the FFN layers. - * - * For a model with N transformer layers, this returns N ImmutableTaskGraphs, one for each layer (containing RMSNorm, Attention, FFN computations). - * - * @return List of immutable task graphs (one per transformer layer) + * Model and quantization-specific implementation of the FFN layer task graph. */ - public abstract List getFFNLayerTaskGraphs(); + protected abstract TaskGraph createFFNLayerTaskGraph(int layerIndex); + + public List getFFNLayerImmutableTaskGraphs() { + return ffnLayerITGs; + } /** * Returns the TaskGraph ID of the last FFN layer. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 4a471e2f..80f4a6ea 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -11,21 +11,15 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class GraniteFP16FFNLayers extends AbstractFFNLayers { -public class GraniteFP16FFNLayers extends AbstractFFNLayers { - - List ffnLayerTaskGraphs; - - public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { + public GraniteFP16FFNLayers(String taskGraph, State state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -62,21 +56,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - @Override - protected List setupFFNLayerTaskGraphs() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); - if (i == config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (LlamaFP16FFNLayers) @@ -163,7 +142,8 @@ protected List setupFFNLayerTaskGraphs() { * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) * */ - TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index fc1c69c2..56f2c0c3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -1,30 +1,23 @@ package org.beehive.gpullama3.tornadovm.layers.type.fp16; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class LlamaFP16FFNLayers extends AbstractFFNLayers { -public class LlamaFP16FFNLayers extends AbstractFFNLayers { - - private List ffnLayerTaskGraphs; - - public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + public LlamaFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -61,22 +54,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - @Override - protected List setupFFNLayerTaskGraphs() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (LlamaFP16FFNLayers) @@ -163,7 +140,8 @@ protected List setupFFNLayerTaskGraphs() { * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) * */ - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java new file mode 100644 index 00000000..499fc176 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java @@ -0,0 +1,208 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class MistralFP16FFNLayers extends AbstractFFNLayers { + + public MistralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + setupFFNLayers(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + return tornadoForwardScheduler; + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, + state.wrapQ, + state.wrapK, + state.wrapV, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + config.dim(), + config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, + state.wrapK, + state.wrapV, + state.wrapKeyCache, + state.wrapValueCache, + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, + state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.tempFFN, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + config.dim(), + config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 0d1f6b28..065c6802 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. * @@ -26,21 +22,18 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3FP16FFNLayers extends AbstractFFNLayers { +public class Phi3FP16FFNLayers extends AbstractFFNLayers { // Typed references to Phi3-specific state and config private final Phi3State phi3State; - private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - List ffnLayerTaskGraphs; public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; - this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -86,26 +79,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - @Override - protected List setupFFNLayerTaskGraphs() { - List ffnGraphs = new ArrayList<>(); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); - if (layerIndex == phi3Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Phi3FP16FFNLayers - Fully Optimized) @@ -191,9 +164,12 @@ protected List setupFFNLayerTaskGraphs() { * • Inline SiLU+GLU: No intermediate wrapHb buffer needed * */ - TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + Phi3TornadoWeights weights = (Phi3TornadoWeights) this.weights; var taskGraphName = "layer_" + layerIndex; var unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights @@ -214,8 +190,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.temp, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -230,8 +206,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapK, // output K phi3State.wrapV, // output V weights.rms_att_weightLayered[layerIndex].asFloatArray(), phi3State.temp, // RMS scale - weights.wqkvLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // dim - phi3Config.kvDim(), // kvDim + weights.wqkvLayered[layerIndex].asHalfFloatArray(), config.dim(), // dim + config.kvDim(), // kvDim LOCAL_WORK_GROUP_SIZE_ALLOC); // Fused Phi3 RoPE Rotation + KV Cache Write @@ -242,11 +218,11 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapV, // V vectors (in only) phi3State.wrapKeyCache, // key cache (out) phi3State.wrapValueCache, // value cache (out) - phi3Config.numberOfKeyValueHeads(), // nHeadKv - phi3Config.headSize(), // head dimension - phi3Config.kvDim(), // kvDim + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // head dimension + config.kvDim(), // kvDim layerIndex, // layer index for cache offset - phi3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, @@ -254,21 +230,21 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapKeyCache, // key cache phi3State.wrapValueCache, // value cache phi3State.wrapXb, // output: attention result - phi3Config.numberOfHeads(), // nHeads - phi3Config.headSize(), // headSize - phi3Config.kvDim(), // kvDim - phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) phi3State.positionHolder, // position layerIndex, // layer index - phi3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapXb, // input: attention output phi3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim] - phi3Config.dim(), // input dim - phi3Config.dim(), // output dim + config.dim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -279,24 +255,24 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.tempFFN, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size // Final normalization (non-NVIDIA only) if (shouldUseFinalNormalization()) { unifiedLayer.task("ffn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) - phi3Config.dim(), // dimension - phi3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } unifiedLayer.task("rms_ffn_silu", Phi3Kernels::fusedRmsNormFFNGateUpSiLU, context, phi3State.wrapX, // input phi3State.wrapHbU, // output (direct to final FFN buffer) weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), phi3State.tempFFN, // RMS scale - weights.wUpLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // input dim - phi3Config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) + weights.wUpLayered[layerIndex].asHalfFloatArray(), config.dim(), // input dim + config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual @@ -304,8 +280,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { context, phi3State.wrapHbU, // input: FFN intermediate phi3State.wrapX, // output: wrapX += wDown · wrapHbU weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] - phi3Config.hiddenDim(), // input dim - phi3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); unifiedLayer.persistOnDevice(phi3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 8732918c..d6ac0eee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -10,16 +10,12 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.WorkerGrid2D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen2FP16FFNLayers: FP16 FFN layers for Qwen2 with Group Query Attention (GQA) support. * @@ -29,18 +25,15 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2FP16FFNLayers extends AbstractFFNLayers { +public class Qwen2FP16FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen2-specific state and config + // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; - private final Qwen2Configuration qwen2Config; - List ffnLayerTaskGraphs; public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; - this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -114,23 +107,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - @Override - protected List setupFFNLayerTaskGraphs() { - List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers()); - for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); - if (layerIndex == qwen2Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Qwen2FP16FFNLayers - Optimized) @@ -219,9 +195,11 @@ protected List setupFFNLayerTaskGraphs() { * • No Q/K RMSNorm: Unlike Qwen3, Qwen2 doesn't normalize Q/K after projection * */ - TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index ba991c9f..60b3cc3e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; - /** * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. * @@ -25,11 +21,10 @@ * * Works directly with Qwen3State to access and mutate Qwen3-specific state fields like tempQcur and tempKcur. */ -public class Qwen3FP16FFNLayers extends AbstractFFNLayers { +public class Qwen3FP16FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen3-specific state and config + // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; - private final Qwen3Configuration qwen3Config; // Qwen3-specific GQA parameters private final int nHeadKv; private final int nEmbdHeadK; @@ -38,12 +33,10 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { private final int nEmbdHead; private final int nEmbdGqa; private final int gqa; - List ffnLayerTaskGraphs; public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; - this.qwen3Config = config; // Initialize GQA parameters from Qwen3Config this.nHeadKv = config.numberOfKeyValueHeads(); @@ -53,7 +46,7 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -74,7 +67,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads(); WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead); - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int qDim0 = nEmbdHeadK * config.numberOfHeads(); int kvDim0 = nEmbdGqa; int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; @@ -100,24 +93,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - @Override - protected List setupFFNLayerTaskGraphs() { - return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); - if (i == qwen3Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (Qwen3FP16FFNLayers) @@ -205,13 +180,14 @@ protected List setupFFNLayerTaskGraphs() { * • RoPE theta: 1,000,000 (vs Llama's 10,000 or 50,000) * */ - TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; // === Dimension Parameters === - int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int qDim = nEmbdHeadK * config.numberOfHeads(); // Q output size (full heads) int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) - int inputDim = qwen3Config.dim(); // Model dimension + int inputDim = config.dim(); // Model dimension var unifiedLayer = new TaskGraph(taskGraphName); @@ -244,8 +220,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.temp, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -283,11 +259,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapK, // K vectors (in/out) weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights - qwen3Config.numberOfHeads(), // nHeads (Q heads) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + config.numberOfHeads(), // nHeads (Q heads) + config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) nEmbdHead, // head dimension nEmbdHead, // local memory size - qwen3Config.rmsNormEps()); // epsilon + config.rmsNormEps()); // epsilon // Fused RoPE Rotation + KV Cache Write unifiedLayer.task("rope_and_kv_cache", @@ -299,11 +275,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapV, // V vectors (in only) qwen3State.wrapKeyCache, // key cache (out) qwen3State.wrapValueCache, // value cache (out) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv + config.numberOfKeyValueHeads(), // nHeadKv nEmbdHead, // head dimension nEmbdGqa, // kvDim layerIndex, // layer index for cache offset - qwen3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -313,13 +289,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapKeyCache, // key cache qwen3State.wrapValueCache, // value cache qwen3State.wrapXb, // output: attention result - qwen3Config.numberOfHeads(), // nHeads + config.numberOfHeads(), // nHeads nEmbdHead, // headSize nEmbdGqa, // kvDim gqa, // kvMul (nHeads / nHeadKv) qwen3State.positionHolder, // position layerIndex, // layer index - qwen3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", @@ -328,8 +304,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapXb, // input: attention output qwen3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim x qDim] - nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) - qwen3Config.dim(), // output dim + nEmbdHeadK * config.numberOfHeads(), // input dim (qDim) + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -342,8 +318,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.tempFFN, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -352,8 +328,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, // scale factor (in/out) - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused RMS Apply + Gate/Up Projection + SiLU + GLU @@ -366,8 +342,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.tempFFN, // RMS scale factor weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) - qwen3Config.dim(), // input dimension - qwen3Config.hiddenDim(), // hidden dimension + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual @@ -377,8 +353,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapHb, // input: FFN intermediate qwen3State.wrapX, // output: wrapX += W2 · wrapHb weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) - qwen3Config.hiddenDim(), // input dim - qwen3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) .persistOnDevice(qwen3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index 1f88d600..8a91c75a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -9,32 +9,15 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; - -public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { - - List ffnLayerTaskGraphs; +public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); - } - - @Override - protected List setupFFNLayerTaskGraphs() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); - if (i == config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - return ffnLayer.snapshot(); - }).toList(); + setupFFNLayers(); } /** @@ -124,7 +107,8 @@ protected List setupFFNLayerTaskGraphs() { * Quantization: Q8_0 format (8-bit weights with block-wise scaling) * */ - TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -309,10 +293,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { if (schedulerType == SchedulerType.NVIDIA) { // Flash Attention (optimized for NVIDIA GPUs) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 91b60b20..1f33f090 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -2,38 +2,21 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { -public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { - - List ffnLayerTaskGraphs; - - public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { + public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); - } - - @Override - protected List setupFFNLayerTaskGraphs() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - return ffnLayer.snapshot(); - }).toList(); + setupFFNLayers(); } // @formatter:off @@ -124,7 +107,8 @@ protected List setupFFNLayerTaskGraphs() { * Quantization: Q8_0 format (8-bit weights with block-wise scaling) * */ - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -308,10 +292,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { return unifiedLayer.task("attention", diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java new file mode 100644 index 00000000..64864114 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java @@ -0,0 +1,189 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class MistralQ8_0FFNLayers extends AbstractFFNLayers { + + public MistralQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + setupFFNLayers(); + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, + state.wrapQ, state.wrapK, state.wrapV, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + config.kvDim(), config.headSize(), layerIndex, config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + + return tornadoForwardScheduler; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index ee5744f8..8f693adc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Phi3Q8_0FFNLayers: Q8_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support. * @@ -25,21 +21,18 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { +public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { - // Typed references to Phi3-specific state and config + // Typed reference to Phi3-specific state private final Phi3State phi3State; - private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - List ffnLayerTaskGraphs; public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; - this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -74,26 +67,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - @Override - protected List setupFFNLayerTaskGraphs() { - List ffnGraphs = new ArrayList<>(); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); - if (layerIndex == phi3Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Phi3Q8_0FFNLayers - Fully Optimized) @@ -191,9 +164,11 @@ protected List setupFFNLayerTaskGraphs() { * • wrapHbG: Gate output merged into final computation * */ - TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; var unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Copy-in quantized weights per layer (Q8_0 format: ByteArray) @@ -216,8 +191,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex context, phi3State.temp, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -240,8 +215,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights phi3State.temp, // RMS scale (precomputed) weights.wqkvLayered[layerIndex].asByteArray(), // Q8 combined QKV [opSize × dim] - phi3Config.dim(), // input dim - phi3Config.kvDim(), // K/V output dim + config.dim(), // input dim + config.kvDim(), // K/V output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // Fused Phi3 RoPE Rotation + KV Cache Write @@ -253,11 +228,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapV, // V vectors (in only) phi3State.wrapKeyCache, // key cache (out) phi3State.wrapValueCache, // value cache (out) - phi3Config.numberOfKeyValueHeads(), // nHeadKv - phi3Config.headSize(), // head dimension - phi3Config.kvDim(), // kvDim + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // head dimension + config.kvDim(), // kvDim layerIndex, // layer index for cache offset - phi3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -267,13 +242,13 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapKeyCache, // key cache phi3State.wrapValueCache, // value cache phi3State.wrapXb, // output: attention result - phi3Config.numberOfHeads(), // nHeads - phi3Config.headSize(), // headSize - phi3Config.kvDim(), // kvDim - phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) phi3State.positionHolder, // position layerIndex, // layer index - phi3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual (Q8 dequantization) unifiedLayer.task("attn_output_proj", @@ -282,8 +257,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapXb, // input: attention output phi3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asByteArray(), // Q8 Wo [dim × dim] - phi3Config.dim(), // input dim - phi3Config.dim(), // output dim + config.dim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -296,8 +271,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex context, phi3State.tempFFN, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -306,8 +281,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) - phi3Config.dim(), // dimension - phi3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused: RMS apply + Q8 gate/up matmul + SiLU activation + GLU @@ -319,8 +294,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights phi3State.tempFFN, // RMS scale (precomputed) weights.wUpLayered[layerIndex].asByteArray(), // Q8 combined gate+up [2×hiddenDim × dim] - phi3Config.dim(), // input dim - phi3Config.hiddenDim(), // output dim + config.dim(), // input dim + config.hiddenDim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual (Q8 dequantization) @@ -330,8 +305,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapHbU, // input: FFN intermediate phi3State.wrapX, // output: wrapX += wDown · wrapHbU weights.wDownLayered[layerIndex].asByteArray(), // Q8 wDown [dim × hiddenDim] - phi3Config.hiddenDim(), // input dim - phi3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); unifiedLayer.persistOnDevice(phi3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index e279c2bb..6cdf32db 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -10,16 +10,12 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.WorkerGrid2D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen2Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support. * @@ -32,19 +28,14 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { - - List ffnLayerTaskGraphs; - - // Typed references to Qwen2-specific state and config +public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { + // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; - private final Qwen2Configuration qwen2Config; public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; - this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -113,34 +104,13 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - /** - * Setup all FFN layers for all transformer layers + * Setup a single transformer layer for Qwen2 with Q8_0 quantization and GQA */ @Override - protected List setupFFNLayerTaskGraphs() { - List ffnGraphs = new ArrayList<>(); - qwen2State.temp.init(0.0f); - qwen2State.tempFFN.init(0.0f); - - for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen2Q8_0FFNLayer((Qwen2TornadoWeights) weights, layerIndex); - if (layerIndex == qwen2Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); - /** - * Setup a single transformer layer for Qwen2 with Q8_0 quantization and GQA - */ - TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { - TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 77c19af1..43b88fb1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen3Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support. * @@ -29,13 +25,10 @@ * Works directly with Qwen3State to access and mutate Qwen3-specific state fields * like tempQcur and tempKcur. */ -public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { - - List ffnLayerTaskGraphs; +public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen3-specific state and config + // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; - private final Qwen3Configuration qwen3Config; // Qwen3-specific GQA parameters private final int nHeadKv; @@ -49,7 +42,6 @@ public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; - this.qwen3Config = config; this.nHeadKv = config.numberOfKeyValueHeads(); this.nEmbdHeadK = config.numberOfHeadsKey(); this.nEmbdHeadV = config.numberOfHeadsValue(); @@ -57,7 +49,7 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayerTaskGraphs(); + setupFFNLayers(); } @Override @@ -79,7 +71,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int qDim0 = nEmbdHeadK * config.numberOfHeads(); int kvDim0 = nEmbdGqa; int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; @@ -102,41 +94,17 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - public List getFFNLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - @Override - protected List setupFFNLayerTaskGraphs() { - List ffnGraphs = new ArrayList<>(); - qwen3State.temp.init(0.0f); - qwen3State.tempFFN.init(0.0f); - qwen3State.tempQcur.init(0.0f); - qwen3State.tempKcur.init(0.0f); - - for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); - if (layerIndex == qwen3Config.numberOfLayers() - 1) { - this.lastFFNLayerTaskGraphID = ffnLayer.getTaskGraphName(); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - /** * Setup a single transformer layer for Qwen3 with GQA (Q8_0 quantized) */ - TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; // === Dimension Parameters === - int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int qDim = nEmbdHeadK * config.numberOfHeads(); // Q output size (full heads) int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) - int inputDim = qwen3Config.dim(); // Model dimension + int inputDim = config.dim(); // Model dimension var unifiedLayer = new TaskGraph(taskGraphName); @@ -208,11 +176,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapK, // K vectors (in/out) weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights - qwen3Config.numberOfHeads(), // nHeads (Q heads) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + config.numberOfHeads(), // nHeads (Q heads) + config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) nEmbdHead, // head dimension nEmbdHead, // local memory size - qwen3Config.rmsNormEps()); // epsilon + config.rmsNormEps()); // epsilon // Fused RoPE Rotation + KV Cache Write unifiedLayer.task("rope_and_kv_cache", @@ -224,11 +192,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapV, // V vectors (in only) qwen3State.wrapKeyCache, // key cache (out) qwen3State.wrapValueCache, // value cache (out) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv + config.numberOfKeyValueHeads(), // nHeadKv nEmbdHead, // head dimension nEmbdGqa, // kvDim layerIndex, // layer index for cache offset - qwen3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -238,13 +206,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapKeyCache, // key cache qwen3State.wrapValueCache, // value cache qwen3State.wrapXb, // output: attention result - qwen3Config.numberOfHeads(), // nHeads + config.numberOfHeads(), // nHeads nEmbdHead, // headSize nEmbdGqa, // kvDim gqa, // kvMul (nHeads / nHeadKv) qwen3State.positionHolder, // position layerIndex, // layer index - qwen3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", @@ -253,7 +221,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapXb, // input: attention output qwen3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asByteArray(), // Wo [dim x qDim] - nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) + nEmbdHeadK * config.numberOfHeads(), // input dim (qDim) config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); @@ -267,8 +235,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.tempFFN, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -277,8 +245,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, // scale factor (in/out) - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused RMS Apply + Gate/Up Projection + SiLU + GLU @@ -291,8 +259,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.tempFFN, // RMS scale factor weights.w1Layered[layerIndex].asByteArray(), // W1 (gate) Q8_0 weights.w3Layered[layerIndex].asByteArray(), // W3 (up) Q8_0 - qwen3Config.dim(), // input dimension - qwen3Config.hiddenDim(), // hidden dimension + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual From bf6823d512e62b3f612edf3835b09a2f187e30a9 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:26:16 +0200 Subject: [PATCH 05/13] [refactor] Move GenericLayerPlanner to layerplanner package --- .../tornadovm/{ => layerplanner}/GenericLayerPlanner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/main/java/org/beehive/gpullama3/tornadovm/{ => layerplanner}/GenericLayerPlanner.java (83%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java similarity index 83% rename from src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java index 5a151212..9e211051 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.layerplanner; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; From 3c09bca5c416774a7a8bba15d9c82beed19a12e6 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:27:31 +0200 Subject: [PATCH 06/13] [refactor] Introduce AbstractLogitsLayer to centralize shared logic for logits layers --- .../tornadovm/layers/AbstractLogitsLayer.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java new file mode 100644 index 00000000..f3215501 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java @@ -0,0 +1,48 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Abstract base for all logits layers (final vocabulary projection step). + * + * Holds the shared fields and calls the protected buildLogitsTaskGraph() hook once + * during construction. Subclasses implement buildLogitsTaskGraph() to define the + * quantization-specific task sequence; Granite variants override it to swap in + * their scaled kernel. + */ +public abstract class AbstractLogitsLayer extends AbstractLayer { + + protected final String lastTaskGraphID; + protected final SchedulerType schedulerType; + private final TaskGraph logitsTaskGraph; + + protected AbstractLogitsLayer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config); + this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; + TornadoWeights tornadoWeights = requireWeightsType(weights, TornadoWeights.class, + getClass().getSimpleName(), "TornadoTensor"); + this.logitsTaskGraph = buildLogitsTaskGraph(tornadoWeights, config); + } + + /** + * Builds the logits task graph. Called once from the constructor. + * Subclasses define the quantization-specific task sequence here. + */ + protected abstract TaskGraph buildLogitsTaskGraph(TornadoWeights weights, Configuration config); + + public final TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + public final ImmutableTaskGraph getImmutableTaskGraph() { + return logitsTaskGraph.snapshot(); + } +} From dc76fdedbdd59d91868f051c1b2768649fa8455b Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:34:36 +0200 Subject: [PATCH 07/13] [refactor] Move QuantizationPlannerFactory to layerplanner package root level --- .../layerplanner/{base => }/QuantizationPlannerFactory.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/{base => }/QuantizationPlannerFactory.java (97%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java index 9efa6a08..f3bc3d3a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.base; +package org.beehive.gpullama3.tornadovm.layerplanner; import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; @@ -8,7 +8,6 @@ import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; From 080fea4d4b6e9ecf05f209a1490ca2bdc6fd336a Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:41:04 +0200 Subject: [PATCH 08/13] [refactor] Simplify and unify layer planners by centralizing inference plan creation logic in layerplanner package --- .../base/QuantizedLayerPlanner.java | 85 ++++++++++++++----- .../model/fp16/GraniteFP16LayerPlanner.java | 15 ++-- .../model/fp16/LlamaFP16LayerPlanner.java | 16 ++-- .../model/fp16/MistralFP16LayerPlanner.java | 2 +- .../model/fp16/Phi3FP16LayerPlanner.java | 14 +-- .../model/fp16/Qwen2FP16LayerPlanner.java | 13 +-- .../model/fp16/Qwen3FP16LayerPlanner.java | 16 ++-- .../model/q8_0/GraniteQ8_0LayerPlanner.java | 13 +-- .../model/q8_0/LlamaQ8_0LayerPlanner.java | 16 ++-- .../model/q8_0/MistralQ8_0LayerPlanner.java | 2 +- .../model/q8_0/Phi3Q8_0LayerPlanner.java | 14 +-- .../model/q8_0/Qwen2Q8_0LayerPlanner.java | 14 +-- .../model/q8_0/Qwen3Q8_0LayerPlanner.java | 15 ++-- .../quantization/FP16LayerPlanner.java | 59 +------------ .../quantization/Q8_0LayerPlanner.java | 65 +------------- 15 files changed, 113 insertions(+), 246 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java index f95d5406..5ef5a1f6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java @@ -4,21 +4,28 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; +import java.util.ArrayList; +import java.util.List; + /** * Abstract base for all quantization-specific planners. * - * Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc. + * Extracts common state from the model, detects the hardware scheduler type, + * and assembles the full execution plan via createTornadoInferencePlan(). + * Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation. */ -public abstract class QuantizedLayerPlanner implements GenericLayerPlanner { - - // Common state for all quantizations - protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; - protected static final int THREAD_SCALE_FOR_LOGITS = 8; +public abstract class QuantizedLayerPlanner + implements GenericLayerPlanner { protected final S state; protected final C config; @@ -27,9 +34,14 @@ public abstract class QuantizedLayerPlanner ffnLayers; + protected AbstractLogitsLayer logitsLayer; + + private List immutableTaskGraphs; + private GridScheduler gridScheduler; + + @SuppressWarnings("unchecked") protected QuantizedLayerPlanner(S state, Model model) { this.state = state; this.model = model; @@ -40,26 +52,53 @@ protected QuantizedLayerPlanner(S state, Model model) { validateQuantizationType(); } - /** - * Override in subclasses to validate correct quantization format. E.g., FP16LayerPlanner checks: weights instanceof FP16Weights - */ + /** Validates that the model weights match the expected quantization type. */ protected abstract void validateQuantizationType(); /** - * Override in subclasses for model-specific initialization + * Creates the TornadoVM inference execution pipeline. + * It represents the entire Feed-Forward Network (FFN) and consists of: + *
    + *
  • Activation layer
  • + *
  • FFN layers (N transformer layers, model-specific)
  • + *
  • Logits layer
  • + *
+ *

+ * Each component is represented as an {@link ImmutableTaskGraph}, along with a + * corresponding {@link GridScheduler} configuration that defines how tasks are + * mapped on the GPU. + *

+ * This method assembles all components into a unified execution pipeline and + * caches the resulting task graphs and scheduler for reuse across inference runs. */ - protected abstract void initializeLayerComponents(); + protected final void createTornadoInferencePlan() { + List allTaskGraphs = new ArrayList<>(); + GridScheduler masterScheduler = new GridScheduler(); + + // 1. Activation layer (common to all models) + allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); + activationLayer.updateGridScheduler(masterScheduler); + + // 2. FFN layers (N transformer layers - model-specific) + allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); + ffnLayers.updateGridScheduler(masterScheduler); + + // 3. Logits layer (common to all models) + allTaskGraphs.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(masterScheduler); - // Common helper methods for all quantizations - protected C getConfig() { - return config; + // Cache for future retrievals + this.immutableTaskGraphs = allTaskGraphs; + this.gridScheduler = masterScheduler; } - protected W getWeights() { - return weights; + @Override + public final List getImmutableTaskGraphs() { + return this.immutableTaskGraphs; } - protected S getState() { - return state; + @Override + public final GridScheduler getGridScheduler() { + return this.gridScheduler; } -} \ No newline at end of file +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index 4a6c853a..8f8e5539 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -10,17 +10,12 @@ import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; public class GraniteFP16LayerPlanner extends FP16LayerPlanner { + public GraniteFP16LayerPlanner(GraniteState state, Model model) { super(state, model); - validateQuantizationType(); - setupTornadoForwardPlan(); - } - - @Override - protected void initializeLayerComponents() { - this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); - this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastFFNLayerTaskGraphID(), this.schedulerType); + this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config); + this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); } - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java index 7cad2949..ccb376db 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java @@ -13,15 +13,9 @@ public class LlamaFP16LayerPlanner extends FP16LayerPlanner extends QuantizedLayerPlanner { - protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; - protected LogitsFP16Layer logitsLayer; - - protected List immutableTaskGraphs; - protected GridScheduler gridScheduler ; - protected FP16LayerPlanner(S state, Model model) { super(state, model); - initializeLayerComponents(); } @Override @@ -42,49 +30,4 @@ protected void validateQuantizationType() { throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType()); } } - - @Override - protected void initializeLayerComponents() { - } - - protected final void setupTornadoForwardPlan() { - List allTaskGraphs = new ArrayList<>(); - GridScheduler masterScheduler = new GridScheduler(); - - // 1. Activation layer (common to all models) - allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); - activationLayer.updateGridScheduler(masterScheduler); - - // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); - ffnLayers.updateGridScheduler(masterScheduler); - - // 3. Logits layer (common to all models) - allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); - logitsLayer.updateGridScheduler(masterScheduler); - - // Cache for future retrievals - this.immutableTaskGraphs = allTaskGraphs; - this.gridScheduler = masterScheduler; - } - - /** - * Returns cached task graphs (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - public final List getImmutableTaskGraphs() { - return this.immutableTaskGraphs; - } - - /** - * Returns cached scheduler (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - @Override - public final GridScheduler getGridScheduler() { - return this.gridScheduler; - } - -} \ No newline at end of file +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java index 06aefee5..d76e37ed 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java @@ -17,25 +17,12 @@ /** * Base for all Q8_0-quantized layer planners. - * - * Subclasses: LlamaQ8_0LayerPlanner, Qwen2Q8_0LayerPlanner, etc. - * - * Q8_0 Specific: - Uses 8-bit integer quantization with uniform scaling per 32-element block - Weights: weights.xxxByteArray arrays - Compute: dequantize on-the-fly during matmul - Memory: 2x - * compression vs FP16 */ -public abstract class Q8_0LayerPlanner extends QuantizedLayerPlanner { - - protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; - protected LogitsQ8_0Layer logitsLayer; - - // Cache for task graphs and scheduler (set once, reused) - protected List cachedTaskGraphs; - protected GridScheduler cachedScheduler; +public abstract class Q8_0LayerPlanner + extends QuantizedLayerPlanner { protected Q8_0LayerPlanner(S state, Model model) { super(state, model); - initializeLayerComponents(); } @Override @@ -44,50 +31,4 @@ protected void validateQuantizationType() { throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType()); } } - - @Override - protected void initializeLayerComponents() { - // Override in subclasses (LlamaQ8_0LayerPlanner, etc.) - } - - protected final void setupTornadoForwardPlan() { - List allTaskGraphs = new ArrayList<>(); - GridScheduler masterScheduler = new GridScheduler(); - - // 1. Activation layer (common to all models) - allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); - activationLayer.updateGridScheduler(masterScheduler); - - // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); - ffnLayers.updateGridScheduler(masterScheduler); - - // 3. Logits layer (common to all models) - allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); - logitsLayer.updateGridScheduler(masterScheduler); - - // Cache for future retrievals - this.cachedTaskGraphs = allTaskGraphs; - this.cachedScheduler = masterScheduler; - } - - /** - * Returns cached task graphs (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - public final List getImmutableTaskGraphs() { - return this.cachedTaskGraphs; - } - - /** - * Returns cached scheduler (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - @Override - public final GridScheduler getGridScheduler() { - return this.cachedScheduler; - } - -} \ No newline at end of file +} From 170db11db294777adb51fd0c51820c6a05fb04b2 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:42:02 +0200 Subject: [PATCH 09/13] [refactor] Move QuantizedLayerPlanner to layerplanner package root-level --- .../org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java | 3 ++- .../layerplanner/{base => }/QuantizedLayerPlanner.java | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/{base => }/QuantizedLayerPlanner.java (96%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index eadd2e68..a42dc310 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java similarity index 96% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java index 5ef5a1f6..1b7c1953 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java @@ -1,10 +1,9 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.base; +package org.beehive.gpullama3.tornadovm.layerplanner; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; From a26f2a9ff9a0acd5ca5d9aa23aacee1b06403755 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 22:52:01 +0200 Subject: [PATCH 10/13] [refactor] Move FP16LayerPlanner and Q8_0LayerPlanner to quantized-specific subpackages --- .../fp16}/FP16LayerPlanner.java | 12 ++---------- .../model/fp16/GraniteFP16LayerPlanner.java | 1 - .../model/fp16/LlamaFP16LayerPlanner.java | 1 - .../model/fp16/MistralFP16LayerPlanner.java | 1 - .../model/fp16/Phi3FP16LayerPlanner.java | 1 - .../model/fp16/Qwen2FP16LayerPlanner.java | 1 - .../model/fp16/Qwen3FP16LayerPlanner.java | 1 - .../model/q8_0/GraniteQ8_0LayerPlanner.java | 1 - .../model/q8_0/LlamaQ8_0LayerPlanner.java | 1 - .../model/q8_0/MistralQ8_0LayerPlanner.java | 1 - .../model/q8_0/Phi3Q8_0LayerPlanner.java | 1 - .../q8_0}/Q8_0LayerPlanner.java | 12 ++---------- .../model/q8_0/Qwen2Q8_0LayerPlanner.java | 1 - .../model/q8_0/Qwen3Q8_0LayerPlanner.java | 1 - 14 files changed, 4 insertions(+), 32 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/{quantization => model/fp16}/FP16LayerPlanner.java (62%) rename src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/{quantization => model/q8_0}/Q8_0LayerPlanner.java (62%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java similarity index 62% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java index 6c57ce13..3b33d98c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java @@ -1,19 +1,11 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.quantization; +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.ArrayList; -import java.util.List; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner; /** * Base for all FP16-quantized layer planners. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index 8f8e5539..488e4b39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java index ccb376db..d6042c39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java index 4d21b972..78e1bffb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java index 6a439ffb..5b50529a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java index 657d1a54..a3dcc5e2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java index 796414f8..76239ae6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java index 31685f42..f7735dc0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java index ace04d77..827ae538 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java index a6d11567..65e0150c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java index 64a91d2e..6f088d36 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java similarity index 62% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java index d76e37ed..b525d2a3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java @@ -1,19 +1,11 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.quantization; +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.ArrayList; -import java.util.List; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner; /** * Base for all Q8_0-quantized layer planners. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java index 84917752..090812e7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java index 75eb8f5c..b5a19be2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; From 4be811a77bfa1ca68de78df174aea166064a5f67 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 23:14:34 +0200 Subject: [PATCH 11/13] [refactor] Unify task graph setup for Logits layers and centralize shared logic into AbstractLogitsLayer --- .../tornadovm/layers/AbstractLogitsLayer.java | 8 +- .../layers/type/fp16/LogitsFP16Layer.java | 44 ++----- .../type/fp16/LogitsGraniteFP16Layer.java | 98 +++++----------- .../type/q8_0/LogitsGraniteQ8_0Layer.java | 85 ++++---------- .../layers/type/q8_0/LogitsQ8_0Layer.java | 110 ++++++++---------- 5 files changed, 114 insertions(+), 231 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java index f3215501..37288e5f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java @@ -29,14 +29,10 @@ protected AbstractLogitsLayer(String name, State state, Weights weights, Configu this.schedulerType = schedulerType; TornadoWeights tornadoWeights = requireWeightsType(weights, TornadoWeights.class, getClass().getSimpleName(), "TornadoTensor"); - this.logitsTaskGraph = buildLogitsTaskGraph(tornadoWeights, config); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); } - /** - * Builds the logits task graph. Called once from the constructor. - * Subclasses define the quantization-specific task sequence here. - */ - protected abstract TaskGraph buildLogitsTaskGraph(TornadoWeights weights, Configuration config); + protected abstract TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config); public final TaskGraph getTaskGraph() { return logitsTaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index d2a81407..bf938a0d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -2,39 +2,29 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsFP16Layer extends AbstractLayer { - - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +public class LogitsFP16Layer extends AbstractLogitsLayer { - public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(name, state, weights, config); - this.lastTaskGraphID = lastTaskGraphID; - this.schedulerType = schedulerType; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); @@ -96,7 +86,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), rmsLocalSize()); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); @@ -106,18 +96,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; + /** Local workgroup size for RMS norm. Qwen2 requires a smaller group (32 vs 256). */ + protected int rmsLocalSize() { + return weights instanceof Qwen2TornadoWeights ? 32 : 256; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java index d55d707e..54ec9641 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -2,51 +2,40 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +/** + * Granite-specific FP16 logits layer. + * Identical to LogitsFP16Layer except vocab_proj uses a scaled kernel (logitScale). + */ public class LogitsGraniteFP16Layer extends LogitsFP16Layer { - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; - public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config, lastTaskGraphID, schedulerType); - this.lastTaskGraphID = lastTaskGraphID; - this.schedulerType = schedulerType; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + GraniteConfiguration graniteCfg = (GraniteConfiguration) config; var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - // Kernel context context, - // Output buffer state.wrapLogits, - // Intermediate FP16 buffer state.wrapXbFP16, - // Weights weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()); @@ -54,72 +43,43 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, + state.wrapX, + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, - state.tempLogits, // in/out: combines partial sums - config.dim(), // dimension - config.rmsNormEps()); // epsilon + state.tempLogits, + config.dim(), + config.rmsNormEps()); } logits.task("rms_apply_fp16", TransformerComputeKernels::mapContextWithQuantizeLogits, context, - state.wrapXbFP16, // output: normalized (FP16) - state.wrapX, // input: hidden state - weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights - state.tempLogits); // scale factor from reduction + state.wrapXbFP16, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); - // === Vocabulary Projection === + // === Vocabulary Projection (Granite: scaled by logitScale) === logits.task("vocab_proj", GraniteKernels::matrixVectorGenericWithGraniteScale, context, - state.wrapXbFP16, // input (FP16) - state.wrapLogits, // output - weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights - config.dim(), // input dimension - config.vocabularySize(), // output dimension + state.wrapXbFP16, + state.wrapLogits, + weights.wclsByteArray.asHalfFloatArray(), + config.dim(), + config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, - config.logitScale()); // granite logit scaling + graniteCfg.logitScale()); - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - return tornadoForwardScheduler; - } - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; - } } - diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java index d1fd4f0c..c583bb00 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -2,69 +2,51 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{ - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +/** + * Granite-specific Q8_0 logits layer. + * Identical to LogitsQ8_0Layer except vocab_proj uses a scaled kernel (logitScale). + */ +public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer { - public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType); - this.lastTaskGraphID = lastTaskGraphID; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); - this.schedulerType = schedulerType; - } - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); - return tornadoForwardScheduler; + public LogitsGraniteQ8_0Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + GraniteConfiguration graniteCfg = (GraniteConfiguration) config; var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, // - state.wrapLogits, // - weights.wclsByteArray.asByteArray(), // + context, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), weights.rms_final_weight_as_floatArray); // === Final RMS Normalization === logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, + state.wrapX, + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", @@ -74,6 +56,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat config.dim(), config.rmsNormEps()); } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, @@ -81,8 +64,9 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits); - // === Vocabulary vocab_proj === - logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, // + // === Vocabulary Projection (Granite: scaled by logitScale) === + logits.task("vocab_proj", + GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, context, state.wrapX, state.wrapLogits, @@ -90,29 +74,10 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, - config.logitScale() + graniteCfg.logitScale()); - ); - - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; - } - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index d54bb3ef..ed0d6cfc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -2,69 +2,49 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsQ8_0Layer extends AbstractLayer { - - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +public class LogitsQ8_0Layer extends AbstractLogitsLayer { - public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config); - this.lastTaskGraphID = lastTaskGraphID; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); - this.schedulerType = schedulerType; - } - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); - return tornadoForwardScheduler; + public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, // - state.wrapLogits, // - weights.wclsByteArray.asByteArray(), // - weights.rms_final_weight_as_floatArray); + context, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + weights.rms_final_weight_as_floatArray); // === Final RMS Normalization === logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", @@ -74,42 +54,44 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con config.dim(), config.rmsNormEps()); } - logits.task("mapContextLogits", - TransformerComputeKernels::reductionOneBlock2WithLogits, - context, - state.wrapX, - weights.rms_final_weight_as_floatArray.asFloatArray(), + + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits); - - // === Vocabulary vocab_proj === - logits.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // - context, - state.wrapX, - state.wrapLogits, - weights.wclsByteArray.asByteArray(), - config.dim(), - config.vocabularySize(), + + // === Vocabulary Projection === + logits.task("vocab_proj", + TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), rmsLocalSize()); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; } - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; + /** Local workgroup size for RMS norm. Qwen2 requires a smaller group (32 vs 256). */ + protected int rmsLocalSize() { + return weights instanceof Qwen2TornadoWeights ? 32 : 256; } - } From 3aa399b05292498f3ad4e0cec2f88133eec6a84e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 23:16:36 +0200 Subject: [PATCH 12/13] [refactor] Simplify and unify Activation task graph setup logic --- .../tornadovm/layers/AbstractLayer.java | 3 ++ .../tornadovm/layers/Activation.java | 54 +++++++++---------- .../tornadovm/layers/ActivationGranite.java | 45 ++++++++-------- 3 files changed, 48 insertions(+), 54 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index 3e2a1beb..f34f5777 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -27,6 +27,9 @@ protected AbstractLayer(String taskGraphName, State state, Weights weights, Conf this.config = config; } + /** + * Ensures weights are of the expected type. + */ @SuppressWarnings("unchecked") protected static T requireWeightsType(Object weights, Class expectedType, String layerName, String layout) { if (expectedType.isInstance(weights)) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index ccd49d63..6ba36f39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -7,55 +7,49 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class Activation extends AbstractLayer { - private final TaskGraph activationUpdate; - - public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { - super(taskGraphHandle, state, weights, config); - - KernelContext kernelContext = new KernelContext(); - - // @formatter:off - switch (config.quantization()) { - case "FP16" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - case "Q8_0" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } + private final TaskGraph activationTaskGraph; + + public Activation(String name, State state, Weights weights, Configuration config) { + super(name, state, weights, config); + this.activationTaskGraph = setupActivationTaskGraph(name); + } + + // @formatter:off + protected TaskGraph setupActivationTaskGraph(String name) { + return switch (config.quantization()) { + case "FP16" -> new TaskGraph(name) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, context, (HalfFloatArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + case "Q8_0" -> new TaskGraph(name) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, context, (ByteArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); - } - // @formatter:on + }; } + // @formatter:on @Override public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid worker = new WorkerGrid1D(config.dim()); - worker.setLocalWork(128, 1, 1); - scheduler.addWorkerGrid("activationUpdate.updateX", worker); + WorkerGrid worker = WorkerGridFactory.genericWorker(config.dim(), 128); + scheduler.addWorkerGrid(activationTaskGraph.getTaskGraphName() + ".updateX", worker); return scheduler; } public TaskGraph getTaskGraph() { - return activationUpdate; + return activationTaskGraph; } public ImmutableTaskGraph getImmutableTaskGraph() { - return activationUpdate.snapshot(); + return activationTaskGraph.snapshot(); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java index 1345d0df..61bfa573 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -4,40 +4,37 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +/** + * Granite-specific activation: applies an embedding scale factor during the FP32 conversion. + * Overrides only the task graph builder; all other behaviour is inherited from Activation. + */ public class ActivationGranite extends Activation { - private final TaskGraph activationUpdate; // Granite is a special case where activation X is scaled by embedding scale float value that inside model. public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) { super(taskGraphHandle, state, weights, config); + } - KernelContext kernelContext = new KernelContext(); - - // @formatter:off - switch (config.quantization()) { - case "FP16" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale()) - .persistOnDevice(state.wrapX); - } - case "Q8_0" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale()) - .persistOnDevice(state.wrapX); - } + // @formatter:off + @Override + protected TaskGraph setupActivationTaskGraph(String handle) { + GraniteConfiguration cfg = (GraniteConfiguration) config; + return switch (config.quantization()) { + case "FP16" -> new TaskGraph(handle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, context, (HalfFloatArray) state.embeddingX, state.wrapX, cfg.embeddingScale()) + .persistOnDevice(state.wrapX); + case "Q8_0" -> new TaskGraph(handle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, context, (ByteArray) state.embeddingX, state.wrapX, cfg.embeddingScale()) + .persistOnDevice(state.wrapX); default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); - } - // @formatter:on + }; } + // @formatter:on +} From a3f145093d24df65b9cfbc9dfaff829f4c4fe85d Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 26 Mar 2026 23:28:02 +0200 Subject: [PATCH 13/13] Introduce DeepSeekR1Qwen model and integrate with Qwen2ModelLoader --- .../model/format/Qwen3ChatFormat.java | 7 +++++- .../model/loader/Qwen2ModelLoader.java | 5 +++- .../gpullama3/model/qwen2/DeepSeekR1Qwen.java | 23 +++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 7e873237..b6d2e798 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -101,7 +101,12 @@ public List encodeMessage(Message message) { @Override public int getBeginOfText() { - return beginOfText; + if (beginOfText == -1) { + // deepseek-r1 + return startHeader; + } else { + return beginOfText; + } } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index f3ac590e..2e3d8002 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -12,6 +12,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; +import org.beehive.gpullama3.model.qwen2.DeepSeekR1Qwen; import org.beehive.gpullama3.model.qwen2.Qwen2; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; @@ -85,7 +86,9 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig // Qwen2.5-Coder uses <|endoftext|> as stop-token. ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + return isDeepSeekR1DistillQwen + ? new DeepSeekR1Qwen(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)) + : new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java b/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java new file mode 100644 index 00000000..cb49c3b6 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java @@ -0,0 +1,23 @@ +package org.beehive.gpullama3.model.qwen2; + +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.Tokenizer; + +public class DeepSeekR1Qwen extends Qwen2 { + + public DeepSeekR1Qwen(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(configuration, tokenizer, weights, chatFormat); + } + + @Override + public ModelType getModelType() { + return ModelType.DEEPSEEK_R1_DISTILL_QWEN; + } + + @Override + public boolean shouldAddBeginOfText() { + return true; + } +}