Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ target_sources(
tiny_llm_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/src/axpby.cpp
${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.cpp
${CMAKE_CURRENT_LIST_DIR}/src/utils.cpp
)

Expand All @@ -58,6 +59,7 @@ if(MLX_BUILD_METAL)
tiny_llm_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/axpby.metal
${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
Expand Down
16 changes: 16 additions & 0 deletions src/extensions/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,20 @@ NB_MODULE(_ext, m) {
Returns:
array: ``alpha * x + beta * y``
)");

m.def("flash_attention", &tiny_llm_ext::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0,
"is_causal"_a = false, "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"(
Flash attention layer (student implementation)

Args:
query (array): Query array.
key (array): Key array.
value (array): Value array.
mask (array): Mask array.
scale (float): Scaling factor.
is_causal (bool): Enable causal-mask fast path.

Returns:
array: ``softmax(query @ key.T * scale) @ value``
)");
}
51 changes: 51 additions & 0 deletions src/extensions/src/flash_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright © 2023-2025 Apple Inc.

#include <stdexcept>

#include "tiny_llm_ext.h"

namespace tiny_llm_ext {

mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask,
const float scale, const bool is_causal, const int num_kv_heads, const int num_heads,
mx::StreamOrDevice s /* = {} */) {
// TODO(student): implement flash attention.
(void)q;
(void)k;
(void)v;
(void)mask;
(void)scale;
(void)is_causal;
(void)num_kv_heads;
(void)num_heads;
(void)s;
throw std::runtime_error("flash_attention is not implemented.");
}

void FlashAttention::eval_cpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
// TODO(student): implement CPU kernel.
(void)inputs;
(void)outputs;
throw std::runtime_error("FlashAttention::eval_cpu is not implemented.");
}

#ifdef _METAL_

void FlashAttention::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
// TODO(student): implement Metal kernel dispatch.
(void)inputs;
(void)outputs;
throw std::runtime_error("FlashAttention::eval_gpu is not implemented.");
}

#else

void FlashAttention::eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) {
(void)inputs;
(void)outputs;
throw std::runtime_error("FlashAttention has no GPU implementation.");
}

#endif

} // namespace tiny_llm_ext
52 changes: 52 additions & 0 deletions src/extensions/src/flash_attention.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/utils.h"

using namespace metal;

[[kernel]] void flash_attention_f32_e128(
device const float *q [[buffer(0)]],
device const float *k [[buffer(1)]],
device const float *v [[buffer(2)]],
device const float *mask [[buffer(3)]],
device float *out [[buffer(4)]],
constant const int *mask_shape [[buffer(5)]],
constant const int64_t *mask_strides [[buffer(6)]],
device const int &is_causal [[buffer(7)]],
device const int &N [[buffer(8)]],
device const int &L [[buffer(9)]],
device const int &S [[buffer(10)]],
device const int &E [[buffer(11)]],
device const int &num_kv_heads [[buffer(12)]],
device const int &num_heads [[buffer(13)]],
device const float &scale [[buffer(14)]],
device const int &Br [[buffer(15)]],
device const int &Bc [[buffer(16)]],
device const int &Tr [[buffer(17)]],
device const int &Tc [[buffer(18)]],
uint2 group_id [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// TODO(student): implement flash attention kernel.
(void)q;
(void)k;
(void)v;
(void)mask;
(void)out;
(void)mask_shape;
(void)mask_strides;
(void)is_causal;
(void)N;
(void)L;
(void)S;
(void)E;
(void)num_kv_heads;
(void)num_heads;
(void)scale;
(void)Br;
(void)Bc;
(void)Tr;
(void)Tc;
(void)group_id;
(void)simd_gid;
(void)simd_lid;
}
35 changes: 35 additions & 0 deletions src/extensions/src/tiny_llm_ext.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,47 @@
#pragma once

#include <stdexcept>
#include <vector>

#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

namespace mx = mlx::core;

namespace tiny_llm_ext {

void load_library(mx::Device d, const char *path);

///////////////////////////////////////////////////////////////////////////////
// Flash Attention (student implementation)
///////////////////////////////////////////////////////////////////////////////

mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask,
const float scale, const bool is_causal, const int num_kv_heads, const int num_heads,
mx::StreamOrDevice s = {});

class FlashAttention : public mx::Primitive {
public:
explicit FlashAttention(mx::Stream stream, const float scale, const bool is_causal, const int num_kv_heads,
const int num_heads)
: mx::Primitive(stream), scale_(scale), is_causal_(is_causal), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {};

void eval_cpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) override;
void eval_gpu(const std::vector<mx::array> &inputs, std::vector<mx::array> &outputs) override;

std::pair<std::vector<mx::array>, std::vector<int>> vmap(const std::vector<mx::array> &inputs,
const std::vector<int> &axes) override {
throw std::runtime_error("FlashAttention has no vmap implementation.");
}

const char *name() const override { return "FlashAttention"; }

private:
float scale_;
bool is_causal_;
int num_kv_heads_;
int num_heads_;
};

} // namespace tiny_llm_ext