From aed0e64a4df4ee35349890df6e669976b6fc3dda Mon Sep 17 00:00:00 2001 From: robbiebusinessacc <65429016+robbiebusinessacc@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:03:31 -0400 Subject: [PATCH] fix: decode str_tokens per token for transformers >= 5 On transformers >= 5, batch_decode treats a 1-D tensor as a single sequence and returns one joined string instead of per-token strings. This silently broke token highlighting in the explainer and classifier scorers and raised IndexError in the intruder scorer. Add decode_per_token to delphi.utils and use it at the four sites that need per-token strings (samplers, non-activating constructor, OpenAI simulator). Fixes #176 --- delphi/latents/constructors.py | 3 +- delphi/latents/samplers.py | 5 +-- .../simulator/simulation/oai_simulator.py | 3 +- delphi/utils.py | 20 ++++++++++++ tests/test_utils.py | 31 +++++++++++++++++++ 5 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 tests/test_utils.py diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index a65bcb32..25893a36 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from delphi import logger +from delphi.utils import decode_per_token from ..config import ConstructorConfig from .latents import ( @@ -50,7 +51,7 @@ def prepare_non_activating_examples( tokens=toks, activations=acts, distance=distance, - str_tokens=tokenizer.batch_decode(toks), + str_tokens=decode_per_token(tokenizer, toks), ) for toks, acts in zip(tokens, activations) ] diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 3fa8d64e..657dd254 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -7,6 +7,7 @@ ) from delphi import logger +from delphi.utils import decode_per_token from ..config import SamplerConfig from .latents import ActivatingExample, LatentRecord @@ -128,7 +129,7 @@ def sampler( # Moved tokenization to sampler to avoid tokenizing # examples that are not going to be used for example in _train: - example.str_tokens = tokenizer.batch_decode(example.tokens) + example.str_tokens = decode_per_token(tokenizer, example.tokens) record.train = _train if cfg.n_examples_test > 0: _test = test( @@ -139,6 +140,6 @@ def sampler( cfg.test_type, ) for example in _test: - example.str_tokens = tokenizer.batch_decode(example.tokens) + example.str_tokens = decode_per_token(tokenizer, example.tokens) record.test = _test return record diff --git a/delphi/scorers/simulator/simulation/oai_simulator.py b/delphi/scorers/simulator/simulation/oai_simulator.py index 033c492d..e2c25f42 100644 --- a/delphi/scorers/simulator/simulation/oai_simulator.py +++ b/delphi/scorers/simulator/simulation/oai_simulator.py @@ -13,6 +13,7 @@ from delphi.latents.latents import ActivatingExample, NonActivatingExample from delphi.scorers.scorer import Scorer, ScorerResult +from delphi.utils import decode_per_token from .data_models import ActivationRecord from .scoring import simulate_and_score @@ -100,7 +101,7 @@ def to_activation_records( result.append( ActivationRecord( - self.tokenizer.batch_decode(example.tokens), + decode_per_token(self.tokenizer, example.tokens), activations, quantile=( example.quantile diff --git a/delphi/utils.py b/delphi/utils.py index 4b062fd8..fcc36c23 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -68,6 +68,26 @@ def load_tokenized_data( return tokens +def decode_per_token( + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + tokens: Tensor, +) -> list[str]: + """ + Decode a 1-D tensor of token ids into one string per token. + + On transformers >= 5, ``batch_decode`` treats a 1-D tensor as a single + sequence and returns one joined string, so it can no longer be used to + get per-token strings. ``convert_ids_to_tokens`` is not a substitute + either: it returns raw vocab strings, which for BPE tokenizers contain + encoding artifacts such as "Ġ". + + Args: + tokenizer: The tokenizer to use. + tokens: A 1-D tensor of token ids of shape (ctx_len,). + """ + return tokenizer.batch_decode(tokens.unsqueeze(-1)) + + T = TypeVar("T") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..d9e3bd33 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,31 @@ +import pytest +import torch +from transformers import AutoTokenizer + +from delphi.utils import decode_per_token + + +@pytest.fixture(scope="module") +def pythia_tokenizer(): + return AutoTokenizer.from_pretrained("EleutherAI/pythia-70m") + + +def test_decode_per_token_is_per_token(pythia_tokenizer): + tokens = torch.tensor(pythia_tokenizer("Hello world, this is a test.")["input_ids"]) + str_tokens = decode_per_token(pythia_tokenizer, tokens) + + # One string per token id, unlike batch_decode on transformers >= 5, + # which joins a 1-D tensor into a single string. + assert len(str_tokens) == len(tokens) + for token_id, str_token in zip(tokens, str_tokens): + assert str_token == pythia_tokenizer.decode(token_id) + + +def test_decode_per_token_roundtrips_text(pythia_tokenizer): + text = "Hello world, this is a test." + tokens = torch.tensor(pythia_tokenizer(text)["input_ids"]) + str_tokens = decode_per_token(pythia_tokenizer, tokens) + + # BPE space markers must come back as real spaces ("Ġworld" -> " world"), + # so the per-token strings concatenate to the original text. + assert "".join(str_tokens) == text