Tensor Parallel Attention
Tensor Parallel Attention layer for distributed inference.
Description
The TP Attention layer provides distributed attention computation with AllGather/ReduceScatter overlapping for efficient tensor parallelism.
Modes
ag_rs: AllGather input + ReduceScatter outputallreduce: AllReduce-based parallelismgemm_ar: GEMM + AllReduce fusion
Example Usage
# Prefill mode
bash scripts/launch.sh python/triton_dist/test/nvidia/test_tp_attn.py \
--bsz 32 --seq_len 128 --model <model_path> --run_type prefill --mode ag_rs
# Decode mode
bash scripts/launch.sh python/triton_dist/test/nvidia/test_tp_attn.py \
--bsz 128 --seq_len 128 --model <model_path> --run_type decode --mode ag_rs