AllGather GEMM

The AllGather GEMM kernel fuses AllGather collective communication with GEMM computation, enabling computation-communication overlapping.

API Reference

ag_gemm(a, b, ctx, ...)

Performs AllGather followed by GEMM with overlapping.

Parameters:
  • a – Local input tensor of shape [M_per_rank, K]

  • b – Weight tensor of shape [N, K]

  • ctx – AGGemmContext containing symmetric memory and signals

Returns:

Output tensor of shape [M, N]

create_ag_gemm_context(local_tensor, weight, rank, num_ranks, max_M, BLOCK_M, BLOCK_N, BLOCK_K, stages)

Creates the context for AG-GEMM operation.

Parameters:
  • local_tensor – Sample local tensor for shape inference

  • weight – Weight tensor

  • rank – Current rank ID

  • num_ranks – Total number of ranks

  • max_M – Maximum M dimension

  • BLOCK_M – Block size in M dimension

  • BLOCK_N – Block size in N dimension

  • BLOCK_K – Block size in K dimension

  • stages – Number of pipeline stages

Returns:

AGGemmContext object

gemm_persistent(...)

Persistent GEMM kernel that consumes AllGather results.

gemm_non_persistent(...)

Non-persistent GEMM kernel variant.

Example Usage

from triton_dist.kernels.nvidia import ag_gemm, create_ag_gemm_context

# Create context
ctx = create_ag_gemm_context(A, B, rank, world_size, max_M=M,
                              BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, stages=3)

# Perform AllGather + GEMM
output = ag_gemm(A, B, ctx)