Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion centml/cli/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
DeploymentType.CSERVE_V2: "cserve",
DeploymentType.CSERVE_V3: "cserve",
DeploymentType.RAG: "rag",
DeploymentType.JOB: "job",
}
# use latest type to for user requests
depl_name_to_type_map = {
"inference": DeploymentType.INFERENCE_V3,
"cserve": DeploymentType.CSERVE_V3,
"compute": DeploymentType.COMPUTE_V2,
"rag": DeploymentType.RAG,
"job": DeploymentType.JOB,
}
rollout_status_to_service_status_map = {
RolloutStatus.HEALTHY: ServiceStatus.HEALTHY,
Expand Down Expand Up @@ -100,6 +102,9 @@ def _get_ready_status(deployment, service_status):
(DeploymentStatus.ACTIVE, ServiceStatus.INITIALIZING): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.MISSING): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.NOTREADY): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.COMPLETED): ("completed", "green", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.CLEANEDUP): ("cleanedUp", "white", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.FAILED): ("failed", "red", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.ERROR): ("error", "red", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.CREATECONTAINERCONFIGERROR): (
"createContainerConfigError",
Expand Down Expand Up @@ -215,6 +220,8 @@ def get(type, id):
deployment = cclient.get_compute(id)
elif depl_type in [DeploymentType.CSERVE_V2, DeploymentType.CSERVE_V3]:
deployment = cclient.get_cserve(id) # handles both V2 and V3
elif depl_type == DeploymentType.JOB:
deployment = cclient.get_job(id)
Comment thread
michaelshin marked this conversation as resolved.
else:
sys.exit("Please enter correct deployment type")

Expand All @@ -228,11 +235,12 @@ def get(type, id):
detail_rows = [
("Name", deployment.name),
("Status", ready_status),
("Endpoint", deployment.endpoint_url),
("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
("Hardware", f"{hw.name} ({hw.num_gpu}x {hw.gpu_type})"),
("Cost", f"{hw.cost_per_hr / 100} credits/hr"),
]
if depl_type != DeploymentType.JOB:
detail_rows.insert(2, ("Endpoint", deployment.endpoint_url))

click.echo(tabulate(detail_rows, tablefmt="rounded_outline", disable_numparse=True))
if status_error_messages:
Expand Down Expand Up @@ -277,6 +285,17 @@ def get(type, id):
]

click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))
elif depl_type == DeploymentType.JOB:
display_rows = [
("Image", deployment.image_url),
("Command", deployment.original_command or "None"),
Comment thread
michaelshin marked this conversation as resolved.
("Environment variables", deployment.env_vars or "None"),
("Completions", deployment.completions),
("Parallelism", deployment.parallelism),
("Logging", deployment.enable_logging),
]

click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))


@click.command(help="Delete a deployment")
Expand Down
7 changes: 7 additions & 0 deletions centml/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CreateInferenceV3DeploymentRequest,
CreateComputeDeploymentRequest,
CreateCServeV3DeploymentRequest,
CreateJobDeploymentRequest,
ApiException,
InviteUserRequest,
Metric,
Expand Down Expand Up @@ -58,6 +59,9 @@ def get_inference(self, id):
def get_compute(self, id):
return self._api.get_compute_deployment_deployments_compute_deployment_id_get(id)

def get_job(self, id):
return self._api.get_job_deployment_deployments_job_deployment_id_get(id)

def get_cserve(self, id):
"""Get CServe deployment details - automatically handles both V2 and V3 deployments"""
# Try V3 first (recommended), fallback to V2 if deployment is V2
Expand All @@ -81,6 +85,9 @@ def create_inference(self, request: CreateInferenceV3DeploymentRequest):
def create_compute(self, request: CreateComputeDeploymentRequest):
return self._api.create_compute_deployment_deployments_compute_post(request)

def create_job(self, request: CreateJobDeploymentRequest):
return self._api.create_job_deployment_deployments_job_post(request)

def create_cserve(self, request: CreateCServeV3DeploymentRequest):
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ pyjwt>=2.8.0
cryptography==46.0.7
websockets>=16.0
pyte>=0.8.0
platform-api-python-client==4.9.0
platform-api-python-client==4.10.0
104 changes: 102 additions & 2 deletions tests/test_cli_cluster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from contextlib import contextmanager
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

from centml.cli.cluster import _get_service_status, _get_status_error_messages
from centml.sdk import RolloutStatus, ServiceStatus
from click.testing import CliRunner

from centml.cli.cluster import _get_ready_status, _get_service_status, _get_status_error_messages
from centml.sdk import DeploymentStatus, DeploymentType, RolloutStatus, ServiceStatus


def test_service_status_uses_legacy_service_status_when_present():
Expand Down Expand Up @@ -77,3 +82,98 @@ def test_status_error_messages_include_legacy_status_message():
status_response = SimpleNamespace(error_message="legacy service failure")

assert _get_status_error_messages(status_response) == ["legacy service failure"]


def test_ready_status_supports_completed_service_status():
deployment = SimpleNamespace(status=DeploymentStatus.ACTIVE)

assert "completed" in _get_ready_status(deployment, ServiceStatus.COMPLETED)


def test_ready_status_supports_cleaned_up_service_status():
deployment = SimpleNamespace(status=DeploymentStatus.ACTIVE)

assert "cleanedUp" in _get_ready_status(deployment, ServiceStatus.CLEANEDUP)


def test_ready_status_supports_failed_service_status():
deployment = SimpleNamespace(status=DeploymentStatus.ACTIVE)

assert "failed" in _get_ready_status(deployment, ServiceStatus.FAILED)


@contextmanager
def _patch_cluster_client():
client = MagicMock()
context = MagicMock()
context.__enter__.return_value = client
context.__exit__.return_value = False

with patch("centml.cli.cluster.get_centml_client", return_value=context):
yield client


def _deployment(**overrides):
defaults = {
"id": 123,
"name": "test-job",
"type": DeploymentType.JOB,
"status": DeploymentStatus.PAUSED,
"created_at": datetime(2026, 1, 2, 3, 4, 5, tzinfo=timezone.utc),
"cluster_id": 1,
"hardware_instance_id": 2,
"endpoint_url": "https://jobs.example.com/test-job",
"image_url": "registry.example.com/job:latest",
"command": ["python", "main.py"],
"args": ["--epochs", "1"],
"original_command": "python main.py --epochs 1",
"env_vars": {"ENV": "test"},
"completions": 1,
"parallelism": 1,
"enable_logging": True,
}
defaults.update(overrides)
return SimpleNamespace(**defaults)


def test_ls_accepts_job_type_and_displays_jobs():
from centml.cli.cluster import ls

deployment = _deployment()
runner = CliRunner()

with _patch_cluster_client() as client:
client.get.return_value = [deployment]

result = runner.invoke(ls, ["job"])

assert result.exit_code == 0
client.get.assert_called_once_with(DeploymentType.JOB)
assert "test-job" in result.output
assert "job" in result.output


def test_get_job_routes_to_job_api_and_displays_job_config():
from centml.cli.cluster import get

deployment = _deployment(status=DeploymentStatus.ACTIVE)
hardware = SimpleNamespace(id=2, name="h100", num_gpu=8, gpu_type="H100", cost_per_hr=1200)
runner = CliRunner()

with _patch_cluster_client() as client:
client.get_job.return_value = deployment
client.get_status.return_value = SimpleNamespace(service_status=ServiceStatus.HEALTHY)
client.get_hardware_instances.return_value = [hardware]

result = runner.invoke(get, ["job", "123"])

assert result.exit_code == 0
client.get_job.assert_called_once_with(123)
client.get_status.assert_called_once_with(123)
assert "test-job" in result.output
assert "ready" in result.output
assert "Endpoint" not in result.output
assert "https://jobs.example.com/test-job" not in result.output
assert "registry.example.com/job:latest" in result.output
assert "Completions" in result.output
assert "Parallelism" in result.output
29 changes: 29 additions & 0 deletions tests/test_sdk_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

from platform_api_python_client import CreateJobDeploymentRequest

from centml.sdk import ApiException
from centml.sdk.api import CentMLClient

Expand Down Expand Up @@ -43,3 +45,30 @@ def test_get_status_raises_v3_error_when_both_status_endpoints_fail():

api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.assert_called_once_with(123)
api.get_deployment_status_deployments_status_deployment_id_get.assert_called_once_with(123)


def test_get_job_delegates_to_platform_client():
api = MagicMock()
expected_response = MagicMock()
api.get_job_deployment_deployments_job_deployment_id_get.return_value = expected_response
client = CentMLClient(api)

response = client.get_job(123)

assert response is expected_response
api.get_job_deployment_deployments_job_deployment_id_get.assert_called_once_with(123)


def test_create_job_delegates_to_platform_client():
api = MagicMock()
expected_response = MagicMock()
api.create_job_deployment_deployments_job_post.return_value = expected_response
request = CreateJobDeploymentRequest(
name="test-job", cluster_id=1, hardware_instance_id=2, image_url="registry.example.com/job:latest"
)
client = CentMLClient(api)

response = client.create_job(request)

assert response is expected_response
api.create_job_deployment_deployments_job_post.assert_called_once_with(request)
Loading