Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion delphi/latents/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from delphi import logger
from delphi.utils import decode_per_token

from ..config import ConstructorConfig
from .latents import (
Expand Down Expand Up @@ -50,7 +51,7 @@ def prepare_non_activating_examples(
tokens=toks,
activations=acts,
distance=distance,
str_tokens=tokenizer.batch_decode(toks),
str_tokens=decode_per_token(tokenizer, toks),
)
for toks, acts in zip(tokens, activations)
]
Expand Down
5 changes: 3 additions & 2 deletions delphi/latents/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)

from delphi import logger
from delphi.utils import decode_per_token

from ..config import SamplerConfig
from .latents import ActivatingExample, LatentRecord
Expand Down Expand Up @@ -128,7 +129,7 @@ def sampler(
# Moved tokenization to sampler to avoid tokenizing
# examples that are not going to be used
for example in _train:
example.str_tokens = tokenizer.batch_decode(example.tokens)
example.str_tokens = decode_per_token(tokenizer, example.tokens)
record.train = _train
if cfg.n_examples_test > 0:
_test = test(
Expand All @@ -139,6 +140,6 @@ def sampler(
cfg.test_type,
)
for example in _test:
example.str_tokens = tokenizer.batch_decode(example.tokens)
example.str_tokens = decode_per_token(tokenizer, example.tokens)
record.test = _test
return record
3 changes: 2 additions & 1 deletion delphi/scorers/simulator/simulation/oai_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from delphi.latents.latents import ActivatingExample, NonActivatingExample
from delphi.scorers.scorer import Scorer, ScorerResult
from delphi.utils import decode_per_token

from .data_models import ActivationRecord
from .scoring import simulate_and_score
Expand Down Expand Up @@ -100,7 +101,7 @@ def to_activation_records(

result.append(
ActivationRecord(
self.tokenizer.batch_decode(example.tokens),
decode_per_token(self.tokenizer, example.tokens),
activations,
quantile=(
example.quantile
Expand Down
20 changes: 20 additions & 0 deletions delphi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ def load_tokenized_data(
return tokens


def decode_per_token(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
tokens: Tensor,
) -> list[str]:
"""
Decode a 1-D tensor of token ids into one string per token.

On transformers >= 5, ``batch_decode`` treats a 1-D tensor as a single
sequence and returns one joined string, so it can no longer be used to
get per-token strings. ``convert_ids_to_tokens`` is not a substitute
either: it returns raw vocab strings, which for BPE tokenizers contain
encoding artifacts such as "Ġ".

Args:
tokenizer: The tokenizer to use.
tokens: A 1-D tensor of token ids of shape (ctx_len,).
"""
return tokenizer.batch_decode(tokens.unsqueeze(-1))


T = TypeVar("T")


Expand Down
31 changes: 31 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch
from transformers import AutoTokenizer

from delphi.utils import decode_per_token


@pytest.fixture(scope="module")
def pythia_tokenizer():
return AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")


def test_decode_per_token_is_per_token(pythia_tokenizer):
tokens = torch.tensor(pythia_tokenizer("Hello world, this is a test.")["input_ids"])
str_tokens = decode_per_token(pythia_tokenizer, tokens)

# One string per token id, unlike batch_decode on transformers >= 5,
# which joins a 1-D tensor into a single string.
assert len(str_tokens) == len(tokens)
for token_id, str_token in zip(tokens, str_tokens):
assert str_token == pythia_tokenizer.decode(token_id)


def test_decode_per_token_roundtrips_text(pythia_tokenizer):
text = "Hello world, this is a test."
tokens = torch.tensor(pythia_tokenizer(text)["input_ids"])
str_tokens = decode_per_token(pythia_tokenizer, tokens)

# BPE space markers must come back as real spaces ("Ġworld" -> " world"),
# so the per-token strings concatenate to the original text.
assert "".join(str_tokens) == text