diff --git a/CHANGELOG.md b/CHANGELOG.md index e20ed73c2..b07244ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat(example): align server MTP support with llama.cpp by @abetlen in #2283 - feat: update llama.cpp to ggml-org/llama.cpp@9e3b928fd - feat(example): add OpenAI-compatible embeddings endpoint by @abetlen in #2281 diff --git a/examples/server/README.md b/examples/server/README.md index 6a1cc4448..ff04374fc 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -434,6 +434,22 @@ Use MTP when the loaded model and llama.cpp build expose the required draft stat } ``` +By default `draft-mtp` creates the MTP context from the target model. +Set `draft_model_path` or `draft_model_from_pretrained` when the model uses a separate assistant GGUF. + +```json +{ + "model": { + "draft_model": "draft-mtp", + "draft_model_num_pred_tokens": 2, + "draft_model_from_pretrained": { + "repo_id": "example/gemma-assistant-GGUF", + "filename": "assistant.gguf" + } + } +} +``` + MTP currently applies to text-only requests. ## Disk Sequence Cache diff --git a/examples/server/server.py b/examples/server/server.py index 2530cf4be..64a16f0bd 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -36,7 +36,6 @@ import copy import shutil import inspect -import importlib.util import sys import urllib.error import urllib.parse @@ -1262,27 +1261,33 @@ def __init__( self, *, model: "Model", + draft_model: Any, context_params: Any, num_pred_tokens: int, top_k: int, p_min: float, ) -> None: self.target_ctx = model.ctx - self.model = model.llama_model + self.model = draft_model self.n_seq_max = model.n_seq_max self.n_vocab = model.n_vocab - self.n_embd = int(llama_cpp.llama_model_n_embd(self.model)) + self.n_embd = int(llama_cpp.llama_model_n_embd_out(self.model)) + if self.n_embd <= 0: + self.n_embd = int(llama_cpp.llama_model_n_embd(self.model)) + if self.n_embd != model.n_embd: + raise RuntimeError( + "MTP draft model output embedding size must match target model " + f"embedding size ({self.n_embd} != {model.n_embd})" + ) self.num_pred_tokens = max(0, int(num_pred_tokens)) self.top_k = max(1, int(top_k)) self.p_min = max(0.0, min(1.0, float(p_min))) - ( - self.target_hidden_norm_weight, - self.draft_hidden_norm_weight, - self.hidden_norm_epsilon, - ) = self._load_hidden_norm_weights(model.model_path) self.ctx = llama_cpp.llama_init_from_model(self.model, context_params) if self.ctx is None: raise RuntimeError("failed to create MTP draft context") + ctx_other = llama_cpp_ext.llama_get_ctx_other(self.ctx) + self.is_mem_shared = bool(ctx_other and ctx_other == self.target_ctx) + self.sampled_batch_draft = not self.is_mem_shared self.n_batch = int(llama_cpp.llama_n_batch(self.ctx)) mem = llama_cpp.llama_get_memory(self.ctx) if mem is None: @@ -1313,135 +1318,13 @@ def __init__( self.decode_failures_total = 0 self.target_processing_enabled = False self.set_target_processing_enabled(True) - llama_cpp_ext.llama_set_embeddings_pre_norm( + llama_cpp_ext.llama_set_embeddings_nextn( self.ctx, True, True, ) self._init_samplers() - @staticmethod - def _load_gguf_reader() -> Any: - try: - from gguf import GGUFReader # type: ignore[import-not-found] - return GGUFReader - except ImportError: - pass - - gguf_init = ( - Path(__file__).resolve().parents[2] - / "vendor/llama.cpp/gguf-py/gguf/__init__.py" - ) - spec = importlib.util.spec_from_file_location( - "gguf", - gguf_init, - submodule_search_locations=[str(gguf_init.parent)], - ) - if spec is None or spec.loader is None: - raise RuntimeError("MTP requires the gguf Python reader from llama.cpp") - module = importlib.util.module_from_spec(spec) - previous = sys.modules.get("gguf") - # Package-relative imports in gguf-py expect the package to be registered. - sys.modules["gguf"] = module - try: - spec.loader.exec_module(module) - except Exception: - if previous is None: - sys.modules.pop("gguf", None) - else: - sys.modules["gguf"] = previous - raise - return module.GGUFReader - - @staticmethod - def _gguf_field_contents(reader: Any, name: str) -> Any: - field = reader.fields.get(name) - if field is None: - return None - return field.contents() - - def _load_hidden_norm_weights( - self, - model_path: str, - ) -> Tuple[np.ndarray, np.ndarray, float]: - GGUFReader = self._load_gguf_reader() - - target_weight: Optional[np.ndarray] = None - draft_weights: List[np.ndarray] = [] - reader = GGUFReader(model_path) - arch = self._gguf_field_contents(reader, "general.architecture") - if arch not in {"qwen35", "qwen35moe"}: - raise RuntimeError( - "draft-mtp currently supports qwen35/qwen35moe GGUF models, " - f"got {arch!r}" - ) - nextn_layers = self._gguf_field_contents( - reader, - f"{arch}.nextn_predict_layers", - ) - # The current MTP path follows llama.cpp's Qwen3.5 one-nextn-layer graph. - if int(nextn_layers or 0) != 1: - raise RuntimeError( - "draft-mtp currently supports exactly one Qwen3.5 nextn prediction layer" - ) - epsilon = self._gguf_field_contents( - reader, - f"{arch}.attention.layer_norm_rms_epsilon", - ) - if epsilon is None: - raise RuntimeError( - f"MTP requires {arch}.attention.layer_norm_rms_epsilon" - ) - for tensor in reader.tensors: - if tensor.name == "output_norm.weight": - target_weight = np.asarray(tensor.data, dtype=np.float32).copy() - elif tensor.name.endswith(".nextn.shared_head_norm.weight"): - draft_weights.append(np.asarray(tensor.data, dtype=np.float32).copy()) - if target_weight is None: - raise RuntimeError("MTP requires output_norm.weight in GGUF model") - if len(draft_weights) > 1: - raise RuntimeError( - "MTP requires at most one blk.*.nextn.shared_head_norm.weight in GGUF model" - ) - draft_weight = draft_weights[0] if draft_weights else target_weight - if target_weight.shape != (self.n_embd,): - raise RuntimeError( - "MTP target norm weight shape does not match model embedding size " - f"({target_weight.shape} != ({self.n_embd},))" - ) - if draft_weight.shape != (self.n_embd,): - raise RuntimeError( - "MTP draft norm weight shape does not match model embedding size " - f"({draft_weight.shape} != ({self.n_embd},))" - ) - return target_weight, draft_weight, float(epsilon) - - def _normalize_hidden_rows(self, rows: np.ndarray, weight: np.ndarray) -> np.ndarray: - rows = np.asarray(rows, dtype=np.float32) - scale = np.reciprocal( - np.sqrt( - np.mean(np.square(rows), axis=-1, keepdims=True) - + self.hidden_norm_epsilon - ) - ) - return rows * scale * weight.reshape(1, -1) - - def _normalize_target_hidden_rows(self, rows: np.ndarray) -> np.ndarray: - return self._normalize_hidden_rows(rows, self.target_hidden_norm_weight) - - def _normalize_draft_hidden_row( - self, - row: Union[np.ndarray, ctypes.POINTER(ctypes.c_float)], - ) -> np.ndarray: - if isinstance(row, np.ndarray): - row_array = row.reshape(1, -1) - else: - row_array = np.ctypeslib.as_array(row, shape=(self.n_embd,)).reshape(1, -1) - return self._normalize_hidden_rows( - row_array, - self.draft_hidden_norm_weight, - )[0] - def _init_samplers(self) -> None: for seq_id in range(self.n_seq_max): params = llama_cpp.llama_sampler_chain_default_params() @@ -1507,7 +1390,7 @@ def close(self) -> None: def set_target_processing_enabled(self, enabled: bool) -> None: if self.target_processing_enabled == enabled: return - llama_cpp_ext.llama_set_embeddings_pre_norm( + llama_cpp_ext.llama_set_embeddings_nextn( self.target_ctx, enabled, False, @@ -1641,14 +1524,13 @@ def process(self, batch: Any, /) -> None: ): return - h_tgt = llama_cpp_ext.llama_get_embeddings_pre_norm(self.target_ctx) + h_tgt = llama_cpp_ext.llama_get_embeddings_nextn(self.target_ctx) if not h_tgt: - raise RuntimeError("missing target pre-norm embeddings for MTP") + raise RuntimeError("missing target nextn embeddings for MTP") h_tgt_rows = np.ctypeslib.as_array( h_tgt, shape=(n_tokens, self.n_embd), ) - h_tgt_rows = self._normalize_target_hidden_rows(h_tgt_rows) previous_row_by_seq: Dict[int, int] = {} first_pos_by_seq: Dict[int, int] = {} @@ -1708,17 +1590,17 @@ def _process_rows( self.ready[seq_id] and self.ready_pos[seq_id] == first_pos ) aligned_by_seq.setdefault(seq_id, aligned) - mtp_pos = ( - pos - 1 - if previous_row_by_seq.get(seq_id) is None - else int(batch.pos[previous_row_by_seq[seq_id]]) - ) - if aligned and mtp_pos >= 0 and mtp_pos >= self.context_pos[seq_id]: - previous_row = previous_row_by_seq.get(seq_id) + previous_row = previous_row_by_seq.get(seq_id) + if ( + aligned + and not self.is_mem_shared + and pos >= 0 + and pos >= self.context_pos[seq_id] + ): slot = int(self.batch.n_tokens) self._add_batch_token( token=int(batch.token[index]), - pos=mtp_pos, + pos=pos, seq_id=seq_id, logits=False, ) @@ -1726,7 +1608,7 @@ def _process_rows( self._set_batch_embedding_row(slot, self.pending_h[seq_id]) else: self._set_batch_embedding_row(slot, h_tgt_rows[previous_row]) - added_pos_by_seq[seq_id] = mtp_pos + added_pos_by_seq[seq_id] = pos previous_row_by_seq[seq_id] = index target_rows_by_seq.setdefault(seq_id, []).append(index) @@ -1760,15 +1642,15 @@ def draft( n_past = int(input_ids.size) - 1 if self.ready_pos[seq_id] != n_past: return np.array([], dtype=np.intc) - first_pos = n_past - 1 + first_pos = n_past if first_pos < 0: return np.array([], dtype=np.intc) token = int(input_ids[-1]) drafted: List[int] = [] - if self.context_pos[seq_id] > first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] > first_pos: self.truncate(seq_id, first_pos) - if self.context_pos[seq_id] < first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] < first_pos: self.ready[seq_id] = False return np.array([], dtype=np.intc) @@ -1782,9 +1664,11 @@ def draft( ) self._set_batch_embedding_row(0, self.pending_h[seq_id]) if not self._try_decode_batch(): - self.truncate(seq_id, first_pos) + if not self.is_mem_shared: + self.truncate(seq_id, first_pos) return np.array([], dtype=np.intc) - self.context_pos[seq_id] = n_past + if not self.is_mem_shared: + self.context_pos[seq_id] = first_pos + 1 while len(drafted) < n_predict: sampled_token = self._sample_token(seq_id=seq_id) @@ -1794,26 +1678,28 @@ def draft( drafted.append(token) if len(drafted) >= n_predict: break - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith(self.ctx, 0) + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith(self.ctx, 0) if not h_row: break - h_row = self._normalize_draft_hidden_row(h_row) self._clear_batch() self._add_batch_token( token=token, - pos=first_pos + len(drafted), + pos=first_pos if self.is_mem_shared else first_pos + len(drafted), seq_id=seq_id, logits=True, ) self._set_batch_embedding_row(0, h_row) if not self._try_decode_batch(): break - self.context_pos[seq_id] = first_pos + len(drafted) + 1 + if not self.is_mem_shared: + self.context_pos[seq_id] = first_pos + len(drafted) + 1 if not drafted: - self.truncate(seq_id, n_past) + if not self.is_mem_shared: + self.truncate(seq_id, n_past) return np.array([], dtype=np.intc) - self.truncate(seq_id, n_past) + if not self.is_mem_shared: + self.truncate(seq_id, n_past) return np.asarray(drafted, dtype=np.intc) def draft_many( @@ -1843,12 +1729,12 @@ def draft_many( n_past = int(input_ids.size) - 1 if self.ready_pos[seq_id] != n_past: continue - first_pos = n_past - 1 + first_pos = n_past if first_pos < 0: continue - if self.context_pos[seq_id] > first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] > first_pos: self.truncate(seq_id, first_pos) - if self.context_pos[seq_id] < first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] < first_pos: self.ready[seq_id] = False continue self._reset_sampler(seq_id) @@ -1904,10 +1790,11 @@ def draft_many( ): if sampled_token is None: continue - self.context_pos[representative.seq_id] = max( - self.context_pos[representative.seq_id], - representative.keep_len, - ) + if not self.is_mem_shared: + self.context_pos[representative.seq_id] = max( + self.context_pos[representative.seq_id], + representative.first_pos + 1, + ) for state in grouped[representative.cache_key]: state.drafted.append(sampled_token) active = [] @@ -1917,7 +1804,11 @@ def draft_many( for row, state in enumerate(active): self._add_batch_token( token=state.token, - pos=state.first_pos + len(state.drafted), + pos=( + state.first_pos + if self.is_mem_shared + else state.first_pos + len(state.drafted) + ), seq_id=state.seq_id, logits=True, ) @@ -1932,14 +1823,19 @@ def draft_many( for row, state in enumerate(active) ] for row, (state, sampled_token) in enumerate(zip(active, sampled_tokens)): - decoded_pos = state.first_pos + len(state.drafted) - self.context_pos[state.seq_id] = max( - self.context_pos[state.seq_id], - decoded_pos + 1, + decoded_pos = ( + state.first_pos + if self.is_mem_shared + else state.first_pos + len(state.drafted) ) + if not self.is_mem_shared: + self.context_pos[state.seq_id] = max( + self.context_pos[state.seq_id], + decoded_pos + 1, + ) if sampled_token is None: continue - h_row_ptr = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row_ptr = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, row ) state.drafted.append(sampled_token) @@ -1947,14 +1843,17 @@ def draft_many( continue if not h_row_ptr: continue - h_row = self._normalize_draft_hidden_row(h_row_ptr) state.token = sampled_token - state.embedding = h_row + state.embedding = np.ctypeslib.as_array( + h_row_ptr, + shape=(self.n_embd,), + ).copy() next_active.append(state) active = next_active finally: - for state in touched: - self.truncate(state.seq_id, state.keep_len) + if not self.is_mem_shared: + for state in touched: + self.truncate(state.seq_id, state.keep_len) for state in touched: if state.drafted: @@ -2036,7 +1935,7 @@ def _build_sampled_batch_plan( pending_token = update.pending_token if ( pending_token is None - or start_pos <= 0 + or start_pos < 0 or target_count <= 0 or target_count > len(tokens) or target_count > len(row_indices) @@ -2046,9 +1945,7 @@ def _build_sampled_batch_plan( continue for target_index in range(sample_index + 1): - mtp_pos = start_pos + target_index - 1 - if mtp_pos < 0: - continue + mtp_pos = start_pos + target_index source_row = ( None if target_index == 0 @@ -2063,7 +1960,7 @@ def _build_sampled_batch_plan( ) ) - actual_pos = start_pos + sample_index + actual_pos = start_pos + sample_index + 1 pending_rows.append( self.SampledPendingRow( update_index=update_index, @@ -2090,6 +1987,8 @@ def _decode_sampled_context_rows( self._clear_batch() decoded_context_rows: List[Tuple[int, int]] = [] for row in context_rows: + if self.is_mem_shared: + continue if row.draft_pos < self.context_pos[row.seq_id]: continue if row.source_row is None: @@ -2124,7 +2023,10 @@ def _decode_sampled_pending_rows( self._clear_batch() for pending_index, row in enumerate(pending_rows): - if row.draft_pos < self.context_pos[row.seq_id]: + if ( + not self.is_mem_shared + and row.draft_pos < self.context_pos[row.seq_id] + ): continue is_sample_pending = ( pending_index @@ -2144,17 +2046,18 @@ def _decode_sampled_pending_rows( update_index=row.update_index, seq_id=row.seq_id, output_index=slot, - keep_len=row.draft_pos + 1, - ready_pos=row.draft_pos + 1, + keep_len=row.draft_pos, + ready_pos=row.draft_pos, ) ) self._decode_batch() - for row in pending_rows: - self.context_pos[row.seq_id] = max( - self.context_pos[row.seq_id], - row.draft_pos + 1, - ) + if not self.is_mem_shared: + for row in pending_rows: + self.context_pos[row.seq_id] = max( + self.context_pos[row.seq_id], + row.draft_pos + 1, + ) return sampled_outputs @@ -2184,7 +2087,7 @@ def _start_sampled_draft_states( if n_predict <= 0: continue if n_predict > 1: - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, output.output_index ) if h_row: @@ -2193,11 +2096,18 @@ def _start_sampled_draft_states( update_index=output.update_index, seq_id=seq_id, keep_len=output.keep_len, - pos=output.keep_len, + pos=( + output.keep_len + if self.is_mem_shared + else output.keep_len + 1 + ), token=sampled_token, drafted=[sampled_token], n_predict=n_predict, - embedding=self._normalize_draft_hidden_row(h_row), + embedding=np.ctypeslib.as_array( + h_row, + shape=(self.n_embd,), + ).copy(), ) ) results[output.update_index] = np.asarray([sampled_token], dtype=np.intc) @@ -2235,31 +2145,36 @@ def _extend_sampled_draft_states( for batch_row, (state, sampled_token) in enumerate( zip(active, sampled_tokens) ): - self.context_pos[state.seq_id] = max( - self.context_pos[state.seq_id], - state.pos + 1, - ) + if not self.is_mem_shared: + self.context_pos[state.seq_id] = max( + self.context_pos[state.seq_id], + state.pos + 1, + ) if sampled_token is None: continue state.drafted.append(sampled_token) if len(state.drafted) >= state.n_predict: continue - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, batch_row ) if not h_row: continue - h_row = self._normalize_draft_hidden_row(h_row) state.token = sampled_token - state.embedding = h_row - state.pos += 1 + state.embedding = np.ctypeslib.as_array( + h_row, + shape=(self.n_embd,), + ).copy() + if not self.is_mem_shared: + state.pos += 1 next_active.append(state) active = next_active finally: for state in touched: cleanup_keep_len_by_seq[state.seq_id] = state.keep_len - for seq_id, keep_len in cleanup_keep_len_by_seq.items(): - self._truncate_memory(seq_id, keep_len) + if not self.is_mem_shared: + for seq_id, keep_len in cleanup_keep_len_by_seq.items(): + self._truncate_memory(seq_id, keep_len) for state in touched: if state.drafted: @@ -2276,9 +2191,9 @@ def process_sampled_batch( results = [np.array([], dtype=np.intc) for _ in updates] if self.num_pred_tokens <= 0 or not updates: return results - h_tgt = llama_cpp_ext.llama_get_embeddings_pre_norm(self.target_ctx) + h_tgt = llama_cpp_ext.llama_get_embeddings_nextn(self.target_ctx) if not h_tgt: - raise RuntimeError("missing target pre-norm embeddings for MTP") + raise RuntimeError("missing target nextn embeddings for MTP") n_target_rows = max( ( max(update.row_indices) + 1 @@ -2290,7 +2205,6 @@ def process_sampled_batch( if n_target_rows <= 0: return results h_tgt_rows = np.ctypeslib.as_array(h_tgt, shape=(n_target_rows, self.n_embd)) - h_tgt_rows = self._normalize_target_hidden_rows(h_tgt_rows) plan = self._build_sampled_batch_plan(updates) if not plan.context_rows and not plan.pending_rows: @@ -2313,8 +2227,9 @@ def process_sampled_batch( self.ready_pos[output.seq_id] = output.ready_pos cleanup_keep_len_by_seq[output.seq_id] = output.keep_len - for seq_id, keep_len in cleanup_keep_len_by_seq.items(): - self._truncate_memory(seq_id, keep_len) + if not self.is_mem_shared: + for seq_id, keep_len in cleanup_keep_len_by_seq.items(): + self._truncate_memory(seq_id, keep_len) if sampled_outputs: active = self._start_sampled_draft_states( @@ -2344,6 +2259,9 @@ def accept(self, seq_id: int, accepted_draft_tokens: int) -> None: def _truncate_memory(self, seq_id: int, keep_len: int) -> None: if seq_id < 0 or seq_id >= self.n_seq_max: return + if self.is_mem_shared: + self.context_pos[seq_id] = min(self.context_pos[seq_id], keep_len) + return if not llama_cpp.llama_memory_seq_rm( self.mem, seq_id, @@ -2386,13 +2304,14 @@ def copy_sequence( or dest_seq_id >= self.n_seq_max ): return - llama_cpp.llama_memory_seq_cp( - self.mem, - source_seq_id, - dest_seq_id, - p0, - p1, - ) + if not self.is_mem_shared: + llama_cpp.llama_memory_seq_cp( + self.mem, + source_seq_id, + dest_seq_id, + p0, + p1, + ) source_ready_pos = self.ready_pos[source_seq_id] copied_full_ready_state = p1 < 0 or p1 == source_ready_pos if self.ready[source_seq_id] and copied_full_ready_state: @@ -3354,6 +3273,10 @@ class ModelOptions(BaseModel): max_output_tokens: Optional[int] = Field(default=None, ge=0) kv_unified: bool = True draft_model: Optional[Literal["prompt-lookup-decoding", "draft-mtp"]] = None + draft_model_path: Optional[str] = None + draft_model_from_pretrained: Optional[ + "ConfigFile.FromPretrainedOptions" + ] = None draft_model_num_pred_tokens: int = 16 draft_model_max_ngram_size: int = 2 draft_model_top_k: int = Field(default=1, ge=1) @@ -3368,8 +3291,21 @@ class ModelOptions(BaseModel): def validate_source(self) -> "ConfigFile.ModelOptions": if (self.path is None) == (self.from_pretrained is None): raise ValueError("exactly one of model.path or model.from_pretrained is required") + if ( + self.draft_model_path is not None + and self.draft_model_from_pretrained is not None + ): + raise ValueError( + "model.draft_model_path and model.draft_model_from_pretrained " + "are mutually exclusive" + ) return self + def resolve_draft_model_path(self) -> Optional[str]: + if self.draft_model_from_pretrained is not None: + return self.draft_model_from_pretrained.resolve_model_path() + return self.draft_model_path + @field_validator("chat_template", mode="before") @classmethod def normalize_chat_template(cls, value: Any) -> Any: @@ -10549,6 +10485,7 @@ def __init__( max_seq_len: Optional[int] = None, max_output_tokens: Optional[int] = None, draft_model: Optional[str] = None, + draft_model_path: Optional[str] = None, draft_model_num_pred_tokens: int = 16, draft_model_max_ngram_size: int = 2, draft_model_top_k: int = 1, @@ -10573,6 +10510,7 @@ def __init__( self._lora_adapters: List[Any] = [] self._lora_adapter_array: Optional[Any] = None self._lora_scales_array: Optional[Any] = None + self.draft_llama_model: Optional[Any] = None model_params, self._c_tensor_split, self._kv_overrides_array = ( self.build_model_params( n_gpu_layers=n_gpu_layers, @@ -10630,11 +10568,6 @@ def __init__( "speculative decoding is only supported for attention models" ) n_ctx_train = int(llama_cpp.llama_model_n_ctx_train(llama_model)) - target_n_rs_seq = ( - max(1, draft_model_num_pred_tokens) - if normalized_draft_model == "draft-mtp" - else None - ) context_params = self.build_context_params( n_ctx=n_ctx if n_ctx is not None else n_ctx_train, @@ -10662,7 +10595,7 @@ def __init__( type_k=type_k, type_v=type_v, kv_unified=kv_unified, - n_rs_seq=target_n_rs_seq, + n_rs_seq=None, ctx_type=None, ) ctx = llama_cpp.llama_init_from_model(llama_model, context_params) @@ -10692,11 +10625,6 @@ def __init__( "MTP requires runtime n_batch to fit the pending token plus draft tokens " f"(required {required_mtp_batch}, got {self.n_batch})" ) - if target_n_rs_seq is not None and self.n_rs_seq < target_n_rs_seq: - raise RuntimeError( - "MTP requires retained recurrent-state slots for rollback " - f"(required {target_n_rs_seq}, got {self.n_rs_seq})" - ) self.n_ctx_train = n_ctx_train self.n_vocab = int(llama_cpp.llama_vocab_n_tokens(self.vocab)) self.n_embd = int(llama_cpp.llama_model_n_embd(self.llama_model)) @@ -10747,6 +10675,22 @@ def __init__( num_pred_tokens=draft_model_num_pred_tokens, ) elif normalized_draft_model == "draft-mtp": + draft_llama_model = self.llama_model + if draft_model_path is not None: + draft_llama_model = llama_cpp.llama_model_load_from_file( + draft_model_path.encode("utf-8"), + model_params, + ) + if draft_llama_model is None: + llama_cpp.llama_batch_free(self.batch) + llama_cpp.llama_free(self.ctx) + self._free_lora_adapters() + llama_cpp.llama_model_free(self.llama_model) + if self.backend_initialized: + llama_cpp.llama_backend_free() + self.backend_initialized = False + raise RuntimeError(f"failed to load MTP draft model: {draft_model_path}") + self.draft_llama_model = draft_llama_model if self.n_ubatch < self.n_seq_max: mtp_n_batch = self.n_batch else: @@ -10792,13 +10736,15 @@ def __init__( type_k=type_k, type_v=type_v, kv_unified=kv_unified, - n_rs_seq=target_n_rs_seq, + n_rs_seq=0, ctx_type=llama_cpp.LLAMA_CONTEXT_TYPE_MTP, n_outputs_max=min(mtp_n_batch, self.n_seq_max), + ctx_other=self.ctx, ) try: self.draft_provider = MTPDraftProvider( model=self, + draft_model=draft_llama_model, context_params=mtp_context_params, num_pred_tokens=draft_model_num_pred_tokens, top_k=draft_model_top_k, @@ -10808,6 +10754,9 @@ def __init__( llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() @@ -10818,7 +10767,10 @@ def __init__( try: self._load_lora_adapters(self.loras) self._apply_lora_adapters(self.ctx, "target") - if isinstance(self.draft_provider, MTPDraftProvider): + if ( + isinstance(self.draft_provider, MTPDraftProvider) + and self.draft_llama_model is None + ): self._apply_lora_adapters(self.draft_provider.ctx, "MTP draft") except BaseException: if self.draft_provider is not None: @@ -10826,6 +10778,9 @@ def __init__( llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() @@ -10938,6 +10893,7 @@ def build_context_params( n_rs_seq: Optional[int] = None, ctx_type: Optional[int] = None, n_outputs_max: Optional[int] = None, + ctx_other: Optional[Any] = None, ) -> Any: context_params = llama_cpp.llama_context_default_params() if n_ctx is not None: @@ -10958,6 +10914,8 @@ def build_context_params( context_params.ctx_type = ctx_type if n_outputs_max is not None: context_params.n_outputs_max = n_outputs_max + if ctx_other is not None: + context_params.ctx_other = ctx_other if rope_scaling_type is not None: context_params.rope_scaling_type = rope_scaling_type if pooling_type is not None: @@ -11070,6 +11028,9 @@ def close(self) -> None: llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() @@ -15649,6 +15610,7 @@ def main() -> None: max_seq_len=config.model.max_seq_len, max_output_tokens=config.model.max_output_tokens, draft_model=config.model.draft_model, + draft_model_path=config.model.resolve_draft_model_path(), draft_model_num_pred_tokens=config.model.draft_model_num_pred_tokens, draft_model_max_ngram_size=config.model.draft_model_max_ngram_size, draft_model_top_k=config.model.draft_model_top_k, diff --git a/llama_cpp/llama_cpp_ext.py b/llama_cpp/llama_cpp_ext.py index f6ab9197b..284811086 100644 --- a/llama_cpp/llama_cpp_ext.py +++ b/llama_cpp/llama_cpp_ext.py @@ -42,58 +42,76 @@ def decorator(f): return decorator -# LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); +# LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); @_ctypes_function_from_names( ( - "llama_set_embeddings_pre_norm", - "_Z29llama_set_embeddings_pre_normP13llama_contextbb", - "?llama_set_embeddings_pre_norm@@YAXPEAUllama_context@@_N1@Z", + "llama_set_embeddings_nextn", + "_Z26llama_set_embeddings_nextnP13llama_contextbb", + "?llama_set_embeddings_nextn@@YAXPEAUllama_context@@_N1@Z", ), [llama_cpp.llama_context_p_ctypes, ctypes.c_bool, ctypes.c_bool], None, ) -def llama_set_embeddings_pre_norm( +def llama_set_embeddings_nextn( ctx: llama_cpp.llama_context_p, value: bool, masked: bool, /, ): - """Set whether the context outputs pre-norm embeddings or not.""" + """Set whether the context outputs nextn embeddings or not.""" ... -# LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); +# LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); @_ctypes_function_from_names( ( - "llama_get_embeddings_pre_norm", - "_Z29llama_get_embeddings_pre_normP13llama_context", - "?llama_get_embeddings_pre_norm@@YAPEAMPEAUllama_context@@@Z", + "llama_get_embeddings_nextn", + "_Z26llama_get_embeddings_nextnP13llama_context", + "?llama_get_embeddings_nextn@@YAPEAMPEAUllama_context@@@Z", ), [llama_cpp.llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float), ) -def llama_get_embeddings_pre_norm( +def llama_get_embeddings_nextn( ctx: llama_cpp.llama_context_p, /, ): - """Get the pre-norm embeddings from the last evaluation.""" + """Get the nextn embeddings from the last evaluation.""" ... -# LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); +# LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); @_ctypes_function_from_names( ( - "llama_get_embeddings_pre_norm_ith", - "_Z33llama_get_embeddings_pre_norm_ithP13llama_contexti", - "?llama_get_embeddings_pre_norm_ith@@YAPEAMPEAUllama_context@@H@Z", + "llama_get_embeddings_nextn_ith", + "_Z30llama_get_embeddings_nextn_ithP13llama_contexti", + "?llama_get_embeddings_nextn_ith@@YAPEAMPEAUllama_context@@H@Z", ), [llama_cpp.llama_context_p_ctypes, ctypes.c_int32], ctypes.POINTER(ctypes.c_float), ) -def llama_get_embeddings_pre_norm_ith( +def llama_get_embeddings_nextn_ith( ctx: llama_cpp.llama_context_p, i: Union[ctypes.c_int32, int], /, ): - """Get the pre-norm embeddings for the ith output row from the last evaluation.""" + """Get the nextn embeddings for the ith output row from the last evaluation.""" + ... + + +# LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); +@_ctypes_function_from_names( + ( + "llama_get_ctx_other", + "_Z19llama_get_ctx_otherP13llama_context", + "?llama_get_ctx_other@@YAPEAUllama_context@@PEAU1@@Z", + ), + [llama_cpp.llama_context_p_ctypes], + llama_cpp.llama_context_p_ctypes, +) +def llama_get_ctx_other( + ctx: llama_cpp.llama_context_p, + /, +): + """Get the context linked through llama_context_params.ctx_other.""" ...