Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion csrc/backends/attention_backends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ enum class AttentionBackend {
STATIC_ATTN,
PAGED_ATTN,
FLASH_ATTN,
FLASH_PREFILL,
FLASH_DECODE,
FLASHINFER,
Default = STATIC_ATTN
};
Expand All @@ -25,6 +27,10 @@ inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) {
return os << "AttentionBackend::PAGED_ATTN";
case AttentionBackend::FLASH_ATTN:
return os << "AttentionBackend::FLASH_ATTN";
case AttentionBackend::FLASH_PREFILL:
return os << "AttentionBackend::FLASH_PREFILL";
case AttentionBackend::FLASH_DECODE:
return os << "AttentionBackend::FLASH_DECODE";
case AttentionBackend::FLASHINFER:
return os << "AttentionBackend::FLASHINFER";
default:
Expand All @@ -46,12 +52,18 @@ inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "flash-attn") {
return AttentionBackend::FLASH_ATTN;
}
if (backend == "flash-prefill") {
return AttentionBackend::FLASH_PREFILL;
}
if (backend == "flash-decode") {
return AttentionBackend::FLASH_DECODE;
}
if (backend == "flashinfer") {
return AttentionBackend::FLASHINFER;
}

throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer");
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flash-prefill, flash-decode, flashinfer");
}

} // namespace infinilm::backends
4 changes: 3 additions & 1 deletion csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ infinicore::Tensor PagedKVCache::create_layer_kv_cache(
size_t block_size = config.block_size();

infinicore::Shape kv_shape;
if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN) {
if (global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_ATTN ||
global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_PREFILL ||
global_state::get_infinilm_config().attention_backend == backends::AttentionBackend::FLASH_DECODE) {
// FLASH_ATTN kernel expects BSHD layout
kv_shape = {2, num_blocks_per_layer, block_size, num_rank_k_heads, k_dim};
} else {
Expand Down
6 changes: 6 additions & 0 deletions csrc/layers/attention/backends/attention_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ AttentionLayer::AttentionLayer(size_t num_heads,
case ::infinilm::backends::AttentionBackend::FLASH_ATTN:
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
break;
case ::infinilm::backends::AttentionBackend::FLASH_PREFILL:
attn_backend_impl_ = std::make_shared<backends::FlashPrefillAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
break;
case ::infinilm::backends::AttentionBackend::FLASH_DECODE:
attn_backend_impl_ = std::make_shared<backends::FlashDecodeAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
break;
default:
throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend");
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/layers/attention/backends/attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

#include "../../../backends/attention_backends.hpp"
#include "../../../global_state/global_state.hpp"
#include "flash_attn.hpp"
#include "infinicore/tensor.hpp"
#include "flash_attn.hpp"
#include "flash_decode_attn.hpp"
#include "flash_prefill_attn.hpp"
#include "paged_attn.hpp"
#include "static_attn.hpp"
#include <memory>
#include <variant>

namespace infinilm::layers::attention {
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>>;
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>, std::shared_ptr<backends::FlashPrefillAttentionImpl>, std::shared_ptr<backends::FlashDecodeAttentionImpl>>;

/**
* @brief Attention layer.
Expand Down
94 changes: 94 additions & 0 deletions csrc/layers/attention/backends/flash_decode_attn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "flash_decode_attn.hpp"

#include "../../../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_kvcache.hpp"

namespace infinilm::layers::attention::backends {

FlashDecodeAttentionImpl::FlashDecodeAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx)
: num_heads_(num_heads),
head_size_(head_size),
scale_(scale),
num_kv_heads_(num_kv_heads),
layer_idx_(layer_idx),
head_dim_(head_size) {

const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
if (!infinilm_config.model_config) {
throw std::runtime_error("infinilm::layers::attention::backends::FlashDecodeAttentionImpl: model_config is null");
}
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
}

infinicore::Tensor FlashDecodeAttentionImpl::forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
auto input_offsets = attn_metadata.input_offsets;
auto block_tables = attn_metadata.block_tables;
auto slot_mapping = attn_metadata.slot_mapping;
auto cu_seqlens = attn_metadata.cu_seqlens;

ASSERT(block_tables.has_value());
ASSERT(slot_mapping.has_value());

// 1. update paged kv cache
auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value());

size_t seq_len = query->shape()[0];
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);

// 2. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
if (is_prefill) {
infinicore::op::paged_attention_prefill_(
attn_output,
query,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
block_tables.value(),
total_sequence_lengths.value(),
input_offsets.value(),
std::nullopt,
scale_);
} else {
auto q_for_fa = query->view({seq_len, 1, num_heads_, head_dim_});
auto attn_out_4d = infinicore::op::mha_kvcache(
q_for_fa,
k_total, // [num_blocks, block_size, num_kv_heads, head_dim]
v_total,
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
std::nullopt,
scale_);
attn_output = attn_out_4d->view({seq_len, num_heads_, head_dim_});
}
attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_});
return attn_output;
}

std::tuple<infinicore::Tensor, infinicore::Tensor> FlashDecodeAttentionImpl::do_kv_cache_update(const AttentionLayer &layer,
const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const {
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
infinicore::op::paged_caching_(
k_cache_layer,
v_cache_layer,
key,
value,
slot_mapping);

return {k_cache_layer, v_cache_layer};
}
} // namespace infinilm::layers::attention::backends
43 changes: 43 additions & 0 deletions csrc/layers/attention/backends/flash_decode_attn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include "../../../global_state/global_state.hpp"
#include "infinicore/tensor.hpp"
#include <tuple>

namespace infinilm::layers::attention {
class AttentionLayer;
}

namespace infinilm::layers::attention::backends {

class FlashDecodeAttentionImpl {
public:
FlashDecodeAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx);

infinicore::Tensor forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const;

std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const AttentionLayer &layer,
const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const;

private:
size_t num_heads_;
size_t head_size_;
float scale_;
size_t num_kv_heads_;
size_t layer_idx_;
size_t head_dim_; // Note: head_dim equals to head_size
size_t max_position_embeddings_;
};
} // namespace infinilm::layers::attention::backends
95 changes: 95 additions & 0 deletions csrc/layers/attention/backends/flash_prefill_attn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "flash_prefill_attn.hpp"

#include "../../../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_varlen.hpp"

namespace infinilm::layers::attention::backends {

FlashPrefillAttentionImpl::FlashPrefillAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx)
: num_heads_(num_heads),
head_size_(head_size),
scale_(scale),
num_kv_heads_(num_kv_heads),
layer_idx_(layer_idx),
head_dim_(head_size) {

const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
if (!infinilm_config.model_config) {
throw std::runtime_error("infinilm::layers::attention::backends::FlashPrefillAttentionImpl: model_config is null");
}
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
}

infinicore::Tensor FlashPrefillAttentionImpl::forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
auto input_offsets = attn_metadata.input_offsets;
auto block_tables = attn_metadata.block_tables;
auto slot_mapping = attn_metadata.slot_mapping;
auto cu_seqlens = attn_metadata.cu_seqlens;

ASSERT(block_tables.has_value());
ASSERT(slot_mapping.has_value());

// 1. update paged kv cache
auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value());

size_t seq_len = query->shape()[0];
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);

// 2. Compute attention
infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
if (is_prefill) {
infinicore::op::mha_varlen_(
attn_output,
query,
k_total,
v_total,
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scale_);
} else {
infinicore::op::paged_attention_(
attn_output,
query,
k_total->permute({0, 2, 1, 3}),
v_total->permute({0, 2, 1, 3}),
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scale_);
}
attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_});
return attn_output;
}

std::tuple<infinicore::Tensor, infinicore::Tensor> FlashPrefillAttentionImpl::do_kv_cache_update(const AttentionLayer &layer,
const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const {
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
infinicore::op::paged_caching_(
k_cache_layer,
v_cache_layer,
key,
value,
slot_mapping);

return {k_cache_layer, v_cache_layer};
}
} // namespace infinilm::layers::attention::backends
43 changes: 43 additions & 0 deletions csrc/layers/attention/backends/flash_prefill_attn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include "../../../global_state/global_state.hpp"
#include "infinicore/tensor.hpp"
#include <tuple>

namespace infinilm::layers::attention {
class AttentionLayer;
}

namespace infinilm::layers::attention::backends {

class FlashPrefillAttentionImpl {
public:
FlashPrefillAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx);

infinicore::Tensor forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const;

std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const AttentionLayer &layer,
const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const;

private:
size_t num_heads_;
size_t head_size_;
float scale_;
size_t num_kv_heads_;
size_t layer_idx_;
size_t head_dim_; // Note: head_dim equals to head_size
size_t max_position_embeddings_;
};
} // namespace infinilm::layers::attention::backends
21 changes: 21 additions & 0 deletions csrc/models/infinilm_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,27 @@ std::vector<infinicore::Tensor> InfinilmModel::default_allocate_kv_cache_tensors
case backends::AttentionBackend::FLASH_ATTN: {
;
}
case backends::AttentionBackend::FLASH_PREFILL:
case backends::AttentionBackend::FLASH_DECODE: {
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
if (nullptr == paged_kv_cache_config) {
throw std::runtime_error(
"infinilm::InfinilmModel::default_allocate_kv_cache_tensors: invalid paged kv cache config type");
}
kv_cache_vec.reserve(num_hidden_layers);

for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) {
auto kv_cache = cache::PagedKVCache::create_layer_kv_cache(
head_dim,
head_dim,
num_key_value_heads,
num_key_value_heads,
dtype,
*paged_kv_cache_config);
kv_cache_vec.push_back(kv_cache);
}
break;
}
case backends::AttentionBackend::PAGED_ATTN: {
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
if (nullptr == paged_kv_cache_config) {
Expand Down
Loading