Low Latency All-to-All Communication

In this tutorial, we demonstrate how to implement the All-to-All communication paradigm in Expert Parallelism (EP) for MoE models using Triton-distributed.

# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/04-deepseek-infer-all2all.py

Motivations

First, let’s quickly review the EP workflow: In MoE, the E experts are distributed across N devices (EP ranks). For simplicity, we assume that N divides E evenly, so experts are distributed uniformly. For example, when E = 128 and N = 32, each device will handle 4 experts.

During inference with EP, each device is assigned a subset of tokens, as determined by the MoE router module. The router on each device generates a tensor of shape [num_tokens, topk], containing the indices of the top k experts selected for each token. The experts chosen for a token may reside on other devices, necessitating communication to send the tokens to the appropriate devices. Similarly, if other devices have tokens that select experts located on the current device, those tokens need to be sent to the current device as well. This process is called Dispatch.

After the tokens are processed by their corresponding experts, they need to be returned to their original devices. This operation mirrors Dispatch and is referred to as Combine. From a communication perspective, both Dispatch and Combine are essentially All-to-All collective communication operations.

Next, we demonstrate how to implement an efficient All-to-All operation in Triton-distributed with minimal code.

Triton-distributed provides a programming model that allows fine-grained control over data movement between devices, optimizing hardware utilization. At the core of our implementation are low-level primitives that manage the communication logic.

Kernel

@triton_dist.jit
def all_to_all_kernel(
    send_tensor,
    data_src,
    data_dst,
    scale_src,
    scale_dst,
    splits_src,
    splits_dst,
    signal,
    send_splits_cumsum,
    recv_offset,
    rank: int,
    call_count: int,
    act_pos: int,
    MODE: tl.constexpr,
    ONLINE_QUANT_FP8: tl.constexpr,
    FP8_GSIZE: tl.constexpr,
    WORLD_SIZE: tl.constexpr,
    HIDDEN: tl.constexpr,
    MAX_M: tl.constexpr,
    NUM_TOT_EXPERTS: tl.constexpr,
    BM: tl.constexpr,
    BN: tl.constexpr,
):
    """
    All-to-All kernel for the Dispatch and Combine phases.

    - send_tensor: The tokens to be sent.
    - data/scale/splits_src/dst: The source and destination symmetric buffers for communication.
    - signal: signal buffer for communication.
    - send_splits_cumsum: Cumulative sum of the token splits (expert-level) for the current rank.
    - recv_offset: only used in combine mode, the base offset of the received tokens.
    - call_count: as the unique ID used for signal operation.
    - act_pos: The position of the active buffer (0 or 1) for double buffering.

    - MODE: Determines whether the operation is Dispatch (0) or Combine (1).
    - ONLINE_QUANT_FP8: A flag indicating whether FP8 quantization is used.
    - FP8_GSIZE: The group size for FP8 quantization.
    - WORLD_SIZE: number of EP ranks.
    - HIDDEN: The hidden size for each token.
    - MAX_M: The maximum number of tokens that can be processed per rank.
    - EXPERTS_PER_RANK: The number of experts handled by each rank.
    - NUM_TOT_EXPERTS: The total number of experts.
    - BM, BN: Block size used to copy data to send buffer
    """
    pid = tl.program_id(0)
    # Triton-distributed exposes `tid` that can be used to identify the thread index within a CTA
    threadidx = tid(axis=0)
    NUM_GROUPS: tl.constexpr = HIDDEN // FP8_GSIZE
    EXPERTS_PER_RANK: tl.constexpr = NUM_TOT_EXPERTS // WORLD_SIZE

Calculate the token range for the current program (rank), get the corresponding pointer.

exp_st = pid * EXPERTS_PER_RANK
exp_ed = exp_st + EXPERTS_PER_RANK
m_st = tl.load(send_splits_cumsum + exp_st)
m_ed = tl.load(send_splits_cumsum + exp_ed)
num_rows_cur_block = m_ed - m_st

# Signal pointer to communicate when data is ready
signal_ptr = signal + act_pos * WORLD_SIZE + rank
if MODE == 0:  # dispatch mode
    # Calculate source and destination offsets based on the expert-level token number cumsum
    split_src_ptr = splits_src + (exp_st + pid)
    split_dst_ptr = splits_dst + act_pos * (NUM_TOT_EXPERTS + WORLD_SIZE) + rank * (EXPERTS_PER_RANK + 1)

    off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK)
    off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1
    cumsum_sts = tl.load(send_splits_cumsum + off0)
    cumsum_eds = tl.load(send_splits_cumsum + off1)
    tl.store(split_src_ptr + tl.arange(0, EXPERTS_PER_RANK), cumsum_eds - cumsum_sts)
    tl.store(split_src_ptr + EXPERTS_PER_RANK, m_st)

    # Calculate the source and destination data offsets for the dispatch operation
    src_off = m_st
    dst_off = rank * MAX_M
    data_src_ptr = data_src + src_off * HIDDEN
    data_dst_ptr = data_dst + act_pos * WORLD_SIZE * MAX_M * HIDDEN + dst_off * HIDDEN
    scale_src_ptr = scale_src + src_off * NUM_GROUPS
    scale_dst_ptr = scale_dst + act_pos * WORLD_SIZE * MAX_M * NUM_GROUPS + dst_off * NUM_GROUPS
else:  # combine mode
    # For the combine phase, source and destination offsets are updated accordingly
    src_off = pid * MAX_M
    dst_off = tl.load(recv_offset + pid)
    data_src_ptr = data_src + act_pos * WORLD_SIZE * MAX_M * HIDDEN + src_off * HIDDEN
    data_dst_ptr = data_dst + dst_off * HIDDEN
    scale_src_ptr = scale_src + act_pos * WORLD_SIZE * MAX_M * NUM_GROUPS + src_off * NUM_GROUPS
    scale_dst_ptr = scale_dst + dst_off * NUM_GROUPS

Copy the data (may be online quantized to FP8) to send buffer.

off_m = tl.arange(0, BM)
if ONLINE_QUANT_FP8 and MODE == 0:
    # TODO: adaptive UNROLL_FACTOR
    UNROLL_FACTOR: tl.constexpr = 4
    group_offs = off_m[:, None] * HIDDEN + tl.arange(0, FP8_GSIZE * UNROLL_FACTOR)[None, :]
    send_tensor_ptrs = send_tensor + m_st * HIDDEN + group_offs
    data_src_ptrs = tl.cast(data_src_ptr, tl.pointer_type(tl.float8e4nv)) + group_offs
    scale_src_ptrs = scale_src_ptr + off_m[:, None] * NUM_GROUPS + tl.arange(0, UNROLL_FACTOR)[None, :]
    # online quant the input data to FP8
    for i in tl.range(ceil_div(num_rows_cur_block, BM)):
        group_mask = off_m[:, None] < num_rows_cur_block - i * BM
        for _ in tl.static_range(0, NUM_GROUPS, UNROLL_FACTOR):
            group = tl.reshape(tl.load(send_tensor_ptrs, group_mask), (BM * UNROLL_FACTOR, FP8_GSIZE))
            scale = tl.max(tl.abs(group), 1, keep_dims=True).to(tl.float32) * FP8_MAX_INV
            quant = tl.reshape((group.to(tl.float32) / scale).to(tl.float8e4nv), (BM, UNROLL_FACTOR * FP8_GSIZE))
            tl.store(data_src_ptrs, quant, group_mask)
            tl.store(scale_src_ptrs, tl.reshape(scale, (BM, UNROLL_FACTOR)), group_mask)
            send_tensor_ptrs += UNROLL_FACTOR * FP8_GSIZE
            data_src_ptrs += UNROLL_FACTOR * FP8_GSIZE
            scale_src_ptrs += UNROLL_FACTOR
        send_tensor_ptrs += (BM - 1) * HIDDEN
        data_src_ptrs += (BM - 1) * HIDDEN
        scale_src_ptrs += (BM - 1) * NUM_GROUPS
else:
    off_n = tl.arange(0, BN)
    send_tensor_ptrs = send_tensor + m_st * HIDDEN + off_m[:, None] * HIDDEN + off_n[None, :]
    data_src_ptrs = data_src_ptr + off_m[:, None] * HIDDEN + off_n[None, :]
    for i in tl.range(ceil_div(num_rows_cur_block, BM)):
        data_mask = (off_m[:, None] < num_rows_cur_block - i * BM) & (off_n[None, :] < HIDDEN)
        tl.store(data_src_ptrs, tl.load(send_tensor_ptrs, data_mask), data_mask)
        send_tensor_ptrs += BM * HIDDEN
        data_src_ptrs += BM * HIDDEN

Perform the memory copy operation using shared memory for inter-rank communication.

# the last argument is the peer id (id of target rank)
libshmem_device.putmem_nbi_block(
    data_dst_ptr,
    data_src_ptr,
    num_rows_cur_block * HIDDEN * (1 if (ONLINE_QUANT_FP8 and MODE == 0) else 2),
    pid,
)
if MODE == 0:
    # Dispatch mode: send split information to the target rank
    libshmem_device.putmem_nbi_block(
        split_dst_ptr,
        split_src_ptr,
        (EXPERTS_PER_RANK + 1) * 4,  # now we use `int32` for splits
        pid,
    )
# If online quantization is enbaled, signal the target rank with the scale data
if ONLINE_QUANT_FP8:
    libshmem_device.putmem_signal_nbi_block(
        scale_dst_ptr,
        scale_src_ptr,
        num_rows_cur_block * NUM_GROUPS * 4,  # assume `float32` for scale
        signal_ptr,
        call_count,
        libshmem_device.NVSHMEM_SIGNAL_SET,
        pid,
    )

Fence data transfer. Then wait for signal

libshmem_device.fence()
if threadidx == 0:
    # notify the target rank (here is the `pid`-th rank) that the data is ready by setting the signal
    if not ONLINE_QUANT_FP8:
        libshmem_device.signal_op(
            signal_ptr,
            call_count,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            pid,
        )
    # wait for the signal from the source rank (here is the `pid`-th rank)
    libshmem_device.signal_wait_until(
        signal + act_pos * WORLD_SIZE + pid,
        libshmem_device.NVSHMEM_CMP_EQ,
        call_count,
    )