Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/binding/py_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ std::vector<std::pair<std::string, Tensor>> MultiGPUPyTrainer::get_gradients(int
auto& block = grads.get_block_shard(l, nullptr);
result.emplace_back(prefix + ".self_attn.qkv.weight", block.Attn_QKV_w);
if (block.Attn_QKV_b)
result.emplace_back(prefix + ".self_attn.qkv.bias", block.Attn_QKV_b.value());
result.emplace_back(prefix + ".self_attn.qkv.bias", block.Attn_QKV_b);
result.emplace_back(prefix + ".self_attn.o_proj.weight", block.Attn_Out_w);
result.emplace_back(prefix + ".mlp.up.weight", block.MLP_Up_w);
result.emplace_back(prefix + ".mlp.down_proj.weight", block.MLP_Down_w);
Expand Down
26 changes: 13 additions & 13 deletions src/kernels/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ void fused_classifier(Tensor& logits, Tensor& losses,
}
}

void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::optional<Tensor> wpe, int B, int T, int C, int V, cudaStream_t stream) {
void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, const Tensor& wpe, int B, int T, int C, int V, cudaStream_t stream) {
if(out.DType == ETensorDType::FP32) {
encoder_forward(out.get<float>(), inp.get<std::int32_t>(), wte.get<float>(), wpe.has_value() ? wpe->get<float>() : nullptr, B, T, C, V, stream);
encoder_forward(out.get<float>(), inp.get<std::int32_t>(), wte.get<float>(), wpe.get_optional<float>(), B, T, C, V, stream);
} else if(out.DType == ETensorDType::BF16) {
encoder_forward(out.get<nv_bfloat16>(), inp.get<std::int32_t>(), wte.get<nv_bfloat16>(), wpe.has_value() ? wpe->get<nv_bfloat16>() : nullptr, B, T, C, V, stream);
encoder_forward(out.get<nv_bfloat16>(), inp.get<std::int32_t>(), wte.get<nv_bfloat16>(), wpe.get_optional<nv_bfloat16>(), B, T, C, V, stream);
} else {
throw std::runtime_error("encoder_forward: unsupported dtype");
}
Expand Down Expand Up @@ -270,36 +270,36 @@ void fill_constant(Tensor& dest, float value, std::size_t count, cudaStream_t st
}
}

void matmul(Tensor& c, const Tensor& a, const Tensor& b, std::optional<Tensor> bias,
void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias,
const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, Tensor& workspace,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream) {
std::byte* ws = workspace.get<std::byte>();
std::size_t ws_size = workspace.bytes();
if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP32) {
float* bias_ptr = bias.has_value() ? bias.value().get<float>() : nullptr;
const float* bias_ptr = bias.get_optional<float>();
matmul(c.get<float>(), a.get<float>(), b.get<float>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::BF16) {
float* bias_ptr = bias.has_value() ? bias.value().get<float>() : nullptr;
const float* bias_ptr = bias.get_optional<float>();
matmul(c.get<float>(), a.get<nv_bfloat16>(), b.get<nv_bfloat16>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else if(c.DType == ETensorDType::FP32 && a.DType == ETensorDType::FP8_E4M3) {
if(bias.has_value()) {
if(bias.value().DType == ETensorDType::BF16) {
matmul(c.get<float>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias->get<nv_bfloat16>(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
if(!bias.empty()) {
if(bias.DType == ETensorDType::BF16) {
matmul(c.get<float>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get<nv_bfloat16>(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else {
matmul(c.get<float>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias->get<float>(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
matmul(c.get<float>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias.get<float>(), scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
}
} else {
matmul(c.get<float>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), (nv_bfloat16*)nullptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
}
} else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E4M3) {
nv_bfloat16* bias_ptr = bias.has_value() ? bias.value().get<nv_bfloat16>() : nullptr;
const nv_bfloat16* bias_ptr = bias.get_optional<nv_bfloat16>();
matmul(c.get<nv_bfloat16>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e4m3>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else if(c.DType == ETensorDType::BF16 && a.DType == ETensorDType::FP8_E4M3 && b.DType == ETensorDType::FP8_E5M2) {
nv_bfloat16* bias_ptr = bias.has_value() ? bias.value().get<nv_bfloat16>() : nullptr;
const nv_bfloat16* bias_ptr = bias.get_optional<nv_bfloat16>();
matmul(c.get<nv_bfloat16>(), a.get<__nv_fp8_e4m3>(), b.get<__nv_fp8_e5m2>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else if(c.DType == ETensorDType::BF16) {
nv_bfloat16* bias_ptr = bias.has_value() ? bias.value().get<nv_bfloat16>() : nullptr;
const nv_bfloat16* bias_ptr = bias.get_optional<nv_bfloat16>();
matmul(c.get<nv_bfloat16>(), a.get<nv_bfloat16>(), b.get<nv_bfloat16>(), bias_ptr, scale_a, scale_b, handle, ws, ws_size, M, N, K, mode, accumulate, stream);
} else {
throw std::logic_error("matmul_forward: invalid DType combination");
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ enum class EMMTranspose { TT, TN, NT, NN };

void encoder_forward(float* out, const int* inp, const float* wte, const float* wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_forward(nv_bfloat16* out, const int* inp, const nv_bfloat16* wte, const nv_bfloat16* wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, std::optional<Tensor> wpe, int B, int T, int C, int V, cudaStream_t stream);
void encoder_forward(Tensor& out, const Tensor& inp, const Tensor& wte, const Tensor& wpe, int B, int T, int C, int V, cudaStream_t stream);

void encoder_backward(float* dwte, int* scratch,
int* workload_indices, int4* bucket_info,
Expand Down Expand Up @@ -92,7 +92,7 @@ void matmul(nv_bfloat16* c, const __nv_fp8_e4m3* a, const __nv_fp8_e5m2* b, cons
cublasLtHandle_t handle, std::byte* workspace, std::size_t workspace_size,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream);

void matmul(Tensor& c, const Tensor& a, const Tensor& b, std::optional<Tensor> bias, const float* scale_a, const float* scale_b,
void matmul(Tensor& c, const Tensor& a, const Tensor& b, const Tensor& bias, const float* scale_a, const float* scale_b,
cublasLtHandle_t handle, Tensor& workspace,
int M, int N, int K, EMMTranspose mode, bool accumulate, cudaStream_t stream);

Expand Down
82 changes: 36 additions & 46 deletions src/models/llama_gradients.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ void LLamaGradsManager::scatter_reduce(int layer_idx, sLLamaBlockWeights<Tensor>
comm.schedule_reduce_scatter(block.LN2_w);
comm.schedule_reduce_scatter(block.MLP_Up_w);
comm.schedule_reduce_scatter(block.MLP_Down_w);
if(block.Attn_QKV_b.has_value()) {
comm.schedule_reduce_scatter(block.Attn_QKV_b.value());
}
comm.schedule_reduce_scatter(block.Attn_QKV_b);
comm.execute_transaction(signal);
}

Expand Down Expand Up @@ -93,10 +91,7 @@ void LLamaGradientsUnsharded::on_first_micro_step(cudaStream_t stream) {
for(auto& layer: mFullGradient.Blocks) {
fill_zero(layer.LN1_w, stream);
fill_zero(layer.LN2_w, stream);

if(auto& qkv_b = layer.Attn_QKV_b; qkv_b.has_value()) {
fill_zero(qkv_b.value(), stream);
}
fill_zero(layer.Attn_QKV_b, stream);
// no need to zero out the matrix weights, we'll just overwrite them on the first
// grad accumulation step
}
Expand Down Expand Up @@ -289,9 +284,7 @@ sLLamaBlockWeights<Tensor>& LLamaGradientsBlockShardedBase::get_block_full(int l
// reset local gradient buffers
fill_zero(dw.LN1_w, stream);
fill_zero(dw.LN2_w, stream);
if (dw.Attn_QKV_b.has_value()) {
fill_zero(dw.Attn_QKV_b.value(), stream);
}
fill_zero(dw.Attn_QKV_b, stream);
return dw;
}

Expand Down Expand Up @@ -371,17 +364,17 @@ void LLamaGradientsBlockSharded_ScatterReduce::sr_accumulate_layer(int layer_idx
cudaStream_t stream,
NCCLCommunicator& comm) {
NvtxRange range("accumulate_layer", layer_idx);
auto rng_1 = mRng.generate(2*mStepCounter + 0, layer_idx);
auto rng_2 = mRng.generate(2*mStepCounter + 1, layer_idx);

sr_accumulate_tensor(sw.LN1_w, dw.LN1_w, stream, rng_1[0]);
sr_accumulate_tensor(sw.LN2_w, dw.LN2_w, stream, rng_1[1]);
sr_accumulate_tensor(sw.MLP_Up_w, dw.MLP_Up_w, stream, rng_1[2]);
sr_accumulate_tensor(sw.MLP_Down_w, dw.MLP_Down_w, stream, rng_1[3]);
sr_accumulate_tensor(sw.Attn_QKV_w, dw.Attn_QKV_w, stream, rng_2[0]);
sr_accumulate_tensor(sw.Attn_Out_w, dw.Attn_Out_w, stream, rng_2[1]);
if(sw.Attn_QKV_b.has_value()) {
sr_accumulate_tensor(sw.Attn_QKV_b.value(), dw.Attn_QKV_b.value(), stream, rng_2[2]);
std::array<std::uint32_t, 8> rng;
mRng.generate(std::span(rng), 2*mStepCounter, layer_idx);

sr_accumulate_tensor(sw.LN1_w, dw.LN1_w, stream, rng[0]);
sr_accumulate_tensor(sw.LN2_w, dw.LN2_w, stream, rng[1]);
sr_accumulate_tensor(sw.MLP_Up_w, dw.MLP_Up_w, stream, rng[2]);
sr_accumulate_tensor(sw.MLP_Down_w, dw.MLP_Down_w, stream, rng[3]);
sr_accumulate_tensor(sw.Attn_QKV_w, dw.Attn_QKV_w, stream, rng[4]);
sr_accumulate_tensor(sw.Attn_Out_w, dw.Attn_Out_w, stream, rng[5]);
if(sw.Attn_QKV_b) {
sr_accumulate_tensor(sw.Attn_QKV_b, dw.Attn_QKV_b, stream, rng[6]);
}
}

Expand All @@ -407,17 +400,16 @@ void LLamaGradientsBlockSharded_AllToAll::scatter_reduce(int layer_idx, sLLamaBl
// accumulate local slice of block to local gradient
{
NvtxRange range("accumulate-own-shard", layer_idx);
auto rng_1 = mRng.generate(2 * mStepCounter + 0, layer_idx);
auto rng_2 = mRng.generate(2 * mStepCounter + 1, layer_idx);
sr_accumulate_tensor(sw.LN1_w, dw.LN1_w, stream, mIsFirstMicroStep, 1.f, rank, rng_1[0]);
sr_accumulate_tensor(sw.LN2_w, dw.LN2_w, stream, mIsFirstMicroStep, 1.f, rank, rng_1[1]);
sr_accumulate_tensor(sw.Attn_QKV_w, dw.Attn_QKV_w, stream, mIsFirstMicroStep, 1.f, rank, rng_1[2]);
sr_accumulate_tensor(sw.Attn_Out_w, dw.Attn_Out_w, stream, mIsFirstMicroStep, 1.f, rank, rng_1[3]);
sr_accumulate_tensor(sw.MLP_Up_w, dw.MLP_Up_w, stream, mIsFirstMicroStep, 1.f, rank, rng_2[0]);
sr_accumulate_tensor(sw.MLP_Down_w, dw.MLP_Down_w, stream, mIsFirstMicroStep, 1.f, rank, rng_2[1]);
if (sw.Attn_QKV_b.has_value()) {
sr_accumulate_tensor(sw.Attn_QKV_b.value(), dw.Attn_QKV_b.value(), stream, mIsFirstMicroStep, 1.f, rank,
rng_2[2]);
std::array<std::uint32_t, 8> rng;
mRng.generate(std::span(rng), 2*mStepCounter, layer_idx);
sr_accumulate_tensor(sw.LN1_w, dw.LN1_w, stream, mIsFirstMicroStep, 1.f, rank, rng[0]);
sr_accumulate_tensor(sw.LN2_w, dw.LN2_w, stream, mIsFirstMicroStep, 1.f, rank, rng[1]);
sr_accumulate_tensor(sw.Attn_QKV_w, dw.Attn_QKV_w, stream, mIsFirstMicroStep, 1.f, rank, rng[2]);
sr_accumulate_tensor(sw.Attn_Out_w, dw.Attn_Out_w, stream, mIsFirstMicroStep, 1.f, rank, rng[3]);
sr_accumulate_tensor(sw.MLP_Up_w, dw.MLP_Up_w, stream, mIsFirstMicroStep, 1.f, rank, rng[4]);
sr_accumulate_tensor(sw.MLP_Down_w, dw.MLP_Down_w, stream, mIsFirstMicroStep, 1.f, rank, rng[5]);
if (sw.Attn_QKV_b) {
sr_accumulate_tensor(sw.Attn_QKV_b, dw.Attn_QKV_b, stream, mIsFirstMicroStep, 1.f, rank, rng[6]);
}
}

Expand All @@ -433,9 +425,7 @@ void LLamaGradientsBlockSharded_AllToAll::scatter_reduce(int layer_idx, sLLamaBl
comm.schedule_destructive_all_to_all(dw.LN2_w);
comm.schedule_destructive_all_to_all(dw.MLP_Up_w);
comm.schedule_destructive_all_to_all(dw.MLP_Down_w);
if(dw.Attn_QKV_b.has_value()) {
comm.schedule_destructive_all_to_all(dw.Attn_QKV_b.value());
}
comm.schedule_destructive_all_to_all(dw.Attn_QKV_b);
comm.execute_transaction(signal);
}

Expand All @@ -462,17 +452,17 @@ void LLamaGradientsBlockSharded_AllToAll::sr_accumulate_layer(int layer_idx,
scale = 1.f / world;
}

auto rng_1 = mRng.generate(2*mStepCounter + 0, layer_idx);
auto rng_2 = mRng.generate(2*mStepCounter + 1, layer_idx + 12345);

vector_reduce_sr(sw.LN1_w, dw.LN1_w, scale, world, (rank + world - 1) % world, sw.LN1_w.nelem(), true, rng_1[0], stream);
vector_reduce_sr(sw.LN2_w, dw.LN2_w, scale, world, (rank + world - 1) % world, sw.LN2_w.nelem(), true, rng_1[1], stream);
vector_reduce_sr(sw.MLP_Up_w, dw.MLP_Up_w, scale, world, (rank + world - 1) % world, sw.MLP_Up_w.nelem(), true, rng_1[2], stream);
vector_reduce_sr(sw.MLP_Down_w, dw.MLP_Down_w, scale, world, (rank + world - 1) % world, sw.MLP_Down_w.nelem(), true, rng_1[3], stream);
vector_reduce_sr(sw.Attn_QKV_w, dw.Attn_QKV_w, scale, world, (rank + world - 1) % world, sw.Attn_QKV_w.nelem(), true, rng_2[0], stream);
vector_reduce_sr(sw.Attn_Out_w, dw.Attn_Out_w, scale, world, (rank + world - 1) % world, sw.Attn_Out_w.nelem(), true, rng_2[1], stream);
if(sw.Attn_QKV_b.has_value()) {
vector_reduce_sr(sw.Attn_QKV_b.value(), dw.Attn_QKV_b.value(), scale, world, (rank + world - 1) % world, sw.Attn_QKV_b->nelem(), true, rng_2[2], stream);
std::array<std::uint32_t, 8> rng;
mRng.generate(std::span(rng), 2*mStepCounter, layer_idx);
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RNG sequence has been changed. The old code generated the second batch of 4 random numbers with y = layer_idx + 12345, but the new code uses y = layer_idx for all batches. This means rng[4], rng[5], and rng[6] will have different values than they had before, potentially breaking deterministic behavior or changing the results. If this change is intentional, it should be documented.

Suggested change
mRng.generate(std::span(rng), 2*mStepCounter, layer_idx);
mRng.generate(std::span(rng).subspan(0, 4), 2*mStepCounter, layer_idx);
mRng.generate(std::span(rng).subspan(4, 4), 2*mStepCounter, layer_idx + 12345);

Copilot uses AI. Check for mistakes.

vector_reduce_sr(sw.LN1_w, dw.LN1_w, scale, world, (rank + world - 1) % world, sw.LN1_w.nelem(), true, rng[0], stream);
vector_reduce_sr(sw.LN2_w, dw.LN2_w, scale, world, (rank + world - 1) % world, sw.LN2_w.nelem(), true, rng[1], stream);
vector_reduce_sr(sw.MLP_Up_w, dw.MLP_Up_w, scale, world, (rank + world - 1) % world, sw.MLP_Up_w.nelem(), true, rng[2], stream);
vector_reduce_sr(sw.MLP_Down_w, dw.MLP_Down_w, scale, world, (rank + world - 1) % world, sw.MLP_Down_w.nelem(), true, rng[3], stream);
vector_reduce_sr(sw.Attn_QKV_w, dw.Attn_QKV_w, scale, world, (rank + world - 1) % world, sw.Attn_QKV_w.nelem(), true, rng[4], stream);
vector_reduce_sr(sw.Attn_Out_w, dw.Attn_Out_w, scale, world, (rank + world - 1) % world, sw.Attn_Out_w.nelem(), true, rng[5], stream);
if(sw.Attn_QKV_b) {
vector_reduce_sr(sw.Attn_QKV_b, dw.Attn_QKV_b, scale, world, (rank + world - 1) % world, sw.Attn_QKV_b.nelem(), true, rng[6], stream);
}
}

Expand Down
Loading