Skip to content

Commit 8bbb0ed

Browse files
committed
Add text processing utility for token clipping
- Introduced a new utility function `clip_to_512_tokens` in `text.py` to truncate input text to a maximum of 512 tokens, ensuring compatibility with NV-Embed model requirements. - Implemented a regex-based tokenizer to handle word and punctuation boundaries, providing a lightweight solution for tokenization. - Added unit tests in `test_text.py` to validate the functionality of the clipping method, including edge cases for token counts and handling of Unicode punctuation.
1 parent 79590c1 commit 8bbb0ed

2 files changed

Lines changed: 99 additions & 0 deletions

File tree

app/utils/text.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
"""Text processing helpers used across the KillrVideo backend."""
4+
5+
import re
6+
7+
__all__ = ["clip_to_512_tokens"]
8+
9+
# ---------------------------------------------------------------------------
10+
# Basic tokenizer
11+
# ---------------------------------------------------------------------------
12+
# The NV-Embed provider enforces a hard limit of **512 tokens** for both
13+
# `$vectorize` requests and vector-enabled queries. We approximate token
14+
# boundaries using a lightweight regex that splits on standard *word* chunks
15+
# while treating punctuation and symbols as individual tokens. This is **not**
16+
# an exact match of the provider's internal SentencePiece model but is
17+
# sufficiently close for defensive clipping.
18+
#
19+
# • Consecutive whitespace is ignored (no empty tokens).
20+
# • Unicode punctuation characters are captured as standalone tokens.
21+
# • The pattern is Unicode-aware through the `re.UNICODE` flag (default in
22+
# Python 3 but kept explicit).
23+
# ---------------------------------------------------------------------------
24+
TOKEN_RE = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE)
25+
MAX_TOKENS_NV_EMBED = 512
26+
27+
28+
def clip_to_512_tokens(text: str) -> str: # noqa: D401
29+
"""Return *text* truncated to **≤512** tokens.
30+
31+
If the input already fits under the limit it is returned unchanged. For
32+
longer inputs the first 512 tokens are kept and re-joined using a single
33+
space so downstream code works with a valid plain-text string.
34+
35+
Parameters
36+
----------
37+
text:
38+
Arbitrary input string (may contain newlines, tabs, or Unicode
39+
punctuation).
40+
41+
Returns
42+
-------
43+
str
44+
The (possibly) truncated string, guaranteed to be ≤512 tokens when
45+
tokenised via :pydata:`TOKEN_RE`.
46+
"""
47+
48+
if not text:
49+
return text # Early-exit for empty or ``None``‐like strings
50+
51+
tokens = TOKEN_RE.findall(text)
52+
53+
if len(tokens) <= MAX_TOKENS_NV_EMBED:
54+
# No truncation needed – preserve original spacing to avoid surprising
55+
# callers that might rely on exact text equality (e.g., hashing).
56+
return text
57+
58+
clipped_tokens = tokens[:MAX_TOKENS_NV_EMBED]
59+
60+
# Re-join tokens with a single space. This canonical form is sufficient
61+
# for embedding purposes and avoids the complexity of reconstructing the
62+
# original whitespace layout.
63+
return " ".join(clipped_tokens)

tests/utils/test_text.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import re
2+
3+
import pytest
4+
5+
from app.utils.text import clip_to_512_tokens
6+
7+
TOKEN_RE = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE)
8+
9+
10+
def _count_tokens(s: str) -> int:
11+
return len(TOKEN_RE.findall(s))
12+
13+
14+
@pytest.mark.parametrize("token_count", [600, 513, 512, 100])
15+
def test_clip_to_512_tokens(token_count: int):
16+
"""Ensure text is clipped to 512 tokens when necessary."""
17+
18+
input_text = " ".join(f"tok{i}" for i in range(token_count))
19+
result = clip_to_512_tokens(input_text)
20+
21+
if token_count <= 512:
22+
assert result == input_text
23+
assert _count_tokens(result) == token_count
24+
else:
25+
assert _count_tokens(result) == 512
26+
assert result.split()[0] == "tok0"
27+
assert result.split()[-1] == "tok511"
28+
29+
30+
def test_unicode_and_whitespace():
31+
"""Function should leave short text with Unicode punctuation unchanged."""
32+
33+
text = "你好, 世界! Hello — world…"
34+
clipped = clip_to_512_tokens(text)
35+
assert clipped == text
36+
assert _count_tokens(clipped) < 512

0 commit comments

Comments
 (0)