From 80d5856831e6875cbdc8340014d7b25190d5a15f Mon Sep 17 00:00:00 2001 From: rubik Date: Wed, 20 May 2026 21:18:11 +0800 Subject: [PATCH] issue/392 [Feature](rope): implement Llama 3.1 frequency-aware RoPE scaling Add native support for `rope_type: "llama3"` to enable proper inference for Llama 3.1 models with extended 128k context lengths. Implement `ModelConfig::createLlama3Scaling` method, introducing the piece-wise smooth interpolation mechanism based on frequency wavelengths, strictly referencing the implementations from SGLang and HuggingFace. To seamlessly integrate this logic without modifying the underlying RoPE kernels, the Llama 3 frequency scaling is elegantly mapped to the existing `LongRopeConfig` structure. Key technical adaptations include: - Deriving `smooth_factor = 1.0 / freq_scale` to align with `LongRopeConfig`'s wavelength multiplier semantics. - Enforcing full double-precision (`double`) for intermediate math to prevent precision truncation with Llama 3's large `rope_theta` (e.g., 500000.0), ensuring bit-wise alignment with PyTorch behavior. - Overriding the outer factor to `1.0f` in the `LongRopeConfig` constructor to bypass the unused amplitude scaling penalty. --- csrc/config/model_config.cpp | 148 +++++++++++++++++++++++++++++------ csrc/config/model_config.hpp | 6 ++ 2 files changed, 132 insertions(+), 22 deletions(-) diff --git a/csrc/config/model_config.cpp b/csrc/config/model_config.cpp index e8bb6f5b..ca3f7505 100644 --- a/csrc/config/model_config.cpp +++ b/csrc/config/model_config.cpp @@ -45,34 +45,138 @@ ModelConfig::get_rope_scaling() const { throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field"); } - if (type_str == "longrope") { - // Required fields for LongRopeConfig - if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) { - throw std::runtime_error( - "LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'"); + // SGLang's style if-else routing in python/srt/layers/rotary_embedding.py:get_rope + if (type_str == "llama3") { + return createLlama3Scaling(rope_scaling); + } else if (type_str == "longrope") { + return createLongRopeScaling(rope_scaling); + } else if (type_str == "default") { + return createDefaultScaling(rope_scaling); + } else if (type_str == "none") { + return createNoneScaling(rope_scaling); + } else if (type_str == "dynamic") { + return createDynamicScaling(rope_scaling); + } + + throw std::runtime_error("Unsupported rope_scaling type: " + type_str); +} + +std::shared_ptr +ModelConfig::createDefaultScaling(const nlohmann::json &rope_scaling) const { + return nullptr; +} + +std::shared_ptr +ModelConfig::createNoneScaling(const nlohmann::json &rope_scaling) const { + return nullptr; +} + +std::shared_ptr +ModelConfig::createDynamicScaling(const nlohmann::json &rope_scaling) const { + // [TODO]Dynamic scaling: currently not handling extended sequence lengths + return nullptr; +} + +std::shared_ptr +ModelConfig::createLlama3Scaling(const nlohmann::json &rope_scaling) const { + // Native support for Llama 3.1 frequency-aware RoPE scaling + // equivalent to SGLang/HuggingFace implementations + + // 1. Validate and extract Llama3 specific parameters + const std::vector required_keys = { + "factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"}; + for (const auto &key : required_keys) { + if (!rope_scaling.contains(key)) { + throw std::runtime_error("Llama3RoPE requires '" + key + "' in rope_scaling"); } + } + + const double factor = rope_scaling["factor"].get(); + const double low_freq_factor = rope_scaling["low_freq_factor"].get(); + const double high_freq_factor = rope_scaling["high_freq_factor"].get(); + const size_t orig_max_pos = rope_scaling["original_max_position_embeddings"].get(); + + // 2. Validate and extract model base parameters + if (!config_json.contains("hidden_size") || !config_json.contains("num_attention_heads")) { + throw std::runtime_error("Llama3RoPE requires 'hidden_size' and 'num_attention_heads' in config"); + } + const size_t hidden_size = config_json["hidden_size"].get(); + const size_t num_heads = config_json["num_attention_heads"].get(); + const size_t head_dim = hidden_size / num_heads; + const double theta = config_json.value("rope_theta", 10000.0); - auto short_factor = rope_scaling["short_factor"].get>(); - auto long_factor = rope_scaling["long_factor"].get>(); - size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get(); + // 3. Pre-compute smooth factors based on wavelength + constexpr double kPi = 3.14159265358979323846; + const size_t cache_dim = head_dim / 2; + const double low_freq_wavelen = static_cast(orig_max_pos) / low_freq_factor; + const double high_freq_wavelen = static_cast(orig_max_pos) / high_freq_factor; - float factor = 1.0f; - if (rope_scaling.contains("factor")) { - factor = rope_scaling["factor"].get(); + const bool has_smooth_range = (high_freq_factor != low_freq_factor); + const double smooth_denom = has_smooth_range ? (high_freq_factor - low_freq_factor) : 1.0; + + std::vector smooth_factors(cache_dim); + for (size_t j = 0; j < cache_dim; ++j) { + const double exponent = 2.0 * static_cast(j) / static_cast(head_dim); + const double inv_freq = 1.0 / std::pow(theta, exponent); + const double wavelen = 2.0 * kPi / inv_freq; + + if (wavelen < high_freq_wavelen) { + // High frequency: no scaling (freq_scale = 1.0) + smooth_factors[j] = 1.0f; + } else if (wavelen > low_freq_wavelen) { + // Low frequency: full scaling (freq_scale = 1.0 / factor) + smooth_factors[j] = static_cast(factor); + } else { + // Mid frequency: smooth frequency interpolation + // + // Equivalent to SGLang's implementation: + // smooth = (orig_max_pos / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + // new_freq = (1 - smooth) * inv_freq / factor + smooth * inv_freq + // => freq_scale = (1 - smooth) / factor + smooth + // + // Since LongRopeConfig applies factors as wavelength multipliers + // (i.e., new_freq = inv_freq / scale), the required smooth_factor + // is the inverse of the frequency scale. + // + double smooth = 0.0; + if (has_smooth_range) { + smooth = (static_cast(orig_max_pos) / wavelen - low_freq_factor) / smooth_denom; + } + const double freq_scale = (1.0 - smooth) / factor + smooth; + smooth_factors[j] = static_cast(1.0 / freq_scale); } + } - return std::make_shared( - std::move(short_factor), - std::move(long_factor), - original_max_position_embeddings, - factor); - } else if (type_str == "default" || type_str == "none" || type_str == "dynamic") { - // Default scaling, no scaling applied - // Currently not handling extended sequence lengths for dynamic scaling. Add specific branches when needed. - return nullptr; - } else { - throw std::runtime_error("Unsupported rope_scaling type: " + type_str); + // 4. Adapt to LongRopeConfig + // - short_factor and long_factor use the same smooth_factors + // - Pass factor=1.0f to bypass the amplitude scaling sqrt(log(...)) in LongRopeConfig constructor + return std::make_shared( + smooth_factors, // short_factor + smooth_factors, // long_factor + orig_max_pos, + 1.0f // Force 1.0f to disable amplitude scaling + ); +} + +std::shared_ptr +ModelConfig::createLongRopeScaling(const nlohmann::json &rope_scaling) const { + const std::vector required_keys = {"short_factor", "long_factor", "original_max_position_embeddings"}; + for (const auto &key : required_keys) { + if (!rope_scaling.contains(key)) { + throw std::runtime_error("LongRopeConfig requires '" + key + "' in rope_scaling"); + } } + + auto short_factor = rope_scaling["short_factor"].get>(); + auto long_factor = rope_scaling["long_factor"].get>(); + const size_t orig_max_pos = rope_scaling["original_max_position_embeddings"].get(); + const float factor = rope_scaling.value("factor", 1.0f); + + return std::make_shared( + std::move(short_factor), + std::move(long_factor), + orig_max_pos, + factor); } infinicore::DataType ModelConfig::get_dtype() const { diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index 3752e218..0b5fa9ff 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -105,5 +105,11 @@ class ModelConfig { private: nlohmann::json config_json; QuantConfig quant_config; + + std::shared_ptr createLlama3Scaling(const nlohmann::json &rope_scaling) const; + std::shared_ptr createLongRopeScaling(const nlohmann::json &rope_scaling) const; + std::shared_ptr createDefaultScaling(const nlohmann::json &rope_scaling) const; + std::shared_ptr createNoneScaling(const nlohmann::json &rope_scaling) const; + std::shared_ptr createDynamicScaling(const nlohmann::json &rope_scaling) const; }; } // namespace infinilm::config