diff --git a/src/cloudai/configurator/env_params.py b/src/cloudai/configurator/env_params.py index 13db021d2..36d73bf14 100644 --- a/src/cloudai/configurator/env_params.py +++ b/src/cloudai/configurator/env_params.py @@ -32,7 +32,7 @@ import math import random from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Self @@ -76,6 +76,50 @@ def _validate_weights(self) -> Self: return self +class ObsLeafDescriptor(BaseModel): + """ + Description of one leaf of a structured (named) observation. + + A structured observation maps each observed name to a self-describing leaf + so adapters can build the matching subspace without guessing: a ``"box"`` + leaf becomes a continuous vector of width ``dim`` (e.g. a log-encoded + env_param as ``dim=2``); a ``"discrete"`` leaf becomes a categorical of + size ``n``. Stateless agents that consume the flat observation ignore this. + """ + + model_config = ConfigDict(extra="forbid") + + kind: Literal["box", "discrete"] + dim: int = 1 + n: Optional[int] = None + + @model_validator(mode="after") + def _validate(self) -> Self: + if self.dim < 1: + raise ValueError(f"ObsLeafDescriptor dim must be >= 1; got {self.dim}") + if self.kind == "discrete" and (self.n is None or self.n < 1): + raise ValueError(f"ObsLeafDescriptor(kind='discrete') requires n >= 1; got n={self.n}") + return self + + +@runtime_checkable +class StructuredObservation(Protocol): + """ + Optional env hooks that expose a structured (per-leaf) observation. + + An env opts in by returning per-leaf :class:`ObsLeafDescriptor` from + ``structured_observation_descriptors`` (``None`` keeps the flat-vector + path) and encoding a raw observation into the matching named leaves via + ``encode_observation``. ``GymnasiumAdapter`` consumes these to expose a + ``gymnasium.spaces.Dict`` observation; the hooks are duck-typed, so envs + need not subclass this Protocol. + """ + + def structured_observation_descriptors(self) -> Optional[Dict[str, ObsLeafDescriptor]]: ... + + def encode_observation(self, observation: list) -> Dict[str, Any]: ... + + @dataclasses.dataclass(frozen=True) class EnvParam: """ diff --git a/src/cloudai/core.py b/src/cloudai/core.py index 752d24972..aefcd0573 100644 --- a/src/cloudai/core.py +++ b/src/cloudai/core.py @@ -51,6 +51,7 @@ from ._core.test_scenario import METRIC_ERROR, MetricErrorSentinel, MetricValue, TestDependency, TestRun, TestScenario from .configurator.base_agent import BaseAgent, BaseAgentConfig, RewardOverrides from .configurator.cloudai_gym import CloudAIGymEnv +from .configurator.env_params import ObsLeafDescriptor, StructuredObservation from .configurator.grid_search import GridSearchAgent from .models.workload import CmdArgs, NsysConfiguration, PredictorConfig, TestDefinition from .parser import Parser @@ -85,6 +86,7 @@ "MetricValue", "MissingTestError", "NsysConfiguration", + "ObsLeafDescriptor", "Parser", "PerTestReporter", "PredictorConfig", @@ -96,6 +98,7 @@ "RewardOverrides", "Runner", "StatusReporter", + "StructuredObservation", "System", "SystemConfigParsingError", "TarballReporter", diff --git a/tests/test_env_params.py b/tests/test_env_params.py index 362a5e5d1..a0341d615 100644 --- a/tests/test_env_params.py +++ b/tests/test_env_params.py @@ -43,6 +43,7 @@ EnvParam, EnvParams, EnvParamSpec, + ObsLeafDescriptor, write_env_params, ) from cloudai.core import TestRun @@ -375,3 +376,31 @@ def test_apply_params_set_accepts_weighted_env_param_draw() -> None: new_tr = tr.apply_params_set({}, env_params={"ball_speed": 1}) assert new_tr.test.cmd_args.ball_speed == 1 + + +# --- ObsLeafDescriptor: structured-observation leaf schema --- + + +def test_obs_leaf_descriptor_box_defaults() -> None: + leaf = ObsLeafDescriptor(kind="box", dim=2) + assert leaf.kind == "box" + assert leaf.dim == 2 + assert leaf.n is None + + +def test_obs_leaf_descriptor_discrete_requires_n() -> None: + leaf = ObsLeafDescriptor(kind="discrete", dim=1, n=3) + assert leaf.n == 3 + with pytest.raises(ValidationError, match="requires n"): + ObsLeafDescriptor(kind="discrete", dim=1) + with pytest.raises(ValidationError, match="requires n"): + ObsLeafDescriptor(kind="discrete", dim=1, n=0) + + +def test_obs_leaf_descriptor_rejects_bad_dim_and_extra_fields() -> None: + with pytest.raises(ValidationError, match="dim must be"): + ObsLeafDescriptor(kind="box", dim=0) + with pytest.raises(ValidationError): + ObsLeafDescriptor(kind="box", dim=1, unexpected=1) # type: ignore + with pytest.raises(ValidationError): + ObsLeafDescriptor(kind="categorical", dim=1) # type: ignore