.. _sphx_glr_getting-started_tutorials_07-overlapping-allgather-gemm.rst: Overlapping AllGather GEMM ========================== In this tutorial, you will write a simple Allgather GEMM fusion kernel using Triton-distributed. In doing so, you will learn about: * Writing a GEMM kernel that consume the results of AllGather. * Optimizing the internode communication with 2D Allgather. .. code-block:: bash # To run this tutorial source ./scripts/setenv.sh bash ./scripts/launch.sh tutorials/07-overlapping-allgather-gemm.py GEMM Kernel ----------- In the AllGather kernel, we previously output signals that can be used to interact with computation kernels. For instance, in an AllGather GEMM example, we can achieve overlapping computation and communication. The GEMM itself doesn't require extensive modifications, you just need to add a small amount of code compared to Triton's default GEMM. .. code-block:: Python def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret @triton_dist.jit(launch_metadata=_matmul_launch_metadata) def kernel_consumer_gemm_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # rank: tl.constexpr, num_ranks: tl.constexpr, ready_ptr, comm_buf_ptr, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # NUM_SMS: tl.constexpr, ready_value: tl.constexpr = 1, local_world_size: tl.constexpr = 8): # # Matmul using TMA and device-side descriptor creation dtype = c_ptr.dtype.element_ty start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n node_id = rank // local_world_size nnodes = num_ranks // local_world_size a_desc = tl.make_tensor_descriptor( a_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) b_desc = tl.make_tensor_descriptor( b_ptr, shape=[N, K], strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) c_desc = tl.make_tensor_descriptor( c_ptr, shape=[M, N], strides=[N, 1], block_shape=[ BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2, ], ) tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: tiles_per_SM += 1 tile_id = start_pid - NUM_SMS ki = -1 pid_m = 0 pid_n = 0 offs_am = 0 offs_bn = 0 M_per_rank = M // num_ranks pid_ms_per_rank = tl.cdiv(M_per_rank, BLOCK_SIZE_M) num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: tile_id += NUM_SMS group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m # swizzle m if nnodes == 1: alpha = 0 beta = 0 pid_m = (pid_m + ((((rank ^ alpha) + beta) % num_ranks) * pid_ms_per_rank)) % num_pid_m else: m_rank = pid_m // pid_ms_per_rank pid_m_intra_rank = pid_m - m_rank * pid_ms_per_rank m_node_id = m_rank // local_world_size m_local_rank = m_rank % local_world_size swizzle_m_node_id = (m_node_id + node_id) % nnodes swizzle_m_local_rank = (m_local_rank + rank) % local_world_size swizzle_m_rank = swizzle_m_node_id * local_world_size + swizzle_m_local_rank pid_m = swizzle_m_rank * pid_ms_per_rank + pid_m_intra_rank offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N rank_beg = offs_am // M_per_rank rank_end = (min(offs_am + BLOCK_SIZE_M, M) - 1) // M_per_rank # Each tile wait for the corresponding data to ready token = dl.wait(ready_ptr + rank_beg, rank_end - rank_beg + 1, "gpu", "acquire", waitValue=ready_value) a_desc = dl.consume_token(a_desc, token) offs_k = ki * BLOCK_SIZE_K # Iteration along k-dimension, and performing multiply. a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) acc0, acc1 = tl.split(acc) c0 = acc0.to(dtype) c_desc.store([offs_am, offs_bn], c0) c1 = acc1.to(dtype) c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c1) else: c = accumulator.to(dtype) c_desc.store([offs_am, offs_bn], c) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) AllGather Kernel ---------------- To fully utilize the bandwidth, internode AllGather Kernel is composed of two parts considering the bandwidth gap between intra-node links and inter-node links. .. code-block:: Python def inter_node_allgather(local_tensor: torch.Tensor, ag_buffer: list[torch.Tensor], signal_buffer: list[torch.Tensor], signal_target, rank, local_world_size, world_size, intranode_ag_stream, internode_ag_stream): local_rank = rank % local_world_size n_nodes = world_size // local_world_size M_per_rank, N = local_tensor.shape # Each rank sends the local_tensor to ranks of other nodes with the same local_rank # Assuming there are 2 nodes, each with 4 workers # 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0]) # 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1]) # 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2]) # 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3]) with torch.cuda.stream(internode_ag_stream): grid = lambda META: (int(n_nodes - 1), ) nvshmem_device_producer_p2p_put_block_kernel[grid]( ag_buffer[local_rank], signal_buffer[local_rank], M_per_rank * N, local_tensor.element_size(), signal_target, rank, local_world_size, world_size, num_warps=32, # each sm launches 1024 threads ) # Each rank sends the local_tensor and the received internode tensors to intranode ranks. # 0-th and 4-th local tensors ([0]->[1,2,3]) # 1-th and 5-th local tensors ([1]->[0,2,3]) # 2-th and 6-th local tensors ([2]->[0,1,3]) # 3-th and 7-th local tensors ([3]->[0,1,2]) # 0-th and 4-th local tensors ([4]->[5,6,7]) # 1-th and 5-th local tensors ([5]->[4,6,7]) # 2-th and 6-th local tensors ([6]->[4,5,7]) # 3-th and 7-th local tensors ([7]->[4,5,6]) with torch.cuda.stream(intranode_ag_stream): cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, local_world_size, world_size, intranode_ag_stream) intranode_ag_stream.wait_stream(internode_ag_stream) Let's declare a function to perform internode communication. .. code-block:: Python @triton_dist.jit def nvshmem_device_producer_p2p_put_block_kernel( ag_buffer_ptr, # *Pointer* to allgather output vector. The rank-th index has been loaded with local tensor signal_buffer_ptr, # *Pointer* to signal barrier. elem_per_rank, size_per_elem, signal_target, rank, local_world_size, world_size, ): pid = tl.program_id(axis=0) num_pid = tl.num_programs(axis=0) n_nodes = world_size // local_world_size local_rank = rank % local_world_size node_rank = rank // local_world_size for i in range(pid, n_nodes - 1, num_pid): # Each SM is assigned to one peer. # Peer id is caculated based on pid and local_rank. peer = local_rank + (node_rank + i + 1) % n_nodes * local_world_size # We use putmem_signal_block to send data and notify the peer. # Since this is the allgather operation, the offsets of both src and dst tensor are both *rank*. libshmem_device.putmem_signal_block( ag_buffer_ptr + rank * elem_per_rank, ag_buffer_ptr + rank * elem_per_rank, elem_per_rank * size_per_elem, signal_buffer_ptr + rank, signal_target, libshmem_device.NVSHMEM_SIGNAL_SET, peer, ) Let's also declare a function to perform intranode communication. .. code-block:: Python def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, local_world_size, world_size, intranode_ag_stream): local_rank = rank % local_world_size n_nodes = world_size // local_world_size node_rank = rank // local_world_size for i in range(1, local_world_size): segment = rank * M_per_rank * N local_dst_rank = (local_rank + local_world_size - i) % local_world_size src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() # Using copy engine to perform intranode transmission # Sending rank-th local tensor to other ranks inside the node. (err, ) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, M_per_rank * N * local_tensor.element_size(), cudart.cudaMemcpyKind.cudaMemcpyDefault, intranode_ag_stream.cuda_stream, ) # Notify the peer that the transmission is done. set_signal(signal_buffer[local_dst_rank][rank].data_ptr(), signal_target, intranode_ag_stream, True) for i in range(1, n_nodes): recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size recv_segment = recv_rank * M_per_rank * N # Waiting for the internode data ready wait_eq(signal_buffer[local_rank][recv_rank].data_ptr(), signal_target, intranode_ag_stream, True) src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() for j in range(1, local_world_size): local_dst_rank = (local_rank + local_world_size - j) % local_world_size dst_ptr = ag_buffer[local_dst_rank].data_ptr() + recv_segment * local_tensor.element_size() # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. (err, ) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, M_per_rank * N * local_tensor.element_size(), cudart.cudaMemcpyKind.cudaMemcpyDefault, intranode_ag_stream.cuda_stream, ) # Notify the peer that the transmission is done. set_signal(signal_buffer[local_dst_rank][recv_rank].data_ptr(), signal_target, intranode_ag_stream, True) AllGather GEMM Kernel --------------------- Now we combine all the kernels here. .. code-block:: Python def ag_gemm_persistent_op(a, b, c, rank, num_ranks, workspace_tensors, barrier_tensors, comm_buf, ag_stream=None, internode_ag_stream=None, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, stages=3, local_world_size=8, signal_target=1): assert a.shape[1] == b.shape[1], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" M_per_rank, K = a.shape M = M_per_rank * num_ranks N_per_rank, K = b.shape local_rank = rank % local_world_size n_nodes = num_ranks // local_world_size num_ag_sms = n_nodes - 1 # only use n_node-1 SMs for internode communication num_gemm_sms = torch.cuda.get_device_properties("cuda").multi_processor_count - num_ag_sms ag_stream = torch.cuda.Stream() if ag_stream is None else ag_stream current_stream = torch.cuda.current_stream() ag_stream.wait_stream(current_stream) inter_node_allgather(a, workspace_tensors, barrier_tensors, signal_target, rank, local_world_size, num_ranks, ag_stream, internode_ag_stream) compiled = None def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(alloc_fn) grid = lambda META: (min( num_gemm_sms, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N_per_rank, META["BLOCK_SIZE_N"]), ), ) compiled = kernel_consumer_gemm_persistent[grid]( workspace_tensors[local_rank][:M], b, c, # M, N_per_rank, K, # rank, num_ranks, barrier_tensors[local_rank], comm_buf, BLOCK_M, BLOCK_N, BLOCK_K, 8, False, NUM_SMS=num_gemm_sms, ready_value=signal_target, num_stages=stages, num_warps=8, ) current_stream.wait_stream(internode_ag_stream) current_stream.wait_stream(ag_stream) return compiled Benchmark --------- .. code-block:: Python def torch_ag_gemm( pg: torch.distributed.ProcessGroup, local_input: torch.Tensor, local_weight: torch.Tensor, ag_out: torch.Tensor, ): torch.distributed.all_gather_into_tensor(ag_out, local_input, pg) ag_gemm_output = torch.matmul(ag_out, local_weight) return ag_gemm_output if __name__ == "__main__": WORLD_SIZE = int(os.getenv("WORLD_SIZE", "-1")) LOCAL_WORLD_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "-1")) if WORLD_SIZE == LOCAL_WORLD_SIZE: print("Skip the test because this should be performed with 2 nodes or higher") import sys sys.exit() if torch.cuda.get_device_capability()[0] < 9: print("Skip the test because the device is not sm90 or higher") import sys sys.exit() TP_GROUP = initialize_distributed() rank = TP_GROUP.rank() M = 8192 N = 49152 K = 12288 config = {"BM": 128, "BN": 256, "BK": 64, "stage": 3} dtype = torch.float16 assert M % WORLD_SIZE == 0 assert N % WORLD_SIZE == 0 M_per_rank = M // WORLD_SIZE N_per_rank = N // WORLD_SIZE A = torch.randn([M_per_rank, K], dtype=dtype, device="cuda") B = torch.randn([N_per_rank, K], dtype=dtype, device="cuda") ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda") golden = torch_ag_gemm(TP_GROUP, A, B.T, ag_buffer) # We can use a context to wrap all the tensors used at runtime. # We rely on NVSHMEM to allocate the symmetric memory for communication # In practice, the following parts are encapsulated in ag_gemm_inter_node() of triton_dist.kernels.nvidia.allgather_gemm.py C = torch.empty([M, N_per_rank], dtype=dtype, device="cuda") ctx = create_ag_gemm_context(A, B, rank, WORLD_SIZE, max_M=M, BLOCK_M=config["BM"], BLOCK_N=config["BN"], BLOCK_K=config["BK"], stages=config["stage"]) ctx.symm_barrier.fill_(0) nvshmem_barrier_all_on_stream(torch.cuda.current_stream()) # copy local data to the ctx ctx.symm_workspace[rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) set_signal(ctx.symm_barrier[rank].data_ptr(), 1, torch.cuda.current_stream(), True) # launch the ag_gemm kernel ag_gemm_persistent_op(A, B, C, ctx.rank, ctx.num_ranks, ctx.symm_workspaces, ctx.symm_barriers, ctx.symm_comm_buf, ag_stream=ctx.ag_intranode_stream, internode_ag_stream=ctx.ag_internode_stream, local_world_size=LOCAL_WORLD_SIZE, signal_target=1) assert torch.allclose(golden, C, atol=1e-3, rtol=1e-3) print("Pass!") ctx.finalize() nvshmem.core.finalize() torch.distributed.destroy_process_group()