Distributed Notify and Wait
In this tutorial, you will write a simple notify and wait example using Triton-distributed.
In doing so, you will learn about:
The signal exchange concept within a single node using Triton-distributed.
The wait, consume_token, notify primitives, which is used to do signal exchange.
The distributed runtime initialization and symmetric tensor management.
How to write producer-consumer data transfer through a small queue
# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/01-distributed-notify-wait.py
Kernel
In this example, the kernel is divided into two parts: one part acts as a producer and the other as a consumer. They transfer data between each other via a small, symmetric memory buffer, requiring handshaking to ensure the buffer’s positions are available for the producer to write to and for the consumer to read from.
import torch
import nvshmem.core
import triton.language as tl
import triton_dist.language as dl
from triton_dist.utils import (
NVSHMEM_SIGNAL_DTYPE,
dist_print,
initialize_distributed,
nvshmem_barrier_all_on_stream,
nvshmem_free_tensor_sync,
nvshmem_create_tensor,
)
from triton_dist.language.extra.language_extra import __syncthreads
import triton_dist
@triton_dist.jit
def producer_consumer_kernel(
rank: tl.constexpr,
num_ranks: tl.constexpr,
input_ptr,
output_ptr,
num_inputs: int,
queue_ptr,
signal_ptr, # *Pointer* to signals.
queue_size: tl.constexpr, # The length of queue in unit of BLOCKs
BLOCK_SIZE: tl.constexpr, # The size of each BLOCK
NUM_PRODUCER_SMS: tl.constexpr,
NUM_CONSUMER_SMS: tl.constexpr,
):
pid = tl.program_id(0)
# This kernel issues async-tasks to two group of blocks
if pid < NUM_PRODUCER_SMS:
#
# Producer
#
peer_rank = (rank + 1) % num_ranks # Peer is the next rank
offs = tl.arange(0, BLOCK_SIZE)
for i in range(pid, num_inputs, NUM_PRODUCER_SMS):
queue_offset = i % queue_size
queue_repeat = i // queue_size
token = dl.wait(
# Use `symm_at` to map the data pointer to remote peer rank
dl.symm_at(signal_ptr, peer_rank) + queue_offset,
1, # The number of signals to wait
"sys", # The scope of the barrier, `gpu` or `sys`
"acquire", # The semantic of the wait
waitValue=queue_repeat * 2, # The value expected, should conform to certain order
) # This wait ensures that the corresponding position is empty
input_ptr = dl.consume_token(input_ptr, token) # consume the token to make sure the `wait` is needed
data = tl.load(input_ptr + i * BLOCK_SIZE + offs)
# Use `symm_at` to map the data pointer to remote peer rank
tl.store(dl.symm_at(queue_ptr, peer_rank) + queue_offset * BLOCK_SIZE + offs, data)
# need a syncthreads to make sure all the data has been sent
__syncthreads()
# `notify` is also single thread scope, we need to use a different thread from `wait`
dl.notify(signal_ptr + queue_offset, peer_rank, # Notify the signal object on the peer rank
signal=queue_repeat * 2 + 1, # Write a value to the signal
sig_op="set", # Set the value. Another choice is `add`
comm_scope="intra_node", # This example is intra-node, another choice is `inter_node`
) # This notifies the consumer that the data is ready
elif pid < NUM_PRODUCER_SMS + NUM_CONSUMER_SMS:
#
# Consumer
#
pid = pid - NUM_PRODUCER_SMS
offs = tl.arange(0, BLOCK_SIZE)
for i in range(pid, num_inputs, NUM_CONSUMER_SMS):
queue_offset = i % queue_size
queue_repeat = i // queue_size
token = dl.wait(signal_ptr + queue_offset, # The base *Pointer* of signals at the current rank
1, # The number of signals to wait
"sys", # The scope of the barrier
"acquire", # The semantic of the wait
waitValue=queue_repeat * 2 + 1, # The value expected
) # This wait ensures that the corresponding position is full
queue_ptr = dl.consume_token(queue_ptr, token)
data = tl.load(queue_ptr + queue_offset * BLOCK_SIZE + offs)
tl.store(output_ptr + i * BLOCK_SIZE + offs, data)
__syncthreads()
dl.notify(signal_ptr + queue_offset, rank, # Notify the signal object on the current rank
signal=queue_repeat * 2 + 2, # Write a value to the signal
sig_op="set", # Set the value. Another choice is `add`
comm_scope="intra_node", # This example is intra-node, another choice is `inter_node`
) # This notifies the consumer that the data is ready
else:
pass
Initialize the Distributed System
Here, we show you how to initialize a distributed system for our example. Triton-distributed provides a convenient initialize_distributed() function that handles:
Setting up PyTorch distributed with NCCL backend
Initializing NVSHMEM with a unique ID
Creating the tensor parallel group
# Simply import and call initialize_distributed()
from triton_dist.utils import initialize_distributed
TP_GROUP = initialize_distributed()
Test the Correctness
Let’s now check our notify and wait kernel for correctness.
INPUT_SIZE = 2025 # A large input size
QUEUE_SIZE = 32 # Queue is smaller than input size
BLOCK_SIZE = 128
def main(TP_GROUP):
stream = torch.cuda.current_stream()
# The created tensor is by-default on current cuda device
# Use nvshmem_create_tensor for symmetric memory allocation
queue = nvshmem_create_tensor(
(QUEUE_SIZE * BLOCK_SIZE, ), # Shape on each device
torch.float32
)
# Use NVSHMEM_SIGNAL_DTYPE for signal tensors (uint64)
signal = nvshmem_create_tensor((QUEUE_SIZE, ), NVSHMEM_SIGNAL_DTYPE)
queue.fill_(-1)
signal.fill_(0) # The initial value of signal should be 0s
# You need a barrier all to make sure the above initialization
# is visible to all the other ranks.
# This is usually used for intra-node.
nvshmem_barrier_all_on_stream(stream)
# Distributed info
rank = TP_GROUP.rank()
num_ranks = TP_GROUP.size()
# Prepare torch local data
input_data = torch.randn((INPUT_SIZE * BLOCK_SIZE, ), dtype=torch.float32).cuda()
output_data = torch.empty_like(input_data)
NUM_REPEAS = 20
# For distributed programming, you have to run it multiple times to ensure
# your program is correct, including reseting signals, avoiding racing, etc.
for iters in range(NUM_REPEAS):
input_data = torch.randn((INPUT_SIZE * BLOCK_SIZE, ), dtype=torch.float32).cuda()
# Need to reset the barrier every time, you may also omit this step for better performance
# by using flipping barriers. We will cover this optimization in future tutorial.
# TODO: tutorial for flipping barriers.
signal.fill_(0)
nvshmem_barrier_all_on_stream(stream)
producer_consumer_kernel[(20, )]( # use 20 SMs
rank,
num_ranks,
input_data,
output_data,
INPUT_SIZE,
queue,
signal,
QUEUE_SIZE,
BLOCK_SIZE,
16, # 16 SMs for producer
4, # 4 SMs for consumer
num_warps=4,
)
# Check results
inputs_all_ranks = [torch.empty_like(input_data) for _ in range(num_ranks)]
torch.distributed.all_gather(inputs_all_ranks, input_data, group=TP_GROUP)
golden = inputs_all_ranks[(rank - 1 + num_ranks) % num_ranks]
if iters == NUM_REPEAS - 1:
dist_print(f"rank{rank}", output_data, need_sync=True, allowed_ranks=list(range(num_ranks)))
dist_print(f"rank{rank}", golden, need_sync=True, allowed_ranks=list(range(num_ranks)))
assert torch.allclose(output_data, golden, atol=1e-5, rtol=1e-5)
if iters == NUM_REPEAS - 1:
dist_print(f"rank{rank} Passed✅!", need_sync=True, allowed_ranks=list(range(num_ranks)))
# Clean up symmetric memory
nvshmem_free_tensor_sync(queue)
nvshmem_free_tensor_sync(signal)
# Initialize the distributed system
TP_GROUP = initialize_distributed()
# The main function
main(TP_GROUP)
# Finalize
nvshmem.core.finalize()
torch.distributed.destroy_process_group()