Expert Parallelism All-to-All Layer
High-level layer API for Expert Parallelism All-to-All communication in MoE models.
Overview
EPAll2AllLayer provides a complete solution for token dispatch and combine operations
in Expert Parallelism, handling both intra-node and inter-node communication transparently.
Key Features
Automatic Topology Detection: Selects optimal kernels for single-node vs multi-node
Dynamic Buffer Management: Automatically resizes output buffers based on actual token counts
Weight Transfer Support: Optionally transfers routing weights along with tokens
Scatter Index Precomputation: Supports external scatter index computation for advanced routing
AOT Compilation: Optional ahead-of-time compilation for reduced JIT overhead
Architecture
EPAll2AllLayer
├── EPConfig # Configuration parameters
├── DispatchCombineContext # Symmetric memory buffers
│ ├── token_send_buf_rdma # [nnodes, max_tokens, hidden]
│ ├── dispatch_output_buf # [recv_tokens, hidden] (resizable)
│ ├── topk_indices_buf_rdma # [nnodes, max_tokens, topk]
│ ├── weight_send/recv_buf # Routing weights
│ ├── signal_buf # NVSHMEM signals
│ └── ...
└── BarrierAllContext # Intra-node barrier
Workflow
┌─────────────────────────────────────────────────────────────────┐
│ dispatch() Flow │
├─────────────────────────────────────────────────────────────────┤
│ 1. Copy input to symmetric buffer │
│ 2. preprocess() - Compute routing metadata │
│ ├── bincount(expert_indices) │
│ ├── get_ag_splits_and_recv_offset_for_dispatch() │
│ └── [inter-node] get_dispatch_send_reqs() │
│ 3. init_output_buffer() - Poll CPU for buffer size │
│ 4. dispatch_token() - Execute dispatch kernel │
│ ├── [inter-node] ep_dispatch_token_inplace() │
│ └── [intra-node] kernel_dispatch_token_intra_node() │
│ 5. dispatch_postprocess() - Reset buffers │
│ 6. Return (output, weights, layout_desc) │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ combine() Flow │
├─────────────────────────────────────────────────────────────────┤
│ 1. Copy expert output to symmetric buffer │
│ 2. combine_token_intra_node_and_send() │
│ ├── [intra-node] kernel_combine_token_intra_node() │
│ └── [inter-node] ep_combine_token_inplace() │
│ 3. [inter-node] Sum across nodes │
│ 4. Return combined output │
└─────────────────────────────────────────────────────────────────┘
API Reference
EPAll2AllLayer
- class EPAll2AllLayer(ep_group, max_tokens, hidden, topk, rank, num_tot_experts, local_world_size, world_size, dtype=torch.bfloat16, weight_dtype=torch.float32, num_sm=20, enable_local_combine=False, use_aot=False)
High-level layer for EP All-to-All communication.
- Parameters:
ep_group – PyTorch distributed process group for EP
max_tokens – Maximum number of tokens per rank
hidden – Hidden dimension size
topk – Number of experts selected per token
rank – Current rank
num_tot_experts – Total number of experts across all ranks
local_world_size – Number of ranks per node (typically 8)
world_size – Total number of ranks
dtype – Token data type (default:
torch.bfloat16)weight_dtype – Routing weight data type (default:
torch.float32)num_sm – Number of SMs to use for kernels (default: 20)
enable_local_combine – Enable intra-node local combine optimization (default: False)
use_aot – Use AOT-compiled kernels (default: False)
- dispatch(input, exp_indices, weight=None, full_scatter_indices=None)
Dispatch tokens to their assigned experts.
- Parameters:
input –
[num_tokens, hidden]- Input tokensexp_indices –
[num_tokens, topk]- Expert indices from Top-K gateweight –
[num_tokens, topk]- Optional routing weights (default: None)full_scatter_indices –
[num_tokens, topk]- Optional precomputed scatter indices
- Returns:
Tuple of (output, weights, layout_desc)
output:
[recv_tokens, hidden]- Received tokensweights:
[recv_tokens]- Received weights (or None)layout_desc:
EPAllToAllLayoutDesc- Metadata for combine
- combine(input, ep_a2a_layout_desc)
Combine expert outputs back to original token positions.
- Parameters:
input –
[recv_tokens, hidden]- Expert output tokensep_a2a_layout_desc – Layout descriptor from dispatch
- Returns:
[num_dispatch_tokens, hidden]- Combined output
- finalize()
Release NVSHMEM symmetric memory buffers.
- ep_barrier_all(stream, intra_node_only=False)
Synchronize all ranks (or just intra-node ranks).
- Parameters:
stream – CUDA stream
intra_node_only – Only synchronize within node (default: False)
EPAllToAllLayoutDesc
- class EPAllToAllLayoutDesc
Descriptor containing routing metadata from dispatch, needed for combine.
- num_dispatch_token_cur_rank: int
Number of tokens dispatched by this rank.
- num_input_tokens_per_rank: torch.Tensor
[world_size]- Number of tokens dispatched by each rank.
- send_reqs_recv_tensor: torch.Tensor | None
[nnodes, 2, max_tokens]- Received send requests (inter-node only).
- topk_indices_tensor: torch.Tensor
[nnodes, max_tokens, topk]or[max_tokens, topk]- Expert indices.
- token_dst_scatter_idx: torch.Tensor
[nnodes, max_tokens, topk]- Scatter indices in output buffer.
- skipped_token_mapping_indices: torch.Tensor | None
Intra-node optimization: maps skipped tokens to first occurrence.
- skipped_token_topk_mapping_indices: torch.Tensor | None
Intra-node optimization: per-TopK mapping for skipped tokens.
EPConfig
DispatchCombineContext
- class DispatchCombineContext
Manages symmetric memory buffers for dispatch and combine operations.
- classmethod create(ep_config, capacity=2) DispatchCombineContext
Create context with NVSHMEM symmetric buffers.
- Parameters:
ep_config – EPConfig instance
capacity – Multiplier for output buffer capacity (default: 2)
- finalize()
Release all symmetric memory buffers.
- reallocate_dispatch_output_buf(dispatch_recv_tokens)
Resize output buffer if needed.
- Parameters:
dispatch_recv_tokens – Required capacity
- Returns:
Tuple of (dispatch_output_buf, weight_recv_buf)
Usage Example
Basic Usage
import torch
import torch.distributed as dist
from triton_dist.utils import initialize_distributed
from triton_dist.layers.nvidia import EPAll2AllLayer
# Initialize distributed runtime
rank, world_size = initialize_distributed()
# Create EP group (all ranks)
ep_group = dist.group.WORLD
# Create layer
ep_layer = EPAll2AllLayer(
ep_group=ep_group,
max_tokens=256,
hidden=7168,
topk=8,
rank=rank,
num_tot_experts=256, # 256 experts / 32 ranks = 8 per rank
local_world_size=8,
world_size=32,
dtype=torch.bfloat16,
weight_dtype=torch.float32,
num_sm=20,
)
# Simulate input
tokens = torch.randn(128, 7168, dtype=torch.bfloat16, device="cuda")
expert_ids = torch.randint(0, 256, (128, 8), dtype=torch.int32, device="cuda")
routing_weights = torch.softmax(torch.randn(128, 8, device="cuda"), dim=-1).float()
# Dispatch
recv_tokens, recv_weights, layout_desc = ep_layer.dispatch(
input=tokens,
exp_indices=expert_ids,
weight=routing_weights,
)
# Expert FFN computation (example)
# recv_tokens: [recv_count, hidden]
expert_output = your_expert_ffn(recv_tokens)
# Combine
combined = ep_layer.combine(expert_output, layout_desc)
# combined: [128, 7168] - back to original token positions
# Cleanup
ep_layer.finalize()
With AOT Compilation
# Enable AOT for production
ep_layer = EPAll2AllLayer(
...,
use_aot=True, # Use pre-compiled kernels
)
With Local Combine Optimization
# For intra-node scenarios with tokens routed to multiple same-rank experts
ep_layer = EPAll2AllLayer(
...,
world_size=8, # Single node
local_world_size=8,
enable_local_combine=True, # Pre-aggregate locally
)
Performance Tips
Buffer Sizing: Set
max_tokensto your expected maximum to avoid reallocationAOT Compilation: Enable
use_aot=Truein production for faster startupSM Count: Tune
num_smbased on your model’s compute requirementsLocal Combine: Enable for intra-node when TopK > 1 and experts overlap
Note
The layer automatically selects between intra-node and inter-node kernels
based on world_size and local_world_size.
See Also
Expert Parallelism All-to-All (EP A2A) - Underlying dispatch/combine kernels
Expert Parallelism All-to-All (Intra-Node) - Intra-node optimized kernels
Low-Latency EP All-to-All Layer - Low-latency variant with FP8 quantization