Expert Parallelism All-to-All Layer

High-level layer API for Expert Parallelism All-to-All communication in MoE models.

Overview

EPAll2AllLayer provides a complete solution for token dispatch and combine operations in Expert Parallelism, handling both intra-node and inter-node communication transparently.

Key Features

  • Automatic Topology Detection: Selects optimal kernels for single-node vs multi-node

  • Dynamic Buffer Management: Automatically resizes output buffers based on actual token counts

  • Weight Transfer Support: Optionally transfers routing weights along with tokens

  • Scatter Index Precomputation: Supports external scatter index computation for advanced routing

  • AOT Compilation: Optional ahead-of-time compilation for reduced JIT overhead

Architecture

EPAll2AllLayer
├── EPConfig                    # Configuration parameters
├── DispatchCombineContext      # Symmetric memory buffers
│   ├── token_send_buf_rdma     # [nnodes, max_tokens, hidden]
│   ├── dispatch_output_buf     # [recv_tokens, hidden] (resizable)
│   ├── topk_indices_buf_rdma   # [nnodes, max_tokens, topk]
│   ├── weight_send/recv_buf    # Routing weights
│   ├── signal_buf              # NVSHMEM signals
│   └── ...
└── BarrierAllContext           # Intra-node barrier

Workflow

┌─────────────────────────────────────────────────────────────────┐
│                        dispatch() Flow                         │
├─────────────────────────────────────────────────────────────────┤
│  1. Copy input to symmetric buffer                             │
│  2. preprocess() - Compute routing metadata                    │
│     ├── bincount(expert_indices)                              │
│     ├── get_ag_splits_and_recv_offset_for_dispatch()          │
│     └── [inter-node] get_dispatch_send_reqs()                 │
│  3. init_output_buffer() - Poll CPU for buffer size           │
│  4. dispatch_token() - Execute dispatch kernel                │
│     ├── [inter-node] ep_dispatch_token_inplace()              │
│     └── [intra-node] kernel_dispatch_token_intra_node()       │
│  5. dispatch_postprocess() - Reset buffers                    │
│  6. Return (output, weights, layout_desc)                     │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                        combine() Flow                          │
├─────────────────────────────────────────────────────────────────┤
│  1. Copy expert output to symmetric buffer                     │
│  2. combine_token_intra_node_and_send()                       │
│     ├── [intra-node] kernel_combine_token_intra_node()        │
│     └── [inter-node] ep_combine_token_inplace()               │
│  3. [inter-node] Sum across nodes                             │
│  4. Return combined output                                     │
└─────────────────────────────────────────────────────────────────┘

API Reference

EPAll2AllLayer

class EPAll2AllLayer(ep_group, max_tokens, hidden, topk, rank, num_tot_experts, local_world_size, world_size, dtype=torch.bfloat16, weight_dtype=torch.float32, num_sm=20, enable_local_combine=False, use_aot=False)

High-level layer for EP All-to-All communication.

Parameters:
  • ep_group – PyTorch distributed process group for EP

  • max_tokens – Maximum number of tokens per rank

  • hidden – Hidden dimension size

  • topk – Number of experts selected per token

  • rank – Current rank

  • num_tot_experts – Total number of experts across all ranks

  • local_world_size – Number of ranks per node (typically 8)

  • world_size – Total number of ranks

  • dtype – Token data type (default: torch.bfloat16)

  • weight_dtype – Routing weight data type (default: torch.float32)

  • num_sm – Number of SMs to use for kernels (default: 20)

  • enable_local_combine – Enable intra-node local combine optimization (default: False)

  • use_aot – Use AOT-compiled kernels (default: False)

dispatch(input, exp_indices, weight=None, full_scatter_indices=None)

Dispatch tokens to their assigned experts.

Parameters:
  • input[num_tokens, hidden] - Input tokens

  • exp_indices[num_tokens, topk] - Expert indices from Top-K gate

  • weight[num_tokens, topk] - Optional routing weights (default: None)

  • full_scatter_indices[num_tokens, topk] - Optional precomputed scatter indices

Returns:

Tuple of (output, weights, layout_desc)

  • output: [recv_tokens, hidden] - Received tokens

  • weights: [recv_tokens] - Received weights (or None)

  • layout_desc: EPAllToAllLayoutDesc - Metadata for combine

combine(input, ep_a2a_layout_desc)

Combine expert outputs back to original token positions.

Parameters:
  • input[recv_tokens, hidden] - Expert output tokens

  • ep_a2a_layout_desc – Layout descriptor from dispatch

Returns:

[num_dispatch_tokens, hidden] - Combined output

finalize()

Release NVSHMEM symmetric memory buffers.

ep_barrier_all(stream, intra_node_only=False)

Synchronize all ranks (or just intra-node ranks).

Parameters:
  • stream – CUDA stream

  • intra_node_only – Only synchronize within node (default: False)

EPAllToAllLayoutDesc

class EPAllToAllLayoutDesc

Descriptor containing routing metadata from dispatch, needed for combine.

num_dispatch_token_cur_rank: int

Number of tokens dispatched by this rank.

num_input_tokens_per_rank: torch.Tensor

[world_size] - Number of tokens dispatched by each rank.

send_reqs_recv_tensor: torch.Tensor | None

[nnodes, 2, max_tokens] - Received send requests (inter-node only).

topk_indices_tensor: torch.Tensor

[nnodes, max_tokens, topk] or [max_tokens, topk] - Expert indices.

token_dst_scatter_idx: torch.Tensor

[nnodes, max_tokens, topk] - Scatter indices in output buffer.

skipped_token_mapping_indices: torch.Tensor | None

Intra-node optimization: maps skipped tokens to first occurrence.

skipped_token_topk_mapping_indices: torch.Tensor | None

Intra-node optimization: per-TopK mapping for skipped tokens.

EPConfig

class EPConfig

Configuration for EP All-to-All.

max_tokens: int
hidden: int
topk: int
num_experts: int
rank: int
world_size: int
local_world_size: int
token_dtype: torch.dtype
weight_dtype: torch.dtype
offset_dtype: torch.dtype
property num_experts_per_rank: int
property is_intra_node: bool

DispatchCombineContext

class DispatchCombineContext

Manages symmetric memory buffers for dispatch and combine operations.

classmethod create(ep_config, capacity=2) DispatchCombineContext

Create context with NVSHMEM symmetric buffers.

Parameters:
  • ep_config – EPConfig instance

  • capacity – Multiplier for output buffer capacity (default: 2)

finalize()

Release all symmetric memory buffers.

reallocate_dispatch_output_buf(dispatch_recv_tokens)

Resize output buffer if needed.

Parameters:

dispatch_recv_tokens – Required capacity

Returns:

Tuple of (dispatch_output_buf, weight_recv_buf)

Usage Example

Basic Usage

import torch
import torch.distributed as dist
from triton_dist.utils import initialize_distributed
from triton_dist.layers.nvidia import EPAll2AllLayer

# Initialize distributed runtime
rank, world_size = initialize_distributed()

# Create EP group (all ranks)
ep_group = dist.group.WORLD

# Create layer
ep_layer = EPAll2AllLayer(
    ep_group=ep_group,
    max_tokens=256,
    hidden=7168,
    topk=8,
    rank=rank,
    num_tot_experts=256,      # 256 experts / 32 ranks = 8 per rank
    local_world_size=8,
    world_size=32,
    dtype=torch.bfloat16,
    weight_dtype=torch.float32,
    num_sm=20,
)

# Simulate input
tokens = torch.randn(128, 7168, dtype=torch.bfloat16, device="cuda")
expert_ids = torch.randint(0, 256, (128, 8), dtype=torch.int32, device="cuda")
routing_weights = torch.softmax(torch.randn(128, 8, device="cuda"), dim=-1).float()

# Dispatch
recv_tokens, recv_weights, layout_desc = ep_layer.dispatch(
    input=tokens,
    exp_indices=expert_ids,
    weight=routing_weights,
)

# Expert FFN computation (example)
# recv_tokens: [recv_count, hidden]
expert_output = your_expert_ffn(recv_tokens)

# Combine
combined = ep_layer.combine(expert_output, layout_desc)
# combined: [128, 7168] - back to original token positions

# Cleanup
ep_layer.finalize()

With AOT Compilation

# Enable AOT for production
ep_layer = EPAll2AllLayer(
    ...,
    use_aot=True,  # Use pre-compiled kernels
)

With Local Combine Optimization

# For intra-node scenarios with tokens routed to multiple same-rank experts
ep_layer = EPAll2AllLayer(
    ...,
    world_size=8,             # Single node
    local_world_size=8,
    enable_local_combine=True, # Pre-aggregate locally
)

Performance Tips

  1. Buffer Sizing: Set max_tokens to your expected maximum to avoid reallocation

  2. AOT Compilation: Enable use_aot=True in production for faster startup

  3. SM Count: Tune num_sm based on your model’s compute requirements

  4. Local Combine: Enable for intra-node when TopK > 1 and experts overlap

Note

The layer automatically selects between intra-node and inter-node kernels based on world_size and local_world_size.

See Also