Flash Decode
Distributed Flash Decoding kernels for attention computation.
API Reference
- gqa_fwd_batch_decode(...)
GQA forward batch decode kernel.
- gqa_fwd_batch_decode_persistent(...)
Persistent version of GQA forward batch decode.
- gqa_fwd_batch_decode_aot(...)
AOT-compiled GQA forward batch decode.
- gqa_fwd_batch_decode_persistent_aot(...)
AOT-compiled persistent GQA forward batch decode.
- gqa_fwd_batch_decode_intra_rank(...)
Intra-rank GQA forward batch decode.
- gqa_fwd_batch_decode_intra_rank_aot(...)
AOT-compiled intra-rank GQA forward batch decode.
- kernel_gqa_fwd_batch_decode_split_kv_persistent(...)
Persistent kernel for split KV flash decode.
- kernel_inter_rank_gqa_fwd_batch_decode_combine_kv(...)
Inter-rank kernel for combining KV results.
- get_triton_combine_kv_algo_info(...)
Gets algorithm info for KV combination.
Performance
Flash decode scales efficiently from 1 GPU to 32 GPUs with minimal latency overhead.