Expert Parallelism All-to-All Fused Layer

High-level layer API for fused Expert Parallelism All-to-All operations in single-node MoE models. This layer combines dispatch, groupgemm, and combine operations into megakernels for optimal performance.

Overview

EpAll2AllFusedOp provides a complete fused solution for EP MoE operations, specifically optimized for single-node 8-GPU configurations. It implements computation-communication fusion to minimize kernel launch overhead and maximize GPU utilization.

Key Features

  • Megakernel Fusion: Single kernel launch for dispatch+groupgemm and groupgemm+combine

  • Token Optimization: Token saving/skipping and sorting for reduced communication

  • Dynamic SM Scheduling: Fine-grained task distribution across SMs

  • Two-Stage Dispatch: Overlaps dispatch completion with groupgemm start

  • Fuse Scatter Mode: Interleaves scatter, groupgemm, and reduce for better overlap

  • Lazy Memory Allocation: Optional lazy NVSHMEM allocation for reduced startup time

Architecture

EpAll2AllFusedOp
├── Preprocessing
│   ├── get_ag_splits_and_recv_offset_for_dispatch()
│   │   ├── Token sorting (token_sort_indices)
│   │   ├── Expert split computation
│   │   └── Scatter index generation
│   └── Buffer initialization
│
├── Mega Dispatch + GroupGEMM
│   ├── mega_kernel_dispatch_token_moe_grouped_gemm()
│   │   ├── Dispatch tasks (NUM_DISPATCH_SM SMs)
│   │   │   ├── Token routing with token saving
│   │   │   ├── Two-stage dispatch (if NUM_TAIL_SMS > 0)
│   │   │   └── Per-expert barrier signaling
│   │   └── GroupGEMM tasks (remaining SMs)
│   │       ├── Wait for dispatch completion
│   │       ├── Execute GEMM computation
│   │       └── Notify combine phase
│   └── Checkpoint (dispatch_output_local)
│
└── Mega GroupGEMM + Combine
    ├── mega_kernel_moe_grouped_gemm_combine_token()
    │   ├── GroupGEMM tasks
    │   │   ├── Execute GEMM computation
    │   │   └── Per-token-block barrier notification
    │   ├── Scatter tasks (fuse_scatter mode)
    │   │   ├── Scatter expert outputs
    │   │   └── Transfer gate values
    │   └── Reduce tasks
    │       └── TopK weighted sum
    └── Output copy

API Reference

EpAll2AllFusedOp

class EpAll2AllFusedOp(ep_group, max_tokens, hidden, topk, rank, num_tot_experts, local_world_size, world_size, dtype=torch.bfloat16, weight_dtype=torch.float32, num_sm=20, sm_margin=0, duplicate_comm_buffer=1, capacity=4.0, FWD_GEMM_BLOCK_SIZE_N=256, need_reversed_token_scatter_idx=False, lazy=False)

Fused EP All-to-All layer for single-node MoE models.

Parameters:
  • ep_group – PyTorch distributed process group for EP

  • max_tokens – Maximum number of tokens per rank

  • hidden – Hidden dimension size

  • topk – Number of experts selected per token

  • rank – Current rank

  • num_tot_experts – Total number of experts across all ranks

  • local_world_size – Number of ranks per node (must equal world_size for fused op)

  • world_size – Total number of ranks (must equal local_world_size)

  • dtype – Token data type (default: torch.bfloat16)

  • weight_dtype – Routing weight data type (default: torch.float32)

  • num_sm – Number of SMs to use for kernels (default: 20)

  • sm_margin – SMs to reserve (default: 0)

  • duplicate_comm_buffer – Number of communication buffers for pipelining (default: 1)

  • capacity – Buffer capacity multiplier (default: 4.0)

  • FWD_GEMM_BLOCK_SIZE_N – Block size N for forward GEMM (default: 256)

  • need_reversed_token_scatter_idx – Whether to generate reverse scatter indices (default: False)

  • lazy – Use lazy NVSHMEM allocation (default: False)

Note

This layer only supports single-node configurations (world_size == local_world_size).

preprocess(exp_indices, full_scatter_indices=None, local_scatter_indices=None)

Preprocess expert indices and compute routing metadata.

Parameters:
  • exp_indices[num_tokens, topk] - Expert indices from Top-K gate

  • full_scatter_indices[num_tokens, topk] - Optional global scatter indices

  • local_scatter_indices[num_tokens, topk] - Optional local scatter indices

Returns:

EPAllToAllLayoutDesc - Layout descriptor for dispatch/combine

Key Operations:

  1. Bincount: Count tokens per expert

  2. AllGather Splits: Exchange split information across ranks

  3. Token Sorting: Generate token_sort_indices for optimal memory access

  4. Scatter Index Computation: Compute destination offsets for each token

mega_dispatch_group_gemm(input, exp_indices, ep_a2a_layout_desc, gemm_weight, gemm_expert_ids, gemm_split_size, gemm_split_size_cum, gemm_tile_num, gemm_tile_num_cum, gemm_num_tiles_total, gemm_expert_offs, weight=None, with_cpy_flag=True, comm_buffer_id=0, optional_sm=None, num_tail_sms=0, gemm_input_reduce_last_dim=True, gemm_weight_reduce_last_dim=True, gemm_output_data=None, gemm_BLOCK_SIZE_N=256, gemm_BLOCK_SIZE_K=64, gemm_GROUP_SIZE_M=3, gemm_num_stages=3, use_block_wise_barrier=False, num_warps=16, enable_profiler=False, profile_file_name='mega_dispatch_group_gemm')

Fused dispatch and groupgemm operation.

Parameters:
  • input[num_tokens, hidden] - Input tokens

  • exp_indices[num_tokens, topk] - Expert indices

  • ep_a2a_layout_descEPAllToAllLayoutDesc - Layout descriptor from preprocess

  • gemm_weight[G, N, K] or [G, K, N] - Expert weights

  • gemm_expert_ids[M_grid] - Expert ID for each tile

  • gemm_split_size[G] - Token count per expert

  • gemm_split_size_cum[M_grid] - Cumulative token counts

  • gemm_tile_num[M_grid] - Number of tiles per expert

  • gemm_tile_num_cum[M_grid] - Cumulative tile counts

  • gemm_num_tiles_total[1] - Total number of tiles

  • gemm_expert_offs[experts_per_rank] - Expert offsets

  • weight[num_tokens, topk] - Optional routing weights

  • with_cpy_flag – Whether to copy input to symmetric buffer (default: True)

  • comm_buffer_id – Communication buffer ID for pipelining (default: 0)

  • optional_sm – Override number of dispatch SMs (default: None)

  • num_tail_sms – Number of tail SMs for two-stage dispatch (default: 0)

  • gemm_input_reduce_last_dim – Whether input last dim is reduced (default: True)

  • gemm_weight_reduce_last_dim – Whether weight last dim is reduced (default: True)

  • gemm_output_data – Pre-allocated output buffer (default: None)

  • gemm_BLOCK_SIZE_N – Block size N for GEMM (default: 256)

  • gemm_BLOCK_SIZE_K – Block size K for GEMM (default: 64)

  • gemm_GROUP_SIZE_M – Group size M for GEMM (default: 3)

  • gemm_num_stages – Number of pipeline stages (default: 3)

  • use_block_wise_barrier – Use per-tile barriers (default: False)

  • num_warps – Number of warps per SM (default: 16)

  • enable_profiler – Enable profiling (default: False)

  • profile_file_name – Profile file name (default: “mega_dispatch_group_gemm”)

Returns:

Tuple of (dispatch_output_local, weight_res, ep_a2a_layout_desc, gemm_output_data)

Workflow:

  1. Copy input to symmetric buffer (if with_cpy_flag=True)

  2. Initialize output buffer based on actual token counts

  3. Launch megakernel with dispatch and groupgemm tasks

  4. Return dispatch output (local copy) and groupgemm output

Two-Stage Dispatch (when num_tail_sms > 0):

  • Main SMs: Handle token routing and communication

  • Tail SMs: Handle local copy and barrier notification

  • Enables overlap between dispatch completion and groupgemm start

mega_group_gemm_combine(gemm_input_data, gemm_weight, gemm_expert_ids, gemm_split_size, gemm_split_size_cum, gemm_tile_num, gemm_tile_num_cum, gemm_num_tiles_total, ep_a2a_layout_desc, gemm_input_reduce_last_dim=True, gemm_weight_reduce_last_dim=True, gemm_BLOCK_SIZE_N=256, gemm_BLOCK_SIZE_K=64, gemm_GROUP_SIZE_M=3, gemm_num_stages=3, gate_input=None, cp_flag=True, combine_output=None, output_gate=None, optional_sm=None, num_reduce_sms=0, optional_signal_tensor=None, num_warps=32, combine_mode='serial', grad_output=None, orig_input=None, grad_weight=None, split_size_cum_per_expert=None, grad_BLOCK_SIZE_M=64, grad_BLOCK_SIZE_N=128, grad_BLOCK_SIZE_K=256, grad_GROUP_SIZE_M=3, enable_profiler=False, profile_file_name='mega_group_gemm_combine')

Fused groupgemm and combine operation.

Parameters:
  • gemm_input_data[M, K] - Input data for groupgemm

  • gemm_weight[G, N, K] or [G, K, N] - Expert weights

  • gemm_expert_ids[M_grid] - Expert ID for each tile

  • gemm_split_size[G] - Token count per expert

  • gemm_split_size_cum[M_grid] - Cumulative token counts

  • gemm_tile_num[M_grid] - Number of tiles per expert

  • gemm_tile_num_cum[M_grid] - Cumulative tile counts

  • gemm_num_tiles_total[1] - Total number of tiles

  • ep_a2a_layout_descEPAllToAllLayoutDesc - Layout descriptor from dispatch

  • gemm_input_reduce_last_dim – Whether input last dim is reduced (default: True)

  • gemm_weight_reduce_last_dim – Whether weight last dim is reduced (default: True)

  • gemm_BLOCK_SIZE_N – Block size N for GEMM (default: 256)

  • gemm_BLOCK_SIZE_K – Block size K for GEMM (default: 64)

  • gemm_GROUP_SIZE_M – Group size M for GEMM (default: 3)

  • gemm_num_stages – Number of pipeline stages (default: 3)

  • gate_input[recv_tokens] - Optional gate input values

  • cp_flag – Whether to copy gate input (default: True)

  • combine_output – Pre-allocated output buffer (default: None)

  • output_gate – Pre-allocated gate output buffer (default: None)

  • optional_sm – Override number of combine SMs (default: None)

  • num_reduce_sms – Number of SMs for reduce phase (default: 0)

  • optional_signal_tensor – Optional signal tensor (default: None)

  • num_warps – Number of warps per SM (default: 32)

  • combine_mode – Combine mode - “serial” or “fuse_scatter” (default: “serial”)

  • grad_output[M, N] - Gradient output for backward (default: None)

  • orig_input[M, K] - Original input for backward (default: None)

  • grad_weight[G, N, K] - Gradient weight output (default: None)

  • split_size_cum_per_expert[G] - Cumulative split size per expert (default: None)

  • grad_BLOCK_SIZE_M – Block size M for gradient GEMM (default: 64)

  • grad_BLOCK_SIZE_N – Block size N for gradient GEMM (default: 128)

  • grad_BLOCK_SIZE_K – Block size K for gradient GEMM (default: 256)

  • grad_GROUP_SIZE_M – Group size M for gradient GEMM (default: 3)

  • enable_profiler – Enable profiling (default: False)

  • profile_file_name – Profile file name (default: “mega_group_gemm_combine”)

Returns:

Combined output (and optionally gate output and grad_weight)

Combine Modes:

  1. Serial Mode (combine_mode="serial"):

    • Execute all groupgemm tasks first

    • Synchronize with barrier_all

    • Execute combine tasks

    • Suitable for small token counts

  2. Fuse Scatter Mode (combine_mode="fuse_scatter"):

    • Interleave scatter, groupgemm, and reduce tasks

    • Fine-grained per-token-block barriers

    • Better overlap for large token counts

    • Recommended for production use

init_output_buffer(num_recv_tokens_per_rank, min_m=None)

Initialize output buffer based on actual token counts.

Parameters:
  • num_recv_tokens_per_rank[world_size] - Token counts per rank (CPU pinned memory)

  • min_m – Minimum M dimension for groupgemm (default: None)

Returns:

Tuple of (output_buf, weight_recv_buf)

Note: This method polls CPU memory to avoid GPU-CPU synchronization overhead.

get_nvshmem_size()

Get total NVSHMEM memory size in bytes.

get_nvshmem_size_gb()

Get total NVSHMEM memory size in GB.

get_nvshmem_size_mb()

Get total NVSHMEM memory size in MB.

get_nvshmem_breakdown()

Get breakdown of NVSHMEM usage by buffer name.

sync()

Materialize all NVSHMEM tensors (required if lazy=True).

finalize()

Release all NVSHMEM symmetric memory buffers.

EPAllToAllLayoutDesc

class EPAllToAllLayoutDesc

Layout descriptor containing routing metadata.

num_dispatch_token_cur_rank: int

Number of tokens dispatched by this rank.

recv_buf_offset_per_expert: torch.Tensor

[world_size, experts_per_rank, world_size] - Destination offsets for each token.

recv_buf_tokens_per_expert: torch.Tensor

[world_size, experts_per_rank] - Token count per expert per rank.

num_recv_tokens_per_rank: torch.Tensor

[world_size] - Total tokens received per rank.

num_input_tokens_per_rank: torch.Tensor

[world_size] - Tokens dispatched per rank.

topk_indices_tensor: torch.Tensor

[nnodes, max_tokens, topk] - Expert indices per token.

token_dst_scatter_idx: torch.Tensor

[nnodes, max_tokens, topk] - Scatter indices in output buffer.

token_sort_indices: torch.Tensor

[nnodes, max_tokens * topk] - Sorted token indices for optimal memory access.

reversed_token_scatter_idx: torch.Tensor

[world_size * max_tokens * topk, 2] - Reverse mapping for combine phase.

Usage Example

Basic Usage

import torch
import torch.distributed as dist
from triton_dist.layers.nvidia.ep_a2a_fused_layer import EpAll2AllFusedOp

# Initialize distributed runtime
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 8, "Fused op only supports single-node 8-GPU"

# Create EP group
ep_group = dist.group.WORLD

# Create fused layer
ep_op = EpAll2AllFusedOp(
    ep_group=ep_group,
    max_tokens=256,
    hidden=7168,
    topk=2,
    rank=rank,
    num_tot_experts=64,  # 64 experts / 8 ranks = 8 per rank
    local_world_size=8,
    world_size=8,
    dtype=torch.bfloat16,
    weight_dtype=torch.float32,
    num_sm=20,
    lazy=False,  # Set to True for lazy allocation
)

# Sync if using lazy allocation
if ep_op._lazy:
    ep_op.sync()

# Simulate input
num_tokens = 128
tokens = torch.randn(num_tokens, 7168, dtype=torch.bfloat16, device="cuda")
expert_ids = torch.randint(0, 64, (num_tokens, 2), dtype=torch.int32, device="cuda")
routing_weights = torch.softmax(torch.randn(num_tokens, 2, device="cuda"), dim=-1).float()

# Preprocess
layout_desc = ep_op.preprocess(
    exp_indices=expert_ids,
)

# Dispatch + GroupGEMM
dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
    input=tokens,
    exp_indices=expert_ids,
    ep_a2a_layout_desc=layout_desc,
    gemm_weight=expert_weights,  # [64, 7168, 2048]
    gemm_expert_ids=expert_ids_tiles,  # [M_grid]
    gemm_split_size=split_size,  # [64]
    gemm_split_size_cum=split_size_cum,  # [M_grid]
    gemm_tile_num=tile_num,  # [M_grid]
    gemm_tile_num_cum=tile_num_cum,  # [M_grid]
    gemm_num_tiles_total=num_tiles_total,  # [1]
    gemm_expert_offs=expert_offs,  # [8]
    weight=routing_weights,
    num_tail_sms=4,  # Enable two-stage dispatch
    use_block_wise_barrier=True,  # Use per-tile barriers
    num_warps=16,
)

# GroupGEMM + Combine
combined_output = ep_op.mega_group_gemm_combine(
    gemm_input_data=gemm_output,  # Output from first groupgemm
    gemm_weight=expert_weights_2,  # [64, 2048, 7168]
    gemm_expert_ids=expert_ids_tiles,
    gemm_split_size=split_size,
    gemm_split_size_cum=split_size_cum,
    gemm_tile_num=tile_num,
    gemm_tile_num_cum=tile_num_cum,
    gemm_num_tiles_total=num_tiles_total,
    ep_a2a_layout_desc=layout_desc,
    gate_input=routing_weights,
    combine_mode="fuse_scatter",  # Use fuse scatter mode
    num_warps=32,
)

# Cleanup
ep_op.finalize()

With Two-Stage Dispatch

Two-stage dispatch enables overlap between dispatch completion and groupgemm start:

dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
    ...,
    num_tail_sms=4,  # Allocate 4 SMs for tail operations
    use_block_wise_barrier=True,  # Required for two-stage dispatch
)

Benefits:

  • Reduced Wait Time: GroupGEMM can start as soon as a tile is ready

  • Better Overlap: Dispatch completion overlaps with groupgemm execution

  • Improved Throughput: 10-20% improvement for large token counts

With Fuse Scatter Mode

Fuse scatter mode interleaves operations for better overlap:

combined_output = ep_op.mega_group_gemm_combine(
    ...,
    combine_mode="fuse_scatter",  # Enable fuse scatter mode
    num_reduce_sms=8,  # Allocate SMs for reduce phase
)

Benefits:

  • Fine-Grained Barriers: Per-token-block synchronization

  • Better Overlap: Scatter, groupgemm, and reduce can overlap

  • Improved Latency: 15-25% improvement for large token counts

With Lazy Allocation

Lazy allocation defers NVSHMEM allocation until first use:

ep_op = EpAll2AllFusedOp(
    ...,
    lazy=True,  # Enable lazy allocation
)

# Query memory requirements before allocation
print(f"NVSHMEM size: {ep_op.get_nvshmem_size_gb():.2f} GB")
ep_op.print_nvshmem_breakdown()

# Materialize when ready
ep_op.sync()

Benefits:

  • Faster Startup: Avoids allocation during initialization

  • Memory Planning: Query requirements before allocation

  • Flexible: Can adjust buffer sizes based on actual usage

Performance Tuning

SM Allocation

Dispatch Phase:

  • NUM_DISPATCH_SM: 20-40 SMs (depends on token count)

  • NUM_TAIL_SMS: 4-8 SMs (for two-stage dispatch)

  • Remaining SMs: Automatically used for groupgemm

Combine Phase:

  • COMBINE_SM: 20-40 SMs (depends on token count)

  • NUM_REDUCE_SMS: 4-8 SMs (for fuse scatter mode)

  • Remaining SMs: Automatically used for groupgemm

Tuning Guidelines:

  • Start with num_sm=20 and adjust based on profiling

  • Use num_tail_sms=4 for two-stage dispatch

  • Use num_reduce_sms=8 for fuse scatter mode

  • Monitor SM utilization with profiling

GEMM Block Sizes

Forward GEMM:

  • BLOCK_SIZE_N: 256 (default, good for most cases)

  • BLOCK_SIZE_K: 64 (default, good for most cases)

  • GROUP_SIZE_M: 3 (default, balances parallelism and overhead)

Gradient GEMM:

  • grad_BLOCK_SIZE_M: 64 (default)

  • grad_BLOCK_SIZE_N: 128 (default)

  • grad_BLOCK_SIZE_K: 256 (default)

Tuning Guidelines:

  • Larger block sizes improve compute efficiency but reduce parallelism

  • Smaller block sizes improve parallelism but increase overhead

  • Use profiling to find optimal values for your workload

Warp Configuration

  • Dispatch: num_warps=16 (good balance)

  • Combine: num_warps=32 (more parallelism for reduction)

Tuning Guidelines:

  • More warps improve parallelism but reduce shared memory per warp

  • Use num_warps=16 for dispatch (communication-bound)

  • Use num_warps=32 for combine (computation-bound)

Profiling

Enable profiling to analyze performance:

dispatch_output, weights, layout_desc, gemm_output = ep_op.mega_dispatch_group_gemm(
    ...,
    enable_profiler=True,
    profile_file_name="my_dispatch_profile",
)

combined_output = ep_op.mega_group_gemm_combine(
    ...,
    enable_profiler=True,
    profile_file_name="my_combine_profile",
)

Profiles are saved to prof/mega/ directory in Perfetto trace format.

Key Metrics:

  • dispatch_token_main: Main dispatch time

  • dispatch_token_tail_notify: Tail notification time

  • group_gemm_wait: Wait time for dispatch completion

  • group_gemm_main: GEMM computation time

  • combine_scatter_token: Scatter time

  • combine_topk_reduce: Reduce time

See Also