Low-Latency EP All-to-All Layer

High-level layer for ultra-low-latency Expert Parallelism All-to-All communication with online FP8 quantization.

Overview

EPLowLatencyAllToAllLayer is optimized for latency-sensitive MoE inference scenarios, achieving sub-202µs end-to-end All-to-All latency on 8 H800 GPUs.

Key Features

  • Ultra-Low Latency: ~202µs for 128 tokens per rank on 8 GPUs

  • Online FP8 Quantization: 2x bandwidth reduction with per-group scaling

  • Double Buffering: Overlaps computation and communication across iterations

  • Fused Operations: Quantization, transfer, and postprocessing in single kernel

  • Built-in Profiling: Detailed intra-kernel timing via Perfetto traces

Comparison with EPAll2AllLayer

Feature

EPAll2AllLayer

EPLowLatencyAllToAllLayer

Latency

Higher (multiple kernel launches)

Lower (~202µs for 8 GPUs)

Throughput

Higher (larger batches)

Optimized for small batches

Data Type

BF16/FP16

Online FP8 quantization

Buffer Management

Dynamic resizing

Fixed-size double buffering

Use Case

Training, large batch inference

Low-latency inference

Architecture

Double Buffering

Iteration 0        Iteration 1        Iteration 2
┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│ Phase 0     │    │ Phase 1     │    │ Phase 0     │
│ Buffer 0    │    │ Buffer 1    │    │ Buffer 0    │
│ Signal=1    │    │ Signal=1    │    │ Signal=2    │
└─────────────┘    └─────────────┘    └─────────────┘

Signal cycling prevents false matches:
- Phase 0: signal = call_count // 2 + 1
- Phase 1: signal = call_count // 2 + 1

Memory Layout

Dispatch Buffers (per phase):
├── send_token_buffer    [max_m, msg_size]              # Local staging
├── recv_token_buffer    [experts/rank, world, max_m, msg_size]  # Symmetric
├── send_count_buffer    [world_size, experts/rank]     # Token counts
├── recv_count_buffer    [world_size, experts/rank]     # Symmetric
├── signal_buffer        [num_experts]                  # NVSHMEM signals
└── recv_slot_counter    [num_experts]                  # Atomic counters

Combine Buffers (per phase):
├── send_tokens_comm_buf [experts/rank, world * max_m, hidden]  # Symmetric
├── recv_token_buffer    [num_experts, max_m, hidden]   # Symmetric
└── signal_buffer        [num_experts]                  # NVSHMEM signals

API Reference

EPLowLatencyAllToAllLayer

class EPLowLatencyAllToAllLayer(max_m, hidden, topk, online_quant_fp8, rank, num_experts, local_world_size, world_size, fp8_gsize=128, dtype=torch.bfloat16, enable_profiling=False)

Low-latency EP All-to-All layer with online FP8 quantization.

Parameters:
  • max_m – Maximum tokens per rank (e.g., 128 or 256)

  • hidden – Hidden dimension size

  • topk – Number of experts per token

  • online_quant_fp8 – Must be True (only FP8 mode supported)

  • rank – Current rank

  • num_experts – Total number of experts

  • local_world_size – Number of ranks per node

  • world_size – Total number of ranks

  • fp8_gsize – FP8 quantization group size (default: 128)

  • dtype – Base data type for computation (default: torch.bfloat16)

  • enable_profiling – Enable intra-kernel profiling (default: False)

dispatch(send_tokens, send_scales, topk_indices)

Dispatch tokens to experts with online FP8 quantization.

Parameters:
  • send_tokens[num_tokens, hidden] - Input tokens (bf16)

  • send_scales – Must be None (online quantization only)

  • topk_indices[num_tokens, topk] - Expert routing decisions

Returns:

Tuple of (recv_token, recv_scale, expert_recv_count, dispatch_meta)

  • recv_token: [experts/rank, world * max_m, hidden] - FP8 tokens

  • recv_scale: [experts/rank, world * max_m, num_groups] - FP32 scales

  • expert_recv_count: [experts/rank] - Tokens per local expert

  • dispatch_meta: DispatchMetaInfo for combine

combine(send_tokens, topk_indices, topk_weights, dispatch_meta, zero_copy=False)

Combine expert outputs with weighted reduction.

Parameters:
  • send_tokens[experts/rank, world * max_m, hidden] - Expert outputs

  • topk_indices[num_combined_tokens, topk] - Expert routing

  • topk_weights[num_combined_tokens, topk] - Routing weights

  • dispatch_meta – Metadata from dispatch

  • zero_copy – If True, assumes send_tokens is in comm buffer (default: False)

Returns:

[num_combined_tokens, hidden] - Combined output

finalize()

Release NVSHMEM symmetric memory buffers.

dump_dispatch_trace()

Export dispatch profiling to Perfetto trace.

Output: ./prof/ll_dispatch_RANK_{rank}.json

dump_combine_trace()

Export combine profiling to Perfetto trace.

Output: ./prof/ll_combine_RANK_{rank}.json

DispatchMetaInfo

class DispatchMetaInfo

Metadata from dispatch needed for combine operation.

recv_token_source_indices: torch.Tensor

[num_experts_per_rank, world_size * max_m] - Maps received positions to source token indices.

recv_token_source_count_and_start: torch.Tensor

[num_experts_per_rank, world_size] - Packed int64 containing (count, start) per source rank.

Usage Example

Basic Usage

import torch
from triton_dist.utils import initialize_distributed
from triton_dist.layers.nvidia import EPLowLatencyAllToAllLayer

# Initialize distributed runtime
rank, world_size = initialize_distributed()

# Create layer
layer = EPLowLatencyAllToAllLayer(
    max_m=128,
    hidden=7168,
    topk=8,
    online_quant_fp8=True,
    rank=rank,
    num_experts=256,
    local_world_size=8,
    world_size=32,
    fp8_gsize=128,
    dtype=torch.bfloat16,
    enable_profiling=False,
)

# Prepare inputs
tokens = torch.randn(128, 7168, dtype=torch.bfloat16, device="cuda")
expert_ids = torch.randint(0, 256, (128, 8), dtype=torch.int32, device="cuda")
routing_weights = torch.softmax(torch.randn(128, 8, device="cuda"), dim=-1).float()

# Dispatch (bf16 → fp8 quantized transfer)
recv_token, recv_scale, expert_recv_count, dispatch_meta = layer.dispatch(
    send_tokens=tokens,
    send_scales=None,  # Online quantization
    topk_indices=expert_ids,
)

# Expert computation
# recv_token: [experts/rank, world * max_m, hidden] - FP8
# recv_scale: [experts/rank, world * max_m, num_groups] - FP32
expert_output = your_fp8_expert_ffn(recv_token, recv_scale)

# Combine
combined = layer.combine(
    send_tokens=expert_output,
    topk_indices=expert_ids,
    topk_weights=routing_weights,
    dispatch_meta=dispatch_meta,
)

# Cleanup
layer.finalize()

With Profiling

# Enable profiling for performance analysis
layer = EPLowLatencyAllToAllLayer(
    ...,
    enable_profiling=True,
)

# Run multiple iterations
for _ in range(100):
    recv_token, recv_scale, expert_recv_count, dispatch_meta = layer.dispatch(...)
    combined = layer.combine(...)

# Export traces (requires barrier for complete data)
torch.distributed.barrier()
layer.dump_dispatch_trace()
layer.dump_combine_trace()

# View in Perfetto UI (https://ui.perfetto.dev)

Expert FFN with FP8

def fp8_expert_ffn(recv_token, recv_scale, expert_recv_count, expert_weights):
    """
    Process tokens with FP8 expert weights.

    recv_token: [experts/rank, world * max_m, hidden] - FP8
    recv_scale: [experts/rank, world * max_m, num_groups] - FP32
    expert_recv_count: [experts/rank] - Number of valid tokens per expert
    expert_weights: Expert FFN weights (may also be FP8)
    """
    outputs = []
    for expert_idx in range(num_experts_per_rank):
        count = expert_recv_count[expert_idx].item()
        if count == 0:
            continue

        # Extract valid tokens for this expert
        tokens = recv_token[expert_idx, :count]  # [count, hidden]
        scales = recv_scale[expert_idx, :count]  # [count, num_groups]

        # Dequantize and compute
        tokens_bf16 = dequantize_fp8(tokens, scales)
        output = expert_weights[expert_idx](tokens_bf16)

        # Store back
        outputs.append((expert_idx, output))

    # Reassemble output tensor
    return reassemble_outputs(outputs, recv_token.shape)

Performance Benchmark

# Run benchmark
NVSHMEM_SYMMETRIC_SIZE=2g bash scripts/launch.sh \
    python/triton_dist/test/nvidia/test_ep_ll_a2a.py \
    -M 128 --iters 100 --verify-iters 20 --check

Expected results on 32x H800 (4 nodes):

Configuration:
- Tokens per rank: 128
- TopK: 8
- Hidden size: 7168
- Data type: FP8 (online quantization)

Results:
- Dispatch latency: ~70 µs (median)
- Combine latency: ~67 µs (median)
- Total A2A latency: ~137 µs

Profiling Categories

Dispatch Phases:

  • quant_and_put: Online FP8 quantization and warp-level transfer

  • count_put: Exchange per-expert token counts

  • wait: Wait for all counts to arrive

  • postprocess: Reorganize received tokens by expert

Combine Phases:

  • copy_and_put: Copy to comm buffer and scatter to source ranks

  • recv_wait: Wait for all expert outputs

  • topk_reduce: Compute weighted sum of TopK expert outputs

Note

This layer requires online_quant_fp8=True. Pre-quantized input is not supported. For BF16/FP16 without quantization, use EPAll2AllLayer instead.

See Also