Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/cloudai/_core/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING, Any, List, Optional, Set, Type, TypeAlias, Union

from ..util import flatten_dict
from .registry import Registry
from .system import System

if TYPE_CHECKING:
Expand Down Expand Up @@ -140,6 +141,22 @@ def get_metric_value(self, system: System, metric: str) -> MetricValue:
def is_dse_job(self) -> bool:
return self.test.is_dse_job or isinstance(self.num_nodes, list)

@property
def is_domain_randomization_active(self) -> bool:
"""
Whether this run will actually env-sample (domain-randomize) per trial.

True only when domain randomization is declared (``env_params`` present), the run is a DSE
job (so a per-trial loop exists - including a ``num_nodes`` sweep), and the agent opts into
sampling. An unknown agent is treated as opted-in so the dedicated agent-resolution error
surfaces instead of this one.
"""
if not self.test.is_domain_randomization_enabled:
return False

agent = Registry().agents_map.get(self.test.agent)
return self.is_dse_job and (agent is None or agent.supports_variable_environment)

@property
def nnodes(self) -> int:
"""Type safe getter for num_nodes, should only be used on an unrolled DSE job."""
Expand All @@ -156,7 +173,9 @@ def param_space(self) -> dict[str, Any]:
**{
key: value
for key, value in cmd_args_dict.items()
if isinstance(value, list) and not self.test.is_dse_excluded_arg(key)
if isinstance(value, list)
and not self.test.is_dse_excluded_arg(key)
and not self.test.is_env_sampled(key)
},
**{f"extra_env_vars.{key}": value for key, value in extra_env_vars_dict.items() if isinstance(value, list)},
}
Expand Down Expand Up @@ -184,9 +203,13 @@ def all_combinations(self) -> list[dict[str, Any]]:

return all_combinations

def apply_params_set(self, action: dict[str, Any]) -> "TestRun":
def apply_params_set(self, action: dict[str, Any], env_params: dict[str, Any] | None = None) -> "TestRun":
tdef = self.test.model_copy(deep=True)
for key, value in action.items():

# RNG runs in the env before this call; applying only concrete values keeps this deterministic.
# action and env_params target disjoint keys, so a plain merge applies both in one pass.
full_action = action | (env_params or {})
for key, value in full_action.items():
if key.startswith("extra_env_vars."):
tdef.extra_env_vars[key[len("extra_env_vars.") :]] = value
else:
Expand All @@ -199,7 +222,12 @@ def apply_params_set(self, action: dict[str, Any]) -> "TestRun":
else:
setattr(obj, attrs[-1], value)

type(tdef)(**tdef.model_dump()) # trigger validation
# env_params is validated at parse time; after the overlay its target cmd_args fields hold
# concrete scalar draws, so re-validating it here would reject weighted specs. Drop it for
# this validation-only pass, which exists to validate the applied action values.
validation_args = tdef.model_dump()
validation_args.pop("env_params", None)
type(tdef)(**validation_args) # trigger validation

new_tr = copy.deepcopy(self)
new_tr.test = tdef
Expand Down
14 changes: 11 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import toml
import yaml

from cloudai.configurator.env_params import validate_domain_randomization_active
from cloudai.core import (
BaseInstaller,
CloudAIGymEnv,
Expand All @@ -39,6 +40,7 @@
System,
TestParser,
TestScenario,
TestScenarioParsingError,
)
from cloudai.models.scenario import ReportConfig
from cloudai.models.workload import TestDefinition
Expand Down Expand Up @@ -133,8 +135,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return 1

err = 0
# Recoverable failures return a non-zero rc and are accumulated here; an unexpected exception
# (a bug) is a hard-fail. We capture it so reports still generate, then re-raise below.
# Capture an unexpected error so reports still generate, then re-raise below.
run_error: Exception | None = None
try:
for tr in runner.runner.test_scenario.test_runs:
Expand Down Expand Up @@ -303,6 +304,12 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int:
return 1
system, test_scenario, tests = setup_result

try:
validate_domain_randomization_active(test_scenario)
except TestScenarioParsingError as e:
logging.error(str(e))
return 1

Comment thread
rutayan-nv marked this conversation as resolved.
if not _handle_single_sbatch(args, system):
return 1

Expand Down Expand Up @@ -491,7 +498,8 @@ def verify_test_scenarios(
tests = Parser.parse_tests(test_tomls, system)
hook_tests = Parser.parse_tests(hook_test_tomls, system)
hooks = Parser.parse_hooks(hook_tomls, system, {t.name: t for t in hook_tests})
Parser.parse_test_scenario(scenario_file, system, {t.name: t for t in tests}, hooks)
scenario = Parser.parse_test_scenario(scenario_file, system, {t.name: t for t in tests}, hooks)
validate_domain_randomization_active(scenario)
except Exception:
nfailed += 1

Expand Down
17 changes: 11 additions & 6 deletions src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class BaseAgent(ABC):
Provides a unified interface and parameter management for action spaces.
"""

# Opt-in: agents that operate over a variable environment - one that changes per trial, whether
# by env_params sampling (domain randomization) or a curriculum schedule - set this True. Default
# False so env_params declared for an agent that cannot handle a varying env are rejected rather
# than silently ignored.
supports_variable_environment: bool = False

def __init__(self, env: BaseGym, config: BaseAgentConfig):
"""
Initialize the agent with the environment.
Expand Down Expand Up @@ -94,9 +100,8 @@ def select_action(self, observation: list[float] | None = None) -> tuple[int, di

Args:
observation: Latest observation produced by the environment (``env.reset()`` on the
first call, then the result of the prior ``env.step()``). Stateless agents such
as grid search or Bayesian optimization may ignore this; observation-conditioned
agents (RL, contextual bandits) should use it.
first call, then the result of the prior ``env.step()``). Stateless agents may
ignore this; observation-conditioned agents should use it.

Returns:
Tuple[int, Dict[str, Any]] | None: The current step index and a dictionary mapping action keys
Expand All @@ -120,8 +125,7 @@ def run(self) -> int:

Default: a step loop driven by the dispatcher (``select_action`` →
``env.step`` → ``update_policy`` per trial). Agents that drive their
own training loop (e.g. RLlib-based agents calling ``algo.train()``)
override this method.
own training loop override this method.

Failure contract (``handle_dse_job`` consumes the result via
``err |= agent.run()``):
Expand All @@ -131,7 +135,8 @@ def run(self) -> int:
accumulated and the next ``TestRun`` still executes. Workload-level
failures are already surfaced this way: ``CloudAIGymEnv.step`` maps a
failed metric to ``rewards.metric_failure`` rather than raising, and
``rllib_run`` catches training errors and returns ``rc=1``.
agents with their own training loop should likewise catch training
errors and return a non-zero code.
- Raise for *unexpected* failures (framework/agent bugs). Exceptions
propagate out of ``handle_dse_job`` and hard-fail the job so the bug
is surfaced instead of masked as a penalizing reward.
Expand Down
50 changes: 44 additions & 6 deletions src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .base_agent import RewardOverrides
from .base_gym import BaseGym
from .env_params import EnvParams, write_env_params


@dataclasses.dataclass(frozen=True)
Expand All @@ -36,6 +37,7 @@ class TrajectoryEntry:
action: dict[str, Any]
reward: float
observation: list
env_params: dict[str, Any] = dataclasses.field(default_factory=dict)


class CloudAIGymEnv(BaseGym):
Expand All @@ -61,8 +63,14 @@ def __init__(self, test_run: TestRun, runner: BaseRunner, rewards: RewardOverrid
self.max_steps = test_run.test.agent_steps
self.reward_function = Registry().get_reward_function(test_run.test.agent_reward_function)
self.trajectory: dict[int, list[TrajectoryEntry]] = {}
self.params: EnvParams | None = EnvParams.from_test(test_run.test)
super().__init__()

@property
def env_params_record_path(self) -> Path:
"""``env.csv`` lives alongside ``trajectory.csv`` so a plain ``merge`` joins them."""
return self.iteration_dir / "env.csv"

def define_action_space(self) -> Dict[str, list[Any]]:
return self.test_run.param_space

Expand Down Expand Up @@ -119,9 +127,11 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
- info (dict): Additional info for debugging.
"""
self.test_run.increment_step()
self.test_run = self.test_run.apply_params_set(action)
# RNG lives in the env: sample here, then apply action + sample so the run and cache key see them.
sampled_env_params = self.params.sample(self.test_run.step) if self.params else {}
self.test_run = self.test_run.apply_params_set(action, env_params=sampled_env_params)

cached_result = self.get_cached_trajectory_result(action)
cached_result = self.get_cached_trajectory_result(action, sampled_env_params)
if cached_result is not None:
logging.info(
"Retrieved cached result from trajectory with reward %s (from step %s). Skipping execution.",
Expand All @@ -134,6 +144,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=cached_result.reward,
observation=cached_result.observation,
env_params=sampled_env_params,
)
)
return cached_result.observation, cached_result.reward, False, {}
Expand Down Expand Up @@ -171,6 +182,7 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=reward,
observation=observation,
env_params=sampled_env_params,
)
)

Expand Down Expand Up @@ -230,7 +242,14 @@ def get_observation(self, action: Any) -> list:
return observation

def write_trajectory(self, entry: TrajectoryEntry):
"""Append the trajectory to the CSV file and to the local attribute."""
"""
Append the entry to the in-memory cache and trajectory.csv (plus env.csv when declared).

``trajectory.csv`` and the ``env.csv`` projection are sunk from the same
``TrajectoryEntry`` here, so a trial that never produces an entry (e.g. a
constraint failure returns before this call) lands in neither file and the
two stay 1:1 step-aligned.
"""
self.current_trajectory.append(entry)

file_exists = self.trajectory_file_path.exists()
Expand All @@ -243,17 +262,36 @@ def write_trajectory(self, entry: TrajectoryEntry):
writer.writerow(["step", "action", "reward", "observation"])
writer.writerow([entry.step, entry.action, entry.reward, entry.observation])

write_env_params(self.env_params_record_path, entry.step, entry.env_params)

@property
def iteration_dir(self) -> Path:
"""Per-iteration output dir; trajectory.csv and env.csv both live here, step-aligned."""
return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}"

@property
def trajectory_file_path(self) -> Path:
return self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" / "trajectory.csv"
return self.iteration_dir / "trajectory.csv"

@property
def current_trajectory(self) -> list[TrajectoryEntry]:
return self.trajectory.setdefault(self.test_run.current_iteration, [])

def get_cached_trajectory_result(self, action: Any) -> TrajectoryEntry | None:
def get_cached_trajectory_result(self, action: Any, env_params: dict[str, Any]) -> TrajectoryEntry | None:
"""
Return a cached entry only when the full trial identity matches.

Trial identity is ``(action, env_params)``: env-randomized parameters
change the workload's behaviour, so a trial repeating the same action
under a different ``env_params`` sample must miss and re-run. Empty
env_params on both sides is the back-compat path for workloads that
do not declare any ``[env_params.*]`` block. The sample is passed in (a
per-trial local owned by ``step``), exactly like ``action``.
"""
for entry in self.current_trajectory:
if self._values_match_exact(entry.action, action):
action_match = self._values_match_exact(entry.action, action)
env_params_match = self._values_match_exact(entry.env_params, env_params)
if action_match and env_params_match:
return entry

return None
Expand Down
Loading
Loading