diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 7e873237..b6d2e798 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -101,7 +101,12 @@ public List encodeMessage(Message message) { @Override public int getBeginOfText() { - return beginOfText; + if (beginOfText == -1) { + // deepseek-r1 + return startHeader; + } else { + return beginOfText; + } } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index f3ac590e..2e3d8002 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -12,6 +12,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; +import org.beehive.gpullama3.model.qwen2.DeepSeekR1Qwen; import org.beehive.gpullama3.model.qwen2.Qwen2; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; @@ -85,7 +86,9 @@ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weig // Qwen2.5-Coder uses <|endoftext|> as stop-token. ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + return isDeepSeekR1DistillQwen + ? new DeepSeekR1Qwen(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)) + : new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java b/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java new file mode 100644 index 00000000..cb49c3b6 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/DeepSeekR1Qwen.java @@ -0,0 +1,23 @@ +package org.beehive.gpullama3.model.qwen2; + +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.Tokenizer; + +public class DeepSeekR1Qwen extends Qwen2 { + + public DeepSeekR1Qwen(Qwen2Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(configuration, tokenizer, weights, chatFormat); + } + + @Override + public ModelType getModelType() { + return ModelType.DEEPSEEK_R1_DISTILL_QWEN; + } + + @Override + public boolean shouldAddBeginOfText() { + return true; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index eadd2e68..a42dc310 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java similarity index 83% rename from src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java index 5a151212..9e211051 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.layerplanner; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java similarity index 88% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java index ca844e51..f3bc3d3a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.base; +package org.beehive.gpullama3.tornadovm.layerplanner; import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; @@ -8,14 +8,15 @@ import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner; @@ -54,7 +55,8 @@ public static GenericLayerPlanner create(GGMLType quantization, State state, Mod // ============ FP16 Planners ============ private static GenericLayerPlanner createFP16Planner(State state, Model model) { return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model); + case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model); + case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model); case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); @@ -67,7 +69,8 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { // ============ Q8_0 Planners ============ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); + case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); + case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model); case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java new file mode 100644 index 00000000..1b7c1953 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java @@ -0,0 +1,103 @@ +package org.beehive.gpullama3.tornadovm.layerplanner; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; + +import java.util.ArrayList; +import java.util.List; + +/** + * Abstract base for all quantization-specific planners. + * + * Extracts common state from the model, detects the hardware scheduler type, + * and assembles the full execution plan via createTornadoInferencePlan(). + * Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation. + */ +public abstract class QuantizedLayerPlanner + implements GenericLayerPlanner { + + protected final S state; + protected final C config; + protected final W weights; + protected final KernelContext context; + protected final Model model; + protected final SchedulerType schedulerType; + + protected Activation activationLayer; + protected AbstractFFNLayers ffnLayers; + protected AbstractLogitsLayer logitsLayer; + + private List immutableTaskGraphs; + private GridScheduler gridScheduler; + + @SuppressWarnings("unchecked") + protected QuantizedLayerPlanner(S state, Model model) { + this.state = state; + this.model = model; + this.config = (C) model.configuration(); + this.weights = (W) model.weights(); + this.context = new KernelContext(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + validateQuantizationType(); + } + + /** Validates that the model weights match the expected quantization type. */ + protected abstract void validateQuantizationType(); + + /** + * Creates the TornadoVM inference execution pipeline. + * It represents the entire Feed-Forward Network (FFN) and consists of: + *
    + *
  • Activation layer
  • + *
  • FFN layers (N transformer layers, model-specific)
  • + *
  • Logits layer
  • + *
+ *

+ * Each component is represented as an {@link ImmutableTaskGraph}, along with a + * corresponding {@link GridScheduler} configuration that defines how tasks are + * mapped on the GPU. + *

+ * This method assembles all components into a unified execution pipeline and + * caches the resulting task graphs and scheduler for reuse across inference runs. + */ + protected final void createTornadoInferencePlan() { + List allTaskGraphs = new ArrayList<>(); + GridScheduler masterScheduler = new GridScheduler(); + + // 1. Activation layer (common to all models) + allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); + activationLayer.updateGridScheduler(masterScheduler); + + // 2. FFN layers (N transformer layers - model-specific) + allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); + ffnLayers.updateGridScheduler(masterScheduler); + + // 3. Logits layer (common to all models) + allTaskGraphs.add(logitsLayer.getImmutableTaskGraph()); + logitsLayer.updateGridScheduler(masterScheduler); + + // Cache for future retrievals + this.immutableTaskGraphs = allTaskGraphs; + this.gridScheduler = masterScheduler; + } + + @Override + public final List getImmutableTaskGraphs() { + return this.immutableTaskGraphs; + } + + @Override + public final GridScheduler getGridScheduler() { + return this.gridScheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java deleted file mode 100644 index f95d5406..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java +++ /dev/null @@ -1,65 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.base; - -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import uk.ac.manchester.tornado.api.KernelContext; - -/** - * Abstract base for all quantization-specific planners. - * - * Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc. - */ -public abstract class QuantizedLayerPlanner implements GenericLayerPlanner { - - // Common state for all quantizations - protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; - protected static final int THREAD_SCALE_FOR_LOGITS = 8; - - protected final S state; - protected final C config; - protected final W weights; - protected final KernelContext context; - protected final Model model; - protected final SchedulerType schedulerType; - - /** - * Constructor: validate quantization type, extract model components - */ - protected QuantizedLayerPlanner(S state, Model model) { - this.state = state; - this.model = model; - this.config = (C) model.configuration(); - this.weights = (W) model.weights(); - this.context = new KernelContext(); - this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); - validateQuantizationType(); - } - - /** - * Override in subclasses to validate correct quantization format. E.g., FP16LayerPlanner checks: weights instanceof FP16Weights - */ - protected abstract void validateQuantizationType(); - - /** - * Override in subclasses for model-specific initialization - */ - protected abstract void initializeLayerComponents(); - - // Common helper methods for all quantizations - protected C getConfig() { - return config; - } - - protected W getWeights() { - return weights; - } - - protected S getState() { - return state; - } -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java new file mode 100644 index 00000000..3b33d98c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java @@ -0,0 +1,25 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner; + +/** + * Base for all FP16-quantized layer planners. + */ +public abstract class FP16LayerPlanner extends QuantizedLayerPlanner { + + protected FP16LayerPlanner(S state, Model model) { + super(state, model); + } + + @Override + protected void validateQuantizationType() { + if (this.weights.getWeightType() != GGMLType.F16) { + throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index 7cc97d64..488e4b39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -4,23 +4,17 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; public class GraniteFP16LayerPlanner extends FP16LayerPlanner { + public GraniteFP16LayerPlanner(GraniteState state, Model model) { super(state, model); - validateQuantizationType(); - setupTornadoForwardPlan(); - } - - @Override - protected void initializeLayerComponents() { - this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); - this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config); + this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); } - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java index 0480d513..d6042c39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; @@ -13,15 +12,9 @@ public class LlamaFP16LayerPlanner extends FP16LayerPlanner { + + public MistralFP16LayerPlanner(LlamaState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java index b1f41515..5b50529a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; @@ -20,15 +19,9 @@ public class Phi3FP16LayerPlanner extends FP16LayerPlanner { + + public MistralQ8_0LayerPlanner(LlamaState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java index dfa0ec0e..6f088d36 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; @@ -21,15 +20,9 @@ public class Phi3Q8_0LayerPlanner extends Q8_0LayerPlanner + extends QuantizedLayerPlanner { + + protected Q8_0LayerPlanner(S state, Model model) { + super(state, model); + } + + @Override + protected void validateQuantizationType() { + if (this.weights.getWeightType() != GGMLType.Q8_0) { + throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java index 34cb1a42..090812e7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java @@ -4,7 +4,6 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; @@ -21,15 +20,9 @@ public class Qwen2Q8_0LayerPlanner extends Q8_0LayerPlanner extends QuantizedLayerPlanner { - - protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; - protected LogitsFP16Layer logitsLayer; - - protected List immutableTaskGraphs; - protected GridScheduler gridScheduler ; - - protected FP16LayerPlanner(S state, Model model) { - super(state, model); - initializeLayerComponents(); - } - - @Override - protected void validateQuantizationType() { - if (this.weights.getWeightType() != GGMLType.F16) { - throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType()); - } - } - - @Override - protected void initializeLayerComponents() { - } - - protected final void setupTornadoForwardPlan() { - List allTaskGraphs = new ArrayList<>(); - GridScheduler masterScheduler = new GridScheduler(); - - // 1. Activation layer (common to all models) - allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); - activationLayer.updateGridScheduler(masterScheduler); - - // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); - ffnLayers.updateGridScheduler(masterScheduler); - - // 3. Logits layer (common to all models) - allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); - logitsLayer.updateGridScheduler(masterScheduler); - - // Cache for future retrievals - this.immutableTaskGraphs = allTaskGraphs; - this.gridScheduler = masterScheduler; - } - - /** - * Returns cached task graphs (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - public final List getImmutableTaskGraphs() { - return this.immutableTaskGraphs; - } - - /** - * Returns cached scheduler (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - @Override - public final GridScheduler getGridScheduler() { - return this.gridScheduler; - } - -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java deleted file mode 100644 index f10f9686..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java +++ /dev/null @@ -1,93 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.quantization; - -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.ArrayList; -import java.util.List; - -/** - * Base for all Q8_0-quantized layer planners. - * - * Subclasses: LlamaQ8_0LayerPlanner, Qwen2Q8_0LayerPlanner, etc. - * - * Q8_0 Specific: - Uses 8-bit integer quantization with uniform scaling per 32-element block - Weights: weights.xxxByteArray arrays - Compute: dequantize on-the-fly during matmul - Memory: 2x - * compression vs FP16 - */ -public abstract class Q8_0LayerPlanner extends QuantizedLayerPlanner { - - protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; - protected LogitsQ8_0Layer logitsLayer; - - // Cache for task graphs and scheduler (set once, reused) - protected List cachedTaskGraphs; - protected GridScheduler cachedScheduler; - - protected Q8_0LayerPlanner(S state, Model model) { - super(state, model); - initializeLayerComponents(); - } - - @Override - protected void validateQuantizationType() { - if (this.weights.getWeightType() != GGMLType.Q8_0) { - throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType()); - } - } - - @Override - protected void initializeLayerComponents() { - // Override in subclasses (LlamaQ8_0LayerPlanner, etc.) - } - - protected final void setupTornadoForwardPlan() { - List allTaskGraphs = new ArrayList<>(); - GridScheduler masterScheduler = new GridScheduler(); - - // 1. Activation layer (common to all models) - allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); - activationLayer.updateGridScheduler(masterScheduler); - - // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); - ffnLayers.updateGridScheduler(masterScheduler); - - // 3. Logits layer (common to all models) - allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); - logitsLayer.updateGridScheduler(masterScheduler); - - // Cache for future retrievals - this.cachedTaskGraphs = allTaskGraphs; - this.cachedScheduler = masterScheduler; - } - - /** - * Returns cached task graphs (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - public final List getImmutableTaskGraphs() { - return this.cachedTaskGraphs; - } - - /** - * Returns cached scheduler (used by hardware strategy pattern). - * - * Removed from all model-specific planners - centralized here. - */ - @Override - public final GridScheduler getGridScheduler() { - return this.cachedScheduler; - } - -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index 3b0620c6..1a73897e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -5,61 +5,77 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; import java.util.List; +import java.util.stream.IntStream; /** * Abstract base class for all FFN (Feed-Forward Network) layer implementations. - * - * Extends AbstractLayer and adds FFN-specific methods: - getFfnLayerTaskGraphs(): Returns task graphs for all transformer layers - getLastTaskGraphID(): Tracks the ID of the last task graph - * - * All model-specific FFN layers extend this: - LlamaFP16FFNLayers, Qwen2FP16FFNLayers, Qwen3FP16FFNLayers, Phi3FP16FFNLayers - LlamaQ8_0FFNLayers, Qwen2Q8_0FFNLayers, Qwen3Q8_0FFNLayers, - * Phi3Q8_0FFNLayers - * - * Used by FP16LayerPlanner and Q8_0LayerPlanner template methods for type-safe polymorphic access to any FFN layer implementation. + * Extended by model and quantization-specific subclasses that provide specific implementations. */ -public abstract class AbstractFFNLayers extends AbstractLayer { +public abstract class AbstractFFNLayers extends AbstractLayer { - protected String lastTaskGraphID; + /** + * List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer. + * Build by {@link #setupFFNLayers()}. + */ + private List ffnLayerITGs; + protected final W weights; + protected final C config; + + protected String lastFFNLayerTaskGraphID; protected final SchedulerType schedulerType; + protected AbstractFFNLayers(String taskGraphName, State state, W weights, C config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config); + this.weights = weights; + this.config = config; + this.schedulerType = schedulerType; + // the ffnLayerITGs is initialized on subclasses + // due to some model-specific values (i.e. in Qwen3) + } /** - * Constructor for FFN layers. - * - * @param taskGraphName - * Name for the task graph - * @param state - * Runtime state (LlamaState, Qwen2State, etc.) - * @param weights - * Model weights (FP16Weights, Q8_0Weights, etc.) - * @param config - * Model configuration + * Creates the {@link ImmutableTaskGraph} list for each FFN layer. */ - protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config); - this.schedulerType = schedulerType; + protected void setupFFNLayers() { + int numLayers = config.numberOfLayers(); + + this.ffnLayerITGs = IntStream.range(0, numLayers) + .mapToObj(this::setupFFNLayer) + .toList(); } /** - * Returns all task graphs for the FFN layers. - * - * For a model with N transformer layers, this returns N ImmutableTaskGraphs, one for each layer (containing RMSNorm, Attention, FFN computations). - * - * @return List of immutable task graphs (one per transformer layer) + * Creates the TaskGraph for a specific FFN layer and produces the {@link ImmutableTaskGraph}. + * In addition, it stores the TaskGraph ID of the last FFN layer for use by the {@link AbstractLogitsLayer}. */ - public abstract List getFfnLayerTaskGraphs(); + private ImmutableTaskGraph setupFFNLayer(int layerIndex) { + TaskGraph tg = createFFNLayerTaskGraph(layerIndex); + + if (layerIndex == config.numberOfLayers() - 1) { + lastFFNLayerTaskGraphID = tg.getTaskGraphName(); + } + + return tg.snapshot(); + } /** - * Get the ID of the last task graph. - * - * Used by LogitsLayer to know where to attach the final logits computation. The last transformer layer's task graph ID is needed to chain the logits computation after all FFN layers complete. - * - * @return Task graph ID of the last FFN layer + * Model and quantization-specific implementation of the FFN layer task graph. + */ + protected abstract TaskGraph createFFNLayerTaskGraph(int layerIndex); + + public List getFFNLayerImmutableTaskGraphs() { + return ffnLayerITGs; + } + + /** + * Returns the TaskGraph ID of the last FFN layer. + * Used by the logits layer to chain its consumeFromDevice call. */ - @Override - public String getLastTaskGraphID() { - return lastTaskGraphID; + public String getLastFFNLayerTaskGraphID() { + return lastFFNLayerTaskGraphID; } /** diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index 6578777f..f34f5777 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -4,40 +4,32 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; -import java.util.ArrayList; -import java.util.List; - /** - * Minimal base with common fields/utilities so subclasses compile cleanly. Adjust or remove fields if they already exist in your project. + * Abstract base class for Activations, FFN Layers, and Logits. */ public abstract class AbstractLayer { /** Common constants used in tasks & worker-grid sizing. */ protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; protected static final int THREAD_SCALE_FOR_LOGITS = 8; - protected static String lastTaskGraphID; + protected final Weights weights; protected final Configuration config; - /** Often a small context/config buffer passed into kernels. Use your real type if available. */ + protected final State state; protected final KernelContext context = new KernelContext(); - /** Collected snapshots for scheduling / debugging. */ - protected final List taskGraphs = new ArrayList<>(); - /** Optional: track the "main" task graph for the layer if one exists. */ - protected TaskGraph taskGraph; - /** Shared runtime objects (exposed because kernels expect them). */ - protected State state; protected AbstractLayer(String taskGraphName, State state, Weights weights, Configuration config) { - this.taskGraph = null; this.state = state; this.weights = weights; this.config = config; } + /** + * Ensures weights are of the expected type. + */ @SuppressWarnings("unchecked") protected static T requireWeightsType(Object weights, Class expectedType, String layerName, String layout) { if (expectedType.isInstance(weights)) { @@ -48,26 +40,8 @@ protected static T requireWeightsType(Object weights, Class expectedType, public abstract GridScheduler updateGridScheduler(GridScheduler scheduler); - public abstract GridScheduler getGridScheduler(); - - public abstract TaskGraph getTaskGraph(); - - public abstract ImmutableTaskGraph getImmutableTaskGraph(); - - /** Allow subclasses to override if they need custom transfers. */ + /** Allow subclasses to override if they need custom data transfers. */ protected TaskGraph configureLayerDataTransfers(TaskGraph tg, int layerIndex) { return tg; } - - public String getLastTaskGraphID() { - return lastTaskGraphID; - } - - public void setupLastID(String taskGraphID) { - if (lastTaskGraphID == null) { - lastTaskGraphID = taskGraphID; - } else if (!lastTaskGraphID.equals(taskGraphID)) { - throw new IllegalStateException("Task graph IDs do not match: " + lastTaskGraphID + " vs " + taskGraphID); - } - } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java new file mode 100644 index 00000000..37288e5f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java @@ -0,0 +1,44 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Abstract base for all logits layers (final vocabulary projection step). + * + * Holds the shared fields and calls the protected buildLogitsTaskGraph() hook once + * during construction. Subclasses implement buildLogitsTaskGraph() to define the + * quantization-specific task sequence; Granite variants override it to swap in + * their scaled kernel. + */ +public abstract class AbstractLogitsLayer extends AbstractLayer { + + protected final String lastTaskGraphID; + protected final SchedulerType schedulerType; + private final TaskGraph logitsTaskGraph; + + protected AbstractLogitsLayer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config); + this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; + TornadoWeights tornadoWeights = requireWeightsType(weights, TornadoWeights.class, + getClass().getSimpleName(), "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + } + + protected abstract TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config); + + public final TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + public final ImmutableTaskGraph getImmutableTaskGraph() { + return logitsTaskGraph.snapshot(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index e7822786..6ba36f39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -7,62 +7,49 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class Activation extends AbstractLayer { - private final TaskGraph activationUpdate; + private final TaskGraph activationTaskGraph; - public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { - super(taskGraphHandle, state, weights, config); - - KernelContext kernelContext = new KernelContext(); + public Activation(String name, State state, Weights weights, Configuration config) { + super(name, state, weights, config); + this.activationTaskGraph = setupActivationTaskGraph(name); + } - // @formatter:off - switch (config.quantization()) { - case "FP16" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - case "Q8_0" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } + // @formatter:off + protected TaskGraph setupActivationTaskGraph(String name) { + return switch (config.quantization()) { + case "FP16" -> new TaskGraph(name) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, context, (HalfFloatArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + case "Q8_0" -> new TaskGraph(name) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, context, (ByteArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); - } - // @formatter:on + }; } + // @formatter:on @Override public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid worker = new WorkerGrid1D(config.dim()); - worker.setLocalWork(128, 1, 1); - scheduler.addWorkerGrid("activationUpdate.updateX", worker); + WorkerGrid worker = WorkerGridFactory.genericWorker(config.dim(), 128); + scheduler.addWorkerGrid(activationTaskGraph.getTaskGraphName() + ".updateX", worker); return scheduler; } - @Override - public GridScheduler getGridScheduler() { - return null; - } - - @Override public TaskGraph getTaskGraph() { - return activationUpdate; + return activationTaskGraph; } - @Override public ImmutableTaskGraph getImmutableTaskGraph() { - return activationUpdate.snapshot(); + return activationTaskGraph.snapshot(); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java index 20002dac..61bfa573 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -4,65 +4,37 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +/** + * Granite-specific activation: applies an embedding scale factor during the FP32 conversion. + * Overrides only the task graph builder; all other behaviour is inherited from Activation. + */ public class ActivationGranite extends Activation { - private final TaskGraph activationUpdate; // Granite is a special case where activation X is scaled by embedding scale float value that inside model. public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) { super(taskGraphHandle, state, weights, config); - - KernelContext kernelContext = new KernelContext(); - - // @formatter:off - switch (config.quantization()) { - case "FP16" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale()) - .persistOnDevice(state.wrapX); - } - case "Q8_0" -> { - this.activationUpdate = new TaskGraph(taskGraphHandle) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale()) - .persistOnDevice(state.wrapX); - } - default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); - } - // @formatter:on - } - - @Override - public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid worker = new WorkerGrid1D(config.dim()); - worker.setLocalWork(128, 1, 1); - scheduler.addWorkerGrid("activationUpdate.updateX", worker); - return scheduler; - } - - @Override - public GridScheduler getGridScheduler() { - return null; } + // @formatter:off @Override - public TaskGraph getTaskGraph() { - return activationUpdate; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return activationUpdate.snapshot(); + protected TaskGraph setupActivationTaskGraph(String handle) { + GraniteConfiguration cfg = (GraniteConfiguration) config; + return switch (config.quantization()) { + case "FP16" -> new TaskGraph(handle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, context, (HalfFloatArray) state.embeddingX, state.wrapX, cfg.embeddingScale()) + .persistOnDevice(state.wrapX); + case "Q8_0" -> new TaskGraph(handle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, context, (ByteArray) state.embeddingX, state.wrapX, cfg.embeddingScale()) + .persistOnDevice(state.wrapX); + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + }; } - + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 0387ae54..80f4a6ea 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -11,23 +11,15 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class GraniteFP16FFNLayers extends AbstractFFNLayers { -public class GraniteFP16FFNLayers extends AbstractFFNLayers { - - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - - public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { + public GraniteFP16FFNLayers(String taskGraph, State state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -64,35 +56,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - List setupFFNLayered() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); - if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (LlamaFP16FFNLayers) @@ -179,7 +142,8 @@ List setupFFNLayered() { * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) * */ - TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 8d105e89..56f2c0c3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -1,32 +1,23 @@ package org.beehive.gpullama3.tornadovm.layers.type.fp16; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class LlamaFP16FFNLayers extends AbstractFFNLayers { -public class LlamaFP16FFNLayers extends AbstractFFNLayers { - - TaskGraph ffnTaskGraphs; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - - public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + public LlamaFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); - this.ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -63,35 +54,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnTaskGraphs; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - List setupFFNLayered() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (LlamaFP16FFNLayers) @@ -178,7 +140,8 @@ List setupFFNLayered() { * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) * */ - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index d2a81407..bf938a0d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -2,39 +2,29 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsFP16Layer extends AbstractLayer { - - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +public class LogitsFP16Layer extends AbstractLogitsLayer { - public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(name, state, weights, config); - this.lastTaskGraphID = lastTaskGraphID; - this.schedulerType = schedulerType; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); @@ -96,7 +86,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con @Override public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), rmsLocalSize()); var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); @@ -106,18 +96,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; + /** Local workgroup size for RMS norm. Qwen2 requires a smaller group (32 vs 256). */ + protected int rmsLocalSize() { + return weights instanceof Qwen2TornadoWeights ? 32 : 256; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java index d55d707e..54ec9641 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -2,51 +2,40 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; +/** + * Granite-specific FP16 logits layer. + * Identical to LogitsFP16Layer except vocab_proj uses a scaled kernel (logitScale). + */ public class LogitsGraniteFP16Layer extends LogitsFP16Layer { - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; - public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config, lastTaskGraphID, schedulerType); - this.lastTaskGraphID = lastTaskGraphID; - this.schedulerType = schedulerType; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + GraniteConfiguration graniteCfg = (GraniteConfiguration) config; var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - // Kernel context context, - // Output buffer state.wrapLogits, - // Intermediate FP16 buffer state.wrapXbFP16, - // Weights weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()); @@ -54,72 +43,43 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, + state.wrapX, + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, - state.tempLogits, // in/out: combines partial sums - config.dim(), // dimension - config.rmsNormEps()); // epsilon + state.tempLogits, + config.dim(), + config.rmsNormEps()); } logits.task("rms_apply_fp16", TransformerComputeKernels::mapContextWithQuantizeLogits, context, - state.wrapXbFP16, // output: normalized (FP16) - state.wrapX, // input: hidden state - weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights - state.tempLogits); // scale factor from reduction + state.wrapXbFP16, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); - // === Vocabulary Projection === + // === Vocabulary Projection (Granite: scaled by logitScale) === logits.task("vocab_proj", GraniteKernels::matrixVectorGenericWithGraniteScale, context, - state.wrapXbFP16, // input (FP16) - state.wrapLogits, // output - weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights - config.dim(), // input dimension - config.vocabularySize(), // output dimension + state.wrapXbFP16, + state.wrapLogits, + weights.wclsByteArray.asHalfFloatArray(), + config.dim(), + config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, - config.logitScale()); // granite logit scaling + graniteCfg.logitScale()); - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - return tornadoForwardScheduler; - } - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; - } } - diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java new file mode 100644 index 00000000..499fc176 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java @@ -0,0 +1,208 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class MistralFP16FFNLayers extends AbstractFFNLayers { + + public MistralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + setupFFNLayers(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + return tornadoForwardScheduler; + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, + state.wrapQ, + state.wrapK, + state.wrapV, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + config.dim(), + config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, + state.wrapK, + state.wrapV, + state.wrapKeyCache, + state.wrapValueCache, + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, + state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.tempFFN, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + config.dim(), + config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index a98f9860..065c6802 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. * @@ -26,23 +22,18 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3FP16FFNLayers extends AbstractFFNLayers { +public class Phi3FP16FFNLayers extends AbstractFFNLayers { // Typed references to Phi3-specific state and config private final Phi3State phi3State; - private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; - this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -88,40 +79,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); - if (layerIndex == phi3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Phi3FP16FFNLayers - Fully Optimized) @@ -207,9 +164,12 @@ List setupFFNLayered() { * • Inline SiLU+GLU: No intermediate wrapHb buffer needed * */ - TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + Phi3TornadoWeights weights = (Phi3TornadoWeights) this.weights; var taskGraphName = "layer_" + layerIndex; var unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights @@ -230,8 +190,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { unifiedLayer.task("attn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.temp, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -246,8 +206,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapK, // output K phi3State.wrapV, // output V weights.rms_att_weightLayered[layerIndex].asFloatArray(), phi3State.temp, // RMS scale - weights.wqkvLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // dim - phi3Config.kvDim(), // kvDim + weights.wqkvLayered[layerIndex].asHalfFloatArray(), config.dim(), // dim + config.kvDim(), // kvDim LOCAL_WORK_GROUP_SIZE_ALLOC); // Fused Phi3 RoPE Rotation + KV Cache Write @@ -258,11 +218,11 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapV, // V vectors (in only) phi3State.wrapKeyCache, // key cache (out) phi3State.wrapValueCache, // value cache (out) - phi3Config.numberOfKeyValueHeads(), // nHeadKv - phi3Config.headSize(), // head dimension - phi3Config.kvDim(), // kvDim + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // head dimension + config.kvDim(), // kvDim layerIndex, // layer index for cache offset - phi3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, @@ -270,21 +230,21 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3State.wrapKeyCache, // key cache phi3State.wrapValueCache, // value cache phi3State.wrapXb, // output: attention result - phi3Config.numberOfHeads(), // nHeads - phi3Config.headSize(), // headSize - phi3Config.kvDim(), // kvDim - phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) phi3State.positionHolder, // position layerIndex, // layer index - phi3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, phi3State.wrapXb, // input: attention output phi3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim × dim] - phi3Config.dim(), // input dim - phi3Config.dim(), // output dim + config.dim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -295,24 +255,24 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { unifiedLayer.task("ffn_rms_reduce", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, phi3State.tempFFN, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size // Final normalization (non-NVIDIA only) if (shouldUseFinalNormalization()) { unifiedLayer.task("ffn_rms_finalize", TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) - phi3Config.dim(), // dimension - phi3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } unifiedLayer.task("rms_ffn_silu", Phi3Kernels::fusedRmsNormFFNGateUpSiLU, context, phi3State.wrapX, // input phi3State.wrapHbU, // output (direct to final FFN buffer) weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), phi3State.tempFFN, // RMS scale - weights.wUpLayered[layerIndex].asHalfFloatArray(), phi3Config.dim(), // input dim - phi3Config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) + weights.wUpLayered[layerIndex].asHalfFloatArray(), config.dim(), // input dim + config.hiddenDim(), // output dim (hiddenDim, not 2×hiddenDim!) LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual @@ -320,8 +280,8 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { context, phi3State.wrapHbU, // input: FFN intermediate phi3State.wrapX, // output: wrapX += wDown · wrapHbU weights.wDownLayered[layerIndex].asHalfFloatArray(), // wDown [dim × hiddenDim] - phi3Config.hiddenDim(), // input dim - phi3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); unifiedLayer.persistOnDevice(phi3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index a6f1c95c..d6ac0eee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -10,16 +10,12 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.WorkerGrid2D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen2FP16FFNLayers: FP16 FFN layers for Qwen2 with Group Query Attention (GQA) support. * @@ -29,20 +25,15 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2FP16FFNLayers extends AbstractFFNLayers { +public class Qwen2FP16FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen2-specific state and config + // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; - private final Qwen2Configuration qwen2Config; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; - this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -116,37 +107,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(qwen2Config.numberOfLayers()); - for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); - if (layerIndex == qwen2Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Qwen2FP16FFNLayers - Optimized) @@ -235,9 +195,11 @@ List setupFFNLayered() { * • No Q/K RMSNorm: Unlike Qwen3, Qwen2 doesn't normalize Q/K after projection * */ - TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index b99a1ab3..60b3cc3e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; - /** * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. * @@ -25,11 +21,10 @@ * * Works directly with Qwen3State to access and mutate Qwen3-specific state fields like tempQcur and tempKcur. */ -public class Qwen3FP16FFNLayers extends AbstractFFNLayers { +public class Qwen3FP16FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen3-specific state and config + // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; - private final Qwen3Configuration qwen3Config; // Qwen3-specific GQA parameters private final int nHeadKv; private final int nEmbdHeadK; @@ -38,14 +33,10 @@ public class Qwen3FP16FFNLayers extends AbstractFFNLayers { private final int nEmbdHead; private final int nEmbdGqa; private final int gqa; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; - this.qwen3Config = config; // Initialize GQA parameters from Qwen3Config this.nHeadKv = config.numberOfKeyValueHeads(); @@ -55,7 +46,7 @@ public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -76,7 +67,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads(); WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead); - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int qDim0 = nEmbdHeadK * config.numberOfHeads(); int kvDim0 = nEmbdGqa; int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; @@ -102,38 +93,6 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - List setupFFNLayered() { - return IntStream.range(0, qwen3Config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, i); - if (i == qwen3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - return ffnLayer.snapshot(); - }).toList(); - } - // @formatter:off /** * Transformer Layer Task Flow (Qwen3FP16FFNLayers) @@ -221,13 +180,14 @@ List setupFFNLayered() { * • RoPE theta: 1,000,000 (vs Llama's 10,000 or 50,000) * */ - TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; // === Dimension Parameters === - int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int qDim = nEmbdHeadK * config.numberOfHeads(); // Q output size (full heads) int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) - int inputDim = qwen3Config.dim(); // Model dimension + int inputDim = config.dim(); // Model dimension var unifiedLayer = new TaskGraph(taskGraphName); @@ -260,8 +220,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.temp, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -299,11 +259,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapK, // K vectors (in/out) weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights - qwen3Config.numberOfHeads(), // nHeads (Q heads) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + config.numberOfHeads(), // nHeads (Q heads) + config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) nEmbdHead, // head dimension nEmbdHead, // local memory size - qwen3Config.rmsNormEps()); // epsilon + config.rmsNormEps()); // epsilon // Fused RoPE Rotation + KV Cache Write unifiedLayer.task("rope_and_kv_cache", @@ -315,11 +275,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapV, // V vectors (in only) qwen3State.wrapKeyCache, // key cache (out) qwen3State.wrapValueCache, // value cache (out) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv + config.numberOfKeyValueHeads(), // nHeadKv nEmbdHead, // head dimension nEmbdGqa, // kvDim layerIndex, // layer index for cache offset - qwen3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -329,13 +289,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapKeyCache, // key cache qwen3State.wrapValueCache, // value cache qwen3State.wrapXb, // output: attention result - qwen3Config.numberOfHeads(), // nHeads + config.numberOfHeads(), // nHeads nEmbdHead, // headSize nEmbdGqa, // kvDim gqa, // kvMul (nHeads / nHeadKv) qwen3State.positionHolder, // position layerIndex, // layer index - qwen3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", @@ -344,8 +304,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapXb, // input: attention output qwen3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asHalfFloatArray(), // Wo [dim x qDim] - nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) - qwen3Config.dim(), // output dim + nEmbdHeadK * config.numberOfHeads(), // input dim (qDim) + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -358,8 +318,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.tempFFN, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -368,8 +328,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, // scale factor (in/out) - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused RMS Apply + Gate/Up Projection + SiLU + GLU @@ -382,8 +342,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.tempFFN, // RMS scale factor weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 (gate) weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 (up) - qwen3Config.dim(), // input dimension - qwen3Config.hiddenDim(), // hidden dimension + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual @@ -393,8 +353,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapHb, // input: FFN intermediate qwen3State.wrapX, // output: wrapX += W2 · wrapHb weights.w2Layered[layerIndex].asHalfFloatArray(), // W2 (down) - qwen3Config.hiddenDim(), // input dim - qwen3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) .persistOnDevice(qwen3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index b7e036d4..8a91c75a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -1,11 +1,7 @@ package org.beehive.gpullama3.tornadovm.layers.type.q8_0; import org.beehive.gpullama3.inference.state.GraniteState; -import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.granite.Granite; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -13,47 +9,15 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; - -public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { - - GridScheduler scheduler; - List ffnLayerTaskGraphs; +public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayered(); - } - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return null; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - List setupFFNLayered() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); - if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - return ffnLayer.snapshot(); - }).toList(); + setupFFNLayers(); } /** @@ -143,7 +107,8 @@ List setupFFNLayered() { * Quantization: Q8_0 format (8-bit weights with block-wise scaling) * */ - TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -328,10 +293,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { if (schedulerType == SchedulerType.NVIDIA) { // Flash Attention (optimized for NVIDIA GPUs) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index ba1b6a79..1f33f090 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -2,53 +2,21 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.List; -import java.util.stream.IntStream; +public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { -public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { - - GridScheduler scheduler; - List ffnLayerTaskGraphs; - - public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { + public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); - ffnLayerTaskGraphs = setupFFNLayered(); - } - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return null; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - List setupFFNLayered() { - return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { - var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); - if (i == config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - return ffnLayer.snapshot(); - }).toList(); + setupFFNLayers(); } // @formatter:off @@ -139,7 +107,8 @@ List setupFFNLayered() { * Quantization: Q8_0 format (8-bit weights with block-wise scaling) * */ - TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var layerTaskGraphName = "layer_" + layerIndex; TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); @@ -323,10 +292,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { if (schedulerType == SchedulerType.NVIDIA) { return unifiedLayer.task("attention", diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java index d1fd4f0c..c583bb00 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -2,69 +2,51 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{ - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +/** + * Granite-specific Q8_0 logits layer. + * Identical to LogitsQ8_0Layer except vocab_proj uses a scaled kernel (logitScale). + */ +public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer { - public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType); - this.lastTaskGraphID = lastTaskGraphID; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); - this.schedulerType = schedulerType; - } - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); - return tornadoForwardScheduler; + public LogitsGraniteQ8_0Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + GraniteConfiguration graniteCfg = (GraniteConfiguration) config; var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, // - state.wrapLogits, // - weights.wclsByteArray.asByteArray(), // + context, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), weights.rms_final_weight_as_floatArray); // === Final RMS Normalization === logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, + state.wrapX, + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", @@ -74,6 +56,7 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat config.dim(), config.rmsNormEps()); } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, @@ -81,8 +64,9 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits); - // === Vocabulary vocab_proj === - logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, // + // === Vocabulary Projection (Granite: scaled by logitScale) === + logits.task("vocab_proj", + GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, context, state.wrapX, state.wrapLogits, @@ -90,29 +74,10 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfigurat config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, - config.logitScale() + graniteCfg.logitScale()); - ); - - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on - - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; - } - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index d54bb3ef..ed0d6cfc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -2,69 +2,49 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsQ8_0Layer extends AbstractLayer { - - private String lastTaskGraphID; - private TaskGraph logitsTaskGraph; - private ImmutableTaskGraph immutableLogitsGraph; - private GridScheduler scheduler; - private SchedulerType schedulerType; +public class LogitsQ8_0Layer extends AbstractLogitsLayer { - public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { - super(taskGraphName, state, weights, config); - this.lastTaskGraphID = lastTaskGraphID; - var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); - this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); - this.schedulerType = schedulerType; - } - - @Override - public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { - var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); - var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); - return tornadoForwardScheduler; + public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); } // @formatter:off - private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); + // === Data Setup === logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, // - state.wrapLogits, // - weights.wclsByteArray.asByteArray(), // - weights.rms_final_weight_as_floatArray); + context, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + weights.rms_final_weight_as_floatArray); // === Final RMS Normalization === logits.task("rms_reduce", TransformerComputeKernels::reductionOneBlockWithLayer, context, - state.tempLogits, // output: partial sums + final scale factor - state.wrapX, // input: hidden state - config.dim(), // dimension - config.rmsNormEps(), // epsilon for numerical stability - state.localSize); // local workgroup size + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), + config.rmsNormEps(), + state.localSize); if (schedulerType == SchedulerType.NON_NVIDIA) { logits.task("rms_finalize", @@ -74,42 +54,44 @@ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration con config.dim(), config.rmsNormEps()); } - logits.task("mapContextLogits", - TransformerComputeKernels::reductionOneBlock2WithLogits, - context, - state.wrapX, - weights.rms_final_weight_as_floatArray.asFloatArray(), + + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits); - - // === Vocabulary vocab_proj === - logits.task("vocab_proj", TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, // - context, - state.wrapX, - state.wrapLogits, - weights.wclsByteArray.asByteArray(), - config.dim(), - config.vocabularySize(), + + // === Vocabulary Projection === + logits.task("vocab_proj", + TransformerComputeKernelsLayered::matrixVectorGenericQ8Byte, + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); - // === Transfer Results to Host === logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); return logits; } // @formatter:on @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return logitsTaskGraph; + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), rmsLocalSize()); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; } - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return immutableLogitsGraph; + /** Local workgroup size for RMS norm. Qwen2 requires a smaller group (32 vs 256). */ + protected int rmsLocalSize() { + return weights instanceof Qwen2TornadoWeights ? 32 : 256; } - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java new file mode 100644 index 00000000..64864114 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java @@ -0,0 +1,189 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class MistralQ8_0FFNLayers extends AbstractFFNLayers { + + public MistralQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + setupFFNLayers(); + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, + state.wrapQ, state.wrapK, state.wrapV, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + config.kvDim(), config.headSize(), layerIndex, config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === FFN Block === + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + + return tornadoForwardScheduler; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 46b0737d..8f693adc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Phi3Q8_0FFNLayers: Q8_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support. * @@ -25,23 +21,18 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { +public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { - // Typed references to Phi3-specific state and config + // Typed reference to Phi3-specific state private final Phi3State phi3State; - private final Phi3Configuration phi3Config; // Phi3-specific dimension for combined QKV buffer private final int opSize; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.phi3State = state; - this.phi3Config = config; this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -76,40 +67,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); - if (layerIndex == phi3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - // @formatter:off /** * Transformer Layer Task Flow (Phi3Q8_0FFNLayers - Fully Optimized) @@ -207,9 +164,11 @@ List setupFFNLayered() { * • wrapHbG: Gate output merged into final computation * */ - TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; var unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(phi3State.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Copy-in quantized weights per layer (Q8_0 format: ByteArray) @@ -232,8 +191,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex context, phi3State.temp, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size if (shouldUseFinalNormalization()) { @@ -256,8 +215,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS weights phi3State.temp, // RMS scale (precomputed) weights.wqkvLayered[layerIndex].asByteArray(), // Q8 combined QKV [opSize × dim] - phi3Config.dim(), // input dim - phi3Config.kvDim(), // K/V output dim + config.dim(), // input dim + config.kvDim(), // K/V output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // Fused Phi3 RoPE Rotation + KV Cache Write @@ -269,11 +228,11 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapV, // V vectors (in only) phi3State.wrapKeyCache, // key cache (out) phi3State.wrapValueCache, // value cache (out) - phi3Config.numberOfKeyValueHeads(), // nHeadKv - phi3Config.headSize(), // head dimension - phi3Config.kvDim(), // kvDim + config.numberOfKeyValueHeads(), // nHeadKv + config.headSize(), // head dimension + config.kvDim(), // kvDim layerIndex, // layer index for cache offset - phi3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -283,13 +242,13 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapKeyCache, // key cache phi3State.wrapValueCache, // value cache phi3State.wrapXb, // output: attention result - phi3Config.numberOfHeads(), // nHeads - phi3Config.headSize(), // headSize - phi3Config.kvDim(), // kvDim - phi3Config.kvMul(), // kvMul (nHeads / nHeadKv) + config.numberOfHeads(), // nHeads + config.headSize(), // headSize + config.kvDim(), // kvDim + config.kvMul(), // kvMul (nHeads / nHeadKv) phi3State.positionHolder, // position layerIndex, // layer index - phi3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual (Q8 dequantization) unifiedLayer.task("attn_output_proj", @@ -298,8 +257,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapXb, // input: attention output phi3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asByteArray(), // Q8 Wo [dim × dim] - phi3Config.dim(), // input dim - phi3Config.dim(), // output dim + config.dim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // ═══════════════════════════════════════════════════════════════════════ @@ -312,8 +271,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex context, phi3State.tempFFN, // output: scale factor phi3State.wrapX, // input: hidden state - phi3Config.dim(), // dimension - phi3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -322,8 +281,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex TransformerComputeKernelsLayered::reductionFinalNormalization, context, phi3State.tempFFN, // scale factor (in/out) - phi3Config.dim(), // dimension - phi3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused: RMS apply + Q8 gate/up matmul + SiLU activation + GLU @@ -335,8 +294,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights phi3State.tempFFN, // RMS scale (precomputed) weights.wUpLayered[layerIndex].asByteArray(), // Q8 combined gate+up [2×hiddenDim × dim] - phi3Config.dim(), // input dim - phi3Config.hiddenDim(), // output dim + config.dim(), // input dim + config.hiddenDim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual (Q8 dequantization) @@ -346,8 +305,8 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3State.wrapHbU, // input: FFN intermediate phi3State.wrapX, // output: wrapX += wDown · wrapHbU weights.wDownLayered[layerIndex].asByteArray(), // Q8 wDown [dim × hiddenDim] - phi3Config.hiddenDim(), // input dim - phi3Config.dim(), // output dim + config.hiddenDim(), // input dim + config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); unifiedLayer.persistOnDevice(phi3State.wrapX); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 7aee3b83..6cdf32db 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -10,16 +10,12 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.WorkerGrid2D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen2Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support. * @@ -32,21 +28,14 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { - - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; - - // Typed references to Qwen2-specific state and config +public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { + // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; - private final Qwen2Configuration qwen2Config; public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen2State = state; - this.qwen2Config = config; - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -115,48 +104,13 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) return tornadoForwardScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - qwen2State.temp.init(0.0f); - qwen2State.tempFFN.init(0.0f); - - for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen2Q8_0FFNLayer((Qwen2TornadoWeights) weights, layerIndex); - if (layerIndex == qwen2Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - /** * Setup a single transformer layer for Qwen2 with Q8_0 quantization and GQA */ - TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { - TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(state.wrapX); unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 6aea5559..43b88fb1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -9,14 +9,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import java.util.ArrayList; -import java.util.List; - /** * Qwen3Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support. * @@ -29,16 +25,10 @@ * Works directly with Qwen3State to access and mutate Qwen3-specific state fields * like tempQcur and tempKcur. */ -public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { - - String lastTaskGraphID; - TaskGraph ffnLayerTaskGraph; - GridScheduler scheduler; - List ffnLayerTaskGraphs; +public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { - // Typed references to Qwen3-specific state and config + // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; - private final Qwen3Configuration qwen3Config; // Qwen3-specific GQA parameters private final int nHeadKv; @@ -52,7 +42,6 @@ public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); this.qwen3State = state; - this.qwen3Config = config; this.nHeadKv = config.numberOfKeyValueHeads(); this.nEmbdHeadK = config.numberOfHeadsKey(); this.nEmbdHeadV = config.numberOfHeadsValue(); @@ -60,7 +49,7 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe this.nEmbdHead = nEmbdHeadV; this.nEmbdGqa = nEmbdVGqa; this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); - ffnLayerTaskGraphs = setupFFNLayered(); + setupFFNLayers(); } @Override @@ -82,7 +71,7 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); - int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int qDim0 = nEmbdHeadK * config.numberOfHeads(); int kvDim0 = nEmbdGqa; int fusedQKVRows = qDim0 + 2 * kvDim0; // Q rows + K rows + V rows int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; @@ -105,55 +94,17 @@ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { return gridScheduler; } - @Override - public GridScheduler getGridScheduler() { - return scheduler; - } - - @Override - public TaskGraph getTaskGraph() { - return ffnLayerTaskGraph; - } - - @Override - public ImmutableTaskGraph getImmutableTaskGraph() { - return null; - } - - public List getFfnLayerTaskGraphs() { - return ffnLayerTaskGraphs; - } - - /** - * Setup all FFN layers for all transformer layers - */ - List setupFFNLayered() { - List ffnGraphs = new ArrayList<>(); - qwen3State.temp.init(0.0f); - qwen3State.tempFFN.init(0.0f); - qwen3State.tempQcur.init(0.0f); - qwen3State.tempKcur.init(0.0f); - - for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { - TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); - if (layerIndex == qwen3Config.numberOfLayers() - 1) { - setupLastID(ffnLayer.getTaskGraphName()); - } - ffnGraphs.add(ffnLayer.snapshot()); - } - return ffnGraphs; - } - /** * Setup a single transformer layer for Qwen3 with GQA (Q8_0 quantized) */ - TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var taskGraphName = "layer_" + layerIndex; // === Dimension Parameters === - int qDim = nEmbdHeadK * qwen3Config.numberOfHeads(); // Q output size (full heads) + int qDim = nEmbdHeadK * config.numberOfHeads(); // Q output size (full heads) int kvDim = nEmbdGqa; // K/V output size (reduced for GQA) - int inputDim = qwen3Config.dim(); // Model dimension + int inputDim = config.dim(); // Model dimension var unifiedLayer = new TaskGraph(taskGraphName); @@ -225,11 +176,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapK, // K vectors (in/out) weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q norm weights weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K norm weights - qwen3Config.numberOfHeads(), // nHeads (Q heads) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) + config.numberOfHeads(), // nHeads (Q heads) + config.numberOfKeyValueHeads(), // nHeadKv (K/V heads, GQA) nEmbdHead, // head dimension nEmbdHead, // local memory size - qwen3Config.rmsNormEps()); // epsilon + config.rmsNormEps()); // epsilon // Fused RoPE Rotation + KV Cache Write unifiedLayer.task("rope_and_kv_cache", @@ -241,11 +192,11 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapV, // V vectors (in only) qwen3State.wrapKeyCache, // key cache (out) qwen3State.wrapValueCache, // value cache (out) - qwen3Config.numberOfKeyValueHeads(), // nHeadKv + config.numberOfKeyValueHeads(), // nHeadKv nEmbdHead, // head dimension nEmbdGqa, // kvDim layerIndex, // layer index for cache offset - qwen3Config.contextLength()); // max sequence length + config.contextLength()); // max sequence length // Flash Attention unifiedLayer.task("attention", @@ -255,13 +206,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapKeyCache, // key cache qwen3State.wrapValueCache, // value cache qwen3State.wrapXb, // output: attention result - qwen3Config.numberOfHeads(), // nHeads + config.numberOfHeads(), // nHeads nEmbdHead, // headSize nEmbdGqa, // kvDim gqa, // kvMul (nHeads / nHeadKv) qwen3State.positionHolder, // position layerIndex, // layer index - qwen3Config.contextLength()); // context length + config.contextLength()); // context length // Output Projection with Residual unifiedLayer.task("attn_output_proj", @@ -270,7 +221,7 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.wrapXb, // input: attention output qwen3State.wrapX, // output: wrapX += Wo · wrapXb weights.woLayered[layerIndex].asByteArray(), // Wo [dim x qDim] - nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim) + nEmbdHeadK * config.numberOfHeads(), // input dim (qDim) config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); @@ -284,8 +235,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) context, qwen3State.tempFFN, // output: scale factor qwen3State.wrapX, // input: hidden state - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps(), // epsilon + config.dim(), // dimension + config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size // Final normalization (non-NVIDIA only) @@ -294,8 +245,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, // scale factor (in/out) - qwen3Config.dim(), // dimension - qwen3Config.rmsNormEps()); // epsilon + config.dim(), // dimension + config.rmsNormEps()); // epsilon } // Fused RMS Apply + Gate/Up Projection + SiLU + GLU @@ -308,8 +259,8 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3State.tempFFN, // RMS scale factor weights.w1Layered[layerIndex].asByteArray(), // W1 (gate) Q8_0 weights.w3Layered[layerIndex].asByteArray(), // W3 (up) Q8_0 - qwen3Config.dim(), // input dimension - qwen3Config.hiddenDim(), // hidden dimension + config.dim(), // input dimension + config.hiddenDim(), // hidden dimension LOCAL_WORK_GROUP_SIZE_ALLOC); // Down Projection with Residual