AllGather GEMM (AMD)
AllGather + GEMM kernel for AMD GPUs with intra-node computation-communication overlapping.
API Reference
- ag_gemm_intra_node(a, b, ctx, ...)
Performs AllGather followed by GEMM with overlapping on AMD GPUs.
- Parameters:
a – Local input tensor of shape
[M_per_rank, K]b – Weight tensor of shape
[N, K]ctx – AGGemmIntraNodeContext
- Returns:
Output tensor of shape
[M, N]
- create_ag_gemm_intra_node_context(max_M, N, K, input_dtype, output_dtype, rank, world_size, tp_group, M_PER_CHUNK=256)
Creates the context for AG-GEMM intra-node operation on AMD GPUs.
- Parameters:
max_M – Maximum M dimension
N – N dimension
K – K dimension
input_dtype – Input data type
output_dtype – Output data type
rank – Current rank ID
world_size – Total number of ranks
tp_group – Tensor parallel process group
M_PER_CHUNK – Chunk size for overlapping
- Returns:
AGGemmIntraNodeContext object
Example Usage
from triton_dist.kernels.amd import ag_gemm_intra_node, create_ag_gemm_intra_node_context
# Create context
ctx = create_ag_gemm_intra_node_context(M, N, K, torch.float16, torch.float16,
rank, world_size, tp_group)
# Perform AllGather + GEMM
output = ag_gemm_intra_node(input, weight, ctx)
Running the Test
bash scripts/launch_amd.sh python/triton_dist/test/amd/test_ag_gemm_intra_node.py 8192 8192 29568