Overlapping GEMM ReduceScatter on AMD GPU
In this tutorial, you will write a fused Gemm and ReduceScatter Op using Triton-distributed.
In doing so, you will learn about:
How to overlap reduce-scatter with gemm operations to hide communication on AMD GPUs.
# To run this tutorial
bash ./scripts/launch_amd.sh tutorials/10-AMD-overlapping-gemm-reduce-scatter.py
Kernel
@triton_dist.jit
def kernel_gemm_rs_producer_fuse_scatter(
# Pointers to matrices
a_ptr, b_ptr, scatter_bufs_ptr, rank, num_ranks,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, M_PER_COPY_CHUNK: tl.constexpr #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# rank swizzle
M_per_rank = M // num_ranks
num_pid_m_per_copy_chunk = M_PER_COPY_CHUNK // BLOCK_SIZE_M
chunk_offset = pid_m // (num_pid_m_per_copy_chunk * num_ranks)
rank_offset = pid_m % (num_pid_m_per_copy_chunk * num_ranks) // num_pid_m_per_copy_chunk
block_offset = pid_m % num_pid_m_per_copy_chunk
rank_offset = (rank_offset + rank + 1) % num_ranks
pid_m = (rank_offset * M_per_rank + chunk_offset * M_PER_COPY_CHUNK + block_offset * BLOCK_SIZE_M) // BLOCK_SIZE_M
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
dtype = a_ptr.dtype.element_ty
c = accumulator.to(dtype)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
target_m = ((pid_m * BLOCK_SIZE_M % M_per_rank) + M_per_rank * rank)
offs_cm = target_m + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptr = tl.load(scatter_bufs_ptr + rank_offset).to(tl.pointer_type(dtype))
c_ptr = tl.multiple_of(c_ptr, 16)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton_dist.jit
def kernel_consumer_reduce(
c_ptr, # [M, N]
out_ptr, # [M_per_rank, N]
# shape of matrix
M_per_rank,
N,
rank,
num_ranks: tl.constexpr,
# tile size
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = tl.where(offs < M_per_rank * N, offs, 0)
out_ptrs = out_ptr + offs
accum = tl.zeros((BLOCK_SIZE, ), dtype=out_ptr.dtype.element_ty)
for i in range(0, num_ranks):
cur_rank = (i + rank + 1) % num_ranks
c_ptrs = c_ptr + offs + cur_rank * M_per_rank * N
data = tl.load(c_ptrs)
accum += data
tl.store(out_ptrs, accum)
def ring_reduce_after_scatter(
rank,
num_ranks,
scatter_out, # [M, N]
stream,
):
M, N = scatter_out.shape
M_per_rank = M // num_ranks
output = torch.empty((M_per_rank, N), dtype=scatter_out.dtype, device=scatter_out.device)
grid = lambda META: (triton.cdiv(M_per_rank * N, META["BLOCK_SIZE"]), )
with torch.cuda.stream(stream):
kernel_consumer_reduce[grid](
scatter_out,
output,
M_per_rank,
N,
rank=rank,
num_ranks=num_ranks,
BLOCK_SIZE=2048,
num_warps=2,
)
return output
GEMM RS Class
class triton_gemm_rs_intra_node(torch.nn.Module):
def __init__(
self,
tp_group: torch.distributed.ProcessGroup,
max_M: int,
N: int,
K: int,
input_dtype: torch.dtype,
output_dtype: torch.dtype,
fuse_scatter: bool = True,
):
self.tp_group = tp_group
self.rank: int = tp_group.rank()
self.world_size = tp_group.size()
self.max_M: int = max_M
self.N = N
self.K = K
self.input_dtype = input_dtype
self.output_dtype = output_dtype
self.fuse_scatter = fuse_scatter
# Use the auxiliary functions provided by Triton-distributed to construct the context required for GEMM-RS.
# This simplifies the code logic. The context mainly includes:
# (1) The globally symmetric memory required;
# (2) The signals used for communication between prodcuer and consumer;
# (3) Scatter streams.
self.ctx = create_gemm_rs_intra_node_context(
self.max_M,
self.N,
self.output_dtype,
self.rank,
self.world_size,
self.tp_group,
self.fuse_scatter,
)
def forward(self, input: torch.Tensor, # [M, local_K]
weight: torch.Tensor, # [N, local_K]
):
ctx = self.ctx
M, local_K = input.shape
N, K = weight.shape
stride_bk, stride_bn = weight.stride(1), weight.stride(0)
assert K == local_K
M_per_rank = M // ctx.num_ranks
current_stream = torch.cuda.current_stream()
barrier_all_on_stream(ctx.rank, ctx.num_ranks, ctx.sync_bufs_ptr, current_stream)
output = torch.empty((M_per_rank, N), dtype=output_dtype, device=input.device)
alignment = 256
assert M % alignment == 0 and N % alignment == 0 and K % alignment == 0
# producer gemm fused scatter
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel_gemm_rs_producer_fuse_scatter[grid](
input,
weight,
ctx.scatter_bufs_ptr,
ctx.rank,
ctx.num_ranks,
M,
N,
K,
input.stride(0),
input.stride(1),
stride_bk,
stride_bn,
N,
1,
)
scatter_out = ctx.scatter_bufs[ctx.rank][:M]
# barrier all to wait for gemm finish
barrier_all_on_stream(ctx.rank, ctx.num_ranks, ctx.sync_bufs_ptr, current_stream)
# consumer reduction
output = ring_reduce_after_scatter(ctx.rank, ctx.num_ranks, scatter_out, current_stream)
return output
Benchmark
def torch_gemm_rs(
input: torch.Tensor, # [M, local_k]
weight: torch.Tensor, # [N, local_K]
bias: Optional[torch.Tensor],
TP_GROUP,
):
M, local_K = input.shape
world_size = TP_GROUP.size()
N, _ = weight.shape
output = torch.matmul(input, weight.T)
if bias:
output = output + bias
rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device)
torch.distributed.reduce_scatter_tensor(rs_output, output, group=TP_GROUP)
return rs_output
def init():
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
torch.cuda.set_device(LOCAL_RANK)
torch.distributed.init_process_group(
backend="nccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=datetime.timedelta(seconds=1800),
)
assert torch.distributed.is_initialized()
TP_GROUP = torch.distributed.new_group(ranks=list(range(WORLD_SIZE)), backend="nccl")
torch.distributed.barrier(TP_GROUP)
torch.manual_seed(3 + RANK)
torch.cuda.manual_seed_all(3 + RANK)
torch.cuda.synchronize()
torch.distributed.barrier()
return RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP
def destroy():
torch.cuda.synchronize()
torch.distributed.barrier()
torch.distributed.destroy_process_group()
if __name__ == "__main__":
# init
RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP = init()
# NOTE: We should get device after process group init.
DEVICE = triton.runtime.driver.active.get_active_torch_device()
dtype = torch.float16
M = 8192
N = 4096
K = 12288
local_K = K // WORLD_SIZE
input_dtype = dtype
output_dtype = input_dtype
atol = 1e-2
rtol = 1e-2
# Generate input and weight.
scale = TP_GROUP.rank() + 1
data_config = [((M, local_K), dtype, (0.01 * scale, 0), DEVICE), # input
((N, local_K), dtype, (0.01 * scale, 0), DEVICE), # weight
(None), # bias
]
generator = generate_data(data_config)
input, weight, bias = next(generator)
# torch impl
ref_out = torch_gemm_rs(input, weight, False, bias, TP_GROUP)
torch.cuda.synchronize()
torch.distributed.barrier()
# dist triton impl
dist_gemm_rs_op = triton_gemm_rs_intra_node(TP_GROUP, M, N, K, input_dtype, output_dtype)
tri_out = dist_gemm_rs_op.forward(input, weight)
if torch.allclose(tri_out, ref_out, atol=atol, rtol=rtol):
dist_print("✅ Triton and Torch match")
else:
dist_print(f"The maximum difference between torch and triton is {torch.max(torch.abs(tri_out - ref_out))}")
dist_print("❌ Triton and Torch differ")
# Finally destroy distributed process group.
destroy()