GEMM ReduceScatter

The GEMM ReduceScatter kernel fuses GEMM computation with ReduceScatter collective communication, enabling computation-communication overlapping.

API Reference

gemm_rs(a, b, ctx, ...)

Performs GEMM followed by ReduceScatter with overlapping.

Parameters:
  • a – Input tensor of shape [M, K]

  • b – Weight tensor of shape [N, K]

  • ctx – GemmRSContext containing symmetric memory and signals

Returns:

Output tensor of shape [M/world_size, N]

create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_dtype, rs_stream=None)

Creates the context for GEMM-RS operation.

Parameters:
  • max_M – Maximum M dimension

  • N – N dimension

  • rank – Current rank ID

  • world_size – Total number of ranks

  • local_world_size – Number of ranks per node

  • output_dtype – Output data type

  • rs_stream – Optional CUDA stream for ReduceScatter

Returns:

GemmRSContext object

Example Usage

from triton_dist.kernels.nvidia import gemm_rs, create_gemm_rs_context

# Create context
ctx = create_gemm_rs_context(M, N, rank, world_size, local_world_size,
                               output_dtype=torch.bfloat16)

# Perform GEMM + ReduceScatter
output = gemm_rs(input, weight, ctx)