Skip to content

[Refactor] Decouple RoPE Scaling via Factory Pattern and Introduce TP-Safe RoPE Cache #401

@rubik-hua

Description

@rubik-hua

Title: [Refactor] Decouple RoPE Scaling via Factory Pattern and Introduce TP-Safe RoPE Cache
Description:

Background & Motivation

The current RoPE instantiation logic in InfiniLM has two major architectural limitations:

  1. Tight Coupling & OCP Violation: The ModelConfig::get_rope_scaling() method directly parses JSON and hardcodes the instantiation of InfiniCore scaling objects (e.g., LongRopeConfig) using if-else branches. This makes it cumbersome to extend support for new scaling types (like Llama 3 or YaRN) without modifying ModelConfig.
  2. Memory Inefficiency: Previously, every transformer layer instantiated its own RoPE object and allocated separate sin/cos caches on the GPU, even though layers within the same device share identical RoPE configurations, leading to redundant VRAM usage.

Refactoring Direction

To address these issues, we are refactoring the RoPE initialization pipeline using the Factory Pattern with a Registry, and introducing a TP-safe caching mechanism for better memory efficiency.

Detailed Changes

1. Introduce Factory & Registry (rotary_embedding_factory.hpp/cpp, rope_scaling_creators.cpp)

  • Factory: Added make_scaling_config(), which routes the rope_scaling_type string from JSON to the corresponding registered creator function.
  • Registry & Creators: Added get_scaling_registry() mapping type strings to creator functions (e.g., create_longrope). This isolates scaling-specific JSON parsing into independent functions, completely decoupling ModelConfig from InfiniCore scaling implementations.

2. Introduce TP-Safe RoPE Cache (rotary_embedding.cpp)

  • Performance Optimization: Introduced a _ROPE_DICT cache to store constructed RoPE objects. Since multiple attention layers share the same RoPE parameters, this cache ensures that we only allocate one set of sin/cos tables per unique configuration, significantly saving GPU VRAM.
  • TP-Safety Design: InfiniLM implements Tensor Parallelism using multi-threading. To make the cache safe and correct under --tp>1:
    • The cache is declared as thread_local to prevent data races during concurrent writes from different TP ranks without requiring heavy mutexes.
    • The cache_key includes the device.id() (e.g., ..._dev_cuda:0), ensuring that different TP ranks strictly construct and use RoPE objects on their respective devices, avoiding cross-device memory access errors.

3. Clean up ModelConfig (model_config.hpp/cpp)

  • Removed: ModelConfig::get_rope_scaling(). ModelConfig now acts purely as a data provider, unaware of InfiniCore scaling objects.
  • Added: ModelConfig::get_rotary_dim(). Centralized the logic for calculating rotary dimensions (handling partial_rotary_factor, clamping, and ensuring even dimensions) into ModelConfig, removing the old standalone helper function.

Action Items / TODOs

  • Ensure init_scaling_registry() is explicitly called in the Python C extension entry point (PyInit__infinilm).
  • Implement create_llama3 in rope_scaling_creators.cpp and register it to support Llama-3/3.1 models.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions