GEMM AllReduce
GEMM + AllReduce fused kernels for efficient tensor parallelism.
API Reference
- gemm_allreduce_op(a, b, ctx, ...)
Performs GEMM followed by AllReduce with overlapping.
- create_gemm_ar_context(...)
Creates context for GEMM + AllReduce.
- low_latency_gemm_allreduce_op(a, b, ctx, ...)
Low-latency GEMM + AllReduce for small batch sizes.
- create_ll_gemm_ar_context(...)
Creates context for low-latency GEMM + AllReduce.
Example Usage
from triton_dist.kernels.nvidia import (
gemm_allreduce_op,
create_gemm_ar_context
)
# Create context
ctx = create_gemm_ar_context(...)
# Perform GEMM + AllReduce
output = gemm_allreduce_op(input, weight, ctx)