issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache#402
Open
rubik-hua wants to merge 1 commit into
Open
issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache#402rubik-hua wants to merge 1 commit into
rubik-hua wants to merge 1 commit into
Conversation
- Decouple scaling config instantiation from ModelConfig via factory and registry pattern. - Add thread-local RoPE cache with device-scoped keys to reduce VRAM usage and ensure TP safety. - Centralize rotary dimension calculation into ModelConfig.
Author
|
@wooway777 @pengcheng888 rope重构可以帮忙检视起来了,infinicore上也有一个pr |
Collaborator
谢谢老师,在看了 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
ModelConfig 扩展(纯粹的数据承载)
ModelConfig 不再掺杂任何具体模型的业务判断,仅提供默认值和读写接口。RoPE::Algo 的差异由具体的模型构建入口(如 csrc/models/chatglm/chatglm_for_causal_lm.cpp)显式指定:
// model_config.hpp
class ModelConfig {
private:
infinicore::nn::RoPE::Algo rope_algo_ = infinicore::nn::RoPE::Algo::GPT_NEOX; // 默认值
public:
infinicore::nn::RoPE::Algo get_rope_algo() const { return rope_algo_; }
};
// csrc/models/chatglm/chatglm_for_causal_lm.cpp
std::shared_ptr create_chatglm_config(const json& hf_config) {
auto config = std::make_shared(hf_config);
// 只有 ChatGLM/GLM4 需要 GPT_J,在此处显式注入,不污染基类
config->set_rope_algo(infinicore::nn::RoPE::Algo::GPT_J);
return config;
}
工厂与注册表机制(字符串路由分发)
引入注册表模式,将 JSON 中的字符串(如 "longrope")映射到具体的对象构造逻辑,替代冗长的 if-else。
// rotary_embedding_factory.hpp
using ScalingCreator = std::function<std::shared_ptrinfinicore::nn::ScalingConfig(
const std::shared_ptrinfinilm::config::ModelConfig&)>;
std::unordered_map<std::string, ScalingCreator>& get_scaling_registry();
std::shared_ptrinfinicore::nn::RoPE make_rope(/* ... */);
工厂核心实现极简,仅负责组装与路由,不因为新增类型而修改:
// rotary_embedding_factory.cpp
std::shared_ptrinfinicore::nn::ScalingConfig
make_scaling_config(const std::shared_ptrinfinilm::config::ModelConfig& model_config) {
std::string scaling_type = model_config->get_orstd::string("rope_scaling_type", "default");
}
需要与InfiniCore的下面PR一起合入,
InfiniTensor/InfiniCore#1181
重构后,新增的rope实现都集中在csrc/layers/rotary_embedding/rope_scaling_creators.cpp增加,其它地方无需修改,跟rotary_embedding.cpp和model_config.cpp解耦掉。
之前@pengcheng888 给的建议是把algo参数收编进model_config中,然后在xx_for_causal_lm.cpp 中写入,我实现了一版,但感觉特别别扭,我理解model_config还是纯粹一点好,能从json中读出来或者加工出来的。后来,我又改动了一下,还是直接放到运行时传参更加优雅吧。
重构后所有现有支持的模型已经跑通