Skip to content

Add HF multi-GPU auto device mode for GraniteSwitchForCausalLM #23

@antonpibm

Description

@antonpibm

Context

When loading GraniteSwitchForCausalLM on a machine with multi-GPU with device_map="auto" (required by BitsAndBytes quantization and on parallelization), accelerate raises:

ValueError: The device_map provided does not give any device for the following parameters:
model.adapter_token_ids, model.token_to_group_mask, model.adapter_hiding_matrix

Why accelerate is involved

transformers uses the accelerate library internally whenever device_map="auto" is passed to from_pretrained(). The call chain is:

AutoModelForCausalLM.from_pretrained(..., device_map="auto")
  → transformers calls accelerate.infer_auto_device_map()
  → accelerate enumerates model.state_dict().keys()
  → accelerate assigns each key to a device
  → fails on buffer keys it can't place

BitsAndBytes quantization requires device_map="auto" (transformers enforces this), which is why the error surfaces in the quantization test. However, this bug would also appear in any HF multi-GPU or CPU-offloading scenario that uses device_map.

Root cause

The three tensors are registered as persistent buffers on GraniteSwitchModel. Because they appear in state_dict(), accelerate's device-map logic tries to assign them a device — but fails because they're top-level model buffers not covered by the auto-partitioning heuristic (only modules in _no_split_modules get properly placed).

Suggested Fix: Make buffers non-persistent

These buffers are derived entirely from config at __init__ time (not learned weights). Making them persistent=False means:

  • They no longer appear in state_dict() → accelerate ignores them
  • They still follow .to(device) (PyTorch guarantees this for non-persistent buffers)
  • They're reconstructed from config every time the model is instantiated, so no information is lost

Compare with vLLM backend implementation as it already uses persistent=False for similar metadata buffers (lora_kernel_meta.py).

Changes

1. HF model — src/granite_switch/hf/modeling_granite_switch.py

Add persistent=False to all 4 register_buffer calls (lines 180, 186, 211, 216):

  • adapter_token_ids (both branches)
  • token_to_group_mask
  • adapter_hiding_matrix

Update the comment at line 170 to mention non-persistence.

2. HF base class — same file, line ~146

Add _keys_to_ignore_on_load_unexpected to GraniteSwitchPreTrainedModel so old checkpoints that still contain these keys don't produce warnings:

_keys_to_ignore_on_load_unexpected = [
    r"model\.adapter_token_ids",
    r"model\.token_to_group_mask",
    r"model\.adapter_hiding_matrix",
]

Minimal reproduction

Run this code on a pod with multiple GPUs

import torch
from transformers import AutoModelForCausalLM
import granite_switch.hf  # noqa: F401 — registers the model with AutoModel

MODEL_ID = "ibm-granite/granite-switch-4.1-3b-preview"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
)
# ValueError: The device_map provided does not give any device for the following
# parameters: model.adapter_token_ids, model.token_to_group_mask,
# model.adapter_hiding_matrix

Reproduction via test suite

The following test fails with device_map="auto":

pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=long

When forced to a single GPU (bypassing accelerate's device_map logic), the same test passes:

CUDA_VISIBLE_DEVICES=0 pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=long

This confirms the issue is specifically in accelerate's multi-device buffer placement, not in the model logic itself.

Verification (after fix)

All tests pass with this config running on multi GPU machine

pytest tests/hf/ -v --tb=short -x -n auto

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions