Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/tether/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4260,12 +4260,28 @@ def go(
from tether import __version__ as _current_tether_version
except Exception: # noqa: BLE001
_current_tether_version = "unknown"
try:
import torch as _torch_cache
_current_torch_version = str(
getattr(_torch_cache, "__version__", "unknown") or "unknown"
)
except Exception: # noqa: BLE001
_current_torch_version = "unknown"
try:
import onnxruntime as _ort_cache
_current_ort_version = str(
getattr(_ort_cache, "__version__", "unknown") or "unknown"
)
except Exception: # noqa: BLE001
_current_ort_version = "unknown"
if meta_marker.exists():
try:
import json as _json_cache
meta = _json_cache.loads(meta_marker.read_text())
cached_version = meta.get("tether_version", "?")
cached_target = meta.get("export_target", "?")
cached_torch_version = meta.get("torch_version", "unknown")
cached_ort_version = meta.get("ort_version", "unknown")
if cached_version != _current_tether_version:
console.print(
f"[yellow]⚠ Cache stale[/yellow]: built by tether {cached_version}, "
Expand All @@ -4276,6 +4292,18 @@ def go(
f"[yellow]⚠ Cache target mismatch[/yellow]: built for "
f"{cached_target}, you need {export_target}. Rebuilding."
)
elif cached_torch_version != _current_torch_version:
console.print(
f"[yellow]⚠ Cache torch version mismatch[/yellow]: built with "
f"torch {cached_torch_version}, you're on "
f"{_current_torch_version}. Rebuilding."
)
elif cached_ort_version != _current_ort_version:
console.print(
f"[yellow]⚠ Cache ORT version mismatch[/yellow]: built with "
f"onnxruntime {cached_ort_version}, you're on "
f"{_current_ort_version}. Rebuilding."
)
else:
cache_valid = True
except Exception as _exc: # noqa: BLE001
Expand Down Expand Up @@ -4338,8 +4366,24 @@ def go(
import json as _json_meta
from tether import __version__ as _current_tether_version
from datetime import datetime as _dt
try:
import torch as _torch_meta
_torch_version = str(
getattr(_torch_meta, "__version__", "unknown") or "unknown"
)
except Exception: # noqa: BLE001
_torch_version = "unknown"
try:
import onnxruntime as _ort_meta
_ort_version = str(
getattr(_ort_meta, "__version__", "unknown") or "unknown"
)
except Exception: # noqa: BLE001
_ort_version = "unknown"
meta = {
"tether_version": _current_tether_version,
"torch_version": _torch_version,
"ort_version": _ort_version,
"model_id": entry.model_id,
"export_target": export_target,
"export_mode": "monolithic",
Expand Down
168 changes: 168 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Tests for CLI smoke tests."""

import json
import sys
import types
from pathlib import Path
from unittest.mock import Mock, patch

from typer.testing import CliRunner

from tether import __version__
Expand All @@ -8,6 +14,61 @@
runner = CliRunner()


def _fake_runtime_modules(torch_version="2.7.1", ort_version="1.25.1"):
torch = types.ModuleType("torch")
torch.__version__ = torch_version
ort = types.ModuleType("onnxruntime")
ort.__version__ = ort_version
return {"torch": torch, "onnxruntime": ort}


def _seed_go_model_cache(tmp_path):
target = tmp_path / "model_cache"
target.mkdir()
(target / "weights.bin").write_text("stub")
return target


def _seed_go_export_cache(tmp_path, meta):
export_dir = tmp_path / "tether_cache" / "exports" / "smolvla-base"
export_dir.mkdir(parents=True)
(export_dir / "VERIFICATION.md").write_text("# stub")
(export_dir / "_tether_meta.json").write_text(json.dumps(meta))
return export_dir


def _fake_export(model_path, output_dir, num_steps=10, target=None):
export_dir = Path(output_dir)
export_dir.mkdir(parents=True, exist_ok=True)
(export_dir / "VERIFICATION.md").write_text("# stub")
return {"onnx_path": str(export_dir / "model.onnx"), "size_mb": 100.0}


def _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock):
target = _seed_go_model_cache(tmp_path)
monkeypatch.setenv("TETHER_HOME", str(tmp_path / "tether_cache"))
server = types.ModuleType("tether.runtime.server")
server.create_app = Mock(side_effect=RuntimeError("serve-stub"))
server.TetherServer = object

with (
patch("tether.exporters.monolithic.export_monolithic", export_mock),
patch.dict("sys.modules", {"tether.runtime.server": server}),
):
return runner.invoke(
app,
[
"go",
"--model",
"smolvla-base",
"--device-class",
"a10g",
"--target-dir",
str(target),
],
)


def test_help():
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
Expand Down Expand Up @@ -117,3 +178,110 @@ def test_serve_help():
def test_serve_missing_dir():
result = runner.invoke(app, ["serve", "/nonexistent/path"])
assert result.exit_code == 1


def test_go_export_cache_accepts_current_torch_and_ort_versions(tmp_path, monkeypatch):
meta = {
"tether_version": __version__,
"torch_version": "2.7.1",
"ort_version": "1.25.1",
"model_id": "smolvla-base",
"export_target": "desktop",
"export_mode": "monolithic",
}
_seed_go_export_cache(tmp_path, meta)
export_mock = Mock(side_effect=_fake_export)

with patch.dict("sys.modules", _fake_runtime_modules()):
result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock)

assert result.exit_code == 1
assert "export hit:" in result.output
export_mock.assert_not_called()


def test_go_export_cache_rebuilds_when_torch_version_changes(tmp_path, monkeypatch):
export_dir = _seed_go_export_cache(
tmp_path,
{
"tether_version": __version__,
"torch_version": "2.0.0",
"ort_version": "1.25.1",
"model_id": "smolvla-base",
"export_target": "desktop",
"export_mode": "monolithic",
},
)
export_mock = Mock(side_effect=_fake_export)

with patch.dict("sys.modules", _fake_runtime_modules()):
result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock)

assert result.exit_code == 1
assert "Cache torch version mismatch" in result.output
export_mock.assert_called_once()
meta = json.loads((export_dir / "_tether_meta.json").read_text())
assert meta["torch_version"] == "2.7.1"
assert meta["ort_version"] == "1.25.1"


def test_go_export_cache_rebuilds_when_ort_version_changes(tmp_path, monkeypatch):
_seed_go_export_cache(
tmp_path,
{
"tether_version": __version__,
"torch_version": "2.7.1",
"ort_version": "1.20.0",
"model_id": "smolvla-base",
"export_target": "desktop",
"export_mode": "monolithic",
},
)
export_mock = Mock(side_effect=_fake_export)

with patch.dict("sys.modules", _fake_runtime_modules()):
result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock)

assert result.exit_code == 1
assert "Cache ORT version mismatch" in result.output
export_mock.assert_called_once()


def test_go_export_cache_rebuilds_legacy_meta_without_runtime_versions(tmp_path, monkeypatch):
_seed_go_export_cache(
tmp_path,
{
"tether_version": __version__,
"model_id": "smolvla-base",
"export_target": "desktop",
"export_mode": "monolithic",
},
)
export_mock = Mock(side_effect=_fake_export)

with patch.dict("sys.modules", _fake_runtime_modules()):
result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock)

assert result.exit_code == 1
assert "Cache torch version mismatch" in result.output
assert "torch unknown" in result.output
export_mock.assert_called_once()


def test_go_export_meta_write_records_unknown_when_runtime_imports_fail(tmp_path, monkeypatch):
def fake_export_and_hide_runtime_versions(model_path, output_dir, num_steps=10, target=None):
result = _fake_export(model_path, output_dir, num_steps=num_steps, target=target)
monkeypatch.setitem(sys.modules, "torch", None)
monkeypatch.setitem(sys.modules, "onnxruntime", None)
return result

export_mock = Mock(side_effect=fake_export_and_hide_runtime_versions)

result = _invoke_go_with_export_cache(tmp_path, monkeypatch, export_mock)

assert result.exit_code == 1
export_mock.assert_called_once()
meta_path = tmp_path / "tether_cache" / "exports" / "smolvla-base" / "_tether_meta.json"
meta = json.loads(meta_path.read_text())
assert meta["torch_version"] == "unknown"
assert meta["ort_version"] == "unknown"