AllGather GEMM (AMD)

AllGather + GEMM kernel for AMD GPUs with intra-node computation-communication overlapping.

API Reference

ag_gemm_intra_node(a, b, ctx, ...)

Performs AllGather followed by GEMM with overlapping on AMD GPUs.

Parameters:
  • a – Local input tensor of shape [M_per_rank, K]

  • b – Weight tensor of shape [N, K]

  • ctx – AGGemmIntraNodeContext

Returns:

Output tensor of shape [M, N]

create_ag_gemm_intra_node_context(max_M, N, K, input_dtype, output_dtype, rank, world_size, tp_group, M_PER_CHUNK=256)

Creates the context for AG-GEMM intra-node operation on AMD GPUs.

Parameters:
  • max_M – Maximum M dimension

  • N – N dimension

  • K – K dimension

  • input_dtype – Input data type

  • output_dtype – Output data type

  • rank – Current rank ID

  • world_size – Total number of ranks

  • tp_group – Tensor parallel process group

  • M_PER_CHUNK – Chunk size for overlapping

Returns:

AGGemmIntraNodeContext object

Example Usage

from triton_dist.kernels.amd import ag_gemm_intra_node, create_ag_gemm_intra_node_context

# Create context
ctx = create_ag_gemm_intra_node_context(M, N, K, torch.float16, torch.float16,
                                         rank, world_size, tp_group)

# Perform AllGather + GEMM
output = ag_gemm_intra_node(input, weight, ctx)

Running the Test

bash scripts/launch_amd.sh python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 8192 29568