Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 126 additions & 22 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,138 @@ ModelConfig::get_rope_scaling() const {
throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field");
}

if (type_str == "longrope") {
// Required fields for LongRopeConfig
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
throw std::runtime_error(
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
// SGLang's style if-else routing in python/srt/layers/rotary_embedding.py:get_rope
if (type_str == "llama3") {
return createLlama3Scaling(rope_scaling);
} else if (type_str == "longrope") {
return createLongRopeScaling(rope_scaling);
} else if (type_str == "default") {
return createDefaultScaling(rope_scaling);
} else if (type_str == "none") {
return createNoneScaling(rope_scaling);
} else if (type_str == "dynamic") {
return createDynamicScaling(rope_scaling);
}

throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::createDefaultScaling(const nlohmann::json &rope_scaling) const {
return nullptr;
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::createNoneScaling(const nlohmann::json &rope_scaling) const {
return nullptr;
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::createDynamicScaling(const nlohmann::json &rope_scaling) const {
// [TODO]Dynamic scaling: currently not handling extended sequence lengths
return nullptr;
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::createLlama3Scaling(const nlohmann::json &rope_scaling) const {
// Native support for Llama 3.1 frequency-aware RoPE scaling
// equivalent to SGLang/HuggingFace implementations

// 1. Validate and extract Llama3 specific parameters
const std::vector<std::string> required_keys = {
"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"};
for (const auto &key : required_keys) {
if (!rope_scaling.contains(key)) {
throw std::runtime_error("Llama3RoPE requires '" + key + "' in rope_scaling");
}
}

const double factor = rope_scaling["factor"].get<double>();
const double low_freq_factor = rope_scaling["low_freq_factor"].get<double>();
const double high_freq_factor = rope_scaling["high_freq_factor"].get<double>();
const size_t orig_max_pos = rope_scaling["original_max_position_embeddings"].get<size_t>();

// 2. Validate and extract model base parameters
if (!config_json.contains("hidden_size") || !config_json.contains("num_attention_heads")) {
throw std::runtime_error("Llama3RoPE requires 'hidden_size' and 'num_attention_heads' in config");
}
const size_t hidden_size = config_json["hidden_size"].get<size_t>();
const size_t num_heads = config_json["num_attention_heads"].get<size_t>();
const size_t head_dim = hidden_size / num_heads;
const double theta = config_json.value("rope_theta", 10000.0);

auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
// 3. Pre-compute smooth factors based on wavelength
constexpr double kPi = 3.14159265358979323846;
const size_t cache_dim = head_dim / 2;
const double low_freq_wavelen = static_cast<double>(orig_max_pos) / low_freq_factor;
const double high_freq_wavelen = static_cast<double>(orig_max_pos) / high_freq_factor;

float factor = 1.0f;
if (rope_scaling.contains("factor")) {
factor = rope_scaling["factor"].get<float>();
const bool has_smooth_range = (high_freq_factor != low_freq_factor);
const double smooth_denom = has_smooth_range ? (high_freq_factor - low_freq_factor) : 1.0;

std::vector<float> smooth_factors(cache_dim);
for (size_t j = 0; j < cache_dim; ++j) {
const double exponent = 2.0 * static_cast<double>(j) / static_cast<double>(head_dim);
const double inv_freq = 1.0 / std::pow(theta, exponent);
const double wavelen = 2.0 * kPi / inv_freq;

if (wavelen < high_freq_wavelen) {
// High frequency: no scaling (freq_scale = 1.0)
smooth_factors[j] = 1.0f;
} else if (wavelen > low_freq_wavelen) {
// Low frequency: full scaling (freq_scale = 1.0 / factor)
smooth_factors[j] = static_cast<float>(factor);
} else {
// Mid frequency: smooth frequency interpolation
//
// Equivalent to SGLang's implementation:
// smooth = (orig_max_pos / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
// new_freq = (1 - smooth) * inv_freq / factor + smooth * inv_freq
// => freq_scale = (1 - smooth) / factor + smooth
//
// Since LongRopeConfig applies factors as wavelength multipliers
// (i.e., new_freq = inv_freq / scale), the required smooth_factor
// is the inverse of the frequency scale.
//
double smooth = 0.0;
if (has_smooth_range) {
smooth = (static_cast<double>(orig_max_pos) / wavelen - low_freq_factor) / smooth_denom;
}
const double freq_scale = (1.0 - smooth) / factor + smooth;
smooth_factors[j] = static_cast<float>(1.0 / freq_scale);
}
}

return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
original_max_position_embeddings,
factor);
} else if (type_str == "default" || type_str == "none" || type_str == "dynamic") {
// Default scaling, no scaling applied
// Currently not handling extended sequence lengths for dynamic scaling. Add specific branches when needed.
return nullptr;
} else {
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
// 4. Adapt to LongRopeConfig
// - short_factor and long_factor use the same smooth_factors
// - Pass factor=1.0f to bypass the amplitude scaling sqrt(log(...)) in LongRopeConfig constructor
return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
smooth_factors, // short_factor
smooth_factors, // long_factor
orig_max_pos,
1.0f // Force 1.0f to disable amplitude scaling
);
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::createLongRopeScaling(const nlohmann::json &rope_scaling) const {
const std::vector<std::string> required_keys = {"short_factor", "long_factor", "original_max_position_embeddings"};
for (const auto &key : required_keys) {
if (!rope_scaling.contains(key)) {
throw std::runtime_error("LongRopeConfig requires '" + key + "' in rope_scaling");
}
}

auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
const size_t orig_max_pos = rope_scaling["original_max_position_embeddings"].get<size_t>();
const float factor = rope_scaling.value("factor", 1.0f);

return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
orig_max_pos,
factor);
}

infinicore::DataType ModelConfig::get_dtype() const {
Expand Down
6 changes: 6 additions & 0 deletions csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,11 @@ class ModelConfig {
private:
nlohmann::json config_json;
QuantConfig quant_config;

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> createLlama3Scaling(const nlohmann::json &rope_scaling) const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> createLongRopeScaling(const nlohmann::json &rope_scaling) const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> createDefaultScaling(const nlohmann::json &rope_scaling) const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> createNoneScaling(const nlohmann::json &rope_scaling) const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> createDynamicScaling(const nlohmann::json &rope_scaling) const;
};
} // namespace infinilm::config