From 60089204b892a2b1b9331ca1dd62c42725676bf8 Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Tue, 16 Jun 2026 00:44:38 -0400 Subject: [PATCH 1/2] feat(configurator): add ObsLeafDescriptor + structured-observation protocol Add ObsLeafDescriptor (a self-describing observation leaf: "box" of width dim, or "discrete" of size n) and a StructuredObservation Protocol that documents the optional env hooks structured_observation_descriptors() and encode_observation(). These let an env expose a named, per-leaf observation so adapters (e.g. GymnasiumAdapter) can build the matching gymnasium spaces.Dict; the hooks are duck-typed, so envs need not subclass. Both exported via cloudai.core. --- src/cloudai/configurator/env_params.py | 46 +++++++++++++++++++++++++- src/cloudai/core.py | 3 ++ tests/test_env_params.py | 29 ++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/cloudai/configurator/env_params.py b/src/cloudai/configurator/env_params.py index 13db021d2..36d73bf14 100644 --- a/src/cloudai/configurator/env_params.py +++ b/src/cloudai/configurator/env_params.py @@ -32,7 +32,7 @@ import math import random from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self @@ -76,6 +76,50 @@ def _validate_weights(self) -> Self: return self +class ObsLeafDescriptor(BaseModel): + """ + Description of one leaf of a structured (named) observation. + + A structured observation maps each observed name to a self-describing leaf + so adapters can build the matching subspace without guessing: a ``"box"`` + leaf becomes a continuous vector of width ``dim`` (e.g. a log-encoded + env_param as ``dim=2``); a ``"discrete"`` leaf becomes a categorical of + size ``n``. Stateless agents that consume the flat observation ignore this. + """ + + model_config = ConfigDict(extra="forbid") + + kind: Literal["box", "discrete"] + dim: int = 1 + n: Optional[int] = None + + @model_validator(mode="after") + def _validate(self) -> Self: + if self.dim < 1: + raise ValueError(f"ObsLeafDescriptor dim must be >= 1; got {self.dim}") + if self.kind == "discrete" and (self.n is None or self.n < 1): + raise ValueError(f"ObsLeafDescriptor(kind='discrete') requires n >= 1; got n={self.n}") + return self + + +@runtime_checkable +class StructuredObservation(Protocol): + """ + Optional env hooks that expose a structured (per-leaf) observation. + + An env opts in by returning per-leaf :class:`ObsLeafDescriptor` from + ``structured_observation_descriptors`` (``None`` keeps the flat-vector + path) and encoding a raw observation into the matching named leaves via + ``encode_observation``. ``GymnasiumAdapter`` consumes these to expose a + ``gymnasium.spaces.Dict`` observation; the hooks are duck-typed, so envs + need not subclass this Protocol. + """ + + def structured_observation_descriptors(self) -> Optional[Dict[str, ObsLeafDescriptor]]: ... + + def encode_observation(self, observation: list) -> Dict[str, Any]: ... + + @dataclasses.dataclass(frozen=True) class EnvParam: """ diff --git a/src/cloudai/core.py b/src/cloudai/core.py index 752d24972..aefcd0573 100644 --- a/src/cloudai/core.py +++ b/src/cloudai/core.py @@ -51,6 +51,7 @@ from ._core.test_scenario import METRIC_ERROR, MetricErrorSentinel, MetricValue, TestDependency, TestRun, TestScenario from .configurator.base_agent import BaseAgent, BaseAgentConfig, RewardOverrides from .configurator.cloudai_gym import CloudAIGymEnv +from .configurator.env_params import ObsLeafDescriptor, StructuredObservation from .configurator.grid_search import GridSearchAgent from .models.workload import CmdArgs, NsysConfiguration, PredictorConfig, TestDefinition from .parser import Parser @@ -85,6 +86,7 @@ "MetricValue", "MissingTestError", "NsysConfiguration", + "ObsLeafDescriptor", "Parser", "PerTestReporter", "PredictorConfig", @@ -96,6 +98,7 @@ "RewardOverrides", "Runner", "StatusReporter", + "StructuredObservation", "System", "SystemConfigParsingError", "TarballReporter", diff --git a/tests/test_env_params.py b/tests/test_env_params.py index 362a5e5d1..5e26c606e 100644 --- a/tests/test_env_params.py +++ b/tests/test_env_params.py @@ -43,6 +43,7 @@ EnvParam, EnvParams, EnvParamSpec, + ObsLeafDescriptor, write_env_params, ) from cloudai.core import TestRun @@ -375,3 +376,31 @@ def test_apply_params_set_accepts_weighted_env_param_draw() -> None: new_tr = tr.apply_params_set({}, env_params={"ball_speed": 1}) assert new_tr.test.cmd_args.ball_speed == 1 + + +# --- ObsLeafDescriptor: structured-observation leaf schema --- + + +def test_obs_leaf_descriptor_box_defaults() -> None: + leaf = ObsLeafDescriptor(kind="box", dim=2) + assert leaf.kind == "box" + assert leaf.dim == 2 + assert leaf.n is None + + +def test_obs_leaf_descriptor_discrete_requires_n() -> None: + leaf = ObsLeafDescriptor(kind="discrete", dim=1, n=3) + assert leaf.n == 3 + with pytest.raises(ValidationError, match="requires n"): + ObsLeafDescriptor(kind="discrete", dim=1) + with pytest.raises(ValidationError, match="requires n"): + ObsLeafDescriptor(kind="discrete", dim=1, n=0) + + +def test_obs_leaf_descriptor_rejects_bad_dim_and_extra_fields() -> None: + with pytest.raises(ValidationError, match="dim must be"): + ObsLeafDescriptor(kind="box", dim=0) + with pytest.raises(ValidationError): + ObsLeafDescriptor(kind="box", dim=1, unexpected=1) + with pytest.raises(ValidationError): + ObsLeafDescriptor(kind="categorical", dim=1) From 3d1d8fb9470406cac2111369f8ba0fb47c718e6b Mon Sep 17 00:00:00 2001 From: Rutayan Patro Date: Tue, 16 Jun 2026 10:57:57 -0400 Subject: [PATCH 2/2] test(env-params): suppress pyright on intentional ObsLeafDescriptor rejection tests Negative tests pass an extra kwarg and an out-of-Literal kind to assert ValidationError; mark the deliberate type violations with type: ignore. --- tests/test_env_params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_env_params.py b/tests/test_env_params.py index 5e26c606e..a0341d615 100644 --- a/tests/test_env_params.py +++ b/tests/test_env_params.py @@ -401,6 +401,6 @@ def test_obs_leaf_descriptor_rejects_bad_dim_and_extra_fields() -> None: with pytest.raises(ValidationError, match="dim must be"): ObsLeafDescriptor(kind="box", dim=0) with pytest.raises(ValidationError): - ObsLeafDescriptor(kind="box", dim=1, unexpected=1) + ObsLeafDescriptor(kind="box", dim=1, unexpected=1) # type: ignore with pytest.raises(ValidationError): - ObsLeafDescriptor(kind="categorical", dim=1) + ObsLeafDescriptor(kind="categorical", dim=1) # type: ignore