Expert Parallelism All-to-All (EP A2A)
Expert Parallelism (EP) All-to-All communication kernels for Mixture-of-Experts (MoE) models. These kernels enable efficient token dispatch and combine operations across distributed experts.
Overview
In MoE models with Expert Parallelism, tokens need to be routed to their assigned experts which may reside on different ranks. The EP A2A kernels provide:
Dispatch: Scatter input tokens to remote experts based on Top-K routing decisions
Combine: Gather expert outputs back to the original ranks and perform weighted summation
Dispatch Phase:
┌────────────────┐ ┌────────────────┐ ┌────────────────┐
│ Rank 0 │ │ Rank 1 │ │ Rank N │
│ [tokens] │ │ [tokens] │ │ [tokens] │
│ ↓ TopK Gate │ │ ↓ TopK Gate │ │ ↓ TopK Gate │
│ expert_ids │ │ expert_ids │ │ expert_ids │
└───────┬────────┘ └───────┬────────┘ └───────┬────────┘
│ │ │
└─────────────────────┼─────────────────────┘
↓
┌───────────────────────┐
│ All-to-All Dispatch │
│ (Token Shuffle) │
└───────────┬───────────┘
↓
┌────────────────┐ ┌────────────────┐ ┌────────────────┐
│ Expert 0..E/N │ │Expert E/N..2E/N│ │Expert (N-1)E/N │
│ [recv tokens] │ │ [recv tokens] │ │ [recv tokens] │
└────────────────┘ └────────────────┘ └────────────────┘
Combine Phase (reverse direction):
Expert outputs → All-to-All → Weighted sum per original token
Key Concepts
Token Layout
The kernel uses a hierarchical layout optimized for multi-node communication:
Inter-node buffer:
[nnodes, max_tokens, hidden]- tokens organized by source nodeIntra-node buffer: Symmetric memory for direct GPU-to-GPU access within a node
recv_buf_offset_per_expert:
[world_size, experts_per_rank, world_size]- tracks where tokens from each source rank go for each local expert
Dispatch Workflow
Preprocessing (
get_ag_splits_and_recv_offset_for_dispatch):Count tokens per expert (bincount)
AllGather split information across all ranks
Compute destination offsets for each token
Generate send requests for inter-node transfers
Token Dispatch (
kernel_dispatch_token):Ring-based inter-node communication (reduces network congestion)
Use NVSHMEM
putmem_nbi_warpfor asynchronous data transferSignal-based synchronization between nodes
Optionally transfer routing weights
Postprocessing:
Reset symmetric buffers for next iteration
Handle skipped token mapping for intra-node optimization
Combine Workflow
Intra-node Reduce (
kernel_combine_token):Load expert outputs from symmetric memory of peer ranks
Accumulate outputs for tokens routed to multiple local experts
Transfer reduced results to inter-node buffer
Inter-node Reduce:
Sum outputs across nodes for each original token
API Reference
Core Functions
- ep_dispatch_token_inplace(send_reqs_for_nodes, signal_buf, recv_buf_offset_per_expert, send_buf, output_buf, weight_send_buf, weight_recv_buf, topk_indices_tensor, token_dst_scatter_idx, num_input_tokens_per_rank, max_tokens, topk, hidden, bytes_per_token, experts_per_rank, local_world_size, has_weight, with_scatter_indices, num_sms, use_aot=False)
Dispatch tokens to their assigned experts across all ranks.
- Parameters:
send_reqs_for_nodes –
[nnodes, 2, max_tokens]- Send request ranges per target nodesignal_buf –
[world_size]- NVSHMEM signal buffer for synchronizationrecv_buf_offset_per_expert –
[world_size, experts_per_rank, world_size]- Destination offsetssend_buf –
[nnodes, max_tokens, hidden]- Source token buffer (symmetric)output_buf –
[recv_tokens, hidden]- Destination buffer for received tokens (symmetric)weight_send_buf –
[nnodes, max_tokens, topk]- Optional routing weights to sendweight_recv_buf –
[recv_tokens]- Optional buffer for received weightstopk_indices_tensor –
[nnodes, max_tokens, topk]- Expert indices per tokentoken_dst_scatter_idx –
[nnodes, max_tokens, topk]- Scatter indices in output buffernum_input_tokens_per_rank –
[world_size]- Token count per source rankmax_tokens – Maximum tokens per rank
topk – Number of experts per token
hidden – Hidden dimension size
bytes_per_token –
hidden * dtype.itemsizeexperts_per_rank – Number of experts per rank
local_world_size – Number of ranks per node
has_weight – Whether to transfer routing weights
with_scatter_indices – Whether scatter indices are precomputed
num_sms – Number of SMs to use
use_aot – Whether to use AOT-compiled kernel
- ep_combine_token_inplace(counter_workspace, num_input_tokens_per_rank, send_reqs_recv_tensor, intra_node_reduce_buf, input, send_buf, topk_indices_tensor, token_dst_scatter_idx, max_tokens, topk, hidden, bytes_per_token, experts_per_rank, local_world_size, num_sms, use_aot=False)
Combine expert outputs back to original token positions.
- Parameters:
counter_workspace –
[nnodes]- Grid synchronization countersnum_input_tokens_per_rank –
[world_size]- Token count per source ranksend_reqs_recv_tensor –
[nnodes, 2, max_tokens]- Received send requestsintra_node_reduce_buf –
[nnodes, max_tokens, hidden]- Intermediate reduction bufferinput –
[recv_tokens, hidden]- Expert output tokens (symmetric)send_buf –
[nnodes, max_tokens, hidden]- Output buffer for combined tokenstopk_indices_tensor –
[nnodes, max_tokens, topk]- Expert indices per tokentoken_dst_scatter_idx –
[nnodes, max_tokens, topk]- Scatter indices from dispatchmax_tokens – Maximum tokens per rank
topk – Number of experts per token
hidden – Hidden dimension size
bytes_per_token –
hidden * dtype.itemsizeexperts_per_rank – Number of experts per rank
local_world_size – Number of ranks per node
num_sms – Number of SMs to use
use_aot – Whether to use AOT-compiled kernel
Preprocessing Functions
- get_ag_splits_and_recv_offset_for_dispatch(send_reqs_for_nodes, send_reqs_recv_bufs, exp_indices, topk_indices_buf, expert_indices_signal_buf, local_splits, full_splits_buf, splits_signal_buf, topk, local_world_size, world_size, max_tokens, experts_per_rank, full_scatter_indices=None, cpu_default_val=-1, offset_dtype=torch.int32, num_sm=20, use_aot=False)
Compute routing metadata for dispatch operation.
- Returns:
Tuple of (recv_buf_offset_per_expert, num_recv_tokens_per_rank_cpu, num_input_tokens_per_rank, token_dst_scatter_idx, send_reqs_recv_bufs_copy, topk_indices_buf_copy)
- get_dispatch_send_reqs(exp_indices, send_reqs_for_nodes, experts_per_rank, local_world_size, num_sms, use_aot=False)
Generate send request ranges for inter-node dispatch.
- Parameters:
exp_indices –
[num_tokens, topk]- Expert indices from Top-K gatesend_reqs_for_nodes –
[nnodes, 2, max_tokens]- Output buffer for (start, end) ranges
- bincount(input, length, output=None, output_dtype=torch.int32, num_sm=16, use_aot=False)
Count tokens per expert (GPU-accelerated bincount).
- Parameters:
input –
[num_tokens * topk]- Flattened expert indiceslength – Number of experts + 1 (for dropped tokens)
- Returns:
Token counts per expert
Kernel Implementation Details
kernel_dispatch_token
@triton_dist.jit
def kernel_dispatch_token(
send_reqs_for_nodes, signals_for_nodes, recv_buf_offset_per_expert,
input_buf, output_buf, weight_send_buf, weight_recv_buf,
topk_indices_tensor, token_dst_scatter_idx, num_input_tokens_per_rank,
max_tokens, topk, hidden_size, bytes_per_token, num_sms,
experts_per_rank: tl.constexpr, local_world_size: tl.constexpr,
HAS_WEIGHT: tl.constexpr, WITH_SCATTER_INDICES: tl.constexpr,
):
...
Key algorithmic steps:
Ring Communication: Iterate through nodes in ring order to balance network load
for node_offset in range(0, nnodes): target_node = (node_id + node_offset + 1) % nnodes
Warp-level Data Transfer: Each warp handles one token’s data transfer
libshmem_device.putmem_warp(dst_ptr, src_ptr, bytes_per_token, expert_rank)
Atomic Offset Allocation: Thread-safe destination slot allocation
store_idx = atomic_add_per_warp( recv_buf_offset_per_expert + expert_rank * ..., 1, scope="gpu", semantic="relaxed")
kernel_combine_token
Key algorithmic steps:
Remote Memory Access: Read expert outputs directly from peer ranks
remote_input_ptr = dl.symm_at(input_buf, expert_rank) token = dl.ld_vector(remote_input_ptr + offset, vec_size=vec_size)
Vectorized Accumulation: Use 128-bit vector operations for efficiency
token_accum = dl.zeros_vector(vec_size, tl.float32) for j in range(topk): token = dl.ld_vector(...).to(tl.float32) token_accum = token_accum + token dl.st_vector(send_buf + ..., token_accum.to(send_buf.dtype.element_ty))
AOT Compilation
The kernels support Ahead-of-Time (AOT) compilation for reduced JIT overhead:
@aot_compile_spaces({
"kernel_dispatch_token_bf16_weight_fp32": {
"signature": kernel_dispatch_token_signature.format(...),
"grid": ["num_sms", "1", "1"],
"triton_algo_infos": [
{"experts_per_rank": 64, "local_world_size": 8, ...},
{"experts_per_rank": 32, ...},
...
]
}
})
@triton_dist.jit
def kernel_dispatch_token(...):
...
Performance Characteristics
Warp-level Parallelism: Each warp handles one token, maximizing parallelism
Ring Communication: Balances network load across nodes
Symmetric Memory: Enables direct GPU-to-GPU access without CPU involvement
Vectorized Operations: 128-bit vector loads/stores for maximum memory bandwidth
Grid Synchronization: Efficient barrier implementation using atomic operations
Note
For optimal performance:
Use
num_sms=20or higher for large token countsEnable AOT compilation (
use_aot=True) in productionEnsure
hiddenis divisible by 16 for vectorization
Usage Example
from triton_dist.layers.nvidia import EPAll2AllLayer
# Create EP layer
ep_layer = EPAll2AllLayer(
ep_group=ep_group,
max_tokens=256,
hidden=7168,
topk=8,
rank=rank,
num_tot_experts=256,
local_world_size=8,
world_size=32,
dtype=torch.bfloat16,
weight_dtype=torch.float32,
)
# Dispatch tokens to experts
output, weights, layout_desc = ep_layer.dispatch(
input=tokens, # [num_tokens, hidden]
exp_indices=expert_ids, # [num_tokens, topk]
weight=routing_weights, # [num_tokens, topk]
)
# After expert computation...
combined = ep_layer.combine(expert_output, layout_desc)
# Cleanup
ep_layer.finalize()
Run Example
NVSHMEM_SYMMETRIC_SIZE=10000000000 bash scripts/launch.sh python/triton_dist/test/nvidia/test_ep_a2a.py -M 8192 -N 7168 --topk 8 --check
See Also
Expert Parallelism All-to-All (Intra-Node) - Optimized kernels for single-node EP
Low-Latency All-to-All V2 (EP) - Low-latency version with FP8 quantization
Expert Parallelism All-to-All Layer - High-level layer API