From b1e3225c66278a419ff95744252b2a0c42f05f17 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:39:22 -0700 Subject: [PATCH 01/26] Initial commit for adding NERSC IRI-API support alongside SFAPI for job submission --- orchestration/globus/token.py | 235 ++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 orchestration/globus/token.py diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py new file mode 100644 index 00000000..81b5438f --- /dev/null +++ b/orchestration/globus/token.py @@ -0,0 +1,235 @@ +import json +import logging +import os +from pathlib import Path +import stat +import time + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +logger = logging.getLogger(__name__) + +# Default token file location, matching the Globus SDK convention. +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" + + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client (machine-to-machine). + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + # 1. Do we already have a valid token? + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + # 2. Mint a new token — same call whether first time or expired. + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"New Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def load_token_file(token_file: Path) -> dict | None: + """Load saved Globus token data from disk. + + Args: + token_file: Path to the JSON token file. + + Returns: + Parsed token dict, or None if the file does not exist. + """ + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_token_file(token_file: Path, tokens: dict) -> None: + """Atomically save Globus token data to disk with owner-only permissions. + + Writes to a temporary file then renames to avoid partial writes. + + Args: + token_file: Destination path for the JSON token file. + tokens: Token dict to serialise. + """ + _ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + required_scopes: frozenset[str], + resource_server: str, +) -> dict: + """Run an interactive browser-based Globus login flow. + + Prints an authorization URL, waits for the user to paste an auth code, + and exchanges it for tokens. + + Args: + client: Globus NativeAppAuthClient to drive the flow. + required_scopes: Set of OAuth2 scopes to request. + resource_server: Resource server key to extract from the token response + (e.g. ``"auth.globus.org"``). + + Returns: + Token dict for the given resource server. + """ + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(required_scopes)), + refresh_tokens=True, + ) + logger.info("Open this URL in your browser to authenticate with Globus:") + logger.info(client.oauth2_get_authorize_url()) + code = input("\nEnter authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(code) + return token_response.by_resource_server[resource_server] + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, + refresh_token: str, + resource_server: str, +) -> dict | None: + """Attempt a silent Globus token refresh. + + Args: + client: Globus NativeAppAuthClient to drive the refresh. + refresh_token: The stored refresh token. + resource_server: Resource server key to extract from the token response. + + Returns: + Fresh token dict for the given resource server, or None if refresh failed. + """ + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.by_resource_server[resource_server] + except GlobusAPIError as e: + logger.warning( + f"Globus token refresh failed ({e.http_status}); " + "falling back to interactive login." + ) + return None + + +def get_access_token( + client_id: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, + force_login: bool = False, +) -> str: + """Get a valid Globus access token, refreshing or logging in as needed. + + Attempts a silent refresh from the saved token file first. Falls back to + interactive browser login if no saved tokens exist, the refresh token is + absent, or the refresh fails. Saves the resulting tokens back to disk. + + Args: + client_id: Globus NativeApp client ID. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token file. Defaults to + ``~/.globus/auth_tokens.json``. + force_login: If True, skip refresh and force interactive login. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + globus_client = globus_sdk.NativeAppAuthClient(client_id) + + auth_data: dict | None = None + + if not force_login: + stored = load_token_file(resolved_token_file) + if stored and stored.get("refresh_token"): + auth_data = refresh_tokens( + globus_client, stored["refresh_token"], resource_server + ) + + if auth_data is None: + logger.info("Initiating interactive Globus login.") + auth_data = interactive_login(globus_client, required_scopes, resource_server) + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) From 924c4a0b8b40209e713055e477d7e9b2ad42913f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:25:00 -0700 Subject: [PATCH 02/26] Removed NERSCLoginMethod(Enum) from nersc.py. Created a temporary test flow for reconstruction to test job submission. In reconstruct(), replaced the SFAPI-specific job submission/polling code with the general _submit_job() and _wait_for_job() methods. --- orchestration/flows/bl832/nersc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f5850a61..aab29c13 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -17,7 +17,6 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 - from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block From 5b0607e70bdfd5187b5fb68381e7f4ae75d4064f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:50:35 -0700 Subject: [PATCH 03/26] Updating pytests --- orchestration/_tests/test_bl832/test_nersc.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index abf616fd..f303bac1 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -23,6 +23,16 @@ def prefect_test_fixture(): # Shared fixtures # --------------------------------------------------------------------------- +@pytest.fixture +def mock_config(mocker): + config = mocker.MagicMock() + config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + return config + + @pytest.fixture def mock_sfapi_client(mocker): """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" @@ -201,11 +211,11 @@ def mock_iriapi_client(mocker): return client + # --------------------------------------------------------------------------- # _create_sfapi_client # --------------------------------------------------------------------------- - def test_create_sfapi_client_success(mocker): """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController @@ -522,7 +532,7 @@ def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config83 monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} controller = NERSCTomographyHPCController( client=mock_iriapi_client, From ce358325fbfc93c4cafeb91aadfd79ed10598297 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 30 Mar 2026 14:53:53 -0700 Subject: [PATCH 04/26] successfully ran reconstruction using the IRI-API --- orchestration/flows/bl832/nersc.py | 1 + orchestration/globus/token.py | 390 +++++++++++++++++++++-------- scripts/get_globus_token.py | 337 +++++++++++++++++++++++++ 3 files changed, 624 insertions(+), 104 deletions(-) create mode 100644 scripts/get_globus_token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index aab29c13..2860eeda 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -393,6 +393,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() # Strip shebang and SBATCH headers, keep the script body + script_body = "\n".join( line for line in job_script.splitlines() if not line.startswith("#SBATCH") and not line.startswith("#!/") diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py index 81b5438f..4970eaa7 100644 --- a/orchestration/globus/token.py +++ b/orchestration/globus/token.py @@ -1,3 +1,4 @@ +# orchestration/globus/token.py import json import logging import os @@ -12,69 +13,20 @@ # Default token file location, matching the Globus SDK convention. DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" -GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" +# IRI API Globus scope and resource server. +# The IRI access token lives in other_tokens under this scope, not at the +# top level of the auth.globus.org response. +IRI_SCOPE: str = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client (machine-to-machine). - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - # 1. Do we already have a valid token? - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - # 2. Mint a new token — same call whether first time or expired. - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] +# --------------------------------------------------------------------------- +# File I/O +# --------------------------------------------------------------------------- def load_token_file(token_file: Path) -> dict | None: """Load saved Globus token data from disk. @@ -112,105 +64,345 @@ def save_token_file(token_file: Path, tokens: dict) -> None: os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +# --------------------------------------------------------------------------- +# IRI token helpers +# --------------------------------------------------------------------------- + +def _parse_scope_string(scope_string: str) -> set[str]: + """Split a space-separated scope string into a set. + + Args: + scope_string: Space-separated OAuth2 scope string. + + Returns: + Set of individual scope strings. + """ + return set(scope_string.split()) if scope_string else set() + + +def extract_iri_token(token_response_data: dict) -> dict: + """Extract the IRI access token entry from a Globus token response. + + The IRI token is not returned at the top level — it lives inside + ``other_tokens``, identified by :data:`IRI_SCOPE`. + + Args: + token_response_data: Full token response dict as returned by the + Globus SDK (i.e. ``token_response.data``). + + Returns: + Token dict for the IRI resource server. + + Raises: + RuntimeError: If no token matching the IRI scope is found. + """ + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError( + f"Missing token for required IRI scope: {IRI_SCOPE}. " + "Re-run with --force-login and ensure consent is granted for the IRI scope." + ) + + +def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + """Return a copy of token_response_data with the IRI entry replaced. + + Args: + token_response_data: Full stored token response dict. + iri_token_data: Updated IRI token dict to splice in. + + Returns: + Updated token response dict. + """ + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for i, token_data in enumerate(other_tokens): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + other_tokens[i] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def _get_iri_refresh_token(stored_tokens: dict) -> str | None: + """Extract the IRI refresh token from stored token data, if present. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The IRI refresh token string, or None if absent. + """ + try: + return extract_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def _get_auth_refresh_token(stored_tokens: dict) -> str | None: + """Extract the top-level Globus Auth refresh token from stored data. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The auth refresh token string, or None if absent. + """ + if "refresh_token" in stored_tokens: + return stored_tokens["refresh_token"] + auth_tokens = stored_tokens.get("auth.globus.org") + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + return None + + +# --------------------------------------------------------------------------- +# NativeApp flow (interactive) +# --------------------------------------------------------------------------- + def interactive_login( client: globus_sdk.NativeAppAuthClient, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], + prompt_login: bool = False, ) -> dict: """Run an interactive browser-based Globus login flow. Prints an authorization URL, waits for the user to paste an auth code, - and exchanges it for tokens. + and returns the full token response data including ``other_tokens``. Args: client: Globus NativeAppAuthClient to drive the flow. - required_scopes: Set of OAuth2 scopes to request. - resource_server: Resource server key to extract from the token response - (e.g. ``"auth.globus.org"``). + requested_scopes: Set of OAuth2 scopes to request. Should include + :data:`IRI_SCOPE` to obtain an IRI API token. + prompt_login: If True, add ``prompt=login`` to the authorize URL to + force a fresh identity-provider login. Returns: - Token dict for the given resource server. + Full token response dict (``token_response.data``), including + ``other_tokens``. + + Raises: + RuntimeError: If no authorization code is entered, or if the code + exchange fails. """ client.oauth2_start_flow( - requested_scopes=" ".join(sorted(required_scopes)), + requested_scopes=" ".join(sorted(requested_scopes)), refresh_tokens=True, ) logger.info("Open this URL in your browser to authenticate with Globus:") - logger.info(client.oauth2_get_authorize_url()) + prompt = "login" if prompt_login else globus_sdk.MISSING + logger.info(client.oauth2_get_authorize_url(prompt=prompt)) code = input("\nEnter authorization code: ").strip() - token_response = client.oauth2_exchange_code_for_tokens(code) - return token_response.by_resource_server[resource_server] + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the " + "code shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as e: + if e.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed — the code was empty, " + "invalid, expired, or already used. Re-run and try again." + ) from e + raise RuntimeError( + f"Authorization code exchange failed with HTTP {e.http_status}." + ) from e + return token_response.data -def refresh_tokens( +def _refresh_single_token( client: globus_sdk.NativeAppAuthClient, refresh_token: str, - resource_server: str, ) -> dict | None: - """Attempt a silent Globus token refresh. + """Attempt a single Globus token refresh, returning raw response data. Args: - client: Globus NativeAppAuthClient to drive the refresh. + client: NativeAppAuthClient to drive the refresh. refresh_token: The stored refresh token. - resource_server: Resource server key to extract from the token response. Returns: - Fresh token dict for the given resource server, or None if refresh failed. + Raw token response data dict, or None if the refresh failed. """ try: token_response = client.oauth2_refresh_token(refresh_token) - return token_response.by_resource_server[resource_server] + return token_response.data except GlobusAPIError as e: logger.warning( f"Globus token refresh failed ({e.http_status}); " - "falling back to interactive login." + "will fall back to interactive login." ) return None +def _refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, + stored_tokens: dict, +) -> tuple[dict | None, bool]: + """Try to refresh stored tokens, preferring the IRI refresh token. + + Attempts the IRI-specific refresh token first, then falls back to the + top-level Globus Auth refresh token. + + Args: + client: NativeAppAuthClient to drive the refresh. + stored_tokens: Full stored token response dict. + + Returns: + Tuple of ``(updated_token_data, success)``. On failure both values + are ``(None, False)``. + """ + iri_refresh = _get_iri_refresh_token(stored_tokens) + if iri_refresh: + iri_token_data = _refresh_single_token(client, iri_refresh) + if iri_token_data is not None: + return _replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh = _get_auth_refresh_token(stored_tokens) + if auth_refresh: + auth_data = _refresh_single_token(client, auth_refresh) + if auth_data is not None: + return auth_data, True + + return None, False + + def get_access_token( client_id: str, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], token_file: Path | None = None, force_login: bool = False, + prompt_login: bool = False, ) -> str: - """Get a valid Globus access token, refreshing or logging in as needed. + """Get a valid IRI API access token via the NativeApp interactive flow. Attempts a silent refresh from the saved token file first. Falls back to interactive browser login if no saved tokens exist, the refresh token is absent, or the refresh fails. Saves the resulting tokens back to disk. + The IRI token is extracted from ``other_tokens`` in the response — it is + not the top-level Globus Auth token. + Args: client_id: Globus NativeApp client ID. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. + requested_scopes: Set of OAuth2 scopes to request. Must include + :data:`IRI_SCOPE` to obtain a usable IRI API token. token_file: Path to the JSON token file. Defaults to ``~/.globus/auth_tokens.json``. force_login: If True, skip refresh and force interactive login. + prompt_login: If True, add ``prompt=login`` to the authorize URL. Returns: - A valid Globus access token string. + A valid IRI API access token string. Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. + RuntimeError: If the IRI scope token is missing from the response. """ resolved_token_file = token_file or DEFAULT_TOKEN_FILE globus_client = globus_sdk.NativeAppAuthClient(client_id) - auth_data: dict | None = None + token_response_data: dict | None = None + used_refresh = False if not force_login: stored = load_token_file(resolved_token_file) - if stored and stored.get("refresh_token"): - auth_data = refresh_tokens( - globus_client, stored["refresh_token"], resource_server + if stored: + token_response_data, used_refresh = _refresh_stored_tokens( + globus_client, stored ) - if auth_data is None: + if token_response_data is None: logger.info("Initiating interactive Globus login.") - auth_data = interactive_login(globus_client, required_scopes, resource_server) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + + # Extract IRI token — if a refresh ran but didn't return the IRI token, + # fall back to interactive login before raising. + try: + iri_token = extract_iri_token(token_response_data) + except RuntimeError: + if used_refresh: + logger.warning( + "Refreshed tokens did not include the IRI token; " + "falling back to interactive login." + ) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + iri_token = extract_iri_token(token_response_data) + else: + raise + + save_token_file(resolved_token_file, token_response_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return iri_token["access_token"] + + +# --------------------------------------------------------------------------- +# Confidential Client flow (machine-to-machine) +# --------------------------------------------------------------------------- + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client. + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] granted = set(auth_data.get("scope", "").split()) missing = required_scopes - granted @@ -220,16 +412,6 @@ def get_access_token( ) save_token_file(resolved_token_file, auth_data) - logger.info(f"Globus token saved to {resolved_token_file}.") + logger.info(f"New Globus token saved to {resolved_token_file}.") return auth_data["access_token"] - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) diff --git a/scripts/get_globus_token.py b/scripts/get_globus_token.py new file mode 100644 index 00000000..6b615378 --- /dev/null +++ b/scripts/get_globus_token.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main() From 1b98b7c1445c6eec76fabfda760df4f8f690eaa6 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:14 -0700 Subject: [PATCH 05/26] removing token.py and moving the logic to get_globus_token.py --- orchestration/globus/token.py | 417 ---------------------------------- 1 file changed, 417 deletions(-) delete mode 100644 orchestration/globus/token.py diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py deleted file mode 100644 index 4970eaa7..00000000 --- a/orchestration/globus/token.py +++ /dev/null @@ -1,417 +0,0 @@ -# orchestration/globus/token.py -import json -import logging -import os -from pathlib import Path -import stat -import time - -import globus_sdk -from globus_sdk.exc import GlobusAPIError - -logger = logging.getLogger(__name__) - -# Default token file location, matching the Globus SDK convention. -DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" - -# IRI API Globus scope and resource server. -# The IRI access token lives in other_tokens under this scope, not at the -# top level of the auth.globus.org response. -IRI_SCOPE: str = ( - "https://auth.globus.org/scopes/" - "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" -) -IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" - - -# --------------------------------------------------------------------------- -# File I/O -# --------------------------------------------------------------------------- - -def load_token_file(token_file: Path) -> dict | None: - """Load saved Globus token data from disk. - - Args: - token_file: Path to the JSON token file. - - Returns: - Parsed token dict, or None if the file does not exist. - """ - if not token_file.exists(): - return None - with token_file.open("r", encoding="utf-8") as f: - return json.load(f) - - -def save_token_file(token_file: Path, tokens: dict) -> None: - """Atomically save Globus token data to disk with owner-only permissions. - - Writes to a temporary file then renames to avoid partial writes. - - Args: - token_file: Destination path for the JSON token file. - tokens: Token dict to serialise. - """ - _ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) - - -# --------------------------------------------------------------------------- -# IRI token helpers -# --------------------------------------------------------------------------- - -def _parse_scope_string(scope_string: str) -> set[str]: - """Split a space-separated scope string into a set. - - Args: - scope_string: Space-separated OAuth2 scope string. - - Returns: - Set of individual scope strings. - """ - return set(scope_string.split()) if scope_string else set() - - -def extract_iri_token(token_response_data: dict) -> dict: - """Extract the IRI access token entry from a Globus token response. - - The IRI token is not returned at the top level — it lives inside - ``other_tokens``, identified by :data:`IRI_SCOPE`. - - Args: - token_response_data: Full token response dict as returned by the - Globus SDK (i.e. ``token_response.data``). - - Returns: - Token dict for the IRI resource server. - - Raises: - RuntimeError: If no token matching the IRI scope is found. - """ - for token_data in token_response_data.get("other_tokens", []): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - return token_data - raise RuntimeError( - f"Missing token for required IRI scope: {IRI_SCOPE}. " - "Re-run with --force-login and ensure consent is granted for the IRI scope." - ) - - -def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: - """Return a copy of token_response_data with the IRI entry replaced. - - Args: - token_response_data: Full stored token response dict. - iri_token_data: Updated IRI token dict to splice in. - - Returns: - Updated token response dict. - """ - merged = dict(token_response_data) - other_tokens = list(merged.get("other_tokens", [])) - for i, token_data in enumerate(other_tokens): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - other_tokens[i] = iri_token_data - break - else: - other_tokens.append(iri_token_data) - merged["other_tokens"] = other_tokens - return merged - - -def _get_iri_refresh_token(stored_tokens: dict) -> str | None: - """Extract the IRI refresh token from stored token data, if present. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The IRI refresh token string, or None if absent. - """ - try: - return extract_iri_token(stored_tokens).get("refresh_token") - except RuntimeError: - return None - - -def _get_auth_refresh_token(stored_tokens: dict) -> str | None: - """Extract the top-level Globus Auth refresh token from stored data. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The auth refresh token string, or None if absent. - """ - if "refresh_token" in stored_tokens: - return stored_tokens["refresh_token"] - auth_tokens = stored_tokens.get("auth.globus.org") - if isinstance(auth_tokens, dict): - return auth_tokens.get("refresh_token") - return None - - -# --------------------------------------------------------------------------- -# NativeApp flow (interactive) -# --------------------------------------------------------------------------- - -def interactive_login( - client: globus_sdk.NativeAppAuthClient, - requested_scopes: frozenset[str], - prompt_login: bool = False, -) -> dict: - """Run an interactive browser-based Globus login flow. - - Prints an authorization URL, waits for the user to paste an auth code, - and returns the full token response data including ``other_tokens``. - - Args: - client: Globus NativeAppAuthClient to drive the flow. - requested_scopes: Set of OAuth2 scopes to request. Should include - :data:`IRI_SCOPE` to obtain an IRI API token. - prompt_login: If True, add ``prompt=login`` to the authorize URL to - force a fresh identity-provider login. - - Returns: - Full token response dict (``token_response.data``), including - ``other_tokens``. - - Raises: - RuntimeError: If no authorization code is entered, or if the code - exchange fails. - """ - client.oauth2_start_flow( - requested_scopes=" ".join(sorted(requested_scopes)), - refresh_tokens=True, - ) - logger.info("Open this URL in your browser to authenticate with Globus:") - prompt = "login" if prompt_login else globus_sdk.MISSING - logger.info(client.oauth2_get_authorize_url(prompt=prompt)) - code = input("\nEnter authorization code: ").strip() - if not code: - raise RuntimeError( - "No authorization code entered. Re-run the script and paste the " - "code shown by Globus after login." - ) - try: - token_response = client.oauth2_exchange_code_for_tokens(code) - except GlobusAPIError as e: - if e.http_status == 400: - raise RuntimeError( - "Authorization code exchange failed — the code was empty, " - "invalid, expired, or already used. Re-run and try again." - ) from e - raise RuntimeError( - f"Authorization code exchange failed with HTTP {e.http_status}." - ) from e - return token_response.data - - -def _refresh_single_token( - client: globus_sdk.NativeAppAuthClient, - refresh_token: str, -) -> dict | None: - """Attempt a single Globus token refresh, returning raw response data. - - Args: - client: NativeAppAuthClient to drive the refresh. - refresh_token: The stored refresh token. - - Returns: - Raw token response data dict, or None if the refresh failed. - """ - try: - token_response = client.oauth2_refresh_token(refresh_token) - return token_response.data - except GlobusAPIError as e: - logger.warning( - f"Globus token refresh failed ({e.http_status}); " - "will fall back to interactive login." - ) - return None - - -def _refresh_stored_tokens( - client: globus_sdk.NativeAppAuthClient, - stored_tokens: dict, -) -> tuple[dict | None, bool]: - """Try to refresh stored tokens, preferring the IRI refresh token. - - Attempts the IRI-specific refresh token first, then falls back to the - top-level Globus Auth refresh token. - - Args: - client: NativeAppAuthClient to drive the refresh. - stored_tokens: Full stored token response dict. - - Returns: - Tuple of ``(updated_token_data, success)``. On failure both values - are ``(None, False)``. - """ - iri_refresh = _get_iri_refresh_token(stored_tokens) - if iri_refresh: - iri_token_data = _refresh_single_token(client, iri_refresh) - if iri_token_data is not None: - return _replace_iri_token(stored_tokens, iri_token_data), True - - auth_refresh = _get_auth_refresh_token(stored_tokens) - if auth_refresh: - auth_data = _refresh_single_token(client, auth_refresh) - if auth_data is not None: - return auth_data, True - - return None, False - - -def get_access_token( - client_id: str, - requested_scopes: frozenset[str], - token_file: Path | None = None, - force_login: bool = False, - prompt_login: bool = False, -) -> str: - """Get a valid IRI API access token via the NativeApp interactive flow. - - Attempts a silent refresh from the saved token file first. Falls back to - interactive browser login if no saved tokens exist, the refresh token is - absent, or the refresh fails. Saves the resulting tokens back to disk. - - The IRI token is extracted from ``other_tokens`` in the response — it is - not the top-level Globus Auth token. - - Args: - client_id: Globus NativeApp client ID. - requested_scopes: Set of OAuth2 scopes to request. Must include - :data:`IRI_SCOPE` to obtain a usable IRI API token. - token_file: Path to the JSON token file. Defaults to - ``~/.globus/auth_tokens.json``. - force_login: If True, skip refresh and force interactive login. - prompt_login: If True, add ``prompt=login`` to the authorize URL. - - Returns: - A valid IRI API access token string. - - Raises: - RuntimeError: If the IRI scope token is missing from the response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - globus_client = globus_sdk.NativeAppAuthClient(client_id) - - token_response_data: dict | None = None - used_refresh = False - - if not force_login: - stored = load_token_file(resolved_token_file) - if stored: - token_response_data, used_refresh = _refresh_stored_tokens( - globus_client, stored - ) - - if token_response_data is None: - logger.info("Initiating interactive Globus login.") - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - - # Extract IRI token — if a refresh ran but didn't return the IRI token, - # fall back to interactive login before raising. - try: - iri_token = extract_iri_token(token_response_data) - except RuntimeError: - if used_refresh: - logger.warning( - "Refreshed tokens did not include the IRI token; " - "falling back to interactive login." - ) - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - iri_token = extract_iri_token(token_response_data) - else: - raise - - save_token_file(resolved_token_file, token_response_data) - logger.info(f"Globus token saved to {resolved_token_file}.") - - return iri_token["access_token"] - - -# --------------------------------------------------------------------------- -# Confidential Client flow (machine-to-machine) -# --------------------------------------------------------------------------- - -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client. - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] From e7c0eece1df10b9f8660e29991fc802c1dcebc7d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:55 -0700 Subject: [PATCH 06/26] moving get_globus_token.py to orchestration/globus/ to be used as a module --- orchestration/globus/get_globus_token.py | 13 - scripts/get_globus_token.py | 337 ----------------------- 2 files changed, 350 deletions(-) delete mode 100644 scripts/get_globus_token.py diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py index d987947e..4cca122f 100644 --- a/orchestration/globus/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -94,19 +94,6 @@ def load_tokens(token_file: Path) -> dict | None: return json.load(f) -# def save_tokens(token_file: Path, tokens: dict) -> None: -# ensure_private_parent_dir(token_file) -# tmp = token_file.with_suffix(".tmp") -# with os.fdopen( -# os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), -# "w", -# encoding="utf-8", -# ) as f: -# json.dump(tokens, f, indent=2) -# os.replace(tmp, token_file) -# os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - def save_tokens(token_file: Path, tokens: dict) -> None: ensure_private_parent_dir(token_file) # Per-process unique tmp name to avoid races between concurrent writers diff --git a/scripts/get_globus_token.py b/scripts/get_globus_token.py deleted file mode 100644 index 6b615378..00000000 --- a/scripts/get_globus_token.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import json -import os -import stat -import time -import urllib.error -import urllib.request -from pathlib import Path - -import globus_sdk -from globus_sdk.exc import GlobusAPIError - -CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" -RESOURCE_SERVER = "auth.globus.org" -IRI_SCOPE = ( - "https://auth.globus.org/scopes/" - "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" -) -REQUIRED_SCOPES = { - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", -} -REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} -DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" - - -def parse_args() -> argparse.Namespace: - default_token_file = Path.home() / ".globus" / "auth_tokens.json" - parser = argparse.ArgumentParser( - description=( - "Get Globus Auth tokens with required scopes. " - "Tokens are saved to a secure local file by default." - ) - ) - parser.add_argument( - "--token-file", - type=Path, - default=default_token_file, - help=f"Path for saved token JSON (default: {default_token_file})", - ) - parser.add_argument( - "--print-token", - action="store_true", - help="Print the access token to stdout (off by default).", - ) - parser.add_argument( - "--force-login", - action="store_true", - help="Skip refresh and force interactive browser login.", - ) - parser.add_argument( - "--refresh-only", - action="store_true", - help="Refresh saved tokens only; do not fall back to interactive login.", - ) - parser.add_argument( - "--prompt-login", - action="store_true", - help="Add prompt=login to the Globus authorize URL to force re-authentication.", - ) - parser.add_argument( - "--validate-iri", - action="store_true", - help="Validate the IRI token by calling the IRI account/projects endpoint.", - ) - parser.add_argument( - "--iri-validate-url", - default=DEFAULT_IRI_VALIDATE_URL, - help=( - "IRI endpoint used by --validate-iri " - f"(default: {DEFAULT_IRI_VALIDATE_URL})" - ), - ) - return parser.parse_args() - - -def parse_scope_string(scope_string: str) -> set[str]: - return set(scope_string.split()) if scope_string else set() - - -def ensure_private_parent_dir(path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) - - -def load_tokens(token_file: Path) -> dict | None: - if not token_file.exists(): - return None - with token_file.open("r", encoding="utf-8") as f: - return json.load(f) - - -def save_tokens(token_file: Path, tokens: dict) -> None: - ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - -def get_refresh_token(stored_tokens: dict) -> str | None: - if "refresh_token" in stored_tokens: - return stored_tokens.get("refresh_token") - - auth_tokens = stored_tokens.get(RESOURCE_SERVER) - if isinstance(auth_tokens, dict): - return auth_tokens.get("refresh_token") - - return None - - -def get_iri_token(token_response_data: dict) -> dict: - for token_data in token_response_data.get("other_tokens", []): - if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): - return token_data - raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") - - -def get_iri_refresh_token(stored_tokens: dict) -> str | None: - try: - return get_iri_token(stored_tokens).get("refresh_token") - except RuntimeError: - return None - - -def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: - merged = dict(token_response_data) - other_tokens = list(merged.get("other_tokens", [])) - for index, token_data in enumerate(other_tokens): - if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): - other_tokens[index] = iri_token_data - break - else: - other_tokens.append(iri_token_data) - merged["other_tokens"] = other_tokens - return merged - - -def validate_auth_data(auth_data: dict) -> dict: - if auth_data.get("resource_server") != RESOURCE_SERVER: - raise RuntimeError( - f"Missing token for required resource server: {RESOURCE_SERVER}" - ) - - granted = parse_scope_string(auth_data.get("scope", "")) - missing = REQUIRED_SCOPES - granted - if missing: - raise RuntimeError(f"Missing required scopes: {sorted(missing)}") - - return get_iri_token(auth_data) - - -def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: - request = urllib.request.Request( - validate_url, - headers={ - "accept": "application/json", - "Authorization": f"Bearer {iri_token_data['access_token']}", - }, - method="GET", - ) - try: - with urllib.request.urlopen(request) as response: - body = response.read().decode("utf-8") - data = json.loads(body) if body else {} - except urllib.error.HTTPError as exc: - body = exc.read().decode("utf-8") - details = body.strip() or exc.reason - raise RuntimeError( - f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" - ) from exc - except urllib.error.URLError as exc: - raise RuntimeError( - f"IRI validation request failed for {validate_url}: {exc.reason}" - ) from exc - except json.JSONDecodeError as exc: - raise RuntimeError( - f"IRI validation returned non-JSON data from {validate_url}" - ) from exc - - if isinstance(data, dict): - session_info = data.get("session_info") - if isinstance(session_info, dict): - authentications = session_info.get("authentications") - if isinstance(authentications, dict) and not authentications: - raise RuntimeError( - "IRI validation succeeded but session_info.authentications is empty. " - "Re-run with --force-login --prompt-login and use a Chrome incognito window." - ) - - return data - - -def interactive_login( - client: globus_sdk.NativeAppAuthClient, - *, - prompt_login: bool = False, -) -> dict: - client.oauth2_start_flow( - requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), - refresh_tokens=True, - ) - print("Open this URL, login, and consent:") - prompt = "login" if prompt_login else globus_sdk.MISSING - print(client.oauth2_get_authorize_url(prompt=prompt)) - code = input("\nEnter authorization code: ").strip() - if not code: - raise RuntimeError( - "No authorization code entered. Re-run the script and paste the code " - "shown by Globus after login." - ) - try: - token_response = client.oauth2_exchange_code_for_tokens(code) - except GlobusAPIError as exc: - if exc.http_status == 400: - raise RuntimeError( - "Authorization code exchange failed. The code was empty, invalid, " - "expired, or already used. Re-run the script and complete the " - "Globus login flow again." - ) from exc - raise RuntimeError( - f"Authorization code exchange failed with HTTP {exc.http_status}. " - "Re-run the script and try again." - ) from exc - return token_response.data - - -def refresh_tokens( - client: globus_sdk.NativeAppAuthClient, refresh_token: str -) -> dict | None: - try: - token_response = client.oauth2_refresh_token(refresh_token) - return token_response.data - except GlobusAPIError as exc: - print( - f"Refresh failed ({exc.http_status}); switching to interactive login." - ) - return None - - -def refresh_stored_tokens( - client: globus_sdk.NativeAppAuthClient, stored_tokens: dict -) -> tuple[dict | None, bool]: - iri_refresh_token = get_iri_refresh_token(stored_tokens) - if iri_refresh_token: - iri_token_data = refresh_tokens(client, iri_refresh_token) - if iri_token_data is not None: - return replace_iri_token(stored_tokens, iri_token_data), True - - auth_refresh_token = get_refresh_token(stored_tokens) - if auth_refresh_token: - auth_data = refresh_tokens(client, auth_refresh_token) - if auth_data is not None: - return auth_data, True - - return None, False - - -def main() -> None: - args = parse_args() - if args.force_login and args.refresh_only: - raise RuntimeError("Choose only one of --force-login or --refresh-only") - - client = globus_sdk.NativeAppAuthClient(CLIENT_ID) - - auth_data = None - used_refresh = False - if not args.force_login: - stored = load_tokens(args.token_file) - if stored: - auth_data, used_refresh = refresh_stored_tokens(client, stored) - - if auth_data is None: - if args.refresh_only: - raise RuntimeError( - "Refresh-only mode failed. No usable saved refresh token was found " - "or token refresh did not return the required IRI token." - ) - auth_data = interactive_login(client, prompt_login=args.prompt_login) - - try: - iri_token_data = validate_auth_data(auth_data) - except RuntimeError as exc: - if used_refresh and "Missing token for required IRI scope" in str(exc): - print( - "Refreshed tokens did not include the IRI token; " - "switching to interactive login." - ) - auth_data = interactive_login(client, prompt_login=args.prompt_login) - iri_token_data = validate_auth_data(auth_data) - else: - raise - - save_tokens(args.token_file, auth_data) - - if args.validate_iri: - validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) - print(f"IRI validation succeeded against {args.iri_validate_url}") - if isinstance(validation_data, dict): - session_info = validation_data.get("session_info") - if isinstance(session_info, dict): - session_id = session_info.get("session_id") - if session_id: - print(f"IRI session_id: {session_id}") - elif isinstance(validation_data, list): - print(f"IRI validation response items: {len(validation_data)}") - - expires_at = iri_token_data.get("expires_at_seconds") - if expires_at: - ttl = int(expires_at - time.time()) - print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") - - print(f"Saved token data to {args.token_file}") - print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") - print(f"IRI token resource server: {iri_token_data.get('resource_server')}") - print(f"IRI token scopes: {iri_token_data.get('scope', '')}") - - if args.print_token: - print("\nIRI access token:") - print(iri_token_data["access_token"]) - else: - print( - "IRI access token not printed " - "(use --print-token to display it for the NERSC IRI API)." - ) - - -if __name__ == "__main__": - main() From 252199d30e199573d7375664f020be6e748d859c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:34:53 -0700 Subject: [PATCH 07/26] Updating unit tests --- orchestration/_tests/test_bl832/test_nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index f303bac1..460c43d6 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -211,11 +211,11 @@ def mock_iriapi_client(mocker): return client - # --------------------------------------------------------------------------- # _create_sfapi_client # --------------------------------------------------------------------------- + def test_create_sfapi_client_success(mocker): """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController From c472f9d3ae6d7ddc8d9c0ead93bb3c40820f32c3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 7 Apr 2026 14:29:40 -0700 Subject: [PATCH 08/26] Rebasing and including segmentation flows as part of iri/sfapi abstraction --- orchestration/_tests/test_bl832/test_nersc.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 460c43d6..8db8ae0f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -23,16 +23,6 @@ def prefect_test_fixture(): # Shared fixtures # --------------------------------------------------------------------------- -@pytest.fixture -def mock_config(mocker): - config = mocker.MagicMock() - config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - return config - - @pytest.fixture def mock_sfapi_client(mocker): """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" From 1b4624d83f8bdd10523f52a77ff646886a208145 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 14:50:03 -0700 Subject: [PATCH 09/26] Making IRIAPI the default login method for now --- orchestration/flows/bl832/nersc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2860eeda..6d57f70f 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2742,6 +2742,7 @@ def nersc_segmentation_sam3_integration_test() -> bool: flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, config=Config832(), + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From c36d9c471c836af7201a900ce6b03b4961f65a1d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 19:21:45 -0700 Subject: [PATCH 10/26] Making the IRI job submission read sbatch settings --- orchestration/flows/bl832/nersc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 6d57f70f..0fd99b97 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -370,6 +370,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: + # Parse SBATCH directives before stripping them sbatch_values = {} for line in job_script.splitlines(): if line.startswith("#SBATCH"): @@ -379,6 +380,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["account"] = line.split("-A ")[-1].strip() elif "--time=" in line: t = line.split("--time=")[-1].strip() + # convert HH:MM:SS to seconds parts = t.split(":") sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) elif "-N " in line: From c3d3b1e61ae474a09ceee31b6d40202f814a3b2b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:41:33 -0700 Subject: [PATCH 11/26] Fixing IRIAPI bugs, also commenting out Globus transfers for now --- orchestration/flows/bl832/nersc.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 0fd99b97..85774465 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -370,7 +370,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: - # Parse SBATCH directives before stripping them sbatch_values = {} for line in job_script.splitlines(): if line.startswith("#SBATCH"): @@ -380,7 +379,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["account"] = line.split("-A ")[-1].strip() elif "--time=" in line: t = line.split("--time=")[-1].strip() - # convert HH:MM:SS to seconds parts = t.split(":") sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) elif "-N " in line: @@ -395,7 +393,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() # Strip shebang and SBATCH headers, keep the script body - script_body = "\n".join( line for line in job_script.splitlines() if not line.startswith("#SBATCH") and not line.startswith("#!/") @@ -448,6 +445,13 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: response.raise_for_status() return str(response.json()["id"]) + # response = self.client.post( + # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", + # json=job_spec, + # ) + # response.raise_for_status() + # return str(response.json()["id"]) + else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -1998,7 +2002,6 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, - login_method=login_method ) if isinstance(recon_result, dict): @@ -2052,7 +2055,7 @@ def nersc_petiole_segment_flow( logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method @@ -2105,7 +2108,7 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) combine_success = combine_future.result() @@ -2651,7 +2654,7 @@ def nersc_segmentation_sam3_task( tomography_controller = get_controller( hpc_type=HPC.NERSC, config=config, - login_method=login_method + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2750,14 +2753,14 @@ def nersc_segmentation_sam3_integration_test() -> bool: return flow_success -if __name__ == "__main__": +# if __name__ == "__main__": # nersc_segmentation_dinov3_task( # recon_folder_path='dabramov/recmoon/', # config=Config832(), # project="moon" # ) - nersc_petiole_segment_flow( - file_path='dabramov/20260221_143000_petiole28', - num_nodes=4, - login_method=NERSCLoginMethod.IRIAPI - ) + # nersc_petiole_segment_flow( + # file_path='dabramov/20260221_143000_petiole28', + # num_nodes=2, + # login_method=NERSCLoginMethod.IRIAPI + # ) From 33bb1f3e602a00448857ccf129d1e6a35d0fc932 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 23 Apr 2026 11:04:53 -0700 Subject: [PATCH 12/26] Updating logger comments --- orchestration/flows/bl832/nersc.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 85774465..3613b3f5 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -445,13 +445,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: response.raise_for_status() return str(response.json()["id"]) - # response = self.client.post( - # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", - # json=job_spec, - # ) - # response.raise_for_status() - # return str(response.json()["id"]) - else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -2747,7 +2740,6 @@ def nersc_segmentation_sam3_integration_test() -> bool: flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, config=Config832(), - login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From 5192ecdaa137ecab000bfb364c74ab3ed598a938 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 12:37:31 -0700 Subject: [PATCH 13/26] connecting to AmSC MLflow service --- orchestration/flows/bl832/register_mlflow.py | 76 ++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index c224c769..1358028d 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -93,6 +93,82 @@ def register_mlflow_checkpoints(): }, ) + # register_checkpoint( + # model_name="sam3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", + # config=config, + # alias="production", + # description="SAM3 v6 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "original_checkpoint_path": + # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", + # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", + # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src/inference_v6.py", + # "batch_size": 1, + # "patch_size": 400, + # "confidence": [0.5], # list → JSON-encoded automatically + # "overlap": 0.25, + # "prompts": [ # list → JSON-encoded automatically + # "Phloem Fibers", + # "Hydrated Xylem vessels", + # "Air-based Pith cells", + # "Dehydrated Xylem vessels", + # ], + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src.inference_dino_v1", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-moon", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", + # inference_params={ + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # "script_name": "src.inference_dino_v2", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. From fbfaac707e91dbe99473eefc6136cde827c16281 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:20:29 -0700 Subject: [PATCH 14/26] removing old commented code --- orchestration/flows/bl832/register_mlflow.py | 76 -------------------- 1 file changed, 76 deletions(-) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index 1358028d..c224c769 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -93,82 +93,6 @@ def register_mlflow_checkpoints(): }, ) - # register_checkpoint( - # model_name="sam3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", - # config=config, - # alias="production", - # description="SAM3 v6 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "original_checkpoint_path": - # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", - # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", - # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", - # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src/inference_v6.py", - # "batch_size": 1, - # "patch_size": 400, - # "confidence": [0.5], # list → JSON-encoded automatically - # "overlap": 0.25, - # "prompts": [ # list → JSON-encoded automatically - # "Phloem Fibers", - # "Hydrated Xylem vessels", - # "Air-based Pith cells", - # "Dehydrated Xylem vessels", - # ], - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src.inference_dino_v1", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-moon", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", - # inference_params={ - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", - # "script_name": "src.inference_dino_v2", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. From 6c1837d1f9f6af7b8e9f1c461dc50572feacf76f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:31:46 -0700 Subject: [PATCH 15/26] adjusting import in pytest to avoid error on github that did not occur locally --- orchestration/_tests/test_bl832/test_nersc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 8db8ae0f..51230c1f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -479,6 +479,7 @@ def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_co def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") From 8055eaeb0da3002f11b0eeb44c6f3746f082245f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 09:39:23 -0700 Subject: [PATCH 16/26] Getting NERSC reservations working with IRI API --- orchestration/flows/bl832/nersc.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 3613b3f5..a07cbe28 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -163,7 +163,7 @@ def __init__( self, config: Config832, client: Client | httpx.Client | None = None, - login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client @@ -2745,14 +2745,14 @@ def nersc_segmentation_sam3_integration_test() -> bool: return flow_success -# if __name__ == "__main__": +if __name__ == "__main__": # nersc_segmentation_dinov3_task( # recon_folder_path='dabramov/recmoon/', # config=Config832(), # project="moon" # ) - # nersc_petiole_segment_flow( - # file_path='dabramov/20260221_143000_petiole28', - # num_nodes=2, - # login_method=NERSCLoginMethod.IRIAPI - # ) + nersc_petiole_segment_flow( + file_path='dabramov/20260221_143000_petiole28', + num_nodes=4, + login_method=NERSCLoginMethod.IRIAPI + ) From cbd7b7ebba3897c38d07fac520dad98a39520da6 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 7 May 2026 15:48:54 -0700 Subject: [PATCH 17/26] fixing globus token race condition when jobs are launch simultaneously --- orchestration/globus/get_globus_token.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py index 4cca122f..d987947e 100644 --- a/orchestration/globus/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -94,6 +94,19 @@ def load_tokens(token_file: Path) -> dict | None: return json.load(f) +# def save_tokens(token_file: Path, tokens: dict) -> None: +# ensure_private_parent_dir(token_file) +# tmp = token_file.with_suffix(".tmp") +# with os.fdopen( +# os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), +# "w", +# encoding="utf-8", +# ) as f: +# json.dump(tokens, f, indent=2) +# os.replace(tmp, token_file) +# os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + def save_tokens(token_file: Path, tokens: dict) -> None: ensure_private_parent_dir(token_file) # Per-process unique tmp name to avoid races between concurrent writers From 34b51c04e9ef274002693673fcddab4e02ce84f3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 12 May 2026 07:33:45 -0700 Subject: [PATCH 18/26] updating config with confab reservation --- config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yml b/config.yml index fb0b9989..cfa3e8fc 100644 --- a/config.yml +++ b/config.yml @@ -225,7 +225,7 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: realtime account: als - reservation: "" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" cpus-per-task: 128 walltime: "0:15:00" From 91dfa00888c3280239f7c9fd906f9a82445b3cda Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:35:16 -0700 Subject: [PATCH 19/26] moving nersc iri/sf-api resource definitions to config (no longer global variables) --- config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yml b/config.yml index cfa3e8fc..c8935df7 100644 --- a/config.yml +++ b/config.yml @@ -225,7 +225,7 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: realtime account: als - reservation: "_CAP_SYNAPS_LIVEDEMO_CPU1" + reservation: "_CAP_SYNAPS_LIVEDEMO_CPU2" cpus-per-task: 128 walltime: "0:15:00" From 284de47d98042ab90dab99e527f6eb9c247b2d47 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 20 May 2026 14:38:30 -0700 Subject: [PATCH 20/26] Updating nersc.py to pull iri/sf-api parameters from the config, rather than a global variable --- orchestration/_tests/test_bl832/test_nersc.py | 1 - orchestration/flows/bl832/nersc.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 51230c1f..8db8ae0f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -479,7 +479,6 @@ def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_co def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod - from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index a07cbe28..7a13bd18 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -195,7 +195,7 @@ def create_nersc_client( Args: config: Config832 instance for accessing config settings needed during client creation. login_method: Which NERSC API to authenticate against. - Defaults to :attr:`NERSCLoginMethod.SFAPI`. + Defaults to :attr:`NERSCLoginMethod.IRIAPI`. Returns: An authenticated :class:`sfapi_client.Client` instance. From d864e80d2ed2e638bcf456965debc4c21323bd87 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 21 May 2026 13:39:20 -0700 Subject: [PATCH 21/26] itial commit for a new orchestration/jobs/ structure --- orchestration/jobs/__init__.py | 0 orchestration/jobs/alcf/__init__.py | 0 orchestration/jobs/alcf/controller.py | 199 ++++++++++++++ orchestration/jobs/controller.py | 57 ++++ orchestration/jobs/nersc/__init__.py | 0 orchestration/jobs/nersc/controller.py | 346 +++++++++++++++++++++++++ orchestration/jobs/nersc/login.py | 167 ++++++++++++ orchestration/jobs/nersc/shifter.py | 242 +++++++++++++++++ orchestration/jobs/options.py | 105 ++++++++ 9 files changed, 1116 insertions(+) create mode 100644 orchestration/jobs/__init__.py create mode 100644 orchestration/jobs/alcf/__init__.py create mode 100644 orchestration/jobs/alcf/controller.py create mode 100644 orchestration/jobs/controller.py create mode 100644 orchestration/jobs/nersc/__init__.py create mode 100644 orchestration/jobs/nersc/controller.py create mode 100644 orchestration/jobs/nersc/login.py create mode 100644 orchestration/jobs/nersc/shifter.py create mode 100644 orchestration/jobs/options.py diff --git a/orchestration/jobs/__init__.py b/orchestration/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/alcf/__init__.py b/orchestration/jobs/alcf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/alcf/controller.py b/orchestration/jobs/alcf/controller.py new file mode 100644 index 00000000..6dcf9644 --- /dev/null +++ b/orchestration/jobs/alcf/controller.py @@ -0,0 +1,199 @@ +# orchestration/jobs/alcf/controller.py + +"""Generic ALCF job controller. + +Wraps Globus Compute's submit/wait pattern behind a common controller +interface. Beamline-specific job functions (the actual Python callables +executed remotely on ALCF compute nodes) live in +``orchestration/flows//alcf.py`` and subclass this controller. + +ALCF execution is fundamentally different from NERSC in shape: instead of +submitting a Slurm batch script and getting a job ID back, you submit a +Python callable (with arguments) and get a :class:`concurrent.futures.Future` +back. The wait/poll logic is correspondingly different too. + +Authentication is via a Globus Compute endpoint ID stored in the Prefect +Secret block ``globus-compute-endpoint``. The allocation root path on the +remote filesystem (e.g. ``/eagle/IRIProd/ALS``) is stored in the Prefect +Variable ``alcf-allocation-root-path``. +""" + +import logging +import time +from concurrent.futures import Future +from typing import Any, Callable + +from globus_compute_sdk import Client, Executor +from globus_compute_sdk.serialize import CombinedCode +from prefect import get_run_logger +from prefect.blocks.system import Secret +from prefect.variables import Variable + +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobController + +logger = logging.getLogger(__name__) + + +# Defaults for wait_for_future. Override per-call when a job is known to be +# longer-running or faster-polling than typical. +_DEFAULT_CHECK_INTERVAL_SECONDS: int = 20 +_DEFAULT_WALLTIME_SECONDS: int = 1200 # 20 minutes + +# Prefect block / variable names. Defined as constants so callers and tests +# have a single place to override. +_GLOBUS_COMPUTE_ENDPOINT_SECRET: str = "globus-compute-endpoint" +_ALLOCATION_ROOT_VARIABLE: str = "alcf-allocation-root-path" + + +class ALCFJobController(JobController): + """Generic ALCF job submission and monitoring via Globus Compute. + + Subclass for beamline-specific work (e.g. ``ALCFTomographyJobController`` + in ``flows/bl832/alcf.py``). This class knows nothing about tomography, + BL832, or any particular pipeline — it only knows how to submit a + callable to ALCF and wait for it. + + Args: + config: Beamline configuration object. Stored as ``self.config`` + for subclass use; this controller doesn't read any specific + fields from it. + + Attributes: + allocation_root: Remote path prefix on the ALCF filesystem + (e.g. ``/eagle/IRIProd/ALS``). Read from the + ``alcf-allocation-root-path`` Prefect Variable. Available for + subclasses to construct script and data paths. + endpoint_id: Globus Compute endpoint UUID, loaded from the + ``globus-compute-endpoint`` Prefect Secret block. + """ + + def __init__(self, config: BeamlineConfig) -> None: + super().__init__(config) + + allocation_data = Variable.get(_ALLOCATION_ROOT_VARIABLE, _sync=True) + self.allocation_root: str = allocation_data.get(_ALLOCATION_ROOT_VARIABLE) + if not self.allocation_root: + raise ValueError( + f"Allocation root not found in Prefect Variable " + f"'{_ALLOCATION_ROOT_VARIABLE}'" + ) + logger.info(f"Allocation root loaded: {self.allocation_root}") + + self.endpoint_id: str = Secret.load(_GLOBUS_COMPUTE_ENDPOINT_SECRET).get() + + def submit(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Future: + """Submit a callable to the Globus Compute endpoint. + + The callable is shipped to the ALCF compute node and executed there + with the supplied args/kwargs. Code-serialization uses + :class:`CombinedCode` so the function can reference imports and + helpers defined in the same module without manual packing. + + Args: + func: The function to run remotely. Should be picklable and + self-contained — imports inside the function body are + evaluated on the remote node, not locally. + *args: Positional arguments passed to ``func`` on the remote node. + **kwargs: Keyword arguments passed to ``func`` on the remote node. + + Returns: + A :class:`concurrent.futures.Future`-compatible object. Use + :meth:`wait_for_future` to poll until done. + """ + gcc = Client(code_serialization_strategy=CombinedCode()) + # Note: the Executor context manager handles connection cleanup, but + # the Future remains valid after the executor exits — submission has + # already been dispatched to the remote endpoint by that point. + with Executor(endpoint_id=self.endpoint_id, client=gcc) as fxe: + future = fxe.submit(func, *args, **kwargs) + return future + + @staticmethod + def wait_for_future( + future: Future, + task_name: str, + check_interval: int = _DEFAULT_CHECK_INTERVAL_SECONDS, + walltime: int = _DEFAULT_WALLTIME_SECONDS, + ) -> bool: + """Block until a Globus Compute future completes or hits walltime. + + Polls ``future.done()`` every ``check_interval`` seconds. If the + future is still not done after ``walltime`` seconds total, the + future is cancelled and ``False`` is returned. Logging uses Prefect's + run logger so progress shows up in the flow run UI. + + Args: + future: The future returned by :meth:`submit`. + task_name: Short descriptive name used in log messages + (e.g. ``"reconstruction"``, ``"tiff to zarr"``). + check_interval: Seconds between ``future.done()`` polls. + walltime: Maximum total seconds to wait before giving up and + cancelling. + + Returns: + True if the task completed successfully (future returned without + raising). False if the task was cancelled, raised an exception, + timed out, or an error occurred during polling. + """ + run_logger = get_run_logger() + start_time = time.time() + success = False + + try: + previous_state = None + while not future.done(): + elapsed_time = time.time() - start_time + if elapsed_time > walltime: + run_logger.error( + f"The {task_name} task exceeded the walltime of " + f"{walltime} seconds. Cancelling the Globus Compute job." + ) + future.cancel() + return False + + if future.cancelled(): + run_logger.warning(f"The {task_name} task was cancelled.") + return False + + # Assume the task is running if not done and not cancelled. + # Log once per state transition rather than every poll. + if previous_state != "running": + run_logger.info(f"The {task_name} task is running...") + previous_state = "running" + + time.sleep(check_interval) + + # Future is done — check whether it was cancelled or raised. + if future.cancelled(): + run_logger.warning( + f"The {task_name} task was cancelled after completion." + ) + return False + + exception = future.exception() + if exception: + run_logger.error( + f"The {task_name} task raised an exception: {exception}" + ) + return False + + result = future.result() + run_logger.info( + f"The {task_name} task completed successfully with result: {result}" + ) + success = True + + except Exception as e: + run_logger.error( + f"An error occurred while waiting for the {task_name} task: {e}" + ) + success = False + + finally: + elapsed_time = time.time() - start_time + run_logger.info( + f"Total duration of the {task_name} task: {elapsed_time:.2f} seconds." + ) + + return success diff --git a/orchestration/jobs/controller.py b/orchestration/jobs/controller.py new file mode 100644 index 00000000..5e312004 --- /dev/null +++ b/orchestration/jobs/controller.py @@ -0,0 +1,57 @@ +# orchestration/jobs/controller.py +"""Generic job controller abstractions. + +Defines the minimal base class and target enum that any backend-specific +controller (NERSC, ALCF, future local clusters) builds on. Beamline-specific +controllers live under ``orchestration/flows//`` and subclass +:class:`JobController` (or a domain-specific intermediate like +``TomographyJobController``). +""" + +from abc import ABC +from enum import Enum +import logging + +from orchestration.config import BeamlineConfig + +logger = logging.getLogger(__name__) + + +class JobTarget(Enum): + """Identifies a job execution target. + + Used by beamline-specific factories (e.g. ``flows/bl832/job_controller.py``) + to dispatch to the appropriate concrete controller class. + + Members: + ALCF: Argonne Leadership Computing Facility + NERSC: National Energy Research Scientific Computing Center + OLCF: Oak Ridge Leadership Computing Facility + """ + + ALCF = "ALCF" + NERSC = "NERSC" + OLCF = "OLCF" + + +class JobController(ABC): + """Base class for job controllers. + + Subclasses provide target-specific job submission and monitoring (see + :class:`orchestration.jobs.nersc.controller.NERSCJobController` and + :class:`orchestration.jobs.alcf.controller.ALCFJobController`). + + No abstract methods are declared here because submission shapes differ + fundamentally between targets — NERSC submits Slurm scripts and returns + job IDs, ALCF submits Python callables via Globus Compute and returns + futures. Domain-specific intermediates (e.g. ``TomographyJobController``) + may add ``@abstractmethod``s like ``reconstruct(file_path)`` when multiple + controllers share an interface. + + Args: + config: Beamline configuration object. Stored as ``self.config`` for + subclass use. + """ + + def __init__(self, config: BeamlineConfig) -> None: + self.config = config diff --git a/orchestration/jobs/nersc/__init__.py b/orchestration/jobs/nersc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/jobs/nersc/controller.py b/orchestration/jobs/nersc/controller.py new file mode 100644 index 00000000..3d52fe90 --- /dev/null +++ b/orchestration/jobs/nersc/controller.py @@ -0,0 +1,346 @@ +# orchestration/jobs/nersc/controller.py + +"""Generic NERSC job controller. + +Wraps the NERSC SFAPI and IRI API behind a common controller interface: +submit Slurm scripts, poll for completion, and perform basic filesystem +operations on Perlmutter. Beamline-specific job script builders live in +``orchestration/flows//nersc.py`` and subclass this controller. + +Two authentication modes are supported, selected by :class:`NERSCLoginMethod`: + +- :attr:`NERSCLoginMethod.SFAPI` — Iris-registered OAuth2 (NERSC OIDC). + Operations go through ``sfapi_client.Client``. + +- :attr:`NERSCLoginMethod.IRIAPI` — Globus bearer token. Operations are + raw ``httpx`` calls to the IRI API. + +The login method also selects the ``nersc_resources`` sub-dict +(``config.nersc_resources["iri"]`` vs ``["sfapi"]``), which carries the API +base URL and resource UUIDs used in URL construction. +""" + +import json +import logging +import os +import time + +import httpx +from sfapi_client import Client +from sfapi_client.compute import Machine + +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobController +from orchestration.jobs.nersc.login import NERSCLoginMethod + +logger = logging.getLogger(__name__) + + +class NERSCJobController(JobController): + """Generic NERSC job submission and monitoring. + + Subclass for beamline-specific work (e.g. ``NERSCTomographyJobController`` + in ``flows/bl832/nersc.py``). This class knows nothing about tomography, + BL832, or any particular pipeline — it only knows how to submit a Slurm + script and wait for it. + + Args: + config: Beamline configuration object. Must expose + ``config.nersc_resources`` with ``"iri"`` and ``"sfapi"`` sub-dicts + populated from the YAML. + client: Authenticated NERSC client. Build with + :func:`orchestration.jobs.nersc.login.create_nersc_client`. + login_method: Which NERSC API the client targets. Determines URL + construction and which dispatch branch is used in each method. + + Attributes: + client: The authenticated NERSC client. + login_method: The selected :class:`NERSCLoginMethod`. + nersc_resources: Sub-dict of ``config.nersc_resources`` for the chosen + login method. Contains ``api_base_url``, ``perlmutter_login``, + ``perlmutter_job_submit``, ``compute_resource``, etc. + """ + + def __init__( + self, + config: BeamlineConfig, + client: Client | httpx.Client | None = None, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, + ) -> None: + super().__init__(config) + self.client = client + self.login_method = login_method + + if login_method is NERSCLoginMethod.IRIAPI: + self.nersc_resources: dict[str, str] = config.nersc_resources["iri"] + elif login_method is NERSCLoginMethod.SFAPI: + self.nersc_resources = config.nersc_resources["sfapi"] + else: + raise ValueError(f"Unsupported NERSCLoginMethod: {login_method}") + + def get_nersc_username(self) -> str: + """Return the NERSC username, used to construct ``pscratch`` paths. + + SFAPI exposes the username via the user endpoint. IRIAPI does not, so + the username is read from the ``NERSC_USERNAME`` environment variable. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and ``NERSC_USERNAME`` is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + + def submit_job(self, job_script: str, num_nodes: int = 1) -> str: + """Submit a Slurm batch script and return the job ID. + + For SFAPI, the script is passed verbatim to ``perlmutter.submit_job``. + + For IRIAPI, the script is parsed: SBATCH headers become attributes in + the PSI/J-style job spec, and the script body (everything after the + SBATCH block) becomes the ``pre_launch`` payload. The IRI API does not + accept raw Slurm scripts, so this translation is necessary. + + Args: + job_script: The full Slurm batch script to submit. + num_nodes: Reserved for future use; the actual node count is read + from the SBATCH ``-N`` line in the script. + + Returns: + The submitted job ID as a string. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + httpx.HTTPStatusError: If the IRI API submission returns non-2xx. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + return self._submit_job_iriapi(job_script) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _submit_job_iriapi(self, job_script: str) -> str: + """Translate a Slurm script into a PSI/J job spec and POST to IRI API.""" + sbatch_values: dict[str, object] = {} + for line in job_script.splitlines(): + if not line.startswith("#SBATCH"): + continue + if "-q " in line: + sbatch_values["queue_name"] = line.split("-q ")[-1].strip() + elif "-A " in line: + sbatch_values["account"] = line.split("-A ")[-1].strip() + elif "--time=" in line: + t = line.split("--time=")[-1].strip() + parts = t.split(":") + sbatch_values["duration"] = ( + int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2]) + ) + elif "-N " in line: + sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) + elif "-C " in line: + sbatch_values["constraint"] = line.split("-C ")[-1].strip() + elif "--output=" in line: + sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() + elif "--error=" in line: + sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + elif "--reservation=" in line: + sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() + + # Script body: everything except shebang and SBATCH headers. + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + constraint = sbatch_values.get("constraint", "cpu") + is_gpu = "gpu" in str(constraint).lower() + + resources = { + "node_count": sbatch_values.get("node_count", 1), + "processes_per_node": 1, + "exclusive_node_use": True, + } + if is_gpu: + resources["gpu_cores_per_process"] = 4 + else: + resources["cpu_cores_per_process"] = 128 + + attributes = { + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "regular"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": {"constraint": constraint}, + } + if "reservation" in sbatch_values: + attributes["reservation_id"] = sbatch_values["reservation"] + + job_spec = { + "executable": "/bin/bash", + # Reading the script from stdin isn't supported, so the body goes + # into pre_launch (runs before the executable's main entry point). + "arguments": ["-s"], + "pre_launch": script_body, + "resources": resources, + "attributes": attributes, + } + if "stdout_path" in sbatch_values: + job_spec["stdout_path"] = sbatch_values["stdout_path"] + if "stderr_path" in sbatch_values: + job_spec["stderr_path"] = sbatch_values["stderr_path"] + + response = self.client.post( + f"/api/v1/compute/job/{self.nersc_resources['perlmutter_job_submit']}", + json=job_spec, + ) + if not response.is_success: + logger.error(f"Job submission failed: {response.status_code} {response.text}") + logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") + response.raise_for_status() + return str(response.json()["id"]) + + def wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job reaches a terminal state. + + For SFAPI, this delegates to ``sfapi_client``'s ``job.complete()``, + which handles polling internally. For IRIAPI, polls the status + endpoint every 60 seconds. + + Args: + job_id: The job ID returned by :meth:`submit_job`. + + Returns: + True if the job completed successfully, False if it failed, + was canceled, or hit a timeout. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/{self.nersc_resources['compute_resource']}/{job_id}" + ) + response.raise_for_status() + state = response.json().get("status", {}).get("state") + logger.info(f"Job {job_id} state: {state}") + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def mkdir_remote(self, path: str) -> None: + """Create a directory on Perlmutter. + + Equivalent to ``mkdir -p`` — intermediate directories are created and + existing directories are not an error. + + Args: + path: Absolute path to create on Perlmutter. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + httpx.HTTPStatusError: If the IRI filesystem call fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + perlmutter.run(f"mkdir -p {path}") + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + f"/api/v1/filesystem/mkdir/{self.nersc_resources['perlmutter_login']}", + json={"path": path, "parents": True}, + ) + response.raise_for_status() + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def read_remote_file(self, path: str) -> str: + """Read a file on Perlmutter and return its contents as a string. + + SFAPI runs ``cat`` synchronously. IRIAPI uses the async filesystem + view endpoint, which returns a task_id that must be polled. + + Args: + path: Absolute path to the file on Perlmutter. + + Returns: + File contents as a string. + + Raises: + ValueError: If ``self.login_method`` is unrecognized. + RuntimeError: If the IRI read task ends in a failed state. + TimeoutError: If the IRI read task does not complete within + the local polling budget (~120 seconds). + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {path}") + if isinstance(result, str): + return result + if hasattr(result, "output"): + return result.output + if hasattr(result, "stdout"): + return result.stdout + return str(result) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + return self._read_remote_file_iriapi(path) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _read_remote_file_iriapi(self, path: str) -> str: + """Read a remote file via the IRI async filesystem-view endpoint.""" + response = self.client.get( + f"/api/v1/filesystem/view/{self.nersc_resources['perlmutter_login']}", + params={"path": path}, + ) + response.raise_for_status() + task_id = response.json().get("task_id") + if not task_id: + return response.text + + for _ in range(40): + task_response = self.client.get(f"/api/v1/task/{task_id}") + task_response.raise_for_status() + task = task_response.json() + status = task.get("status") + if status == "completed": + result = task.get("result", "") + if isinstance(result, dict): + output = result.get("output", result) + if isinstance(output, dict): + return output.get("content", str(output)) + return str(output) + return str(result) + elif status == "failed": + raise RuntimeError( + f"File read task {task_id} failed: {task.get('result')}" + ) + time.sleep(3) + + raise TimeoutError(f"File read task {task_id} did not complete") diff --git a/orchestration/jobs/nersc/login.py b/orchestration/jobs/nersc/login.py new file mode 100644 index 00000000..7b0776b1 --- /dev/null +++ b/orchestration/jobs/nersc/login.py @@ -0,0 +1,167 @@ +# orchestration/jobs/nersc/login.py + +"""NERSC API authentication and client construction. + +Two login methods are supported, each backed by a different credential model +and pointed at a different API base URL: + +- :attr:`NERSCLoginMethod.SFAPI`: Iris-registered OAuth2 client ID + private + key (NERSC OIDC flow). Reads ``PATH_NERSC_CLIENT_ID`` and + ``PATH_NERSC_PRI_KEY`` from the environment. +- :attr:`NERSCLoginMethod.IRIAPI`: Globus bearer token written by + ``orchestration/globus/get_globus_token.py``. Reads ``PATH_GLOBUS_TOKEN_FILE`` + from the environment, falling back to ``~/.globus/auth_tokens.json``. + +The base URLs for each method live in the beamline config under +``nersc_resources.{iri,sfapi}.api_base_url``. +""" + +from enum import Enum +import json +import logging +import os +from pathlib import Path + +from authlib.jose import JsonWebKey +from dotenv import load_dotenv +import httpx +from sfapi_client import Client + +from orchestration.config import BeamlineConfig +from orchestration.globus.get_globus_token import ( + DEFAULT_TOKEN_FILE, + get_iri_access_token, +) + +logger = logging.getLogger(__name__) +load_dotenv() + +# Env var pointing at the cached Globus token file used by IRIAPI auth. +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" + + +class NERSCLoginMethod(Enum): + """Selects which NERSC API to authenticate against. + + Each method has its own credentials and API base URL — see module docstring. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via Globus bearer token.""" + + +def create_nersc_client( + config: BeamlineConfig, + login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, +) -> Client | httpx.Client: + """Create and return a NERSC client for the requested login method. + + Reads the API base URL from ``config.nersc_resources[login_method.value]``, + then delegates to the appropriate underscored builder. + + Args: + config: Beamline config instance. Must expose ``nersc_resources`` with + ``"iri"`` and ``"sfapi"`` sub-dicts each containing ``api_base_url``. + login_method: Which NERSC API to authenticate against. Defaults to + :attr:`NERSCLoginMethod.IRIAPI`. + + Returns: + An authenticated client — :class:`sfapi_client.Client` for SFAPI, + :class:`httpx.Client` for IRIAPI. + + Raises: + ValueError: If SFAPI credential env vars are unset, or if + ``login_method`` is not a recognized member. + FileNotFoundError: If SFAPI credential files are absent. + RuntimeError: If the Globus token is expired or missing required scopes. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + + if login_method is NERSCLoginMethod.SFAPI: + api_base_url = config.nersc_resources["sfapi"]["api_base_url"] + client = _create_sfapi_client() + elif login_method is NERSCLoginMethod.IRIAPI: + api_base_url = config.nersc_resources["iri"]["api_base_url"] + client = _create_iriapi_client(api_base_url) + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_base_url})." + ) + return client + + +def _create_iriapi_client(api_base_url: str) -> httpx.Client: + """Create a NERSC IRI API client using a Globus bearer token. + + Reuses a cached token if valid; otherwise mints a new one via the client + credentials grant. No browser or user interaction. + + Args: + api_base_url: Base URL for the NERSC IRI API. + + Returns: + An authenticated :class:`httpx.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_iri_access_token( + token_file=token_file, + force_login=False, + prompt_login=False, + ) + + return httpx.Client( + base_url=api_base_url, + headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), + ) + + +def _create_sfapi_client() -> Client: + """Create a NERSC SFAPI client from Iris-registered OAuth2 credentials. + + Reads the client ID and private key paths from ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY``. When generating the SFAPI key in Iris, the + "asldev" user must be selected so the key has the necessary data-access + permissions. + + Returns: + An authenticated :class:`sfapi_client.Client`. + + Raises: + ValueError: If the credential env vars are unset. + FileNotFoundError: If the credential files are absent. + """ + client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") + client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") + + if not client_id_path or not client_secret_path: + logger.error("NERSC credentials paths are missing.") + raise ValueError("Missing NERSC credentials paths.") + if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): + logger.error("NERSC credential files are missing.") + raise FileNotFoundError("NERSC credential files are missing.") + + with open(client_id_path, "r") as f: + client_id = f.read() + with open(client_secret_path, "r") as f: + client_secret = JsonWebKey.import_key(json.loads(f.read())) + + try: + client = Client(client_id, client_secret) + logger.info("NERSC client created successfully.") + return client + except Exception as e: + logger.error(f"Failed to create NERSC client: {e}") + raise diff --git a/orchestration/jobs/nersc/shifter.py b/orchestration/jobs/nersc/shifter.py new file mode 100644 index 00000000..348b4be7 --- /dev/null +++ b/orchestration/jobs/nersc/shifter.py @@ -0,0 +1,242 @@ +# orchestration/jobs/nersc/shifter.py + +"""Shifter container image cache management at NERSC. + +Shifter is NERSC's container runtime. Images pulled into the per-system cache +are available to all subsequent jobs via ``shifter --image=`` without +a per-job pull penalty. These helpers manage that cache. + +Both functions take a :class:`NERSCJobController` so they can use its +``submit_job`` / ``wait_for_job`` / ``read_remote_file`` primitives. This +keeps the controller focused on the submit/wait/filesystem core and avoids +ballooning it with one-off Shifter operations. +""" + +import logging +import time +from typing import TYPE_CHECKING + +from orchestration.jobs.nersc.login import NERSCLoginMethod + +if TYPE_CHECKING: + from orchestration.jobs.nersc.controller import NERSCJobController + +logger = logging.getLogger(__name__) + + +def _shifter_pull_script( + image: str, + log_dir: str, + *, + account: str, + qos: str = "debug", + walltime: str = "0:15:00", +) -> str: + """ + Build the Slurm script that pulls a Shifter image. + + Args: + - image: The container image to pull + - log_dir: Directory to store Slurm output and error logs + - account: Slurm account to charge the pull job to + - qos: Slurm QoS for the pull job (default: "debug") + - walltime: Walltime for the pull job (default: "0:15:00") + + Returns: + - A string containing the Slurm job script to pull the specified Shifter image. + """ + return f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -C cpu +#SBATCH --job-name=shifter_pull +#SBATCH --output={log_dir}/shifter_pull_%j.out +#SBATCH --error={log_dir}/shifter_pull_%j.err +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time={walltime} + +echo "Starting Shifter image pull at $(date)" +echo "Image: {image}" + +echo "Checking existing images..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" || true + +echo "Pulling image..." +shifterimg -v pull {image} +PULL_STATUS=$? + +if [ $PULL_STATUS -eq 0 ]; then + echo "Image pull successful" +else + echo "Image pull failed with status $PULL_STATUS" + exit 1 +fi + +echo "Verifying image..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" + +echo "Completed at $(date)" +""" + + +def _shifter_check_script( + image: str, + output_file: str, + *, + account: str, + qos: str = "debug", + walltime: str = "0:05:00", +) -> str: + """ + Build the Slurm script that writes Shifter cache state to a file. + + Args: + - image: The container image to check for in the Shifter cache + - output_file: The file to write the check output to + - account: Slurm account to charge the check job to + - qos: Slurm QoS for the check job (default: "debug") + - walltime: Walltime for the check job (default: "0:05:00") + + Returns: + - A string containing the Slurm job script to check for the specified Shifter image and write the results. + """ + return f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -C cpu +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time={walltime} +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true +""" + + +def _user_log_dir(controller: "NERSCJobController") -> str: + """ + Return the per-user log directory used for Shifter Slurm output. + This is a convention for storing logs in a user-specific location on the shared filesystem. + It is not an official NERSC requirement, but it helps avoid cluttering home directories and keeps logs organized. + The path is typically of the form: /pscratch/sd///shifter_logs + + Args: + - controller: A NERSCJobController instance used to get the username for constructing the log directory path. + Returns: + - A string representing the path to the user's log directory for Shifter jobs. + """ + username = controller.get_nersc_username() + return f"/pscratch/sd/{username[0]}/{username}/shifter_logs" + + +def pull_shifter_image( + controller: "NERSCJobController", + image: str, + wait: bool = True, + account: str = "als", +) -> bool: + """Pull a container image into NERSC's Shifter cache. + + Run this once when an image is updated, not before every job that uses it. + After the image is cached, jobs using ``--image=`` start much faster. + + Args: + controller: A NERSC job controller used to submit and monitor the + Slurm pull job. + image: Container image to pull, e.g. + ``"docker:ghcr.io/als-computing/tomopy-multinode:latest"``. + wait: If True, block until the pull job finishes and return its + success state. If False, return True as soon as the job is + submitted. + account: Slurm account to charge the pull job to. Defaults to + ``"als"``; pass a different value for other beamlines. + + Returns: + ``True`` if the pull succeeded (or was submitted, when ``wait=False``), + ``False`` if submission or the pull itself failed. + """ + logger.info(f"Pulling Shifter image: {image}") + + log_dir = _user_log_dir(controller) + controller.mkdir_remote(log_dir) + + job_script = _shifter_pull_script(image, log_dir, account=account) + + try: + job_id = controller.submit_job(job_script) + logger.info(f"Submitted Shifter pull job: {job_id}") + + if not wait: + logger.info(f"Returning early; check status with job ID {job_id}") + return True + + time.sleep(30) + success = controller.wait_for_job(job_id) + logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") + return success + + except Exception as e: + logger.error(f"Error during Shifter image pull: {e}") + return False + + +def check_shifter_image( + controller: "NERSCJobController", + image: str, + account: str = "als", +) -> bool: + """Check whether a container image is already in NERSC's Shifter cache. + + Dispatches on the controller's login method. SFAPI can run ``shifterimg`` + synchronously via the utilities endpoint; IRIAPI must submit a Slurm job, + wait for it, then read the captured stdout. + + Args: + controller: A NERSC job controller used to query Shifter. + image: Container image to check. + account: Slurm account for the IRIAPI check job. Ignored for SFAPI. + + Returns: + ``True`` if the image is present in the cache, ``False`` otherwise + (including on error). + """ + from sfapi_client.compute import Machine + + logger.info(f"Checking Shifter cache for: {image}") + + try: + if controller.login_method is NERSCLoginMethod.SFAPI: + # Synchronous: run shifterimg directly via the utilities endpoint + perlmutter = controller.client.compute(Machine.perlmutter) + result = perlmutter.run( + f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"" + ) + output = ( + result if isinstance(result, str) + else getattr(result, "output", None) or getattr(result, "stdout", "") or str(result) + ) + + elif controller.login_method is NERSCLoginMethod.IRIAPI: + # Async: submit a one-off job, wait, read the captured output file + log_dir = _user_log_dir(controller) + controller.mkdir_remote(log_dir) + output_file = f"{log_dir}/shifter_check_{int(time.time())}.txt" + + job_script = _shifter_check_script(image, output_file, account=account) + job_id = controller.submit_job(job_script) + controller.wait_for_job(job_id) + output = controller.read_remote_file(output_file) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {controller.login_method}") + + if output.strip(): + logger.info(f"Image found in Shifter cache: {output.strip()}") + return True + logger.info(f"Image not found in Shifter cache: {image}") + return False + + except Exception as e: + logger.warning(f"Error checking Shifter cache: {e}") + return False diff --git a/orchestration/jobs/options.py b/orchestration/jobs/options.py new file mode 100644 index 00000000..7b3da21c --- /dev/null +++ b/orchestration/jobs/options.py @@ -0,0 +1,105 @@ +"""Three-layer job option resolution: config defaults → MLflow → Prefect Variable. + +Lets beamline-specific code declare a base settings dict, optionally tie it to +an MLflow Model Registry entry, and accept runtime overrides via a Prefect +Variable — all without hardcoding any one beamline or HPC target. +""" + +import json +import logging +from typing import Any + +from prefect.variables import Variable + +from orchestration.config import BeamlineConfig +from orchestration.mlflow import get_checkpoint_info + +logger = logging.getLogger(__name__) + + +def load_job_options( + variable_name: str, + config_settings: dict[str, Any], + config: BeamlineConfig | None = None, + mlflow_model_name: str | None = None, + mlflow_checkpoint_key: str | None = None, +) -> dict[str, Any]: + """Load job options with three-layer resolution: config → MLflow → Prefect Variable. + + Resolution order (later layers win): + + 1. ``config_settings`` — authoritative defaults from the config YAML. + 2. MLflow Model Registry — if ``mlflow_model_name`` is provided, all + ``inference_params`` tags are overlaid onto opts by their config key name. + ``nersc_path`` is additionally mapped to ``mlflow_checkpoint_key`` if given. + 3. Prefect Variable (``variable_name``) — skipped if absent or ``defaults: true``. + If ``defaults: false``, provided keys override all lower layers. + + Args: + variable_name: Name of the Prefect Variable to load. + config_settings: Settings dict used as base defaults (e.g. + ``config.nersc_segment_sam3_settings``). + config: Beamline config instance needed for MLflow lookup. If ``None``, + the MLflow layer is skipped. + mlflow_model_name: Registered MLflow model name, e.g. ``'sam3-petiole'``. + If ``None``, the MLflow layer is skipped. + mlflow_checkpoint_key: Config key to populate from the MLflow model's + ``nersc_path`` tag, e.g. ``'finetuned_checkpoint_path'``. + + Returns: + Resolved options dict ready for use by the caller. + """ + # ── Layer 1: config defaults ────────────────────────────────────────────── + opts = dict(config_settings) + + # ── Layer 2: MLflow registry ────────────────────────────────────────────── + if config is not None and mlflow_model_name: + try: + checkpoint_info = get_checkpoint_info(mlflow_model_name, config) + if checkpoint_info: + # Map nersc_path to the caller-specified checkpoint key + if mlflow_checkpoint_key: + opts[mlflow_checkpoint_key] = checkpoint_info.nersc_path + logger.info( + f"MLflow '{mlflow_model_name}': " + f"{mlflow_checkpoint_key}={checkpoint_info.nersc_path}" + ) + # Overlay all inference params that match existing config keys + overlaid = [] + for k, v in checkpoint_info.inference_params.items(): + if k in opts: + opts[k] = v + overlaid.append(k) + else: + # Also inject new keys (e.g. alcf_path for future use) + opts[k] = v + logger.info( + f"MLflow '{mlflow_model_name}': overlaid params: {overlaid}" + ) + else: + logger.info( + f"MLflow: no production checkpoint for '{mlflow_model_name}', " + "using config defaults." + ) + except Exception as e: + logger.warning( + f"MLflow lookup failed for '{mlflow_model_name}': {e}. " + "Using config defaults." + ) + + # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── + try: + options = Variable.get(variable_name, default={"defaults": True}, _sync=True) + if isinstance(options, str): + options = json.loads(options) + except Exception as e: + logger.warning(f"Could not load '{variable_name}': {e}. Skipping variable overrides.") + return opts + + if options.get("defaults", True): + logger.info(f"Prefect Variable '{variable_name}': no overrides.") + return opts + + overrides = {k: v for k, v in options.items() if k != "defaults"} + logger.info(f"Prefect Variable '{variable_name}': applying overrides: {list(overrides)}") + return {**opts, **overrides} From afb3fe296ced15c0068bf7964486e2daec1025f0 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 26 May 2026 10:12:47 -0700 Subject: [PATCH 22/26] Large refactoring of job submission code. Moving ALCF/NERSC generic job submission code up a level, and leaving bl832-level code in the orchestration/flow/bl832/ directory. --- .../_tests/test_bl832/test_mlflow.py | 72 +- orchestration/_tests/test_bl832/test_nersc.py | 72 +- orchestration/_tests/test_sfapi_flow.py | 63 -- orchestration/flows/bl832/alcf.py | 156 +---- orchestration/flows/bl832/job_controller.py | 119 +--- orchestration/flows/bl832/nersc.py | 646 ++---------------- 6 files changed, 153 insertions(+), 975 deletions(-) delete mode 100644 orchestration/_tests/test_sfapi_flow.py diff --git a/orchestration/_tests/test_bl832/test_mlflow.py b/orchestration/_tests/test_bl832/test_mlflow.py index 49806332..43aafed1 100644 --- a/orchestration/_tests/test_bl832/test_mlflow.py +++ b/orchestration/_tests/test_bl832/test_mlflow.py @@ -226,13 +226,13 @@ class TestLoadJobOptionsMLflowLayer: def _patch_variable_defaults(self, mocker): mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": True}, ) def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832): """When MLflow returns a checkpoint, nersc_path is written to mlflow_checkpoint_key.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo self._patch_variable_defaults(mocker) @@ -246,12 +246,12 @@ def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832 inference_params={}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -263,7 +263,7 @@ def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832 def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_config832): """inference_params from MLflow overwrite matching config keys.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo self._patch_variable_defaults(mocker) @@ -281,12 +281,12 @@ def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_conf }, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -300,13 +300,13 @@ def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_conf def test_mlflow_layer_skipped_when_config_is_none(self, mocker, mock_config832): """Passing config=None skips the MLflow layer entirely.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) - spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=None, @@ -320,13 +320,13 @@ def test_mlflow_layer_skipped_when_config_is_none(self, mocker, mock_config832): def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config832): """Passing mlflow_model_name=None skips the MLflow layer.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) - spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") base_settings = dict(mock_config832.nersc_segment_sam3_settings) - _load_job_options( + load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -337,16 +337,16 @@ def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config8 def test_config_defaults_used_when_mlflow_returns_none(self, mocker, mock_config832): """get_checkpoint_info returning None → config defaults unchanged.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=None, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -358,16 +358,16 @@ def test_config_defaults_used_when_mlflow_returns_none(self, mocker, mock_config def test_config_defaults_used_when_mlflow_raises(self, mocker, mock_config832): """An exception from get_checkpoint_info is caught; config defaults are used.""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options self._patch_variable_defaults(mocker) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", side_effect=RuntimeError("Network timeout"), ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -379,12 +379,12 @@ def test_config_defaults_used_when_mlflow_raises(self, mocker, mock_config832): def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config832): """Prefect Variable overrides take priority over MLflow inference params (layer 3 > layer 2).""" - from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.jobs.options import load_job_options from orchestration.mlflow import ModelCheckpointInfo # MLflow says batch_size=8; Prefect Variable says batch_size=16 → 16 wins mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": False, "batch_size": 16}, ) @@ -397,12 +397,12 @@ def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config832): inference_params={"batch_size": 8}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=checkpoint_info, ) base_settings = dict(mock_config832.nersc_segment_sam3_settings) - opts = _load_job_options( + opts = load_job_options( "nersc-segmentation-options", base_settings, config=mock_config832, @@ -437,7 +437,7 @@ def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_config832): resolved_settings["finetuned_checkpoint_path"] = mlflow_checkpoint mocker.patch( - "orchestration.flows.bl832.nersc._load_job_options", + "orchestration.flows.bl832.nersc.load_job_options", return_value=resolved_settings, ) @@ -455,16 +455,16 @@ def capture_script(script, *args, **kwargs): config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) - mocker.patch.object(controller, "_submit_job", side_effect=capture_script) - mocker.patch.object(controller, "_wait_for_job", return_value=True) - mocker.patch.object(controller, "_mkdir_remote", return_value=None) + mocker.patch.object(controller, "submit_job", side_effect=capture_script) + mocker.patch.object(controller, "wait_for_job", return_value=True) + mocker.patch.object(controller, "mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) # _get_nersc_username reads NERSC_USERNAME for IRIAPI; stub it - mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + mocker.patch.object(controller, "get_nersc_username", return_value="testuser") result = controller.segmentation_sam3(recon_folder_path="folder/recfile") - assert captured, "_submit_job was never called" + assert captured, "submit_job was never called" assert mlflow_checkpoint in captured[0], ( "The MLflow checkpoint path must appear in the SLURM job script" ) @@ -475,11 +475,11 @@ def test_config_default_checkpoint_used_when_mlflow_unavailable(self, mocker, mo mocker.patch("orchestration.flows.bl832.nersc.time.sleep") mocker.patch( - "orchestration.flows.bl832.nersc.Variable.get", + "orchestration.jobs.options.Variable.get", return_value={"defaults": True}, ) mocker.patch( - "orchestration.flows.bl832.nersc.get_checkpoint_info", + "orchestration.jobs.options.get_checkpoint_info", return_value=None, ) @@ -494,16 +494,16 @@ def capture_script(script, *args, **kwargs): config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) - mocker.patch.object(controller, "_submit_job", side_effect=capture_script) - mocker.patch.object(controller, "_wait_for_job", return_value=True) - mocker.patch.object(controller, "_mkdir_remote", return_value=None) + mocker.patch.object(controller, "submit_job", side_effect=capture_script) + mocker.patch.object(controller, "wait_for_job", return_value=True) + mocker.patch.object(controller, "mkdir_remote", return_value=None) mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) - mocker.patch.object(controller, "_get_nersc_username", return_value="testuser") + mocker.patch.object(controller, "get_nersc_username", return_value="testuser") controller.segmentation_sam3(recon_folder_path="folder/recfile") config_default = mock_config832.nersc_segment_sam3_settings["finetuned_checkpoint_path"] - assert captured, "_submit_job was never called" + assert captured, "submit_job was never called" assert config_default in captured[0], ( "Config default checkpoint path must be used when MLflow is unavailable" ) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 8db8ae0f..97381a18 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -201,60 +201,6 @@ def mock_iriapi_client(mocker): return client -# --------------------------------------------------------------------------- -# _create_sfapi_client -# --------------------------------------------------------------------------- - - -def test_create_sfapi_client_success(mocker): - """Valid credentials produce a Client instance.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch( - "builtins.open", - side_effect=[ - mocker.mock_open(read_data="my-client-id")(), - mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), - ] - ) - mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") - mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - - client = NERSCTomographyHPCController._create_sfapi_client() - - mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") - assert client is mock_client_cls.return_value - - -def test_create_sfapi_client_missing_paths(mocker): - """Unset env vars raise ValueError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController._create_sfapi_client() - - -def test_create_sfapi_client_missing_files(mocker): - """Env vars set but files absent raise FileNotFoundError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController._create_sfapi_client() - - # ────────────────────────────────────────────────────────────────────────────── # build_multi_resolution # ────────────────────────────────────────────────────────────────────────────── @@ -304,7 +250,7 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -326,7 +272,7 @@ def test_segmentation_sam3_submission_failure(mocker, mock_sfapi_client, mock_co from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("GPU queue full") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -346,7 +292,7 @@ def test_segmentation_sam3_uses_variable_options(mocker, mock_sfapi_client, mock from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={ + mocker.patch("orchestration.jobs.options.Variable.get", return_value={ "defaults": False, "batch_size": 8, "patch_size": 512, @@ -395,7 +341,7 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -414,7 +360,7 @@ def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_ from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("No GPU nodes") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -605,7 +551,7 @@ def test_segmentation_dinov3_output_paths(mocker, mock_sfapi_client, mock_config from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) captured_scripts = [] original_return = mock_sfapi_client.compute.return_value.submit_job.return_value @@ -636,7 +582,7 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 from sfapi_client.compute import Machine mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController( client=mock_sfapi_client, config=mock_config832, @@ -655,7 +601,7 @@ def test_combine_segmentations_submission_failure(mocker, mock_sfapi_client, moc from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Cluster down") controller = NERSCTomographyHPCController( client=mock_sfapi_client, @@ -673,7 +619,7 @@ def test_combine_segmentations_script_references_sam3_and_dino(mocker, mock_sfap from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) captured_scripts = [] original_return = mock_sfapi_client.compute.return_value.submit_job.return_value diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py deleted file mode 100644 index e0d4a854..00000000 --- a/orchestration/_tests/test_sfapi_flow.py +++ /dev/null @@ -1,63 +0,0 @@ -# orchestration/_tests/test_sfapi_flow.py -import pytest -from uuid import uuid4 - -from prefect.blocks.system import Secret -from prefect.testing.utilities import prefect_test_harness - - -@pytest.fixture(autouse=True, scope="session") -def prefect_test_fixture(): - with prefect_test_harness(): - Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) - Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) - yield - - -def test_create_sfapi_client_success(mocker): - """Valid credentials produce a Client instance.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch( - "builtins.open", - side_effect=[ - mocker.mock_open(read_data="my-client-id")(), - mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), - ] - ) - mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") - mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - - client = NERSCTomographyHPCController._create_sfapi_client() - - mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") - assert client is mock_client_cls.return_value - - -def test_create_sfapi_client_missing_paths(mocker): - """Unset env vars raise ValueError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController._create_sfapi_client() - - -def test_create_sfapi_client_missing_files(mocker): - """Env vars set but files absent raise FileNotFoundError.""" - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret", - }.get(x)) - mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController._create_sfapi_client() diff --git a/orchestration/flows/bl832/alcf.py b/orchestration/flows/bl832/alcf.py index 28ac7813..326a1b00 100644 --- a/orchestration/flows/bl832/alcf.py +++ b/orchestration/flows/bl832/alcf.py @@ -1,45 +1,33 @@ -from concurrent.futures import Future import datetime +import logging from pathlib import Path -import time from typing import Optional -from globus_compute_sdk import Client, Executor -from globus_compute_sdk.serialize import CombinedCode from prefect import flow, task, get_run_logger -from prefect.blocks.system import Secret from prefect.variables import Variable from orchestration.flows.bl832.config import Config832 from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.jobs.alcf.controller import ALCFJobController from orchestration.transfer_controller import get_transfer_controller, CopyMethod from orchestration.prefect import schedule_prefect_flow from orchestration.tiled import register_file_to_tiled +logger = logging.getLogger(__name__) -class ALCFTomographyHPCController(TomographyHPCController): - """ - Implementation of TomographyHPCController for ALCF. Methods here leverage Globus Compute for processing tasks. - There is a @staticmethod wrapper for each compute task submitted via Globus Compute. - Also, there is a shared wait_for_globus_compute_future method that waits for the task to complete. - Args: - TomographyHPCController (ABC): Abstract class for tomography HPC controllers. +class ALCFTomographyHPCController(TomographyHPCController, ALCFJobController): + """ALCF tomography HPC controller for BL832. + + Submits reconstruction and multi-resolution jobs to ALCF via Globus Compute. + Beamline-agnostic submit/wait primitives are inherited from ALCFJobController. """ - def __init__( - self, - config: Config832 - ) -> None: - super().__init__(config) - # Load allocation root from the Prefect JSON block - # The block must be registered with the name "alcf-allocation-root-path" - logger = get_run_logger() - allocation_data = Variable.get("alcf-allocation-root-path", _sync=True) - self.allocation_root = allocation_data.get("alcf-allocation-root-path") - if not self.allocation_root: - raise ValueError("Allocation root not found in JSON block 'alcf-allocation-root-path'") - logger.info(f"Allocation root loaded: {self.allocation_root}") + def __init__(self, config: Config832) -> None: + # ALCF doesn't take a pre-built client — Globus Compute Client is + # constructed inside ALCFJobController.submit() per-call. + ALCFJobController.__init__(self, config) + TomographyHPCController.__init__(self, config) def reconstruct( self, @@ -54,26 +42,22 @@ def reconstruct( Returns: bool: True if the task completed successfully, False otherwise. """ - logger = get_run_logger() + run_logger = get_run_logger() file_name = Path(file_path).stem + ".h5" folder_name = Path(file_path).parent.name iri_als_bl832_rundir = f"{self.allocation_root}/data/raw" iri_als_bl832_recon_script = f"{self.allocation_root}/scripts/globus_reconstruction.py" - gcc = Client(code_serialization_strategy=CombinedCode()) - - with Executor(endpoint_id=Secret.load("globus-compute-endpoint").get(), client=gcc) as fxe: - logger.info(f"Running Tomopy reconstruction on {file_name} at ALCF") - future = fxe.submit( - self._reconstruct_wrapper, - iri_als_bl832_rundir, - iri_als_bl832_recon_script, - file_name, - folder_name - ) - result = self._wait_for_globus_compute_future(future, "reconstruction", check_interval=10) - return result + run_logger.info(f"Running Tomopy reconstruction on {file_name} at ALCF") + future = self.submit( + self._reconstruct_wrapper, + iri_als_bl832_rundir, + iri_als_bl832_recon_script, + file_name, + folder_name, + ) + return self.wait_for_future(future, "reconstruction", check_interval=10) @staticmethod def _reconstruct_wrapper( @@ -129,7 +113,7 @@ def build_multi_resolution( Returns: bool: True if the task completed successfully, False otherwise. """ - logger = get_run_logger() + run_logger = get_run_logger() file_name = Path(file_path).stem folder_name = Path(file_path).parent.name @@ -140,19 +124,15 @@ def build_multi_resolution( iri_als_bl832_rundir = f"{self.allocation_root}/data/raw" iri_als_bl832_conversion_script = f"{self.allocation_root}/scripts/tiff_to_zarr.py" - gcc = Client(code_serialization_strategy=CombinedCode()) - - with Executor(endpoint_id=Secret.load("globus-compute-endpoint").get(), client=gcc) as fxe: - logger.info(f"Running Tiff to Zarr on {raw_path} at ALCF") - future = fxe.submit( - self._build_multi_resolution_wrapper, - iri_als_bl832_rundir, - iri_als_bl832_conversion_script, - tiff_scratch_path, - raw_path - ) - result = self._wait_for_globus_compute_future(future, "tiff to zarr conversion", check_interval=10) - return result + run_logger.info(f"Running Tiff to Zarr on {raw_path} at ALCF") + future = self.submit( + self._build_multi_resolution_wrapper, + iri_als_bl832_rundir, + iri_als_bl832_conversion_script, + tiff_scratch_path, + raw_path, + ) + return self.wait_for_future(future, "tiff to zarr conversion", check_interval=10) @staticmethod def _build_multi_resolution_wrapper( @@ -186,76 +166,6 @@ def _build_multi_resolution_wrapper( f"Converted tiff files to zarr;\n {zarr_res}" ) - @staticmethod - def _wait_for_globus_compute_future( - future: Future, - task_name: str, - check_interval: int = 20, - walltime: int = 1200 # seconds = 20 minutes - ) -> bool: - """ - Wait for a Globus Compute task to complete, assuming that if future.done() is False, the task is running. - - Args: - future: The future object returned from the Globus Compute Executor submit method. - task_name: A descriptive name for the task being executed (used for logging). - check_interval: The interval (in seconds) between status checks. - walltime: The maximum time (in seconds) to wait for the task to complete. - - Returns: - bool: True if the task completed successfully within walltime, False otherwise. - """ - logger = get_run_logger() - - start_time = time.time() - success = False - - try: - previous_state = None - while not future.done(): - elapsed_time = time.time() - start_time - if elapsed_time > walltime: - logger.error(f"The {task_name} task exceeded the walltime of {walltime} seconds." - "Cancelling the Globus Compute job.") - future.cancel() - return False - - # Check if the task was cancelled - if future.cancelled(): - logger.warning(f"The {task_name} task was cancelled.") - return False - # Assume the task is running if not done and not cancelled - elif previous_state != 'running': - logger.info(f"The {task_name} task is running...") - previous_state = 'running' - - time.sleep(check_interval) # Wait before the next status check - - # Task is done, check if it was cancelled or raised an exception - if future.cancelled(): - logger.warning(f"The {task_name} task was cancelled after completion.") - return False - - exception = future.exception() - if exception: - logger.error(f"The {task_name} task raised an exception: {exception}") - return False - - # Task completed successfully - result = future.result() - logger.info(f"The {task_name} task completed successfully with result: {result}") - success = True - - except Exception as e: - logger.error(f"An error occurred while waiting for the {task_name} task: {str(e)}") - success = False - - finally: - # Log the total time taken for the task - elapsed_time = time.time() - start_time - logger.info(f"Total duration of the {task_name} task: {elapsed_time:.2f} seconds.") - - return success @task(name="schedule_prune_task") diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index c526f72a..ddf3eff3 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -1,48 +1,36 @@ from abc import ABC, abstractmethod from dotenv import load_dotenv -from enum import Enum import logging -from orchestration.flows.bl832.config import Config832 +from orchestration.config import BeamlineConfig +from orchestration.jobs.controller import JobTarget +from orchestration.jobs.nersc.login import NERSCLoginMethod # noqa: F401 — re-exported for callers logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() -class NERSCLoginMethod(Enum): - """Selects which NERSC API login method to use when creating a NERSC client. - - Each method corresponds to a different set of credentials and API base URL. - """ - - SFAPI = "sfapi" - """Standard Superfacility API via Iris-registered OAuth2 credentials.""" - - IRIAPI = "iriapi" - """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" +# HPC is a BL832-scoped alias for backwards compat. +# New code should use JobTarget from orchestration.jobs.controller. +# TODO: retire this alias after all call sites are migrated. +HPC = JobTarget class TomographyHPCController(ABC): - """ - Abstract class for tomography HPC controllers. - Provides interface methods for reconstruction and building multi-resolution datasets. + """Abstract class for tomography HPC controllers. - Args: - ABC: Abstract Base Class + Provides interface methods for reconstruction and building multi-resolution + datasets. Stays in bl832/job_controller.py because reconstruct and + build_multi_resolution are tomography-specific, not generic infrastructure. """ - def __init__( - self, - config: Config832 - ) -> None: + + def __init__(self, config: BeamlineConfig) -> None: self.config = config @abstractmethod - def reconstruct( - self, - file_path: str = "", - ) -> bool: - """Perform tomography reconstruction + def reconstruct(self, file_path: str = "") -> bool: + """Perform tomography reconstruction. :param file_path: Path to the file to reconstruct. :return: True if successful, False otherwise. @@ -50,11 +38,8 @@ def reconstruct( pass @abstractmethod - def build_multi_resolution( - self, - file_path: str = "", - ) -> bool: - """Generate multi-resolution version of reconstructed tomography + def build_multi_resolution(self, file_path: str = "") -> bool: + """Generate multi-resolution version of reconstructed tomography. :param file_path: Path to the file for which to build multi-resolution data. :return: True if successful, False otherwise. @@ -62,31 +47,18 @@ def build_multi_resolution( pass -class HPC(Enum): - """ - Enum representing different HPC environments. - Use enum names as strings to identify HPC sites, ensuring a standard set of values. - - Members: - ALCF: Argonne Leadership Computing Facility - NERSC: National Energy Research Scientific Computing Center - """ - ALCF = "ALCF" - NERSC = "NERSC" - OLCF = "OLCF" - - def get_controller( hpc_type: HPC, - config: Config832, + config: BeamlineConfig, login_method: NERSCLoginMethod | None = None, ) -> TomographyHPCController: - """ - Factory function that returns an HPC controller instance for the given HPC environment. + """Factory: return the appropriate tomography HPC controller. - :param hpc_type: A string identifying the HPC environment (e.g., 'ALCF', 'NERSC'). - :return: An instance of a TomographyHPCController subclass corresponding to the given HPC environment. - :raises ValueError: If an invalid or unsupported HPC type is specified. + :param hpc_type: Target HPC site (use HPC or JobTarget enum). + :param config: Beamline configuration object. + :param login_method: NERSC-only; which API to authenticate against. + :return: A TomographyHPCController subclass instance. + :raises ValueError: If hpc_type is invalid or config is missing. """ if not isinstance(hpc_type, HPC): raise ValueError(f"Invalid HPC type provided: {hpc_type}") @@ -96,42 +68,25 @@ def get_controller( if hpc_type == HPC.ALCF: from orchestration.flows.bl832.alcf import ALCFTomographyHPCController - return ALCFTomographyHPCController( - config=config - ) + return ALCFTomographyHPCController(config=config) + elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod - resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.jobs.nersc.login import create_nersc_client + resolved_login_method = ( + login_method if isinstance(login_method, NERSCLoginMethod) + else NERSCLoginMethod.SFAPI + ) + client = create_nersc_client(config, resolved_login_method) return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_nersc_client( - config=config, - login_method=resolved_login_method - ), config=config, + client=client, login_method=resolved_login_method, ) + elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller - pass + raise NotImplementedError("OLCF controller not yet implemented") + else: raise ValueError(f"Unsupported HPC type: {hpc_type}") - - -def do_it_all() -> None: - controller = get_controller("ALCF") - controller.reconstruct() - controller.build_multi_resolution() - - file_path = "" - controller = get_controller("NERSC") - controller.reconstruct( - file_path=file_path, - ) - controller.build_multi_resolution( - file_path=file_path, - ) - - -if __name__ == "__main__": - do_it_all() - logger.info("Done.") diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 7a13bd18..5a1d8846 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2,30 +2,25 @@ import datetime from dotenv import load_dotenv import httpx -import json import logging -import os from pathlib import Path import re import time -from authlib.jose import JsonWebKey from prefect import flow, get_run_logger, task from prefect.variables import Variable from sfapi_client import Client -from sfapi_client.compute import Machine from typing import Any, Optional from orchestration.flows.bl832.config import Config832 -from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController +from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.mlflow import get_checkpoint_info -from orchestration.globus.get_globus_token import ( - get_iri_access_token, - DEFAULT_TOKEN_FILE, -) +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod +from orchestration.jobs.nersc.shifter import pull_shifter_image, check_shifter_image +from orchestration.jobs.options import load_job_options from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod from orchestration.tiled import register_file_to_tiled @@ -36,15 +31,12 @@ logger.setLevel(logging.INFO) load_dotenv() -# Applies only to NERSCLoginMethod.IRIAPI -_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" - @dataclass class SegmentationModelSpec: """All config-resolution inputs for a single model+project combination. - Consumed by ``_load_job_options`` and the job-script builders. + Consumed by ``load_job_options`` and the job-script builders. Adding a new model or project means adding one entry to the registry — nothing else changes. @@ -53,8 +45,6 @@ class SegmentationModelSpec: :param mlflow_model_name: Registered MLflow model name. :param mlflow_checkpoint_key: Config key populated from the MLflow model's ``nersc_path`` tag. - :param output_subdir: Subdirectory written under ``seg_folder/``, - e.g. ``'dino'``, ``'sam3'``, ``'dino_moon'``. :param extra_cli_flags: Additional flags injected into the inference command, e.g. ``{'--project': 'moon'}``. Omit flags not needed. """ @@ -65,98 +55,12 @@ class SegmentationModelSpec: extra_cli_flags: dict[str, str] = field(default_factory=dict) -def _load_job_options( - variable_name: str, - config_settings: dict[str, Any], - config: Config832 | None = None, - mlflow_model_name: str | None = None, - mlflow_checkpoint_key: str | None = None, -) -> dict[str, Any]: - """Load job options with three-layer resolution: config → MLflow → Prefect Variable. - - Resolution order (later layers win): - - 1. ``config_settings`` — authoritative defaults from the config YAML. - 2. MLflow Model Registry — if ``mlflow_model_name`` is provided, all - ``inference_params`` tags are overlaid onto opts by their config key name. - ``nersc_path`` is additionally mapped to ``mlflow_checkpoint_key`` if given. - 3. Prefect Variable (``variable_name``) — skipped if absent or ``defaults: true``. - If ``defaults: false``, provided keys override all lower layers. +class NERSCTomographyHPCController(TomographyHPCController, NERSCJobController, NerscStreamingMixin): + """NERSC tomography HPC controller for BL832. - Args: - variable_name: Name of the Prefect Variable to load. - config_settings: Settings dict from Config832 used as base defaults. - config: Config832 instance needed for MLflow lookup. If ``None``, the - MLflow layer is skipped. - mlflow_model_name: Registered MLflow model name, e.g. ``'sam3-petiole'``. - If ``None``, the MLflow layer is skipped. - mlflow_checkpoint_key: Config key to populate from the MLflow model's - ``nersc_path`` tag, e.g. ``'finetuned_checkpoint_path'``. - - Returns: - Resolved options dict ready for use by the caller. - """ - # ── Layer 1: config defaults ────────────────────────────────────────────── - opts = dict(config_settings) - - # ── Layer 2: MLflow registry ────────────────────────────────────────────── - if config is not None and mlflow_model_name: - try: - checkpoint_info = get_checkpoint_info(mlflow_model_name, config) - if checkpoint_info: - # Map nersc_path to the caller-specified checkpoint key - if mlflow_checkpoint_key: - opts[mlflow_checkpoint_key] = checkpoint_info.nersc_path - logger.info( - f"MLflow '{mlflow_model_name}': " - f"{mlflow_checkpoint_key}={checkpoint_info.nersc_path}" - ) - # Overlay all inference params that match existing config keys - overlaid = [] - for k, v in checkpoint_info.inference_params.items(): - if k in opts: - opts[k] = v - overlaid.append(k) - else: - # Also inject new keys (e.g. alcf_path for future use) - opts[k] = v - logger.info( - f"MLflow '{mlflow_model_name}': overlaid params: {overlaid}" - ) - else: - logger.info( - f"MLflow: no production checkpoint for '{mlflow_model_name}', " - "using config defaults." - ) - except Exception as e: - logger.warning( - f"MLflow lookup failed for '{mlflow_model_name}': {e}. " - "Using config defaults." - ) - - # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── - try: - options = Variable.get(variable_name, default={"defaults": True}, _sync=True) - if isinstance(options, str): - options = json.loads(options) - except Exception as e: - logger.warning(f"Could not load '{variable_name}': {e}. Skipping variable overrides.") - return opts - - if options.get("defaults", True): - logger.info(f"Prefect Variable '{variable_name}': no overrides.") - return opts - - overrides = {k: v for k, v in options.items() if k != "defaults"} - logger.info(f"Prefect Variable '{variable_name}': applying overrides: {list(overrides)}") - return {**opts, **overrides} - - -class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): - """ - Implementation for a NERSC-based tomography HPC controller. - - Submits reconstruction and multi-resolution jobs to NERSC via SFAPI. + Submits reconstruction, multi-resolution, and segmentation jobs to NERSC + via the SFAPI or IRI API. Beamline-agnostic job primitives (submit, wait, + filesystem ops) are inherited from NERSCJobController. """ def __init__( @@ -165,6 +69,7 @@ def __init__( client: Client | httpx.Client | None = None, login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, ) -> None: + NERSCJobController.__init__(self, config, client, login_method) TomographyHPCController.__init__(self, config) self.client = client self.login_method = login_method @@ -175,119 +80,6 @@ def __init__( else: raise ValueError(f"Unsupported NERSCLoginMethod: {login_method}") - @staticmethod - def create_nersc_client( - config: Config832, - login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, - ) -> Client | httpx.Client: - """Create and return a NERSC client for the requested login method. - - Two fundamentally different auth strategies are supported: - - - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 - client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` - and ``PATH_NERSC_PRI_KEY`` to the paths of those files. - - - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written - by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token - file path, or rely on the default (``~/.globus/auth_tokens.json``). - - Args: - config: Config832 instance for accessing config settings needed during client creation. - login_method: Which NERSC API to authenticate against. - Defaults to :attr:`NERSCLoginMethod.IRIAPI`. - - Returns: - An authenticated :class:`sfapi_client.Client` instance. - - Raises: - ValueError: If SFAPI credential environment variables are unset. - FileNotFoundError: If credential or token files are absent. - RuntimeError: If the Globus token is expired. - Exception: If the underlying client construction fails. - """ - logger.info(f"Creating NERSC client using login method: {login_method.value}") - - if login_method is NERSCLoginMethod.SFAPI: - api_base_url = config.nersc_resources["sfapi"]["api_base_url"] - client = NERSCTomographyHPCController._create_sfapi_client() - - elif login_method is NERSCLoginMethod.IRIAPI: - api_base_url = config.nersc_resources["iri"]["api_base_url"] - client = NERSCTomographyHPCController._create_iriapi_client(api_base_url) - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") - - logger.info( - f"NERSC client created successfully " - f"(method={login_method.value}, api_url={api_base_url})." - ) - return client - - @staticmethod - def _create_iriapi_client(api_base_url: str) -> httpx.Client: - """Create a NERSC client for the IRI API using a Globus bearer token. - - Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the - environment. Reuses a cached token if valid; otherwise mints a new one - via the client credentials grant. No browser or user interaction. - - Parameters: - api_base_url: The base URL for the NERSC IRI API - Returns: - An authenticated :class:`httpx.Client` targeting the IRI API. - - Raises: - ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. - RuntimeError: If the acquired token is missing required scopes. - """ - token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) - token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - - access_token = get_iri_access_token( - token_file=token_file, - force_login=False, - prompt_login=False - ) - - return httpx.Client( - base_url=api_base_url, - headers={"Authorization": f"Bearer {access_token}"}, - timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), - ) - - @staticmethod - def _create_sfapi_client() -> Client: - """Create and return an NERSC client instance""" - - # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! - # Otherwise, the key will not have the necessary permissions to access the data. - client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") - client_secret_path = os.getenv("PATH_NERSC_PRI_KEY") - - if not client_id_path or not client_secret_path: - logger.error("NERSC credentials paths are missing.") - raise ValueError("Missing NERSC credentials paths.") - if not os.path.isfile(client_id_path) or not os.path.isfile(client_secret_path): - logger.error("NERSC credential files are missing.") - raise FileNotFoundError("NERSC credential files are missing.") - - client_id = None - client_secret = None - with open(client_id_path, "r") as f: - client_id = f.read() - - with open(client_secret_path, "r") as f: - client_secret = JsonWebKey.import_key(json.loads(f.read())) - - try: - client = Client(client_id, client_secret) - logger.info("NERSC client created successfully.") - return client - except Exception as e: - logger.error(f"Failed to create NERSC client: {e}") - raise e - def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelSpec: """Return the SegmentationModelSpec for a model+project combination. @@ -326,234 +118,6 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] - def _get_nersc_username(self) -> str: - """Get the NERSC username for constructing pscratch paths. - - Uses the sfapi_client user endpoint for SFAPI, or reads - ``NERSC_USERNAME`` from the environment for IRIAPI. - - Returns: - NERSC username string. - - Raises: - ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - return self.client.user().name - else: - username = os.getenv("NERSC_USERNAME") - if not username: - raise ValueError( - "NERSC_USERNAME must be set in the environment when using IRIAPI." - ) - return username - - def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: - """Submit a Slurm job script and return the job ID. - - Dispatches to the appropriate submission mechanism based on - ``self.login_method``. - - Args: - job_script: The full Slurm batch script to submit. - num_nodes: The number of nodes to request for the job. - - Returns: - The submitted job ID as a string. - - Raises: - RuntimeError: If job submission fails. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - return str(job.jobid) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - sbatch_values = {} - for line in job_script.splitlines(): - if line.startswith("#SBATCH"): - if "-q " in line: - sbatch_values["queue_name"] = line.split("-q ")[-1].strip() - elif "-A " in line: - sbatch_values["account"] = line.split("-A ")[-1].strip() - elif "--time=" in line: - t = line.split("--time=")[-1].strip() - parts = t.split(":") - sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) - elif "-N " in line: - sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) - elif "-C " in line: - sbatch_values["constraint"] = line.split("-C ")[-1].strip() - elif "--output=" in line: - sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() - elif "--error=" in line: - sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() - elif "--reservation=" in line: - sbatch_values["reservation"] = line.split("--reservation=")[-1].strip() - - # Strip shebang and SBATCH headers, keep the script body - script_body = "\n".join( - line for line in job_script.splitlines() - if not line.startswith("#SBATCH") and not line.startswith("#!/") - ).strip() - - constraint = sbatch_values.get("constraint", "cpu") - is_gpu = "gpu" in constraint.lower() - - resources = { - "node_count": sbatch_values.get("node_count", 1), - "processes_per_node": 1, - "exclusive_node_use": True, - } - if is_gpu: - resources["gpu_cores_per_process"] = 4 - else: - resources["cpu_cores_per_process"] = 128 - - custom_attributes = {"constraint": constraint} - - attributes = { - "duration": sbatch_values.get("duration", 1800), - "queue_name": sbatch_values.get("queue_name", "regular"), - "account": sbatch_values.get("account", "als"), - "custom_attributes": custom_attributes, - } - if "reservation" in sbatch_values: - attributes["reservation_id"] = sbatch_values["reservation"] - - job_spec = { - "executable": "/bin/bash", - "arguments": ["-s"], # read script from stdin isn't supported, so... - "pre_launch": script_body, # run the body here before the executable - "resources": resources, - "attributes": attributes, - } - - if "stdout_path" in sbatch_values: - job_spec["stdout_path"] = sbatch_values["stdout_path"] - if "stderr_path" in sbatch_values: - job_spec["stderr_path"] = sbatch_values["stderr_path"] - - response = self.client.post( - f"/api/v1/compute/job/{self.nersc_resources['perlmutter_job_submit']}", - json=job_spec, - ) - if not response.is_success: - logger.error(f"Job submission failed: {response.status_code} {response.text}") - logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") - response.raise_for_status() - return str(response.json()["id"]) - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _wait_for_job(self, job_id: str) -> bool: - """Block until a submitted job completes. - - Dispatches to the appropriate polling mechanism based on - ``self.login_method``. - - Args: - job_id: The job ID returned by `_submit_job`. - - Returns: - True if the job completed successfully, False otherwise. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.job(jobid=job_id) - job.complete() - return True - - elif self.login_method is NERSCLoginMethod.IRIAPI: - while True: - response = self.client.get( - f"/api/v1/compute/status/{self.nersc_resources['compute_resource']}/{job_id}" - ) - response.raise_for_status() - state = response.json().get("status", {}).get("state") - logger.info(f"Job {job_id} state: {state}") - if state == "completed": - return True - if state in ("failed", "canceled", "timeout"): - logger.error(f"Job {job_id} ended with state: {state}") - return False - time.sleep(60) - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _mkdir_remote(self, path: str) -> None: - """Create a directory on Perlmutter remotely. - - Args: - path: Absolute path to create. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - perlmutter.run(f"mkdir -p {path}") - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.post( - f"/api/v1/filesystem/mkdir/{self.nersc_resources['perlmutter_login']}", - json={"path": path, "parents": True}, - ) - response.raise_for_status() - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - - def _read_remote_file(self, path: str) -> str: - """Read a remote file on Perlmutter and return its contents. - - Args: - path: Absolute path to the file on Perlmutter. - - Returns: - File contents as a string. - """ - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"cat {path}") - if isinstance(result, str): - return result - elif hasattr(result, 'output'): - return result.output - elif hasattr(result, 'stdout'): - return result.stdout - return str(result) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.get( - f"/api/v1/filesystem/view/{self.nersc_resources['perlmutter_login']}", - params={"path": path}, - ) - response.raise_for_status() - task_id = response.json().get("task_id") - if not task_id: - return response.text - - for _ in range(40): - task_response = self.client.get(f"/api/v1/task/{task_id}") - task_response.raise_for_status() - task = task_response.json() - status = task.get("status") - if status == "completed": - result = task.get("result", "") - if isinstance(result, dict): - output = result.get("output", result) - if isinstance(output, dict): - return output.get("content", str(output)) - return str(output) - return str(result) - elif status == "failed": - raise RuntimeError(f"File read task {task_id} failed: {task.get('result')}") - time.sleep(3) - - raise TimeoutError(f"File read task {task_id} did not complete") - - else: - raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") - def reconstruct( self, file_path: str = "", @@ -568,7 +132,7 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - username = self._get_nersc_username() + username = self.get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -596,7 +160,7 @@ def reconstruct( logger.info(f"Folder name: {folder_name}") logger.info(f"Number of nodes: {num_nodes}") - opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + opts = load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) logger.info(f"Resolved options: {opts}") @@ -738,10 +302,10 @@ def reconstruct( job_id = None try: logger.info("Submitting reconstruction job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) timing = self._fetch_timing_data(pscratch_path, job_id) if success else None return {"success": success, "job_id": job_id, "timing": timing} except Exception as e: @@ -759,7 +323,7 @@ def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: - output = self._read_remote_file(timing_file) + output = self.read_remote_file(timing_file) logger.info(f"Timing file contents:\n{output}") @@ -815,7 +379,7 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - username = self._get_nersc_username() + username = self.get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -841,7 +405,7 @@ def build_multi_resolution( # account = self.config.nersc_account - opts = _load_job_options( + opts = load_job_options( "nersc-multiresolution-options", self.config.nersc_multiresolution_settings ) @@ -882,10 +446,10 @@ def build_multi_resolution( """ try: logger.info("Submitting Tiff to Zarr job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") return success except Exception as e: @@ -902,10 +466,10 @@ def segmentation_sam3( """ logger.info("Starting NERSC segmentation process (inference_v6).") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - opts = _load_job_options( + opts = load_job_options( variable_name="nersc-segmentation-options", config_settings=self.config.nersc_segment_sam3_settings, config=self.config, @@ -1097,14 +661,14 @@ def segmentation_sam3( # Ensure directories exist logger.info("Creating necessary directories...") - self._mkdir_remote(f"{pscratch_path}/tomo_seg_logs") - self._mkdir_remote(output_dir) + self.mkdir_remote(f"{pscratch_path}/tomo_seg_logs") + self.mkdir_remote(output_dir) # Submit job - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info("Segmentation job completed successfully.") timing = self._fetch_seg_timing_from_output(pscratch_path, job_id, job_name) @@ -1153,12 +717,12 @@ def segmentation_dinov3( """ logger.info("Starting NERSC DINOv3 segmentation process.") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" # Load from config spec = self._get_segmentation_spec("dinov3", project) - opts = _load_job_options( + opts = load_job_options( variable_name=spec.variable_name, config_settings=spec.settings, config=self.config, @@ -1294,10 +858,10 @@ def segmentation_dinov3( """ try: logger.info("Submitting DINOv3 segmentation job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"DINOv3 segmentation job {'completed successfully' if success else 'failed'}.") return success except Exception as e: @@ -1318,10 +882,10 @@ def combine_segmentations( """ logger.info("Starting NERSC segmentation combination process.") - username = self._get_nersc_username() + username = self.get_nersc_username() pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - opts = _load_job_options( + opts = load_job_options( "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings ) @@ -1418,10 +982,10 @@ def combine_segmentations( """ try: logger.info("Submitting segmentation combination job to Perlmutter.") - job_id = self._submit_job(job_script) + job_id = self.submit_job(job_script) logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - success = self._wait_for_job(job_id) + success = self.wait_for_job(job_id) logger.info(f"Segmentation combination job {'completed successfully' if success else 'failed'}.") return success except Exception as e: @@ -1440,7 +1004,7 @@ def _fetch_seg_timing_from_output(self, pscratch_path: str, job_id: str, job_nam output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" try: - output = self._read_remote_file(output_file) + output = self.read_remote_file(output_file) logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') @@ -1500,140 +1064,6 @@ def start_streaming_service( walltime=walltime ) - def pull_shifter_image( - self, - image: str = None, - wait: bool = True, - ) -> bool: - """ - Pull a container image into NERSC's Shifter cache. - - This should be run once when the image is updated, not before every reconstruction. - After the image is cached, jobs using --image= will start much faster. - - :param image: Container image to pull (defaults to recon_image from config) - :param wait: Whether to wait for the pull to complete - :return: True if successful, False otherwise - """ - logger.info("Starting Shifter image pull.") - - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - - if image is None: - image = self.config.ghcr_images832["recon_image"] - - logger.info(f"Pulling image: {image}") - - job_script = f"""#!/bin/bash -#SBATCH -q debug -#SBATCH -A als -#SBATCH -C cpu -#SBATCH --job-name=shifter_pull -#SBATCH --output={pscratch_path}/tomo_recon_logs/shifter_pull_%j.out -#SBATCH --error={pscratch_path}/tomo_recon_logs/shifter_pull_%j.err -#SBATCH -N 1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=1 -#SBATCH --time=0:15:00 - -echo "Starting Shifter image pull at $(date)" -echo "Image: {image}" - -# Check if image already exists -echo "Checking existing images..." -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" || true - -# Pull the image -echo "Pulling image..." -shifterimg -v pull {image} -PULL_STATUS=$? - -if [ $PULL_STATUS -eq 0 ]; then - echo "Image pull successful" -else - echo "Image pull failed with status $PULL_STATUS" - exit 1 -fi - -# Verify the image is now available -echo "Verifying image..." -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" - -echo "Completed at $(date)" -""" - - try: - logger.info("Submitting Shifter image pull job to Perlmutter.") - job_id = self._submit_job(job_script) - logger.info(f"Submitted job ID: {job_id}") - - if wait: - time.sleep(30) - success = self._wait_for_job(job_id) - logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") - return success - else: - logger.info(f"Job submitted. Check status with job ID: {job_id}") - return True - - except Exception as e: - logger.error(f"Error during Shifter image pull: {e}") - return False - - def check_shifter_image( - self, - image: str = None, - ) -> bool: - """ - Check if a container image is already in NERSC's Shifter cache. - - :param image: Container image to check (defaults to recon_image from config) - :return: True if image exists in cache, False otherwise - """ - logger.info("Checking Shifter image cache.") - - if image is None: - image = self.config.ghcr_images832["recon_image"] - - try: - # Run shifterimg images command - if self.login_method is NERSCLoginMethod.SFAPI: - # synchronous via utilities/command - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") - output = result if isinstance(result, str) else getattr(result, 'output', str(result)) - - elif self.login_method is NERSCLoginMethod.IRIAPI: - # async: submit job → wait → read stdout file - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" - output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" - check_script = f"""#!/bin/bash -#SBATCH -q debug -#SBATCH -A als -#SBATCH -C cpu -#SBATCH -N 1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=1 -#SBATCH --time=0:05:00 -shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true -""" - job_id = self._submit_job(check_script) - self._wait_for_job(job_id) - output = self._read_remote_file(output_file) - - if output.strip(): - logger.info(f"Image found in Shifter cache: {output.strip()}") - return True - else: - logger.info(f"Image not found in Shifter cache: {image}") - return False - - except Exception as e: - logger.warning(f"Error checking Shifter cache: {e}") - return False - def schedule_pruning( config: Config832, @@ -2533,10 +1963,10 @@ def pull_shifter_image_flow( ) # Check if already cached - if controller.check_shifter_image(image): + if check_shifter_image(controller, image): logger.info("Image already in cache, pulling anyway to update...") - success = controller.pull_shifter_image(image) + success = pull_shifter_image(controller, image) logger.info(f"Shifter image pull success: {success}") return success From 2806e9a15de7fdfc8c83b49fd8fbfe1b2250da43 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 26 May 2026 11:29:46 -0700 Subject: [PATCH 23/26] Adding pytests mirroring the jobs/ directory structure --- orchestration/_tests/test_jobs/__init__.py | 0 .../_tests/test_jobs/alcf/__init__.py | 0 .../_tests/test_jobs/alcf/test_controller.py | 198 +++++++++++ orchestration/_tests/test_jobs/conftest.py | 87 +++++ .../_tests/test_jobs/nersc/__init__.py | 0 .../_tests/test_jobs/nersc/test_controller.py | 333 ++++++++++++++++++ .../_tests/test_jobs/nersc/test_login.py | 85 +++++ .../_tests/test_jobs/nersc/test_shifter.py | 120 +++++++ .../_tests/test_jobs/test_controller.py | 71 ++++ .../_tests/test_jobs/test_options.py | 169 +++++++++ 10 files changed, 1063 insertions(+) create mode 100644 orchestration/_tests/test_jobs/__init__.py create mode 100644 orchestration/_tests/test_jobs/alcf/__init__.py create mode 100644 orchestration/_tests/test_jobs/alcf/test_controller.py create mode 100644 orchestration/_tests/test_jobs/conftest.py create mode 100644 orchestration/_tests/test_jobs/nersc/__init__.py create mode 100644 orchestration/_tests/test_jobs/nersc/test_controller.py create mode 100644 orchestration/_tests/test_jobs/nersc/test_login.py create mode 100644 orchestration/_tests/test_jobs/nersc/test_shifter.py create mode 100644 orchestration/_tests/test_jobs/test_controller.py create mode 100644 orchestration/_tests/test_jobs/test_options.py diff --git a/orchestration/_tests/test_jobs/__init__.py b/orchestration/_tests/test_jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/alcf/__init__.py b/orchestration/_tests/test_jobs/alcf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/alcf/test_controller.py b/orchestration/_tests/test_jobs/alcf/test_controller.py new file mode 100644 index 00000000..3cc463a1 --- /dev/null +++ b/orchestration/_tests/test_jobs/alcf/test_controller.py @@ -0,0 +1,198 @@ +"""Tests for orchestration/jobs/alcf/controller.py — ALCFJobController. + +Patch targets (confirmed from module-level imports in alcf/controller.py): + orchestration.jobs.alcf.controller.Variable.get (prefect Variable) + orchestration.jobs.alcf.controller.Secret.load (prefect Secret) + orchestration.jobs.alcf.controller.get_run_logger (prefect run logger) + orchestration.jobs.alcf.controller.Client (globus_compute_sdk.Client) + orchestration.jobs.alcf.controller.Executor (globus_compute_sdk.Executor) + orchestration.jobs.alcf.controller.time.sleep (skip polling delays) + +All TestALCFJobControllerInit and TestSubmit tests request the mock_alcf_prefect +fixture to satisfy Variable.get and Secret.load calls in __init__. + +TestWaitForFuture uses ALCFJobController.wait_for_future() as a @staticmethod — +no instance construction needed, so mock_alcf_prefect is not required there. +""" + +import pytest + +from orchestration.jobs.alcf.controller import ( + ALCFJobController, + _ALLOCATION_ROOT_VARIABLE, + _GLOBUS_COMPUTE_ENDPOINT_SECRET, +) + + +# ── Init ────────────────────────────────────────────────────────────────────── + +class TestALCFJobControllerInit: + def test_reads_allocation_root_from_variable(self, mocker, fake_config, mock_alcf_prefect): + ctrl = ALCFJobController(fake_config) + assert ctrl.allocation_root == "/eagle/IRIProd/ALS" + + def test_reads_endpoint_id_from_secret(self, mocker, fake_config, mock_alcf_prefect): + ctrl = ALCFJobController(fake_config) + assert ctrl.endpoint_id == "mock-endpoint-uuid" + + def test_variable_get_called_with_correct_name(self, mocker, fake_config, mock_alcf_prefect): + ALCFJobController(fake_config) + mock_alcf_prefect.variable.assert_called_once_with( + _ALLOCATION_ROOT_VARIABLE, _sync=True + ) + + def test_secret_load_called_with_correct_name(self, mocker, fake_config, mock_alcf_prefect): + ALCFJobController(fake_config) + # mock_alcf_prefect.secret IS the mock for Secret.load (not Secret itself) + mock_alcf_prefect.secret.assert_called_once_with(_GLOBUS_COMPUTE_ENDPOINT_SECRET) + + def test_raises_when_allocation_root_missing(self, mocker, fake_config): + # allocation_data.get(...) returns None → ValueError + mocker.patch( + "orchestration.jobs.alcf.controller.Variable.get", + return_value={}, # key absent → .get() returns None + ) + mocker.patch("orchestration.jobs.alcf.controller.Secret.load") + with pytest.raises(ValueError, match="Allocation root not found"): + ALCFJobController(fake_config) + + def test_stores_config(self, mocker, fake_config, mock_alcf_prefect): + ctrl = ALCFJobController(fake_config) + assert ctrl.config is fake_config + + +# ── submit ──────────────────────────────────────────────────────────────────── + +class TestSubmit: + def test_constructs_client_and_submits_via_executor(self, mocker, fake_config, mock_alcf_prefect): + mock_client_cls = mocker.patch("orchestration.jobs.alcf.controller.Client") + mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") + + mock_future = mocker.MagicMock() + mock_executor_instance = mocker.MagicMock() + mock_executor_instance.submit.return_value = mock_future + mock_executor_cls.return_value.__enter__ = mocker.MagicMock(return_value=mock_executor_instance) + mock_executor_cls.return_value.__exit__ = mocker.MagicMock(return_value=False) + + def noop(): + pass + + ctrl = ALCFJobController(fake_config) + result = ctrl.submit(noop) + + mock_client_cls.assert_called_once() + mock_executor_cls.assert_called_once_with( + endpoint_id="mock-endpoint-uuid", + client=mock_client_cls.return_value, + ) + mock_executor_instance.submit.assert_called_once() + assert result is mock_future + + def test_returns_future(self, mocker, fake_config, mock_alcf_prefect): + mocker.patch("orchestration.jobs.alcf.controller.Client") + mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") + + mock_future = mocker.MagicMock() + mock_executor_instance = mocker.MagicMock() + mock_executor_instance.submit.return_value = mock_future + mock_executor_cls.return_value.__enter__ = mocker.MagicMock(return_value=mock_executor_instance) + mock_executor_cls.return_value.__exit__ = mocker.MagicMock(return_value=False) + + def identity(x): + return x + + ctrl = ALCFJobController(fake_config) + future = ctrl.submit(identity, 42, key="val") + + mock_executor_instance.submit.assert_called_once_with(identity, 42, key="val") + assert future is mock_future + + +# ── wait_for_future ─────────────────────────────────────────────────────────── + +class TestWaitForFuture: + """wait_for_future is @staticmethod — call directly without constructing an instance.""" + + def _run_logger(self, mocker): + """Patch get_run_logger and return a mock logger.""" + run_logger = mocker.MagicMock() + mocker.patch( + "orchestration.jobs.alcf.controller.get_run_logger", + return_value=run_logger, + ) + return run_logger + + def test_returns_true_on_success(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = False + future.exception.return_value = None + future.result.return_value = "output" + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is True + + def test_returns_false_when_future_raises(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = False + future.exception.return_value = RuntimeError("job failed") + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is False + + def test_returns_false_when_cancelled(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.return_value = True + future.cancelled.return_value = True + + result = ALCFJobController.wait_for_future(future, "reconstruction") + assert result is False + + def test_returns_false_on_timeout(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + # Simulate time advancing past walltime by patching time.time + call_count = [0] + start = 1000.0 + + def fake_time(): + val = start + call_count[0] * 700 + call_count[0] += 1 + return val + + mocker.patch("orchestration.jobs.alcf.controller.time.time", side_effect=fake_time) + + future = mocker.MagicMock() + future.done.return_value = False # never completes + future.cancelled.return_value = False + + result = ALCFJobController.wait_for_future( + future, "reconstruction", check_interval=1, walltime=600 + ) + future.cancel.assert_called() + assert result is False + + def test_polls_until_done(self, mocker): + mocker.patch("orchestration.jobs.alcf.controller.time.sleep") + self._run_logger(mocker) + + future = mocker.MagicMock() + future.done.side_effect = [False, False, True] + future.cancelled.return_value = False + future.exception.return_value = None + future.result.return_value = "done" + + result = ALCFJobController.wait_for_future(future, "task", check_interval=1, walltime=3600) + assert future.done.call_count == 3 + assert result is True diff --git a/orchestration/_tests/test_jobs/conftest.py b/orchestration/_tests/test_jobs/conftest.py new file mode 100644 index 00000000..cd0c70b2 --- /dev/null +++ b/orchestration/_tests/test_jobs/conftest.py @@ -0,0 +1,87 @@ +import types + +import pytest +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(scope="session", autouse=True) +def prefect_test_fixture(): + """Wrap the entire test_jobs/ session in a Prefect test harness. + + Required because ALCFJobController.__init__ calls Variable.get(_sync=True), + which needs a live Prefect API — even when Variable.get is patched, the + import-time Prefect setup must succeed. + """ + with prefect_test_harness(): + yield + + +@pytest.fixture +def fake_config(): + """Minimal BeamlineConfig-like namespace for tests that need a config object.""" + return types.SimpleNamespace( + nersc_resources={ + "iri": { + "api_base_url": "https://mock-iri.nersc.gov", + "perlmutter_login": "mock-login-uuid", + "perlmutter_job_submit": "mock-submit-uuid", + "compute_resource": "mock-compute-uuid", + }, + "sfapi": {"api_base_url": "https://mock-sfapi.nersc.gov"}, + }, + mlflow={"tracking_uri": "http://mock-mlflow:5000"}, + ) + + +@pytest.fixture +def mock_sfapi_client(mocker): + """MagicMock shaped like an sfapi_client.Client.""" + client = mocker.MagicMock() + user = mocker.MagicMock() + user.name = "testuser" + client.user.return_value = user + return client + + +@pytest.fixture +def mock_iriapi_client(mocker): + """MagicMock shaped like an httpx.Client targeting the IRI API.""" + client = mocker.MagicMock() + client.post.return_value = mocker.MagicMock( + is_success=True, json=lambda: {"id": "job-42"} + ) + client.get.return_value = mocker.MagicMock( + is_success=True, + json=lambda: {"status": {"state": "completed"}}, + ) + return client + + +@pytest.fixture +def mock_alcf_prefect(mocker): + """Patch Variable.get and Secret.load in the ALCF controller module. + + allocation_data is a real dict, not a MagicMock — the constructor calls + allocation_data.get("alcf-allocation-root-path") and checks truthiness of + the result. A MagicMock would pass the check with an arbitrary truthy value, + masking bugs. + + Tests that need Prefect mocked request this fixture explicitly. + """ + allocation_data = {"alcf-allocation-root-path": "/eagle/IRIProd/ALS"} + var_mock = mocker.patch( + "orchestration.jobs.alcf.controller.Variable.get", + return_value=allocation_data, + ) + secret_mock = mocker.patch("orchestration.jobs.alcf.controller.Secret.load") + secret_mock.return_value.get.return_value = "mock-endpoint-uuid" + return types.SimpleNamespace(variable=var_mock, secret=secret_mock) + + +@pytest.fixture +def mock_options_prefect(mocker): + """Patch Variable.get in the options module (used by load_job_options tests).""" + return mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": True}, + ) diff --git a/orchestration/_tests/test_jobs/nersc/__init__.py b/orchestration/_tests/test_jobs/nersc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_jobs/nersc/test_controller.py b/orchestration/_tests/test_jobs/nersc/test_controller.py new file mode 100644 index 00000000..b3fad7f2 --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_controller.py @@ -0,0 +1,333 @@ +"""Tests for orchestration/jobs/nersc/controller.py — NERSCJobController. + +Patch targets: + orchestration.jobs.nersc.controller.time.sleep (skip 60-second polling delays) + +No Prefect Variable/Secret imports in this module — no Prefect mocking needed. +""" + +import pytest + +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _sfapi_controller(mocker, fake_config): + client = mocker.MagicMock() + user = mocker.MagicMock() + user.name = "sfapiuser" + client.user.return_value = user + return NERSCJobController(fake_config, client=client, login_method=NERSCLoginMethod.SFAPI) + + +def _iriapi_controller(mocker, fake_config): + client = mocker.MagicMock() + # POST (submit_job): json() returns the job ID dict + post_response = mocker.MagicMock(is_success=True) + post_response.json.return_value = {"id": "job-99"} + client.post.return_value = post_response + # GET (wait_for_job, read_remote_file): default to completed state + get_response = mocker.MagicMock(is_success=True) + get_response.json.return_value = {"status": {"state": "completed"}} + get_response.text = "" + client.get.return_value = get_response + return NERSCJobController(fake_config, client=client, login_method=NERSCLoginMethod.IRIAPI) + + +# ── Initialization ──────────────────────────────────────────────────────────── + +class TestNERSCJobControllerInit: + def test_sfapi_stores_sfapi_nersc_resources(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + assert ctrl.nersc_resources == fake_config.nersc_resources["sfapi"] + + def test_iriapi_stores_iri_nersc_resources(self, mocker, fake_config): + ctrl = _iriapi_controller(mocker, fake_config) + assert ctrl.nersc_resources == fake_config.nersc_resources["iri"] + + def test_stores_login_method(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + assert ctrl.login_method is NERSCLoginMethod.SFAPI + + def test_unknown_login_method_raises(self, mocker, fake_config): + bad_method = mocker.MagicMock() + bad_method.__eq__ = lambda s, o: False + bad_method.__ne__ = lambda s, o: True + with pytest.raises(ValueError, match="Unsupported NERSCLoginMethod"): + NERSCJobController(fake_config, client=None, login_method=bad_method) + + +# ── get_nersc_username ──────────────────────────────────────────────────────── + +class TestGetNerscUsername: + def test_sfapi_reads_name_from_client(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + assert ctrl.get_nersc_username() == "sfapiuser" + + def test_iriapi_reads_from_env(self, mocker, fake_config, monkeypatch): + monkeypatch.setenv("NERSC_USERNAME", "envuser") + ctrl = _iriapi_controller(mocker, fake_config) + assert ctrl.get_nersc_username() == "envuser" + + def test_iriapi_raises_when_env_unset(self, mocker, fake_config, monkeypatch): + monkeypatch.delenv("NERSC_USERNAME", raising=False) + ctrl = _iriapi_controller(mocker, fake_config) + with pytest.raises(ValueError, match="NERSC_USERNAME must be set"): + ctrl.get_nersc_username() + + +# ── submit_job ──────────────────────────────────────────────────────────────── + +class TestSubmitJob: + def test_sfapi_returns_job_id_string(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + job = mocker.MagicMock() + job.jobid = 12345 + ctrl.client.compute.return_value.submit_job.return_value = job + result = ctrl.submit_job("#!/bin/bash\n#SBATCH -q debug\necho hi") + assert result == "12345" + + def test_sfapi_calls_perlmutter_submit_job(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + job = mocker.MagicMock() + job.jobid = "abc" + perlmutter = ctrl.client.compute.return_value + perlmutter.submit_job.return_value = job + ctrl.submit_job("script") + perlmutter.submit_job.assert_called_once_with("script") + + def test_iriapi_returns_job_id_string(self, mocker, fake_config): + ctrl = _iriapi_controller(mocker, fake_config) + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + result = ctrl.submit_job(script) + assert result == "job-99" + + def test_iriapi_posts_to_job_submit_url(self, mocker, fake_config): + ctrl = _iriapi_controller(mocker, fake_config) + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ctrl.submit_job(script) + call_args = ctrl.client.post.call_args + assert "mock-submit-uuid" in call_args[0][0] + + +# ── _submit_job_iriapi SBATCH parsing ───────────────────────────────────────── + +class TestSubmitJobIRIAPI: + """Tests the SBATCH header parsing logic in _submit_job_iriapi.""" + + def _submit_and_capture_spec(self, mocker, fake_config, script): + ctrl = _iriapi_controller(mocker, fake_config) + captured = {} + + def capture_post(url, json=None, **kwargs): + captured["json"] = json + resp = mocker.MagicMock(is_success=True) + resp.json.return_value = {"id": "captured-id"} + return resp + + ctrl.client.post = capture_post + ctrl._submit_job_iriapi(script) + return captured["json"] + + def test_parses_queue_name(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q premium\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert spec["attributes"]["queue_name"] == "premium" + + def test_parses_account(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A myproject\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert spec["attributes"]["account"] == "myproject" + + def test_parses_walltime_to_seconds(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=01:30:00\n#SBATCH -N 1\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert spec["attributes"]["duration"] == 5400 # 1h30m in seconds + + def test_parses_node_count(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 4\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert spec["resources"]["node_count"] == 4 + + def test_cpu_constraint_adds_cpu_cores(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C cpu\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert "cpu_cores_per_process" in spec["resources"] + assert "gpu_cores_per_process" not in spec["resources"] + + def test_gpu_constraint_adds_gpu_cores(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C gpu\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert "gpu_cores_per_process" in spec["resources"] + assert "cpu_cores_per_process" not in spec["resources"] + + def test_reservation_included_when_present(self, mocker, fake_config): + script = ( + "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" + "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH --reservation=myres\necho hi" + ) + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert spec["attributes"]["reservation_id"] == "myres" + + def test_no_reservation_when_absent(self, mocker, fake_config): + script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" + spec = self._submit_and_capture_spec(mocker, fake_config, script) + assert "reservation_id" not in spec["attributes"] + + +# ── wait_for_job ────────────────────────────────────────────────────────────── + +class TestWaitForJob: + def test_sfapi_returns_true_on_complete(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + job = mocker.MagicMock() + ctrl.client.compute.return_value.job.return_value = job + result = ctrl.wait_for_job("12345") + job.complete.assert_called_once() + assert result is True + + def test_iriapi_returns_true_when_completed(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "completed"}} + result = ctrl.wait_for_job("42") + assert result is True + + def test_iriapi_returns_false_on_failed(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "failed"}} + result = ctrl.wait_for_job("42") + assert result is False + + def test_iriapi_returns_false_on_canceled(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + ctrl.client.get.return_value.json.return_value = {"status": {"state": "canceled"}} + result = ctrl.wait_for_job("42") + assert result is False + + def test_iriapi_polls_until_terminal_state(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + responses = [ + {"status": {"state": "running"}}, + {"status": {"state": "running"}}, + {"status": {"state": "completed"}}, + ] + ctrl.client.get.return_value.json.side_effect = responses + result = ctrl.wait_for_job("42") + assert ctrl.client.get.call_count == 3 + assert result is True + + +# ── mkdir_remote ────────────────────────────────────────────────────────────── + +class TestMkdirRemote: + def test_sfapi_runs_mkdir(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") + ctrl.client.compute.return_value.run.assert_called_once() + cmd = ctrl.client.compute.return_value.run.call_args[0][0] + assert "mkdir -p" in cmd + assert "/pscratch/sd/t/testuser/mydir" in cmd + + def test_iriapi_posts_to_mkdir_url(self, mocker, fake_config): + ctrl = _iriapi_controller(mocker, fake_config) + ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") + ctrl.client.post.assert_called() + url = ctrl.client.post.call_args[0][0] + assert "mock-login-uuid" in url + + def test_iriapi_posts_path_in_body(self, mocker, fake_config): + ctrl = _iriapi_controller(mocker, fake_config) + ctrl.mkdir_remote("/some/path") + body = ctrl.client.post.call_args[1]["json"] + assert body["path"] == "/some/path" + assert body["parents"] is True + + +# ── read_remote_file ────────────────────────────────────────────────────────── + +class TestReadRemoteFile: + def test_sfapi_returns_string_result(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + ctrl.client.compute.return_value.run.return_value = "file contents" + result = ctrl.read_remote_file("/some/file.txt") + assert result == "file contents" + + def test_sfapi_extracts_output_attribute(self, mocker, fake_config): + ctrl = _sfapi_controller(mocker, fake_config) + run_result = mocker.MagicMock(spec=[]) # no __str__ shortcuts + run_result.output = "from output attr" + ctrl.client.compute.return_value.run.return_value = run_result + result = ctrl.read_remote_file("/some/file.txt") + assert result == "from output attr" + + def test_iriapi_returns_file_contents_on_completed_task(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + + # First call: GET /filesystem/view → task_id + # Subsequent calls: GET /task/ → status=completed + result + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-abc"} + view_response.text = "" + + task_response = mocker.MagicMock(is_success=True) + task_response.json.return_value = {"status": "completed", "result": "file data"} + + ctrl.client.get.side_effect = [view_response, task_response] + result = ctrl.read_remote_file("/pscratch/data.txt") + assert result == "file data" + + def test_iriapi_raises_on_failed_task(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-fail"} + view_response.text = "" + + task_response = mocker.MagicMock(is_success=True) + task_response.json.return_value = {"status": "failed", "result": "disk error"} + + ctrl.client.get.side_effect = [view_response, task_response] + with pytest.raises(RuntimeError, match="failed"): + ctrl.read_remote_file("/pscratch/data.txt") + + def test_iriapi_raises_timeout_after_40_polls(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.controller.time.sleep") + ctrl = _iriapi_controller(mocker, fake_config) + + view_response = mocker.MagicMock(is_success=True) + view_response.json.return_value = {"task_id": "task-slow"} + view_response.text = "" + + # Always return "pending" — never completes + pending_response = mocker.MagicMock(is_success=True) + pending_response.json.return_value = {"status": "pending", "result": None} + + ctrl.client.get.side_effect = [view_response] + [pending_response] * 40 + with pytest.raises(TimeoutError): + ctrl.read_remote_file("/pscratch/slow.txt") diff --git a/orchestration/_tests/test_jobs/nersc/test_login.py b/orchestration/_tests/test_jobs/nersc/test_login.py new file mode 100644 index 00000000..c43538d7 --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_login.py @@ -0,0 +1,85 @@ +"""Tests for orchestration/jobs/nersc/login.py. + +Patch targets (nersc/login.py has NO Variable or Secret imports): + orchestration.jobs.nersc.login._create_sfapi_client (private builder) + orchestration.jobs.nersc.login._create_iriapi_client (private builder) + +The private builders are patched, not tested directly — they are thin wrappers +around third-party constructors with extensive credential I/O (sfapi_client.Client +__init__, env vars, file reads). Patching them is sufficient to exercise +create_nersc_client's dispatch logic, which is the only thing this module adds. +That choice is intentional. +""" + +import pytest + +from orchestration.jobs.nersc.login import NERSCLoginMethod, create_nersc_client + + +# ── Structural identity check ───────────────────────────────────────────────── + +def test_login_method_is_canonical(): + """NERSCLoginMethod is defined exactly once; bl832 re-export is the same object. + + Uses `is`, not `==` — equality passes even for two separate enum classes + whose members have matching names and values. + """ + from orchestration.jobs.nersc.login import NERSCLoginMethod as A + from orchestration.flows.bl832.job_controller import NERSCLoginMethod as B + assert A is B + + +# ── NERSCLoginMethod enum ───────────────────────────────────────────────────── + +class TestNERSCLoginMethod: + def test_sfapi_value(self): + assert NERSCLoginMethod.SFAPI.value == "sfapi" + + def test_iriapi_value(self): + assert NERSCLoginMethod.IRIAPI.value == "iriapi" + + def test_membership_sfapi(self): + assert NERSCLoginMethod("sfapi") is NERSCLoginMethod.SFAPI + + def test_membership_iriapi(self): + assert NERSCLoginMethod("iriapi") is NERSCLoginMethod.IRIAPI + + +# ── create_nersc_client dispatch ────────────────────────────────────────────── + +class TestCreateNerscClient: + def test_sfapi_dispatches_to_sfapi_builder(self, mocker, fake_config): + mock_client = mocker.MagicMock() + builder = mocker.patch( + "orchestration.jobs.nersc.login._create_sfapi_client", + return_value=mock_client, + ) + result = create_nersc_client(fake_config, NERSCLoginMethod.SFAPI) + builder.assert_called_once_with() + assert result is mock_client + + def test_iriapi_dispatches_to_iriapi_builder(self, mocker, fake_config): + mock_client = mocker.MagicMock() + builder = mocker.patch( + "orchestration.jobs.nersc.login._create_iriapi_client", + return_value=mock_client, + ) + result = create_nersc_client(fake_config, NERSCLoginMethod.IRIAPI) + builder.assert_called_once_with(fake_config.nersc_resources["iri"]["api_base_url"]) + assert result is mock_client + + def test_sfapi_passes_api_base_url_from_config(self, mocker, fake_config): + mocker.patch("orchestration.jobs.nersc.login._create_sfapi_client") + # No assertion on URL for SFAPI (the builder doesn't take a URL arg), + # but create_nersc_client must read the sfapi sub-dict without raising. + create_nersc_client(fake_config, NERSCLoginMethod.SFAPI) + + def test_iriapi_passes_correct_api_base_url(self, mocker, fake_config): + builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") + create_nersc_client(fake_config, NERSCLoginMethod.IRIAPI) + builder.assert_called_once_with("https://mock-iri.nersc.gov") + + def test_default_login_method_is_iriapi(self, mocker, fake_config): + builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") + create_nersc_client(fake_config) + builder.assert_called_once() diff --git a/orchestration/_tests/test_jobs/nersc/test_shifter.py b/orchestration/_tests/test_jobs/nersc/test_shifter.py new file mode 100644 index 00000000..d855bc71 --- /dev/null +++ b/orchestration/_tests/test_jobs/nersc/test_shifter.py @@ -0,0 +1,120 @@ +"""Tests for orchestration/jobs/nersc/shifter.py. + +All tests use mocker.MagicMock(spec=NERSCJobController) — spec= is required so +that drift in the controller interface (renamed/removed methods) breaks these +tests rather than silently passing. + +check_shifter_image SFAPI branch deferred import note: + check_shifter_image does `from sfapi_client.compute import Machine` inside the + function body (line 204). When setting up the SFAPI mock, use a plain + mocker.MagicMock() (not spec'd) for the compute() return value — Machine.perlmutter + is evaluated after the patch, so a spec'd mock keyed on the Machine class would fail. + +Patch targets: + orchestration.jobs.nersc.shifter.time.sleep (skips 30-second pull delay) +""" + +from orchestration.jobs.nersc.controller import NERSCJobController +from orchestration.jobs.nersc.login import NERSCLoginMethod +from orchestration.jobs.nersc.shifter import check_shifter_image, pull_shifter_image + + +def _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI): + controller = mocker.MagicMock(spec=NERSCJobController) + controller.login_method = login_method + # client is an instance attribute (set in __init__) — not in the class spec, + # so set it explicitly here so tests can configure controller.client.compute. + controller.client = mocker.MagicMock() + controller.get_nersc_username.return_value = "testuser" + controller.submit_job.return_value = "job-123" + controller.wait_for_job.return_value = True + controller.read_remote_file.return_value = "ghcr.io/als-computing/image:latest found" + return controller + + +class TestPullShifterImage: + def test_submits_pull_script_and_returns_success(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + result = pull_shifter_image(controller, "docker:ghcr.io/als/image:latest") + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_called_once_with("job-123") + assert result is True + + def test_script_contains_image_name(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + pull_shifter_image(controller, "docker:ghcr.io/als/myimage:v2") + script = controller.submit_job.call_args[0][0] + assert "docker:ghcr.io/als/myimage:v2" in script + + def test_returns_true_when_wait_false(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + result = pull_shifter_image(controller, "docker:image:latest", wait=False) + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_not_called() + assert result is True + + def test_returns_false_on_submit_exception(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + controller.submit_job.side_effect = RuntimeError("NERSC down") + result = pull_shifter_image(controller, "docker:image:latest") + assert result is False + + def test_mkdir_remote_called_for_log_dir(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker) + pull_shifter_image(controller, "docker:image:latest") + controller.mkdir_remote.assert_called_once() + log_dir_arg = controller.mkdir_remote.call_args[0][0] + assert "testuser" in log_dir_arg + + +class TestCheckShifterImage: + def test_sfapi_returns_true_when_grep_matches(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + # Use plain MagicMock (not spec'd) for perlmutter — Machine.perlmutter + # is evaluated after the patch inside the function body. + perlmutter = mocker.MagicMock() + perlmutter.run.return_value = "ghcr.io/als/image:latest found in cache" + controller.client.compute.return_value = perlmutter + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is True + + def test_sfapi_returns_false_when_no_match(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + perlmutter = mocker.MagicMock() + perlmutter.run.return_value = "" + controller.client.compute.return_value = perlmutter + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is False + + def test_iriapi_submits_check_job_and_reads_output(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker, login_method=NERSCLoginMethod.IRIAPI) + controller.read_remote_file.return_value = "ghcr.io/als/image:latest cached" + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + controller.submit_job.assert_called_once() + controller.wait_for_job.assert_called_once_with("job-123") + controller.read_remote_file.assert_called_once() + assert result is True + + def test_iriapi_returns_false_when_image_not_in_output(self, mocker): + mocker.patch("orchestration.jobs.nersc.shifter.time.sleep") + controller = _make_controller(mocker, login_method=NERSCLoginMethod.IRIAPI) + controller.read_remote_file.return_value = "" + + result = check_shifter_image(controller, "docker:ghcr.io/als/image:latest") + assert result is False + + def test_returns_false_on_exception(self, mocker): + controller = _make_controller(mocker, login_method=NERSCLoginMethod.SFAPI) + controller.client.compute.side_effect = RuntimeError("network error") + + result = check_shifter_image(controller, "docker:image:latest") + assert result is False diff --git a/orchestration/_tests/test_jobs/test_controller.py b/orchestration/_tests/test_jobs/test_controller.py new file mode 100644 index 00000000..505866e7 --- /dev/null +++ b/orchestration/_tests/test_jobs/test_controller.py @@ -0,0 +1,71 @@ +"""Tests for orchestration/jobs/controller.py.""" + +from orchestration.jobs.controller import JobController, JobTarget + + +# ── Structural identity check ───────────────────────────────────────────────── + +def test_hpc_alias_is_canonical_job_target(): + """HPC in bl832/job_controller.py must be the same object as JobTarget. + + Uses `is`, not `==` — equality passes even for two separate enum classes + whose members have matching names and values. `is` catches a re-introduced + duplicate enum class immediately. + """ + from orchestration.flows.bl832.job_controller import HPC + assert HPC is JobTarget + + +# ── JobTarget enum ──────────────────────────────────────────────────────────── + +class TestJobTarget: + def test_has_alcf_member(self): + assert JobTarget.ALCF is not None + + def test_has_nersc_member(self): + assert JobTarget.NERSC is not None + + def test_has_olcf_member(self): + assert JobTarget.OLCF is not None + + def test_nersc_string_value(self): + assert JobTarget.NERSC.value == "NERSC" + + def test_alcf_string_value(self): + assert JobTarget.ALCF.value == "ALCF" + + def test_olcf_string_value(self): + assert JobTarget.OLCF.value == "OLCF" + + def test_membership_by_value(self): + assert JobTarget("NERSC") is JobTarget.NERSC + + +# ── JobController ABC ───────────────────────────────────────────────────────── + +class TestJobControllerABC: + """JobController is an ABC but declares NO @abstractmethod. + + Direct instantiation succeeds when a valid config is provided. Tests here + verify the contract (stores config, ABC inheritance) rather than checking + for TypeError on instantiation. + """ + + def test_instantiates_with_valid_config(self, fake_config): + controller = JobController(fake_config) + assert controller is not None + + def test_stores_config(self, fake_config): + controller = JobController(fake_config) + assert controller.config is fake_config + + def test_subclass_inherits_config(self, fake_config): + class DummyJobController(JobController): + pass + + controller = DummyJobController(fake_config) + assert controller.config is fake_config + + def test_is_abc_subclass(self): + from abc import ABC + assert issubclass(JobController, ABC) diff --git a/orchestration/_tests/test_jobs/test_options.py b/orchestration/_tests/test_jobs/test_options.py new file mode 100644 index 00000000..2ad1394c --- /dev/null +++ b/orchestration/_tests/test_jobs/test_options.py @@ -0,0 +1,169 @@ +"""Tests for orchestration/jobs/options.py — load_job_options three-layer resolution. + +Patch targets verified from source: + orchestration.jobs.options.Variable.get (prefect.variables.Variable imported at top) + orchestration.jobs.options.get_checkpoint_info (orchestration.mlflow imported at top) +""" + +import json + +from orchestration.jobs.options import load_job_options + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def _make_checkpoint(mocker, *, nersc_path="/pscratch/checkpoint.pt", inference_params=None): + cp = mocker.MagicMock() + cp.nersc_path = nersc_path + cp.inference_params = inference_params or {} + return cp + + +# ── Tests ───────────────────────────────────────────────────────────────────── + +class TestLoadJobOptions: + + # ── Layer 1: config defaults ────────────────────────────────────────────── + + def test_returns_config_defaults_when_variable_says_defaults(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + opts = load_job_options("some-var", {"key": "value"}) + assert opts == {"key": "value"} + + def test_returns_copy_not_original(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + base = {"key": "value"} + opts = load_job_options("some-var", base) + opts["key"] = "mutated" + assert base["key"] == "value" + + # ── Layer 2: MLflow ─────────────────────────────────────────────────────── + + def test_mlflow_nersc_path_maps_to_checkpoint_key(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, nersc_path="/pscratch/model.pt") + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"finetuned_checkpoint_path": "/old/path"}, + config=fake_config, + mlflow_model_name="my-model", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + assert opts["finetuned_checkpoint_path"] == "/pscratch/model.pt" + + def test_mlflow_inference_params_overlay_config_defaults(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, inference_params={"batch_size": 16, "threshold": 0.5}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"batch_size": 8, "threshold": 0.3, "other": "kept"}, + config=fake_config, + mlflow_model_name="my-model", + ) + assert opts["batch_size"] == 16 + assert opts["threshold"] == 0.5 + assert opts["other"] == "kept" + + def test_mlflow_layer_skipped_when_config_is_none(self, mocker): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") + + opts = load_job_options("var", {"key": "value"}, config=None, mlflow_model_name="model") + spy.assert_not_called() + assert opts == {"key": "value"} + + def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") + + opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name=None) + spy.assert_not_called() + assert opts == {"key": "value"} + + def test_mlflow_fallback_to_config_when_checkpoint_is_none(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=None) + + opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name="model") + assert opts == {"key": "value"} + + def test_mlflow_fallback_to_config_when_get_checkpoint_raises(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + mocker.patch( + "orchestration.jobs.options.get_checkpoint_info", + side_effect=RuntimeError("mlflow unreachable"), + ) + + opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name="model") + assert opts == {"key": "value"} + + def test_mlflow_injects_new_keys_not_in_config(self, mocker, fake_config): + mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) + cp = _make_checkpoint(mocker, inference_params={"new_param": "injected"}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options("var", {"existing": "kept"}, config=fake_config, mlflow_model_name="model") + assert opts["new_param"] == "injected" + assert opts["existing"] == "kept" + + # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── + + def test_prefect_variable_overrides_win_over_config(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "key": "override"}, + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "override" + + def test_defaults_true_suppresses_variable_overrides(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": True, "key": "would-override"}, + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "config-default" + + def test_json_string_variable_is_parsed(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value=json.dumps({"defaults": False, "key": "from-json"}), + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts["key"] == "from-json" + + def test_variable_get_failure_falls_back_to_opts(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + side_effect=Exception("Prefect unavailable"), + ) + opts = load_job_options("var", {"key": "config-default"}) + assert opts == {"key": "config-default"} + + def test_prefect_variable_wins_over_mlflow(self, mocker, fake_config): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "batch_size": 99}, + ) + cp = _make_checkpoint(mocker, inference_params={"batch_size": 16}) + mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) + + opts = load_job_options( + "var", + {"batch_size": 8}, + config=fake_config, + mlflow_model_name="model", + ) + assert opts["batch_size"] == 99 + + def test_defaults_key_not_present_in_output(self, mocker): + mocker.patch( + "orchestration.jobs.options.Variable.get", + return_value={"defaults": False, "key": "override"}, + ) + opts = load_job_options("var", {}) + assert "defaults" not in opts From 31de0f17b15ef6a753ab5e10ab0c331f40346437 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 26 May 2026 11:35:24 -0700 Subject: [PATCH 24/26] renaming dummy/fake --> mock in unit tests --- .../_tests/test_jobs/alcf/test_controller.py | 38 ++--- orchestration/_tests/test_jobs/conftest.py | 2 +- .../_tests/test_jobs/nersc/test_controller.py | 144 +++++++++--------- .../_tests/test_jobs/nersc/test_login.py | 24 ++- .../_tests/test_jobs/test_controller.py | 18 +-- .../_tests/test_jobs/test_options.py | 28 ++-- 6 files changed, 126 insertions(+), 128 deletions(-) diff --git a/orchestration/_tests/test_jobs/alcf/test_controller.py b/orchestration/_tests/test_jobs/alcf/test_controller.py index 3cc463a1..d8dbf352 100644 --- a/orchestration/_tests/test_jobs/alcf/test_controller.py +++ b/orchestration/_tests/test_jobs/alcf/test_controller.py @@ -27,26 +27,26 @@ # ── Init ────────────────────────────────────────────────────────────────────── class TestALCFJobControllerInit: - def test_reads_allocation_root_from_variable(self, mocker, fake_config, mock_alcf_prefect): - ctrl = ALCFJobController(fake_config) + def test_reads_allocation_root_from_variable(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) assert ctrl.allocation_root == "/eagle/IRIProd/ALS" - def test_reads_endpoint_id_from_secret(self, mocker, fake_config, mock_alcf_prefect): - ctrl = ALCFJobController(fake_config) + def test_reads_endpoint_id_from_secret(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) assert ctrl.endpoint_id == "mock-endpoint-uuid" - def test_variable_get_called_with_correct_name(self, mocker, fake_config, mock_alcf_prefect): - ALCFJobController(fake_config) + def test_variable_get_called_with_correct_name(self, mocker, mock_config, mock_alcf_prefect): + ALCFJobController(mock_config) mock_alcf_prefect.variable.assert_called_once_with( _ALLOCATION_ROOT_VARIABLE, _sync=True ) - def test_secret_load_called_with_correct_name(self, mocker, fake_config, mock_alcf_prefect): - ALCFJobController(fake_config) + def test_secret_load_called_with_correct_name(self, mocker, mock_config, mock_alcf_prefect): + ALCFJobController(mock_config) # mock_alcf_prefect.secret IS the mock for Secret.load (not Secret itself) mock_alcf_prefect.secret.assert_called_once_with(_GLOBUS_COMPUTE_ENDPOINT_SECRET) - def test_raises_when_allocation_root_missing(self, mocker, fake_config): + def test_raises_when_allocation_root_missing(self, mocker, mock_config): # allocation_data.get(...) returns None → ValueError mocker.patch( "orchestration.jobs.alcf.controller.Variable.get", @@ -54,17 +54,17 @@ def test_raises_when_allocation_root_missing(self, mocker, fake_config): ) mocker.patch("orchestration.jobs.alcf.controller.Secret.load") with pytest.raises(ValueError, match="Allocation root not found"): - ALCFJobController(fake_config) + ALCFJobController(mock_config) - def test_stores_config(self, mocker, fake_config, mock_alcf_prefect): - ctrl = ALCFJobController(fake_config) - assert ctrl.config is fake_config + def test_stores_config(self, mocker, mock_config, mock_alcf_prefect): + ctrl = ALCFJobController(mock_config) + assert ctrl.config is mock_config # ── submit ──────────────────────────────────────────────────────────────────── class TestSubmit: - def test_constructs_client_and_submits_via_executor(self, mocker, fake_config, mock_alcf_prefect): + def test_constructs_client_and_submits_via_executor(self, mocker, mock_config, mock_alcf_prefect): mock_client_cls = mocker.patch("orchestration.jobs.alcf.controller.Client") mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") @@ -77,7 +77,7 @@ def test_constructs_client_and_submits_via_executor(self, mocker, fake_config, m def noop(): pass - ctrl = ALCFJobController(fake_config) + ctrl = ALCFJobController(mock_config) result = ctrl.submit(noop) mock_client_cls.assert_called_once() @@ -88,7 +88,7 @@ def noop(): mock_executor_instance.submit.assert_called_once() assert result is mock_future - def test_returns_future(self, mocker, fake_config, mock_alcf_prefect): + def test_returns_future(self, mocker, mock_config, mock_alcf_prefect): mocker.patch("orchestration.jobs.alcf.controller.Client") mock_executor_cls = mocker.patch("orchestration.jobs.alcf.controller.Executor") @@ -101,7 +101,7 @@ def test_returns_future(self, mocker, fake_config, mock_alcf_prefect): def identity(x): return x - ctrl = ALCFJobController(fake_config) + ctrl = ALCFJobController(mock_config) future = ctrl.submit(identity, 42, key="val") mock_executor_instance.submit.assert_called_once_with(identity, 42, key="val") @@ -166,12 +166,12 @@ def test_returns_false_on_timeout(self, mocker): call_count = [0] start = 1000.0 - def fake_time(): + def mock_time(): val = start + call_count[0] * 700 call_count[0] += 1 return val - mocker.patch("orchestration.jobs.alcf.controller.time.time", side_effect=fake_time) + mocker.patch("orchestration.jobs.alcf.controller.time.time", side_effect=mock_time) future = mocker.MagicMock() future.done.return_value = False # never completes diff --git a/orchestration/_tests/test_jobs/conftest.py b/orchestration/_tests/test_jobs/conftest.py index cd0c70b2..80ca8ac5 100644 --- a/orchestration/_tests/test_jobs/conftest.py +++ b/orchestration/_tests/test_jobs/conftest.py @@ -17,7 +17,7 @@ def prefect_test_fixture(): @pytest.fixture -def fake_config(): +def mock_config(): """Minimal BeamlineConfig-like namespace for tests that need a config object.""" return types.SimpleNamespace( nersc_resources={ diff --git a/orchestration/_tests/test_jobs/nersc/test_controller.py b/orchestration/_tests/test_jobs/nersc/test_controller.py index b3fad7f2..eb38ee52 100644 --- a/orchestration/_tests/test_jobs/nersc/test_controller.py +++ b/orchestration/_tests/test_jobs/nersc/test_controller.py @@ -14,15 +14,15 @@ # ── Helpers ─────────────────────────────────────────────────────────────────── -def _sfapi_controller(mocker, fake_config): +def _sfapi_controller(mocker, mock_config): client = mocker.MagicMock() user = mocker.MagicMock() user.name = "sfapiuser" client.user.return_value = user - return NERSCJobController(fake_config, client=client, login_method=NERSCLoginMethod.SFAPI) + return NERSCJobController(mock_config, client=client, login_method=NERSCLoginMethod.SFAPI) -def _iriapi_controller(mocker, fake_config): +def _iriapi_controller(mocker, mock_config): client = mocker.MagicMock() # POST (submit_job): json() returns the job ID dict post_response = mocker.MagicMock(is_success=True) @@ -33,47 +33,47 @@ def _iriapi_controller(mocker, fake_config): get_response.json.return_value = {"status": {"state": "completed"}} get_response.text = "" client.get.return_value = get_response - return NERSCJobController(fake_config, client=client, login_method=NERSCLoginMethod.IRIAPI) + return NERSCJobController(mock_config, client=client, login_method=NERSCLoginMethod.IRIAPI) # ── Initialization ──────────────────────────────────────────────────────────── class TestNERSCJobControllerInit: - def test_sfapi_stores_sfapi_nersc_resources(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) - assert ctrl.nersc_resources == fake_config.nersc_resources["sfapi"] + def test_sfapi_stores_sfapi_nersc_resources(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) + assert ctrl.nersc_resources == mock_config.nersc_resources["sfapi"] - def test_iriapi_stores_iri_nersc_resources(self, mocker, fake_config): - ctrl = _iriapi_controller(mocker, fake_config) - assert ctrl.nersc_resources == fake_config.nersc_resources["iri"] + def test_iriapi_stores_iri_nersc_resources(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) + assert ctrl.nersc_resources == mock_config.nersc_resources["iri"] - def test_stores_login_method(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_stores_login_method(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) assert ctrl.login_method is NERSCLoginMethod.SFAPI - def test_unknown_login_method_raises(self, mocker, fake_config): + def test_unknown_login_method_raises(self, mocker, mock_config): bad_method = mocker.MagicMock() bad_method.__eq__ = lambda s, o: False bad_method.__ne__ = lambda s, o: True with pytest.raises(ValueError, match="Unsupported NERSCLoginMethod"): - NERSCJobController(fake_config, client=None, login_method=bad_method) + NERSCJobController(mock_config, client=None, login_method=bad_method) # ── get_nersc_username ──────────────────────────────────────────────────────── class TestGetNerscUsername: - def test_sfapi_reads_name_from_client(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_reads_name_from_client(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) assert ctrl.get_nersc_username() == "sfapiuser" - def test_iriapi_reads_from_env(self, mocker, fake_config, monkeypatch): + def test_iriapi_reads_from_env(self, mocker, mock_config, monkeypatch): monkeypatch.setenv("NERSC_USERNAME", "envuser") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) assert ctrl.get_nersc_username() == "envuser" - def test_iriapi_raises_when_env_unset(self, mocker, fake_config, monkeypatch): + def test_iriapi_raises_when_env_unset(self, mocker, mock_config, monkeypatch): monkeypatch.delenv("NERSC_USERNAME", raising=False) - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) with pytest.raises(ValueError, match="NERSC_USERNAME must be set"): ctrl.get_nersc_username() @@ -81,16 +81,16 @@ def test_iriapi_raises_when_env_unset(self, mocker, fake_config, monkeypatch): # ── submit_job ──────────────────────────────────────────────────────────────── class TestSubmitJob: - def test_sfapi_returns_job_id_string(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_returns_job_id_string(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) job = mocker.MagicMock() job.jobid = 12345 ctrl.client.compute.return_value.submit_job.return_value = job result = ctrl.submit_job("#!/bin/bash\n#SBATCH -q debug\necho hi") assert result == "12345" - def test_sfapi_calls_perlmutter_submit_job(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_calls_perlmutter_submit_job(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) job = mocker.MagicMock() job.jobid = "abc" perlmutter = ctrl.client.compute.return_value @@ -98,14 +98,14 @@ def test_sfapi_calls_perlmutter_submit_job(self, mocker, fake_config): ctrl.submit_job("script") perlmutter.submit_job.assert_called_once_with("script") - def test_iriapi_returns_job_id_string(self, mocker, fake_config): - ctrl = _iriapi_controller(mocker, fake_config) + def test_iriapi_returns_job_id_string(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" result = ctrl.submit_job(script) assert result == "job-99" - def test_iriapi_posts_to_job_submit_url(self, mocker, fake_config): - ctrl = _iriapi_controller(mocker, fake_config) + def test_iriapi_posts_to_job_submit_url(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" ctrl.submit_job(script) call_args = ctrl.client.post.call_args @@ -117,8 +117,8 @@ def test_iriapi_posts_to_job_submit_url(self, mocker, fake_config): class TestSubmitJobIRIAPI: """Tests the SBATCH header parsing logic in _submit_job_iriapi.""" - def _submit_and_capture_spec(self, mocker, fake_config, script): - ctrl = _iriapi_controller(mocker, fake_config) + def _submit_and_capture_spec(self, mocker, mock_config, script): + ctrl = _iriapi_controller(mocker, mock_config) captured = {} def capture_post(url, json=None, **kwargs): @@ -131,105 +131,105 @@ def capture_post(url, json=None, **kwargs): ctrl._submit_job_iriapi(script) return captured["json"] - def test_parses_queue_name(self, mocker, fake_config): + def test_parses_queue_name(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q premium\n#SBATCH -A als\n" "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert spec["attributes"]["queue_name"] == "premium" - def test_parses_account(self, mocker, fake_config): + def test_parses_account(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A myproject\n" "#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert spec["attributes"]["account"] == "myproject" - def test_parses_walltime_to_seconds(self, mocker, fake_config): + def test_parses_walltime_to_seconds(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" "#SBATCH --time=01:30:00\n#SBATCH -N 1\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert spec["attributes"]["duration"] == 5400 # 1h30m in seconds - def test_parses_node_count(self, mocker, fake_config): + def test_parses_node_count(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" "#SBATCH --time=00:10:00\n#SBATCH -N 4\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert spec["resources"]["node_count"] == 4 - def test_cpu_constraint_adds_cpu_cores(self, mocker, fake_config): + def test_cpu_constraint_adds_cpu_cores(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C cpu\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert "cpu_cores_per_process" in spec["resources"] assert "gpu_cores_per_process" not in spec["resources"] - def test_gpu_constraint_adds_gpu_cores(self, mocker, fake_config): + def test_gpu_constraint_adds_gpu_cores(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH -C gpu\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert "gpu_cores_per_process" in spec["resources"] assert "cpu_cores_per_process" not in spec["resources"] - def test_reservation_included_when_present(self, mocker, fake_config): + def test_reservation_included_when_present(self, mocker, mock_config): script = ( "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n" "#SBATCH --time=00:10:00\n#SBATCH -N 1\n#SBATCH --reservation=myres\necho hi" ) - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert spec["attributes"]["reservation_id"] == "myres" - def test_no_reservation_when_absent(self, mocker, fake_config): + def test_no_reservation_when_absent(self, mocker, mock_config): script = "#!/bin/bash\n#SBATCH -q debug\n#SBATCH -A als\n#SBATCH --time=00:10:00\n#SBATCH -N 1\necho hi" - spec = self._submit_and_capture_spec(mocker, fake_config, script) + spec = self._submit_and_capture_spec(mocker, mock_config, script) assert "reservation_id" not in spec["attributes"] # ── wait_for_job ────────────────────────────────────────────────────────────── class TestWaitForJob: - def test_sfapi_returns_true_on_complete(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_returns_true_on_complete(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) job = mocker.MagicMock() ctrl.client.compute.return_value.job.return_value = job result = ctrl.wait_for_job("12345") job.complete.assert_called_once() assert result is True - def test_iriapi_returns_true_when_completed(self, mocker, fake_config): + def test_iriapi_returns_true_when_completed(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) ctrl.client.get.return_value.json.return_value = {"status": {"state": "completed"}} result = ctrl.wait_for_job("42") assert result is True - def test_iriapi_returns_false_on_failed(self, mocker, fake_config): + def test_iriapi_returns_false_on_failed(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) ctrl.client.get.return_value.json.return_value = {"status": {"state": "failed"}} result = ctrl.wait_for_job("42") assert result is False - def test_iriapi_returns_false_on_canceled(self, mocker, fake_config): + def test_iriapi_returns_false_on_canceled(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) ctrl.client.get.return_value.json.return_value = {"status": {"state": "canceled"}} result = ctrl.wait_for_job("42") assert result is False - def test_iriapi_polls_until_terminal_state(self, mocker, fake_config): + def test_iriapi_polls_until_terminal_state(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) responses = [ {"status": {"state": "running"}}, {"status": {"state": "running"}}, @@ -244,23 +244,23 @@ def test_iriapi_polls_until_terminal_state(self, mocker, fake_config): # ── mkdir_remote ────────────────────────────────────────────────────────────── class TestMkdirRemote: - def test_sfapi_runs_mkdir(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_runs_mkdir(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") ctrl.client.compute.return_value.run.assert_called_once() cmd = ctrl.client.compute.return_value.run.call_args[0][0] assert "mkdir -p" in cmd assert "/pscratch/sd/t/testuser/mydir" in cmd - def test_iriapi_posts_to_mkdir_url(self, mocker, fake_config): - ctrl = _iriapi_controller(mocker, fake_config) + def test_iriapi_posts_to_mkdir_url(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) ctrl.mkdir_remote("/pscratch/sd/t/testuser/mydir") ctrl.client.post.assert_called() url = ctrl.client.post.call_args[0][0] assert "mock-login-uuid" in url - def test_iriapi_posts_path_in_body(self, mocker, fake_config): - ctrl = _iriapi_controller(mocker, fake_config) + def test_iriapi_posts_path_in_body(self, mocker, mock_config): + ctrl = _iriapi_controller(mocker, mock_config) ctrl.mkdir_remote("/some/path") body = ctrl.client.post.call_args[1]["json"] assert body["path"] == "/some/path" @@ -270,23 +270,23 @@ def test_iriapi_posts_path_in_body(self, mocker, fake_config): # ── read_remote_file ────────────────────────────────────────────────────────── class TestReadRemoteFile: - def test_sfapi_returns_string_result(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_returns_string_result(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) ctrl.client.compute.return_value.run.return_value = "file contents" result = ctrl.read_remote_file("/some/file.txt") assert result == "file contents" - def test_sfapi_extracts_output_attribute(self, mocker, fake_config): - ctrl = _sfapi_controller(mocker, fake_config) + def test_sfapi_extracts_output_attribute(self, mocker, mock_config): + ctrl = _sfapi_controller(mocker, mock_config) run_result = mocker.MagicMock(spec=[]) # no __str__ shortcuts run_result.output = "from output attr" ctrl.client.compute.return_value.run.return_value = run_result result = ctrl.read_remote_file("/some/file.txt") assert result == "from output attr" - def test_iriapi_returns_file_contents_on_completed_task(self, mocker, fake_config): + def test_iriapi_returns_file_contents_on_completed_task(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) # First call: GET /filesystem/view → task_id # Subsequent calls: GET /task/ → status=completed + result @@ -301,9 +301,9 @@ def test_iriapi_returns_file_contents_on_completed_task(self, mocker, fake_confi result = ctrl.read_remote_file("/pscratch/data.txt") assert result == "file data" - def test_iriapi_raises_on_failed_task(self, mocker, fake_config): + def test_iriapi_raises_on_failed_task(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) view_response = mocker.MagicMock(is_success=True) view_response.json.return_value = {"task_id": "task-fail"} @@ -316,9 +316,9 @@ def test_iriapi_raises_on_failed_task(self, mocker, fake_config): with pytest.raises(RuntimeError, match="failed"): ctrl.read_remote_file("/pscratch/data.txt") - def test_iriapi_raises_timeout_after_40_polls(self, mocker, fake_config): + def test_iriapi_raises_timeout_after_40_polls(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.controller.time.sleep") - ctrl = _iriapi_controller(mocker, fake_config) + ctrl = _iriapi_controller(mocker, mock_config) view_response = mocker.MagicMock(is_success=True) view_response.json.return_value = {"task_id": "task-slow"} diff --git a/orchestration/_tests/test_jobs/nersc/test_login.py b/orchestration/_tests/test_jobs/nersc/test_login.py index c43538d7..32f32bfc 100644 --- a/orchestration/_tests/test_jobs/nersc/test_login.py +++ b/orchestration/_tests/test_jobs/nersc/test_login.py @@ -11,8 +11,6 @@ That choice is intentional. """ -import pytest - from orchestration.jobs.nersc.login import NERSCLoginMethod, create_nersc_client @@ -48,38 +46,38 @@ def test_membership_iriapi(self): # ── create_nersc_client dispatch ────────────────────────────────────────────── class TestCreateNerscClient: - def test_sfapi_dispatches_to_sfapi_builder(self, mocker, fake_config): + def test_sfapi_dispatches_to_sfapi_builder(self, mocker, mock_config): mock_client = mocker.MagicMock() builder = mocker.patch( "orchestration.jobs.nersc.login._create_sfapi_client", return_value=mock_client, ) - result = create_nersc_client(fake_config, NERSCLoginMethod.SFAPI) + result = create_nersc_client(mock_config, NERSCLoginMethod.SFAPI) builder.assert_called_once_with() assert result is mock_client - def test_iriapi_dispatches_to_iriapi_builder(self, mocker, fake_config): + def test_iriapi_dispatches_to_iriapi_builder(self, mocker, mock_config): mock_client = mocker.MagicMock() builder = mocker.patch( "orchestration.jobs.nersc.login._create_iriapi_client", return_value=mock_client, ) - result = create_nersc_client(fake_config, NERSCLoginMethod.IRIAPI) - builder.assert_called_once_with(fake_config.nersc_resources["iri"]["api_base_url"]) + result = create_nersc_client(mock_config, NERSCLoginMethod.IRIAPI) + builder.assert_called_once_with(mock_config.nersc_resources["iri"]["api_base_url"]) assert result is mock_client - def test_sfapi_passes_api_base_url_from_config(self, mocker, fake_config): + def test_sfapi_passes_api_base_url_from_config(self, mocker, mock_config): mocker.patch("orchestration.jobs.nersc.login._create_sfapi_client") # No assertion on URL for SFAPI (the builder doesn't take a URL arg), # but create_nersc_client must read the sfapi sub-dict without raising. - create_nersc_client(fake_config, NERSCLoginMethod.SFAPI) + create_nersc_client(mock_config, NERSCLoginMethod.SFAPI) - def test_iriapi_passes_correct_api_base_url(self, mocker, fake_config): + def test_iriapi_passes_correct_api_base_url(self, mocker, mock_config): builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") - create_nersc_client(fake_config, NERSCLoginMethod.IRIAPI) + create_nersc_client(mock_config, NERSCLoginMethod.IRIAPI) builder.assert_called_once_with("https://mock-iri.nersc.gov") - def test_default_login_method_is_iriapi(self, mocker, fake_config): + def test_default_login_method_is_iriapi(self, mocker, mock_config): builder = mocker.patch("orchestration.jobs.nersc.login._create_iriapi_client") - create_nersc_client(fake_config) + create_nersc_client(mock_config) builder.assert_called_once() diff --git a/orchestration/_tests/test_jobs/test_controller.py b/orchestration/_tests/test_jobs/test_controller.py index 505866e7..5e9a88e6 100644 --- a/orchestration/_tests/test_jobs/test_controller.py +++ b/orchestration/_tests/test_jobs/test_controller.py @@ -51,20 +51,20 @@ class TestJobControllerABC: for TypeError on instantiation. """ - def test_instantiates_with_valid_config(self, fake_config): - controller = JobController(fake_config) + def test_instantiates_with_valid_config(self, mock_config): + controller = JobController(mock_config) assert controller is not None - def test_stores_config(self, fake_config): - controller = JobController(fake_config) - assert controller.config is fake_config + def test_stores_config(self, mock_config): + controller = JobController(mock_config) + assert controller.config is mock_config - def test_subclass_inherits_config(self, fake_config): - class DummyJobController(JobController): + def test_subclass_inherits_config(self, mock_config): + class MockJobController(JobController): pass - controller = DummyJobController(fake_config) - assert controller.config is fake_config + controller = MockJobController(mock_config) + assert controller.config is mock_config def test_is_abc_subclass(self): from abc import ABC diff --git a/orchestration/_tests/test_jobs/test_options.py b/orchestration/_tests/test_jobs/test_options.py index 2ad1394c..a4a29ec9 100644 --- a/orchestration/_tests/test_jobs/test_options.py +++ b/orchestration/_tests/test_jobs/test_options.py @@ -39,7 +39,7 @@ def test_returns_copy_not_original(self, mocker): # ── Layer 2: MLflow ─────────────────────────────────────────────────────── - def test_mlflow_nersc_path_maps_to_checkpoint_key(self, mocker, fake_config): + def test_mlflow_nersc_path_maps_to_checkpoint_key(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) cp = _make_checkpoint(mocker, nersc_path="/pscratch/model.pt") mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) @@ -47,13 +47,13 @@ def test_mlflow_nersc_path_maps_to_checkpoint_key(self, mocker, fake_config): opts = load_job_options( "var", {"finetuned_checkpoint_path": "/old/path"}, - config=fake_config, + config=mock_config, mlflow_model_name="my-model", mlflow_checkpoint_key="finetuned_checkpoint_path", ) assert opts["finetuned_checkpoint_path"] == "/pscratch/model.pt" - def test_mlflow_inference_params_overlay_config_defaults(self, mocker, fake_config): + def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) cp = _make_checkpoint(mocker, inference_params={"batch_size": 16, "threshold": 0.5}) mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) @@ -61,7 +61,7 @@ def test_mlflow_inference_params_overlay_config_defaults(self, mocker, fake_conf opts = load_job_options( "var", {"batch_size": 8, "threshold": 0.3, "other": "kept"}, - config=fake_config, + config=mock_config, mlflow_model_name="my-model", ) assert opts["batch_size"] == 16 @@ -76,37 +76,37 @@ def test_mlflow_layer_skipped_when_config_is_none(self, mocker): spy.assert_not_called() assert opts == {"key": "value"} - def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, fake_config): + def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) spy = mocker.patch("orchestration.jobs.options.get_checkpoint_info") - opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name=None) + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name=None) spy.assert_not_called() assert opts == {"key": "value"} - def test_mlflow_fallback_to_config_when_checkpoint_is_none(self, mocker, fake_config): + def test_mlflow_fallback_to_config_when_checkpoint_is_none(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=None) - opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name="model") + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name="model") assert opts == {"key": "value"} - def test_mlflow_fallback_to_config_when_get_checkpoint_raises(self, mocker, fake_config): + def test_mlflow_fallback_to_config_when_get_checkpoint_raises(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) mocker.patch( "orchestration.jobs.options.get_checkpoint_info", side_effect=RuntimeError("mlflow unreachable"), ) - opts = load_job_options("var", {"key": "value"}, config=fake_config, mlflow_model_name="model") + opts = load_job_options("var", {"key": "value"}, config=mock_config, mlflow_model_name="model") assert opts == {"key": "value"} - def test_mlflow_injects_new_keys_not_in_config(self, mocker, fake_config): + def test_mlflow_injects_new_keys_not_in_config(self, mocker, mock_config): mocker.patch("orchestration.jobs.options.Variable.get", return_value={"defaults": True}) cp = _make_checkpoint(mocker, inference_params={"new_param": "injected"}) mocker.patch("orchestration.jobs.options.get_checkpoint_info", return_value=cp) - opts = load_job_options("var", {"existing": "kept"}, config=fake_config, mlflow_model_name="model") + opts = load_job_options("var", {"existing": "kept"}, config=mock_config, mlflow_model_name="model") assert opts["new_param"] == "injected" assert opts["existing"] == "kept" @@ -144,7 +144,7 @@ def test_variable_get_failure_falls_back_to_opts(self, mocker): opts = load_job_options("var", {"key": "config-default"}) assert opts == {"key": "config-default"} - def test_prefect_variable_wins_over_mlflow(self, mocker, fake_config): + def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config): mocker.patch( "orchestration.jobs.options.Variable.get", return_value={"defaults": False, "batch_size": 99}, @@ -155,7 +155,7 @@ def test_prefect_variable_wins_over_mlflow(self, mocker, fake_config): opts = load_job_options( "var", {"batch_size": 8}, - config=fake_config, + config=mock_config, mlflow_model_name="model", ) assert opts["batch_size"] == 99 From d583480a7c72cb300101a79f801ab1d6e3aa452e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 27 May 2026 11:38:14 -0700 Subject: [PATCH 25/26] polishing up post-rebase --- orchestration/flows/bl832/nersc.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 5a1d8846..9d05b8ef 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -67,7 +67,7 @@ def __init__( self, config: Config832, client: Client | httpx.Client | None = None, - login_method: NERSCLoginMethod = NERSCLoginMethod.IRIAPI, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: NERSCJobController.__init__(self, config, client, login_method) TomographyHPCController.__init__(self, config) @@ -1337,14 +1337,6 @@ def nersc_recon_flow( nersc_to_beegfs_zarr_future.result() logger.info("All transfers complete.") - # Register the reconstructed TIFFs in tiled - register_file_to_tiled( - path=Path(config.beegfs_scratch.root_path+tiff_file_path), - prefix="beamlines/bl832/scratch", - overwrite=False, - tags=["scratch", "bl832"], - ) - # Register the reconstructed ZARRs in tiled register_file_to_tiled( path=Path(config.beegfs_scratch.root_path+zarr_file_path), @@ -1383,6 +1375,7 @@ def nersc_petiole_segment_flow( :param file_path: The path to the file to be processed. :param config: Configuration object for the flow. :param num_nodes: Number of nodes for reconstruction. + :login_method: Method to use for logging into NERSC (default: SFAPI). :return: True if reconstruction and at least one segmentation task succeeded. """ logger = get_run_logger() @@ -2009,7 +2002,7 @@ def nersc_multiresolution_task( :param file_path: Path to the reconstructed data folder to be processed. :param config: Configuration object for the flow. - :param login_method: NERSC API to authenticate against. + :param login_method: Method to use for logging into NERSC (default: SFAPI). :return: True if the task completed successfully, False otherwise. """ logger = get_run_logger() @@ -2077,7 +2070,7 @@ def nersc_segmentation_sam3_task( tomography_controller = get_controller( hpc_type=HPC.NERSC, config=config, - login_method=NERSCLoginMethod.IRIAPI + login_method=login_method ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( From e4e1b5a4f352d0a633d196020f8477b30c566c3a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 27 May 2026 11:49:44 -0700 Subject: [PATCH 26/26] removing reservation from config --- config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.yml b/config.yml index c8935df7..fb0b9989 100644 --- a/config.yml +++ b/config.yml @@ -225,7 +225,7 @@ hpc_submission_settings832: # ── SLURM resource allocation ───────────────────────────────────────────── qos: realtime account: als - reservation: "_CAP_SYNAPS_LIVEDEMO_CPU2" + reservation: "" cpus-per-task: 128 walltime: "0:15:00"