Inter-node AllGather

In this tutorial, you will write a low latency all gather kernel using using Triton-distributed.

# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/03-inter-node-allgather.py

Motivations

When communicating across machines, you need to consider overlapping both inter-machine and intra-machine transfers. In this scenario, only NVSHMEM primitives can be used. The general approach is to assign some blocks to handle inter-machine transfers and others to handle intra-machine transfers. Initially, local data is transferred, and then intra-machine transfers wait for the inter-machine results to arrive before forwarding. When blocks require signal exchange, a dedicated thread (typically thread 0) should be responsible. The typical way to write this is:

if thread_idx < WORLD_SIZE and thread_idx != rank:
    libshmem_device.signal_wait_until(symm_flag + thread_idx, libshmem_device.NVSHMEM_CMP_EQ, signal_value)
__syncthreads()

1D Kernel

@triton_dist.jit(do_not_specialize=["rank", "signal_value"])
def all_gather_push_1d_kernel(symm_ptr, bytes_per_rank, symm_flag, WORLD_SIZE: tl.constexpr, rank, signal_value):
    pid = tl.program_id(0)
    thread_idx = tid(0)
    # there are WORLD_SIZE programs processing different data.
    if pid == rank:  # 1 program waitings for all other peers putmem done
        peer = thread_idx
        if peer < WORLD_SIZE and peer != rank:
            # rank `peer` is responsible for putmem segment `peer`.
            # wait for rank `peer` putmem done, then quit the kernel
            libshmem_device.signal_wait_until(
                symm_flag + peer,
                libshmem_device.NVSHMEM_CMP_EQ,
                signal_value,
            )
        # wait for all peers done
        __syncthreads()
    else:  # other `WORLD_SIZE - 1` programs putmem with signals to other peers
        peer = pid
        segment = rank
        # libshmem_device.putmem_signal_block calls NVSHMEM API `nvshmemx_putmem_signal_block`, which
        # sends `segment` at `rank` to `segment` at `peer` and set `flag[segment]` at `peer` to `signal_value`
        # NVSHMEM takes care of memory barrier semantic, so don't worry.
        libshmem_device.putmem_signal_block(
            tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank,
            tl.cast(symm_ptr, tl.pointer_type(tl.int8)) + segment * bytes_per_rank,
            bytes_per_rank,
            symm_flag + segment,
            signal_value,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            peer,
        )  # write and tell peer remote that remote copy is done

2D Kernel

@triton_dist.jit(do_not_specialize=["rank", "signal_value"])
def all_gather_push_2d_kernel(
    symm_ptr,
    bytes_per_rank,
    symm_flag,
    NNODES: tl.constexpr,
    WORLD_SIZE: tl.constexpr,
    rank,
    signal_value,
):
    pid = tl.program_id(0)
    thread_idx = tid(0)
    # Here we represent rank or segment in 2d fashion:
    #  (node_id, local_rank) => node_id * local_world_size + local_rank = rank
    LOCAL_WORLD_SIZE = WORLD_SIZE // NNODES
    node_id = rank // LOCAL_WORLD_SIZE
    local_rank = rank % LOCAL_WORLD_SIZE

    peer_rank = pid
    peer_node_id = peer_rank // LOCAL_WORLD_SIZE
    peer_local_rank = peer_rank % LOCAL_WORLD_SIZE
    symm_ptr = tl.cast(symm_ptr, tl.pointer_type(tl.int8))

    if peer_local_rank == local_rank:
        if peer_rank != rank:  # cross NODE communication
            # send segment `(node_id, local_rank)` at rank `(node_id, local_rank)` to segment `(node_id, local_rank)` at rank `(peer_node_id, local_rank)`
            peer = peer_node_id * LOCAL_WORLD_SIZE + local_rank
            segment = rank
            libshmem_device.putmem_signal_nbi_block(
                symm_ptr + segment * bytes_per_rank,
                symm_ptr + segment * bytes_per_rank,
                bytes_per_rank,
                symm_flag + segment,
                signal_value,
                libshmem_device.NVSHMEM_SIGNAL_SET,
                peer,
            )
        else:  # wait for all peers done done
            if thread_idx < WORLD_SIZE and thread_idx != rank:
                libshmem_device.signal_wait_until(
                    symm_flag + thread_idx,
                    libshmem_device.NVSHMEM_CMP_EQ,
                    signal_value,
                )
            __syncthreads()
    else:  # intra-NODE communication
        peer = node_id * LOCAL_WORLD_SIZE + peer_local_rank
        segment = peer_node_id * LOCAL_WORLD_SIZE + local_rank
        # wait for inter-NODE putmem_signal done from other nodes
        if peer_node_id != node_id:
            if thread_idx == 0:
                libshmem_device.signal_wait_until(
                    symm_flag + segment,
                    libshmem_device.NVSHMEM_CMP_EQ,
                    signal_value,
                )
            __syncthreads()
        # send segment (i, local_rank) to (i, peer_local_rank) for all i and local_rank != peer_rank
        libshmem_device.putmem_signal_block(
            symm_ptr + segment * bytes_per_rank,
            symm_ptr + segment * bytes_per_rank,
            bytes_per_rank,
            symm_flag + segment,
            signal_value,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            peer,
        )

Benchmark

import os
from dataclasses import dataclass

import torch

import triton_dist
import triton.language as tl
from triton_dist.language.extra.language_extra import __syncthreads, tid
from triton_dist.language.extra import libshmem_device
from triton_dist.profiler_utils import perf_func
from triton_dist.utils import (
    finalize_distributed,
    initialize_distributed,
    nvshmem_barrier_all_on_stream,
    NVSHMEM_SIGNAL_DTYPE,
    nvshmem_free_tensor_sync,
    nvshmem_create_tensor,
    sleep_async,
)


@dataclass
class AllGatherContext:
    rank: int
    node: int
    num_ranks: int
    num_nodes: int
    symm_signals: torch.Tensor
    signal_value: int = 15
    max_buffer_size: int = 2 * 32 * 1024 * 1024


def all_gather_push_1d(ctx: AllGatherContext, symm_buffer: torch.Tensor):
    ctx.signal_value += 1
    all_gather_push_1d_kernel[(ctx.num_ranks, )](
        symm_buffer,
        symm_buffer.nbytes // ctx.num_ranks,
        ctx.symm_signals[ctx.signal_value % 2],
        ctx.num_ranks,
        ctx.rank,
        ctx.signal_value,
    )
    return symm_buffer


def all_gather_push_2d(ctx: AllGatherContext, symm_buffer: torch.Tensor):
    ctx.signal_value += 1
    all_gather_push_2d_kernel[(ctx.num_ranks, )](
        symm_buffer,
        symm_buffer.nbytes // ctx.num_ranks,
        ctx.symm_signals[ctx.signal_value % 2],
        ctx.num_nodes,
        ctx.num_ranks,
        ctx.rank,
        ctx.signal_value,
        num_warps=32,  # use as many threads as possible
    )
    return symm_buffer


def perf_ag(func, ag_buffers: torch.Tensor, nbytes: int, ctx: AllGatherContext):
    nbytes_per_rank = nbytes // WORLD_SIZE

    ref_tensor = torch.arange(nbytes, dtype=torch.int8).cuda()
    ref_tensor = (torch.randint(0, 9999, [nbytes // 4], dtype=torch.int32).view(torch.int8).cuda())
    torch.distributed.broadcast(ref_tensor, src=0)

    # local copy
    ag_buffer = ag_buffers[ctx.signal_value % 2]
    # suppose You already write to ag_bufferp[index_start:index_end], so here copy does not count in profile
    index_start, index_end = nbytes_per_rank * RANK, nbytes_per_rank * (RANK + 1)
    ag_buffer[index_start:index_end].copy_(ref_tensor[index_start:index_end])

    def _run_all_gather_triton():
        ag_buffer = ag_buffers[ctx.signal_value % 2][:nbytes]
        return func(ctx, ag_buffer)

    def _run_all_gather_nccl():
        torch.distributed.all_gather_into_tensor(ref_tensor, ref_tensor[index_start:index_end], group=TP_GROUP)

    result = _run_all_gather_triton()

    # verify
    torch.testing.assert_close(result, ref_tensor, atol=0, rtol=0)
    print(f"✅ RANK[{RANK}] check passed")

    # perf all-gather by NCCL
    sleep_async(1000)  # in case CPU bound
    _, duration_per_iter_ms = perf_func(
        _run_all_gather_nccl,
        warmup_iters=5,
        iters=10,
    )
    gbps = nbytes * 1e-9 / (duration_per_iter_ms * 1e-3) * (WORLD_SIZE - 1) / WORLD_SIZE
    print(
        f"[NCCL] RANK = {RANK}, {nbytes // 1024} KB, Latency {duration_per_iter_ms * 1000:0.2f} us, Bus bandwith = {gbps:0.2f} GB/S"
    )

    # perf all-gather by triton-distributed
    nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
    sleep_async(1000)  # in case CPU bound
    _, duration_per_iter_ms = perf_func(
        _run_all_gather_triton,
        warmup_iters=5,
        iters=10,
    )

    gbps = nbytes * 1e-9 / (duration_per_iter_ms * 1e-3) * (WORLD_SIZE - 1) / WORLD_SIZE
    print(
        f"[Triton] RANK = {RANK}, {nbytes // 1024} KB, Latency {duration_per_iter_ms * 1000:0.2f} us, Bus bandwith = {gbps:0.2f} GB/S"
    )


# get all distributed arguments from environment. which is set by torchrun
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
NNODES = WORLD_SIZE // LOCAL_WORLD_SIZE

# Use initialize_distributed() for easy setup
TP_GROUP = initialize_distributed()

nbytes = 8 * 1024  # total bytes for AllGather

# use nvshmem_create_tensor as a torch-friendly wrapper of nvshmem_malloc.
# since our implementation does not wait for other peers done, so a double buffer is
# used to avoid data corrupt when all_gather kernels are in different phases.
symm_ag_buffer = nvshmem_create_tensor((2, nbytes), torch.int8)

# keep some variables here
ctx = AllGatherContext(
    rank=TP_GROUP.rank(),
    node=RANK // LOCAL_WORLD_SIZE,
    num_ranks=WORLD_SIZE,
    num_nodes=NNODES,
    symm_signals=[
        nvshmem_create_tensor((1, ), NVSHMEM_SIGNAL_DTYPE) for _ in range(2)
    ],
    signal_value=10,
)
print("using push 1d...")
perf_ag(
    all_gather_push_1d,
    symm_ag_buffer,
    nbytes,
    ctx,
)
print("using push 2d...")
perf_ag(
    all_gather_push_2d,
    symm_ag_buffer,
    nbytes,
    ctx,
)

# Clean up symmetric memory
nvshmem_free_tensor_sync(symm_ag_buffer)
nvshmem_free_tensor_sync(ctx.symm_signals[0])
nvshmem_free_tensor_sync(ctx.symm_signals[1])

finalize_distributed()