.. _sphx_glr_getting-started_tutorials_09-AMD-overlapping-allgather-gemm.rst: Overlapping AllGather GEMM on AMD GPU ===================================== This script demonstrates the implementations allgather-gemm intra node with computational communication overlap using Triton-distrbuted. In this tutorial, you will write a simple fused AllGather and Gemm using Triton-distributed. In doing so, you will learn about: * How to finely partition communication (producers) and computation (consumers) and overlap them. * Intra node communication on AMD GPUs. .. code-block:: bash # To run this tutorial bash ./scripts/launch_amd.sh tutorials/09-AMD-overlapping-allgather-gemm.py Background ---------- The typical implementation of AG + Gemm is sequential, and communication and computation cannot be parallelized. As shown below: .. image:: images/09/09-01/09-01-1.png :alt: :align: center The total time consumption is t1 + t2. Notice that when the shape of GEMM is relatively large, it is always tiled into small blocks for computation on heterogeneous accelerators. This means that there are always some tiles computed first and some computed later. Therefore, we can first transfer the data that is first relied on the GEMM computation during AG, so that the GEMM computation and the AG communication can be overlapped, improving the overall MFU. As shown below, both communication and computation are divided into 4 blocks, and the blocks with the same number have dependencies: .. image:: images/09/09-02/09-02-1.png :alt: :align: center Then the effect of the overlap is as follows: .. image:: images/09/09-03/09-03-1.png :alt: :align: center Under the ideal computation-to-memory access ratio, we can hide the communication time of blocks 1, 2, and 3 in the GEMM computation time overhead. Fortunately, there is always a part of the AG communication that is local. Local communication can be omitted or has very low overhead. If the communication of block 0 only occurs locally, then the GEMM waiting for the communication of block 0 always has almost no overhead. Assume that the shapes of matrices A and B in GEMM are [M, K] and [K, N] respectively. We can partition both communication and computation along the M-axis into several chunks. Communication is carried out on a per-chunk basis, and each tile of GEMM needs to wait for the data on the chunks it depends on to arrive. We can use signals to build a producer-consumer dependency model. That is, GEMM acts as the consumer and waits for the signals that the current tile's computation depends on, while AG acts as the producer and sets the signals after the data transfer is completed. For example, assume that the size of M is 1024, the size of each chunk is 128, and the BLOCK_SIZE_M of GEMM is 256. Then each tile of GEMM needs to wait for the signals of two chunks to arrive. Kernel ------ The differences from the gemm-only kernel are: (1) Each tile needs to wait for the arrival of the dependent signals before computation; (2) Swizzling is performed on the order of the chunks for AG communication. .. code-block:: Python @triton_dist.jit def consumer_gemm_persistent_kernel( A, localA, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, rank, world_size, barrier_ptr, # signals BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, M_PER_CHUNK: tl.constexpr, NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr, EVEN_K: tl.constexpr, ): pid = tl.program_id(0) if NUM_XCDS != 1: pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n tl.assume(stride_am > 0) tl.assume(stride_ak > 0) tl.assume(stride_bn > 0) tl.assume(stride_bk > 0) tl.assume(stride_cm > 0) tl.assume(stride_cn > 0) M_per_rank = M // world_size pid_m_per_rank = tl.cdiv(M_per_rank, BLOCK_SIZE_M) acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 for tile_id in range(pid, total_tiles, NUM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m ## swizzle num_pid_m_per_copy_chunk = M_PER_CHUNK // BLOCK_SIZE_M chunk_offset = pid_m // (num_pid_m_per_copy_chunk * world_size) rank_offset = pid_m % (num_pid_m_per_copy_chunk * world_size) // num_pid_m_per_copy_chunk block_offset = pid_m % num_pid_m_per_copy_chunk # swizzle rank offset, launch current rank firstly. rank_offset = (rank_offset + rank) % world_size pid_m = (rank_offset * M_per_rank + chunk_offset * M_PER_CHUNK + block_offset * BLOCK_SIZE_M) // BLOCK_SIZE_M offs_am = pid_m * BLOCK_SIZE_M offs_sig = offs_am // M_PER_CHUNK offs_rank = pid_m // pid_m_per_rank # Skip waiting local tensor as we pass it int kernel function as argument. if offs_rank != rank: token = dl.wait(barrier_ptr + offs_sig, 1, "sys", "acquire", waitValue=1) A = dl.consume_token(A, token) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N rk = tl.arange(0, BLOCK_SIZE_K) rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) if offs_rank == rank: rm = rm % M_per_rank A_BASE = localA + rm[:, None] * stride_am + rk[None, :] * stride_ak else: A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn tl.assume(pid_m > 0) tl.assume(pid_n > 0) loop_k = tl.cdiv(K, BLOCK_SIZE_K) if not EVEN_K: loop_k -= 1 acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in range(0, loop_k): a = tl.load(tl.multiple_of(A_BASE, (1, 16))) b = tl.load(tl.multiple_of(B_BASE, (16, 1))) acc += tl.dot(a, b) A_BASE += BLOCK_SIZE_K * stride_ak B_BASE += BLOCK_SIZE_K * stride_bk if not EVEN_K: k = loop_k rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn A_BASE = tl.multiple_of(A_BASE, (1, 16)) B_BASE = tl.multiple_of(B_BASE, (16, 1)) a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) acc += tl.dot(a, b) c = acc.to(C.type.element_ty) rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) c_mask = (rm[:, None] < M) & (rn[None, :] < N) C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(C_, c, c_mask) To fully utilize the communication bandwidth of AMD GPUs, we use multiple streams for parallel processing and perform swizzling on the stream selection to enable AMD GPU links to work simultaneously. .. code-block:: Python def producer_ag_push_mode( rank, num_ranks, local_tensor: torch.Tensor, remote_tensor_buffers: List[torch.Tensor], one: torch.Tensor, M_PER_CHUNK: int, ag_stream_pool: List[torch.cuda.Stream], barrier_buffers: List[torch.Tensor], ): M_per_rank, N = local_tensor.shape chunk_num_per_rank = M_per_rank // M_PER_CHUNK num_stream = len(ag_stream_pool) rank_orders = [(rank + i) % num_ranks for i in range(num_ranks)] data_elem_size = local_tensor.element_size() barrier_elem_size = one.element_size() for idx, remote_rank in enumerate(rank_orders): if remote_rank == rank: continue for chunk_idx_intra_rank in range(chunk_num_per_rank): # Chunk swizzle chunk_pos = rank * chunk_num_per_rank + chunk_idx_intra_rank stream_pos = idx % num_stream ag_stream = ag_stream_pool[stream_pos] M_dst_start_pos = rank * M_per_rank + chunk_idx_intra_rank * M_PER_CHUNK M_src_start_pos = chunk_idx_intra_rank * M_PER_CHUNK src_ptr = local_tensor.data_ptr() + M_src_start_pos * N * data_elem_size dst_ptr = remote_tensor_buffers[remote_rank].data_ptr() + M_dst_start_pos * N * data_elem_size nbytes = M_PER_CHUNK * N * data_elem_size cp_res = hip.hipMemcpyAsync( dst_ptr, src_ptr, nbytes, hip.hipMemcpyKind.hipMemcpyDeviceToDeviceNoCU, ag_stream.cuda_stream, ) HIP_CHECK(cp_res) # Set signal through copy-engine. This is a trick for AMD GPU. cp_res = hip.hipMemcpyAsync( barrier_buffers[remote_rank].data_ptr() + chunk_pos * barrier_elem_size, one.data_ptr(), barrier_elem_size, hip.hipMemcpyKind.hipMemcpyDeviceToDeviceNoCU, ag_stream.cuda_stream, ) HIP_CHECK(cp_res) AG GEMM Class ------------- .. code-block:: Python class triton_ag_gemm_intra_node(torch.nn.Module): def __init__( self, tp_group: torch.distributed.ProcessGroup, max_M: int, N: int, K: int, M_PER_CHUNK: int, input_dtype: torch.dtype, output_dtype: torch.dtype, ): self.tp_group = tp_group self.rank: int = tp_group.rank() self.world_size = tp_group.size() self.max_M: int = max_M self.N = N self.K = K self.M_PER_CHUNK = M_PER_CHUNK self.input_dtype = input_dtype self.output_dtype = output_dtype # Use the auxiliary functions provided by Triton-distributed to construct the context required for AG-GEMM. This simplifies the code logic. # The context mainly includes: # (1) The globally symmetric memory required for AllGather; # (2) The signals used for communication between AG and GEMM; # (3) Streams. self.ctx = create_ag_gemm_intra_node_context( self.max_M, self.N, self.K, self.input_dtype, self.output_dtype, self.rank, self.world_size, self.tp_group, M_PER_CHUNK=M_PER_CHUNK, ) def forward(self, input: torch.Tensor, # [local_M, K] weight: torch.Tensor, # [local_N, K] transed_weight: bool, # indicates whether weight already transposed ): current_stream = torch.cuda.current_stream() ctx = self.ctx M = self.max_M K = self.K N_PER_RANK = weight.shape[1] if transed_weight else weight.shape[0] # Create empty output buffer. output = torch.zeros([M, N_PER_RANK], dtype=input.dtype, device=input.device) + 1.0 # Barrier before the current AG-GEMM operation. This step is necessary when running multiple rounds of computations. torch.cuda.synchronize() barrier_all_on_stream(self.rank, ctx.num_ranks, ctx.comm_buf_ptr, current_stream) # Sync work streams. for ag_stream in ctx.ag_streams: ag_stream.wait_stream(current_stream) # producer allgather producer_ag_push_mode(ctx.rank, ctx.num_ranks, input, ctx.workspace_tensors, ctx.one, ctx.M_PER_CHUNK, ctx.ag_streams, ctx.barrier_tensors) torch.cuda.synchronize() torch.distributed.barrier() # consumer gemm NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count NUM_XCDS = 4 grid = lambda META: (min( NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N_PER_RANK, META["BLOCK_SIZE_N"]), ), ) full_input = ctx.workspace_tensors[ctx.rank][:M] local_input = input consumer_gemm_persistent_kernel[grid]( full_input, local_input, weight, output, M, N_PER_RANK, K, full_input.stride(0), full_input.stride(1), weight.stride(1), weight.stride(0), output.stride(0), output.stride(1), ctx.rank, ctx.num_ranks, ctx.barrier_tensors[ctx.rank], M_PER_CHUNK=ctx.M_PER_CHUNK, NUM_SMS=NUM_SMS, NUM_XCDS=NUM_XCDS, ) return output Benchmark --------- .. code-block:: Python def torch_ag_gemm( input: torch.Tensor, # [local_M, k] weight: torch.Tensor, # [local_N, K] transed_weight: bool, bias: Optional[torch.Tensor], TP_GROUP, ): local_M, K = input.shape world_size = TP_GROUP.size() if transed_weight: assert K == weight.shape[0] else: assert K == weight.shape[1] weight = weight.T assert input.device == weight.device # AG + Gemm full_input = torch.empty((local_M * world_size, K), dtype=input.dtype, device=input.device) torch.distributed.all_gather_into_tensor(full_input, input, group=TP_GROUP) output = torch.matmul(full_input, weight) if bias: output = output + bias return output def init(): RANK = int(os.environ.get("RANK", 0)) LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) torch.cuda.set_device(LOCAL_RANK) torch.distributed.init_process_group( backend="nccl", world_size=WORLD_SIZE, rank=RANK, timeout=datetime.timedelta(seconds=1800), ) assert torch.distributed.is_initialized() TP_GROUP = torch.distributed.new_group(ranks=list(range(WORLD_SIZE)), backend="nccl") torch.distributed.barrier(TP_GROUP) torch.use_deterministic_algorithms(False, warn_only=True) torch.set_printoptions(precision=5) torch.manual_seed(3 + RANK) torch.cuda.manual_seed_all(3 + RANK) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False np.random.seed(3 + RANK) torch.cuda.synchronize() torch.distributed.barrier() return RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP def destroy(): torch.cuda.synchronize() torch.distributed.barrier() torch.distributed.destroy_process_group() if __name__ == "__main__": # init RANK, LOCAL_RANK, WORLD_SIZE, TP_GROUP = init() # NOTE: We should get device after process group init. DEVICE = triton.runtime.driver.active.get_active_torch_device() dtype = torch.float16 M = 8192 N = 11008 K = 4096 chunk_size = 256 local_M = M // WORLD_SIZE local_N = N // WORLD_SIZE input_dtype = dtype output_dtype = input_dtype atol = 1e-2 rtol = 1e-2 # Generate input and weight. scale = TP_GROUP.rank() + 1 data_config = [((local_M, K), dtype, (0.01 * scale, 0), DEVICE), # input ((local_N, K), dtype, (0.01 * scale, 0), DEVICE), # weight (None), # bias ] generator = generate_data(data_config) input, weight, bias = next(generator) # torch ref_out = torch_ag_gemm(input, weight, False, bias, TP_GROUP) torch.cuda.synchronize() torch.distributed.barrier() # dist triton dist_ag_gemm_op = triton_ag_gemm_intra_node(TP_GROUP, M, N, K, chunk_size, input_dtype, output_dtype) tri_out = dist_ag_gemm_op.forward(input, weight, False) if torch.allclose(tri_out, ref_out, atol=atol, rtol=rtol): dist_print("✅ Triton and Torch match") else: dist_print(f"The maximum difference between torch and triton is {torch.max(torch.abs(tri_out - ref_out))}") dist_print("❌ Triton and Torch differ") # After all, destroy distributed process group. destroy()