From c99e33f0ae8ffbbd9ebe5988e70eb396c407d5f6 Mon Sep 17 00:00:00 2001 From: Cathryn Lavery Date: Tue, 5 May 2026 13:00:21 -0500 Subject: [PATCH] Auto-detect NeMo model type during conversion --- README.md | 2 +- scripts/convert_nemo.py | 53 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 08aa69a..a6fc490 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ pip install safetensors torch python scripts/convert_nemo.py parakeet-tdt_ctc-110m.nemo -o model.safetensors ``` -The converter supports all model types: `110m-tdt-ctc` (default), `600m-tdt`, `eou-120m`, `nemotron-600m`, `sortformer`. +The converter auto-detects the model type by default. It also supports explicit model types: `110m-tdt-ctc`, `600m-tdt`, `eou-120m`, `nemotron-600m`, `sortformer`. ```bash python scripts/convert_nemo.py checkpoint.nemo -o model.safetensors --model 600m-tdt diff --git a/scripts/convert_nemo.py b/scripts/convert_nemo.py index 4289f4e..8946cc0 100644 --- a/scripts/convert_nemo.py +++ b/scripts/convert_nemo.py @@ -19,6 +19,7 @@ """ import argparse +import re import tarfile import tempfile import sys @@ -93,6 +94,44 @@ NUM_DURATIONS = 5 +def infer_model_type(state_dict): + """Infer the model preset from checkpoint tensor shapes. + + This keeps the common path safe: converting a 600M checkpoint with the + 110M default silently produces loadable but invalid joint weights. + """ + if any(k.startswith("sortformer_modules.") for k in state_dict): + return "sortformer" + + embed = state_dict.get("decoder.prediction.embed.weight") + if embed is None: + raise ValueError( + "could not infer model type: missing decoder.prediction.embed.weight; " + "pass --model explicitly" + ) + + vocab_size = int(embed.shape[0]) + layer_indices = set() + for key in state_dict: + match = re.match(r"encoder\.layers\.(\d+)\.", key) + if match: + layer_indices.add(int(match.group(1))) + num_layers = len(layer_indices) + + if vocab_size == 8193: + return "600m-tdt" + if vocab_size == 1027: + return "eou-120m" + if vocab_size == 1025 and num_layers == 17: + return "110m-tdt-ctc" + + raise ValueError( + "could not infer model type from checkpoint " + f"(vocab_size={vocab_size}, encoder_layers={num_layers}); " + "pass --model explicitly" + ) + + # ─── NeMo → Axiom name mapping ────────────────────────────────────────────── def build_subsampling_map(axiom_prefix="encoder_"): @@ -384,8 +423,13 @@ def dump_keys(ckpt_path): print(f" {key:70s} {list(t.shape)}") -def convert(ckpt_path, output_path, model_type=DEFAULT_MODEL): +def convert(ckpt_path, output_path, model_type="auto"): """Convert NeMo checkpoint to axiom safetensors.""" + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + if model_type == "auto": + model_type = infer_model_type(state_dict) + print(f"Auto-detected model type: {model_type}") + preset = MODEL_PRESETS[model_type] vocab_size = preset["vocab_size"] num_durations = preset["num_durations"] @@ -397,7 +441,6 @@ def convert(ckpt_path, output_path, model_type=DEFAULT_MODEL): print(f" Encoder layers: {preset['num_layers']}, vocab: {vocab_size}, " f"LSTM layers: {num_lstm_layers}, CTC: {preset.get('has_ctc', False)}") - state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) mapping = build_full_mapping(preset) output = {} @@ -517,9 +560,9 @@ def main(): help="Output safetensors file (default: model.safetensors)") parser.add_argument("--dump", action="store_true", help="Just dump checkpoint keys and shapes") - parser.add_argument("--model", choices=list(MODEL_PRESETS.keys()), - default=DEFAULT_MODEL, - help=f"Model type (default: {DEFAULT_MODEL})") + parser.add_argument("--model", choices=["auto"] + list(MODEL_PRESETS.keys()), + default="auto", + help="Model type (default: auto-detect from checkpoint shapes)") args = parser.parse_args() ckpt_path = extract_checkpoint(args.input)