End-to-End Model Level Demo

This document provides an end-to-end (E2E) integration for Triton-Distributed. It is designed to showcase how to integrate Triton-Distributed’s high-performance distributed kernels into a complete LLM, using Qwen3-32B as a reference example. The demo covers the tensor parallel implementation and performance testing from individual layers (Attention, MLP) to the entire model.

../../_images/e2e_qwen_32b.png

Features

  • Two Strategies for Tensor Parallelism (TP):

    • Utilizes AllGather-GEMM and GEMM-ReduceScatter kernels. The input is sharded along the batch dimension, and communication is highly overlapped with computation.

    • Employs GEMM-AllReduce. The input is replicated across all devices.

  • Layer-wise Module Implementation: Provides TP_Attn and TP_MLP modules that can easily replace corresponding layers in existing models to enable distributed parallelism.

  • Full Model Integration: Demonstrates how to seamlessly integrate the parallel modules into a dense model, using Qwen3-32B as an example. We also include a complete inference Engine with CUDA Graph integration.

Perf on 8xH800: Large tensor shapes are best suited for a pipelined AllGather-GEMM + GEMM-ReduceScatter to overlap computation and communication, while smaller shapes are more efficient with GEMM-AllReduce .

  • AllGather-GEMM + GEMM-ReduceScatter

Test Case Parameters Torch AR (ms) Dist-Triton (ms) Speedup
MLP M=2048 1.076972 0.8854406 1.216
Attn Prefill bsz=32, ctx=128 0.71913 0.748670 0.961*
Attn Decode bsz=4096, ctx=128 1.29802 1.31813 0.985*
E2E Model Prefill bsz=32, ctx=128 123.3569 104.2794 1.183
E2E Model Decode bsz=4096, ctx=128 160.1424 140.393 1.141

*The items marked with an asterisk show negative performance gains (i.e., slower speeds). This is because the shape of the weight tensors in the Attention computations is very small. For small-sized tensors, the additional overhead of splitting the communication operation into AllGather and ReduceScatter outweighs the gains from overlapping the computations, so the performance is worse than PyTorch’s single AllReduce operation.

  • GEMM-AllReduce

Test Case Parameters Torch AR (ms) Triton Dist AR (ms) Speedup
MLP M=2048 0.6012 0.4756 1.26x
Attn Prefill bsz=1, ctx=128 0.1292 0.0900 1.44x
Attn Decode bsz=128, ctx=128 0.1435 0.1036 1.39x
E2E Model Prefill bsz=1, ctx=128 15.78 11.70 1.35x
E2E Model Decode bsz=128, ctx=128 16.54 12.41 1.33x

Perf on 8xMI308X:

Test Case Parameters Torch AR (ms) Dist-Triton (ms) Speedup
AG_GEMM M=4096 1.8047 1.8002 1.0025x
GEMM_RS M=4096 1.057 0.837 1.2627x
MLP M=4096 3.019 2.829 1.067x
Attn Prefill bsz=32, ctx=128 1.555 1.50833 1.0312x
Attn Decode bsz=4096, ctx=128 3.3783 3.2765 1.0310x

Environment Setup

First, run the following scripts to install the necessary dependencies and configure your environment variables.

# Build the environment and install dependencies
bash ./scripts/build_e2e_env.sh

Running the Demos

We provide a set of test scripts for various use cases.

1. Layer-Level Benchmarks

These scripts are used to benchmark the performance of the TP_Attn and TP_MLP layers in isolation.

MLP Layer (test_tp_mlp.py)

AG_GEMM + GEMM_RS Mode: This command benchmarks the performance of ag_gemm + gemm_rs. The input tensor x’s M dimension (batch_size * seq_len) is sharded across GPUs.

bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_mlp.py --M 4096 --model Qwen/Qwen3-32B --mode ag_rs

AllReduce Mode: Use the --mode gemm_ar flag to switch to the GEMM-AllReduce paradigm. In this mode, the input is replicated on all GPUs.

NVSHMEM_DISABLE_CUDA_VMM=0 bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_mlp.py --M 2048 --model Qwen/Qwen3-32B --mode gemm_ar

Attention Layer (test_tp_attn.py)

The Attention layer benchmark is divided into prefill and decode modes.

AG_GEMM + GEMM_RS Mode:

# prefill
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_attn.py --bsz 32 --seq_len 128 --model Qwen/Qwen3-32B --run_type prefill --mode ag_rs

# decode
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_attn.py --bsz 4096 --seq_len 128 --model Qwen/Qwen3-32B --run_type decode --mode ag_rs

GEMM-AllReduce Mode:

# prefill
NVSHMEM_DISABLE_CUDA_VMM=0 bash scripts/launch.sh python/triton_dist/test/nvidia/test_tp_attn.py --bsz 8 --seq_len 128 --model Qwen/Qwen3-32B --run_type prefill --mode gemm_ar

# decode
NVSHMEM_DISABLE_CUDA_VMM=0 bash scripts/launch.sh python/triton_dist/test/nvidia/test_tp_attn.py --bsz 128 --seq_len 128 --model Qwen/Qwen3-32B --run_type decode --mode gemm_ar

2. Model-Level End-to-End Tests (test_tp_e2e.py)

This script tests a single forward pass of the complete Qwen3 model, which can be used for correctness validation or performance evaluation.

Correctness Check (--check): This mode compares the output of the Triton-Distributed implementation against the native PyTorch eager mode implementation to ensure numerical consistency.

# AG_GEMM + GEMM_RS Mode
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 8 --seq_len 256 --model Qwen/Qwen3-32B --check --mode ag_rs

# GEMM-AllReduce Mode
NVSHMEM_DISABLE_CUDA_VMM=0 bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 8 --seq_len 128 --model Qwen/Qwen3-32B --check --mode gemm_ar

Performance Benchmark:

# AG_GEMM + GEMM_RS Mode
# Prefill
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 32 --seq_len 128 --model Qwen/Qwen3-32B --mode ag_rs --run_type prefill

# Decode
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 4096 --seq_len 128 --model Qwen/Qwen3-32B --mode ag_rs --run_type decode

# GEMM-AllReduce Mode
# Prefill
NVSHMEM_DISABLE_CUDA_VMM=0 bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 1 --seq_len 128 --model Qwen/Qwen3-32B --mode gemm_ar --run_type prefill

# Decode
NVSHMEM_DISABLE_CUDA_VMM=0 bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_tp_e2e.py --bsz 128 --seq_len 128 --model Qwen/Qwen3-32B --mode gemm_ar --run_type decode

3. Full Inference Pipeline (test_e2e_inference.py)

This script runs a complete generation task (including one prefill step and multiple decode steps) using the Engine class. It measures end-to-end throughput and latency.

# Baseline PyTorch Eager Mode
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_e2e_inference.py --bsz 4096 --gen_len 128 --max_length 150 --backend torch

bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_e2e_inference.py --bsz 128 --gen_len 128 --max_length 150 --backend torch

# Triton-Distributed AG_GEMM + GEMM_RS Mode
bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_e2e_inference.py --bsz 4096 --gen_len 128 --max_length 150 --backend triton_dist

# Triton-Distributed GEMM-AllReduce Mode
NVSHMEM_DISABLE_CUDA_VMM=0 bash ./scripts/launch.sh ./python/triton_dist/test/nvidia/test_e2e_inference.py --bsz 128 --gen_len 128 --max_length 150 --backend triton_dist_gemm_ar