Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 18 additions & 50 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,56 +25,6 @@ ModelConfig::get_quant_scheme() const {
}
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::get_rope_scaling() const {
if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) {
return nullptr;
}

const auto &rope_scaling = config_json["rope_scaling"];
if (!rope_scaling.is_object()) {
throw std::runtime_error("rope_scaling must be an object");
}

std::string type_str;
if (rope_scaling.contains("type")) {
type_str = rope_scaling["type"].get<std::string>();
} else if (rope_scaling.contains("rope_type")) {
type_str = rope_scaling["rope_type"].get<std::string>();
} else {
throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field");
}

if (type_str == "longrope") {
// Required fields for LongRopeConfig
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
throw std::runtime_error(
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
}

auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();

float factor = 1.0f;
if (rope_scaling.contains("factor")) {
factor = rope_scaling["factor"].get<float>();
}

return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
original_max_position_embeddings,
factor);
} else if (type_str == "default" || type_str == "none" || type_str == "dynamic") {
// Default scaling, no scaling applied
// Currently not handling extended sequence lengths for dynamic scaling. Add specific branches when needed.
return nullptr;
} else {
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
}
}

infinicore::DataType ModelConfig::get_dtype() const {
std::string dtype_str;
if (config_json.contains("dtype")) {
Expand All @@ -88,6 +38,24 @@ infinicore::DataType ModelConfig::get_dtype() const {
return parse_dtype(dtype_str);
}

size_t ModelConfig::get_rotary_dim() const {
size_t head_dim = get_head_dim();
double partial_rotary_factor = get_or<double>("partial_rotary_factor", 1.0);

if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) {
return head_dim;
}

size_t rotary_dim = static_cast<size_t>(std::llround(
static_cast<double>(head_dim) * partial_rotary_factor));
rotary_dim = std::clamp(rotary_dim, static_cast<size_t>(2), head_dim);

if (rotary_dim % 2 != 0) {
rotary_dim -= 1;
}
return std::max(rotary_dim, static_cast<size_t>(2));
}

std::ostream &operator<<(std::ostream &os, const ModelConfig &config) {
os << config.config_json.dump(4);
return os;
Expand Down
5 changes: 4 additions & 1 deletion csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class ModelConfig {
return get<size_t>("hidden_size") / get<size_t>("num_attention_heads");
}

// Compute the actual rotary dimension based on partial rotation factor
size_t get_rotary_dim() const;

QuantConfig get_quant_config() const {
return quant_config;
}
Expand All @@ -68,7 +71,7 @@ class ModelConfig {

infinicore::DataType get_dtype() const;
infinilm::quantization::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;

void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) {
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
}
Expand Down
66 changes: 66 additions & 0 deletions csrc/layers/rotary_embedding/rope_scaling_creators.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "../../config/model_config.hpp"
#include "infinicore/nn/rope_scaling_configs.hpp"
#include "rotary_embedding_factory.hpp"
#include <vector>

namespace infinilm::layers::rotary_embedding {
namespace {
/**
* @brief Default creator for types that apply no scaling.
* Returns nullptr, which the InfiniCore RoPE layer interprets as a 1.0x pass-through.
*/
std::shared_ptr<infinicore::nn::RopeScalingConfig>
create_default_scaling(const std::shared_ptr<config::ModelConfig> &) {
return nullptr;
}

// TODO(rubik) create_dynamic_scaling

/**
* @brief Creator function for LongRoPE scaling configuration.
* Extracts 'short_factor', 'long_factor', etc., from the model config.
*/
std::shared_ptr<infinicore::nn::RopeScalingConfig>
create_longrope(const std::shared_ptr<config::ModelConfig> &cfg) {
const auto &rope_scaling = cfg->get_config_json()["rope_scaling"];

// Required fields for LongRopeConfig
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
throw std::runtime_error(
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
}

auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();

float factor = 1.0f;
if (rope_scaling.contains("factor")) {
factor = rope_scaling["factor"].get<float>();
}

return std::make_shared<infinicore::nn::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
original_max_position_embeddings,
factor);
}

// Future scaling creators go here (e.g., create_llama3, create_linear)

} // anonymous namespace

// Static self-registration block
// Registers creator functions into the factory registry upon program startup.
static bool _registered = []() {
auto &registry = get_scaling_registry();
registry["default"] = create_default_scaling;
registry["none"] = create_default_scaling;
registry["dynamic"] = create_default_scaling;
registry["longrope"] = create_longrope;
// add new scaling
// registry["llama3"] = create_llama3_scaling;
return true;
}();

} // namespace infinilm::layers::rotary_embedding
72 changes: 35 additions & 37 deletions csrc/layers/rotary_embedding/rotary_embedding.cpp
Original file line number Diff line number Diff line change
@@ -1,46 +1,44 @@
#include "rotary_embedding.hpp"
#include <algorithm> // std::clamp
#include <cmath> // std::llround
#include "../../config/model_config.hpp"
#include "rotary_embedding_factory.hpp"
#include <memory>
#include <string>
#include <unordered_map>

namespace infinilm::layers::rotary_embedding {
namespace {
thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;
} // namespace

size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor) {
if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) {
return head_dim;
}

size_t rotary_dim = static_cast<size_t>(std::llround(
static_cast<double>(head_dim) * partial_rotary_factor));
rotary_dim = std::clamp(rotary_dim, static_cast<size_t>(2), head_dim);
// Cache dictionary to avoid redundant allocations of RoPE instances.
// thread_local ensures it is only visible within this compilation unit.
thread_local std::unordered_map<std::string, std::shared_ptr<infinicore::nn::RoPE>> _ROPE_DICT;

// RoPE operates on complex pairs, so the rotary dimension must be even
if (rotary_dim % 2 != 0) {
rotary_dim -= 1;
std::shared_ptr<infinicore::nn::RoPE>
get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
const infinicore::Device &device,
infinicore::nn::RoPE::Algo algo) {

// 1. Compute the actual rotary dimension
size_t rotary_dim = model_config->get_rotary_dim();

// 2. Resolve scaling config via the internal factory
auto scaling = make_scaling_config(model_config);

// 3. Cache key must include rotary_dim AND the actual scaling type
// to avoid reusing the same RoPE instance across models with different settings
// (Enhancement: dynamically determine scaling_type instead of hardcoding "default")
std::string scaling_type_str = "default";
if (scaling) {
// Assuming we can get the type string from the JSON for cache key generation,
// or ideally, ScalingConfig should have a virtual std::string type_name() const method.
// Here we read it from JSON for the cache key purpose only, keeping it decoupled from InfiniCore.
const auto &rope_scaling_json = model_config->get_config_json()["rope_scaling"];
if (rope_scaling_json.contains("type")) {
scaling_type_str = rope_scaling_json["type"].get<std::string>();
} else if (rope_scaling_json.contains("rope_type")) {
scaling_type_str = rope_scaling_json["rope_type"].get<std::string>();
}
}
return std::max(rotary_dim, static_cast<size_t>(2));
}

std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
const infinicore::Device &device,
infinicore::nn::RoPE::Algo algo) {
// 1. Get head dimension
size_t head_dim = model_config->get_head_dim();

// 2. Safely get partial_rotary_factor, defaulting to 1.0 (full rotation)
double partial_rotary_factor = model_config->get_or<double>("partial_rotary_factor", 1.0);

// 3. Compute the actual rotary dimension
size_t rotary_dim = get_rotary_dim(head_dim, partial_rotary_factor);

// 4. Cache key must include rotary_dim to avoid reusing the same RoPE instance
// across models with different partial_rotary_factor values
const std::string scaling_type = "default";
std::string cache_key = scaling_type + "_rope_dim_" + std::to_string(rotary_dim);
std::string cache_key = scaling_type_str + "_rope_dim_" + std::to_string(rotary_dim)
+ "_dev_" + device.toString();
auto it = _ROPE_DICT.find(cache_key);
if (it != _ROPE_DICT.end()) {
return it->second;
Expand All @@ -49,9 +47,9 @@ std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::c
const auto &dtype = model_config->get_dtype();
size_t max_position_embeddings = model_config->get<size_t>("max_position_embeddings");
double rope_theta = model_config->get<double>("rope_theta");

auto rope = std::make_shared<infinicore::nn::RoPE>(rotary_dim, max_position_embeddings, rope_theta,
algo, dtype, device,
model_config->get_rope_scaling());
algo, dtype, device, scaling);

_ROPE_DICT.emplace(cache_key, rope);
return rope;
Expand Down
23 changes: 15 additions & 8 deletions csrc/layers/rotary_embedding/rotary_embedding.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
#pragma once

#include "../../config/model_config.hpp"
#include "infinicore/nn/rope.hpp"
#include <memory>

namespace infinilm::layers::rotary_embedding {
namespace infinilm::config {
class ModelConfig; // Forward declaration
}

// Compute the actual number of dimensions involved in rotary position embedding.
// For partial rotation, the dimension is clamped to [2, head_dim] and must be even.
size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor);
namespace infinilm::layers::rotary_embedding {

std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
const infinicore::Device &device,
infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX);
/**
* @brief Public API to assemble and construct a complete RoPE module.
*
* @param model_config Model configuration.
* @param device Device to create the cache on.
* @param algo RoPE algorithm type (default: Algo::GPT_NEOX).
*/
std::shared_ptr<infinicore::nn::RoPE>
get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
const infinicore::Device &device,
infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX);

} // namespace infinilm::layers::rotary_embedding
42 changes: 42 additions & 0 deletions csrc/layers/rotary_embedding/rotary_embedding_factory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "rotary_embedding_factory.hpp"
#include "../../config/model_config.hpp"
#include <stdexcept>

namespace infinilm::layers::rotary_embedding {

std::unordered_map<std::string, ScalingCreator> &get_scaling_registry() {
static std::unordered_map<std::string, ScalingCreator> registry;
return registry;
}

std::shared_ptr<infinicore::nn::RopeScalingConfig>
make_scaling_config(const std::shared_ptr<config::ModelConfig> &model_config) {
if (!model_config || !model_config->get_config_json().contains("rope_scaling") || model_config->get_config_json()["rope_scaling"].is_null()) {
return nullptr;
}

const auto &rope_scaling = model_config->get_config_json()["rope_scaling"];
if (!rope_scaling.is_object()) {
throw std::runtime_error("rope_scaling must be an object");
}

std::string scaling_type;
if (rope_scaling.contains("type")) {
scaling_type = rope_scaling["type"].get<std::string>();
} else if (rope_scaling.contains("rope_type")) {
scaling_type = rope_scaling["rope_type"].get<std::string>();
} else {
throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field");
}

// Registry routing: delegate construction to the specific creator
auto &registry = get_scaling_registry();
auto it = registry.find(scaling_type);
if (it != registry.end()) {
return it->second(model_config);
}

throw std::runtime_error("Unsupported rope_scaling_type: " + scaling_type);
}

} // namespace infinilm::layers::rotary_embedding
35 changes: 35 additions & 0 deletions csrc/layers/rotary_embedding/rotary_embedding_factory.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "infinicore/nn/rope.hpp"
#include "infinicore/nn/rope_scaling_configs.hpp"
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>

namespace infinilm::config {
class ModelConfig; // Forward declaration
}

namespace infinilm::layers::rotary_embedding {

/**
* @brief Function pointer type for creating specific RopeScalingConfig instances.
* Implementations should extract parameters from ModelConfig and construct the corresponding Config object.
*/
using ScalingCreator = std::function<std::shared_ptr<infinicore::nn::RopeScalingConfig>(
const std::shared_ptr<infinilm::config::ModelConfig> &)>;

/**
* @brief Get the singleton registry mapping scaling type strings to their creator functions.
*/
std::unordered_map<std::string, ScalingCreator> &get_scaling_registry();

/**
* @brief Factory method to create a RopeScalingConfig based on the ModelConfig.
* Routes the "rope_scaling_type" string to the corresponding registered creator.
*/
std::shared_ptr<infinicore::nn::RopeScalingConfig>
make_scaling_config(const std::shared_ptr<infinilm::config::ModelConfig> &model_config);

} // namespace infinilm::layers::rotary_embedding
2 changes: 1 addition & 1 deletion csrc/models/glm4/glm4_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Glm4Attention::Glm4Attention(std::shared_ptr<infinilm::config::ModelConfig> mode
num_attention_heads_(model_config->get<size_t>("num_attention_heads")),
num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")),
head_dim_(model_config->get_head_dim()),
rotary_dim_(infinilm::layers::rotary_embedding::get_rotary_dim(model_config->get_head_dim(), model_config->get_or<double>("partial_rotary_factor", 1.0))),
rotary_dim_(model_config->get_rotary_dim()),
use_bias_(model_config->get_or<bool>("attention_bias", true)),
use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)) {

Expand Down
2 changes: 1 addition & 1 deletion csrc/models/llama_legacy/llama_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct LlamaConfig : public InfinilmModel::Config {
size_t max_position_embeddings = 2048; // Maximum sequence length
double rope_theta = 10000.0; // RoPE base frequency

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> rope_scaling = nullptr; // RoPE scaling type
std::shared_ptr<infinicore::nn::RopeScalingConfig> rope_scaling = nullptr; // RoPE scaling type

// Normalization
double rms_norm_eps = 1e-6; // RMSNorm epsilon
Expand Down
Loading