Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;

// Barrier after cleaning (make sure low-latency mode work
// Barrier after cleaning (make sure low-latency mode work
#ifdef USE_ROCM
if (threadIdx.x == 0)
internode::shmem_device_barrier_all();
Expand Down Expand Up @@ -109,7 +109,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
__shared__ internode::shmem_ctx_t ctx;
if constexpr (kMultinode)
EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0 or ctx == ROCSHMEM_CTX_INVALID);
Expand Down Expand Up @@ -247,7 +247,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#if defined(ROCM_DISABLE_CTX)
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
#else //DISABLE_CTX
#if defined(NIC_THOR2)
internode::shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[dst_expert_local_idx], reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
#else
internode::shmem_ctx_schar_put_nbi_warp(ctx, reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
#endif
#endif
}
#else //USE_ROCM
Expand Down Expand Up @@ -311,7 +315,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}

#if defined(NIC_IO) || defined(NIC_THOR2)
#if defined(NIC_IO)
if constexpr (kMultinode){
if (thread_id == 0 ){
#if defined(ROCM_DISABLE_CTX)
Expand Down Expand Up @@ -345,8 +349,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
__threadfence_system();
#if defined(ROCM_DISABLE_CTX)
internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#else
#if defined(NIC_THOR2)
internode::shmem_ctx_long_atomic_add(rocshmem_ctx_array[dst_expert_local_idx], rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#else
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#endif
#endif
}
#else //CUDA
Expand All @@ -355,7 +363,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
} else {
st_na_release(reinterpret_cast<int64_t *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
}
#if defined(NIC_IO) || defined(NIC_THOR2)
#if defined(NIC_IO)
if constexpr (kMultinode){
if (thread_id == 0 ){
#if defined(ROCM_DISABLE_CTX)
Expand All @@ -378,7 +386,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0){
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
if constexpr (kMultinode)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
Expand Down Expand Up @@ -480,7 +488,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
}
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
if constexpr (kMultinode)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
Expand Down Expand Up @@ -601,7 +609,7 @@ combine(void* combined_x,
int num_experts, int rank, int num_ranks,
int phases, bool zero_copy) {

#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
__shared__ internode::shmem_ctx_t ctx;
if constexpr(kMultinode)
EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0 or ctx == ROCSHMEM_CTX_INVALID);
Expand Down Expand Up @@ -679,29 +687,36 @@ combine(void* combined_x,
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(4, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);

//nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset);
if constexpr (!kMultinode){
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
}else{
#if defined(ROCM_DISABLE_CTX)
internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
#else
#if defined(NIC_THOR2)
internode::shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[local_expert_idx],reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
#else
internode::shmem_ctx_schar_put_nbi_warp(ctx,reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank);
#endif
#endif
}

}
}

#if !defined(NIC_THOR2)
if constexpr (kMultinode){
if (sub_warp_id == 0)
#if defined(ROCM_DISABLE_CTX)
if (sub_warp_id == 0) {
#if !defined(ROCM_DISABLE_CTX)
internode::shmem_fence();
#else
internode::shmem_ctx_quiet(ctx);
#endif
}
}
#endif

// Put finishing flag
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
Expand All @@ -725,8 +740,12 @@ combine(void* combined_x,
__threadfence_system();
#if defined(ROCM_DISABLE_CTX)
internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
#if defined(NIC_THOR2)
internode::shmem_ctx_long_atomic_add(rocshmem_ctx_array[local_expert_idx], rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
#endif
#endif //DISABLE_CTX
}
#else
Expand All @@ -742,7 +761,7 @@ combine(void* combined_x,
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0){
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
if constexpr (kMultinode)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
Expand Down Expand Up @@ -800,7 +819,7 @@ combine(void* combined_x,
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
#if !defined(ROCM_DISABLE_CTX)
#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2)
if constexpr (kMultinode)
internode::shmem_wg_ctx_destroy(&ctx);
#endif
Expand Down