Expert Parallelism All-to-All Fused Layer
High-level layer API for fused Expert Parallelism All-to-All operations in single-node MoE models. This layer combines dispatch, groupgemm, and combine operations into megakernels for optimal performance.
Overview
EpAll2AllFusedOp provides a complete fused solution for EP MoE operations, specifically optimized
for single-node 8-GPU configurations. It implements computation-communication fusion to minimize
kernel launch overhead and maximize GPU utilization.
Key Features
Megakernel Fusion: Single kernel launch for dispatch+groupgemm and groupgemm+combine
Token Optimization: Token saving/skipping and sorting for reduced communication
Dynamic SM Scheduling: Fine-grained task distribution across SMs
Two-Stage Dispatch: Overlaps dispatch completion with groupgemm start
Fuse Scatter Mode: Interleaves scatter, groupgemm, and reduce for better overlap
Lazy Memory Allocation: Optional lazy NVSHMEM allocation for reduced startup time
Architecture
EpAll2AllFusedOp
├── Preprocessing
│ ├── get_ag_splits_and_recv_offset_for_dispatch()
│ │ ├── Token sorting (token_sort_indices)
│ │ ├── Expert split computation
│ │ └── Scatter index generation
│ └── Buffer initialization
│
├── Mega Dispatch + GroupGEMM
│ ├── mega_kernel_dispatch_token_moe_grouped_gemm()
│ │ ├── Dispatch tasks (NUM_DISPATCH_SM SMs)
│ │ │ ├── Token routing with token saving
│ │ │ ├── Two-stage dispatch (if NUM_TAIL_SMS > 0)
│ │ │ └── Per-expert barrier signaling
│ │ └── GroupGEMM tasks (remaining SMs)
│ │ ├── Wait for dispatch completion
│ │ ├── Execute GEMM computation
│ │ └── Notify combine phase
│ └── Checkpoint (dispatch_output_local)
│
└── Mega GroupGEMM + Combine
├── mega_kernel_moe_grouped_gemm_combine_token()
│ ├── GroupGEMM tasks
│ │ ├── Execute GEMM computation
│ │ └── Per-token-block barrier notification
│ ├── Scatter tasks (fuse_scatter mode)
│ │ ├── Scatter expert outputs
│ │ └── Transfer gate values
│ └── Reduce tasks
│ └── TopK weighted sum
└── Output copy
API Reference
EpAll2AllFusedOp
- class EpAll2AllFusedOp(ep_group, max_tokens, hidden, topk, rank, num_tot_experts, local_world_size, world_size, dtype=torch.bfloat16, weight_dtype=torch.float32, num_sm=20, sm_margin=0, duplicate_comm_buffer=1, capacity=4.0, FWD_GEMM_BLOCK_SIZE_N=256, need_reversed_token_scatter_idx=False, lazy=False)
Fused EP All-to-All layer for single-node MoE models.
- Parameters:
ep_group – PyTorch distributed process group for EP
max_tokens – Maximum number of tokens per rank
hidden – Hidden dimension size
topk – Number of experts selected per token
rank – Current rank
num_tot_experts – Total number of experts across all ranks
local_world_size – Number of ranks per node (must equal world_size for fused op)
world_size – Total number of ranks (must equal local_world_size)
dtype – Token data type (default:
torch.bfloat16)weight_dtype – Routing weight data type (default:
torch.float32)num_sm – Number of SMs to use for kernels (default: 20)
sm_margin – SMs to reserve (default: 0)
duplicate_comm_buffer – Number of communication buffers for pipelining (default: 1)
capacity – Buffer capacity multiplier (default: 4.0)
FWD_GEMM_BLOCK_SIZE_N – Block size N for forward GEMM (default: 256)
need_reversed_token_scatter_idx – Whether to generate reverse scatter indices (default: False)
lazy – Use lazy NVSHMEM allocation (default: False)
Note
This layer only supports single-node configurations (
world_size == local_world_size).- preprocess(exp_indices, full_scatter_indices=None, local_scatter_indices=None)
Preprocess expert indices and compute routing metadata.
- Parameters:
exp_indices –
[num_tokens, topk]- Expert indices from Top-K gatefull_scatter_indices –
[num_tokens, topk]- Optional global scatter indiceslocal_scatter_indices –
[num_tokens, topk]- Optional local scatter indices
- Returns:
EPAllToAllLayoutDesc- Layout descriptor for dispatch/combine
Key Operations:
Bincount: Count tokens per expert
AllGather Splits: Exchange split information across ranks
Token Sorting: Generate
token_sort_indicesfor optimal memory accessScatter Index Computation: Compute destination offsets for each token
- mega_dispatch_group_gemm(input, exp_indices, ep_a2a_layout_desc, gemm_weight, gemm_expert_ids, gemm_split_size, gemm_split_size_cum, gemm_tile_num, gemm_tile_num_cum, gemm_num_tiles_total, gemm_expert_offs, weight=None, with_cpy_flag=True, comm_buffer_id=0, optional_sm=None, num_tail_sms=0, gemm_input_reduce_last_dim=True, gemm_weight_reduce_last_dim=True, gemm_output_data=None, gemm_BLOCK_SIZE_N=256, gemm_BLOCK_SIZE_K=64, gemm_GROUP_SIZE_M=3, gemm_num_stages=3, use_block_wise_barrier=False, num_warps=16, enable_profiler=False, profile_file_name='mega_dispatch_group_gemm')
Fused dispatch and groupgemm operation.
- Parameters:
input –
[num_tokens, hidden]- Input tokensexp_indices –
[num_tokens, topk]- Expert indicesep_a2a_layout_desc –
EPAllToAllLayoutDesc- Layout descriptor from preprocessgemm_weight –
[G, N, K]or[G, K, N]- Expert weightsgemm_expert_ids –
[M_grid]- Expert ID for each tilegemm_split_size –
[G]- Token count per expertgemm_split_size_cum –
[M_grid]- Cumulative token countsgemm_tile_num –
[M_grid]- Number of tiles per expertgemm_tile_num_cum –
[M_grid]- Cumulative tile countsgemm_num_tiles_total –
[1]- Total number of tilesgemm_expert_offs –
[experts_per_rank]- Expert offsetsweight –
[num_tokens, topk]- Optional routing weightswith_cpy_flag – Whether to copy input to symmetric buffer (default: True)
comm_buffer_id – Communication buffer ID for pipelining (default: 0)
optional_sm – Override number of dispatch SMs (default: None)
num_tail_sms – Number of tail SMs for two-stage dispatch (default: 0)
gemm_input_reduce_last_dim – Whether input last dim is reduced (default: True)
gemm_weight_reduce_last_dim – Whether weight last dim is reduced (default: True)
gemm_output_data – Pre-allocated output buffer (default: None)
gemm_BLOCK_SIZE_N – Block size N for GEMM (default: 256)
gemm_BLOCK_SIZE_K – Block size K for GEMM (default: 64)
gemm_GROUP_SIZE_M – Group size M for GEMM (default: 3)
gemm_num_stages – Number of pipeline stages (default: 3)
use_block_wise_barrier – Use per-tile barriers (default: False)
num_warps – Number of warps per SM (default: 16)
enable_profiler – Enable profiling (default: False)
profile_file_name – Profile file name (default: “mega_dispatch_group_gemm”)
- Returns:
Tuple of (dispatch_output_local, weight_res, ep_a2a_layout_desc, gemm_output_data)
Workflow:
Copy input to symmetric buffer (if
with_cpy_flag=True)Initialize output buffer based on actual token counts
Launch megakernel with dispatch and groupgemm tasks
Return dispatch output (local copy) and groupgemm output
Two-Stage Dispatch (when
num_tail_sms > 0):Main SMs: Handle token routing and communication
Tail SMs: Handle local copy and barrier notification
Enables overlap between dispatch completion and groupgemm start
- mega_group_gemm_combine(gemm_input_data, gemm_weight, gemm_expert_ids, gemm_split_size, gemm_split_size_cum, gemm_tile_num, gemm_tile_num_cum, gemm_num_tiles_total, ep_a2a_layout_desc, gemm_input_reduce_last_dim=True, gemm_weight_reduce_last_dim=True, gemm_BLOCK_SIZE_N=256, gemm_BLOCK_SIZE_K=64, gemm_GROUP_SIZE_M=3, gemm_num_stages=3, gate_input=None, cp_flag=True, combine_output=None, output_gate=None, optional_sm=None, num_reduce_sms=0, optional_signal_tensor=None, num_warps=32, combine_mode='serial', grad_output=None, orig_input=None, grad_weight=None, split_size_cum_per_expert=None, grad_BLOCK_SIZE_M=64, grad_BLOCK_SIZE_N=128, grad_BLOCK_SIZE_K=256, grad_GROUP_SIZE_M=3, enable_profiler=False, profile_file_name='mega_group_gemm_combine')
Fused groupgemm and combine operation.
- Parameters:
gemm_input_data –
[M, K]- Input data for groupgemmgemm_weight –
[G, N, K]or[G, K, N]- Expert weightsgemm_expert_ids –
[M_grid]- Expert ID for each tilegemm_split_size –
[G]- Token count per expertgemm_split_size_cum –
[M_grid]- Cumulative token countsgemm_tile_num –
[M_grid]- Number of tiles per expertgemm_tile_num_cum –
[M_grid]- Cumulative tile countsgemm_num_tiles_total –
[1]- Total number of tilesep_a2a_layout_desc –
EPAllToAllLayoutDesc- Layout descriptor from dispatchgemm_input_reduce_last_dim – Whether input last dim is reduced (default: True)
gemm_weight_reduce_last_dim – Whether weight last dim is reduced (default: True)
gemm_BLOCK_SIZE_N – Block size N for GEMM (default: 256)
gemm_BLOCK_SIZE_K – Block size K for GEMM (default: 64)
gemm_GROUP_SIZE_M – Group size M for GEMM (default: 3)
gemm_num_stages – Number of pipeline stages (default: 3)
gate_input –
[recv_tokens]- Optional gate input valuescp_flag – Whether to copy gate input (default: True)
combine_output – Pre-allocated output buffer (default: None)
output_gate – Pre-allocated gate output buffer (default: None)
optional_sm – Override number of combine SMs (default: None)
num_reduce_sms – Number of SMs for reduce phase (default: 0)
optional_signal_tensor – Optional signal tensor (default: None)
num_warps – Number of warps per SM (default: 32)
combine_mode – Combine mode - “serial” or “fuse_scatter” (default: “serial”)
grad_output –
[M, N]- Gradient output for backward (default: None)orig_input –
[M, K]- Original input for backward (default: None)grad_weight –
[G, N, K]- Gradient weight output (default: None)split_size_cum_per_expert –
[G]- Cumulative split size per expert (default: None)grad_BLOCK_SIZE_M – Block size M for gradient GEMM (default: 64)
grad_BLOCK_SIZE_N – Block size N for gradient GEMM (default: 128)
grad_BLOCK_SIZE_K – Block size K for gradient GEMM (default: 256)
grad_GROUP_SIZE_M – Group size M for gradient GEMM (default: 3)
enable_profiler – Enable profiling (default: False)
profile_file_name – Profile file name (default: “mega_group_gemm_combine”)
- Returns:
Combined output (and optionally gate output and grad_weight)
Combine Modes:
Serial Mode (
combine_mode="serial"):Execute all groupgemm tasks first
Synchronize with barrier_all
Execute combine tasks
Suitable for small token counts
Fuse Scatter Mode (
combine_mode="fuse_scatter"):Interleave scatter, groupgemm, and reduce tasks
Fine-grained per-token-block barriers
Better overlap for large token counts
Recommended for production use
- init_output_buffer(num_recv_tokens_per_rank, min_m=None)
Initialize output buffer based on actual token counts.
- Parameters:
num_recv_tokens_per_rank –
[world_size]- Token counts per rank (CPU pinned memory)min_m – Minimum M dimension for groupgemm (default: None)
- Returns:
Tuple of (output_buf, weight_recv_buf)
Note: This method polls CPU memory to avoid GPU-CPU synchronization overhead.
- get_nvshmem_size()
Get total NVSHMEM memory size in bytes.
- get_nvshmem_size_gb()
Get total NVSHMEM memory size in GB.
- get_nvshmem_size_mb()
Get total NVSHMEM memory size in MB.
- get_nvshmem_breakdown()
Get breakdown of NVSHMEM usage by buffer name.
- sync()
Materialize all NVSHMEM tensors (required if
lazy=True).
- finalize()
Release all NVSHMEM symmetric memory buffers.
EPAllToAllLayoutDesc
- class EPAllToAllLayoutDesc
Layout descriptor containing routing metadata.
- num_dispatch_token_cur_rank: int
Number of tokens dispatched by this rank.
- recv_buf_offset_per_expert: torch.Tensor
[world_size, experts_per_rank, world_size]- Destination offsets for each token.
- recv_buf_tokens_per_expert: torch.Tensor
[world_size, experts_per_rank]- Token count per expert per rank.
- num_recv_tokens_per_rank: torch.Tensor
[world_size]- Total tokens received per rank.
- num_input_tokens_per_rank: torch.Tensor
[world_size]- Tokens dispatched per rank.
- topk_indices_tensor: torch.Tensor
[nnodes, max_tokens, topk]- Expert indices per token.
- token_dst_scatter_idx: torch.Tensor
[nnodes, max_tokens, topk]- Scatter indices in output buffer.
- token_sort_indices: torch.Tensor
[nnodes, max_tokens * topk]- Sorted token indices for optimal memory access.
- reversed_token_scatter_idx: torch.Tensor
[world_size * max_tokens * topk, 2]- Reverse mapping for combine phase.
Usage Example
Basic Usage
import torch
import torch.distributed as dist
from triton_dist.layers.nvidia.ep_a2a_fused_layer import EpAll2AllFusedOp
# Initialize distributed runtime
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 8, "Fused op only supports single-node 8-GPU"
# Create EP group
ep_group = dist.group.WORLD
# Create fused layer
ep_op = EpAll2AllFusedOp(
ep_group=ep_group,
max_tokens=256,
hidden=7168,
topk=2,
rank=rank,
num_tot_experts=64, # 64 experts / 8 ranks = 8 per rank
local_world_size=8,
world_size=8,
dtype=torch.bfloat16,
weight_dtype=torch.float32,
num_sm=20,
lazy=False, # Set to True for lazy allocation
)
# Sync if using lazy allocation
if ep_op._lazy:
ep_op.sync()
# Simulate input
num_tokens = 128
tokens = torch.randn(num_tokens, 7168, dtype=torch.bfloat16, device="cuda")
expert_ids = torch.randint(0, 64, (num_tokens, 2), dtype=torch.int32, device="cuda")
routing_weights = torch.softmax(torch.randn(num_tokens, 2, device="cuda"), dim=-1).float()
# Preprocess
layout_desc = ep_op.preprocess(
exp_indices=expert_ids,
)
# Dispatch + GroupGEMM
dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
input=tokens,
exp_indices=expert_ids,
ep_a2a_layout_desc=layout_desc,
gemm_weight=expert_weights, # [64, 7168, 2048]
gemm_expert_ids=expert_ids_tiles, # [M_grid]
gemm_split_size=split_size, # [64]
gemm_split_size_cum=split_size_cum, # [M_grid]
gemm_tile_num=tile_num, # [M_grid]
gemm_tile_num_cum=tile_num_cum, # [M_grid]
gemm_num_tiles_total=num_tiles_total, # [1]
gemm_expert_offs=expert_offs, # [8]
weight=routing_weights,
num_tail_sms=4, # Enable two-stage dispatch
use_block_wise_barrier=True, # Use per-tile barriers
num_warps=16,
)
# GroupGEMM + Combine
combined_output = ep_op.mega_group_gemm_combine(
gemm_input_data=gemm_output, # Output from first groupgemm
gemm_weight=expert_weights_2, # [64, 2048, 7168]
gemm_expert_ids=expert_ids_tiles,
gemm_split_size=split_size,
gemm_split_size_cum=split_size_cum,
gemm_tile_num=tile_num,
gemm_tile_num_cum=tile_num_cum,
gemm_num_tiles_total=num_tiles_total,
ep_a2a_layout_desc=layout_desc,
gate_input=routing_weights,
combine_mode="fuse_scatter", # Use fuse scatter mode
num_warps=32,
)
# Cleanup
ep_op.finalize()
With Two-Stage Dispatch
Two-stage dispatch enables overlap between dispatch completion and groupgemm start:
dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
...,
num_tail_sms=4, # Allocate 4 SMs for tail operations
use_block_wise_barrier=True, # Required for two-stage dispatch
)
Benefits:
Reduced Wait Time: GroupGEMM can start as soon as a tile is ready
Better Overlap: Dispatch completion overlaps with groupgemm execution
Improved Throughput: 10-20% improvement for large token counts
With Fuse Scatter Mode
Fuse scatter mode interleaves operations for better overlap:
combined_output = ep_op.mega_group_gemm_combine(
...,
combine_mode="fuse_scatter", # Enable fuse scatter mode
num_reduce_sms=8, # Allocate SMs for reduce phase
)
Benefits:
Fine-Grained Barriers: Per-token-block synchronization
Better Overlap: Scatter, groupgemm, and reduce can overlap
Improved Latency: 15-25% improvement for large token counts
With Lazy Allocation
Lazy allocation defers NVSHMEM allocation until first use:
ep_op = EpAll2AllFusedOp(
...,
lazy=True, # Enable lazy allocation
)
# Query memory requirements before allocation
print(f"NVSHMEM size: {ep_op.get_nvshmem_size_gb():.2f} GB")
ep_op.print_nvshmem_breakdown()
# Materialize when ready
ep_op.sync()
Benefits:
Faster Startup: Avoids allocation during initialization
Memory Planning: Query requirements before allocation
Flexible: Can adjust buffer sizes based on actual usage
Performance Tuning
SM Allocation
Dispatch Phase:
NUM_DISPATCH_SM: 20-40 SMs (depends on token count)
NUM_TAIL_SMS: 4-8 SMs (for two-stage dispatch)
Remaining SMs: Automatically used for groupgemm
Combine Phase:
COMBINE_SM: 20-40 SMs (depends on token count)
NUM_REDUCE_SMS: 4-8 SMs (for fuse scatter mode)
Remaining SMs: Automatically used for groupgemm
Tuning Guidelines:
Start with
num_sm=20and adjust based on profilingUse
num_tail_sms=4for two-stage dispatchUse
num_reduce_sms=8for fuse scatter modeMonitor SM utilization with profiling
GEMM Block Sizes
Forward GEMM:
BLOCK_SIZE_N: 256 (default, good for most cases)
BLOCK_SIZE_K: 64 (default, good for most cases)
GROUP_SIZE_M: 3 (default, balances parallelism and overhead)
Gradient GEMM:
grad_BLOCK_SIZE_M: 64 (default)
grad_BLOCK_SIZE_N: 128 (default)
grad_BLOCK_SIZE_K: 256 (default)
Tuning Guidelines:
Larger block sizes improve compute efficiency but reduce parallelism
Smaller block sizes improve parallelism but increase overhead
Use profiling to find optimal values for your workload
Warp Configuration
Dispatch:
num_warps=16(good balance)Combine:
num_warps=32(more parallelism for reduction)
Tuning Guidelines:
More warps improve parallelism but reduce shared memory per warp
Use
num_warps=16for dispatch (communication-bound)Use
num_warps=32for combine (computation-bound)
Profiling
Enable profiling to analyze performance:
dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
...,
enable_profiler=True,
profile_file_name="my_dispatch_profile",
)
combined_output = ep_op.mega_group_gemm_combine(
...,
enable_profiler=True,
profile_file_name="my_combine_profile",
)
Profiles are saved to prof/mega/ directory in Perfetto trace format.
Key Metrics:
dispatch_token_main: Main dispatch time
dispatch_token_tail_notify: Tail notification time
group_gemm_wait: Wait time for dispatch completion
group_gemm_main: GEMM computation time
combine_scatter_token: Scatter time
combine_topk_reduce: Reduce time
See Also
Expert Parallelism All-to-All Fused Megakernel - Underlying megakernel implementation
Expert Parallelism All-to-All Layer - Non-fused EP All-to-All layer
Expert Parallelism All-to-All (Intra-Node) - Intra-node optimized kernels