GEMM ReduceScatter (AMD)

GEMM + ReduceScatter kernel for AMD GPUs with intra-node computation-communication overlapping.

API Reference

gemm_rs_intra_node(a, b, ctx, ...)

Performs GEMM followed by ReduceScatter with overlapping on AMD GPUs.

Parameters:
  • a – Input tensor of shape [M, K]

  • b – Weight tensor of shape [N, K]

  • ctx – GemmRSIntraNodeContext

Returns:

Output tensor of shape [M/world_size, N]

create_gemm_rs_intra_node_context(max_M, N, output_dtype, rank, world_size, tp_group, fuse_scatter=True)

Creates the context for GEMM-RS intra-node operation on AMD GPUs.

Parameters:
  • max_M – Maximum M dimension

  • N – N dimension

  • output_dtype – Output data type

  • rank – Current rank ID

  • world_size – Total number of ranks

  • tp_group – Tensor parallel process group

  • fuse_scatter – Whether to fuse scatter into GEMM

Returns:

GemmRSIntraNodeContext object

Example Usage

from triton_dist.kernels.amd import gemm_rs_intra_node, create_gemm_rs_intra_node_context

# Create context
ctx = create_gemm_rs_intra_node_context(M, N, torch.float16,
                                         rank, world_size, tp_group)

# Perform GEMM + ReduceScatter
output = gemm_rs_intra_node(input, weight, ctx)