GEMM AllReduce

GEMM + AllReduce fused kernels for efficient tensor parallelism.

API Reference

gemm_allreduce_op(a, b, ctx, ...)

Performs GEMM followed by AllReduce with overlapping.

create_gemm_ar_context(...)

Creates context for GEMM + AllReduce.

low_latency_gemm_allreduce_op(a, b, ctx, ...)

Low-latency GEMM + AllReduce for small batch sizes.

create_ll_gemm_ar_context(...)

Creates context for low-latency GEMM + AllReduce.

Example Usage

from triton_dist.kernels.nvidia import (
    gemm_allreduce_op,
    create_gemm_ar_context
)

# Create context
ctx = create_gemm_ar_context(...)

# Perform GEMM + AllReduce
output = gemm_allreduce_op(input, weight, ctx)