From bcbb1c3990551d115fb3d02320794ec76be54947 Mon Sep 17 00:00:00 2001 From: rubik Date: Thu, 28 May 2026 16:55:02 +0800 Subject: [PATCH] issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache - Decouple scaling config instantiation from ModelConfig via factory and registry pattern. - Add thread-local RoPE cache with device-scoped keys to reduce VRAM usage and ensure TP safety. - Centralize rotary dimension calculation into ModelConfig. --- csrc/config/model_config.cpp | 68 +++++------------- csrc/config/model_config.hpp | 5 +- .../rope_scaling_creators.cpp | 66 +++++++++++++++++ .../rotary_embedding/rotary_embedding.cpp | 72 +++++++++---------- .../rotary_embedding/rotary_embedding.hpp | 23 +++--- .../rotary_embedding_factory.cpp | 42 +++++++++++ .../rotary_embedding_factory.hpp | 35 +++++++++ csrc/models/glm4/glm4_attention.cpp | 2 +- csrc/models/llama_legacy/llama_config.hpp | 2 +- csrc/models/llama_legacy/llama_model.cpp | 4 +- csrc/pybind11/models/llama_legacy.hpp | 6 +- 11 files changed, 223 insertions(+), 102 deletions(-) create mode 100644 csrc/layers/rotary_embedding/rope_scaling_creators.cpp create mode 100644 csrc/layers/rotary_embedding/rotary_embedding_factory.cpp create mode 100644 csrc/layers/rotary_embedding/rotary_embedding_factory.hpp diff --git a/csrc/config/model_config.cpp b/csrc/config/model_config.cpp index e8bb6f5b2..f0095558e 100644 --- a/csrc/config/model_config.cpp +++ b/csrc/config/model_config.cpp @@ -25,56 +25,6 @@ ModelConfig::get_quant_scheme() const { } } -std::shared_ptr -ModelConfig::get_rope_scaling() const { - if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) { - return nullptr; - } - - const auto &rope_scaling = config_json["rope_scaling"]; - if (!rope_scaling.is_object()) { - throw std::runtime_error("rope_scaling must be an object"); - } - - std::string type_str; - if (rope_scaling.contains("type")) { - type_str = rope_scaling["type"].get(); - } else if (rope_scaling.contains("rope_type")) { - type_str = rope_scaling["rope_type"].get(); - } else { - throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field"); - } - - if (type_str == "longrope") { - // Required fields for LongRopeConfig - if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) { - throw std::runtime_error( - "LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'"); - } - - auto short_factor = rope_scaling["short_factor"].get>(); - auto long_factor = rope_scaling["long_factor"].get>(); - size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get(); - - float factor = 1.0f; - if (rope_scaling.contains("factor")) { - factor = rope_scaling["factor"].get(); - } - - return std::make_shared( - std::move(short_factor), - std::move(long_factor), - original_max_position_embeddings, - factor); - } else if (type_str == "default" || type_str == "none" || type_str == "dynamic") { - // Default scaling, no scaling applied - // Currently not handling extended sequence lengths for dynamic scaling. Add specific branches when needed. - return nullptr; - } else { - throw std::runtime_error("Unsupported rope_scaling type: " + type_str); - } -} - infinicore::DataType ModelConfig::get_dtype() const { std::string dtype_str; if (config_json.contains("dtype")) { @@ -88,6 +38,24 @@ infinicore::DataType ModelConfig::get_dtype() const { return parse_dtype(dtype_str); } +size_t ModelConfig::get_rotary_dim() const { + size_t head_dim = get_head_dim(); + double partial_rotary_factor = get_or("partial_rotary_factor", 1.0); + + if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) { + return head_dim; + } + + size_t rotary_dim = static_cast(std::llround( + static_cast(head_dim) * partial_rotary_factor)); + rotary_dim = std::clamp(rotary_dim, static_cast(2), head_dim); + + if (rotary_dim % 2 != 0) { + rotary_dim -= 1; + } + return std::max(rotary_dim, static_cast(2)); +} + std::ostream &operator<<(std::ostream &os, const ModelConfig &config) { os << config.config_json.dump(4); return os; diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index 3752e2188..644d40522 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -58,6 +58,9 @@ class ModelConfig { return get("hidden_size") / get("num_attention_heads"); } + // Compute the actual rotary dimension based on partial rotation factor + size_t get_rotary_dim() const; + QuantConfig get_quant_config() const { return quant_config; } @@ -68,7 +71,7 @@ class ModelConfig { infinicore::DataType get_dtype() const; infinilm::quantization::QuantScheme get_quant_scheme() const; - std::shared_ptr get_rope_scaling() const; + void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { this->quant_config.set_kv_quant_scheme(kv_cache_dtype); } diff --git a/csrc/layers/rotary_embedding/rope_scaling_creators.cpp b/csrc/layers/rotary_embedding/rope_scaling_creators.cpp new file mode 100644 index 000000000..2600dcf6a --- /dev/null +++ b/csrc/layers/rotary_embedding/rope_scaling_creators.cpp @@ -0,0 +1,66 @@ +#include "../../config/model_config.hpp" +#include "infinicore/nn/rope_scaling_configs.hpp" +#include "rotary_embedding_factory.hpp" +#include + +namespace infinilm::layers::rotary_embedding { +namespace { +/** + * @brief Default creator for types that apply no scaling. + * Returns nullptr, which the InfiniCore RoPE layer interprets as a 1.0x pass-through. + */ +std::shared_ptr +create_default_scaling(const std::shared_ptr &) { + return nullptr; +} + +// TODO(rubik) create_dynamic_scaling + +/** + * @brief Creator function for LongRoPE scaling configuration. + * Extracts 'short_factor', 'long_factor', etc., from the model config. + */ +std::shared_ptr +create_longrope(const std::shared_ptr &cfg) { + const auto &rope_scaling = cfg->get_config_json()["rope_scaling"]; + + // Required fields for LongRopeConfig + if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) { + throw std::runtime_error( + "LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'"); + } + + auto short_factor = rope_scaling["short_factor"].get>(); + auto long_factor = rope_scaling["long_factor"].get>(); + size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get(); + + float factor = 1.0f; + if (rope_scaling.contains("factor")) { + factor = rope_scaling["factor"].get(); + } + + return std::make_shared( + std::move(short_factor), + std::move(long_factor), + original_max_position_embeddings, + factor); +} + +// Future scaling creators go here (e.g., create_llama3, create_linear) + +} // anonymous namespace + +// Static self-registration block +// Registers creator functions into the factory registry upon program startup. +static bool _registered = []() { + auto ®istry = get_scaling_registry(); + registry["default"] = create_default_scaling; + registry["none"] = create_default_scaling; + registry["dynamic"] = create_default_scaling; + registry["longrope"] = create_longrope; + // add new scaling + // registry["llama3"] = create_llama3_scaling; + return true; +}(); + +} // namespace infinilm::layers::rotary_embedding diff --git a/csrc/layers/rotary_embedding/rotary_embedding.cpp b/csrc/layers/rotary_embedding/rotary_embedding.cpp index 50d695332..78d2a16dc 100644 --- a/csrc/layers/rotary_embedding/rotary_embedding.cpp +++ b/csrc/layers/rotary_embedding/rotary_embedding.cpp @@ -1,46 +1,44 @@ #include "rotary_embedding.hpp" -#include // std::clamp -#include // std::llround +#include "../../config/model_config.hpp" +#include "rotary_embedding_factory.hpp" +#include #include -#include namespace infinilm::layers::rotary_embedding { -namespace { -thread_local std::unordered_map> _ROPE_DICT; -} // namespace - -size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor) { - if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) { - return head_dim; - } - size_t rotary_dim = static_cast(std::llround( - static_cast(head_dim) * partial_rotary_factor)); - rotary_dim = std::clamp(rotary_dim, static_cast(2), head_dim); +// Cache dictionary to avoid redundant allocations of RoPE instances. +// thread_local ensures it is only visible within this compilation unit. +thread_local std::unordered_map> _ROPE_DICT; - // RoPE operates on complex pairs, so the rotary dimension must be even - if (rotary_dim % 2 != 0) { - rotary_dim -= 1; +std::shared_ptr +get_rope(const std::shared_ptr &model_config, + const infinicore::Device &device, + infinicore::nn::RoPE::Algo algo) { + + // 1. Compute the actual rotary dimension + size_t rotary_dim = model_config->get_rotary_dim(); + + // 2. Resolve scaling config via the internal factory + auto scaling = make_scaling_config(model_config); + + // 3. Cache key must include rotary_dim AND the actual scaling type + // to avoid reusing the same RoPE instance across models with different settings + // (Enhancement: dynamically determine scaling_type instead of hardcoding "default") + std::string scaling_type_str = "default"; + if (scaling) { + // Assuming we can get the type string from the JSON for cache key generation, + // or ideally, ScalingConfig should have a virtual std::string type_name() const method. + // Here we read it from JSON for the cache key purpose only, keeping it decoupled from InfiniCore. + const auto &rope_scaling_json = model_config->get_config_json()["rope_scaling"]; + if (rope_scaling_json.contains("type")) { + scaling_type_str = rope_scaling_json["type"].get(); + } else if (rope_scaling_json.contains("rope_type")) { + scaling_type_str = rope_scaling_json["rope_type"].get(); + } } - return std::max(rotary_dim, static_cast(2)); -} -std::shared_ptr get_rope(const std::shared_ptr &model_config, - const infinicore::Device &device, - infinicore::nn::RoPE::Algo algo) { - // 1. Get head dimension - size_t head_dim = model_config->get_head_dim(); - - // 2. Safely get partial_rotary_factor, defaulting to 1.0 (full rotation) - double partial_rotary_factor = model_config->get_or("partial_rotary_factor", 1.0); - - // 3. Compute the actual rotary dimension - size_t rotary_dim = get_rotary_dim(head_dim, partial_rotary_factor); - - // 4. Cache key must include rotary_dim to avoid reusing the same RoPE instance - // across models with different partial_rotary_factor values - const std::string scaling_type = "default"; - std::string cache_key = scaling_type + "_rope_dim_" + std::to_string(rotary_dim); + std::string cache_key = scaling_type_str + "_rope_dim_" + std::to_string(rotary_dim) + + "_dev_" + device.toString(); auto it = _ROPE_DICT.find(cache_key); if (it != _ROPE_DICT.end()) { return it->second; @@ -49,9 +47,9 @@ std::shared_ptr get_rope(const std::shared_ptrget_dtype(); size_t max_position_embeddings = model_config->get("max_position_embeddings"); double rope_theta = model_config->get("rope_theta"); + auto rope = std::make_shared(rotary_dim, max_position_embeddings, rope_theta, - algo, dtype, device, - model_config->get_rope_scaling()); + algo, dtype, device, scaling); _ROPE_DICT.emplace(cache_key, rope); return rope; diff --git a/csrc/layers/rotary_embedding/rotary_embedding.hpp b/csrc/layers/rotary_embedding/rotary_embedding.hpp index aab529bea..1e9ff23a9 100644 --- a/csrc/layers/rotary_embedding/rotary_embedding.hpp +++ b/csrc/layers/rotary_embedding/rotary_embedding.hpp @@ -1,17 +1,24 @@ #pragma once -#include "../../config/model_config.hpp" #include "infinicore/nn/rope.hpp" #include -namespace infinilm::layers::rotary_embedding { +namespace infinilm::config { +class ModelConfig; // Forward declaration +} -// Compute the actual number of dimensions involved in rotary position embedding. -// For partial rotation, the dimension is clamped to [2, head_dim] and must be even. -size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor); +namespace infinilm::layers::rotary_embedding { -std::shared_ptr get_rope(const std::shared_ptr &model_config, - const infinicore::Device &device, - infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX); +/** + * @brief Public API to assemble and construct a complete RoPE module. + * + * @param model_config Model configuration. + * @param device Device to create the cache on. + * @param algo RoPE algorithm type (default: Algo::GPT_NEOX). + */ +std::shared_ptr +get_rope(const std::shared_ptr &model_config, + const infinicore::Device &device, + infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX); } // namespace infinilm::layers::rotary_embedding diff --git a/csrc/layers/rotary_embedding/rotary_embedding_factory.cpp b/csrc/layers/rotary_embedding/rotary_embedding_factory.cpp new file mode 100644 index 000000000..4866de9db --- /dev/null +++ b/csrc/layers/rotary_embedding/rotary_embedding_factory.cpp @@ -0,0 +1,42 @@ +#include "rotary_embedding_factory.hpp" +#include "../../config/model_config.hpp" +#include + +namespace infinilm::layers::rotary_embedding { + +std::unordered_map &get_scaling_registry() { + static std::unordered_map registry; + return registry; +} + +std::shared_ptr +make_scaling_config(const std::shared_ptr &model_config) { + if (!model_config || !model_config->get_config_json().contains("rope_scaling") || model_config->get_config_json()["rope_scaling"].is_null()) { + return nullptr; + } + + const auto &rope_scaling = model_config->get_config_json()["rope_scaling"]; + if (!rope_scaling.is_object()) { + throw std::runtime_error("rope_scaling must be an object"); + } + + std::string scaling_type; + if (rope_scaling.contains("type")) { + scaling_type = rope_scaling["type"].get(); + } else if (rope_scaling.contains("rope_type")) { + scaling_type = rope_scaling["rope_type"].get(); + } else { + throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field"); + } + + // Registry routing: delegate construction to the specific creator + auto ®istry = get_scaling_registry(); + auto it = registry.find(scaling_type); + if (it != registry.end()) { + return it->second(model_config); + } + + throw std::runtime_error("Unsupported rope_scaling_type: " + scaling_type); +} + +} // namespace infinilm::layers::rotary_embedding diff --git a/csrc/layers/rotary_embedding/rotary_embedding_factory.hpp b/csrc/layers/rotary_embedding/rotary_embedding_factory.hpp new file mode 100644 index 000000000..088998ffb --- /dev/null +++ b/csrc/layers/rotary_embedding/rotary_embedding_factory.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "infinicore/nn/rope.hpp" +#include "infinicore/nn/rope_scaling_configs.hpp" +#include +#include +#include +#include + +namespace infinilm::config { +class ModelConfig; // Forward declaration +} + +namespace infinilm::layers::rotary_embedding { + +/** + * @brief Function pointer type for creating specific RopeScalingConfig instances. + * Implementations should extract parameters from ModelConfig and construct the corresponding Config object. + */ +using ScalingCreator = std::function( + const std::shared_ptr &)>; + +/** + * @brief Get the singleton registry mapping scaling type strings to their creator functions. + */ +std::unordered_map &get_scaling_registry(); + +/** + * @brief Factory method to create a RopeScalingConfig based on the ModelConfig. + * Routes the "rope_scaling_type" string to the corresponding registered creator. + */ +std::shared_ptr +make_scaling_config(const std::shared_ptr &model_config); + +} // namespace infinilm::layers::rotary_embedding diff --git a/csrc/models/glm4/glm4_attention.cpp b/csrc/models/glm4/glm4_attention.cpp index a72c3a85e..20dbd688b 100644 --- a/csrc/models/glm4/glm4_attention.cpp +++ b/csrc/models/glm4/glm4_attention.cpp @@ -21,7 +21,7 @@ Glm4Attention::Glm4Attention(std::shared_ptr mode num_attention_heads_(model_config->get("num_attention_heads")), num_key_value_heads_(model_config->get("num_key_value_heads")), head_dim_(model_config->get_head_dim()), - rotary_dim_(infinilm::layers::rotary_embedding::get_rotary_dim(model_config->get_head_dim(), model_config->get_or("partial_rotary_factor", 1.0))), + rotary_dim_(model_config->get_rotary_dim()), use_bias_(model_config->get_or("attention_bias", true)), use_output_bias_(model_config->get_or("attention_output_bias", false)) { diff --git a/csrc/models/llama_legacy/llama_config.hpp b/csrc/models/llama_legacy/llama_config.hpp index c3aa1cfaa..44cee1b89 100644 --- a/csrc/models/llama_legacy/llama_config.hpp +++ b/csrc/models/llama_legacy/llama_config.hpp @@ -36,7 +36,7 @@ struct LlamaConfig : public InfinilmModel::Config { size_t max_position_embeddings = 2048; // Maximum sequence length double rope_theta = 10000.0; // RoPE base frequency - std::shared_ptr rope_scaling = nullptr; // RoPE scaling type + std::shared_ptr rope_scaling = nullptr; // RoPE scaling type // Normalization double rms_norm_eps = 1e-6; // RMSNorm epsilon diff --git a/csrc/models/llama_legacy/llama_model.cpp b/csrc/models/llama_legacy/llama_model.cpp index 1a0a2f185..a2dea5d88 100644 --- a/csrc/models/llama_legacy/llama_model.cpp +++ b/csrc/models/llama_legacy/llama_model.cpp @@ -1,4 +1,5 @@ #include "llama_model.hpp" +#include "../../layers/rotary_embedding/rotary_embedding_factory.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rope.hpp" @@ -22,9 +23,10 @@ LlamaModel::LlamaModel(std::shared_ptr model_conf } INFINICORE_NN_MODULE_INIT(norm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), dtype, device); + auto rope_scaling_config = infinilm::layers::rotary_embedding::make_scaling_config(model_config_); INFINICORE_NN_MODULE_INIT(rotary_emb, model_config_->get_head_dim(), model_config_->get("max_position_embeddings"), model_config_->get("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX, - dtype, device, model_config_->get_rope_scaling()); + dtype, device, rope_scaling_config); for (auto &layer : layers_) { if (layer) { diff --git a/csrc/pybind11/models/llama_legacy.hpp b/csrc/pybind11/models/llama_legacy.hpp index cd6dbf1b7..3458c94b1 100644 --- a/csrc/pybind11/models/llama_legacy.hpp +++ b/csrc/pybind11/models/llama_legacy.hpp @@ -104,8 +104,8 @@ inline void bind_llama(py::module &m) { return py::none(); } - using ScalingConfig = infinicore::nn::RoPE::ScalingConfig; - using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig; + using ScalingConfig = infinicore::nn::RopeScalingConfig; + using LongRopeConfig = infinicore::nn::LongRopeConfig; py::dict d; @@ -148,7 +148,7 @@ inline void bind_llama(py::module &m) { : get_str("type"); if (type == "longrope") { - using LongRopeConfig = infinicore::nn::RoPE::LongRopeConfig; + using LongRopeConfig = infinicore::nn::LongRopeConfig; if (!d.contains("short_factor") || !d.contains("long_factor") || !d.contains("original_max_position_embeddings")) { throw py::value_error(