Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/cloudai/configurator/env_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import math
import random
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Protocol, runtime_checkable

from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
Expand Down Expand Up @@ -76,6 +76,50 @@ def _validate_weights(self) -> Self:
return self


class ObsLeafDescriptor(BaseModel):
"""
Description of one leaf of a structured (named) observation.

A structured observation maps each observed name to a self-describing leaf
so adapters can build the matching subspace without guessing: a ``"box"``
leaf becomes a continuous vector of width ``dim`` (e.g. a log-encoded
env_param as ``dim=2``); a ``"discrete"`` leaf becomes a categorical of
size ``n``. Stateless agents that consume the flat observation ignore this.
"""

model_config = ConfigDict(extra="forbid")

kind: Literal["box", "discrete"]
dim: int = 1
n: Optional[int] = None

@model_validator(mode="after")
def _validate(self) -> Self:
if self.dim < 1:
raise ValueError(f"ObsLeafDescriptor dim must be >= 1; got {self.dim}")
if self.kind == "discrete" and (self.n is None or self.n < 1):
raise ValueError(f"ObsLeafDescriptor(kind='discrete') requires n >= 1; got n={self.n}")
return self
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@runtime_checkable
class StructuredObservation(Protocol):
"""
Optional env hooks that expose a structured (per-leaf) observation.

An env opts in by returning per-leaf :class:`ObsLeafDescriptor` from
``structured_observation_descriptors`` (``None`` keeps the flat-vector
path) and encoding a raw observation into the matching named leaves via
``encode_observation``. ``GymnasiumAdapter`` consumes these to expose a
``gymnasium.spaces.Dict`` observation; the hooks are duck-typed, so envs
need not subclass this Protocol.
"""

def structured_observation_descriptors(self) -> Optional[Dict[str, ObsLeafDescriptor]]: ...

def encode_observation(self, observation: list) -> Dict[str, Any]: ...
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@dataclasses.dataclass(frozen=True)
class EnvParam:
"""
Expand Down
3 changes: 3 additions & 0 deletions src/cloudai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ._core.test_scenario import METRIC_ERROR, MetricErrorSentinel, MetricValue, TestDependency, TestRun, TestScenario
from .configurator.base_agent import BaseAgent, BaseAgentConfig, RewardOverrides
from .configurator.cloudai_gym import CloudAIGymEnv
from .configurator.env_params import ObsLeafDescriptor, StructuredObservation
from .configurator.grid_search import GridSearchAgent
from .models.workload import CmdArgs, NsysConfiguration, PredictorConfig, TestDefinition
from .parser import Parser
Expand Down Expand Up @@ -85,6 +86,7 @@
"MetricValue",
"MissingTestError",
"NsysConfiguration",
"ObsLeafDescriptor",
"Parser",
"PerTestReporter",
"PredictorConfig",
Expand All @@ -96,6 +98,7 @@
"RewardOverrides",
"Runner",
"StatusReporter",
"StructuredObservation",
"System",
"SystemConfigParsingError",
"TarballReporter",
Expand Down
29 changes: 29 additions & 0 deletions tests/test_env_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
EnvParam,
EnvParams,
EnvParamSpec,
ObsLeafDescriptor,
write_env_params,
)
from cloudai.core import TestRun
Expand Down Expand Up @@ -375,3 +376,31 @@ def test_apply_params_set_accepts_weighted_env_param_draw() -> None:
new_tr = tr.apply_params_set({}, env_params={"ball_speed": 1})

assert new_tr.test.cmd_args.ball_speed == 1


# --- ObsLeafDescriptor: structured-observation leaf schema ---


def test_obs_leaf_descriptor_box_defaults() -> None:
leaf = ObsLeafDescriptor(kind="box", dim=2)
assert leaf.kind == "box"
assert leaf.dim == 2
assert leaf.n is None


def test_obs_leaf_descriptor_discrete_requires_n() -> None:
leaf = ObsLeafDescriptor(kind="discrete", dim=1, n=3)
assert leaf.n == 3
with pytest.raises(ValidationError, match="requires n"):
ObsLeafDescriptor(kind="discrete", dim=1)
with pytest.raises(ValidationError, match="requires n"):
ObsLeafDescriptor(kind="discrete", dim=1, n=0)


def test_obs_leaf_descriptor_rejects_bad_dim_and_extra_fields() -> None:
with pytest.raises(ValidationError, match="dim must be"):
ObsLeafDescriptor(kind="box", dim=0)
with pytest.raises(ValidationError):
ObsLeafDescriptor(kind="box", dim=1, unexpected=1) # type: ignore
with pytest.raises(ValidationError):
ObsLeafDescriptor(kind="categorical", dim=1) # type: ignore
Loading