Overlapping AllGather GEMM on AMD GPU

This script demonstrates the implementations allgather-gemm intra node with computational communication overlap using Triton-distrbuted.

In this tutorial, you will write a simple fused AllGather and Gemm using Triton-distributed.

In doing so, you will learn about:

  • How to finely partition communication (producers) and computation (consumers) and overlap them.

  • Intra node communication on AMD GPUs.

# To run this tutorial
bash ./scripts/launch_amd.sh tutorials/09-AMD-overlapping-allgather-gemm.py

Background

The typical implementation of AG + Gemm is sequential, and communication and computation cannot be parallelized. As shown below:

The total time consumption is t1 + t2.

Notice that when the shape of GEMM is relatively large, it is always tiled into small blocks for computation on heterogeneous accelerators. This means that there are always some tiles computed first and some computed later. Therefore, we can first transfer the data that is first relied on the GEMM computation during AG, so that the GEMM computation and the AG communication can be overlapped, improving the overall MFU. As shown below, both communication and computation are divided into 4 blocks, and the blocks with the same number have dependencies:

Then the effect of the overlap is as follows:

Under the ideal computation-to-memory access ratio, we can hide the communication time of blocks 1, 2, and 3 in the GEMM computation time overhead. Fortunately, there is always a part of the AG communication that is local. Local communication can be omitted or has very low overhead. If the communication of block 0 only occurs locally, then the GEMM waiting for the communication of block 0 always has almost no overhead.

Assume that the shapes of matrices A and B in GEMM are [M, K] and [K, N] respectively. We can partition both communication and computation along the M-axis into several chunks. Communication is carried out on a per-chunk basis, and each tile of GEMM needs to wait for the data on the chunks it depends on to arrive. We can use signals to build a producer-consumer dependency model. That is, GEMM acts as the consumer and waits for the signals that the current tile’s computation depends on, while AG acts as the producer and sets the signals after the data transfer is completed. For example, assume that the size of M is 1024, the size of each chunk is 128, and the BLOCK_SIZE_M of GEMM is 256. Then each tile of GEMM needs to wait for the signals of two chunks to arrive.

Kernel

The differences from the gemm-only kernel are: (1) Each tile needs to wait for the arrival of the dependent signals before computation; (2) Swizzling is performed on the order of the chunks for AG communication.

@triton_dist.jit
def consumer_gemm_persistent_kernel(
    A,
    localA,
    B,
    C,
    M,
    N,
    K,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    rank,
    world_size,
    barrier_ptr,  # signals
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    M_PER_CHUNK: tl.constexpr,
    NUM_SMS: tl.constexpr,
    NUM_XCDS: tl.constexpr,
    EVEN_K: tl.constexpr,
):
    pid = tl.program_id(0)
    if NUM_XCDS != 1:
        pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    total_tiles = num_pid_m * num_pid_n

    tl.assume(stride_am > 0)
    tl.assume(stride_ak > 0)
    tl.assume(stride_bn > 0)
    tl.assume(stride_bk > 0)
    tl.assume(stride_cm > 0)
    tl.assume(stride_cn > 0)

    M_per_rank = M // world_size
    pid_m_per_rank = tl.cdiv(M_per_rank, BLOCK_SIZE_M)

    acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
    for tile_id in range(pid, total_tiles, NUM_SMS):
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = tile_id // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
        pid_n = (tile_id % num_pid_in_group) // group_size_m

        ## swizzle
        num_pid_m_per_copy_chunk = M_PER_CHUNK // BLOCK_SIZE_M
        chunk_offset = pid_m // (num_pid_m_per_copy_chunk * world_size)
        rank_offset = pid_m % (num_pid_m_per_copy_chunk * world_size) // num_pid_m_per_copy_chunk
        block_offset = pid_m % num_pid_m_per_copy_chunk
        # swizzle rank offset, launch current rank firstly.
        rank_offset = (rank_offset + rank) % world_size
        pid_m = (rank_offset * M_per_rank + chunk_offset * M_PER_CHUNK + block_offset * BLOCK_SIZE_M) // BLOCK_SIZE_M

        offs_am = pid_m * BLOCK_SIZE_M
        offs_sig = offs_am // M_PER_CHUNK
        offs_rank = pid_m // pid_m_per_rank

        # Skip waiting local tensor as we pass it int kernel function as argument.
        if offs_rank != rank:
            token = dl.wait(barrier_ptr + offs_sig, 1, "sys", "acquire", waitValue=1)
            A = dl.consume_token(A, token)

        rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        rk = tl.arange(0, BLOCK_SIZE_K)
        rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
        rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

        if offs_rank == rank:
            rm = rm % M_per_rank
            A_BASE = localA + rm[:, None] * stride_am + rk[None, :] * stride_ak
        else:
            A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak

        B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
        tl.assume(pid_m > 0)
        tl.assume(pid_n > 0)

        loop_k = tl.cdiv(K, BLOCK_SIZE_K)
        if not EVEN_K:
            loop_k -= 1

        acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
        for k in range(0, loop_k):
            a = tl.load(tl.multiple_of(A_BASE, (1, 16)))
            b = tl.load(tl.multiple_of(B_BASE, (16, 1)))
            acc += tl.dot(a, b)
            A_BASE += BLOCK_SIZE_K * stride_ak
            B_BASE += BLOCK_SIZE_K * stride_bk

        if not EVEN_K:
            k = loop_k
            rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
            A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
            B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
            A_BASE = tl.multiple_of(A_BASE, (1, 16))
            B_BASE = tl.multiple_of(B_BASE, (16, 1))
            a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0)
            b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0)
            acc += tl.dot(a, b)

        c = acc.to(C.type.element_ty)
        rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
        rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
        c_mask = (rm[:, None] < M) & (rn[None, :] < N)
        C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
        tl.store(C_, c, c_mask)

To fully utilize the communication bandwidth of AMD GPUs, we use multiple streams for parallel processing and perform swizzling on the stream selection to enable AMD GPU links to work simultaneously.

def producer_ag_push_mode(
    rank,
    num_ranks,
    local_tensor: torch.Tensor,
    remote_tensor_buffers: List[torch.Tensor],
    one: torch.Tensor,
    M_PER_CHUNK: int,
    ag_stream_pool: List[torch.cuda.Stream],
    barrier_buffers: List[torch.Tensor],
):
    M_per_rank, N = local_tensor.shape
    chunk_num_per_rank = M_per_rank // M_PER_CHUNK
    num_stream = len(ag_stream_pool)
    rank_orders = [(rank + i) % num_ranks for i in range(num_ranks)]

    data_elem_size = local_tensor.element_size()
    barrier_elem_size = one.element_size()

    for idx, remote_rank in enumerate(rank_orders):
        if remote_rank == rank:
            continue
        for chunk_idx_intra_rank in range(chunk_num_per_rank):
            # Chunk swizzle
            chunk_pos = rank * chunk_num_per_rank + chunk_idx_intra_rank
            stream_pos = idx % num_stream
            ag_stream = ag_stream_pool[stream_pos]
            M_dst_start_pos = rank * M_per_rank + chunk_idx_intra_rank * M_PER_CHUNK
            M_src_start_pos = chunk_idx_intra_rank * M_PER_CHUNK
            src_ptr = local_tensor.data_ptr() + M_src_start_pos * N * data_elem_size
            dst_ptr = remote_tensor_buffers[remote_rank].data_ptr() + M_dst_start_pos * N * data_elem_size

            nbytes = M_PER_CHUNK * N * data_elem_size
            cp_res = hip.hipMemcpyAsync(
                dst_ptr,
                src_ptr,
                nbytes,
                hip.hipMemcpyKind.hipMemcpyDeviceToDeviceNoCU,
                ag_stream.cuda_stream,
            )
            HIP_CHECK(cp_res)
            # Set signal through copy-engine. This is a trick for AMD GPU.
            cp_res = hip.hipMemcpyAsync(
                barrier_buffers[remote_rank].data_ptr() + chunk_pos * barrier_elem_size,
                one.data_ptr(),
                barrier_elem_size,
                hip.hipMemcpyKind.hipMemcpyDeviceToDeviceNoCU,
                ag_stream.cuda_stream,
            )
            HIP_CHECK(cp_res)

AG GEMM Class

class triton_ag_gemm_intra_node(torch.nn.Module):

    def __init__(
        self,
        tp_group: torch.distributed.ProcessGroup,
        max_M: int,
        N: int,
        K: int,
        M_PER_CHUNK: int,
        input_dtype: torch.dtype,
        output_dtype: torch.dtype,
    ):
        self.tp_group = tp_group
        self.rank: int = tp_group.rank()
        self.world_size = tp_group.size()
        self.max_M: int = max_M
        self.N = N
        self.K = K
        self.M_PER_CHUNK = M_PER_CHUNK
        self.input_dtype = input_dtype
        self.output_dtype = output_dtype

        # Use the auxiliary functions provided by Triton-distributed to construct the context required for AG-GEMM. This simplifies the code logic.
        # The context mainly includes:
        # (1) The globally symmetric memory required for AllGather;
        # (2) The signals used for communication between AG and GEMM;
        # (3) Streams.
        self.ctx = create_ag_gemm_intra_node_context(
            self.max_M,
            self.N,
            self.K,
            self.input_dtype,
            self.output_dtype,
            self.rank,
            self.world_size,
            self.tp_group,
            M_PER_CHUNK=M_PER_CHUNK,
        )

    def forward(self, input: torch.Tensor,  # [local_M, K]
                weight: torch.Tensor,  # [local_N, K]
                transed_weight: bool,  # indicates whether weight already transposed
                ):
        current_stream = torch.cuda.current_stream()
        ctx = self.ctx
        M = self.max_M
        K = self.K
        N_PER_RANK = weight.shape[1] if transed_weight else weight.shape[0]
        # Create empty output buffer.
        output = torch.zeros([M, N_PER_RANK], dtype=input.dtype, device=input.device) + 1.0

        # Barrier before the current AG-GEMM operation. This step is necessary when running multiple rounds of computations.
        torch.cuda.synchronize()
        barrier_all_on_stream(self.rank, ctx.num_ranks, ctx.comm_buf_ptr, current_stream)

        # Sync work streams.
        for ag_stream in ctx.ag_streams:
            ag_stream.wait_stream(current_stream)

        # producer allgather
        producer_ag_push_mode(ctx.rank, ctx.num_ranks, input, ctx.workspace_tensors, ctx.one, ctx.M_PER_CHUNK,
                            ctx.ag_streams, ctx.barrier_tensors)
        torch.cuda.synchronize()
        torch.distributed.barrier()
        # consumer gemm
        NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count
        NUM_XCDS = 4

        grid = lambda META: (min(
            NUM_SMS,
            triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N_PER_RANK, META["BLOCK_SIZE_N"]),
        ), )

        full_input = ctx.workspace_tensors[ctx.rank][:M]
        local_input = input

        consumer_gemm_persistent_kernel[grid](
            full_input,
            local_input,
            weight,
            output,
            M,
            N_PER_RANK,
            K,
            full_input.stride(0),
            full_input.stride(1),
            weight.stride(1),
            weight.stride(0),
            output.stride(0),
            output.stride(1),
            ctx.rank,
            ctx.num_ranks,
            ctx.barrier_tensors[ctx.rank],
            M_PER_CHUNK=ctx.M_PER_CHUNK,
            NUM_SMS=NUM_SMS,
            NUM_XCDS=NUM_XCDS,
        )
        return output

Benchmark

def torch_ag_gemm(
    input: torch.Tensor,  # [local_M, k]
    weight: torch.Tensor,  # [local_N, K]
    transed_weight: bool,
    bias: Optional[torch.Tensor],
    TP_GROUP,
):
    local_M, K = input.shape
    world_size = TP_GROUP.size()
    if transed_weight:
        assert K == weight.shape[0]
    else:
        assert K == weight.shape[1]
        weight = weight.T
    assert input.device == weight.device
    # AG + Gemm
    full_input = torch.empty((local_M * world_size, K), dtype=input.dtype, device=input.device)
    torch.distributed.all_gather_into_tensor(full_input, input, group=TP_GROUP)
    output = torch.matmul(full_input, weight)

    if bias:
        output = output + bias

    return output


def init():
    RANK = int(os.environ.get("RANK", 0))
    LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
    WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
    torch.cuda.set_device(LOCAL_RANK)
    torch.distributed.init_process_group(
        backend="nccl",
        world_size=WORLD_SIZE,
        rank=RANK,
        timeout=datetime.timedelta(seconds=1800),
    )
    assert torch.distributed.is_initialized()
    TP_GROUP = torch.distributed.new_group(ranks=list(range(WORLD_SIZE)), backend="nccl")
    torch.distributed.barrier(TP_GROUP)

    torch.use_deterministic_algorithms(False, warn_only=True)
    torch.set_printoptions(precision=5)
    torch.manual_seed(3 + RANK)
    torch.cuda.manual_seed_all(3 + RANK)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
    torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
    np.random.seed(3 + RANK)

    torch.cuda.synchronize()
    torch.distributed.barrier()

    return RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP


def destroy():
    torch.cuda.synchronize()
    torch.distributed.barrier()
    torch.distributed.destroy_process_group()


if __name__ == "__main__":
    # init
    RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP = init()

    # NOTE: We should get device after process group init.
    DEVICE = triton.runtime.driver.active.get_active_torch_device()

    dtype = torch.float16
    M = 8192
    N = 11008
    K = 4096
    chunk_size = 256
    local_M = M // WORLD_SIZE
    local_N = N // WORLD_SIZE
    input_dtype = dtype
    output_dtype = input_dtype
    atol = 1e-2
    rtol = 1e-2

    # Generate input and weight.
    scale = TP_GROUP.rank() + 1
    data_config = [((local_M, K), dtype, (0.01 * scale, 0), DEVICE),  # input
                ((local_N, K), dtype, (0.01 * scale, 0), DEVICE),  # weight
                (None),  # bias
                ]
    generator = generate_data(data_config)
    input, weight, bias = next(generator)

    # torch
    ref_out = torch_ag_gemm(input, weight, False, bias, TP_GROUP)
    torch.cuda.synchronize()
    torch.distributed.barrier()

    # dist triton
    dist_ag_gemm_op = triton_ag_gemm_intra_node(TP_GROUP, M, N, K, chunk_size, input_dtype, output_dtype)
    tri_out = dist_ag_gemm_op.forward(input, weight, False)

    if torch.allclose(tri_out, ref_out, atol=atol, rtol=rtol):
        dist_print("✅ Triton and Torch match")
    else:
        dist_print(f"The maximum difference between torch and triton is {torch.max(torch.abs(tri_out - ref_out))}")
        dist_print("❌ Triton and Torch differ")

    # After all, destroy distributed process group.
    destroy()