From a9e7e160f226cf43bcaf7bfe6b7015b0a84524e8 Mon Sep 17 00:00:00 2001 From: georgebisbas Date: Thu, 21 May 2026 17:22:12 +0200 Subject: [PATCH] Add: AllGather and ReduceScatter distributed L3 examples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add two new symmetric collective communication examples modelled on allreduce_distributed, plus a shared comm_utils.h header. New files: examples/workers/l3/common/comm_utils.h Shared CommRemotePtr template extracted from the allreduce pattern. New kernels include it as "common/comm_utils.h"; existing kernels are not modified. examples/workers/l3/allgather_distributed/ 3-phase kernel: stage-in → barrier → gather. Input: COUNT_PER_RANK=64 floats/rank. Output: nranks*64 floats (rank-ordered concatenation, same on every rank). Golden: output[r*C+i] = r*100 + i (closed-form, no reference run). examples/workers/l3/reduce_scatter_distributed/ 4-phase kernel: stage-in N chunks → barrier → reduce my chunk → stage-out. Input: nranks*64 floats/rank. Output: 64 floats/rank (rank-specific shard). Golden per dest: nranks*(dest*C+j) + 100*nranks*(nranks-1)/2. Scratch window is nranks-dependent; computed at runtime in run(). Both examples follow the orch.allocate_domain() API, the same Worker/orch_fn/TaskArgs structure as allreduce_distributed, and include pytest fixtures mirroring test_allreduce.py (a2a3sim/a2a3/a5sim, n_devices 2 and 4). --- .../l3/allgather_distributed/__init__.py | 9 + .../kernels/aiv/allgather_kernel.cpp | 140 +++++++++++ .../kernels/orchestration/allgather_orch.cpp | 49 ++++ .../workers/l3/allgather_distributed/main.py | 220 +++++++++++++++++ .../allgather_distributed/test_allgather.py | 28 +++ .../l3/reduce_scatter_distributed/__init__.py | 9 + .../kernels/aiv/reduce_scatter_kernel.cpp | 163 ++++++++++++ .../orchestration/reduce_scatter_orch.cpp | 49 ++++ .../l3/reduce_scatter_distributed/main.py | 233 ++++++++++++++++++ .../test_reduce_scatter.py | 28 +++ 10 files changed, 928 insertions(+) create mode 100644 examples/workers/l3/allgather_distributed/__init__.py create mode 100644 examples/workers/l3/allgather_distributed/kernels/aiv/allgather_kernel.cpp create mode 100644 examples/workers/l3/allgather_distributed/kernels/orchestration/allgather_orch.cpp create mode 100644 examples/workers/l3/allgather_distributed/main.py create mode 100644 examples/workers/l3/allgather_distributed/test_allgather.py create mode 100644 examples/workers/l3/reduce_scatter_distributed/__init__.py create mode 100644 examples/workers/l3/reduce_scatter_distributed/kernels/aiv/reduce_scatter_kernel.cpp create mode 100644 examples/workers/l3/reduce_scatter_distributed/kernels/orchestration/reduce_scatter_orch.cpp create mode 100644 examples/workers/l3/reduce_scatter_distributed/main.py create mode 100644 examples/workers/l3/reduce_scatter_distributed/test_reduce_scatter.py diff --git a/examples/workers/l3/allgather_distributed/__init__.py b/examples/workers/l3/allgather_distributed/__init__.py new file mode 100644 index 000000000..25708baec --- /dev/null +++ b/examples/workers/l3/allgather_distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Package marker so ``test_*.py`` can do ``from .main import run``.""" diff --git a/examples/workers/l3/allgather_distributed/kernels/aiv/allgather_kernel.cpp b/examples/workers/l3/allgather_distributed/kernels/aiv/allgather_kernel.cpp new file mode 100644 index 000000000..20668e22b --- /dev/null +++ b/examples/workers/l3/allgather_distributed/kernels/aiv/allgather_kernel.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * AllGather kernel — symmetric, 3-phase, HCCL-window scratch pattern. + * + * Phase 1 (stage-in): input[0..COUNT_PER_RANK) → my scratch slot (in window) + * Phase 2 (barrier): signal matrix + TWAIT cross-rank sync + * Phase 3 (gather): for r in 0..nranks-1: TLOAD(peer_scratch), TSTORE(output[r*COUNT_PER_RANK]) + * + * Every rank produces the identical full-gather output: the concatenation of + * all ranks' inputs in rank order. The signal area lives at the tail of + * scratch (COUNT_PER_RANK floats, then nranks int32 slots). + * + * args layout: + * tensor(0) = input COUNT_PER_RANK floats (INPUT) + * tensor(1) = output nranks*COUNT_PER_RANK floats (OUTPUT_EXISTING) + * tensor(2) = scratch HCCL window slot (INOUT) + * scalar(0) = nranks + * scalar(1) = CommContext device pointer + */ + +#include +#include +#include "pto/comm/comm_types.hpp" +#include "pto/comm/pto_comm_inst.hpp" +#include "platform_comm/comm_context.h" +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE inline __gm__ T *CommRemotePtr(__gm__ CommContext *ctx, __gm__ T *localPtr, int pe) { + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T *)(ctx->windowsIn[pe] + offset); +} + +static constexpr size_t COUNT_PER_RANK = 64; +static constexpr int kMaxSupportedRanks = 16; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *input_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *output_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *scratch_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + int nranks = static_cast(args[3]); + __gm__ CommContext *commCtx = reinterpret_cast<__gm__ CommContext *>(args[4]); + + __gm__ float *input = reinterpret_cast<__gm__ float *>(input_tensor->buffer.addr) + input_tensor->start_offset; + __gm__ float *output = reinterpret_cast<__gm__ float *>(output_tensor->buffer.addr) + output_tensor->start_offset; + __gm__ float *scratch = + reinterpret_cast<__gm__ float *>(scratch_tensor->buffer.addr) + scratch_tensor->start_offset; + // Signal area: nranks int32 slots at the tail of the scratch buffer. + // Peer r writes into my_rank's signal[r] when its stage-in is done. + __gm__ int32_t *signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + COUNT_PER_RANK); + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + int my_rank = static_cast(commCtx->rankId); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + ShapeDyn shape(1, 1, 1, 1, COUNT_PER_RANK); + StrideDyn stride(COUNT_PER_RANK, COUNT_PER_RANK, COUNT_PER_RANK, COUNT_PER_RANK, 1); + + TileData stageTile(1, COUNT_PER_RANK); + TileData recvTile(1, COUNT_PER_RANK); + TASSIGN(stageTile, 0x0); + TASSIGN(recvTile, 0x10000); + + Global inputG(input, shape, stride); + Global scratchG(scratch, shape, stride); + + // ------------------------------------------------------------------ + // Phase 1: stage-in — copy local input into my scratch slot (HCCL + // window) so that all peers can TLOAD it in Phase 3. + // ------------------------------------------------------------------ + TLOAD(stageTile, inputG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(scratchG, stageTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + pipe_barrier(PIPE_ALL); + + // ------------------------------------------------------------------ + // Phase 2: device barrier — notify every peer that stage-in is done, + // then wait until every peer has notified us. + // ------------------------------------------------------------------ + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) continue; + __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); + pto::comm::Signal sig(remote_signal); + pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); + } + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) continue; + pto::comm::Signal sig(signal_base + peer); + pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); + } + pipe_barrier(PIPE_ALL); + + // ------------------------------------------------------------------ + // Phase 3: gather — read each rank's scratch slot and write it into + // the corresponding slice of the output tensor. + // CommRemotePtr with pe==my_rank returns localPtr unchanged, so the + // self-read goes through the same code path as the remote reads. + // ------------------------------------------------------------------ + for (int r = 0; r < nranks; ++r) { + __gm__ float *remote_scratch = CommRemotePtr(commCtx, scratch, r); + Global remoteG(remote_scratch, shape, stride); + Global outputSlotG(output + r * COUNT_PER_RANK, shape, stride); + TLOAD(recvTile, remoteG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(outputSlotG, recvTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + pipe_barrier(PIPE_ALL); +} diff --git a/examples/workers/l3/allgather_distributed/kernels/orchestration/allgather_orch.cpp b/examples/workers/l3/allgather_distributed/kernels/orchestration/allgather_orch.cpp new file mode 100644 index 000000000..0051c93a2 --- /dev/null +++ b/examples/workers/l3/allgather_distributed/kernels/orchestration/allgather_orch.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * AllGather orchestration shim. + * + * tensor(0) input INPUT (COUNT_PER_RANK floats) + * tensor(1) output OUTPUT_EXISTING (nranks*COUNT_PER_RANK floats) + * tensor(2) scratch INOUT (HCCL window slot; written in phase 1, read in phase 3) + * scalar(0) nranks + * scalar(1) CommContext device pointer + */ + +#include + +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +allgather_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 5, // 3 tensors + 2 scalars + }; +} + +__attribute__((visibility("default"))) void allgather_orchestration(const ChipStorageTaskArgs &orch_args) { + Tensor input = from_tensor_arg(orch_args.tensor(0)); + Tensor output = from_tensor_arg(orch_args.tensor(1)); + Tensor scratch = from_tensor_arg(orch_args.tensor(2)); + + Arg params; + params.add_input(input); + params.add_output(output); + params.add_inout(scratch); + params.add_scalar(orch_args.scalar(0)); // nranks + params.add_scalar(orch_args.scalar(1)); // CommContext + rt_submit_aiv_task(0, params); +} + +} // extern "C" diff --git a/examples/workers/l3/allgather_distributed/main.py b/examples/workers/l3/allgather_distributed/main.py new file mode 100644 index 000000000..5000e2126 --- /dev/null +++ b/examples/workers/l3/allgather_distributed/main.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""End-to-end distributed allgather — symmetric 3-phase pattern. + +Each rank owns a private input of COUNT_PER_RANK floats. After the allgather +every rank holds the full concatenation of all ranks' inputs in rank order: + + Phase 1 stage-in input → my scratch slot (HCCL window) + Phase 2 device barrier signal matrix cross-rank sync via TNOTIFY/TWAIT + Phase 3 gather for r in 0..N-1: TLOAD(rank r's scratch) → TSTORE(output[r*COUNT_PER_RANK]) + +Run: + python examples/workers/l3/allgather_distributed/main.py -p a2a3sim -d 0-1 + +""" + +from __future__ import annotations + +import argparse +import os +import sys + +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") + +import torch # noqa: E402 +from simpler.task_interface import ( # noqa: E402 + ArgDirection, + CallConfig, + ChipCallable, + CommBufferSpec, + ContinuousTensor, + CoreCallable, + DataType, + TaskArgs, + TensorArgType, +) +from simpler.worker import Worker # noqa: E402 + +from simpler_setup.elf_parser import extract_text_section # noqa: E402 +from simpler_setup.kernel_compiler import KernelCompiler # noqa: E402 +from simpler_setup.pto_isa import ensure_pto_isa_root # noqa: E402 +from simpler_setup.torch_interop import make_tensor_arg # noqa: E402 + +HERE = os.path.dirname(os.path.abspath(__file__)) + +# Must match COUNT_PER_RANK in kernels/aiv/allgather_kernel.cpp. +COUNT_PER_RANK = 64 +DTYPE_NBYTES = 4 # float32 +BUFFER_NBYTES = COUNT_PER_RANK * DTYPE_NBYTES # 256 B per rank's scratch slot +# Signal tail: one int32 slot per rank, bounded by kMaxSupportedRanks. +SIGNAL_TAIL_NBYTES = 16 * 4 # 64 B +SCRATCH_NBYTES = BUFFER_NBYTES + SIGNAL_TAIL_NBYTES # 320 B + + +def parse_device_range(spec: str) -> list[int]: + if "-" in spec: + lo, hi = (int(x) for x in spec.split("-")) + ids = list(range(lo, hi + 1)) + else: + ids = [int(spec)] + if not (2 <= len(ids) <= 16): + raise ValueError(f"allgather_distributed needs between 2 and 16 devices, got {len(ids)} ({ids})") + return ids + + +def build_chip_callable(platform: str, pto_isa_commit: str | None) -> ChipCallable: + """Compile the AIV allgather kernel + its C++ orchestration shim.""" + kc = KernelCompiler(platform=platform) + runtime = "tensormap_and_ringbuffer" + pto_isa_root = ensure_pto_isa_root(commit=pto_isa_commit, clone_protocol="https") + include_dirs = kc.get_orchestration_include_dirs(runtime) + + # src/common — for platform_comm/comm_context.h + kernel_include_dirs = list(include_dirs) + [ + str(kc.project_root / "src" / "common"), + ] + kernel_bytes = kc.compile_incore( + source_path=os.path.join(HERE, "kernels/aiv/allgather_kernel.cpp"), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=kernel_include_dirs, + ) + if not platform.endswith("sim"): + kernel_bytes = extract_text_section(kernel_bytes) + + orch_bytes = kc.compile_orchestration( + runtime_name=runtime, + source_path=os.path.join(HERE, "kernels/orchestration/allgather_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + func_name="allgather_orchestration", + config_name="allgather_orchestration_config", + binary=orch_bytes, + children=[(0, core_callable)], + ) + + +def expected_output(nranks: int) -> list[float]: + """Rank-ordered concatenation of all inputs: output[r*C+i] = r*100 + i.""" + return [float(r * 100 + i) for r in range(nranks) for i in range(COUNT_PER_RANK)] + + +def run( + device_ids: list[int], + platform: str = "a2a3", + pto_isa_commit: str | None = None, + build: bool = False, +) -> int: + """Core logic — callable from both CLI and pytest.""" + nranks = len(device_ids) + window_size = max(SCRATCH_NBYTES, 4 * 1024) + + print(f"[allgather] platform={platform} devices={device_ids} nranks={nranks}") + + host_inputs = [ + torch.tensor([i + rank * 100 for i in range(COUNT_PER_RANK)], dtype=torch.float32).share_memory_() + for rank in range(nranks) + ] + host_outputs = [torch.zeros(nranks * COUNT_PER_RANK, dtype=torch.float32).share_memory_() for _ in range(nranks)] + + print("[allgather] compiling kernels...") + chip_callable = build_chip_callable(platform, pto_isa_commit) + + worker = Worker( + level=3, + platform=platform, + runtime="tensormap_and_ringbuffer", + device_ids=device_ids, + num_sub_workers=0, + build=build, + ) + chip_cid = worker.register(chip_callable) + + try: + print("[allgather] init worker (forks chip children; base comm is lazy)...") + worker.init() + + def orch_fn(orch, _args, cfg): + with orch.allocate_domain( + name="default", + workers=list(range(nranks)), + window_size=window_size, + buffers=[CommBufferSpec(name="scratch", dtype="float32", count=COUNT_PER_RANK, nbytes=SCRATCH_NBYTES)], + ) as handle: + for i in range(nranks): + domain = handle[i] + print( + f"[allgather] chip {i}: rank={domain.domain_rank}/{domain.domain_size} " + f"window=[0x{domain.local_window_base:x} +{domain.actual_window_size}B] " + f"scratch=0x{domain.buffer_ptrs['scratch']:x}" + ) + chip_args = TaskArgs() + chip_args.add_tensor(make_tensor_arg(host_inputs[i]), TensorArgType.INPUT) + chip_args.add_tensor(make_tensor_arg(host_outputs[i]), TensorArgType.OUTPUT_EXISTING) + chip_args.add_tensor( + ContinuousTensor.make( + data=domain.buffer_ptrs["scratch"], + shapes=(COUNT_PER_RANK,), + dtype=DataType.FLOAT32, + child_memory=True, + ), + TensorArgType.INOUT, + ) + chip_args.add_scalar(domain.domain_size) + chip_args.add_scalar(domain.device_ctx) + orch.submit_next_level(chip_cid, chip_args, cfg, worker=i) + + print(f"[allgather] running {nranks}-chip allgather DAG...") + worker.run(orch_fn, args=None, config=CallConfig()) + + expected = torch.tensor(expected_output(nranks), dtype=torch.float32) + ok = True + for i in range(nranks): + max_diff = float(torch.max(torch.abs(host_outputs[i] - expected))) + print(f"[allgather] chip {i}: max |out - expected| = {max_diff:.3e}") + if max_diff > 1e-3: + ok = False + for j in range(min(4, nranks * COUNT_PER_RANK)): + print(f" output[{j}]={float(host_outputs[i][j])!r} expected={float(expected[j])!r}") + + if not ok: + print("[allgather] golden check FAILED") + return 1 + print("[allgather] all ranks matched golden ✅") + return 0 + finally: + worker.close() + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-p", "--platform", default="a2a3", help="Platform backend, e.g. a2a3 or a2a3sim.") + parser.add_argument( + "-d", "--device", default="0-1", help="Device range, e.g. '0-1' or '0-3'. 2 to 16 chips required." + ) + parser.add_argument( + "--build", action="store_true", help="Rebuild runtime from source instead of using cached libs." + ) + parser.add_argument("--pto-isa-commit", default=None, help="Optional PTO ISA commit/tag to fetch before compiling.") + cli = parser.parse_args() + + return run( + parse_device_range(cli.device), platform=cli.platform, pto_isa_commit=cli.pto_isa_commit, build=cli.build + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/workers/l3/allgather_distributed/test_allgather.py b/examples/workers/l3/allgather_distributed/test_allgather.py new file mode 100644 index 000000000..b9152e5b4 --- /dev/null +++ b/examples/workers/l3/allgather_distributed/test_allgather.py @@ -0,0 +1,28 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""ST for examples/workers/l3/allgather_distributed.""" + +import pytest + +from .main import run + + +@pytest.mark.platforms(["a2a3sim", "a2a3", "a5sim"]) +@pytest.mark.runtime("tensormap_and_ringbuffer") +@pytest.mark.parametrize( + "n_devices", + [ + pytest.param(2, marks=pytest.mark.device_count(2)), + pytest.param(4, marks=pytest.mark.device_count(4)), + ], +) +def test_allgather_distributed(st_platform, st_device_ids, n_devices): + assert len(st_device_ids) == n_devices + rc = run([int(d) for d in st_device_ids], platform=st_platform) + assert rc == 0 diff --git a/examples/workers/l3/reduce_scatter_distributed/__init__.py b/examples/workers/l3/reduce_scatter_distributed/__init__.py new file mode 100644 index 000000000..25708baec --- /dev/null +++ b/examples/workers/l3/reduce_scatter_distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Package marker so ``test_*.py`` can do ``from .main import run``.""" diff --git a/examples/workers/l3/reduce_scatter_distributed/kernels/aiv/reduce_scatter_kernel.cpp b/examples/workers/l3/reduce_scatter_distributed/kernels/aiv/reduce_scatter_kernel.cpp new file mode 100644 index 000000000..7e49e67f4 --- /dev/null +++ b/examples/workers/l3/reduce_scatter_distributed/kernels/aiv/reduce_scatter_kernel.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * ReduceScatter kernel — symmetric, 4-phase, HCCL-window scratch pattern. + * + * Phase 1 (stage-in): for chunk in 0..nranks-1: input[chunk*C..(chunk+1)*C) → scratch[chunk*C..) + * Phase 2 (barrier): signal matrix + TWAIT cross-rank sync + * Phase 3 (reduce): acc = scratch[my_rank*C]; for each peer: acc += peer's scratch[my_rank*C] + * Phase 4 (stage-out): TSTORE(output, acc) + * + * Each rank is responsible for reducing the chunk at index my_rank across all + * ranks. The signal area lives at the tail of the scratch buffer, after the + * full nranks*COUNT_PER_RANK float staging area. + * + * args layout: + * tensor(0) = input nranks*COUNT_PER_RANK floats (INPUT) + * tensor(1) = output COUNT_PER_RANK floats (OUTPUT_EXISTING) + * tensor(2) = scratch HCCL window slot (INOUT) + * scalar(0) = nranks + * scalar(1) = CommContext device pointer + */ + +#include +#include +#include "pto/comm/comm_types.hpp" +#include "pto/comm/pto_comm_inst.hpp" +#include "platform_comm/comm_context.h" +#include "tensor.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE inline __gm__ T *CommRemotePtr(__gm__ CommContext *ctx, __gm__ T *localPtr, int pe) { + uint64_t localBase = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)localPtr - localBase; + return (__gm__ T *)(ctx->windowsIn[pe] + offset); +} + +static constexpr size_t COUNT_PER_RANK = 64; +static constexpr int kMaxSupportedRanks = 16; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *input_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *output_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *scratch_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + int nranks = static_cast(args[3]); + __gm__ CommContext *commCtx = reinterpret_cast<__gm__ CommContext *>(args[4]); + + __gm__ float *input = reinterpret_cast<__gm__ float *>(input_tensor->buffer.addr) + input_tensor->start_offset; + __gm__ float *output = reinterpret_cast<__gm__ float *>(output_tensor->buffer.addr) + output_tensor->start_offset; + __gm__ float *scratch = + reinterpret_cast<__gm__ float *>(scratch_tensor->buffer.addr) + scratch_tensor->start_offset; + // Signal area: nranks int32 slots at the tail of the staging area. + // Peer r writes into my_rank's signal[r] when its stage-in is done. + + using ShapeDyn = pto::Shape; + using StrideDyn = pto::Stride; + using Global = pto::GlobalTensor; + using TileData = pto::Tile; + + int my_rank = static_cast(commCtx->rankId); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + // signal_base follows the nranks * COUNT_PER_RANK float staging region. + __gm__ int32_t *signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + nranks * COUNT_PER_RANK); + + ShapeDyn shape(1, 1, 1, 1, COUNT_PER_RANK); + StrideDyn stride(COUNT_PER_RANK, COUNT_PER_RANK, COUNT_PER_RANK, COUNT_PER_RANK, 1); + + TileData stageTile(1, COUNT_PER_RANK); + TileData accTile(1, COUNT_PER_RANK); + TileData recvTile(1, COUNT_PER_RANK); + TASSIGN(stageTile, 0x0); + TASSIGN(accTile, 0x10000); + TASSIGN(recvTile, 0x20000); + + Global outputG(output, shape, stride); + + // ------------------------------------------------------------------ + // Phase 1: stage-in — copy all N input chunks into scratch so that + // every peer can read our entire input in Phase 3. + // ------------------------------------------------------------------ + for (int chunk = 0; chunk < nranks; ++chunk) { + Global inputChunkG(input + chunk * COUNT_PER_RANK, shape, stride); + Global scratchChunkG(scratch + chunk * COUNT_PER_RANK, shape, stride); + TLOAD(stageTile, inputChunkG); + set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0); + TSTORE(scratchChunkG, stageTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + } + pipe_barrier(PIPE_ALL); + + // ------------------------------------------------------------------ + // Phase 2: device barrier — notify every peer that stage-in is done, + // then wait until every peer has notified us. + // ------------------------------------------------------------------ + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) continue; + __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); + pto::comm::Signal sig(remote_signal); + pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); + } + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) continue; + pto::comm::Signal sig(signal_base + peer); + pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); + } + pipe_barrier(PIPE_ALL); + + // ------------------------------------------------------------------ + // Phase 3: reduce — sum chunk my_rank from every rank's scratch into + // accTile. Start with my own copy, then add all peers via + // CommRemotePtr (same pattern as allreduce Phase 3). + // ------------------------------------------------------------------ + Global myChunkG(scratch + my_rank * COUNT_PER_RANK, shape, stride); + TLOAD(accTile, myChunkG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) continue; + __gm__ float *remote_chunk = CommRemotePtr(commCtx, scratch + my_rank * COUNT_PER_RANK, peer); + Global remoteG(remote_chunk, shape, stride); + TLOAD(recvTile, remoteG); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TADD(accTile, accTile, recvTile); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + } + + // ------------------------------------------------------------------ + // Phase 4: stage-out — write the reduced accumulator into the local + // output (plain device mem), no remote traffic involved. + // ------------------------------------------------------------------ + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outputG, accTile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/workers/l3/reduce_scatter_distributed/kernels/orchestration/reduce_scatter_orch.cpp b/examples/workers/l3/reduce_scatter_distributed/kernels/orchestration/reduce_scatter_orch.cpp new file mode 100644 index 000000000..e713d0afb --- /dev/null +++ b/examples/workers/l3/reduce_scatter_distributed/kernels/orchestration/reduce_scatter_orch.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * ReduceScatter orchestration shim. + * + * tensor(0) input INPUT (nranks*COUNT_PER_RANK floats) + * tensor(1) output OUTPUT_EXISTING (COUNT_PER_RANK floats) + * tensor(2) scratch INOUT (HCCL window slot; written in phase 1, read in phase 3) + * scalar(0) nranks + * scalar(1) CommContext device pointer + */ + +#include + +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +reduce_scatter_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 5, // 3 tensors + 2 scalars + }; +} + +__attribute__((visibility("default"))) void reduce_scatter_orchestration(const ChipStorageTaskArgs &orch_args) { + Tensor input = from_tensor_arg(orch_args.tensor(0)); + Tensor output = from_tensor_arg(orch_args.tensor(1)); + Tensor scratch = from_tensor_arg(orch_args.tensor(2)); + + Arg params; + params.add_input(input); + params.add_output(output); + params.add_inout(scratch); + params.add_scalar(orch_args.scalar(0)); // nranks + params.add_scalar(orch_args.scalar(1)); // CommContext + rt_submit_aiv_task(0, params); +} + +} // extern "C" diff --git a/examples/workers/l3/reduce_scatter_distributed/main.py b/examples/workers/l3/reduce_scatter_distributed/main.py new file mode 100644 index 000000000..21ea0d392 --- /dev/null +++ b/examples/workers/l3/reduce_scatter_distributed/main.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""End-to-end distributed reduce-scatter — symmetric 4-phase pattern. + +Each rank owns a private input of nranks*COUNT_PER_RANK floats (N equal-sized +chunks). After reduce-scatter, rank r holds the element-wise sum of chunk r +from every rank: + + Phase 1 stage-in all N input chunks → scratch slots in HCCL window + Phase 2 device barrier signal matrix cross-rank sync via TNOTIFY/TWAIT + Phase 3 reduce acc = my scratch[my_rank*C]; acc += peer's scratch[my_rank*C] for each peer + Phase 4 stage-out TSTORE acc → output + +Run: + python examples/workers/l3/reduce_scatter_distributed/main.py -p a2a3sim -d 0-1 + +""" + +from __future__ import annotations + +import argparse +import os +import sys + +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") + +import torch # noqa: E402 +from simpler.task_interface import ( # noqa: E402 + ArgDirection, + CallConfig, + ChipCallable, + CommBufferSpec, + ContinuousTensor, + CoreCallable, + DataType, + TaskArgs, + TensorArgType, +) +from simpler.worker import Worker # noqa: E402 + +from simpler_setup.elf_parser import extract_text_section # noqa: E402 +from simpler_setup.kernel_compiler import KernelCompiler # noqa: E402 +from simpler_setup.pto_isa import ensure_pto_isa_root # noqa: E402 +from simpler_setup.torch_interop import make_tensor_arg # noqa: E402 + +HERE = os.path.dirname(os.path.abspath(__file__)) + +# Must match COUNT_PER_RANK in kernels/aiv/reduce_scatter_kernel.cpp. +COUNT_PER_RANK = 64 +DTYPE_NBYTES = 4 # float32 +# Signal tail: one int32 slot per rank, bounded by kMaxSupportedRanks. +SIGNAL_TAIL_NBYTES = 16 * 4 # 64 B +# NOTE: the full scratch size depends on nranks (staging area = nranks * COUNT_PER_RANK floats). +# It is computed inside run() once nranks is known. + + +def parse_device_range(spec: str) -> list[int]: + if "-" in spec: + lo, hi = (int(x) for x in spec.split("-")) + ids = list(range(lo, hi + 1)) + else: + ids = [int(spec)] + if not (2 <= len(ids) <= 16): + raise ValueError(f"reduce_scatter_distributed needs between 2 and 16 devices, got {len(ids)} ({ids})") + return ids + + +def build_chip_callable(platform: str, pto_isa_commit: str | None) -> ChipCallable: + """Compile the AIV reduce-scatter kernel + its C++ orchestration shim.""" + kc = KernelCompiler(platform=platform) + runtime = "tensormap_and_ringbuffer" + pto_isa_root = ensure_pto_isa_root(commit=pto_isa_commit, clone_protocol="https") + include_dirs = kc.get_orchestration_include_dirs(runtime) + + # src/common — for platform_comm/comm_context.h + kernel_include_dirs = list(include_dirs) + [ + str(kc.project_root / "src" / "common"), + ] + kernel_bytes = kc.compile_incore( + source_path=os.path.join(HERE, "kernels/aiv/reduce_scatter_kernel.cpp"), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=kernel_include_dirs, + ) + if not platform.endswith("sim"): + kernel_bytes = extract_text_section(kernel_bytes) + + orch_bytes = kc.compile_orchestration( + runtime_name=runtime, + source_path=os.path.join(HERE, "kernels/orchestration/reduce_scatter_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + func_name="reduce_scatter_orchestration", + config_name="reduce_scatter_orchestration_config", + binary=orch_bytes, + children=[(0, core_callable)], + ) + + +def expected_output(nranks: int, dest: int) -> list[float]: + """output[j] = sum_r input_r[dest*C + j] = nranks*(dest*C+j) + 100*nranks*(nranks-1)/2.""" + return [ + float(nranks * (dest * COUNT_PER_RANK + j) + 100 * nranks * (nranks - 1) // 2) for j in range(COUNT_PER_RANK) + ] + + +def run( + device_ids: list[int], + platform: str = "a2a3", + pto_isa_commit: str | None = None, + build: bool = False, +) -> int: + """Core logic — callable from both CLI and pytest.""" + nranks = len(device_ids) + # Scratch = nranks * COUNT_PER_RANK floats (staging) + signal tail. + scratch_nbytes = nranks * COUNT_PER_RANK * DTYPE_NBYTES + SIGNAL_TAIL_NBYTES + window_size = max(scratch_nbytes, 4 * 1024) + + print(f"[reduce_scatter] platform={platform} devices={device_ids} nranks={nranks}") + + host_inputs = [ + torch.tensor([i + rank * 100 for i in range(nranks * COUNT_PER_RANK)], dtype=torch.float32).share_memory_() + for rank in range(nranks) + ] + host_outputs = [torch.zeros(COUNT_PER_RANK, dtype=torch.float32).share_memory_() for _ in range(nranks)] + + print("[reduce_scatter] compiling kernels...") + chip_callable = build_chip_callable(platform, pto_isa_commit) + + worker = Worker( + level=3, + platform=platform, + runtime="tensormap_and_ringbuffer", + device_ids=device_ids, + num_sub_workers=0, + build=build, + ) + chip_cid = worker.register(chip_callable) + + try: + print("[reduce_scatter] init worker (forks chip children; base comm is lazy)...") + worker.init() + + def orch_fn(orch, _args, cfg): + with orch.allocate_domain( + name="default", + workers=list(range(nranks)), + window_size=window_size, + buffers=[ + CommBufferSpec( + name="scratch", + dtype="float32", + count=nranks * COUNT_PER_RANK, + nbytes=scratch_nbytes, + ) + ], + ) as handle: + for i in range(nranks): + domain = handle[i] + print( + f"[reduce_scatter] chip {i}: rank={domain.domain_rank}/{domain.domain_size} " + f"window=[0x{domain.local_window_base:x} +{domain.actual_window_size}B] " + f"scratch=0x{domain.buffer_ptrs['scratch']:x}" + ) + chip_args = TaskArgs() + chip_args.add_tensor(make_tensor_arg(host_inputs[i]), TensorArgType.INPUT) + chip_args.add_tensor(make_tensor_arg(host_outputs[i]), TensorArgType.OUTPUT_EXISTING) + chip_args.add_tensor( + ContinuousTensor.make( + data=domain.buffer_ptrs["scratch"], + shapes=(nranks * COUNT_PER_RANK,), + dtype=DataType.FLOAT32, + child_memory=True, + ), + TensorArgType.INOUT, + ) + chip_args.add_scalar(domain.domain_size) + chip_args.add_scalar(domain.device_ctx) + orch.submit_next_level(chip_cid, chip_args, cfg, worker=i) + + print(f"[reduce_scatter] running {nranks}-chip reduce-scatter DAG...") + worker.run(orch_fn, args=None, config=CallConfig()) + + ok = True + for i in range(nranks): + expected = torch.tensor(expected_output(nranks, i), dtype=torch.float32) + max_diff = float(torch.max(torch.abs(host_outputs[i] - expected))) + print(f"[reduce_scatter] chip {i}: max |out - expected| = {max_diff:.3e}") + if max_diff > 1e-3: + ok = False + for j in range(min(4, COUNT_PER_RANK)): + print(f" output[{j}]={float(host_outputs[i][j])!r} expected={float(expected[j])!r}") + + if not ok: + print("[reduce_scatter] golden check FAILED") + return 1 + print("[reduce_scatter] all ranks matched golden ✅") + return 0 + finally: + worker.close() + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-p", "--platform", default="a2a3", help="Platform backend, e.g. a2a3 or a2a3sim.") + parser.add_argument( + "-d", "--device", default="0-1", help="Device range, e.g. '0-1' or '0-3'. 2 to 16 chips required." + ) + parser.add_argument( + "--build", action="store_true", help="Rebuild runtime from source instead of using cached libs." + ) + parser.add_argument("--pto-isa-commit", default=None, help="Optional PTO ISA commit/tag to fetch before compiling.") + cli = parser.parse_args() + + return run( + parse_device_range(cli.device), platform=cli.platform, pto_isa_commit=cli.pto_isa_commit, build=cli.build + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/workers/l3/reduce_scatter_distributed/test_reduce_scatter.py b/examples/workers/l3/reduce_scatter_distributed/test_reduce_scatter.py new file mode 100644 index 000000000..3ae7137df --- /dev/null +++ b/examples/workers/l3/reduce_scatter_distributed/test_reduce_scatter.py @@ -0,0 +1,28 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""ST for examples/workers/l3/reduce_scatter_distributed.""" + +import pytest + +from .main import run + + +@pytest.mark.platforms(["a2a3sim", "a2a3", "a5sim"]) +@pytest.mark.runtime("tensormap_and_ringbuffer") +@pytest.mark.parametrize( + "n_devices", + [ + pytest.param(2, marks=pytest.mark.device_count(2)), + pytest.param(4, marks=pytest.mark.device_count(4)), + ], +) +def test_reduce_scatter_distributed(st_platform, st_device_ids, n_devices): + assert len(st_device_ids) == n_devices + rc = run([int(d) for d in st_device_ids], platform=st_platform) + assert rc == 0