GEMM ReduceScatter
The GEMM ReduceScatter kernel fuses GEMM computation with ReduceScatter collective communication, enabling computation-communication overlapping.
API Reference
- gemm_rs(a, b, ctx, ...)
Performs GEMM followed by ReduceScatter with overlapping.
- Parameters:
a – Input tensor of shape
[M, K]b – Weight tensor of shape
[N, K]ctx – GemmRSContext containing symmetric memory and signals
- Returns:
Output tensor of shape
[M/world_size, N]
- create_gemm_rs_context(max_M, N, rank, world_size, local_world_size, output_dtype, rs_stream=None)
Creates the context for GEMM-RS operation.
- Parameters:
max_M – Maximum M dimension
N – N dimension
rank – Current rank ID
world_size – Total number of ranks
local_world_size – Number of ranks per node
output_dtype – Output data type
rs_stream – Optional CUDA stream for ReduceScatter
- Returns:
GemmRSContext object
Example Usage
from triton_dist.kernels.nvidia import gemm_rs, create_gemm_rs_context
# Create context
ctx = create_gemm_rs_context(M, N, rank, world_size, local_world_size,
output_dtype=torch.bfloat16)
# Perform GEMM + ReduceScatter
output = gemm_rs(input, weight, ctx)