Skip to content

str_tokens is a single joined string (not per-token) on transformers 5.x — token highlighting silently broken in explainer and classifier scorers #176

Description

@LakeSJS

Summary

The sampler and constructor set example.str_tokens = tokenizer.batch_decode(example.tokens), where example.tokens is a 1-D tensor of shape (ctx_len,). On transformers ≥ 5, batch_decode forwards to decode(), which treats 1-D input as a single sequence — so it returns a length-1 list containing one joined string instead of the per-token list delphi assumes. On transformers 4.x, batch_decode was [self.decode(seq) for seq in sequences], and iterating a 1-D tensor yielded per-token strings.

Since pyproject.toml declares transformers unpinned, fresh installs get 5.x and hit this.

Consequences (mostly silent — no error)

  • Explainer: _highlight / _join_activations (explainers/explainer.py:75-116) iterate str_toks against per-token activation lists (length ctx_len). With len(str_tokens) == 1 they emit no <<>> marks and an empty Activations: line, so the explainer never sees which tokens activate.
  • Detection/fuzz scorers: scorers/classifier/sample.py:99 reads example.str_tokens the same way, so fuzzing prompts lose per-token highlighting (the fuzz scorer's premise is per-token marks) → silently degraded scores.
  • Simulation scorer: scorers/simulator/simulation/oai_simulator.py:103 builds ActivationRecord(self.tokenizer.batch_decode(example.tokens), activations) — a length-1 token list zipped against per-token activations.
  • Intruder scorer (loud symptom): scorers/classifier/intruder.py:122 indexes str_tokens[index] with token positions, which raises IndexError for any index > 0.

Minimal repro (transformers 5.5.4)

from transformers import AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained("bert-base-uncased")
ids = torch.arange(1000, 1010)
print(len(tok.batch_decode(ids)))           # 5.x -> 1   (4.x -> 10)
print(len([tok.decode(i) for i in ids]))    # -> 10  (per-token, matches 4.x batch_decode)

Affected sites

  • latents/samplers.py:131 and latents/samplers.py:142 (train/test examples)
  • latents/constructors.py:53 (non-activating examples)
  • scorers/simulator/simulation/oai_simulator.py:103

(latents/constructors.py:428 and :465 use the same pattern but immediately "".join(...) the result, so they still produce the intended text and are not functionally affected.)

Suggested fix

Replace batch_decode(example.tokens) at the affected sites with a per-token decode, which exactly restores the 4.x behavior:

example.str_tokens = [tokenizer.decode(t) for t in example.tokens]
# or equivalently: tokenizer.batch_decode(example.tokens.unsqueeze(-1))

Note that tokenizer.convert_ids_to_tokens(...) is not a drop-in fix: it returns raw vocab strings, which for BPE tokenizers (GPT-2/Llama-style) injects encoding artifacts into prompts — e.g. ['Hello', 'Ġworld', ',', 'Ġtesting'] instead of ['Hello', ' world', ',', ' testing'] — and would also corrupt the scorers that "".join(str_tokens) (embedding, surprisal). For WordPiece tokenizers the two happen to be identical, but per-token decode is correct for all tokenizer types.

Alternatively/additionally, pinning transformers<5 in pyproject.toml would prevent the silent failure until the call sites are migrated.

Env: delphi @ a70e7ee, transformers 5.5.4, torch 2.11.0.


Found while debugging unhighlighted explainer prompts in a downstream pipeline; behavior verified against both the 4.x and 5.x batch_decode implementations. Investigation was AI-assisted (Claude Code).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions