Expert Parallelism All-to-All Fused Megakernel

Expert Parallelism (EP) All-to-All fused megakernel for single-node MoE models. This kernel implements computation-communication fusion for dispatch+groupgemm and groupgemm+combine operations.

Overview

The fused megakernel combines multiple operations into a single kernel launch to reduce kernel launch overhead, improve SM utilization, and enable fine-grained task scheduling. It is specifically optimized for single-node 8-GPU EP MoE scenarios.

Key Features

  • Megakernel Architecture: Single kernel launch handles dispatch+groupgemm, and groupgemm+combine operations

  • Task-based Scheduling: Dynamic task queue with atomic counter for load balancing

  • Token Saving/Skipping: Optimizes communication by avoiding redundant token transfers

  • Token Sorting: Reorders tokens to improve memory access patterns and enable early dispatch completion

  • SM Scheduling: Fine-grained SM-level task distribution for optimal GPU utilization

Architecture

┌─────────────────────────────────────────────────────────────┐
│              Mega Kernel Task Queue                        │
│                                                             │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐   │
│  │ Dispatch     │  │ GroupGEMM    │  │ Combine       │   │
│  │ Tasks        │  │ Tasks        │  │ Tasks         │   │
│  └──────┬───────┘  └──────┬───────┘  └──────┬───────┘   │
│         │                 │                 │            │
│         └─────────────────┼─────────────────┘            │
│                           │                               │
│              ┌────────────▼────────────┐                 │
│              │  Atomic Task Counter    │                 │
│              │  (Shared across SMs)    │                 │
│              └────────────┬────────────┘                 │
│                           │                               │
│         ┌─────────────────┼─────────────────┐           │
│         │                 │                 │            │
│    ┌────▼────┐      ┌────▼────┐      ┌────▼────┐      │
│    │  SM 0   │      │  SM 1   │ ...  │  SM N   │      │
│    │         │      │         │      │         │      │
│    │ Fetch   │      │ Fetch   │      │ Fetch   │      │
│    │ Execute │      │ Execute │      │ Execute │      │
│    └─────────┘      └─────────┘      └─────────┘      │
└─────────────────────────────────────────────────────────────┘

Megakernel Functions

Dispatch + GroupGEMM Fusion

mega_kernel_dispatch_token_moe_grouped_gemm(...)

Fused kernel that performs token dispatch and groupgemm computation in a single launch.

Key Components:

  1. Task Counter: Atomic counter shared across all SMs for dynamic task distribution

  2. Dispatch Tasks: Token routing and communication operations

  3. GroupGEMM Tasks: Matrix multiplication for expert computation

Workflow:

while task_id < total_tasks:
    task_id = atomic_add(task_counter_ptr, 1)
    if task_id < num_dispatch_tasks:
        # Execute dispatch token operation
        tile_kernel_dispatch_token_intra_node(...)
    else:
        # Execute groupgemm operation
        tile_kernel_moe_grouped_gemm_nk_const(...)

Two-Stage Dispatch (when NUM_TAIL_SMS > 0):

  • Stage 1: Main dispatch SMs handle token routing

  • Stage 2: Tail SMs handle local copy and barrier notification

  • Enables overlap between dispatch completion and groupgemm start

GroupGEMM + Combine Fusion

mega_kernel_moe_grouped_gemm_combine_token(...)

Fused kernel that performs groupgemm computation and token combine in a single launch.

Two Modes:

  1. Serial Mode (USE_SCATTER_MODE=False):

    • Execute all groupgemm tasks first

    • Synchronize with barrier_all

    • Execute combine tasks

  2. Fuse Scatter Mode (USE_SCATTER_MODE=True):

    • Interleave scatter, groupgemm, and reduce tasks

    • Fine-grained barrier synchronization per token block

    • Enables better overlap

Workflow (Fuse Scatter Mode):

while task_id < total_tasks:
    task_id = atomic_add(task_counter_ptr, 1)
    if task_id < num_combine_tasks:
        # Scatter tokens to output buffer
        tile_kernel_scatter_token_intra_node(...)
    elif task_id < num_combine_tasks + group_gemm_tasks:
        # Execute groupgemm with notification
        tile_kernel_moe_grouped_gemm_nk_const(..., NEED_NOTIFY=True)
    else:
        # TopK reduce
        tile_kernel_topk_reduce_token_intra_node(...)

Core Optimizations

Token Saving/Skipping

Problem: When a token is routed to multiple experts on the same rank, we can avoid redundant communication by only sending once and reusing the data.

Solution: Two-stage dispatch with token rank table and indirect position tracking.

Implementation (tile_kernel_dispatch_token_intra_node_two_stage):

# Stage 1: Main dispatch SMs
for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
    sort_token_offset = ld(token_sort_indices + send_token_offset)
    if sort_token_offset >= 0:  # ignore dropped tokens
        token_offset = sort_token_offset // topk
        expert_idx = ld(topk_indices_tensor + sort_token_offset)
        expert_rank = expert_idx // experts_per_rank

        # Check if token already sent to this rank
        has_sent = ld(token_rank_table_buf + token_offset * world_size + expert_rank)
        if has_sent < 0:
            # First time sending to this rank
            has_sent = store_idx
            libshmem_device.putmem_warp(dst_ptr, src_ptr, bytes_per_token, expert_rank)
            st(token_rank_table_buf + token_offset * world_size + expert_rank, store_idx)

        # Store indirect position for later lookup
        remote_token_indirect_pos = dl.symm_at(token_indirect_pos_buf, expert_rank)
        st(remote_token_indirect_pos + store_idx, has_sent, scope="sys", semantic="release")

# Stage 2: Tail SMs copy locally and notify
for tile_id in range(pid - num_pid, gemm_total_tiles_m, num_tail_sms):
    # Wait for indirect position
    has_sent = ld_acquire(token_indirect_pos_buf + real_offset, scope="sys")
    while has_sent < 0:
        has_sent = ld_acquire(token_indirect_pos_buf + real_offset, scope="sys")

    # Copy from original position if different
    copy_warp(dispatch_output_local + real_offset * hidden_size,
              output_buf + has_sent * hidden_size, bytes_per_token)

Benefits:

  • Reduces communication volume by ~30% for tokens with multiple same-rank experts

  • Enables early groupgemm start while dispatch completes

  • Maintains correctness through indirect position tracking

Token Sorting

Problem: Tokens arrive in arbitrary order, causing poor memory access patterns and delayed dispatch completion.

Solution: Pre-sort tokens by expert rank and expert index to improve locality.

Implementation (get_ag_splits_and_recv_offset_for_dispatch):

The preprocessing kernel generates token_sort_indices that reorders tokens:

# Sort tokens by: (expert_rank, expert_idx_intra_rank, token_idx)
# This ensures:
# 1. Tokens going to same expert are contiguous
# 2. Memory access patterns are sequential
# 3. Early completion signals can be sent per expert

for token_idx in range(...):
    expert_idx = ld(topk_indices_buf + token_idx)
    expert_rank = expert_idx // experts_per_rank
    expert_idx_intra_rank = expert_idx % experts_per_rank

    # Compute sort key
    sort_key = (expert_rank * experts_per_rank + expert_idx_intra_rank) * max_tokens + token_idx
    st(token_sort_indices + sort_key, token_idx)

Dispatch Loop (tile_kernel_dispatch_token_intra_node):

# Iterate through sorted token offsets
for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
    sort_token_offset = ld(token_sort_indices + send_token_offset)
    if sort_token_offset >= 0:  # ignore dropped tokens
        token_offset = sort_token_offset // topk
        expert_idx = ld(topk_indices_tensor + sort_token_offset)
        # ... dispatch logic ...

Benefits:

  • Sequential Memory Access: Tokens for same expert are processed together

  • Early Completion: Can signal completion per expert as soon as all tokens arrive

  • Better Cache Utilization: Improved L2 cache hit rate

  • Reduced Barrier Overhead: Per-expert barriers instead of global barriers

SM Scheduling

Problem: Static task assignment leads to load imbalance and poor SM utilization.

Solution: Dynamic task queue with atomic counter for work-stealing.

Task Distribution:

# Each SM fetches tasks dynamically
task_id = tl.atomic_add(task_counter_ptr, 1)
while task_id < total_tasks:
    if task_id < num_dispatch_tasks:
        # Dispatch task
        execute_dispatch_task(task_id)
    elif task_id < num_dispatch_tasks + group_gemm_tasks:
        # GroupGEMM task
        execute_groupgemm_task(task_id - num_dispatch_tasks)
    else:
        # Combine task
        execute_combine_task(task_id - num_dispatch_tasks - group_gemm_tasks)

    # Fetch next task
    task_id = tl.atomic_add(task_counter_ptr, 1)

SM Allocation Strategies:

  1. Dispatch SMs (NUM_DISPATCH_SM): Dedicated SMs for token dispatch - Typically 20-40 SMs depending on token count - Handles communication-intensive operations

  2. Tail SMs (NUM_TAIL_SMS): SMs for two-stage dispatch completion - Typically 4-8 SMs - Handles local copy and barrier notification - Enables overlap with groupgemm

  3. GroupGEMM SMs: Remaining SMs for computation - Automatically distributed via task queue - Load balanced across all experts

Benefits:

  • Load Balancing: Fast SMs automatically take more tasks

  • Overlap: Dispatch and groupgemm can overlap naturally

  • Flexibility: Can adjust SM allocation based on workload characteristics

Barrier Synchronization

Per-Expert Barriers (Default):

# Signal completion when all tokens for an expert arrive
sent_tokens = atomic_add(counter_ptr + expert_idx, 1)
if sent_tokens == tokens_this_expert - 1:
    libshmem_device.fence()
    libshmem_device.signal_op(
        barriers_ptr + expert_idx_intra_rank * world_size + rank,
        1, libshmem_device.NVSHMEM_SIGNAL_SET, expert_rank)

Block-wise Barriers (USE_BLOCK_WISE_BARRIER=True):

# Per-tile barrier for finer granularity
# Enables groupgemm to start as soon as a tile is ready
barrier_idx = local_pid_m + tile_begin
while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
    pass

Per-Token Block Barriers (Combine phase):

# Barrier per hidden_size / BLOCK_SIZE_N chunk
barrier_n_idx = elem_idx * VEC_SIZE // BARRIER_TOKEN_BLOCK_SIZE
barrier_idx = token_scatter_idx * N_BARRIERS_PER_TOKEN + barrier_n_idx
token = ld_acquire(remote_barriers_ptr + barrier_idx, scope="sys")
while token != 1:
    token = ld_acquire(remote_barriers_ptr + barrier_idx, scope="sys")

Benefits:

  • Finer Granularity: Reduces wait time by enabling partial execution

  • Better Overlap: GroupGEMM can start processing as soon as data is ready

  • Reduced Synchronization Overhead: Smaller barrier scopes

Kernel Implementation Details

Dispatch Tile Kernel

@triton_dist.jit(do_not_specialize=["pid", "num_pid"])
def tile_kernel_dispatch_token_intra_node(
    pid, num_pid,
    counter_ptr, barriers_ptr,
    recv_buf_offset_per_expert,
    local_splits_buf,
    input_buf, output_buf,
    weight_send_buf, weight_recv_buf,
    topk_indices_tensor,
    token_dst_scatter_idx,
    num_input_tokens_per_rank,
    token_sort_indices,
    topk: tl.constexpr,
    hidden_size: tl.constexpr,
    experts_per_rank: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    WITH_SCATTER_INDICES: tl.constexpr,
    num_warps: tl.constexpr,
    profiler: Profiler,
    ENABLE_PROFILING: tl.constexpr,
):
    WARP_SIZE = 32
    rank = dl.rank()
    world_size = dl.num_ranks()
    thread_idx = tid(0)
    lane_idx = thread_idx % WARP_SIZE
    warp_id = thread_idx // WARP_SIZE
    total_warps = num_warps * num_pid
    global_warp_id = pid * num_warps + warp_id

    token_num = tl.load(num_input_tokens_per_rank + rank)

    # Process tokens in sorted order
    for send_token_offset in range(global_warp_id, token_num * topk, total_warps):
        sort_token_offset = ld(token_sort_indices + send_token_offset)
        if sort_token_offset >= 0:  # ignore dropped tokens
            token_offset = sort_token_offset // topk
            expert_idx = ld(topk_indices_tensor + sort_token_offset)
            expert_rank = expert_idx // experts_per_rank
            expert_idx_intra_rank = expert_idx % experts_per_rank

            # Allocate destination slot
            if not WITH_SCATTER_INDICES:
                store_idx = atomic_add_per_warp(
                    recv_buf_offset_per_expert + expert_rank * experts_per_rank * world_size +
                    expert_idx_intra_rank * world_size + rank, 1,
                    scope="gpu", semantic="relaxed")
            else:
                store_idx = ld(token_dst_scatter_idx + sort_token_offset)

            # Transfer token data
            src_ptr = input_buf + token_offset * hidden_size
            dst_ptr = output_buf + store_idx.to(tl.int64) * hidden_size
            libshmem_device.putmem_warp(dst_ptr, src_ptr, bytes_per_token, expert_rank)

            # Transfer weight if needed
            if HAS_WEIGHT:
                libshmem_device.putmem_warp(
                    weight_recv_buf + store_idx,
                    weight_send_buf + sort_token_offset,
                    weight_elem_size, expert_rank)

            # Signal completion per expert
            sync_warp()
            if lane_idx == 0:
                tokens_this_expert = ld(local_splits_buf + expert_idx)
                sent_tokens = atomic_add(counter_ptr + expert_idx, 1,
                                       scope="gpu", semantic="relaxed")
                if sent_tokens == tokens_this_expert - 1:
                    libshmem_device.fence()
                    libshmem_device.signal_op(
                        barriers_ptr + expert_idx_intra_rank * world_size + rank,
                        1, libshmem_device.NVSHMEM_SIGNAL_SET, expert_rank)

GroupGEMM Tile Kernel

@triton_dist.jit(do_not_specialize=["pid", "num_pid", "M"])
def tile_kernel_moe_grouped_gemm_nk_const(
    pid, num_pid,
    counter_ptr, barriers_ptr,
    a_ptr, b_ptr, c_ptr,
    expert_ids_ptr,
    split_size_ptr, split_size_cum_ptr,
    tile_num_ptr, tile_num_cum_ptr,
    num_total_tiles_ptr,
    M, N: tl.constexpr, K: tl.constexpr,
    stride_am, stride_ak,
    stride_be, stride_bn, stride_bk,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    profiler: Profiler,
    NEED_WAIT: tl.constexpr,
    NEED_NOTIFY: tl.constexpr,
    USE_BLOCK_WISE_BARRIER: tl.constexpr,
    IS_DISPATCH_TWO_STAGET: tl.constexpr,
    ENABLE_PROFILING: tl.constexpr,
):
    num_block_n = tl.cdiv(N, BLOCK_SIZE_N)
    pid_m = pid // num_block_n
    pid_n = pid % num_block_n

    expert_id = tl.load(expert_ids_ptr + pid_m)
    split_size = tl.load(split_size_ptr + expert_id)
    split_size_cum = tl.load(split_size_cum_ptr + pid_m)
    row_begin = split_size_cum

    # Wait for dispatch to complete (if needed)
    if NEED_WAIT:
        if IS_DISPATCH_TWO_STAGET:
            if USE_BLOCK_WISE_BARRIER:
                barrier_idx = local_pid_m + tile_begin
                while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
                    pass
            else:
                barrier_idx = expert_id
                while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
                    pass
        else:
            # Per-expert barrier
            barrier_idx = expert_id * world_size + thread_idx
            while ld_acquire(barriers_ptr + barrier_idx, scope="gpu") != 1:
                pass

    # Execute GEMM
    # ... GEMM computation ...

    # Notify combine phase (if needed)
    if NEED_NOTIFY:
        __syncthreads()
        token_begin = row_begin + local_pid_m * BLOCK_SIZE_M
        valid_tokens = min(row_remain, BLOCK_SIZE_M)
        if thread_idx < valid_tokens:
            st(barriers_ptr + (token_begin + thread_idx) * num_block_n + pid_n,
               1, scope="gpu", semantic="release")

Combine Tile Kernels

Gather Combine (tile_kernel_gather_combine_token_intra_node):

  • Reads expert outputs from remote ranks

  • Accumulates outputs for each token’s topk experts

  • Uses vectorized operations (128-bit loads/stores)

Scatter Combine (tile_kernel_scatter_token_intra_node):

  • Scatters expert outputs back to source ranks

  • Waits for per-token-block barriers

  • Transfers gate values if needed

TopK Reduce (tile_kernel_topk_reduce_token_intra_node):

  • Reduces scattered outputs for each token

  • Performs weighted sum across topk experts

  • Final output per original token

Performance Characteristics

  • Num Tokens Per Rank: 32k

Single-Node 8-GPU Configuration:

  • Communication: NVSHMEM symmetric memory for direct GPU-to-GPU access

  • Latency: ~3.5ms dispatch, ~5.9ms dispatch+groupgemm (depending on hidden size)

  • Throughput: algorithm bandwith ~201GB/s (due to token saving, it exceeds hardware limit)

Optimization Impact:

  • Token Saving: ~18%

  • Token Sorting: ~22%

  • Other Optimizations: ~8%

Usage Example

from triton_dist.kernels.nvidia.ep_all2all_fused import (
    mega_kernel_dispatch_token_moe_grouped_gemm,
    mega_kernel_moe_grouped_gemm_combine_token,
)

# Dispatch + GroupGEMM
mega_kernel_dispatch_token_moe_grouped_gemm[grid](
    task_counter_ptr,
    # ... dispatch params ...
    num_dispatch_tasks=NUM_DISPATCH_SM,
    # ... groupgemm params ...
    NUM_WARPS=16,
    NUM_TAIL_SMS=4,
    USE_BLOCK_WISE_BARRIER=True,
)

# GroupGEMM + Combine
mega_kernel_moe_grouped_gemm_combine_token[grid](
    task_counter_ptr,
    # ... groupgemm params ...
    # ... combine params ...
    num_combine_tasks=COMBINE_SM,
    num_reduce_tasks=REDUCE_SM,
    USE_SCATTER_MODE=True,
    NUM_WARPS=32,
)

See Also