Expert Parallelism All-to-All Fused Megakernel
Expert Parallelism (EP) All-to-All fused megakernel for single-node MoE models. This kernel implements computation-communication fusion for dispatch+groupgemm and groupgemm+combine operations.
Overview
The fused megakernel combines multiple operations into a single kernel launch to reduce kernel launch overhead, improve SM utilization, and enable fine-grained task scheduling. It is specifically optimized for single-node 8-GPU EP MoE scenarios.
Key Features
Megakernel Architecture: Single kernel launch handles dispatch+groupgemm, and groupgemm+combine operations
Task-based Scheduling: Dynamic task queue with atomic counter for load balancing
Token Saving/Skipping: Optimizes communication by avoiding redundant token transfers
Token Sorting: Reorders tokens to improve memory access patterns and enable early dispatch completion
SM Scheduling: Fine-grained SM-level task distribution for optimal GPU utilization
Architecture
┌─────────────────────────────────────────────────────────────┐
│ Mega Kernel Task Queue │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Dispatch │ │ GroupGEMM │ │ Combine │ │
│ │ Tasks │ │ Tasks │ │ Tasks │ │
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │ │
│ └─────────────────┼─────────────────┘ │
│ │ │
│ ┌────────────▼────────────┐ │
│ │ Atomic Task Counter │ │
│ │ (Shared across SMs) │ │
│ └────────────┬────────────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ┌────▼────┐ ┌────▼────┐ ┌────▼────┐ │
│ │ SM 0 │ │ SM 1 │ ... │ SM N │ │
│ │ │ │ │ │ │ │
│ │ Fetch │ │ Fetch │ │ Fetch │ │
│ │ Execute │ │ Execute │ │ Execute │ │
│ └─────────┘ └─────────┘ └─────────┘ │
└─────────────────────────────────────────────────────────────┘
Megakernel Functions
Dispatch + GroupGEMM Fusion
- mega_kernel_dispatch_token_moe_grouped_gemm(...)
Fused kernel that performs token dispatch and groupgemm computation in a single launch.
Key Components:
Task Counter: Atomic counter shared across all SMs for dynamic task distribution
Dispatch Tasks: Token routing and communication operations
GroupGEMM Tasks: Matrix multiplication for expert computation
Workflow:
while task_id < total_tasks: task_id = atomic_add(task_counter_ptr, 1) if task_id < num_dispatch_tasks: # Execute dispatch token operation tile_kernel_dispatch_token_intra_node(...) else: # Execute groupgemm operation tile_kernel_moe_grouped_gemm_nk_const(...)Two-Stage Dispatch (when
NUM_TAIL_SMS > 0):Stage 1: Main dispatch SMs handle token routing
Stage 2: Tail SMs handle local copy and barrier notification
Enables overlap between dispatch completion and groupgemm start
GroupGEMM + Combine Fusion
- mega_kernel_moe_grouped_gemm_combine_token(...)
Fused kernel that performs groupgemm computation and token combine in a single launch.
Two Modes:
Serial Mode (
USE_SCATTER_MODE=False):Execute all groupgemm tasks first
Synchronize with barrier_all
Execute combine tasks
Fuse Scatter Mode (
USE_SCATTER_MODE=True):Interleave scatter, groupgemm, and reduce tasks
Fine-grained barrier synchronization per token block
Enables better overlap
Workflow (Fuse Scatter Mode):
while task_id < total_tasks: task_id = atomic_add(task_counter_ptr, 1) if task_id < num_combine_tasks: # Scatter tokens to output buffer tile_kernel_scatter_token_intra_node(...) elif task_id < num_combine_tasks + group_gemm_tasks: # Execute groupgemm with notification tile_kernel_moe_grouped_gemm_nk_const(..., NEED_NOTIFY=True) else: # TopK reduce tile_kernel_topk_reduce_token_intra_node(...)
Core Optimizations
Token Saving/Skipping
Problem: When a token is routed to multiple experts on the same rank, we can avoid redundant communication by only sending once and reusing the data.
Solution: Two-stage dispatch with token rank table and indirect position tracking.
Implementation (tile_kernel_dispatch_token_intra_node_two_stage):
# Stage 1: Main dispatch SMs
for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
sort_token_offset = ld(token_sort_indices + send_token_offset)
if sort_token_offset >= 0: # ignore dropped tokens
token_offset = sort_token_offset // topk
expert_idx = ld(topk_indices_tensor + sort_token_offset)
expert_rank = expert_idx // experts_per_rank
# Check if token already sent to this rank
has_sent = ld(token_rank_table_buf + token_offset * world_size + expert_rank)
if has_sent < 0:
# First time sending to this rank
has_sent = store_idx
libshmem_device.putmem_warp(dst_ptr, src_ptr, bytes_per_token, expert_rank)
st(token_rank_table_buf + token_offset * world_size + expert_rank, store_idx)
# Store indirect position for later lookup
remote_token_indirect_pos = dl.symm_at(token_indirect_pos_buf, expert_rank)
st(remote_token_indirect_pos + store_idx, has_sent, scope="sys", semantic="release")
# Stage 2: Tail SMs copy locally and notify
for tile_id in range(pid - num_pid, gemm_total_tiles_m, num_tail_sms):
# Wait for indirect position
has_sent = ld_acquire(token_indirect_pos_buf + real_offset, scope="sys")
while has_sent < 0:
has_sent = ld_acquire(token_indirect_pos_buf + real_offset, scope="sys")
# Copy from original position if different
copy_warp(dispatch_output_local + real_offset * hidden_size,
output_buf + has_sent * hidden_size, bytes_per_token)
Benefits:
Reduces communication volume by ~30% for tokens with multiple same-rank experts
Enables early groupgemm start while dispatch completes
Maintains correctness through indirect position tracking
Token Sorting
Problem: Tokens arrive in arbitrary order, causing poor memory access patterns and delayed dispatch completion.
Solution: Pre-sort tokens by expert rank and expert index to improve locality.
Implementation (get_ag_splits_and_recv_offset_for_dispatch):
The preprocessing kernel generates token_sort_indices that reorders tokens:
# Sort tokens by: (expert_rank, expert_idx_intra_rank, token_idx)
# This ensures:
# 1. Tokens going to same expert are contiguous
# 2. Memory access patterns are sequential
# 3. Early completion signals can be sent per expert
for token_idx in range(...):
expert_idx = ld(topk_indices_buf + token_idx)
expert_rank = expert_idx // experts_per_rank
expert_idx_intra_rank = expert_idx % experts_per_rank
# Compute sort key
sort_key = (expert_rank * experts_per_rank + expert_idx_intra_rank) * max_tokens + token_idx
st(token_sort_indices + sort_key, token_idx)
Dispatch Loop (tile_kernel_dispatch_token_intra_node):
# Iterate through sorted token offsets
for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
sort_token_offset = ld(token_sort_indices + send_token_offset)
if sort_token_offset >= 0: # ignore dropped tokens
token_offset = sort_token_offset // topk
expert_idx = ld(topk_indices_tensor + sort_token_offset)
# ... dispatch logic ...
Benefits:
Sequential Memory Access: Tokens for same expert are processed together
Early Completion: Can signal completion per expert as soon as all tokens arrive
Better Cache Utilization: Improved L2 cache hit rate
Reduced Barrier Overhead: Per-expert barriers instead of global barriers
SM Scheduling
Problem: Static task assignment leads to load imbalance and poor SM utilization.
Solution: Dynamic task queue with atomic counter for work-stealing.
Task Distribution:
# Each SM fetches tasks dynamically
task_id = tl.atomic_add(task_counter_ptr, 1)
while task_id < total_tasks:
if task_id < num_dispatch_tasks:
# Dispatch task
execute_dispatch_task(task_id)
elif task_id < num_dispatch_tasks + group_gemm_tasks:
# GroupGEMM task
execute_groupgemm_task(task_id - num_dispatch_tasks)
else:
# Combine task
execute_combine_task(task_id - num_dispatch_tasks - group_gemm_tasks)
# Fetch next task
task_id = tl.atomic_add(task_counter_ptr, 1)
SM Allocation Strategies:
Dispatch SMs (
NUM_DISPATCH_SM): Dedicated SMs for token dispatch - Typically 20-40 SMs depending on token count - Handles communication-intensive operationsTail SMs (
NUM_TAIL_SMS): SMs for two-stage dispatch completion - Typically 4-8 SMs - Handles local copy and barrier notification - Enables overlap with groupgemmGroupGEMM SMs: Remaining SMs for computation - Automatically distributed via task queue - Load balanced across all experts
Benefits:
Load Balancing: Fast SMs automatically take more tasks
Overlap: Dispatch and groupgemm can overlap naturally
Flexibility: Can adjust SM allocation based on workload characteristics
Barrier Synchronization
Per-Expert Barriers (Default):
# Signal completion when all tokens for an expert arrive
sent_tokens = atomic_add(counter_ptr + expert_idx, 1)
if sent_tokens == tokens_this_expert - 1:
libshmem_device.fence()
libshmem_device.signal_op(
barriers_ptr + expert_idx_intra_rank * world_size + rank,
1, libshmem_device.NVSHMEM_SIGNAL_SET, expert_rank)
Block-wise Barriers (USE_BLOCK_WISE_BARRIER=True):
# Per-tile barrier for finer granularity
# Enables groupgemm to start as soon as a tile is ready
barrier_idx = local_pid_m + tile_begin
while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
pass
Per-Token Block Barriers (Combine phase):
# Barrier per hidden_size / BLOCK_SIZE_N chunk
barrier_n_idx = elem_idx * VEC_SIZE // BARRIER_TOKEN_BLOCK_SIZE
barrier_idx = token_scatter_idx * N_BARRIERS_PER_TOKEN + barrier_n_idx
token = ld_acquire(remote_barriers_ptr + barrier_idx, scope="sys")
while token != 1:
token = ld_acquire(remote_barriers_ptr + barrier_idx, scope="sys")
Benefits:
Finer Granularity: Reduces wait time by enabling partial execution
Better Overlap: GroupGEMM can start processing as soon as data is ready
Reduced Synchronization Overhead: Smaller barrier scopes
Kernel Implementation Details
Dispatch Tile Kernel
@triton_dist.jit(do_not_specialize=["pid", "num_pid"])
def tile_kernel_dispatch_token_intra_node(
pid, num_pid,
counter_ptr, barriers_ptr,
recv_buf_offset_per_expert,
local_splits_buf,
input_buf, output_buf,
weight_send_buf, weight_recv_buf,
topk_indices_tensor,
token_dst_scatter_idx,
num_input_tokens_per_rank,
token_sort_indices,
topk: tl.constexpr,
hidden_size: tl.constexpr,
experts_per_rank: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
WITH_SCATTER_INDICES: tl.constexpr,
num_warps: tl.constexpr,
profiler: Profiler,
ENABLE_PROFILING: tl.constexpr,
):
WARP_SIZE = 32
rank = dl.rank()
world_size = dl.num_ranks()
thread_idx = tid(0)
lane_idx = thread_idx % WARP_SIZE
warp_id = thread_idx // WARP_SIZE
total_warps = num_warps * num_pid
global_warp_id = pid * num_warps + warp_id
token_num = tl.load(num_input_tokens_per_rank + rank)
# Process tokens in sorted order
for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
sort_token_offset = ld(token_sort_indices + send_token_offset)
if sort_token_offset >= 0: # ignore dropped tokens
token_offset = sort_token_offset // topk
expert_idx = ld(topk_indices_tensor + sort_token_offset)
expert_rank = expert_idx // experts_per_rank
expert_idx_intra_rank = expert_idx % experts_per_rank
# Allocate destination slot
if not WITH_SCATTER_INDICES:
store_idx = atomic_add_per_warp(
recv_buf_offset_per_expert + expert_rank * experts_per_rank * world_size +
expert_idx_intra_rank * world_size + rank, 1,
scope="gpu", semantic="relaxed")
else:
store_idx = ld(token_dst_scatter_idx + sort_token_offset)
# Transfer token data
src_ptr = input_buf + token_offset * hidden_size
dst_ptr = output_buf + store_idx.to(tl.int64) * hidden_size
libshmem_device.putmem_warp(dst_ptr, src_ptr, bytes_per_token, expert_rank)
# Transfer weight if needed
if HAS_WEIGHT:
libshmem_device.putmem_warp(
weight_recv_buf + store_idx,
weight_send_buf + sort_token_offset,
weight_elem_size, expert_rank)
# Signal completion per expert
sync_warp()
if lane_idx == 0:
tokens_this_expert = ld(local_splits_buf + expert_idx)
sent_tokens = atomic_add(counter_ptr + expert_idx, 1,
scope="gpu", semantic="relaxed")
if sent_tokens == tokens_this_expert - 1:
libshmem_device.fence()
libshmem_device.signal_op(
barriers_ptr + expert_idx_intra_rank * world_size + rank,
1, libshmem_device.NVSHMEM_SIGNAL_SET, expert_rank)
GroupGEMM Tile Kernel
@triton_dist.jit(do_not_specialize=["pid", "num_pid", "M"])
def tile_kernel_moe_grouped_gemm_nk_const(
pid, num_pid,
counter_ptr, barriers_ptr,
a_ptr, b_ptr, c_ptr,
expert_ids_ptr,
split_size_ptr, split_size_cum_ptr,
tile_num_ptr, tile_num_cum_ptr,
num_total_tiles_ptr,
M, N: tl.constexpr, K: tl.constexpr,
stride_am, stride_ak,
stride_be, stride_bn, stride_bk,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
profiler: Profiler,
NEED_WAIT: tl.constexpr,
NEED_NOTIFY: tl.constexpr,
USE_BLOCK_WISE_BARRIER: tl.constexpr,
IS_DISPATCH_TWO_STAGET: tl.constexpr,
ENABLE_PROFILING: tl.constexpr,
):
num_block_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_block_n
pid_n = pid % num_block_n
expert_id = tl.load(expert_ids_ptr + pid_m)
split_size = tl.load(split_size_ptr + expert_id)
split_size_cum = tl.load(split_size_cum_ptr + pid_m)
row_begin = split_size_cum
# Wait for dispatch to complete (if needed)
if NEED_WAIT:
if IS_DISPATCH_TWO_STAGET:
if USE_BLOCK_WISE_BARRIER:
barrier_idx = local_pid_m + tile_begin
while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
pass
else:
barrier_idx = expert_id
while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
pass
else:
# Per-expert barrier
barrier_idx = expert_id * world_size + thread_idx
while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
pass
# Execute GEMM
# ... GEMM computation ...
# Notify combine phase (if needed)
if NEED_NOTIFY:
__syncthreads()
token_begin = row_begin + local_pid_m * BLOCK_SIZE_M
valid_tokens = min(row_remain, BLOCK_SIZE_M)
if thread_idx < valid_tokens:
st(barriers_ptr + (token_begin + thread_idx) * num_block_n + pid_n,
1, scope="gpu", semantic="release")
Combine Tile Kernels
Gather Combine (tile_kernel_gather_combine_token_intra_node):
Reads expert outputs from remote ranks
Accumulates outputs for each token’s topk experts
Uses vectorized operations (128-bit loads/stores)
Scatter Combine (tile_kernel_scatter_token_intra_node):
Scatters expert outputs back to source ranks
Waits for per-token-block barriers
Transfers gate values if needed
TopK Reduce (tile_kernel_topk_reduce_token_intra_node):
Reduces scattered outputs for each token
Performs weighted sum across topk experts
Final output per original token
Performance Characteristics
Num Tokens Per Rank: 32k
Single-Node 8-GPU Configuration:
Communication: NVSHMEM symmetric memory for direct GPU-to-GPU access
Latency: ~3.5ms dispatch, ~5.9ms dispatch+groupgemm (depending on hidden size)
Throughput: algorithm bandwith ~201GB/s (due to token saving, it exceeds hardware limit)
Optimization Impact:
Token Saving: ~18%
Token Sorting: ~22%
Other Optimizations: ~8%
Usage Example
from triton_dist.kernels.nvidia.ep_all2all_fused import (
mega_kernel_dispatch_token_moe_grouped_gemm,
mega_kernel_moe_grouped_gemm_combine_token,
)
# Dispatch + GroupGEMM
mega_kernel_dispatch_token_moe_grouped_gemm[grid](
task_counter_ptr,
# ... dispatch params ...
num_dispatch_tasks=NUM_DISPATCH_SM,
# ... groupgemm params ...
NUM_WARPS=16,
NUM_TAIL_SMS=4,
USE_BLOCK_WISE_BARRIER=True,
)
# GroupGEMM + Combine
mega_kernel_moe_grouped_gemm_combine_token[grid](
task_counter_ptr,
# ... groupgemm params ...
# ... combine params ...
num_combine_tasks=COMBINE_SM,
num_reduce_tasks=REDUCE_SM,
USE_SCATTER_MODE=True,
NUM_WARPS=32,
)
See Also
Expert Parallelism All-to-All (Intra-Node) - Non-fused intra-node kernels
Expert Parallelism All-to-All Fused Layer - High-level fused layer API
Expert Parallelism All-to-All (EP A2A) - Inter-node EP All-to-All kernels