From 1f7c1a204e55f0e7d46d1767c6ae610917c434f8 Mon Sep 17 00:00:00 2001 From: Yiltan Date: Fri, 13 Mar 2026 14:59:34 +0000 Subject: [PATCH] ctx_array --- csrc/kernels/internode_ll.cu | 43 ++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 3f6adc2..70555be 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -56,7 +56,7 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) clean_1[i] = 0; - // Barrier after cleaning (make sure low-latency mode work + // Barrier after cleaning (make sure low-latency mode work #ifdef USE_ROCM if (threadIdx.x == 0) internode::shmem_device_barrier_all(); @@ -109,7 +109,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, using scale_t = std::conditional_t; using packed_t = std::conditional_t; EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) __shared__ internode::shmem_ctx_t ctx; if constexpr (kMultinode) EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0 or ctx == ROCSHMEM_CTX_INVALID); @@ -247,7 +247,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, #if defined(ROCM_DISABLE_CTX) internode::shmemx_int8_put_nbi_warp(reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); #else //DISABLE_CTX +#if defined(NIC_THOR2) + internode::shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[dst_expert_local_idx], reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); +#else internode::shmem_ctx_schar_put_nbi_warp(ctx, reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); +#endif #endif } #else //USE_ROCM @@ -311,7 +315,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } } -#if defined(NIC_IO) || defined(NIC_THOR2) +#if defined(NIC_IO) if constexpr (kMultinode){ if (thread_id == 0 ){ #if defined(ROCM_DISABLE_CTX) @@ -345,8 +349,12 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, __threadfence_system(); #if defined(ROCM_DISABLE_CTX) internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); +#else +#if defined(NIC_THOR2) + internode::shmem_ctx_long_atomic_add(rocshmem_ctx_array[dst_expert_local_idx], rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #else internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); +#endif #endif } #else //CUDA @@ -355,7 +363,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } else { st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); } -#if defined(NIC_IO) || defined(NIC_THOR2) +#if defined(NIC_IO) if constexpr (kMultinode){ if (thread_id == 0 ){ #if defined(ROCM_DISABLE_CTX) @@ -378,7 +386,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Receiving phase LOW_LATENCY_DISPATCH_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0){ -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) if constexpr (kMultinode) internode::shmem_wg_ctx_destroy(&ctx); #endif @@ -480,7 +488,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } } } -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) if constexpr (kMultinode) internode::shmem_wg_ctx_destroy(&ctx); #endif @@ -601,7 +609,7 @@ combine(void* combined_x, int num_experts, int rank, int num_ranks, int phases, bool zero_copy) { -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) __shared__ internode::shmem_ctx_t ctx; if constexpr(kMultinode) EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0 or ctx == ROCSHMEM_CTX_INVALID); @@ -679,29 +687,36 @@ combine(void* combined_x, const auto buf_int4_ptr = reinterpret_cast(buf_ptr); if (not zero_copy) UNROLLED_WARP_COPY(4, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); - + //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset); if constexpr (!kMultinode){ internode::shmemx_int8_put_nbi_warp(reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); }else{ #if defined(ROCM_DISABLE_CTX) internode::shmemx_int8_put_nbi_warp(reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); +#else +#if defined(NIC_THOR2) + internode::shmem_ctx_schar_put_nbi_warp(rocshmem_ctx_array[local_expert_idx],reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); #else internode::shmem_ctx_schar_put_nbi_warp(ctx,reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); +#endif #endif } } } +#if !defined(NIC_THOR2) if constexpr (kMultinode){ - if (sub_warp_id == 0) -#if defined(ROCM_DISABLE_CTX) + if (sub_warp_id == 0) { +#if !defined(ROCM_DISABLE_CTX) internode::shmem_fence(); #else internode::shmem_ctx_quiet(ctx); #endif + } } +#endif // Put finishing flag EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); @@ -725,8 +740,12 @@ combine(void* combined_x, __threadfence_system(); #if defined(ROCM_DISABLE_CTX) internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); +#else +#if defined(NIC_THOR2) + internode::shmem_ctx_long_atomic_add(rocshmem_ctx_array[local_expert_idx], rdma_recv_flag + global_expert_idx, 1, dst_rank); #else internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank); +#endif #endif //DISABLE_CTX } #else @@ -742,7 +761,7 @@ combine(void* combined_x, // Receiving phase LOW_LATENCY_COMBINE_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0){ -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) if constexpr (kMultinode) internode::shmem_wg_ctx_destroy(&ctx); #endif @@ -800,7 +819,7 @@ combine(void* combined_x, (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; } } -#if !defined(ROCM_DISABLE_CTX) +#if !defined(ROCM_DISABLE_CTX) && !defined(NIC_THOR2) if constexpr (kMultinode) internode::shmem_wg_ctx_destroy(&ctx); #endif