Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 93 additions & 21 deletions centml/cli/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from typing import Dict
import click
from tabulate import tabulate
from centml.sdk import DeploymentType, DeploymentStatus, ServiceStatus, ApiException, HardwareInstanceResponse
from centml.sdk import (
DeploymentType,
DeploymentStatus,
ServiceStatus,
RolloutStatus,
ApiException,
HardwareInstanceResponse,
)
from centml.sdk.api import get_centml_client

# convert deployment type enum to a user friendly name
Expand All @@ -27,6 +34,12 @@
"compute": DeploymentType.COMPUTE_V2,
"rag": DeploymentType.RAG,
}
rollout_status_to_service_status_map = {
RolloutStatus.HEALTHY: ServiceStatus.HEALTHY,
RolloutStatus.MISSING: ServiceStatus.MISSING,
RolloutStatus.PROGRESSING: ServiceStatus.INITIALIZING,
RolloutStatus.DEGRADED: ServiceStatus.ERROR,
}


def handle_exception(func):
Expand Down Expand Up @@ -75,18 +88,18 @@ def _get_replica_info(deployment):
return {"min": "N/A", "max": "N/A"}


def _get_ready_status(cclient, deployment):
def _get_ready_status(deployment, service_status):
api_status = deployment.status
service_status = (
cclient.get_status(deployment.id).service_status if deployment.status == DeploymentStatus.ACTIVE else None
)

status_styles = {
(DeploymentStatus.PAUSED, None): ("paused", "yellow", "black"),
(DeploymentStatus.DELETED, None): ("deleted", "white", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.HEALTHY): ("ready", "green", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.SCALINGUP): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.PULLING): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.INITIALIZING): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.MISSING): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.NOTREADY): ("starting", "black", "white"),
(DeploymentStatus.ACTIVE, ServiceStatus.ERROR): ("error", "red", "black"),
(DeploymentStatus.ACTIVE, ServiceStatus.CREATECONTAINERCONFIGERROR): (
"createContainerConfigError",
Expand All @@ -103,6 +116,62 @@ def _get_ready_status(cclient, deployment):
return click.style(style[0], fg=style[1], bg=style[2])


def _get_service_status(status_response, revision_number):
if status_response is None:
return None

service_status = getattr(status_response, "service_status", None)
if service_status is not None:
return service_status

revision_pod_details_list = getattr(status_response, "revision_pod_details_list", None) or []
current_revision = next(
(
revision
for revision in revision_pod_details_list
if getattr(revision, "revision_number", None) == revision_number
),
(
revision_pod_details_list[0]
if revision_pod_details_list and getattr(revision_pod_details_list[0], "revision_number") is None
else None
),
)
revision_status = getattr(current_revision, "revision_status", None)

return revision_status or rollout_status_to_service_status_map.get(getattr(status_response, "rollout_status", None))


def _append_status_error_message(messages, seen_messages, label, error_message):
if not error_message or error_message in seen_messages:
return

seen_messages.add(error_message)
messages.append(f"{label}: {error_message}")


def _get_status_error_messages(status_response):
if status_response is None:
return []

error_message = getattr(status_response, "error_message", None)
if error_message:
return [error_message]

messages = []
seen_messages = set()

for revision in getattr(status_response, "revision_pod_details_list", None) or []:
revision_label = f"revision {revision.revision_number}" if revision.revision_number is not None else "revision"
_append_status_error_message(messages, seen_messages, revision_label, revision.error_message)

for pod in getattr(revision, "pod_details_list", None) or []:
pod_label = pod.name or "pod"
_append_status_error_message(messages, seen_messages, f"{revision_label} / {pod_label}", pod.error_message)

return messages


@click.command(help="List all deployments")
@click.argument("type", type=click.Choice(list(depl_name_to_type_map.keys())), required=False, default=None)
def ls(type):
Expand Down Expand Up @@ -149,24 +218,27 @@ def get(type, id):
else:
sys.exit("Please enter correct deployment type")

ready_status = _get_ready_status(cclient, deployment)
deployment_status = cclient.get_status(deployment.id) if deployment.status == DeploymentStatus.ACTIVE else None
revision_number = getattr(deployment, "revision_number", None)
service_status = _get_service_status(deployment_status, revision_number)
ready_status = _get_ready_status(deployment, service_status)
status_error_messages = _get_status_error_messages(deployment_status)
_, id_to_hw_map = _get_hw_to_id_map(cclient, deployment.cluster_id)
hw = id_to_hw_map[deployment.hardware_instance_id]

click.echo(
tabulate(
[
("Name", deployment.name),
("Status", ready_status),
("Endpoint", deployment.endpoint_url),
("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
("Hardware", f"{hw.name} ({hw.num_gpu}x {hw.gpu_type})"),
("Cost", f"{hw.cost_per_hr / 100} credits/hr"),
],
tablefmt="rounded_outline",
disable_numparse=True,
)
)
detail_rows = [
("Name", deployment.name),
("Status", ready_status),
("Endpoint", deployment.endpoint_url),
("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
("Hardware", f"{hw.name} ({hw.num_gpu}x {hw.gpu_type})"),
("Cost", f"{hw.cost_per_hr / 100} credits/hr"),
]

click.echo(tabulate(detail_rows, tablefmt="rounded_outline", disable_numparse=True))
if status_error_messages:
click.echo("\nStatus errors:")
for message in status_error_messages:
click.echo(f"- {message}")

click.echo("Additional deployment configurations:")
if depl_type in [DeploymentType.INFERENCE_V2, DeploymentType.INFERENCE_V3]:
Expand Down
Empty file.
15 changes: 11 additions & 4 deletions centml/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from centml.sdk import auth
from centml.sdk.config import settings

STATUS_V3_DEPLOYMENT_TYPES = {DeploymentType.INFERENCE_V3, DeploymentType.CSERVE_V3}


class CentMLClient:
def __init__(self, api):
Expand All @@ -26,10 +28,15 @@ def get(self, depl_type):
return deployments

def get_status(self, id):
return self._api.get_deployment_status_deployments_status_deployment_id_get(id)

def get_status_v3(self, deployment_id):
return self._api.get_deployment_status_v3_deployments_status_v3_deployment_id_get(deployment_id)
try:
return self._api.get_deployment_status_v3_deployments_status_v3_deployment_id_get(id)
except ApiException as e:
Comment thread
michaelshin marked this conversation as resolved.
Comment thread
michaelshin marked this conversation as resolved.
Comment thread
michaelshin marked this conversation as resolved.
if e.status in [404, 400]:
try:
return self._api.get_deployment_status_deployments_status_deployment_id_get(id)
except ApiException as v2_error:
raise e from v2_error
Comment thread
michaelshin marked this conversation as resolved.
raise

def get_inference(self, id):
"""Get Inference deployment details - automatically handles both V2 and V3 deployments"""
Expand Down
4 changes: 2 additions & 2 deletions centml/sdk/shell/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def build_ws_url(api_url, deployment_id, pod_name, shell_type=None):


def get_running_pods(cclient, deployment_id) -> list[str]:
status = cclient.get_status_v3(deployment_id)
status = cclient.get_status(deployment_id)
running_pods = []
for revision in status.revision_pod_details_list or []:
for revision in getattr(status, "revision_pod_details_list", None) or []:
for pod in revision.pod_details_list or []:
if pod.status == PodStatus.RUNNING and pod.name:
running_pods.append(pod.name)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-r requirements.txt

torch==2.8.0
black>=23.10.0
black==26.3.1
pylint>=3.0.1
pytest>=7.4.0
pytest-env>=1.1.3
Expand Down
79 changes: 79 additions & 0 deletions tests/test_cli_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from types import SimpleNamespace

from centml.cli.cluster import _get_service_status, _get_status_error_messages
from centml.sdk import RolloutStatus, ServiceStatus


def test_service_status_uses_legacy_service_status_when_present():
status_response = SimpleNamespace(service_status=ServiceStatus.HEALTHY)

assert _get_service_status(status_response, revision_number=None) == ServiceStatus.HEALTHY


def test_service_status_maps_v3_healthy_rollout_status():
status_response = SimpleNamespace(rollout_status=RolloutStatus.HEALTHY)

assert _get_service_status(status_response, revision_number=None) == ServiceStatus.HEALTHY


def test_service_status_uses_current_v3_revision_status():
status_response = SimpleNamespace(
rollout_status=RolloutStatus.PROGRESSING,
revision_pod_details_list=[
SimpleNamespace(revision_number=1, revision_status=ServiceStatus.HEALTHY),
SimpleNamespace(revision_number=2, revision_status=ServiceStatus.SCALINGUP),
],
)

assert _get_service_status(status_response, revision_number=2) == ServiceStatus.SCALINGUP


def test_service_status_uses_v3_revision_without_revision_number_as_fallback():
status_response = SimpleNamespace(
rollout_status=RolloutStatus.DEGRADED,
revision_pod_details_list=[
SimpleNamespace(revision_number=None, revision_status=ServiceStatus.IMAGEPULLBACKOFF)
],
)

assert _get_service_status(status_response, revision_number=2) == ServiceStatus.IMAGEPULLBACKOFF


def test_status_error_messages_include_revision_and_pod_messages():
status_response = SimpleNamespace(
revision_pod_details_list=[
SimpleNamespace(
revision_number=3,
error_message="revision failed",
pod_details_list=[
SimpleNamespace(name="pod-a", error_message="image pull failed"),
SimpleNamespace(name="pod-b", error_message=None),
],
)
]
)

messages = _get_status_error_messages(status_response)

assert messages == ["revision 3: revision failed", "revision 3 / pod-a: image pull failed"]


def test_status_error_messages_do_not_repeat_duplicate_messages():
duplicate_message = "one or more objects failed to apply"
status_response = SimpleNamespace(
revision_pod_details_list=[
SimpleNamespace(
revision_number=None,
error_message=duplicate_message,
pod_details_list=[SimpleNamespace(name=None, error_message=duplicate_message)],
)
]
)

assert _get_status_error_messages(status_response) == [f"revision: {duplicate_message}"]


def test_status_error_messages_include_legacy_status_message():
status_response = SimpleNamespace(error_message="legacy service failure")

assert _get_status_error_messages(status_response) == ["legacy service failure"]
45 changes: 45 additions & 0 deletions tests/test_sdk_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from types import SimpleNamespace
from unittest.mock import MagicMock

from centml.sdk import ApiException
from centml.sdk.api import CentMLClient


def test_get_status_uses_v3_endpoint():
api = MagicMock()
expected_status = SimpleNamespace()
api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.return_value = expected_status

assert CentMLClient(api).get_status(123) is expected_status

api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.assert_called_once_with(123)
api.get_deployment_status_deployments_status_deployment_id_get.assert_not_called()


def test_get_status_falls_back_to_legacy_endpoint_when_v3_is_not_found():
api = MagicMock()
expected_status = SimpleNamespace()
api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.side_effect = ApiException(status=404)
api.get_deployment_status_deployments_status_deployment_id_get.return_value = expected_status

assert CentMLClient(api).get_status(123) is expected_status

api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.assert_called_once_with(123)
api.get_deployment_status_deployments_status_deployment_id_get.assert_called_once_with(123)


def test_get_status_raises_v3_error_when_both_status_endpoints_fail():
api = MagicMock()
v3_error = ApiException(status=404)
api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.side_effect = v3_error
api.get_deployment_status_deployments_status_deployment_id_get.side_effect = ApiException(status=404)

try:
CentMLClient(api).get_status(123)
except ApiException as e:
assert e is v3_error
else:
raise AssertionError("Expected ApiException")

api.get_deployment_status_v3_deployments_status_v3_deployment_id_get.assert_called_once_with(123)
api.get_deployment_status_deployments_status_deployment_id_get.assert_called_once_with(123)
Loading