diff --git a/csrc/backends/attention_backends.hpp b/csrc/backends/attention_backends.hpp index b274aacc..b7319966 100644 --- a/csrc/backends/attention_backends.hpp +++ b/csrc/backends/attention_backends.hpp @@ -13,6 +13,8 @@ enum class AttentionBackend { STATIC_ATTN, PAGED_ATTN, FLASH_ATTN, + FLASH_PREFILL, + FLASH_DECODE, FLASHINFER, Default = STATIC_ATTN }; @@ -25,6 +27,10 @@ inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) { return os << "AttentionBackend::PAGED_ATTN"; case AttentionBackend::FLASH_ATTN: return os << "AttentionBackend::FLASH_ATTN"; + case AttentionBackend::FLASH_PREFILL: + return os << "AttentionBackend::FLASH_PREFILL"; + case AttentionBackend::FLASH_DECODE: + return os << "AttentionBackend::FLASH_DECODE"; case AttentionBackend::FLASHINFER: return os << "AttentionBackend::FLASHINFER"; default: @@ -46,12 +52,18 @@ inline AttentionBackend parse_attention_backend(const std::string &backend) { if (backend == "flash-attn") { return AttentionBackend::FLASH_ATTN; } + if (backend == "flash-prefill") { + return AttentionBackend::FLASH_PREFILL; + } + if (backend == "flash-decode") { + return AttentionBackend::FLASH_DECODE; + } if (backend == "flashinfer") { return AttentionBackend::FLASHINFER; } throw std::invalid_argument( - "Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer"); + "Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flash-prefill, flash-decode, flashinfer"); } } // namespace infinilm::backends diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index df0ceb29..05da6541 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -255,7 +255,9 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache( size_t block_size = config.block_size(); infinicore::Shape kv_shape; - if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN) { + if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN || + global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_PREFILL || + global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_DECODE) { // FLASH_ATTN kernel expects BSHD layout kv_shape = {2, num_blocks_per_layer, block_size, num_rank_k_heads, k_dim}; } else { diff --git a/csrc/layers/attention/backends/attention_layer.cpp b/csrc/layers/attention/backends/attention_layer.cpp index fcaefa29..beab9545 100644 --- a/csrc/layers/attention/backends/attention_layer.cpp +++ b/csrc/layers/attention/backends/attention_layer.cpp @@ -20,6 +20,12 @@ AttentionLayer::AttentionLayer(size_t num_heads, case ::infinilm::backends::AttentionBackend::FLASH_ATTN: attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); break; + case ::infinilm::backends::AttentionBackend::FLASH_PREFILL: + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + break; + case ::infinilm::backends::AttentionBackend::FLASH_DECODE: + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + break; default: throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend"); } diff --git a/csrc/layers/attention/backends/attention_layer.hpp b/csrc/layers/attention/backends/attention_layer.hpp index 87411062..bf8b2276 100644 --- a/csrc/layers/attention/backends/attention_layer.hpp +++ b/csrc/layers/attention/backends/attention_layer.hpp @@ -2,15 +2,17 @@ #include "../../../backends/attention_backends.hpp" #include "../../../global_state/global_state.hpp" -#include "flash_attn.hpp" #include "infinicore/tensor.hpp" +#include "flash_attn.hpp" +#include "flash_decode_attn.hpp" +#include "flash_prefill_attn.hpp" #include "paged_attn.hpp" #include "static_attn.hpp" #include #include namespace infinilm::layers::attention { -using AttentionImpl = std::variant, std::shared_ptr, std::shared_ptr>; +using AttentionImpl = std::variant, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr>; /** * @brief Attention layer. diff --git a/csrc/layers/attention/backends/flash_decode_attn.cpp b/csrc/layers/attention/backends/flash_decode_attn.cpp new file mode 100644 index 00000000..134069d9 --- /dev/null +++ b/csrc/layers/attention/backends/flash_decode_attn.cpp @@ -0,0 +1,94 @@ +#include "flash_decode_attn.hpp" + +#include "../../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/mha_kvcache.hpp" + +namespace infinilm::layers::attention::backends { + +FlashDecodeAttentionImpl::FlashDecodeAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale), + num_kv_heads_(num_kv_heads), + layer_idx_(layer_idx), + head_dim_(head_size) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::FlashDecodeAttentionImpl: model_config is null"); + } + max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); +} + +infinicore::Tensor FlashDecodeAttentionImpl::forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + infinicore::Tensor &kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + auto total_sequence_lengths = attn_metadata.total_sequence_lengths; + auto input_offsets = attn_metadata.input_offsets; + auto block_tables = attn_metadata.block_tables; + auto slot_mapping = attn_metadata.slot_mapping; + auto cu_seqlens = attn_metadata.cu_seqlens; + + ASSERT(block_tables.has_value()); + ASSERT(slot_mapping.has_value()); + + // 1. update paged kv cache + auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value()); + + size_t seq_len = query->shape()[0]; + bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); + + // 2. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + if (is_prefill) { + infinicore::op::paged_attention_prefill_( + attn_output, + query, + k_total->permute({0, 2, 1, 3}), + v_total->permute({0, 2, 1, 3}), + block_tables.value(), + total_sequence_lengths.value(), + input_offsets.value(), + std::nullopt, + scale_); + } else { + auto q_for_fa = query->view({seq_len, 1, num_heads_, head_dim_}); + auto attn_out_4d = infinicore::op::mha_kvcache( + q_for_fa, + k_total, // [num_blocks, block_size, num_kv_heads, head_dim] + v_total, + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + std::nullopt, + scale_); + attn_output = attn_out_4d->view({seq_len, num_heads_, head_dim_}); + } + attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_}); + return attn_output; +} + +std::tuple FlashDecodeAttentionImpl::do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + infinicore::Tensor &kv_cache, + const infinicore::Tensor slot_mapping) const { + auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0); + auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0); + infinicore::op::paged_caching_( + k_cache_layer, + v_cache_layer, + key, + value, + slot_mapping); + + return {k_cache_layer, v_cache_layer}; +} +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_decode_attn.hpp b/csrc/layers/attention/backends/flash_decode_attn.hpp new file mode 100644 index 00000000..e8866c76 --- /dev/null +++ b/csrc/layers/attention/backends/flash_decode_attn.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "../../../global_state/global_state.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::layers::attention::backends { + +class FlashDecodeAttentionImpl { +public: + FlashDecodeAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx); + + infinicore::Tensor forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + infinicore::Tensor &kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + std::tuple do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + infinicore::Tensor &kv_cache, + const infinicore::Tensor slot_mapping) const; + +private: + size_t num_heads_; + size_t head_size_; + float scale_; + size_t num_kv_heads_; + size_t layer_idx_; + size_t head_dim_; // Note: head_dim equals to head_size + size_t max_position_embeddings_; +}; +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_prefill_attn.cpp b/csrc/layers/attention/backends/flash_prefill_attn.cpp new file mode 100644 index 00000000..953c5ca2 --- /dev/null +++ b/csrc/layers/attention/backends/flash_prefill_attn.cpp @@ -0,0 +1,95 @@ +#include "flash_prefill_attn.hpp" + +#include "../../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/mha_varlen.hpp" + +namespace infinilm::layers::attention::backends { + +FlashPrefillAttentionImpl::FlashPrefillAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale), + num_kv_heads_(num_kv_heads), + layer_idx_(layer_idx), + head_dim_(head_size) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::FlashPrefillAttentionImpl: model_config is null"); + } + max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); +} + +infinicore::Tensor FlashPrefillAttentionImpl::forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + infinicore::Tensor &kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + auto total_sequence_lengths = attn_metadata.total_sequence_lengths; + auto input_offsets = attn_metadata.input_offsets; + auto block_tables = attn_metadata.block_tables; + auto slot_mapping = attn_metadata.slot_mapping; + auto cu_seqlens = attn_metadata.cu_seqlens; + + ASSERT(block_tables.has_value()); + ASSERT(slot_mapping.has_value()); + + // 1. update paged kv cache + auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value()); + + size_t seq_len = query->shape()[0]; + bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); + + // 2. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + if (is_prefill) { + infinicore::op::mha_varlen_( + attn_output, + query, + k_total, + v_total, + input_offsets.value(), + cu_seqlens.value(), + block_tables.value(), + max_position_embeddings_, + max_position_embeddings_, + std::nullopt, + scale_); + } else { + infinicore::op::paged_attention_( + attn_output, + query, + k_total->permute({0, 2, 1, 3}), + v_total->permute({0, 2, 1, 3}), + block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scale_); + } + attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_}); + return attn_output; +} + +std::tuple FlashPrefillAttentionImpl::do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + infinicore::Tensor &kv_cache, + const infinicore::Tensor slot_mapping) const { + auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0); + auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0); + infinicore::op::paged_caching_( + k_cache_layer, + v_cache_layer, + key, + value, + slot_mapping); + + return {k_cache_layer, v_cache_layer}; +} +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_prefill_attn.hpp b/csrc/layers/attention/backends/flash_prefill_attn.hpp new file mode 100644 index 00000000..96c4610e --- /dev/null +++ b/csrc/layers/attention/backends/flash_prefill_attn.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "../../../global_state/global_state.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::layers::attention::backends { + +class FlashPrefillAttentionImpl { +public: + FlashPrefillAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx); + + infinicore::Tensor forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + infinicore::Tensor &kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + std::tuple do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + infinicore::Tensor &kv_cache, + const infinicore::Tensor slot_mapping) const; + +private: + size_t num_heads_; + size_t head_size_; + float scale_; + size_t num_kv_heads_; + size_t layer_idx_; + size_t head_dim_; // Note: head_dim equals to head_size + size_t max_position_embeddings_; +}; +} // namespace infinilm::layers::attention::backends diff --git a/csrc/models/infinilm_model.cpp b/csrc/models/infinilm_model.cpp index 3923474e..c14e6860 100644 --- a/csrc/models/infinilm_model.cpp +++ b/csrc/models/infinilm_model.cpp @@ -62,6 +62,27 @@ std::vector InfinilmModel::default_allocate_kv_cache_tensors case backends::AttentionBackend::FLASH_ATTN: { ; } + case backends::AttentionBackend::FLASH_PREFILL: + case backends::AttentionBackend::FLASH_DECODE: { + auto paged_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == paged_kv_cache_config) { + throw std::runtime_error( + "infinilm::InfinilmModel::default_allocate_kv_cache_tensors: invalid paged kv cache config type"); + } + kv_cache_vec.reserve(num_hidden_layers); + + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + dtype, + *paged_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } + break; + } case backends::AttentionBackend::PAGED_ATTN: { auto paged_kv_cache_config = dynamic_cast(cache_config); if (nullptr == paged_kv_cache_config) { diff --git a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp index a95abab8..21f1a873 100644 --- a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp +++ b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp @@ -57,6 +57,39 @@ std::vector qwen3_next_allocate_kv_cache_tensors( case backends::AttentionBackend::FLASH_ATTN: { ; } + case backends::AttentionBackend::FLASH_PREFILL: + case backends::AttentionBackend::FLASH_DECODE: { + auto paged_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == paged_kv_cache_config) { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid paged kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const auto &dtype{text_config->get_dtype()}; + const std::vector layer_types = text_config->get>("layer_types"); + + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + const std::string &layer_type = layer_types[layer_idx]; + if ("linear_attention" == layer_type) { + kv_cache_vec.emplace_back(); + } else if ("full_attention" == layer_type) { + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + dtype, + *paged_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } else { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: unsupported layer_type '" + layer_type + "' for layer " + std::to_string(layer_idx)); + } + } + break; + } case backends::AttentionBackend::PAGED_ATTN: { auto paged_kv_cache_config = dynamic_cast(cache_config); if (nullptr == paged_kv_cache_config) { diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index aab5dd45..3f8e0749 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -119,7 +119,7 @@ def _add_common_args(self): "--attn", type=str, default="default", - choices=["default", "paged-attn", "flash-attn"], + choices=["default", "paged-attn", "flash-attn", "flash-prefill", "flash-decode"], ) self.parser.add_argument("--enable-graph", action="store_true") self.parser.add_argument( diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83..8a96a1b4 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -113,7 +113,7 @@ def __init__(self, config: EngineConfig): logger.info( f"Using Static KV Cache with max_cache_len={config.max_cache_len}" ) - elif config.cache_type == "paged": + elif config.cache_type in ["paged", "flash-prefill", "flash-decode"]: cache_config = PagedKVCacheConfig( num_blocks=config.num_blocks, block_size=config.block_size )