diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 23e90ed8..73654b53 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -172,7 +172,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) self.draft_model_dir = draft_model_path self.draft_config = Config.from_directory(str(draft_model_path.resolve())) - self.draft_model = Model.from_config(self.draft_config) + self.draft_model = Model.from_config(self.draft_config,component=("mtp" if (self.draft_model_dir==self.model_dir) else "text")) #TODO: expose this in a sane way, as a config or something default_ndt = self.draft_model.caps.get("default_draft_size", 4) self.draft_num_tokens = draft_args.get("draft_num_tokens", default_ndt) xlogger.info(f"Using draft model: {str(draft_model_path.resolve())}")