AllGather GEMM
The AllGather GEMM kernel fuses AllGather collective communication with GEMM computation, enabling computation-communication overlapping.
API Reference
- ag_gemm(a, b, ctx, ...)
Performs AllGather followed by GEMM with overlapping.
- Parameters:
a – Local input tensor of shape
[M_per_rank, K]b – Weight tensor of shape
[N, K]ctx – AGGemmContext containing symmetric memory and signals
- Returns:
Output tensor of shape
[M, N]
- create_ag_gemm_context(local_tensor, weight, rank, num_ranks, max_M, BLOCK_M, BLOCK_N, BLOCK_K, stages)
Creates the context for AG-GEMM operation.
- Parameters:
local_tensor – Sample local tensor for shape inference
weight – Weight tensor
rank – Current rank ID
num_ranks – Total number of ranks
max_M – Maximum M dimension
BLOCK_M – Block size in M dimension
BLOCK_N – Block size in N dimension
BLOCK_K – Block size in K dimension
stages – Number of pipeline stages
- Returns:
AGGemmContext object
- gemm_persistent(...)
Persistent GEMM kernel that consumes AllGather results.
- gemm_non_persistent(...)
Non-persistent GEMM kernel variant.
Example Usage
from triton_dist.kernels.nvidia import ag_gemm, create_ag_gemm_context
# Create context
ctx = create_ag_gemm_context(A, B, rank, world_size, max_M=M,
BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, stages=3)
# Perform AllGather + GEMM
output = ag_gemm(A, B, ctx)