Context
When loading GraniteSwitchForCausalLM on a machine with multi-GPU with device_map="auto" (required by BitsAndBytes quantization and on parallelization), accelerate raises:
ValueError: The device_map provided does not give any device for the following parameters:
model.adapter_token_ids, model.token_to_group_mask, model.adapter_hiding_matrix
Why accelerate is involved
transformers uses the accelerate library internally whenever device_map="auto" is passed to from_pretrained(). The call chain is:
AutoModelForCausalLM.from_pretrained(..., device_map="auto")
→ transformers calls accelerate.infer_auto_device_map()
→ accelerate enumerates model.state_dict().keys()
→ accelerate assigns each key to a device
→ fails on buffer keys it can't place
BitsAndBytes quantization requires device_map="auto" (transformers enforces this), which is why the error surfaces in the quantization test. However, this bug would also appear in any HF multi-GPU or CPU-offloading scenario that uses device_map.
Root cause
The three tensors are registered as persistent buffers on GraniteSwitchModel. Because they appear in state_dict(), accelerate's device-map logic tries to assign them a device — but fails because they're top-level model buffers not covered by the auto-partitioning heuristic (only modules in _no_split_modules get properly placed).
Suggested Fix: Make buffers non-persistent
These buffers are derived entirely from config at __init__ time (not learned weights). Making them persistent=False means:
- They no longer appear in
state_dict() → accelerate ignores them
- They still follow
.to(device) (PyTorch guarantees this for non-persistent buffers)
- They're reconstructed from config every time the model is instantiated, so no information is lost
Compare with vLLM backend implementation as it already uses persistent=False for similar metadata buffers (lora_kernel_meta.py).
Changes
1. HF model — src/granite_switch/hf/modeling_granite_switch.py
Add persistent=False to all 4 register_buffer calls (lines 180, 186, 211, 216):
adapter_token_ids (both branches)
token_to_group_mask
adapter_hiding_matrix
Update the comment at line 170 to mention non-persistence.
2. HF base class — same file, line ~146
Add _keys_to_ignore_on_load_unexpected to GraniteSwitchPreTrainedModel so old checkpoints that still contain these keys don't produce warnings:
_keys_to_ignore_on_load_unexpected = [
r"model\.adapter_token_ids",
r"model\.token_to_group_mask",
r"model\.adapter_hiding_matrix",
]
Minimal reproduction
Run this code on a pod with multiple GPUs
import torch
from transformers import AutoModelForCausalLM
import granite_switch.hf # noqa: F401 — registers the model with AutoModel
MODEL_ID = "ibm-granite/granite-switch-4.1-3b-preview"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
# ValueError: The device_map provided does not give any device for the following
# parameters: model.adapter_token_ids, model.token_to_group_mask,
# model.adapter_hiding_matrix
Reproduction via test suite
The following test fails with device_map="auto":
pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=long
When forced to a single GPU (bypassing accelerate's device_map logic), the same test passes:
CUDA_VISIBLE_DEVICES=0 pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=long
This confirms the issue is specifically in accelerate's multi-device buffer placement, not in the model logic itself.
Verification (after fix)
All tests pass with this config running on multi GPU machine
pytest tests/hf/ -v --tb=short -x -n auto
Context
When loading
GraniteSwitchForCausalLMon a machine with multi-GPU withdevice_map="auto"(required by BitsAndBytes quantization and on parallelization), accelerate raises:Why accelerate is involved
transformersuses theacceleratelibrary internally wheneverdevice_map="auto"is passed tofrom_pretrained(). The call chain is:BitsAndBytes quantization requires
device_map="auto"(transformers enforces this), which is why the error surfaces in the quantization test. However, this bug would also appear in any HF multi-GPU or CPU-offloading scenario that usesdevice_map.Root cause
The three tensors are registered as persistent buffers on
GraniteSwitchModel. Because they appear instate_dict(), accelerate's device-map logic tries to assign them a device — but fails because they're top-level model buffers not covered by the auto-partitioning heuristic (only modules in_no_split_modulesget properly placed).Suggested Fix: Make buffers non-persistent
These buffers are derived entirely from config at
__init__time (not learned weights). Making thempersistent=Falsemeans:state_dict()→ accelerate ignores them.to(device)(PyTorch guarantees this for non-persistent buffers)Compare with vLLM backend implementation as it already uses
persistent=Falsefor similar metadata buffers (lora_kernel_meta.py).Changes
1. HF model —
src/granite_switch/hf/modeling_granite_switch.pyAdd
persistent=Falseto all 4register_buffercalls (lines 180, 186, 211, 216):adapter_token_ids(both branches)token_to_group_maskadapter_hiding_matrixUpdate the comment at line 170 to mention non-persistence.
2. HF base class — same file, line ~146
Add
_keys_to_ignore_on_load_unexpectedtoGraniteSwitchPreTrainedModelso old checkpoints that still contain these keys don't produce warnings:Minimal reproduction
Run this code on a pod with multiple GPUs
Reproduction via test suite
The following test fails with
device_map="auto":pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=longWhen forced to a single GPU (bypassing accelerate's device_map logic), the same test passes:
CUDA_VISIBLE_DEVICES=0 pytest "tests/hf/test_quantization.py::TestBnBAdapterActivation::test_adapter_activates[hallucination_detection(lora)]" -v -s --tb=longThis confirms the issue is specifically in accelerate's multi-device buffer placement, not in the model logic itself.
Verification (after fix)
All tests pass with this config running on multi GPU machine