|
| 1 | +""" |
| 2 | +Embedding service for generating vector embeddings using IBM Granite model. |
| 3 | +
|
| 4 | +This service uses the IBM Granite-Embedding-30m-English model to generate |
| 5 | +384-dimensional embeddings for text. The model is loaded once at startup |
| 6 | +and cached in memory for fast inference. |
| 7 | +""" |
| 8 | + |
| 9 | +import logging |
| 10 | +import re |
| 11 | +from typing import List, Optional |
| 12 | + |
| 13 | +from sentence_transformers import SentenceTransformer |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class EmbeddingService: |
| 19 | + """ |
| 20 | + Singleton service for generating embeddings using IBM Granite model. |
| 21 | +
|
| 22 | + The model is loaded once at initialization and cached for fast inference. |
| 23 | + """ |
| 24 | + |
| 25 | + _instance: Optional["EmbeddingService"] = None |
| 26 | + _model: Optional[SentenceTransformer] = None |
| 27 | + |
| 28 | + MODEL_NAME = "ibm-granite/granite-embedding-30m-english" |
| 29 | + EMBEDDING_DIMENSION = 384 |
| 30 | + MAX_TOKENS = 512 |
| 31 | + |
| 32 | + def __new__(cls): |
| 33 | + """Implement singleton pattern.""" |
| 34 | + if cls._instance is None: |
| 35 | + cls._instance = super().__new__(cls) |
| 36 | + return cls._instance |
| 37 | + |
| 38 | + def __init__(self): |
| 39 | + """Initialize the embedding service.""" |
| 40 | + if self._model is None: |
| 41 | + logger.info(f"Loading embedding model: {self.MODEL_NAME}") |
| 42 | + self._model = SentenceTransformer(self.MODEL_NAME) |
| 43 | + logger.info( |
| 44 | + f"Embedding model loaded successfully. Dimension: {self.EMBEDDING_DIMENSION}" |
| 45 | + ) |
| 46 | + |
| 47 | + def _clip_to_max_tokens(self, text: str) -> str: |
| 48 | + """ |
| 49 | + Clip text to maximum token limit (512 tokens). |
| 50 | +
|
| 51 | + Uses a simple tokenizer that matches the pattern used for NVIDIA embeddings. |
| 52 | + This is a conservative approximation - actual tokenization may differ slightly. |
| 53 | +
|
| 54 | + Args: |
| 55 | + text: The input text to clip |
| 56 | +
|
| 57 | + Returns: |
| 58 | + The clipped text if over limit, otherwise original text |
| 59 | + """ |
| 60 | + # Simple tokenizer: matches word characters and punctuation |
| 61 | + token_re = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE) |
| 62 | + tokens = token_re.findall(text) |
| 63 | + |
| 64 | + if len(tokens) <= self.MAX_TOKENS: |
| 65 | + return text |
| 66 | + |
| 67 | + # Clip to max tokens and rejoin |
| 68 | + clipped_tokens = tokens[: self.MAX_TOKENS] |
| 69 | + # Find the position in original text where we should cut |
| 70 | + # This is approximate but works well enough |
| 71 | + clipped_text = " ".join(clipped_tokens) |
| 72 | + |
| 73 | + logger.warning(f"Text clipped from {len(tokens)} to {self.MAX_TOKENS} tokens") |
| 74 | + |
| 75 | + return clipped_text |
| 76 | + |
| 77 | + def generate_embedding(self, text: str, clip_tokens: bool = True) -> List[float]: |
| 78 | + """ |
| 79 | + Generate a 384-dimensional embedding vector for the given text. |
| 80 | +
|
| 81 | + Args: |
| 82 | + text: The input text to embed |
| 83 | + clip_tokens: Whether to clip text to MAX_TOKENS (default: True) |
| 84 | +
|
| 85 | + Returns: |
| 86 | + A list of 384 float values representing the embedding vector |
| 87 | +
|
| 88 | + Raises: |
| 89 | + ValueError: If text is empty or model is not loaded |
| 90 | + """ |
| 91 | + if not text or not text.strip(): |
| 92 | + raise ValueError("Cannot generate embedding for empty text") |
| 93 | + |
| 94 | + if self._model is None: |
| 95 | + raise ValueError("Embedding model not loaded") |
| 96 | + |
| 97 | + # Clip to token limit if requested |
| 98 | + if clip_tokens: |
| 99 | + text = self._clip_to_max_tokens(text) |
| 100 | + |
| 101 | + # Generate embedding |
| 102 | + # encode() returns numpy array, convert to list of floats |
| 103 | + embedding = self._model.encode(text, convert_to_numpy=True) |
| 104 | + |
| 105 | + return embedding.tolist() |
| 106 | + |
| 107 | + def generate_embeddings_batch( |
| 108 | + self, texts: List[str], clip_tokens: bool = True |
| 109 | + ) -> List[List[float]]: |
| 110 | + """ |
| 111 | + Generate embeddings for multiple texts in a batch. |
| 112 | +
|
| 113 | + This is more efficient than calling generate_embedding() multiple times |
| 114 | + as the model can process multiple texts in parallel. |
| 115 | +
|
| 116 | + Args: |
| 117 | + texts: List of input texts to embed |
| 118 | + clip_tokens: Whether to clip texts to MAX_TOKENS (default: True) |
| 119 | +
|
| 120 | + Returns: |
| 121 | + List of embedding vectors, one for each input text |
| 122 | +
|
| 123 | + Raises: |
| 124 | + ValueError: If any text is empty or model is not loaded |
| 125 | + """ |
| 126 | + if not texts: |
| 127 | + raise ValueError("Cannot generate embeddings for empty list") |
| 128 | + |
| 129 | + if self._model is None: |
| 130 | + raise ValueError("Embedding model not loaded") |
| 131 | + |
| 132 | + # Validate and clip texts |
| 133 | + processed_texts = [] |
| 134 | + for text in texts: |
| 135 | + if not text or not text.strip(): |
| 136 | + raise ValueError("Cannot generate embedding for empty text") |
| 137 | + |
| 138 | + if clip_tokens: |
| 139 | + text = self._clip_to_max_tokens(text) |
| 140 | + |
| 141 | + processed_texts.append(text) |
| 142 | + |
| 143 | + # Generate embeddings in batch |
| 144 | + embeddings = self._model.encode( |
| 145 | + processed_texts, convert_to_numpy=True, show_progress_bar=False |
| 146 | + ) |
| 147 | + |
| 148 | + return [emb.tolist() for emb in embeddings] |
| 149 | + |
| 150 | + |
| 151 | +# Global instance |
| 152 | +_embedding_service: Optional[EmbeddingService] = None |
| 153 | + |
| 154 | + |
| 155 | +def get_embedding_service() -> EmbeddingService: |
| 156 | + """ |
| 157 | + Get the global embedding service instance. |
| 158 | +
|
| 159 | + This function initializes the service on first call and returns |
| 160 | + the cached instance on subsequent calls. |
| 161 | +
|
| 162 | + Returns: |
| 163 | + The global EmbeddingService instance |
| 164 | + """ |
| 165 | + global _embedding_service |
| 166 | + |
| 167 | + if _embedding_service is None: |
| 168 | + _embedding_service = EmbeddingService() |
| 169 | + |
| 170 | + return _embedding_service |
0 commit comments