Distributed Notify and Wait

In this tutorial, you will write a simple notify and wait example using Triton-distributed.

In doing so, you will learn about:

  • The signal exchange concept within a single node using Triton-distributed.

  • The wait, consume_token, notify primitives, which is used to do signal exchange.

  • The distributed runtime initialization and symmetric tensor management.

  • How to write producer-consumer data transfer through a small queue

# To run this tutorial
source ./scripts/setenv.sh
bash ./scripts/launch.sh tutorials/01-distributed-notify-wait.py

Kernel

In this example, the kernel is divided into two parts: one part acts as a producer and the other as a consumer. They transfer data between each other via a small, symmetric memory buffer, requiring handshaking to ensure the buffer’s positions are available for the producer to write to and for the consumer to read from.

import torch
import nvshmem.core

import triton.language as tl
import triton_dist.language as dl
from triton_dist.utils import (
    NVSHMEM_SIGNAL_DTYPE,
    dist_print,
    initialize_distributed,
    nvshmem_barrier_all_on_stream,
    nvshmem_free_tensor_sync,
    nvshmem_create_tensor,
)
from triton_dist.language.extra.language_extra import __syncthreads
import triton_dist


@triton_dist.jit
def producer_consumer_kernel(
    rank: tl.constexpr,
    num_ranks: tl.constexpr,
    input_ptr,
    output_ptr,
    num_inputs: int,
    queue_ptr,
    signal_ptr,  # *Pointer* to signals.
    queue_size: tl.constexpr,  # The length of queue in unit of BLOCKs
    BLOCK_SIZE: tl.constexpr,  # The size of each BLOCK
    NUM_PRODUCER_SMS: tl.constexpr,
    NUM_CONSUMER_SMS: tl.constexpr,
):
    pid = tl.program_id(0)
    # This kernel issues async-tasks to two group of blocks
    if pid < NUM_PRODUCER_SMS:
        #
        # Producer
        #
        peer_rank = (rank + 1) % num_ranks  # Peer is the next rank
        offs = tl.arange(0, BLOCK_SIZE)
        for i in range(pid, num_inputs, NUM_PRODUCER_SMS):
            queue_offset = i % queue_size
            queue_repeat = i // queue_size
            token = dl.wait(
                # Use `symm_at` to map the data pointer to remote peer rank
                dl.symm_at(signal_ptr, peer_rank) + queue_offset,
                1,  # The number of signals to wait
                "sys",  # The scope of the barrier, `gpu` or `sys`
                "acquire",  # The semantic of the wait
                waitValue=queue_repeat * 2,  # The value expected, should conform to certain order
            )  # This wait ensures that the corresponding position is empty
            input_ptr = dl.consume_token(input_ptr, token)  # consume the token to make sure the `wait` is needed
            data = tl.load(input_ptr + i * BLOCK_SIZE + offs)
            # Use `symm_at` to map the data pointer to remote peer rank
            tl.store(dl.symm_at(queue_ptr, peer_rank) + queue_offset * BLOCK_SIZE + offs, data)
            # need a syncthreads to make sure all the data has been sent
            __syncthreads()
            # `notify` is also single thread scope, we need to use a different thread from `wait`
            dl.notify(signal_ptr + queue_offset, peer_rank,  # Notify the signal object on the peer rank
                    signal=queue_repeat * 2 + 1,  # Write a value to the signal
                    sig_op="set",  # Set the value. Another choice is `add`
                    comm_scope="intra_node",  # This example is intra-node, another choice is `inter_node`
                    )  # This notifies the consumer that the data is ready
    elif pid < NUM_PRODUCER_SMS + NUM_CONSUMER_SMS:
        #
        # Consumer
        #
        pid = pid - NUM_PRODUCER_SMS
        offs = tl.arange(0, BLOCK_SIZE)
        for i in range(pid, num_inputs, NUM_CONSUMER_SMS):
            queue_offset = i % queue_size
            queue_repeat = i // queue_size
            token = dl.wait(signal_ptr + queue_offset,  # The base *Pointer* of signals at the current rank
                            1,  # The number of signals to wait
                            "sys",  # The scope of the barrier
                            "acquire",  # The semantic of the wait
                            waitValue=queue_repeat * 2 + 1,  # The value expected
                            )  # This wait ensures that the corresponding position is full
            queue_ptr = dl.consume_token(queue_ptr, token)
            data = tl.load(queue_ptr + queue_offset * BLOCK_SIZE + offs)
            tl.store(output_ptr + i * BLOCK_SIZE + offs, data)
            __syncthreads()
            dl.notify(signal_ptr + queue_offset, rank,  # Notify the signal object on the current rank
                    signal=queue_repeat * 2 + 2,  # Write a value to the signal
                    sig_op="set",  # Set the value. Another choice is `add`
                    comm_scope="intra_node",  # This example is intra-node, another choice is `inter_node`
                    )  # This notifies the consumer that the data is ready
    else:
        pass

Initialize the Distributed System

Here, we show you how to initialize a distributed system for our example. Triton-distributed provides a convenient initialize_distributed() function that handles:

  • Setting up PyTorch distributed with NCCL backend

  • Initializing NVSHMEM with a unique ID

  • Creating the tensor parallel group

# Simply import and call initialize_distributed()
from triton_dist.utils import initialize_distributed

TP_GROUP = initialize_distributed()

Test the Correctness

Let’s now check our notify and wait kernel for correctness.

INPUT_SIZE = 2025  # A large input size
QUEUE_SIZE = 32  # Queue is smaller than input size
BLOCK_SIZE = 128


def main(TP_GROUP):
    stream = torch.cuda.current_stream()
    # The created tensor is by-default on current cuda device
    # Use nvshmem_create_tensor for symmetric memory allocation
    queue = nvshmem_create_tensor(
        (QUEUE_SIZE * BLOCK_SIZE, ),  # Shape on each device
        torch.float32
    )
    # Use NVSHMEM_SIGNAL_DTYPE for signal tensors (uint64)
    signal = nvshmem_create_tensor((QUEUE_SIZE, ), NVSHMEM_SIGNAL_DTYPE)
    queue.fill_(-1)
    signal.fill_(0)  # The initial value of signal should be 0s
    # You need a barrier all to make sure the above initialization
    # is visible to all the other ranks.
    # This is usually used for intra-node.
    nvshmem_barrier_all_on_stream(stream)

    # Distributed info
    rank = TP_GROUP.rank()
    num_ranks = TP_GROUP.size()

    # Prepare torch local data
    input_data = torch.randn((INPUT_SIZE * BLOCK_SIZE, ), dtype=torch.float32).cuda()
    output_data = torch.empty_like(input_data)

    NUM_REPEAS = 20
    # For distributed programming, you have to run it multiple times to ensure
    # your program is correct, including reseting signals, avoiding racing, etc.
    for iters in range(NUM_REPEAS):
        input_data = torch.randn((INPUT_SIZE * BLOCK_SIZE, ), dtype=torch.float32).cuda()
        # Need to reset the barrier every time, you may also omit this step for better performance
        # by using flipping barriers. We will cover this optimization in future tutorial.
        # TODO: tutorial for flipping barriers.
        signal.fill_(0)
        nvshmem_barrier_all_on_stream(stream)

        producer_consumer_kernel[(20, )](  # use 20 SMs
            rank,
            num_ranks,
            input_data,
            output_data,
            INPUT_SIZE,
            queue,
            signal,
            QUEUE_SIZE,
            BLOCK_SIZE,
            16,  # 16 SMs for producer
            4,  # 4 SMs for consumer
            num_warps=4,
        )

        # Check results
        inputs_all_ranks = [torch.empty_like(input_data) for _ in range(num_ranks)]
        torch.distributed.all_gather(inputs_all_ranks, input_data, group=TP_GROUP)
        golden = inputs_all_ranks[(rank - 1 + num_ranks) % num_ranks]
        if iters == NUM_REPEAS - 1:
            dist_print(f"rank{rank}", output_data, need_sync=True, allowed_ranks=list(range(num_ranks)))
            dist_print(f"rank{rank}", golden, need_sync=True, allowed_ranks=list(range(num_ranks)))
        assert torch.allclose(output_data, golden, atol=1e-5, rtol=1e-5)
        if iters == NUM_REPEAS - 1:
            dist_print(f"rank{rank} Passed✅!", need_sync=True, allowed_ranks=list(range(num_ranks)))

    # Clean up symmetric memory
    nvshmem_free_tensor_sync(queue)
    nvshmem_free_tensor_sync(signal)


# Initialize the distributed system
TP_GROUP = initialize_distributed()
# The main function
main(TP_GROUP)
# Finalize
nvshmem.core.finalize()
torch.distributed.destroy_process_group()