Low-Latency All-to-All V2 (EP)
Ultra-low-latency All-to-All communication kernels for Expert Parallelism with online FP8 quantization. Designed for latency-sensitive MoE inference scenarios.
Overview
The Low-Latency All-to-All V2 kernels are optimized for:
Minimal Latency: Single-kernel dispatch and combine operations
Online FP8 Quantization: Reduces transfer size by 2x with minimal accuracy loss
Double Buffering: Overlaps communication with computation across iterations
Fused Operations: Quantization, transfer, and postprocessing in single kernel launch
Performance
Benchmark: 8x H800 GPUs
- Tokens per rank: 128
- TopK: 8
- Hidden size: 7168
- Data type: FP8 (online quantization)
Results:
- Dispatch latency: ~76 µs
- Combine latency: ~126 µs
- Total A2A latency: ~202 µs
Architecture
Message Format
Each token message contains:
┌─────────────┬──────────────────────────────────────┬───────────────┐
│ META (16B) │ TOKEN (hidden × 1B) │ SCALE (groups × 4B) │
│ (padded) │ (FP8 quantized) │ (FP32) │
└─────────────┴──────────────────────────────────────┴───────────────┘
META: Source token index (int32, padded to 16 bytes for alignment)
TOKEN: FP8-quantized hidden state
SCALE: Per-group quantization scales (hidden // fp8_gsize groups)
Double Buffering
Phase 0: Phase 1:
┌─────────────────┐ ┌─────────────────┐
│ Buffer Set 0 │ │ Buffer Set 1 │
│ - send_token │ │ - send_token │
│ - recv_token │ │ - recv_token │
│ - signal_buf │ │ - signal_buf │
└─────────────────┘ └─────────────────┘
↑ ↑
│ │
Iteration 0,2,4... Iteration 1,3,5...
Benefits:
- No explicit synchronization between iterations
- Signal values cycle: 1,1,2,2,3,3,...
- Enables compute-communication overlap
Dispatch Kernel (V2)
The dispatch kernel performs:
Online FP8 Quantization: Per-group scaling and quantization
Warp-level Transfer: Parallel token transfers to target experts
Count Exchange: AllGather of per-expert token counts
Postprocessing: Reorganize received tokens by expert
@triton_dist.jit
def dispatch_kernel_v2(
profiler_buf,
send_tensor, # [num_tokens, HIDDEN] - input tokens (bf16)
send_scale, # [num_tokens, NUM_GROUPS] - optional precomputed scales
topk_idx, # [num_tokens, TOPK] - expert routing
num_tokens,
send_token_buffer, # [max_m, msg_size] - symmetric
recv_token_buffer, # [num_experts_per_rank, world_size, max_m, msg_size] - symmetric
send_count_buffer, # [world_size, num_experts_per_rank] - symmetric
recv_count_buffer, # [world_size, num_experts_per_rank] - symmetric
recv_slot_counter, # [num_experts] - atomic counters
signal_buffer, # [WORLD_SIZE] - NVSHMEM signals
recv_token_source_indices, # [num_experts_per_rank, world_size * max_m] - output
recv_scale, # [num_experts_per_rank, world_size * max_m, num_groups] - output
recv_token, # [num_experts_per_rank, world_size * max_m, hidden] - output
expert_recv_count, # [num_experts_per_rank] - output
recv_token_source_count_and_start, # [num_experts_per_rank, world_size] - output
grid_sync_counter,
signal_val: int,
...
):
Kernel Phases
Phase 0: Quantize and Transfer
# Per-token processing
for i in range(pid, num_tokens, num_ctas):
cur_token = tl.load(send_token_ptrs, pertoken_mask)
# Online FP8 quantization (per-group)
group = tl.reshape(cur_token, (BLOCK_SCALE, FP8_GSIZE))
scale = tl.max(tl.abs(group), 1, keep_dims=True).to(tl.float32) * FP8_MAX_INV
quant = (group.to(tl.float32) / scale).to(tl.float8e4nv)
# Store token index as metadata
tl.store(tl.cast(send_buffer, tl.pointer_type(tl.int32)), i)
tl.store(send_token_buffer_ptrs, quant)
tl.store(send_scale_buffer_ptrs, scale)
# Warp-level transfer to each TopK expert
for warp_id in range(TOPK):
dst_expert = tl.load(topk_idx_ptrs + warp_id)
dst_slot = atomic_add_per_warp(recv_slot_counter + dst_expert, 1)
libshmem_device.putmem_nbi_warp(recv_buffer + dst_slot * MSG_SIZE,
send_buffer, MSG_SIZE, dst_rank)
Phase 1: Exchange Counts
barrier_on_this_grid(grid_sync_counter, False)
libshmem_device.fence()
for dst_rank in range(pid, WORLD_SIZE, num_ctas):
token_counts = tl.load(recv_slot_counter + dst_rank * NUM_EXPERTS_PER_RANK + ...)
libshmem_device.putmem_signal_nbi_block(
recv_count_buffer + rank * NUM_EXPERTS_PER_RANK,
send_count_buffer + dst_rank * NUM_EXPERTS_PER_RANK,
NUM_EXPERTS_PER_RANK * 4,
signal_buffer + rank, signal_val,
libshmem_device.NVSHMEM_SIGNAL_SET, dst_rank)
Phase 2: Wait for All Counts
for src_rank in range(pid, WORLD_SIZE, num_ctas):
libshmem_device.signal_wait_until(
signal_buffer + src_rank,
libshmem_device.NVSHMEM_CMP_EQ,
signal_val)
Phase 3: Postprocess
# Reorganize tokens by expert
for target_expert_idx in range(pid, NUM_EXPERTS, num_ctas):
dispatch_postprocess_kernel_v2_for_expert(
target_expert_idx,
recv_token_source_indices, recv_scale, recv_token,
expert_recv_count, recv_token_source_count_and_start,
recv_token_buffer, recv_count_buffer, ...)
Combine Kernel (V2)
The combine kernel performs:
Copy to Communication Buffer: Prepare tokens for transfer
Scatter to Source Ranks: Send expert outputs back to original token positions
Wait for All Data: Synchronize all transfers
TopK Weighted Reduce: Compute final combined output
@triton_dist.jit
def combine_kernel_v2(
profiler_buf,
send_tokens, # [num_experts_per_rank, world_size * max_m, hidden]
send_tokens_comm_buf, # Communication buffer (symmetric)
topk_indices, # [num_combined_tokens, topk]
topk_weights, # [num_combined_tokens, topk]
combined_out, # [num_combined_tokens, hidden] - output
recv_token_buffer, # [num_experts, max_m, hidden] - symmetric
signal_buf, # [num_expert] - NVSHMEM signals
dispatch_recv_token_source_indices,
dispatch_recv_token_source_count_and_start,
grid_sync_counter,
num_combined_tokens: int,
signal_val: int,
...
):
Intra-Node Optimization
For intra-node transfers, the kernel uses direct symmetric memory access instead of NVSHMEM put:
is_intra_node = (dst_rank // LOCAL_WORLD_SIZE) == cur_node_id
if is_intra_node:
# Direct load/store through symmetric memory
dst_remote_ptr = dl.symm_at(dst_ptr, dst_rank)
for h_idx in range(lane_id, num_hidden_iters, WARP_SIZE):
val_vec = dl.ld_vector(src_ptr + h_idx * VEC_SIZE, vec_size=VEC_SIZE)
dl.st_vector(dst_remote_ptr + h_idx * VEC_SIZE, val_vec)
else:
# NVSHMEM put for inter-node
libshmem_device.putmem_nbi_warp(dst_ptr, src_ptr, nbytes, dst_rank)
API Reference
Context Management
- create_ep_ll_a2a_ctx(max_m, hidden, topk, num_experts, online_quant_fp8, fp8_gsize, dtype, world_size, rank)
Create dispatch and combine contexts for low-latency EP All-to-All.
- Parameters:
max_m – Maximum tokens per rank (e.g., 128 or 256)
hidden – Hidden dimension size
topk – Number of experts per token
num_experts – Total number of experts
online_quant_fp8 – Must be True (only FP8 mode supported)
fp8_gsize – FP8 quantization group size (default: 128)
dtype – Base data type (e.g., torch.bfloat16)
world_size – Total number of ranks
rank – Current rank
- Returns:
Tuple of (LowlatencyDispatchContext, LowlatencyCombineContext)
Data Structures
- class LowlatencyDispatchContext
Context for low-latency dispatch operations.
- signal_val: int
Current signal value for this phase (cycles: 1,1,2,2,3,3,…)
- update_phase()
Advance to next phase (toggles buffer set and updates signal value).
- finalize()
Release NVSHMEM symmetric memory buffers.
- class LowlatencyCombineContext
Context for low-latency combine operations.
Same interface as LowlatencyDispatchContext.
- class DispatchMetaInfo
Metadata from dispatch operation needed for combine.
- recv_token_source_indices: torch.Tensor
[num_experts_per_rank, world_size * max_m]- Maps received tokens to source indices.
- recv_token_source_count_and_start: torch.Tensor
[num_experts_per_rank, world_size]- Packed (count, start) per source rank.
Kernel Functions
- dispatch_kernel_v2[grid](profiler_buf, send_tensor, send_scale, topk_idx, num_tokens, ...)
Low-latency dispatch kernel with online FP8 quantization.
- Parameters:
send_tensor –
[num_tokens, HIDDEN]- Input tokens (bf16)topk_idx –
[num_tokens, TOPK]- Expert routing decisionsONLINE_QUANT_FP8 – Whether to perform online FP8 quantization (must be True)
FP8_GSIZE – Quantization group size (typically 128)
ENABLE_PROFILING – Enable intra-kernel profiling
- combine_kernel_v2[grid](profiler_buf, send_tokens, send_tokens_comm_buf, topk_indices, topk_weights, combined_out, ...)
Low-latency combine kernel with weighted reduction.
- Parameters:
send_tokens –
[num_experts_per_rank, world_size * max_m, hidden]- Expert outputstopk_weights –
[num_combined_tokens, topk]- Routing weightscombined_out –
[num_combined_tokens, hidden]- Output bufferZERO_COPY – If True, assumes send_tokens is already in comm buffer
Profiling Support
Built-in profiling for performance analysis:
# Enable profiling
layer = EPLowLatencyAllToAllLayer(..., enable_profiling=True)
# Run operations
recv_token, recv_scale, expert_recv_count, dispatch_meta = layer.dispatch(...)
combined = layer.combine(...)
# Export traces
layer.dump_dispatch_trace() # Outputs: ./prof/ll_dispatch_RANK_X.json
layer.dump_combine_trace() # Outputs: ./prof/ll_combine_RANK_X.json
Profiling categories:
Dispatch:
quant_and_put,count_put,wait,postprocessCombine:
copy_and_put,recv_wait,topk_reduce
Usage Example
from triton_dist.layers.nvidia import EPLowLatencyAllToAllLayer
# Create layer
layer = EPLowLatencyAllToAllLayer(
max_m=128,
hidden=7168,
topk=8,
online_quant_fp8=True,
rank=rank,
num_experts=256,
local_world_size=8,
world_size=32,
fp8_gsize=128,
dtype=torch.bfloat16,
enable_profiling=False,
)
# Dispatch (bf16 input → fp8 quantized transfer)
recv_token, recv_scale, expert_recv_count, dispatch_meta = layer.dispatch(
send_tokens=tokens, # [num_tokens, hidden], bf16
send_scales=None, # Online quantization
topk_indices=expert_ids, # [num_tokens, topk]
)
# Expert computation on recv_token (fp8 with scales)
# ... your expert FFN here ...
expert_output = ... # [num_experts_per_rank, world_size * max_m, hidden]
# Combine
combined = layer.combine(
send_tokens=expert_output,
topk_indices=expert_ids,
topk_weights=routing_weights,
dispatch_meta=dispatch_meta,
)
# Cleanup
layer.finalize()
Note
This kernel currently requires online_quant_fp8=True.
Pre-quantized input mode is not yet supported.
Run Example
NVSHMEM_SYMMETRIC_SIZE=2g bash scripts/launch.sh python/triton_dist/test/nvidia/test_ep_ll_a2a.py -M 128 --profile
See Also
Expert Parallelism All-to-All (EP A2A) - Standard EP A2A kernels (higher throughput, higher latency)
Expert Parallelism All-to-All (Intra-Node) - Intra-node optimized kernels
Low-Latency EP All-to-All Layer - High-level layer API