diff --git a/orchestration/_tests/test_bl832/test_mlflow.py b/orchestration/_tests/test_bl832/test_mlflow.py index 49806332..43aafed1 100644 --- a/orchestration/_tests/test_bl832/test_mlflow.py +++ b/orchestration/_tests/test_bl832/test_mlflow.py @@ -226,13 +226,13 @@ class TestLoadJobOptionsMLflowLayer: def _patch_variable_defaults(self, mocker): mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": True}, ) def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832): """When MLflow returns a checkpoint, nersc_path is written to mlflow_checkpoint_key.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo self._patch_variable_defaults(mocker) @@ -246,12 +246,12 @@ def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832 inference_params={}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -263,7 +263,7 @@ def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832 def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_config832): """inference_params from MLflow overwrite matching config keys.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo self._patch_variable_defaults(mocker) @@ -281,12 +281,12 @@ def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_conf }, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -300,13 +300,13 @@ def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_conf def test_mlflow_layer_skipped_when_config_is_none(self, mocker, mock_config832): """Passing config=None skips the MLflow layer entirely.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) - spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=None, @@ -320,13 +320,13 @@ def test_mlflow_layer_skipped_when_config_is_none(self, mocker, mock_config832): def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config832): """Passing mlflow_model_name=None skips the MLflow layer.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) - spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") base_settings = dict(mock_config832.nersc_segment_sam3_settings) - _load_job_options( + load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -337,16 +337,16 @@ def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config8 def test_config_defaults_used_when_mlflow_returns_none(self, mocker, mock_config832): """get_checkpoint_info returning None → config defaults unchanged.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=None, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -358,16 +358,16 @@ def test_config_defaults_used_when_mlflow_returns_none(self, mocker, mock_config def test_config_defaults_used_when_mlflow_raises(self, mocker, mock_config832): """An exception from get_checkpoint_info is caught; config defaults are used.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", side_effect=RuntimeError("Network timeout"), ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -379,12 +379,12 @@ def test_config_defaults_used_when_mlflow_raises(self, mocker, mock_config832): def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config832): """Prefect Variable overrides take priority over MLflow inference params (layer 3 > layer 2).""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo # MLflow says batch_size=8; Prefect Variable says batch_size=16 → 16 wins mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": False, "batch_size": 16}, ) @@ -397,12 +397,12 @@ def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config832): inference_params={"batch_size": 8}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -437,7 +437,7 @@ def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_config832): resolved_settings["finetuned_checkpoint_path"] = mlflow_checkpoint mocker.patch( - "orchestration.flows.bl832.nersc._load_job_options", + "orchestration.flows.bl832.nersc.load_job_options", return_value=resolved_settings, ) @@ -455,16 +455,16 @@ def capture_script(script, *args, **kwargs): config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) - mocker.patch.object(controller, "_submit_job", side_effect=capture_script) - mocker.patch.object(controller, "_wait_for_job", return_value=True) - mocker.patch.object(controller, "_mkdir_remote", return_value=None) + mocker.patch.object(controller, "submit_job", side_effect=capture_script) + mocker.patch.object(controller, "wait_for_job", return_value=True) + mocker.patch.object(controller, "mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) # _get_nersc_username reads NERSC_USERNAME for IRIAPI; stub it - mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + mocker.patch.object(controller, "get_nersc_username", return_value="testuser") result = controller.segmentation_sam3(recon_folder_path="folder/recfile") - assert captured, "_submit_job was never called" + assert captured, "submit_job was never called" assert mlflow_checkpoint in captured[0], ( "The MLflow checkpoint path must appear in the SLURM job script" ) @@ -475,11 +475,11 @@ def test_config_default_checkpoint_used_when_mlflow_unavailable(self, mocker, mo mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": True}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=None, ) @@ -494,16 +494,16 @@ def capture_script(script, *args, **kwargs): config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) - mocker.patch.object(controller, "_submit_job", side_effect=capture_script) - mocker.patch.object(controller, "_wait_for_job", return_value=True) - mocker.patch.object(controller, "_mkdir_remote", return_value=None) + mocker.patch.object(controller, "submit_job", side_effect=capture_script) + mocker.patch.object(controller, "wait_for_job", return_value=True) + mocker.patch.object(controller, "mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) - mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + mocker.patch.object(controller, "get_nersc_username", return_value="testuser") controller.segmentation_sam3(recon_folder_path="folder/recfile") config_default = mock_config832.nersc_segment_sam3_settings["finetuned_checkpoint_path"] - assert captured, "_submit_job was never called" + assert captured, "submit_job was never called" assert config_default in captured[0], ( "Config default checkpoint path must be used when MLflow is unavailable" ) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index abf616fd..97381a18 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -201,60 +201,6 @@ def mock_iriapi_client(mocker): return client -# --------------------------------------------------------------------------- -# _create_sfapi_client -# --------------------------------------------------------------------------- - - -def test_create_sfapi_client_success(mocker): - """Valid credentials produce a Client instance.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch( - "builtins.open", - side_effect=[ - mocker.mock_open(read_data="my-client-id")(), - mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), - ] - ) - mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") - mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - - client = NERSCTomographyHPCController._create_sfapi_client() - - mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") - assert client is mock_client_cls.return_value - - -def test_create_sfapi_client_missing_paths(mocker): - """Unset env vars raise ValueError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController._create_sfapi_client() - - -def test_create_sfapi_client_missing_files(mocker): - """Env vars set but files absent raise FileNotFoundError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController._create_sfapi_client() - - # ────────────────────────────────────────────────────────────────────────────── # build_multi_resolution # ────────────────────────────────────────────────────────────────────────────── @@ -304,7 +250,7 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -326,7 +272,7 @@ def test_segmentation_sam3_submission_failure(mocker, mock_sfapi_client, mock_co from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("GPU queue full") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -346,7 +292,7 @@ def test_segmentation_sam3_uses_variable_options(mocker, mock_sfapi_client, mock from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={ + mocker.patch("orchestration.jobs.options.Variable.get", return_value={ "defaults": False, "batch_size": 8, "patch_size": 512, @@ -395,7 +341,7 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -414,7 +360,7 @@ def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_ from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("No GPU nodes") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -522,7 +468,7 @@ def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config83 monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} controller = NERSCTomographyHPCController( client=mock_iriapi_client, @@ -605,7 +551,7 @@ def test_segmentation_dinov3_output_paths(mocker, mock_sfapi_client, mock_config from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) captured_scripts = [] original_return = mock_sfapi_client.compute.return_value.submit_job.return_value @@ -636,7 +582,7 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -655,7 +601,7 @@ def test_combine_segmentations_submission_failure(mocker, mock_sfapi_client, moc from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Cluster down") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -673,7 +619,7 @@ def test_combine_segmentations_script_references_sam3_and_dino(mocker, mock_sfap from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) captured_scripts = [] original_return = mock_sfapi_client.compute.return_value.submit_job.return_value diff --git a/orchestration/_tests/test_jobs/__init__.py b/orchestration/_tests/test_jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/alcf/__init__.py b/orchestration/_tests/test_jobs/alcf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/alcf/test_controller.py b/orchestration/_tests/test_jobs/alcf/test_controller.py new file mode 100644 index 00000000..d8dbf352 --- /dev/null +++ b/orchestration/_tests/test_jobs/alcf/test_controller.py @@ -0,0 +1,198 @@ +"""Tests for orchestration/jobs/alcf/controller.py — ALCFJobController. + +Patch targets (confirmed from module-level imports in alcf/controller.py): + orchestration.jobs.alcf.controller.Variable.get (prefect Variable) + orchestration.jobs.alcf.controller.Secret.load (prefect Secret) + orchestration.jobs.alcf.controller.get_run_logger (prefect run logger) + orchestration.jobs.alcf.controller.Client (globus_compute_sdk.Client) + orchestration.jobs.alcf.controller.Executor (globus_compute_sdk.Executor) + orchestration.jobs.alcf.controller.time.sleep (skip polling delays) + +All TestALCFJobControllerInit and TestSubmit tests request the mock_alcf_prefect +fixture to satisfy Variable.get and Secret.load calls in __init__. + +TestWaitForFuture uses ALCFJobController.wait_for_future() as a @staticmethod — +no instance construction needed, so mock_alcf_prefect is not required there. +""" + +import pytest + +from orchestration.jobs.alcf.controller import ( + ALCFJobController, + _ALLOCATION_ROOT_VARIABLE, + _GLOBUS_COMPUTE_ENDPOINT_SECRET, +) + + +# ── Init ────────────────────────────────────────────────────────────────────── + +class TestALCFJobControllerInit: + def test_reads_allocation_root_from_variable(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) + assert ctrl.allocation_root == "/eagle/IRIProd/ALS" + + def test_reads_endpoint_id_from_secret(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) + assert ctrl.endpoint_id == "mock-endpoint-uuid" + + def test_variable_get_called_with_correct_name(self, mocker, mock_config, mock_alcf_prefect): + ALCFJobController(mock_config) + mock_alcf_prefect.variable.assert_called_once_with( + _ALLOCATION_ROOT_VARIABLE, _sync=True + ) + + def test_secret_load_called_with_correct_name(self, mocker, mock_config, mock_alcf_prefect): + ALCFJobController(mock_config) + # mock_alcf_prefect.secret IS the mock for Secret.load (not Secret itself) + mock_alcf_prefect.secret.assert_called_once_with(_GLOBUS_COMPUTE_ENDPOINT_SECRET) + + def test_raises_when_allocation_root_missing(self, mocker, mock_config): + # allocation_data.get(...) returns None → ValueError + mocker.patch( + "orchestration.jobs.alcf.controller.Variable.get", + return_value={}, # key absent → .get() returns None + ) + mocker.patch("orchestration.jobs.alcf.controller.Secret.load") + with pytest.raises(ValueError, match="Allocation root not found"): + ALCFJobController(mock_config) + + def test_stores_config(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) + assert ctrl.config is mock_config + + +# ── submit ──────────────────────────────────────────────────────────────────── + +class TestSubmit: + def test_constructs_client_and_submits_via_executor(self, mocker, mock_config, mock_alcf_prefect): + mock_client_cls = mocker.patch("orchestration.jobs.alcf.controller.Client") + mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") + + mock_future = mocker.MagicMock() + mock_executor_instance = mocker.MagicMock() + mock_executor_instance.submit.return_value = mock_future + mock_executor_cls.return_value.__enter__ = mocker.MagicMock(return_value=mock_executor_instance) + mock_executor_cls.return_value.__exit__ = mocker.MagicMock(return_value=False) + + def noop(): + pass + + ctrl = ALCFJobController(mock_config) + result = ctrl.submit(noop) + + mock_client_cls.assert_called_once() + mock_executor_cls.assert_called_once_with( + endpoint_id="mock-endpoint-uuid", + client=mock_client_cls.return_value, + ) + mock_executor_instance.submit.assert_called_once() + assert result is mock_future + + def test_returns_future(self, mocker, mock_config, mock_alcf_prefect): + mocker.patch("orchestration.jobs.alcf.controller.Client") + mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") + + mock_future = mocker.MagicMock() + mock_executor_instance = mocker.MagicMock() + mock_executor_instance.submit.return_value = mock_future + mock_executor_cls.return_value.__enter__ = mocker.MagicMock(return_value=mock_executor_instance) + mock_executor_cls.return_value.__exit__ = mocker.MagicMock(return_value=False) + + def identity(x): + return x + + ctrl = ALCFJobController(mock_config) + future = ctrl.submit(identity, 42, key="val") + + mock_executor_instance.submit.assert_called_once_with(identity, 42, key="val") + assert future is mock_future + + +# ── wait_for_future ─────────────────────────────────────────────────────────── + +class TestWaitForFuture: + """wait_for_future is @staticmethod — call directly without constructing an instance.""" + + def _run_logger(self, mocker): + """Patch get_run_logger and return a mock logger.""" + run_logger = mocker.MagicMock() + mocker.patch( + "orchestration.jobs.alcf.controller.get_run_logger", + return_value=run_logger, + ) + return run_logger + + def test_returns_true_on_success(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = False + future.exception.return_value = None + future.result.return_value = "output" + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is True + + def test_returns_false_when_future_raises(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = False + future.exception.return_value = RuntimeError("job failed") + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is False + + def test_returns_false_when_cancelled(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = True + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is False + + def test_returns_false_on_timeout(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + # Simulate time advancing past walltime by patching time.time + call_count = [0] + start = 1000.0 + + def mock_time(): + val = start + call_count[0] * 700 + call_count[0] += 1 + return val + + mocker.patch("orchestration.jobs.alcf.controller.time.time", side_effect=mock_time) + + future = mocker.MagicMock() + future.done.return_value = False # never completes + future.cancelled.return_value = False + + result = ALCFJobController.wait_for_future( + future, "reconstruction", check_interval=1, walltime=600 + ) + future.cancel.assert_called() + assert result is False + + def test_polls_until_done(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.side_effect = [False, False, True] + future.cancelled.return_value = False + future.exception.return_value = None + future.result.return_value = "done" + + result = ALCFJobController.wait_for_future(future, "task", check_interval=1, walltime=3600) + assert future.done.call_count == 3 + assert result is True diff --git a/orchestration/_tests/test_jobs/conftest.py b/orchestration/_tests/test_jobs/conftest.py new file mode 100644 index 00000000..80ca8ac5 --- /dev/null +++ b/orchestration/_tests/test_jobs/conftest.py @@ -0,0 +1,87 @@ +import types + +import pytest +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(scope="session", autouse=True) +def prefect_test_fixture(): + """Wrap the entire test_jobs/ session in a Prefect test harness. + + Required because ALCFJobController.__init__ calls Variable.get(_sync=True), + which needs a live Prefect API — even when Variable.get is patched, the + import-time Prefect setup must succeed. + """ + with prefect_test_harness(): + yield + + +@pytest.fixture +def mock_config(): + """Minimal BeamlineConfig-like namespace for tests that need a config object.""" + return types.SimpleNamespace( + nersc_resources={ + "iri": { + "api_base_url": "https://mock-iri.nersc.gov", + "perlmutter_login": "mock-login-uuid", + "perlmutter_job_submit": "mock-submit-uuid", + "compute_resource": "mock-compute-uuid", + }, + "sfapi": {"api_base_url": "https://mock-sfapi.nersc.gov"}, + }, + mlflow={"tracking_uri": "http://mock-mlflow:5000"}, + ) + + +@pytest.fixture +def mock_sfapi_client(mocker): + """MagicMock shaped like an sfapi_client.Client.""" + client = mocker.MagicMock() + user = mocker.MagicMock() + user.name = "testuser" + client.user.return_value = user + return client + + +@pytest.fixture +def mock_iriapi_client(mocker): + """MagicMock shaped like an httpx.Client targeting the IRI API.""" + client = mocker.MagicMock() + client.post.return_value = mocker.MagicMock( + is_success=True, json=lambda: {"id": "job-42"} + ) + client.get.return_value = mocker.MagicMock( + is_success=True, + json=lambda: {"status": {"state": "completed"}}, + ) + return client + + +@pytest.fixture +def mock_alcf_prefect(mocker): + """Patch Variable.get and Secret.load in the ALCF controller module. + + allocation_data is a real dict, not a MagicMock — the constructor calls + allocation_data.get("alcf-allocation-root-path") and checks truthiness of + the result. A MagicMock would pass the check with an arbitrary truthy value, + masking bugs. + + Tests that need Prefect mocked request this fixture explicitly. + """ + allocation_data = {"alcf-allocation-root-path": "/eagle/IRIProd/ALS"} + var_mock = mocker.patch( + "orchestration.jobs.alcf.controller.Variable.get", + return_value=allocation_data, + ) + secret_mock = mocker.patch("orchestration.jobs.alcf.controller.Secret.load") + secret_mock.return_value.get.return_value = "mock-endpoint-uuid" + return types.SimpleNamespace(variable=var_mock, secret=secret_mock) + + +@pytest.fixture +def mock_options_prefect(mocker): + """Patch Variable.get in the options module (used by load_job_options tests).""" + return mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": True}, + ) diff --git a/orchestration/_tests/test_jobs/nersc/__init__.py b/orchestration/_tests/test_jobs/nersc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/nersc/test_controller.py b/orchestration/_tests/test_jobs/nersc/test_controller.py new file mode 100644 index 00000000..eb38ee52 --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_controller.py @@ -0,0 +1,333 @@ +"""Tests for orchestration/jobs/nersc/controller.py — NERSCJobController. + +Patch targets: + orchestration.jobs.nersc.controller.time.sleep (skip 60-second polling delays) + +No Prefect Variable/Secret imports in this module — no Prefect mocking needed. +""" + +import pytest + +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _sfapi_controller(mocker, mock_config): + client = mocker.MagicMock() + user = mocker.MagicMock() + user.name = "sfapiuser" + client.user.return_value = user + return NERSCJobController(mock_config, client=client, login_method=NERSCLoginMethod.SFAPI) + + +def _iriapi_controller(mocker, mock_config): + client = mocker.MagicMock() + # POST (submit_job): json() returns the job ID dict + post_response = mocker.MagicMock(is_success=True) + post_response.json.return_value = {"id": "job-99"} + client.post.return_value = post_response + # GET (wait_for_job, read_remote_file): default to completed state + get_response = mocker.MagicMock(is_success=True) + get_response.json.return_value = {"status": {"state": "completed"}} + get_response.text = "" + client.get.return_value = get_response + return NERSCJobController(mock_config, client=client, login_method=NERSCLoginMethod.IRIAPI) + + +# ── Initialization ──────────────────────────────────────────────────────────── + +class TestNERSCJobControllerInit: + def test_sfapi_stores_sfapi_nersc_resources(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + assert ctrl.nersc_resources == mock_config.nersc_resources["sfapi"] + + def test_iriapi_stores_iri_nersc_resources(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + assert ctrl.nersc_resources == mock_config.nersc_resources["iri"] + + def test_stores_login_method(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + assert ctrl.login_method is NERSCLoginMethod.SFAPI + + def test_unknown_login_method_raises(self, mocker, mock_config): + bad_method = mocker.MagicMock() + bad_method.__eq__ = lambda s, o: False + bad_method.__ne__ = lambda s, o: True + with pytest.raises(ValueError, match="Unsupported NERSCLoginMethod"): + NERSCJobController(mock_config, client=None, login_method=bad_method) + + +# ── get_nersc_username ──────────────────────────────────────────────────────── + +class TestGetNerscUsername: + def test_sfapi_reads_name_from_client(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + assert ctrl.get_nersc_username() == "sfapiuser" + + def test_iriapi_reads_from_env(self, mocker, mock_config, monkeypatch): + monkeypatch.setenv("NERSC_USERNAME", "envuser") + ctrl = _iriapi_controller(mocker, mock_config) + assert ctrl.get_nersc_username() == "envuser" + + def test_iriapi_raises_when_env_unset(self, mocker, mock_config, monkeypatch): + monkeypatch.delenv("NERSC_USERNAME", raising=False) + ctrl = _iriapi_controller(mocker, mock_config) + with pytest.raises(ValueError, match="NERSC_USERNAME must be set"): + ctrl.get_nersc_username() + + +# ── submit_job ──────────────────────────────────────────────────────────────── + +class TestSubmitJob: + def test_sfapi_returns_job_id_string(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + job = mocker.MagicMock() + job.jobid = 12345 + ctrl.client.compute.return_value.submit_job.return_value = job + result = ctrl.submit_job("#!/bin/bash\n#SBATCH -q debug\necho hi") + assert result == "12345" + + def test_sfapi_calls_perlmutter_submit_job(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + job = mocker.MagicMock() + job.jobid = "abc" + perlmutter = ctrl.client.compute.return_value + perlmutter.submit_job.return_value = job + ctrl.submit_job("script") + perlmutter.submit_job.assert_called_once_with("script") + + def test_iriapi_returns_job_id_string(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + result = ctrl.submit_job(script) + assert result == "job-99" + + def test_iriapi_posts_to_job_submit_url(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ctrl.submit_job(script) + call_args = ctrl.client.post.call_args + assert "mock-submit-uuid" in call_args[0][0] + + +# ── _submit_job_iriapi SBATCH parsing ───────────────────────────────────────── + +class TestSubmitJobIRIAPI: + """Tests the SBATCH header parsing logic in _submit_job_iriapi.""" + + def _submit_and_capture_spec(self, mocker, mock_config, script): + ctrl = _iriapi_controller(mocker, mock_config) + captured = {} + + def capture_post(url, json=None, **kwargs): + captured["json"] = json + resp = mocker.MagicMock(is_success=True) + resp.json.return_value = {"id": "captured-id"} + return resp + + ctrl.client.post = capture_post + ctrl._submit_job_iriapi(script) + return captured["json"] + + def test_parses_queue_name(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q premium\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert spec["attributes"]["queue_name"] == "premium" + + def test_parses_account(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A myproject\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert spec["attributes"]["account"] == "myproject" + + def test_parses_walltime_to_seconds(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=01:30:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert spec["attributes"]["duration"] == 5400 # 1h30m in seconds + + def test_parses_node_count(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 4\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert spec["resources"]["node_count"] == 4 + + def test_cpu_constraint_adds_cpu_cores(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C cpu\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert "cpu_cores_per_process" in spec["resources"] + assert "gpu_cores_per_process" not in spec["resources"] + + def test_gpu_constraint_adds_gpu_cores(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C gpu\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert "gpu_cores_per_process" in spec["resources"] + assert "cpu_cores_per_process" not in spec["resources"] + + def test_reservation_included_when_present(self, mocker, mock_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH --reservation=myres\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert spec["attributes"]["reservation_id"] == "myres" + + def test_no_reservation_when_absent(self, mocker, mock_config): + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + spec = self._submit_and_capture_spec(mocker, mock_config, script) + assert "reservation_id" not in spec["attributes"] + + +# ── wait_for_job ────────────────────────────────────────────────────────────── + +class TestWaitForJob: + def test_sfapi_returns_true_on_complete(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + job = mocker.MagicMock() + ctrl.client.compute.return_value.job.return_value = job + result = ctrl.wait_for_job("12345") + job.complete.assert_called_once() + assert result is True + + def test_iriapi_returns_true_when_completed(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "completed"}} + result = ctrl.wait_for_job("42") + assert result is True + + def test_iriapi_returns_false_on_failed(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "failed"}} + result = ctrl.wait_for_job("42") + assert result is False + + def test_iriapi_returns_false_on_canceled(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "canceled"}} + result = ctrl.wait_for_job("42") + assert result is False + + def test_iriapi_polls_until_terminal_state(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + responses = [ + {"status": {"state": "running"}}, + {"status": {"state": "running"}}, + {"status": {"state": "completed"}}, + ] + ctrl.client.get.return_value.json.side_effect = responses + result = ctrl.wait_for_job("42") + assert ctrl.client.get.call_count == 3 + assert result is True + + +# ── mkdir_remote ────────────────────────────────────────────────────────────── + +class TestMkdirRemote: + def test_sfapi_runs_mkdir(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") + ctrl.client.compute.return_value.run.assert_called_once() + cmd = ctrl.client.compute.return_value.run.call_args[0][0] + assert "mkdir -p" in cmd + assert "/pscratch/sd/t/testuser/mydir" in cmd + + def test_iriapi_posts_to_mkdir_url(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") + ctrl.client.post.assert_called() + url = ctrl.client.post.call_args[0][0] + assert "mock-login-uuid" in url + + def test_iriapi_posts_path_in_body(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + ctrl.mkdir_remote("/some/path") + body = ctrl.client.post.call_args[1]["json"] + assert body["path"] == "/some/path" + assert body["parents"] is True + + +# ── read_remote_file ────────────────────────────────────────────────────────── + +class TestReadRemoteFile: + def test_sfapi_returns_string_result(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + ctrl.client.compute.return_value.run.return_value = "file contents" + result = ctrl.read_remote_file("/some/file.txt") + assert result == "file contents" + + def test_sfapi_extracts_output_attribute(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + run_result = mocker.MagicMock(spec=[]) # no __str__ shortcuts + run_result.output = "from output attr" + ctrl.client.compute.return_value.run.return_value = run_result + result = ctrl.read_remote_file("/some/file.txt") + assert result == "from output attr" + + def test_iriapi_returns_file_contents_on_completed_task(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + + # First call: GET /filesystem/view → task_id + # Subsequent calls: GET /task/ → status=completed + result + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-abc"} + view_response.text = "" + + task_response = mocker.MagicMock(is_success=True) + task_response.json.return_value = {"status": "completed", "result": "file data"} + + ctrl.client.get.side_effect = [view_response, task_response] + result = ctrl.read_remote_file("/pscratch/data.txt") + assert result == "file data" + + def test_iriapi_raises_on_failed_task(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-fail"} + view_response.text = "" + + task_response = mocker.MagicMock(is_success=True) + task_response.json.return_value = {"status": "failed", "result": "disk error"} + + ctrl.client.get.side_effect = [view_response, task_response] + with pytest.raises(RuntimeError, match="failed"): + ctrl.read_remote_file("/pscratch/data.txt") + + def test_iriapi_raises_timeout_after_40_polls(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, mock_config) + + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-slow"} + view_response.text = "" + + # Always return "pending" — never completes + pending_response = mocker.MagicMock(is_success=True) + pending_response.json.return_value = {"status": "pending", "result": None} + + ctrl.client.get.side_effect = [view_response] + [pending_response] * 40 + with pytest.raises(TimeoutError): + ctrl.read_remote_file("/pscratch/slow.txt") diff --git a/orchestration/_tests/test_jobs/nersc/test_login.py b/orchestration/_tests/test_jobs/nersc/test_login.py new file mode 100644 index 00000000..32f32bfc --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_login.py @@ -0,0 +1,83 @@ +"""Tests for orchestration/jobs/nersc/login.py. + +Patch targets (nersc/login.py has NO Variable or Secret imports): + orchestration.jobs.nersc.login._create_sfapi_client (private builder) + orchestration.jobs.nersc.login._create_iriapi_client (private builder) + +The private builders are patched, not tested directly — they are thin wrappers +around third-party constructors with extensive credential I/O (sfapi_client.Client +__init__, env vars, file reads). Patching them is sufficient to exercise +create_nersc_client's dispatch logic, which is the only thing this module adds. +That choice is intentional. +""" + +from orchestration.jobs.nersc.login import NERSCLoginMethod, create_nersc_client + + +# ── Structural identity check ───────────────────────────────────────────────── + +def test_login_method_is_canonical(): + """NERSCLoginMethod is defined exactly once; bl832 re-export is the same object. + + Uses `is`, not `==` — equality passes even for two separate enum classes + whose members have matching names and values. + """ + from orchestration.jobs.nersc.login import NERSCLoginMethod as A + from orchestration.flows.bl832.job_controller import NERSCLoginMethod as B + assert A is B + + +# ── NERSCLoginMethod enum ───────────────────────────────────────────────────── + +class TestNERSCLoginMethod: + def test_sfapi_value(self): + assert NERSCLoginMethod.SFAPI.value == "sfapi" + + def test_iriapi_value(self): + assert NERSCLoginMethod.IRIAPI.value == "iriapi" + + def test_membership_sfapi(self): + assert NERSCLoginMethod("sfapi") is NERSCLoginMethod.SFAPI + + def test_membership_iriapi(self): + assert NERSCLoginMethod("iriapi") is NERSCLoginMethod.IRIAPI + + +# ── create_nersc_client dispatch ────────────────────────────────────────────── + +class TestCreateNerscClient: + def test_sfapi_dispatches_to_sfapi_builder(self, mocker, mock_config): + mock_client = mocker.MagicMock() + builder = mocker.patch( + "orchestration.jobs.nersc.login._create_sfapi_client", + return_value=mock_client, + ) + result = create_nersc_client(mock_config, NERSCLoginMethod.SFAPI) + builder.assert_called_once_with() + assert result is mock_client + + def test_iriapi_dispatches_to_iriapi_builder(self, mocker, mock_config): + mock_client = mocker.MagicMock() + builder = mocker.patch( + "orchestration.jobs.nersc.login._create_iriapi_client", + return_value=mock_client, + ) + result = create_nersc_client(mock_config, NERSCLoginMethod.IRIAPI) + builder.assert_called_once_with(mock_config.nersc_resources["iri"]["api_base_url"]) + assert result is mock_client + + def test_sfapi_passes_api_base_url_from_config(self, mocker, mock_config): + mocker.patch("orchestration.jobs.nersc.login._create_sfapi_client") + # No assertion on URL for SFAPI (the builder doesn't take a URL arg), + # but create_nersc_client must read the sfapi sub-dict without raising. + create_nersc_client(mock_config, NERSCLoginMethod.SFAPI) + + def test_iriapi_passes_correct_api_base_url(self, mocker, mock_config): + builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") + create_nersc_client(mock_config, NERSCLoginMethod.IRIAPI) + builder.assert_called_once_with("https://mock-iri.nersc.gov") + + def test_default_login_method_is_iriapi(self, mocker, mock_config): + builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") + create_nersc_client(mock_config) + builder.assert_called_once() diff --git a/orchestration/_tests/test_jobs/nersc/test_shifter.py b/orchestration/_tests/test_jobs/nersc/test_shifter.py new file mode 100644 index 00000000..d855bc71 --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_shifter.py @@ -0,0 +1,120 @@ +"""Tests for orchestration/jobs/nersc/shifter.py. + +All tests use mocker.MagicMock(spec=NERSCJobController) — spec= is required so +that drift in the controller interface (renamed/removed methods) breaks these +tests rather than silently passing. + +check_shifter_image SFAPI branch deferred import note: + check_shifter_image does `from sfapi_client.compute import Machine` inside the + function body (line 204). When setting up the SFAPI mock, use a plain + mocker.MagicMock() (not spec'd) for the compute() return value — Machine.perlmutter + is evaluated after the patch, so a spec'd mock keyed on the Machine class would fail. + +Patch targets: + orchestration.jobs.nersc.shifter.time.sleep (skips 30-second pull delay) +""" + +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod +from orchestration.jobs.nersc.shifter import check_shifter_image, pull_shifter_image + + +def _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI): + controller = mocker.MagicMock(spec=NERSCJobController) + controller.login_method = login_method + # client is an instance attribute (set in __init__) — not in the class spec, + # so set it explicitly here so tests can configure controller.client.compute. + controller.client = mocker.MagicMock() + controller.get_nersc_username.return_value = "testuser" + controller.submit_job.return_value = "job-123" + controller.wait_for_job.return_value = True + controller.read_remote_file.return_value = "ghcr.io/als-computing/image:latest found" + return controller + + +class TestPullShifterImage: + def test_submits_pull_script_and_returns_success(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + result = pull_shifter_image(controller, "docker:ghcr.io/als/image:latest") + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_called_once_with("job-123") + assert result is True + + def test_script_contains_image_name(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + pull_shifter_image(controller, "docker:ghcr.io/als/myimage:v2") + script = controller.submit_job.call_args[0][0] + assert "docker:ghcr.io/als/myimage:v2" in script + + def test_returns_true_when_wait_false(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + result = pull_shifter_image(controller, "docker:image:latest", wait=False) + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_not_called() + assert result is True + + def test_returns_false_on_submit_exception(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + controller.submit_job.side_effect = RuntimeError("NERSC down") + result = pull_shifter_image(controller, "docker:image:latest") + assert result is False + + def test_mkdir_remote_called_for_log_dir(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + pull_shifter_image(controller, "docker:image:latest") + controller.mkdir_remote.assert_called_once() + log_dir_arg = controller.mkdir_remote.call_args[0][0] + assert "testuser" in log_dir_arg + + +class TestCheckShifterImage: + def test_sfapi_returns_true_when_grep_matches(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + # Use plain MagicMock (not spec'd) for perlmutter — Machine.perlmutter + # is evaluated after the patch inside the function body. + perlmutter = mocker.MagicMock() + perlmutter.run.return_value = "ghcr.io/als/image:latest found in cache" + controller.client.compute.return_value = perlmutter + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is True + + def test_sfapi_returns_false_when_no_match(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + perlmutter = mocker.MagicMock() + perlmutter.run.return_value = "" + controller.client.compute.return_value = perlmutter + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is False + + def test_iriapi_submits_check_job_and_reads_output(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker, login_method=NERSCLoginMethod.IRIAPI) + controller.read_remote_file.return_value = "ghcr.io/als/image:latest cached" + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_called_once_with("job-123") + controller.read_remote_file.assert_called_once() + assert result is True + + def test_iriapi_returns_false_when_image_not_in_output(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker, login_method=NERSCLoginMethod.IRIAPI) + controller.read_remote_file.return_value = "" + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is False + + def test_returns_false_on_exception(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + controller.client.compute.side_effect = RuntimeError("network error") + + result = check_shifter_image(controller, "docker:image:latest") + assert result is False diff --git a/orchestration/_tests/test_jobs/test_controller.py b/orchestration/_tests/test_jobs/test_controller.py new file mode 100644 index 00000000..5e9a88e6 --- /dev/null +++ b/orchestration/_tests/test_jobs/test_controller.py @@ -0,0 +1,71 @@ +"""Tests for orchestration/jobs/controller.py.""" + +from orchestration.jobs.controller import JobController, JobTarget + + +# ── Structural identity check ───────────────────────────────────────────────── + +def test_hpc_alias_is_canonical_job_target(): + """HPC in bl832/job_controller.py must be the same object as JobTarget. + + Uses `is`, not `==` — equality passes even for two separate enum classes + whose members have matching names and values. `is` catches a re-introduced + duplicate enum class immediately. + """ + from orchestration.flows.bl832.job_controller import HPC + assert HPC is JobTarget + + +# ── JobTarget enum ──────────────────────────────────────────────────────────── + +class TestJobTarget: + def test_has_alcf_member(self): + assert JobTarget.ALCF is not None + + def test_has_nersc_member(self): + assert JobTarget.NERSC is not None + + def test_has_olcf_member(self): + assert JobTarget.OLCF is not None + + def test_nersc_string_value(self): + assert JobTarget.NERSC.value == "NERSC" + + def test_alcf_string_value(self): + assert JobTarget.ALCF.value == "ALCF" + + def test_olcf_string_value(self): + assert JobTarget.OLCF.value == "OLCF" + + def test_membership_by_value(self): + assert JobTarget("NERSC") is JobTarget.NERSC + + +# ── JobController ABC ───────────────────────────────────────────────────────── + +class TestJobControllerABC: + """JobController is an ABC but declares NO @abstractmethod. + + Direct instantiation succeeds when a valid config is provided. Tests here + verify the contract (stores config, ABC inheritance) rather than checking + for TypeError on instantiation. + """ + + def test_instantiates_with_valid_config(self, mock_config): + controller = JobController(mock_config) + assert controller is not None + + def test_stores_config(self, mock_config): + controller = JobController(mock_config) + assert controller.config is mock_config + + def test_subclass_inherits_config(self, mock_config): + class MockJobController(JobController): + pass + + controller = MockJobController(mock_config) + assert controller.config is mock_config + + def test_is_abc_subclass(self): + from abc import ABC + assert issubclass(JobController, ABC) diff --git a/orchestration/_tests/test_jobs/test_options.py b/orchestration/_tests/test_jobs/test_options.py new file mode 100644 index 00000000..a4a29ec9 --- /dev/null +++ b/orchestration/_tests/test_jobs/test_options.py @@ -0,0 +1,169 @@ +"""Tests for orchestration/jobs/options.py — load_job_options three-layer resolution. + +Patch targets verified from source: + orchestration.jobs.options.Variable.get (prefect.variables.Variable imported at top) + orchestration.jobs.options.get_checkpoint_info (orchestration.mlflow imported at top) +""" + +import json + +from orchestration.jobs.options import load_job_options + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _make_checkpoint(mocker, *, nersc_path="/pscratch/checkpoint.pt", inference_params=None): + cp = mocker.MagicMock() + cp.nersc_path = nersc_path + cp.inference_params = inference_params or {} + return cp + + +# ── Tests ───────────────────────────────────────────────────────────────────── + +class TestLoadJobOptions: + + # ── Layer 1: config defaults ────────────────────────────────────────────── + + def test_returns_config_defaults_when_variable_says_defaults(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + opts = load_job_options("some-var", {"key": "value"}) + assert opts == {"key": "value"} + + def test_returns_copy_not_original(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + base = {"key": "value"} + opts = load_job_options("some-var", base) + opts["key"] = "mutated" + assert base["key"] == "value" + + # ── Layer 2: MLflow ─────────────────────────────────────────────────────── + + def test_mlflow_nersc_path_maps_to_checkpoint_key(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, nersc_path="/pscratch/model.pt") + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"finetuned_checkpoint_path": "/old/path"}, + config=mock_config, + mlflow_model_name="my-model", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + assert opts["finetuned_checkpoint_path"] == "/pscratch/model.pt" + + def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, inference_params={"batch_size": 16, "threshold": 0.5}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"batch_size": 8, "threshold": 0.3, "other": "kept"}, + config=mock_config, + mlflow_model_name="my-model", + ) + assert opts["batch_size"] == 16 + assert opts["threshold"] == 0.5 + assert opts["other"] == "kept" + + def test_mlflow_layer_skipped_when_config_is_none(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") + + opts = load_job_options("var", {"key": "value"}, config=None, mlflow_model_name="model") + spy.assert_not_called() + assert opts == {"key": "value"} + + def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") + + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name=None) + spy.assert_not_called() + assert opts == {"key": "value"} + + def test_mlflow_fallback_to_config_when_checkpoint_is_none(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=None) + + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name="model") + assert opts == {"key": "value"} + + def test_mlflow_fallback_to_config_when_get_checkpoint_raises(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + mocker.patch( + "orchestration.jobs.options.get_checkpoint_info", + side_effect=RuntimeError("mlflow unreachable"), + ) + + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name="model") + assert opts == {"key": "value"} + + def test_mlflow_injects_new_keys_not_in_config(self, mocker, mock_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, inference_params={"new_param": "injected"}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options("var", {"existing": "kept"}, config=mock_config, mlflow_model_name="model") + assert opts["new_param"] == "injected" + assert opts["existing"] == "kept" + + # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── + + def test_prefect_variable_overrides_win_over_config(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "key": "override"}, + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "override" + + def test_defaults_true_suppresses_variable_overrides(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": True, "key": "would-override"}, + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "config-default" + + def test_json_string_variable_is_parsed(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value=json.dumps({"defaults": False, "key": "from-json"}), + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "from-json" + + def test_variable_get_failure_falls_back_to_opts(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + side_effect=Exception("Prefect unavailable"), + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts == {"key": "config-default"} + + def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "batch_size": 99}, + ) + cp = _make_checkpoint(mocker, inference_params={"batch_size": 16}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"batch_size": 8}, + config=mock_config, + mlflow_model_name="model", + ) + assert opts["batch_size"] == 99 + + def test_defaults_key_not_present_in_output(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "key": "override"}, + ) + opts = load_job_options("var", {}) + assert "defaults" not in opts diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py deleted file mode 100644 index e0d4a854..00000000 --- a/orchestration/_tests/test_sfapi_flow.py +++ /dev/null @@ -1,63 +0,0 @@ -# orchestration/_tests/test_sfapi_flow.py -import pytest -from uuid import uuid4 - -from prefect.blocks.system import Secret -from prefect.testing.utilities import prefect_test_harness - - -@pytest.fixture(autouse=True, scope="session") -def prefect_test_fixture(): - with prefect_test_harness(): - Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) - Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) - yield - - -def test_create_sfapi_client_success(mocker): - """Valid credentials produce a Client instance.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch( - "builtins.open", - side_effect=[ - mocker.mock_open(read_data="my-client-id")(), - mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), - ] - ) - mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") - mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - - client = NERSCTomographyHPCController._create_sfapi_client() - - mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") - assert client is mock_client_cls.return_value - - -def test_create_sfapi_client_missing_paths(mocker): - """Unset env vars raise ValueError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController._create_sfapi_client() - - -def test_create_sfapi_client_missing_files(mocker): - """Env vars set but files absent raise FileNotFoundError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController._create_sfapi_client() diff --git a/orchestration/flows/bl832/alcf.py b/orchestration/flows/bl832/alcf.py index 28ac7813..326a1b00 100644 --- a/orchestration/flows/bl832/alcf.py +++ b/orchestration/flows/bl832/alcf.py @@ -1,45 +1,33 @@ -from concurrent.futures import Future import datetime +import logging from pathlib import Path -import time from typing import Optional -from globus_compute_sdk import Client, Executor -from globus_compute_sdk.serialize import CombinedCode from prefect import flow, task, get_run_logger -from prefect.blocks.system import Secret from prefect.variables import Variable from orchestration.flows.bl832.config import Config832 from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.jobs.alcf.controller import ALCFJobController from orchestration.transfer_controller import get_transfer_controller, CopyMethod from orchestration.prefect import schedule_prefect_flow from orchestration.tiled import register_file_to_tiled +logger = logging.getLogger(__name__) -class ALCFTomographyHPCController(TomographyHPCController): - """ - Implementation of TomographyHPCController for ALCF. Methods here leverage Globus Compute for processing tasks. - There is a @staticmethod wrapper for each compute task submitted via Globus Compute. - Also, there is a shared wait_for_globus_compute_future method that waits for the task to complete. - Args: - TomographyHPCController (ABC): Abstract class for tomography HPC controllers. +class ALCFTomographyHPCController(TomographyHPCController, ALCFJobController): + """ALCF tomography HPC controller for BL832. + + Submits reconstruction and multi-resolution jobs to ALCF via Globus Compute. + Beamline-agnostic submit/wait primitives are inherited from ALCFJobController. """ - def __init__( - self, - config: Config832 - ) -> None: - super().__init__(config) - # Load allocation root from the Prefect JSON block - # The block must be registered with the name "alcf-allocation-root-path" - logger = get_run_logger() - allocation_data = Variable.get("alcf-allocation-root-path", _sync=True) - self.allocation_root = allocation_data.get("alcf-allocation-root-path") - if not self.allocation_root: - raise ValueError("Allocation root not found in JSON block 'alcf-allocation-root-path'") - logger.info(f"Allocation root loaded: {self.allocation_root}") + def __init__(self, config: Config832) -> None: + # ALCF doesn't take a pre-built client — Globus Compute Client is + # constructed inside ALCFJobController.submit() per-call. + ALCFJobController.__init__(self, config) + TomographyHPCController.__init__(self, config) def reconstruct( self, @@ -54,26 +42,22 @@ def reconstruct( Returns: bool: True if the task completed successfully, False otherwise. """ - logger = get_run_logger() + run_logger = get_run_logger() file_name = Path(file_path).stem + ".h5" folder_name = Path(file_path).parent.name iri_als_bl832_rundir = f"{self.allocation_root}/data/raw" iri_als_bl832_recon_script = f"{self.allocation_root}/scripts/globus_reconstruction.py" - gcc = Client(code_serialization_strategy=CombinedCode()) - - with Executor(endpoint_id=Secret.load("globus-compute-endpoint").get(), client=gcc) as fxe: - logger.info(f"Running Tomopy reconstruction on {file_name} at ALCF") - future = fxe.submit( - self._reconstruct_wrapper, - iri_als_bl832_rundir, - iri_als_bl832_recon_script, - file_name, - folder_name - ) - result = self._wait_for_globus_compute_future(future, "reconstruction", check_interval=10) - return result + run_logger.info(f"Running Tomopy reconstruction on {file_name} at ALCF") + future = self.submit( + self._reconstruct_wrapper, + iri_als_bl832_rundir, + iri_als_bl832_recon_script, + file_name, + folder_name, + ) + return self.wait_for_future(future, "reconstruction", check_interval=10) @staticmethod def _reconstruct_wrapper( @@ -129,7 +113,7 @@ def build_multi_resolution( Returns: bool: True if the task completed successfully, False otherwise. """ - logger = get_run_logger() + run_logger = get_run_logger() file_name = Path(file_path).stem folder_name = Path(file_path).parent.name @@ -140,19 +124,15 @@ def build_multi_resolution( iri_als_bl832_rundir = f"{self.allocation_root}/data/raw" iri_als_bl832_conversion_script = f"{self.allocation_root}/scripts/tiff_to_zarr.py" - gcc = Client(code_serialization_strategy=CombinedCode()) - - with Executor(endpoint_id=Secret.load("globus-compute-endpoint").get(), client=gcc) as fxe: - logger.info(f"Running Tiff to Zarr on {raw_path} at ALCF") - future = fxe.submit( - self._build_multi_resolution_wrapper, - iri_als_bl832_rundir, - iri_als_bl832_conversion_script, - tiff_scratch_path, - raw_path - ) - result = self._wait_for_globus_compute_future(future, "tiff to zarr conversion", check_interval=10) - return result + run_logger.info(f"Running Tiff to Zarr on {raw_path} at ALCF") + future = self.submit( + self._build_multi_resolution_wrapper, + iri_als_bl832_rundir, + iri_als_bl832_conversion_script, + tiff_scratch_path, + raw_path, + ) + return self.wait_for_future(future, "tiff to zarr conversion", check_interval=10) @staticmethod def _build_multi_resolution_wrapper( @@ -186,76 +166,6 @@ def _build_multi_resolution_wrapper( f"Converted tiff files to zarr;\n {zarr_res}" ) - @staticmethod - def _wait_for_globus_compute_future( - future: Future, - task_name: str, - check_interval: int = 20, - walltime: int = 1200 # seconds = 20 minutes - ) -> bool: - """ - Wait for a Globus Compute task to complete, assuming that if future.done() is False, the task is running. - - Args: - future: The future object returned from the Globus Compute Executor submit method. - task_name: A descriptive name for the task being executed (used for logging). - check_interval: The interval (in seconds) between status checks. - walltime: The maximum time (in seconds) to wait for the task to complete. - - Returns: - bool: True if the task completed successfully within walltime, False otherwise. - """ - logger = get_run_logger() - - start_time = time.time() - success = False - - try: - previous_state = None - while not future.done(): - elapsed_time = time.time() - start_time - if elapsed_time > walltime: - logger.error(f"The {task_name} task exceeded the walltime of {walltime} seconds." - "Cancelling the Globus Compute job.") - future.cancel() - return False - - # Check if the task was cancelled - if future.cancelled(): - logger.warning(f"The {task_name} task was cancelled.") - return False - # Assume the task is running if not done and not cancelled - elif previous_state != 'running': - logger.info(f"The {task_name} task is running...") - previous_state = 'running' - - time.sleep(check_interval) # Wait before the next status check - - # Task is done, check if it was cancelled or raised an exception - if future.cancelled(): - logger.warning(f"The {task_name} task was cancelled after completion.") - return False - - exception = future.exception() - if exception: - logger.error(f"The {task_name} task raised an exception: {exception}") - return False - - # Task completed successfully - result = future.result() - logger.info(f"The {task_name} task completed successfully with result: {result}") - success = True - - except Exception as e: - logger.error(f"An error occurred while waiting for the {task_name} task: {str(e)}") - success = False - - finally: - # Log the total time taken for the task - elapsed_time = time.time() - start_time - logger.info(f"Total duration of the {task_name} task: {elapsed_time:.2f} seconds.") - - return success @task(name="schedule_prune_task") diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index c526f72a..ddf3eff3 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -1,48 +1,36 @@ from abc import ABC, abstractmethod from dotenv import load_dotenv -from enum import Enum import logging -from orchestration.flows.bl832.config import Config832 +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobTarget +from orchestration.jobs.nersc.login import NERSCLoginMethod # noqa: F401 — re-exported for callers logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() -class NERSCLoginMethod(Enum): - """Selects which NERSC API login method to use when creating a NERSC client. - - Each method corresponds to a different set of credentials and API base URL. - """ - - SFAPI = "sfapi" - """Standard Superfacility API via Iris-registered OAuth2 credentials.""" - - IRIAPI = "iriapi" - """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" +# HPC is a BL832-scoped alias for backwards compat. +# New code should use JobTarget from orchestration.jobs.controller. +# TODO: retire this alias after all call sites are migrated. +HPC = JobTarget class TomographyHPCController(ABC): - """ - Abstract class for tomography HPC controllers. - Provides interface methods for reconstruction and building multi-resolution datasets. + """Abstract class for tomography HPC controllers. - Args: - ABC: Abstract Base Class + Provides interface methods for reconstruction and building multi-resolution + datasets. Stays in bl832/job_controller.py because reconstruct and + build_multi_resolution are tomography-specific, not generic infrastructure. """ - def __init__( - self, - config: Config832 - ) -> None: + + def __init__(self, config: BeamlineConfig) -> None: self.config = config @abstractmethod - def reconstruct( - self, - file_path: str = "", - ) -> bool: - """Perform tomography reconstruction + def reconstruct(self, file_path: str = "") -> bool: + """Perform tomography reconstruction. :param file_path: Path to the file to reconstruct. :return: True if successful, False otherwise. @@ -50,11 +38,8 @@ def reconstruct( pass @abstractmethod - def build_multi_resolution( - self, - file_path: str = "", - ) -> bool: - """Generate multi-resolution version of reconstructed tomography + def build_multi_resolution(self, file_path: str = "") -> bool: + """Generate multi-resolution version of reconstructed tomography. :param file_path: Path to the file for which to build multi-resolution data. :return: True if successful, False otherwise. @@ -62,31 +47,18 @@ def build_multi_resolution( pass -class HPC(Enum): - """ - Enum representing different HPC environments. - Use enum names as strings to identify HPC sites, ensuring a standard set of values. - - Members: - ALCF: Argonne Leadership Computing Facility - NERSC: National Energy Research Scientific Computing Center - """ - ALCF = "ALCF" - NERSC = "NERSC" - OLCF = "OLCF" - - def get_controller( hpc_type: HPC, - config: Config832, + config: BeamlineConfig, login_method: NERSCLoginMethod | None = None, ) -> TomographyHPCController: - """ - Factory function that returns an HPC controller instance for the given HPC environment. + """Factory: return the appropriate tomography HPC controller. - :param hpc_type: A string identifying the HPC environment (e.g., 'ALCF', 'NERSC'). - :return: An instance of a TomographyHPCController subclass corresponding to the given HPC environment. - :raises ValueError: If an invalid or unsupported HPC type is specified. + :param hpc_type: Target HPC site (use HPC or JobTarget enum). + :param config: Beamline configuration object. + :param login_method: NERSC-only; which API to authenticate against. + :return: A TomographyHPCController subclass instance. + :raises ValueError: If hpc_type is invalid or config is missing. """ if not isinstance(hpc_type, HPC): raise ValueError(f"Invalid HPC type provided: {hpc_type}") @@ -96,42 +68,25 @@ def get_controller( if hpc_type == HPC.ALCF: from orchestration.flows.bl832.alcf import ALCFTomographyHPCController - return ALCFTomographyHPCController( - config=config - ) + return ALCFTomographyHPCController(config=config) + elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod - resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.jobs.nersc.login import create_nersc_client + resolved_login_method = ( + login_method if isinstance(login_method, NERSCLoginMethod) + else NERSCLoginMethod.SFAPI + ) + client = create_nersc_client(config, resolved_login_method) return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_nersc_client( - config=config, - login_method=resolved_login_method - ), config=config, + client=client, login_method=resolved_login_method, ) + elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller - pass + raise NotImplementedError("OLCF controller not yet implemented") + else: raise ValueError(f"Unsupported HPC type: {hpc_type}") - - -def do_it_all() -> None: - controller = get_controller("ALCF") - controller.reconstruct() - controller.build_multi_resolution() - - file_path = "" - controller = get_controller("NERSC") - controller.reconstruct( - file_path=file_path, - ) - controller.build_multi_resolution( - file_path=file_path, - ) - - -if __name__ == "__main__": - do_it_all() - logger.info("Done.") diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f5850a61..9d05b8ef 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2,31 +2,25 @@ import datetime from dotenv import load_dotenv import httpx -import json import logging -import os from pathlib import Path import re import time -from authlib.jose import JsonWebKey from prefect import flow, get_run_logger, task from prefect.variables import Variable from sfapi_client import Client -from sfapi_client.compute import Machine from typing import Any, Optional from orchestration.flows.bl832.config import Config832 - -from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController +from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.mlflow import get_checkpoint_info -from orchestration.globus.get_globus_token import ( - get_iri_access_token, - DEFAULT_TOKEN_FILE, -) +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod +from orchestration.jobs.nersc.shifter import pull_shifter_image, check_shifter_image +from orchestration.jobs.options import load_job_options from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod from orchestration.tiled import register_file_to_tiled @@ -37,15 +31,12 @@ logger.setLevel(logging.INFO) load_dotenv() -# Applies only to NERSCLoginMethod.IRIAPI -_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" - @dataclass class SegmentationModelSpec: """All config-resolution inputs for a single model+project combination. - Consumed by ``_load_job_options`` and the job-script builders. + Consumed by ``load_job_options`` and the job-script builders. Adding a new model or project means adding one entry to the registry — nothing else changes. @@ -54,8 +45,6 @@ class SegmentationModelSpec: :param mlflow_model_name: Registered MLflow model name. :param mlflow_checkpoint_key: Config key populated from the MLflow model's ``nersc_path`` tag. - :param output_subdir: Subdirectory written under ``seg_folder/``, - e.g. ``'dino'``, ``'sam3'``, ``'dino_moon'``. :param extra_cli_flags: Additional flags injected into the inference command, e.g. ``{'--project': 'moon'}``. Omit flags not needed. """ @@ -66,98 +55,12 @@ class SegmentationModelSpec: extra_cli_flags: dict[str, str] = field(default_factory=dict) -def _load_job_options( - variable_name: str, - config_settings: dict[str, Any], - config: Config832 | None = None, - mlflow_model_name: str | None = None, - mlflow_checkpoint_key: str | None = None, -) -> dict[str, Any]: - """Load job options with three-layer resolution: config → MLflow → Prefect Variable. - - Resolution order (later layers win): - - 1. ``config_settings`` — authoritative defaults from the config YAML. - 2. MLflow Model Registry — if ``mlflow_model_name`` is provided, all - ``inference_params`` tags are overlaid onto opts by their config key name. - ``nersc_path`` is additionally mapped to ``mlflow_checkpoint_key`` if given. - 3. Prefect Variable (``variable_name``) — skipped if absent or ``defaults: true``. - If ``defaults: false``, provided keys override all lower layers. - - Args: - variable_name: Name of the Prefect Variable to load. - config_settings: Settings dict from Config832 used as base defaults. - config: Config832 instance needed for MLflow lookup. If ``None``, the - MLflow layer is skipped. - mlflow_model_name: Registered MLflow model name, e.g. ``'sam3-petiole'``. - If ``None``, the MLflow layer is skipped. - mlflow_checkpoint_key: Config key to populate from the MLflow model's - ``nersc_path`` tag, e.g. ``'finetuned_checkpoint_path'``. +class NERSCTomographyHPCController(TomographyHPCController, NERSCJobController, NerscStreamingMixin): + """NERSC tomography HPC controller for BL832. - Returns: - Resolved options dict ready for use by the caller. - """ - # ── Layer 1: config defaults ────────────────────────────────────────────── - opts = dict(config_settings) - - # ── Layer 2: MLflow registry ────────────────────────────────────────────── - if config is not None and mlflow_model_name: - try: - checkpoint_info = get_checkpoint_info(mlflow_model_name, config) - if checkpoint_info: - # Map nersc_path to the caller-specified checkpoint key - if mlflow_checkpoint_key: - opts[mlflow_checkpoint_key] = checkpoint_info.nersc_path - logger.info( - f"MLflow '{mlflow_model_name}': " - f"{mlflow_checkpoint_key}={checkpoint_info.nersc_path}" - ) - # Overlay all inference params that match existing config keys - overlaid = [] - for k, v in checkpoint_info.inference_params.items(): - if k in opts: - opts[k] = v - overlaid.append(k) - else: - # Also inject new keys (e.g. alcf_path for future use) - opts[k] = v - logger.info( - f"MLflow '{mlflow_model_name}': overlaid params: {overlaid}" - ) - else: - logger.info( - f"MLflow: no production checkpoint for '{mlflow_model_name}', " - "using config defaults." - ) - except Exception as e: - logger.warning( - f"MLflow lookup failed for '{mlflow_model_name}': {e}. " - "Using config defaults." - ) - - # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── - try: - options = Variable.get(variable_name, default={"defaults": True}, _sync=True) - if isinstance(options, str): - options = json.loads(options) - except Exception as e: - logger.warning(f"Could not load '{variable_name}': {e}. Skipping variable overrides.") - return opts - - if options.get("defaults", True): - logger.info(f"Prefect Variable '{variable_name}': no overrides.") - return opts - - overrides = {k: v for k, v in options.items() if k != "defaults"} - logger.info(f"Prefect Variable '{variable_name}': applying overrides: {list(overrides)}") - return {**opts, **overrides} - - -class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): - """ - Implementation for a NERSC-based tomography HPC controller. - - Submits reconstruction and multi-resolution jobs to NERSC via SFAPI. + Submits reconstruction, multi-resolution, and segmentation jobs to NERSC + via the SFAPI or IRI API. Beamline-agnostic job primitives (submit, wait, + filesystem ops) are inherited from NERSCJobController. """ def __init__( @@ -166,6 +69,7 @@ def __init__( client: Client | httpx.Client | None = None, login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: + NERSCJobController.__init__(self, config, client, login_method) TomographyHPCController.__init__(self, config) self.client = client self.login_method = login_method @@ -176,119 +80,6 @@ def __init__( else: raise ValueError(f"Unsupported NERSCLoginMethod: {login_method}") - @staticmethod - def create_nersc_client( - config: Config832, - login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, - ) -> Client | httpx.Client: - """Create and return a NERSC client for the requested login method. - - Two fundamentally different auth strategies are supported: - - - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 - client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` - and ``PATH_NERSC_PRI_KEY`` to the paths of those files. - - - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written - by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token - file path, or rely on the default (``~/.globus/auth_tokens.json``). - - Args: - config: Config832 instance for accessing config settings needed during client creation. - login_method: Which NERSC API to authenticate against. - Defaults to :attr:`NERSCLoginMethod.SFAPI`. - - Returns: - An authenticated :class:`sfapi_client.Client` instance. - - Raises: - ValueError: If SFAPI credential environment variables are unset. - FileNotFoundError: If credential or token files are absent. - RuntimeError: If the Globus token is expired. - Exception: If the underlying client construction fails. - """ - logger.info(f"Creating NERSC client using login method: {login_method.value}") - - if login_method is NERSCLoginMethod.SFAPI: - api_base_url = config.nersc_resources["sfapi"]["api_base_url"] - client = NERSCTomographyHPCController._create_sfapi_client() - - elif login_method is NERSCLoginMethod.IRIAPI: - api_base_url = config.nersc_resources["iri"]["api_base_url"] - client = NERSCTomographyHPCController._create_iriapi_client(api_base_url) - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") - - logger.info( - f"NERSC client created successfully " - f"(method={login_method.value}, api_url={api_base_url})." - ) - return client - - @staticmethod - def _create_iriapi_client(api_base_url: str) -> httpx.Client: - """Create a NERSC client for the IRI API using a Globus bearer token. - - Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the - environment. Reuses a cached token if valid; otherwise mints a new one - via the client credentials grant. No browser or user interaction. - - Parameters: - api_base_url: The base URL for the NERSC IRI API - Returns: - An authenticated :class:`httpx.Client` targeting the IRI API. - - Raises: - ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. - RuntimeError: If the acquired token is missing required scopes. - """ - token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) - token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - - access_token = get_iri_access_token( - token_file=token_file, - force_login=False, - prompt_login=False - ) - - return httpx.Client( - base_url=api_base_url, - headers={"Authorization": f"Bearer {access_token}"}, - timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), - ) - - @staticmethod - def _create_sfapi_client() -> Client: - """Create and return an NERSC client instance""" - - # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! - # Otherwise, the key will not have the necessary permissions to access the data. - client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") - client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") - - if not client_id_path or not client_secret_path: - logger.error("NERSC credentials paths are missing.") - raise ValueError("Missing NERSC credentials paths.") - if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): - logger.error("NERSC credential files are missing.") - raise FileNotFoundError("NERSC credential files are missing.") - - client_id = None - client_secret = None - with open(client_id_path, "r") as f: - client_id = f.read() - - with open(client_secret_path, "r") as f: - client_secret = JsonWebKey.import_key(json.loads(f.read())) - - try: - client = Client(client_id, client_secret) - logger.info("NERSC client created successfully.") - return client - except Exception as e: - logger.error(f"Failed to create NERSC client: {e}") - raise e - def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelSpec: """Return the SegmentationModelSpec for a model+project combination. @@ -327,234 +118,6 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] - def _get_nersc_username(self) -> str: - """Get the NERSC username for constructing pscratch paths. - - Uses the sfapi_client user endpoint for SFAPI, or reads - ``NERSC_USERNAME`` from the environment for IRIAPI. - - Returns: - NERSC username string. - - Raises: - ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - return self.client.user().name - else: - username = os.getenv("NERSC_USERNAME") - if not username: - raise ValueError( - "NERSC_USERNAME must be set in the environment when using IRIAPI." - ) - return username - - def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: - """Submit a Slurm job script and return the job ID. - - Dispatches to the appropriate submission mechanism based on - ``self.login_method``. - - Args: - job_script: The full Slurm batch script to submit. - num_nodes: The number of nodes to request for the job. - - Returns: - The submitted job ID as a string. - - Raises: - RuntimeError: If job submission fails. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - return str(job.jobid) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - sbatch_values = {} - for line in job_script.splitlines(): - if line.startswith("#SBATCH"): - if "-q " in line: - sbatch_values["queue_name"] = line.split("-q ")[-1].strip() - elif "-A " in line: - sbatch_values["account"] = line.split("-A ")[-1].strip() - elif "--time=" in line: - t = line.split("--time=")[-1].strip() - parts = t.split(":") - sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) - elif "-N " in line: - sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) - elif "-C " in line: - sbatch_values["constraint"] = line.split("-C ")[-1].strip() - elif "--output=" in line: - sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() - elif "--error=" in line: - sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() - elif "--reservation=" in line: - sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() - - # Strip shebang and SBATCH headers, keep the script body - script_body = "\n".join( - line for line in job_script.splitlines() - if not line.startswith("#SBATCH") and not line.startswith("#!/") - ).strip() - - constraint = sbatch_values.get("constraint", "cpu") - is_gpu = "gpu" in constraint.lower() - - resources = { - "node_count": sbatch_values.get("node_count", 1), - "processes_per_node": 1, - "exclusive_node_use": True, - } - if is_gpu: - resources["gpu_cores_per_process"] = 4 - else: - resources["cpu_cores_per_process"] = 128 - - custom_attributes = {"constraint": constraint} - - attributes = { - "duration": sbatch_values.get("duration", 1800), - "queue_name": sbatch_values.get("queue_name", "regular"), - "account": sbatch_values.get("account", "als"), - "custom_attributes": custom_attributes, - } - if "reservation" in sbatch_values: - attributes["reservation_id"] = sbatch_values["reservation"] - - job_spec = { - "executable": "/bin/bash", - "arguments": ["-s"], # read script from stdin isn't supported, so... - "pre_launch": script_body, # run the body here before the executable - "resources": resources, - "attributes": attributes, - } - - if "stdout_path" in sbatch_values: - job_spec["stdout_path"] = sbatch_values["stdout_path"] - if "stderr_path" in sbatch_values: - job_spec["stderr_path"] = sbatch_values["stderr_path"] - - response = self.client.post( - f"/api/v1/compute/job/{self.nersc_resources['perlmutter_job_submit']}", - json=job_spec, - ) - if not response.is_success: - logger.error(f"Job submission failed: {response.status_code} {response.text}") - logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") - response.raise_for_status() - return str(response.json()["id"]) - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _wait_for_job(self, job_id: str) -> bool: - """Block until a submitted job completes. - - Dispatches to the appropriate polling mechanism based on - ``self.login_method``. - - Args: - job_id: The job ID returned by `_submit_job`. - - Returns: - True if the job completed successfully, False otherwise. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.job(jobid=job_id) - job.complete() - return True - - elif self.login_method is NERSCLoginMethod.IRIAPI: - while True: - response = self.client.get( - f"/api/v1/compute/status/{self.nersc_resources['compute_resource']}/{job_id}" - ) - response.raise_for_status() - state = response.json().get("status", {}).get("state") - logger.info(f"Job {job_id} state: {state}") - if state == "completed": - return True - if state in ("failed", "canceled", "timeout"): - logger.error(f"Job {job_id} ended with state: {state}") - return False - time.sleep(60) - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _mkdir_remote(self, path: str) -> None: - """Create a directory on Perlmutter remotely. - - Args: - path: Absolute path to create. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - perlmutter.run(f"mkdir -p {path}") - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.post( - f"/api/v1/filesystem/mkdir/{self.nersc_resources['perlmutter_login']}", - json={"path": path, "parents": True}, - ) - response.raise_for_status() - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _read_remote_file(self, path: str) -> str: - """Read a remote file on Perlmutter and return its contents. - - Args: - path: Absolute path to the file on Perlmutter. - - Returns: - File contents as a string. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"cat {path}") - if isinstance(result, str): - return result - elif hasattr(result, 'output'): - return result.output - elif hasattr(result, 'stdout'): - return result.stdout - return str(result) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.get( - f"/api/v1/filesystem/view/{self.nersc_resources['perlmutter_login']}", - params={"path": path}, - ) - response.raise_for_status() - task_id = response.json().get("task_id") - if not task_id: - return response.text - - for _ in range(40): - task_response = self.client.get(f"/api/v1/task/{task_id}") - task_response.raise_for_status() - task = task_response.json() - status = task.get("status") - if status == "completed": - result = task.get("result", "") - if isinstance(result, dict): - output = result.get("output", result) - if isinstance(output, dict): - return output.get("content", str(output)) - return str(output) - return str(result) - elif status == "failed": - raise RuntimeError(f"File read task {task_id} failed: {task.get('result')}") - time.sleep(3) - - raise TimeoutError(f"File read task {task_id} did not complete") - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - def reconstruct( self, file_path: str = "", @@ -569,7 +132,7 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - username = self._get_nersc_username() + username = self.get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -597,7 +160,7 @@ def reconstruct( logger.info(f"Folder name: {folder_name}") logger.info(f"Number of nodes: {num_nodes}") - opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + opts = load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) logger.info(f"Resolved options: {opts}") @@ -739,10 +302,10 @@ def reconstruct( job_id = None try: logger.info("Submitting reconstruction job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) timing = self._fetch_timing_data(pscratch_path, job_id) if success else None return {"success": success, "job_id": job_id, "timing": timing} except Exception as e: @@ -760,7 +323,7 @@ def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: - output = self._read_remote_file(timing_file) + output = self.read_remote_file(timing_file) logger.info(f"Timing file contents:\n{output}") @@ -816,7 +379,7 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - username = self._get_nersc_username() + username = self.get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -842,7 +405,7 @@ def build_multi_resolution( # account = self.config.nersc_account - opts = _load_job_options( + opts = load_job_options( "nersc-multiresolution-options", self.config.nersc_multiresolution_settings ) @@ -883,10 +446,10 @@ def build_multi_resolution( """ try: logger.info("Submitting Tiff to Zarr job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") return success except Exception as e: @@ -903,10 +466,10 @@ def segmentation_sam3( """ logger.info("Starting NERSC segmentation process (inference_v6).") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - opts = _load_job_options( + opts = load_job_options( variable_name="nersc-segmentation-options", config_settings=self.config.nersc_segment_sam3_settings, config=self.config, @@ -1098,14 +661,14 @@ def segmentation_sam3( # Ensure directories exist logger.info("Creating necessary directories...") - self._mkdir_remote(f"{pscratch_path}/tomo_seg_logs") - self._mkdir_remote(output_dir) + self.mkdir_remote(f"{pscratch_path}/tomo_seg_logs") + self.mkdir_remote(output_dir) # Submit job - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info("Segmentation job completed successfully.") timing = self._fetch_seg_timing_from_output(pscratch_path, job_id, job_name) @@ -1154,12 +717,12 @@ def segmentation_dinov3( """ logger.info("Starting NERSC DINOv3 segmentation process.") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" # Load from config spec = self._get_segmentation_spec("dinov3", project) - opts = _load_job_options( + opts = load_job_options( variable_name=spec.variable_name, config_settings=spec.settings, config=self.config, @@ -1295,10 +858,10 @@ def segmentation_dinov3( """ try: logger.info("Submitting DINOv3 segmentation job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"DINOv3 segmentation job {'completed successfully' if success else 'failed'}.") return success except Exception as e: @@ -1319,10 +882,10 @@ def combine_segmentations( """ logger.info("Starting NERSC segmentation combination process.") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - opts = _load_job_options( + opts = load_job_options( "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings ) @@ -1419,10 +982,10 @@ def combine_segmentations( """ try: logger.info("Submitting segmentation combination job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"Segmentation combination job {'completed successfully' if success else 'failed'}.") return success except Exception as e: @@ -1441,7 +1004,7 @@ def _fetch_seg_timing_from_output(self, pscratch_path: str, job_id: str, job_nam output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" try: - output = self._read_remote_file(output_file) + output = self.read_remote_file(output_file) logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') @@ -1501,140 +1064,6 @@ def start_streaming_service( walltime=walltime ) - def pull_shifter_image( - self, - image: str = None, - wait: bool = True, - ) -> bool: - """ - Pull a container image into NERSC's Shifter cache. - - This should be run once when the image is updated, not before every reconstruction. - After the image is cached, jobs using --image= will start much faster. - - :param image: Container image to pull (defaults to recon_image from config) - :param wait: Whether to wait for the pull to complete - :return: True if successful, False otherwise - """ - logger.info("Starting Shifter image pull.") - - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - - if image is None: - image = self.config.ghcr_images832["recon_image"] - - logger.info(f"Pulling image: {image}") - - job_script = f"""#!/bin/bash -#SBATCH -q debug -#SBATCH -A als -#SBATCH -C cpu -#SBATCH --job-name=shifter_pull -#SBATCH --output={pscratch_path}/tomo_recon_logs/shifter_pull_%j.out -#SBATCH --error={pscratch_path}/tomo_recon_logs/shifter_pull_%j.err -#SBATCH -N 1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=1 -#SBATCH --time=0:15:00 - -echo "Starting Shifter image pull at $(date)" -echo "Image: {image}" - -# Check if image already exists -echo "Checking existing images..." -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" || true - -# Pull the image -echo "Pulling image..." -shifterimg -v pull {image} -PULL_STATUS=$? - -if [ $PULL_STATUS -eq 0 ]; then - echo "Image pull successful" -else - echo "Image pull failed with status $PULL_STATUS" - exit 1 -fi - -# Verify the image is now available -echo "Verifying image..." -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" - -echo "Completed at $(date)" -""" - - try: - logger.info("Submitting Shifter image pull job to Perlmutter.") - job_id = self._submit_job(job_script) - logger.info(f"Submitted job ID: {job_id}") - - if wait: - time.sleep(30) - success = self._wait_for_job(job_id) - logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") - return success - else: - logger.info(f"Job submitted. Check status with job ID: {job_id}") - return True - - except Exception as e: - logger.error(f"Error during Shifter image pull: {e}") - return False - - def check_shifter_image( - self, - image: str = None, - ) -> bool: - """ - Check if a container image is already in NERSC's Shifter cache. - - :param image: Container image to check (defaults to recon_image from config) - :return: True if image exists in cache, False otherwise - """ - logger.info("Checking Shifter image cache.") - - if image is None: - image = self.config.ghcr_images832["recon_image"] - - try: - # Run shifterimg images command - if self.login_method is NERSCLoginMethod.SFAPI: - # synchronous via utilities/command - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") - output = result if isinstance(result, str) else getattr(result, 'output', str(result)) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - # async: submit job → wait → read stdout file - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" - check_script = f"""#!/bin/bash -#SBATCH -q debug -#SBATCH -A als -#SBATCH -C cpu -#SBATCH -N 1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=1 -#SBATCH --time=0:05:00 -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true -""" - job_id = self._submit_job(check_script) - self._wait_for_job(job_id) - output = self._read_remote_file(output_file) - - if output.strip(): - logger.info(f"Image found in Shifter cache: {output.strip()}") - return True - else: - logger.info(f"Image not found in Shifter cache: {image}") - return False - - except Exception as e: - logger.warning(f"Error checking Shifter cache: {e}") - return False - def schedule_pruning( config: Config832, @@ -1908,14 +1337,6 @@ def nersc_recon_flow( nersc_to_beegfs_zarr_future.result() logger.info("All transfers complete.") - # Register the reconstructed TIFFs in tiled - register_file_to_tiled( - path=Path(config.beegfs_scratch.root_path+tiff_file_path), - prefix="beamlines/bl832/scratch", - overwrite=False, - tags=["scratch", "bl832"], - ) - # Register the reconstructed ZARRs in tiled register_file_to_tiled( path=Path(config.beegfs_scratch.root_path+zarr_file_path), @@ -1954,6 +1375,7 @@ def nersc_petiole_segment_flow( :param file_path: The path to the file to be processed. :param config: Configuration object for the flow. :param num_nodes: Number of nodes for reconstruction. + :login_method: Method to use for logging into NERSC (default: SFAPI). :return: True if reconstruction and at least one segmentation task succeeded. """ logger = get_run_logger() @@ -1996,7 +1418,6 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, - login_method=login_method ) if isinstance(recon_result, dict): @@ -2050,7 +1471,7 @@ def nersc_petiole_segment_flow( logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method @@ -2103,7 +1524,7 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) combine_success = combine_future.result() @@ -2535,10 +1956,10 @@ def pull_shifter_image_flow( ) # Check if already cached - if controller.check_shifter_image(image): + if check_shifter_image(controller, image): logger.info("Image already in cache, pulling anyway to update...") - success = controller.pull_shifter_image(image) + success = pull_shifter_image(controller, image) logger.info(f"Shifter image pull success: {success}") return success @@ -2581,7 +2002,7 @@ def nersc_multiresolution_task( :param file_path: Path to the reconstructed data folder to be processed. :param config: Configuration object for the flow. - :param login_method: NERSC API to authenticate against. + :param login_method: Method to use for logging into NERSC (default: SFAPI). :return: True if the task completed successfully, False otherwise. """ logger = get_run_logger() diff --git a/orchestration/jobs/__init__.py b/orchestration/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/alcf/__init__.py b/orchestration/jobs/alcf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/alcf/controller.py b/orchestration/jobs/alcf/controller.py new file mode 100644 index 00000000..6dcf9644 --- /dev/null +++ b/orchestration/jobs/alcf/controller.py @@ -0,0 +1,199 @@ +# orchestration/jobs/alcf/controller.py + +"""Generic ALCF job controller. + +Wraps Globus Compute's submit/wait pattern behind a common controller +interface. Beamline-specific job functions (the actual Python callables +executed remotely on ALCF compute nodes) live in +``orchestration/flows//alcf.py`` and subclass this controller. + +ALCF execution is fundamentally different from NERSC in shape: instead of +submitting a Slurm batch script and getting a job ID back, you submit a +Python callable (with arguments) and get a :class:`concurrent.futures.Future` +back. The wait/poll logic is correspondingly different too. + +Authentication is via a Globus Compute endpoint ID stored in the Prefect +Secret block ``globus-compute-endpoint``. The allocation root path on the +remote filesystem (e.g. ``/eagle/IRIProd/ALS``) is stored in the Prefect +Variable ``alcf-allocation-root-path``. +""" + +import logging +import time +from concurrent.futures import Future +from typing import Any, Callable + +from globus_compute_sdk import Client, Executor +from globus_compute_sdk.serialize import CombinedCode +from prefect import get_run_logger +from prefect.blocks.system import Secret +from prefect.variables import Variable + +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobController + +logger = logging.getLogger(__name__) + + +# Defaults for wait_for_future. Override per-call when a job is known to be +# longer-running or faster-polling than typical. +_DEFAULT_CHECK_INTERVAL_SECONDS: int = 20 +_DEFAULT_WALLTIME_SECONDS: int = 1200 # 20 minutes + +# Prefect block / variable names. Defined as constants so callers and tests +# have a single place to override. +_GLOBUS_COMPUTE_ENDPOINT_SECRET: str = "globus-compute-endpoint" +_ALLOCATION_ROOT_VARIABLE: str = "alcf-allocation-root-path" + + +class ALCFJobController(JobController): + """Generic ALCF job submission and monitoring via Globus Compute. + + Subclass for beamline-specific work (e.g. ``ALCFTomographyJobController`` + in ``flows/bl832/alcf.py``). This class knows nothing about tomography, + BL832, or any particular pipeline — it only knows how to submit a + callable to ALCF and wait for it. + + Args: + config: Beamline configuration object. Stored as ``self.config`` + for subclass use; this controller doesn't read any specific + fields from it. + + Attributes: + allocation_root: Remote path prefix on the ALCF filesystem + (e.g. ``/eagle/IRIProd/ALS``). Read from the + ``alcf-allocation-root-path`` Prefect Variable. Available for + subclasses to construct script and data paths. + endpoint_id: Globus Compute endpoint UUID, loaded from the + ``globus-compute-endpoint`` Prefect Secret block. + """ + + def __init__(self, config: BeamlineConfig) -> None: + super().__init__(config) + + allocation_data = Variable.get(_ALLOCATION_ROOT_VARIABLE, _sync=True) + self.allocation_root: str = allocation_data.get(_ALLOCATION_ROOT_VARIABLE) + if not self.allocation_root: + raise ValueError( + f"Allocation root not found in Prefect Variable " + f"'{_ALLOCATION_ROOT_VARIABLE}'" + ) + logger.info(f"Allocation root loaded: {self.allocation_root}") + + self.endpoint_id: str = Secret.load(_GLOBUS_COMPUTE_ENDPOINT_SECRET).get() + + def submit(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Future: + """Submit a callable to the Globus Compute endpoint. + + The callable is shipped to the ALCF compute node and executed there + with the supplied args/kwargs. Code-serialization uses + :class:`CombinedCode` so the function can reference imports and + helpers defined in the same module without manual packing. + + Args: + func: The function to run remotely. Should be picklable and + self-contained — imports inside the function body are + evaluated on the remote node, not locally. + *args: Positional arguments passed to ``func`` on the remote node. + **kwargs: Keyword arguments passed to ``func`` on the remote node. + + Returns: + A :class:`concurrent.futures.Future`-compatible object. Use + :meth:`wait_for_future` to poll until done. + """ + gcc = Client(code_serialization_strategy=CombinedCode()) + # Note: the Executor context manager handles connection cleanup, but + # the Future remains valid after the executor exits — submission has + # already been dispatched to the remote endpoint by that point. + with Executor(endpoint_id=self.endpoint_id, client=gcc) as fxe: + future = fxe.submit(func, *args, **kwargs) + return future + + @staticmethod + def wait_for_future( + future: Future, + task_name: str, + check_interval: int = _DEFAULT_CHECK_INTERVAL_SECONDS, + walltime: int = _DEFAULT_WALLTIME_SECONDS, + ) -> bool: + """Block until a Globus Compute future completes or hits walltime. + + Polls ``future.done()`` every ``check_interval`` seconds. If the + future is still not done after ``walltime`` seconds total, the + future is cancelled and ``False`` is returned. Logging uses Prefect's + run logger so progress shows up in the flow run UI. + + Args: + future: The future returned by :meth:`submit`. + task_name: Short descriptive name used in log messages + (e.g. ``"reconstruction"``, ``"tiff to zarr"``). + check_interval: Seconds between ``future.done()`` polls. + walltime: Maximum total seconds to wait before giving up and + cancelling. + + Returns: + True if the task completed successfully (future returned without + raising). False if the task was cancelled, raised an exception, + timed out, or an error occurred during polling. + """ + run_logger = get_run_logger() + start_time = time.time() + success = False + + try: + previous_state = None + while not future.done(): + elapsed_time = time.time() - start_time + if elapsed_time > walltime: + run_logger.error( + f"The {task_name} task exceeded the walltime of " + f"{walltime} seconds. Cancelling the Globus Compute job." + ) + future.cancel() + return False + + if future.cancelled(): + run_logger.warning(f"The {task_name} task was cancelled.") + return False + + # Assume the task is running if not done and not cancelled. + # Log once per state transition rather than every poll. + if previous_state != "running": + run_logger.info(f"The {task_name} task is running...") + previous_state = "running" + + time.sleep(check_interval) + + # Future is done — check whether it was cancelled or raised. + if future.cancelled(): + run_logger.warning( + f"The {task_name} task was cancelled after completion." + ) + return False + + exception = future.exception() + if exception: + run_logger.error( + f"The {task_name} task raised an exception: {exception}" + ) + return False + + result = future.result() + run_logger.info( + f"The {task_name} task completed successfully with result: {result}" + ) + success = True + + except Exception as e: + run_logger.error( + f"An error occurred while waiting for the {task_name} task: {e}" + ) + success = False + + finally: + elapsed_time = time.time() - start_time + run_logger.info( + f"Total duration of the {task_name} task: {elapsed_time:.2f} seconds." + ) + + return success diff --git a/orchestration/jobs/controller.py b/orchestration/jobs/controller.py new file mode 100644 index 00000000..5e312004 --- /dev/null +++ b/orchestration/jobs/controller.py @@ -0,0 +1,57 @@ +# orchestration/jobs/controller.py +"""Generic job controller abstractions. + +Defines the minimal base class and target enum that any backend-specific +controller (NERSC, ALCF, future local clusters) builds on. Beamline-specific +controllers live under ``orchestration/flows//`` and subclass +:class:`JobController` (or a domain-specific intermediate like +``TomographyJobController``). +""" + +from abc import ABC +from enum import Enum +import logging + +from orchestration.config import BeamlineConfig + +logger = logging.getLogger(__name__) + + +class JobTarget(Enum): + """Identifies a job execution target. + + Used by beamline-specific factories (e.g. ``flows/bl832/job_controller.py``) + to dispatch to the appropriate concrete controller class. + + Members: + ALCF: Argonne Leadership Computing Facility + NERSC: National Energy Research Scientific Computing Center + OLCF: Oak Ridge Leadership Computing Facility + """ + + ALCF = "ALCF" + NERSC = "NERSC" + OLCF = "OLCF" + + +class JobController(ABC): + """Base class for job controllers. + + Subclasses provide target-specific job submission and monitoring (see + :class:`orchestration.jobs.nersc.controller.NERSCJobController` and + :class:`orchestration.jobs.alcf.controller.ALCFJobController`). + + No abstract methods are declared here because submission shapes differ + fundamentally between targets — NERSC submits Slurm scripts and returns + job IDs, ALCF submits Python callables via Globus Compute and returns + futures. Domain-specific intermediates (e.g. ``TomographyJobController``) + may add ``@abstractmethod``s like ``reconstruct(file_path)`` when multiple + controllers share an interface. + + Args: + config: Beamline configuration object. Stored as ``self.config`` for + subclass use. + """ + + def __init__(self, config: BeamlineConfig) -> None: + self.config = config diff --git a/orchestration/jobs/nersc/__init__.py b/orchestration/jobs/nersc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/nersc/controller.py b/orchestration/jobs/nersc/controller.py new file mode 100644 index 00000000..3d52fe90 --- /dev/null +++ b/orchestration/jobs/nersc/controller.py @@ -0,0 +1,346 @@ +# orchestration/jobs/nersc/controller.py + +"""Generic NERSC job controller. + +Wraps the NERSC SFAPI and IRI API behind a common controller interface: +submit Slurm scripts, poll for completion, and perform basic filesystem +operations on Perlmutter. Beamline-specific job script builders live in +``orchestration/flows//nersc.py`` and subclass this controller. + +Two authentication modes are supported, selected by :class:`NERSCLoginMethod`: + +- :attr:`NERSCLoginMethod.SFAPI` — Iris-registered OAuth2 (NERSC OIDC). + Operations go through ``sfapi_client.Client``. + +- :attr:`NERSCLoginMethod.IRIAPI` — Globus bearer token. Operations are + raw ``httpx`` calls to the IRI API. + +The login method also selects the ``nersc_resources`` sub-dict +(``config.nersc_resources["iri"]`` vs ``["sfapi"]``), which carries the API +base URL and resource UUIDs used in URL construction. +""" + +import json +import logging +import os +import time + +import httpx +from sfapi_client import Client +from sfapi_client.compute import Machine + +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobController +from orchestration.jobs.nersc.login import NERSCLoginMethod + +logger = logging.getLogger(__name__) + + +class NERSCJobController(JobController): + """Generic NERSC job submission and monitoring. + + Subclass for beamline-specific work (e.g. ``NERSCTomographyJobController`` + in ``flows/bl832/nersc.py``). This class knows nothing about tomography, + BL832, or any particular pipeline — it only knows how to submit a Slurm + script and wait for it. + + Args: + config: Beamline configuration object. Must expose + ``config.nersc_resources`` with ``"iri"`` and ``"sfapi"`` sub-dicts + populated from the YAML. + client: Authenticated NERSC client. Build with + :func:`orchestration.jobs.nersc.login.create_nersc_client`. + login_method: Which NERSC API the client targets. Determines URL + construction and which dispatch branch is used in each method. + + Attributes: + client: The authenticated NERSC client. + login_method: The selected :class:`NERSCLoginMethod`. + nersc_resources: Sub-dict of ``config.nersc_resources`` for the chosen + login method. Contains ``api_base_url``, ``perlmutter_login``, + ``perlmutter_job_submit``, ``compute_resource``, etc. + """ + + def __init__( + self, + config: BeamlineConfig, + client: Client | httpx.Client | None = None, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, + ) -> None: + super().__init__(config) + self.client = client + self.login_method = login_method + + if login_method is NERSCLoginMethod.IRIAPI: + self.nersc_resources: dict[str, str] = config.nersc_resources["iri"] + elif login_method is NERSCLoginMethod.SFAPI: + self.nersc_resources = config.nersc_resources["sfapi"] + else: + raise ValueError(f"Unsupported NERSCLoginMethod: {login_method}") + + def get_nersc_username(self) -> str: + """Return the NERSC username, used to construct ``pscratch`` paths. + + SFAPI exposes the username via the user endpoint. IRIAPI does not, so + the username is read from the ``NERSC_USERNAME`` environment variable. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and ``NERSC_USERNAME`` is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + + def submit_job(self, job_script: str, num_nodes: int = 1) -> str: + """Submit a Slurm batch script and return the job ID. + + For SFAPI, the script is passed verbatim to ``perlmutter.submit_job``. + + For IRIAPI, the script is parsed: SBATCH headers become attributes in + the PSI/J-style job spec, and the script body (everything after the + SBATCH block) becomes the ``pre_launch`` payload. The IRI API does not + accept raw Slurm scripts, so this translation is necessary. + + Args: + job_script: The full Slurm batch script to submit. + num_nodes: Reserved for future use; the actual node count is read + from the SBATCH ``-N`` line in the script. + + Returns: + The submitted job ID as a string. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + httpx.HTTPStatusError: If the IRI API submission returns non-2xx. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + return self._submit_job_iriapi(job_script) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _submit_job_iriapi(self, job_script: str) -> str: + """Translate a Slurm script into a PSI/J job spec and POST to IRI API.""" + sbatch_values: dict[str, object] = {} + for line in job_script.splitlines(): + if not line.startswith("#SBATCH"): + continue + if "-q " in line: + sbatch_values["queue_name"] = line.split("-q ")[-1].strip() + elif "-A " in line: + sbatch_values["account"] = line.split("-A ")[-1].strip() + elif "--time=" in line: + t = line.split("--time=")[-1].strip() + parts = t.split(":") + sbatch_values["duration"] = ( + int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2]) + ) + elif "-N " in line: + sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) + elif "-C " in line: + sbatch_values["constraint"] = line.split("-C ")[-1].strip() + elif "--output=" in line: + sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() + elif "--error=" in line: + sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + elif "--reservation=" in line: + sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() + + # Script body: everything except shebang and SBATCH headers. + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + constraint = sbatch_values.get("constraint", "cpu") + is_gpu = "gpu" in str(constraint).lower() + + resources = { + "node_count": sbatch_values.get("node_count", 1), + "processes_per_node": 1, + "exclusive_node_use": True, + } + if is_gpu: + resources["gpu_cores_per_process"] = 4 + else: + resources["cpu_cores_per_process"] = 128 + + attributes = { + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "regular"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": {"constraint": constraint}, + } + if "reservation" in sbatch_values: + attributes["reservation_id"] = sbatch_values["reservation"] + + job_spec = { + "executable": "/bin/bash", + # Reading the script from stdin isn't supported, so the body goes + # into pre_launch (runs before the executable's main entry point). + "arguments": ["-s"], + "pre_launch": script_body, + "resources": resources, + "attributes": attributes, + } + if "stdout_path" in sbatch_values: + job_spec["stdout_path"] = sbatch_values["stdout_path"] + if "stderr_path" in sbatch_values: + job_spec["stderr_path"] = sbatch_values["stderr_path"] + + response = self.client.post( + f"/api/v1/compute/job/{self.nersc_resources['perlmutter_job_submit']}", + json=job_spec, + ) + if not response.is_success: + logger.error(f"Job submission failed: {response.status_code} {response.text}") + logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") + response.raise_for_status() + return str(response.json()["id"]) + + def wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job reaches a terminal state. + + For SFAPI, this delegates to ``sfapi_client``'s ``job.complete()``, + which handles polling internally. For IRIAPI, polls the status + endpoint every 60 seconds. + + Args: + job_id: The job ID returned by :meth:`submit_job`. + + Returns: + True if the job completed successfully, False if it failed, + was canceled, or hit a timeout. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/{self.nersc_resources['compute_resource']}/{job_id}" + ) + response.raise_for_status() + state = response.json().get("status", {}).get("state") + logger.info(f"Job {job_id} state: {state}") + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def mkdir_remote(self, path: str) -> None: + """Create a directory on Perlmutter. + + Equivalent to ``mkdir -p`` — intermediate directories are created and + existing directories are not an error. + + Args: + path: Absolute path to create on Perlmutter. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + httpx.HTTPStatusError: If the IRI filesystem call fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + perlmutter.run(f"mkdir -p {path}") + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + f"/api/v1/filesystem/mkdir/{self.nersc_resources['perlmutter_login']}", + json={"path": path, "parents": True}, + ) + response.raise_for_status() + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def read_remote_file(self, path: str) -> str: + """Read a file on Perlmutter and return its contents as a string. + + SFAPI runs ``cat`` synchronously. IRIAPI uses the async filesystem + view endpoint, which returns a task_id that must be polled. + + Args: + path: Absolute path to the file on Perlmutter. + + Returns: + File contents as a string. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + RuntimeError: If the IRI read task ends in a failed state. + TimeoutError: If the IRI read task does not complete within + the local polling budget (~120 seconds). + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {path}") + if isinstance(result, str): + return result + if hasattr(result, "output"): + return result.output + if hasattr(result, "stdout"): + return result.stdout + return str(result) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + return self._read_remote_file_iriapi(path) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _read_remote_file_iriapi(self, path: str) -> str: + """Read a remote file via the IRI async filesystem-view endpoint.""" + response = self.client.get( + f"/api/v1/filesystem/view/{self.nersc_resources['perlmutter_login']}", + params={"path": path}, + ) + response.raise_for_status() + task_id = response.json().get("task_id") + if not task_id: + return response.text + + for _ in range(40): + task_response = self.client.get(f"/api/v1/task/{task_id}") + task_response.raise_for_status() + task = task_response.json() + status = task.get("status") + if status == "completed": + result = task.get("result", "") + if isinstance(result, dict): + output = result.get("output", result) + if isinstance(output, dict): + return output.get("content", str(output)) + return str(output) + return str(result) + elif status == "failed": + raise RuntimeError( + f"File read task {task_id} failed: {task.get('result')}" + ) + time.sleep(3) + + raise TimeoutError(f"File read task {task_id} did not complete") diff --git a/orchestration/jobs/nersc/login.py b/orchestration/jobs/nersc/login.py new file mode 100644 index 00000000..7b0776b1 --- /dev/null +++ b/orchestration/jobs/nersc/login.py @@ -0,0 +1,167 @@ +# orchestration/jobs/nersc/login.py + +"""NERSC API authentication and client construction. + +Two login methods are supported, each backed by a different credential model +and pointed at a different API base URL: + +- :attr:`NERSCLoginMethod.SFAPI`: Iris-registered OAuth2 client ID + private + key (NERSC OIDC flow). Reads ``PATH_NERSC_CLIENT_ID`` and + ``PATH_NERSC_PRI_KEY`` from the environment. +- :attr:`NERSCLoginMethod.IRIAPI`: Globus bearer token written by + ``orchestration/globus/get_globus_token.py``. Reads ``PATH_GLOBUS_TOKEN_FILE`` + from the environment, falling back to ``~/.globus/auth_tokens.json``. + +The base URLs for each method live in the beamline config under +``nersc_resources.{iri,sfapi}.api_base_url``. +""" + +from enum import Enum +import json +import logging +import os +from pathlib import Path + +from authlib.jose import JsonWebKey +from dotenv import load_dotenv +import httpx +from sfapi_client import Client + +from orchestration.config import BeamlineConfig +from orchestration.globus.get_globus_token import ( + DEFAULT_TOKEN_FILE, + get_iri_access_token, +) + +logger = logging.getLogger(__name__) +load_dotenv() + +# Env var pointing at the cached Globus token file used by IRIAPI auth. +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" + + +class NERSCLoginMethod(Enum): + """Selects which NERSC API to authenticate against. + + Each method has its own credentials and API base URL — see module docstring. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via Globus bearer token.""" + + +def create_nersc_client( + config: BeamlineConfig, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, +) -> Client | httpx.Client: + """Create and return a NERSC client for the requested login method. + + Reads the API base URL from ``config.nersc_resources[login_method.value]``, + then delegates to the appropriate underscored builder. + + Args: + config: Beamline config instance. Must expose ``nersc_resources`` with + ``"iri"`` and ``"sfapi"`` sub-dicts each containing ``api_base_url``. + login_method: Which NERSC API to authenticate against. Defaults to + :attr:`NERSCLoginMethod.IRIAPI`. + + Returns: + An authenticated client — :class:`sfapi_client.Client` for SFAPI, + :class:`httpx.Client` for IRIAPI. + + Raises: + ValueError: If SFAPI credential env vars are unset, or if + ``login_method`` is not a recognized member. + FileNotFoundError: If SFAPI credential files are absent. + RuntimeError: If the Globus token is expired or missing required scopes. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + + if login_method is NERSCLoginMethod.SFAPI: + api_base_url = config.nersc_resources["sfapi"]["api_base_url"] + client = _create_sfapi_client() + elif login_method is NERSCLoginMethod.IRIAPI: + api_base_url = config.nersc_resources["iri"]["api_base_url"] + client = _create_iriapi_client(api_base_url) + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_base_url})." + ) + return client + + +def _create_iriapi_client(api_base_url: str) -> httpx.Client: + """Create a NERSC IRI API client using a Globus bearer token. + + Reuses a cached token if valid; otherwise mints a new one via the client + credentials grant. No browser or user interaction. + + Args: + api_base_url: Base URL for the NERSC IRI API. + + Returns: + An authenticated :class:`httpx.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_iri_access_token( + token_file=token_file, + force_login=False, + prompt_login=False, + ) + + return httpx.Client( + base_url=api_base_url, + headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), + ) + + +def _create_sfapi_client() -> Client: + """Create a NERSC SFAPI client from Iris-registered OAuth2 credentials. + + Reads the client ID and private key paths from ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY``. When generating the SFAPI key in Iris, the + "asldev" user must be selected so the key has the necessary data-access + permissions. + + Returns: + An authenticated :class:`sfapi_client.Client`. + + Raises: + ValueError: If the credential env vars are unset. + FileNotFoundError: If the credential files are absent. + """ + client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") + client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") + + if not client_id_path or not client_secret_path: + logger.error("NERSC credentials paths are missing.") + raise ValueError("Missing NERSC credentials paths.") + if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): + logger.error("NERSC credential files are missing.") + raise FileNotFoundError("NERSC credential files are missing.") + + with open(client_id_path, "r") as f: + client_id = f.read() + with open(client_secret_path, "r") as f: + client_secret = JsonWebKey.import_key(json.loads(f.read())) + + try: + client = Client(client_id, client_secret) + logger.info("NERSC client created successfully.") + return client + except Exception as e: + logger.error(f"Failed to create NERSC client: {e}") + raise diff --git a/orchestration/jobs/nersc/shifter.py b/orchestration/jobs/nersc/shifter.py new file mode 100644 index 00000000..348b4be7 --- /dev/null +++ b/orchestration/jobs/nersc/shifter.py @@ -0,0 +1,242 @@ +# orchestration/jobs/nersc/shifter.py + +"""Shifter container image cache management at NERSC. + +Shifter is NERSC's container runtime. Images pulled into the per-system cache +are available to all subsequent jobs via ``shifter --image=`` without +a per-job pull penalty. These helpers manage that cache. + +Both functions take a :class:`NERSCJobController` so they can use its +``submit_job`` / ``wait_for_job`` / ``read_remote_file`` primitives. This +keeps the controller focused on the submit/wait/filesystem core and avoids +ballooning it with one-off Shifter operations. +""" + +import logging +import time +from typing import TYPE_CHECKING + +from orchestration.jobs.nersc.login import NERSCLoginMethod + +if TYPE_CHECKING: + from orchestration.jobs.nersc.controller import NERSCJobController + +logger = logging.getLogger(__name__) + + +def _shifter_pull_script( + image: str, + log_dir: str, + *, + account: str, + qos: str = "debug", + walltime: str = "0:15:00", +) -> str: + """ + Build the Slurm script that pulls a Shifter image. + + Args: + - image: The container image to pull + - log_dir: Directory to store Slurm output and error logs + - account: Slurm account to charge the pull job to + - qos: Slurm QoS for the pull job (default: "debug") + - walltime: Walltime for the pull job (default: "0:15:00") + + Returns: + - A string containing the Slurm job script to pull the specified Shifter image. + """ + return f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -C cpu +#SBATCH --job-name=shifter_pull +#SBATCH --output={log_dir}/shifter_pull_%j.out +#SBATCH --error={log_dir}/shifter_pull_%j.err +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time={walltime} + +echo "Starting Shifter image pull at $(date)" +echo "Image: {image}" + +echo "Checking existing images..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" || true + +echo "Pulling image..." +shifterimg -v pull {image} +PULL_STATUS=$? + +if [ $PULL_STATUS -eq 0 ]; then + echo "Image pull successful" +else + echo "Image pull failed with status $PULL_STATUS" + exit 1 +fi + +echo "Verifying image..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" + +echo "Completed at $(date)" +""" + + +def _shifter_check_script( + image: str, + output_file: str, + *, + account: str, + qos: str = "debug", + walltime: str = "0:05:00", +) -> str: + """ + Build the Slurm script that writes Shifter cache state to a file. + + Args: + - image: The container image to check for in the Shifter cache + - output_file: The file to write the check output to + - account: Slurm account to charge the check job to + - qos: Slurm QoS for the check job (default: "debug") + - walltime: Walltime for the check job (default: "0:05:00") + + Returns: + - A string containing the Slurm job script to check for the specified Shifter image and write the results. + """ + return f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -C cpu +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time={walltime} +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true +""" + + +def _user_log_dir(controller: "NERSCJobController") -> str: + """ + Return the per-user log directory used for Shifter Slurm output. + This is a convention for storing logs in a user-specific location on the shared filesystem. + It is not an official NERSC requirement, but it helps avoid cluttering home directories and keeps logs organized. + The path is typically of the form: /pscratch/sd///shifter_logs + + Args: + - controller: A NERSCJobController instance used to get the username for constructing the log directory path. + Returns: + - A string representing the path to the user's log directory for Shifter jobs. + """ + username = controller.get_nersc_username() + return f"/pscratch/sd/{username[0]}/{username}/shifter_logs" + + +def pull_shifter_image( + controller: "NERSCJobController", + image: str, + wait: bool = True, + account: str = "als", +) -> bool: + """Pull a container image into NERSC's Shifter cache. + + Run this once when an image is updated, not before every job that uses it. + After the image is cached, jobs using ``--image=`` start much faster. + + Args: + controller: A NERSC job controller used to submit and monitor the + Slurm pull job. + image: Container image to pull, e.g. + ``"docker:ghcr.io/als-computing/tomopy-multinode:latest"``. + wait: If True, block until the pull job finishes and return its + success state. If False, return True as soon as the job is + submitted. + account: Slurm account to charge the pull job to. Defaults to + ``"als"``; pass a different value for other beamlines. + + Returns: + ``True`` if the pull succeeded (or was submitted, when ``wait=False``), + ``False`` if submission or the pull itself failed. + """ + logger.info(f"Pulling Shifter image: {image}") + + log_dir = _user_log_dir(controller) + controller.mkdir_remote(log_dir) + + job_script = _shifter_pull_script(image, log_dir, account=account) + + try: + job_id = controller.submit_job(job_script) + logger.info(f"Submitted Shifter pull job: {job_id}") + + if not wait: + logger.info(f"Returning early; check status with job ID {job_id}") + return True + + time.sleep(30) + success = controller.wait_for_job(job_id) + logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") + return success + + except Exception as e: + logger.error(f"Error during Shifter image pull: {e}") + return False + + +def check_shifter_image( + controller: "NERSCJobController", + image: str, + account: str = "als", +) -> bool: + """Check whether a container image is already in NERSC's Shifter cache. + + Dispatches on the controller's login method. SFAPI can run ``shifterimg`` + synchronously via the utilities endpoint; IRIAPI must submit a Slurm job, + wait for it, then read the captured stdout. + + Args: + controller: A NERSC job controller used to query Shifter. + image: Container image to check. + account: Slurm account for the IRIAPI check job. Ignored for SFAPI. + + Returns: + ``True`` if the image is present in the cache, ``False`` otherwise + (including on error). + """ + from sfapi_client.compute import Machine + + logger.info(f"Checking Shifter cache for: {image}") + + try: + if controller.login_method is NERSCLoginMethod.SFAPI: + # Synchronous: run shifterimg directly via the utilities endpoint + perlmutter = controller.client.compute(Machine.perlmutter) + result = perlmutter.run( + f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"" + ) + output = ( + result if isinstance(result, str) + else getattr(result, "output", None) or getattr(result, "stdout", "") or str(result) + ) + + elif controller.login_method is NERSCLoginMethod.IRIAPI: + # Async: submit a one-off job, wait, read the captured output file + log_dir = _user_log_dir(controller) + controller.mkdir_remote(log_dir) + output_file = f"{log_dir}/shifter_check_{int(time.time())}.txt" + + job_script = _shifter_check_script(image, output_file, account=account) + job_id = controller.submit_job(job_script) + controller.wait_for_job(job_id) + output = controller.read_remote_file(output_file) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {controller.login_method}") + + if output.strip(): + logger.info(f"Image found in Shifter cache: {output.strip()}") + return True + logger.info(f"Image not found in Shifter cache: {image}") + return False + + except Exception as e: + logger.warning(f"Error checking Shifter cache: {e}") + return False diff --git a/orchestration/jobs/options.py b/orchestration/jobs/options.py new file mode 100644 index 00000000..7b3da21c --- /dev/null +++ b/orchestration/jobs/options.py @@ -0,0 +1,105 @@ +"""Three-layer job option resolution: config defaults → MLflow → Prefect Variable. + +Lets beamline-specific code declare a base settings dict, optionally tie it to +an MLflow Model Registry entry, and accept runtime overrides via a Prefect +Variable — all without hardcoding any one beamline or HPC target. +""" + +import json +import logging +from typing import Any + +from prefect.variables import Variable + +from orchestration.config import BeamlineConfig +from orchestration.mlflow import get_checkpoint_info + +logger = logging.getLogger(__name__) + + +def load_job_options( + variable_name: str, + config_settings: dict[str, Any], + config: BeamlineConfig | None = None, + mlflow_model_name: str | None = None, + mlflow_checkpoint_key: str | None = None, +) -> dict[str, Any]: + """Load job options with three-layer resolution: config → MLflow → Prefect Variable. + + Resolution order (later layers win): + + 1. ``config_settings`` — authoritative defaults from the config YAML. + 2. MLflow Model Registry — if ``mlflow_model_name`` is provided, all + ``inference_params`` tags are overlaid onto opts by their config key name. + ``nersc_path`` is additionally mapped to ``mlflow_checkpoint_key`` if given. + 3. Prefect Variable (``variable_name``) — skipped if absent or ``defaults: true``. + If ``defaults: false``, provided keys override all lower layers. + + Args: + variable_name: Name of the Prefect Variable to load. + config_settings: Settings dict used as base defaults (e.g. + ``config.nersc_segment_sam3_settings``). + config: Beamline config instance needed for MLflow lookup. If ``None``, + the MLflow layer is skipped. + mlflow_model_name: Registered MLflow model name, e.g. ``'sam3-petiole'``. + If ``None``, the MLflow layer is skipped. + mlflow_checkpoint_key: Config key to populate from the MLflow model's + ``nersc_path`` tag, e.g. ``'finetuned_checkpoint_path'``. + + Returns: + Resolved options dict ready for use by the caller. + """ + # ── Layer 1: config defaults ────────────────────────────────────────────── + opts = dict(config_settings) + + # ── Layer 2: MLflow registry ────────────────────────────────────────────── + if config is not None and mlflow_model_name: + try: + checkpoint_info = get_checkpoint_info(mlflow_model_name, config) + if checkpoint_info: + # Map nersc_path to the caller-specified checkpoint key + if mlflow_checkpoint_key: + opts[mlflow_checkpoint_key] = checkpoint_info.nersc_path + logger.info( + f"MLflow '{mlflow_model_name}': " + f"{mlflow_checkpoint_key}={checkpoint_info.nersc_path}" + ) + # Overlay all inference params that match existing config keys + overlaid = [] + for k, v in checkpoint_info.inference_params.items(): + if k in opts: + opts[k] = v + overlaid.append(k) + else: + # Also inject new keys (e.g. alcf_path for future use) + opts[k] = v + logger.info( + f"MLflow '{mlflow_model_name}': overlaid params: {overlaid}" + ) + else: + logger.info( + f"MLflow: no production checkpoint for '{mlflow_model_name}', " + "using config defaults." + ) + except Exception as e: + logger.warning( + f"MLflow lookup failed for '{mlflow_model_name}': {e}. " + "Using config defaults." + ) + + # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── + try: + options = Variable.get(variable_name, default={"defaults": True}, _sync=True) + if isinstance(options, str): + options = json.loads(options) + except Exception as e: + logger.warning(f"Could not load '{variable_name}': {e}. Skipping variable overrides.") + return opts + + if options.get("defaults", True): + logger.info(f"Prefect Variable '{variable_name}': no overrides.") + return opts + + overrides = {k: v for k, v in options.items() if k != "defaults"} + logger.info(f"Prefect Variable '{variable_name}': applying overrides: {list(overrides)}") + return {**opts, **overrides}