GEMM ReduceScatter (AMD)
GEMM + ReduceScatter kernel for AMD GPUs with intra-node computation-communication overlapping.
API Reference
- gemm_rs_intra_node(a, b, ctx, ...)
Performs GEMM followed by ReduceScatter with overlapping on AMD GPUs.
- Parameters:
a – Input tensor of shape
[M, K]b – Weight tensor of shape
[N, K]ctx – GemmRSIntraNodeContext
- Returns:
Output tensor of shape
[M/world_size, N]
- create_gemm_rs_intra_node_context(max_M, N, output_dtype, rank, world_size, tp_group, fuse_scatter=True)
Creates the context for GEMM-RS intra-node operation on AMD GPUs.
- Parameters:
max_M – Maximum M dimension
N – N dimension
output_dtype – Output data type
rank – Current rank ID
world_size – Total number of ranks
tp_group – Tensor parallel process group
fuse_scatter – Whether to fuse scatter into GEMM
- Returns:
GemmRSIntraNodeContext object
Example Usage
from triton_dist.kernels.amd import gemm_rs_intra_node, create_gemm_rs_intra_node_context
# Create context
ctx = create_gemm_rs_intra_node_context(M, N, torch.float16,
rank, world_size, tp_group)
# Perform GEMM + ReduceScatter
output = gemm_rs_intra_node(input, weight, ctx)