Title: [Refactor] Decouple RoPE Scaling via Factory Pattern and Introduce TP-Safe RoPE Cache
Description:
Background & Motivation
The current RoPE instantiation logic in InfiniLM has two major architectural limitations:
- Tight Coupling & OCP Violation: The
ModelConfig::get_rope_scaling() method directly parses JSON and hardcodes the instantiation of InfiniCore scaling objects (e.g., LongRopeConfig) using if-else branches. This makes it cumbersome to extend support for new scaling types (like Llama 3 or YaRN) without modifying ModelConfig.
- Memory Inefficiency: Previously, every transformer layer instantiated its own
RoPE object and allocated separate sin/cos caches on the GPU, even though layers within the same device share identical RoPE configurations, leading to redundant VRAM usage.
Refactoring Direction
To address these issues, we are refactoring the RoPE initialization pipeline using the Factory Pattern with a Registry, and introducing a TP-safe caching mechanism for better memory efficiency.
Detailed Changes
1. Introduce Factory & Registry (rotary_embedding_factory.hpp/cpp, rope_scaling_creators.cpp)
- Factory: Added
make_scaling_config(), which routes the rope_scaling_type string from JSON to the corresponding registered creator function.
- Registry & Creators: Added
get_scaling_registry() mapping type strings to creator functions (e.g., create_longrope). This isolates scaling-specific JSON parsing into independent functions, completely decoupling ModelConfig from InfiniCore scaling implementations.
2. Introduce TP-Safe RoPE Cache (rotary_embedding.cpp)
- Performance Optimization: Introduced a
_ROPE_DICT cache to store constructed RoPE objects. Since multiple attention layers share the same RoPE parameters, this cache ensures that we only allocate one set of sin/cos tables per unique configuration, significantly saving GPU VRAM.
- TP-Safety Design: InfiniLM implements Tensor Parallelism using multi-threading. To make the cache safe and correct under
--tp>1:
- The cache is declared as
thread_local to prevent data races during concurrent writes from different TP ranks without requiring heavy mutexes.
- The
cache_key includes the device.id() (e.g., ..._dev_cuda:0), ensuring that different TP ranks strictly construct and use RoPE objects on their respective devices, avoiding cross-device memory access errors.
3. Clean up ModelConfig (model_config.hpp/cpp)
- Removed:
ModelConfig::get_rope_scaling(). ModelConfig now acts purely as a data provider, unaware of InfiniCore scaling objects.
- Added:
ModelConfig::get_rotary_dim(). Centralized the logic for calculating rotary dimensions (handling partial_rotary_factor, clamping, and ensuring even dimensions) into ModelConfig, removing the old standalone helper function.
Action Items / TODOs
Title: [Refactor] Decouple RoPE Scaling via Factory Pattern and Introduce TP-Safe RoPE Cache
Description:
Background & Motivation
The current RoPE instantiation logic in InfiniLM has two major architectural limitations:
ModelConfig::get_rope_scaling()method directly parses JSON and hardcodes the instantiation of InfiniCore scaling objects (e.g.,LongRopeConfig) usingif-elsebranches. This makes it cumbersome to extend support for new scaling types (like Llama 3 or YaRN) without modifyingModelConfig.RoPEobject and allocated separatesin/coscaches on the GPU, even though layers within the same device share identical RoPE configurations, leading to redundant VRAM usage.Refactoring Direction
To address these issues, we are refactoring the RoPE initialization pipeline using the Factory Pattern with a Registry, and introducing a TP-safe caching mechanism for better memory efficiency.
Detailed Changes
1. Introduce Factory & Registry (
rotary_embedding_factory.hpp/cpp,rope_scaling_creators.cpp)make_scaling_config(), which routes therope_scaling_typestring from JSON to the corresponding registered creator function.get_scaling_registry()mapping type strings to creator functions (e.g.,create_longrope). This isolates scaling-specific JSON parsing into independent functions, completely decouplingModelConfigfrom InfiniCore scaling implementations.2. Introduce TP-Safe RoPE Cache (
rotary_embedding.cpp)_ROPE_DICTcache to store constructedRoPEobjects. Since multiple attention layers share the same RoPE parameters, this cache ensures that we only allocate one set ofsin/costables per unique configuration, significantly saving GPU VRAM.--tp>1:thread_localto prevent data races during concurrent writes from different TP ranks without requiring heavy mutexes.cache_keyincludes thedevice.id()(e.g.,..._dev_cuda:0), ensuring that different TP ranks strictly construct and useRoPEobjects on their respective devices, avoiding cross-device memory access errors.3. Clean up
ModelConfig(model_config.hpp/cpp)ModelConfig::get_rope_scaling().ModelConfignow acts purely as a data provider, unaware of InfiniCore scaling objects.ModelConfig::get_rotary_dim(). Centralized the logic for calculating rotary dimensions (handlingpartial_rotary_factor, clamping, and ensuring even dimensions) intoModelConfig, removing the old standalone helper function.Action Items / TODOs
init_scaling_registry()is explicitly called in the Python C extension entry point (PyInit__infinilm).create_llama3inrope_scaling_creators.cppand register it to support Llama-3/3.1 models.