From 40dc260a31d50e6cea87ec7eb83816ab5e4a715d Mon Sep 17 00:00:00 2001 From: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Date: Sun, 7 Jun 2026 09:14:46 -0700 Subject: [PATCH] feat(cli): record torch/onnxruntime versions in export cache and warn on mismatch Stores build-time torch and onnxruntime versions in cache metadata and warns at load when they differ from the current environment, so a stale/incompatible cached export isn't reused silently. Closes #47 Co-Authored-By: Claude Opus 4.8 (1M context) --- src/tether/cli.py | 44 ++++++++++++ tests/test_cli.py | 168 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) diff --git a/src/tether/cli.py b/src/tether/cli.py index faae252..f1a520d 100644 --- a/src/tether/cli.py +++ b/src/tether/cli.py @@ -4260,12 +4260,28 @@ def go( from tether import __version__ as _current_tether_version except Exception: # noqa: BLE001 _current_tether_version = "unknown" + try: + import torch as _torch_cache + _current_torch_version = str( + getattr(_torch_cache, "__version__", "unknown") or "unknown" + ) + except Exception: # noqa: BLE001 + _current_torch_version = "unknown" + try: + import onnxruntime as _ort_cache + _current_ort_version = str( + getattr(_ort_cache, "__version__", "unknown") or "unknown" + ) + except Exception: # noqa: BLE001 + _current_ort_version = "unknown" if meta_marker.exists(): try: import json as _json_cache meta = _json_cache.loads(meta_marker.read_text()) cached_version = meta.get("tether_version", "?") cached_target = meta.get("export_target", "?") + cached_torch_version = meta.get("torch_version", "unknown") + cached_ort_version = meta.get("ort_version", "unknown") if cached_version != _current_tether_version: console.print( f"[yellow]⚠ Cache stale[/yellow]: built by tether {cached_version}, " @@ -4276,6 +4292,18 @@ def go( f"[yellow]⚠ Cache target mismatch[/yellow]: built for " f"{cached_target}, you need {export_target}. Rebuilding." ) + elif cached_torch_version != _current_torch_version: + console.print( + f"[yellow]⚠ Cache torch version mismatch[/yellow]: built with " + f"torch {cached_torch_version}, you're on " + f"{_current_torch_version}. Rebuilding." + ) + elif cached_ort_version != _current_ort_version: + console.print( + f"[yellow]⚠ Cache ORT version mismatch[/yellow]: built with " + f"onnxruntime {cached_ort_version}, you're on " + f"{_current_ort_version}. Rebuilding." + ) else: cache_valid = True except Exception as _exc: # noqa: BLE001 @@ -4338,8 +4366,24 @@ def go( import json as _json_meta from tether import __version__ as _current_tether_version from datetime import datetime as _dt + try: + import torch as _torch_meta + _torch_version = str( + getattr(_torch_meta, "__version__", "unknown") or "unknown" + ) + except Exception: # noqa: BLE001 + _torch_version = "unknown" + try: + import onnxruntime as _ort_meta + _ort_version = str( + getattr(_ort_meta, "__version__", "unknown") or "unknown" + ) + except Exception: # noqa: BLE001 + _ort_version = "unknown" meta = { "tether_version": _current_tether_version, + "torch_version": _torch_version, + "ort_version": _ort_version, "model_id": entry.model_id, "export_target": export_target, "export_mode": "monolithic", diff --git a/tests/test_cli.py b/tests/test_cli.py index 11e29ae..91f169c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,11 @@ """Tests for CLI smoke tests.""" +import json +import sys +import types +from pathlib import Path +from unittest.mock import Mock, patch + from typer.testing import CliRunner from tether import __version__ @@ -8,6 +14,61 @@ runner = CliRunner() +def _fake_runtime_modules(torch_version="2.7.1", ort_version="1.25.1"): + torch = types.ModuleType("torch") + torch.__version__ = torch_version + ort = types.ModuleType("onnxruntime") + ort.__version__ = ort_version + return {"torch": torch, "onnxruntime": ort} + + +def _seed_go_model_cache(tmp_path): + target = tmp_path / "model_cache" + target.mkdir() + (target / "weights.bin").write_text("stub") + return target + + +def _seed_go_export_cache(tmp_path, meta): + export_dir = tmp_path / "tether_cache" / "exports" / "smolvla-base" + export_dir.mkdir(parents=True) + (export_dir / "VERIFICATION.md").write_text("# stub") + (export_dir / "_tether_meta.json").write_text(json.dumps(meta)) + return export_dir + + +def _fake_export(model_path, output_dir, num_steps=10, target=None): + export_dir = Path(output_dir) + export_dir.mkdir(parents=True, exist_ok=True) + (export_dir / "VERIFICATION.md").write_text("# stub") + return {"onnx_path": str(export_dir / "model.onnx"), "size_mb": 100.0} + + +def _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock): + target = _seed_go_model_cache(tmp_path) + monkeypatch.setenv("TETHER_HOME", str(tmp_path / "tether_cache")) + server = types.ModuleType("tether.runtime.server") + server.create_app = Mock(side_effect=RuntimeError("serve-stub")) + server.TetherServer = object + + with ( + patch("tether.exporters.monolithic.export_monolithic", export_mock), + patch.dict("sys.modules", {"tether.runtime.server": server}), + ): + return runner.invoke( + app, + [ + "go", + "--model", + "smolvla-base", + "--device-class", + "a10g", + "--target-dir", + str(target), + ], + ) + + def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 @@ -117,3 +178,110 @@ def test_serve_help(): def test_serve_missing_dir(): result = runner.invoke(app, ["serve", "/nonexistent/path"]) assert result.exit_code == 1 + + +def test_go_export_cache_accepts_current_torch_and_ort_versions(tmp_path, monkeypatch): + meta = { + "tether_version": __version__, + "torch_version": "2.7.1", + "ort_version": "1.25.1", + "model_id": "smolvla-base", + "export_target": "desktop", + "export_mode": "monolithic", + } + _seed_go_export_cache(tmp_path, meta) + export_mock = Mock(side_effect=_fake_export) + + with patch.dict("sys.modules", _fake_runtime_modules()): + result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock) + + assert result.exit_code == 1 + assert "export hit:" in result.output + export_mock.assert_not_called() + + +def test_go_export_cache_rebuilds_when_torch_version_changes(tmp_path, monkeypatch): + export_dir = _seed_go_export_cache( + tmp_path, + { + "tether_version": __version__, + "torch_version": "2.0.0", + "ort_version": "1.25.1", + "model_id": "smolvla-base", + "export_target": "desktop", + "export_mode": "monolithic", + }, + ) + export_mock = Mock(side_effect=_fake_export) + + with patch.dict("sys.modules", _fake_runtime_modules()): + result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock) + + assert result.exit_code == 1 + assert "Cache torch version mismatch" in result.output + export_mock.assert_called_once() + meta = json.loads((export_dir / "_tether_meta.json").read_text()) + assert meta["torch_version"] == "2.7.1" + assert meta["ort_version"] == "1.25.1" + + +def test_go_export_cache_rebuilds_when_ort_version_changes(tmp_path, monkeypatch): + _seed_go_export_cache( + tmp_path, + { + "tether_version": __version__, + "torch_version": "2.7.1", + "ort_version": "1.20.0", + "model_id": "smolvla-base", + "export_target": "desktop", + "export_mode": "monolithic", + }, + ) + export_mock = Mock(side_effect=_fake_export) + + with patch.dict("sys.modules", _fake_runtime_modules()): + result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock) + + assert result.exit_code == 1 + assert "Cache ORT version mismatch" in result.output + export_mock.assert_called_once() + + +def test_go_export_cache_rebuilds_legacy_meta_without_runtime_versions(tmp_path, monkeypatch): + _seed_go_export_cache( + tmp_path, + { + "tether_version": __version__, + "model_id": "smolvla-base", + "export_target": "desktop", + "export_mode": "monolithic", + }, + ) + export_mock = Mock(side_effect=_fake_export) + + with patch.dict("sys.modules", _fake_runtime_modules()): + result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock) + + assert result.exit_code == 1 + assert "Cache torch version mismatch" in result.output + assert "torch unknown" in result.output + export_mock.assert_called_once() + + +def test_go_export_meta_write_records_unknown_when_runtime_imports_fail(tmp_path, monkeypatch): + def fake_export_and_hide_runtime_versions(model_path, output_dir, num_steps=10, target=None): + result = _fake_export(model_path, output_dir, num_steps=num_steps, target=target) + monkeypatch.setitem(sys.modules, "torch", None) + monkeypatch.setitem(sys.modules, "onnxruntime", None) + return result + + export_mock = Mock(side_effect=fake_export_and_hide_runtime_versions) + + result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock) + + assert result.exit_code == 1 + export_mock.assert_called_once() + meta_path = tmp_path / "tether_cache" / "exports" / "smolvla-base" / "_tether_meta.json" + meta = json.loads(meta_path.read_text()) + assert meta["torch_version"] == "unknown" + assert meta["ort_version"] == "unknown"