diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index f33c71964..2d988afa0 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -73,13 +73,76 @@ def gen_exec_command(self) -> str: launcher_py = (mbridge_repo_path / "scripts" / "performance" / "setup_experiment.py").absolute() - parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py) + pre_hook_sbatch_path: Optional[Path] = None + base_slurm_params: str = "" + if self.test_run.pre_test: + pre_hook_sbatch_path = self._gen_pre_hook_sbatch() + parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py, include_slurm_params=False) + base_slurm_params = ";".join(self._collect_additional_slurm_params()) + else: + parts = self._build_launcher_parts(args, tdef, mbridge_repo_path, launcher_py) + launcher_python = str((venv_path / "bin" / "python").absolute()) - full_cmd = self._wrap_launcher_for_job_id_and_quiet_output(" ".join(parts), launcher_python) + full_cmd = self._wrap_launcher_for_job_id_and_quiet_output( + " ".join(parts), + launcher_python, + pre_hook_sbatch_path=pre_hook_sbatch_path, + base_slurm_params=base_slurm_params, + ) self._write_command_to_file(full_cmd, self.test_run.output_path) return full_cmd + def _collect_additional_slurm_params(self) -> list[str]: + """Return the additional_slurm_params list (without dependency).""" + params: list[str] = [] + if self.system.gpus_per_node and self.system.supports_gpu_directives: + params.append(f"gpus-per-node={self.system.gpus_per_node}") + params.append(f"gres=gpu:{self.system.gpus_per_node}") + _, node_list = self.get_cached_nodes_spec() + if node_list: + params.append(f"nodelist={','.join(node_list)}") + elif self.test_run.exclude_nodes: + params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") + for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): + if source: + params.extend(self._parse_srun_args_as_slurm_params(source)) + return params + + def _gen_pre_hook_sbatch(self) -> Path: + """Generate a standalone sbatch script for pre-hook tests; return its path.""" + pre_hook_output = self.test_run.output_path / "pre_hook" + pre_hook_output.mkdir(parents=True, exist_ok=True) + + for tr in self.test_run.pre_test.test_runs: + tr.num_nodes = self.test_run.nnodes + + pre_hook_cmds = self.gen_pre_test(self.test_run.pre_test, self.test_run.output_path) + + num_nodes, node_list = self.get_cached_nodes_spec() + sbatch_lines = [ + "#!/bin/bash", + "# Pre-hook sbatch generated by CloudAI", + f"#SBATCH --job-name=pre_hook_{self.job_name()}", + f"#SBATCH --output={pre_hook_output.absolute() / 'stdout.txt'}", + f"#SBATCH --error={pre_hook_output.absolute() / 'stderr.txt'}", + f"#SBATCH --partition={self.system.default_partition}", + ] + if self.system.account: + sbatch_lines.append(f"#SBATCH --account={self.system.account}") + if node_list: + sbatch_lines.append(f"#SBATCH --nodelist={','.join(node_list)}") + elif num_nodes: + sbatch_lines.append(f"#SBATCH --nodes={num_nodes}") + if self.test_run.time_limit: + sbatch_lines.append(f"#SBATCH --time={self.test_run.time_limit}") + sbatch_lines.extend(["", pre_hook_cmds]) + + sbatch_path = self.test_run.output_path / "pre_hook_sbatch_script.sh" + sbatch_path.write_text("\n".join(sbatch_lines)) + sbatch_path.chmod(sbatch_path.stat().st_mode | stat.S_IXUSR) + return sbatch_path + def store_test_run(self) -> None: test_cmd = self.gen_exec_command() trd = TestRunDetails.from_test_run(self.test_run, test_cmd=test_cmd, full_cmd=test_cmd) @@ -166,12 +229,21 @@ def _normalize_cuda_graph_scope_arg(self, val: Any) -> str: parts = [p.strip().strip("\"'") for p in s.split(",") if p.strip()] return ",".join(parts) - def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher_python: str) -> str: + def _wrap_launcher_for_job_id_and_quiet_output( + self, + launcher_cmd: str, + launcher_python: str, + pre_hook_sbatch_path: Optional[Path] = None, + base_slurm_params: str = "", + ) -> str: """ Run the Megatron-Bridge launcher quietly and ensure CloudAI can parse a job ID. CloudAI's SlurmRunner expects stdout to include "Submitted batch job ". This writes a readable wrapper script (with section breaks) into the test output directory, then runs it. + + If pre_hook_sbatch_path is provided, the pre-hook sbatch is submitted first and its job ID is used as + a Slurm dependency (afterok) for the main training job, so training only starts if the pre-hook passed. """ output_dir = self.test_run.output_path.absolute() output_dir.mkdir(parents=True, exist_ok=True) @@ -181,6 +253,28 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher container_runtime_exports = self._container_runtime_env_exports() + pre_hook_lines: list[str] = [] + launch_line: str + if pre_hook_sbatch_path is not None: + pre_hook_lines = [ + f'PRE_HOOK_SBATCH="{pre_hook_sbatch_path.absolute()}"', + 'PRE_HOOK_OUTPUT=$(sbatch "$PRE_HOOK_SBATCH" 2>&1)', + 'PRE_HOOK_JOB_ID=$(echo "$PRE_HOOK_OUTPUT" | grep -Eo "Submitted batch job [0-9]+" | grep -Eo "[0-9]+" | tail -n1 || true)', # noqa: E501 + 'if [ -z "$PRE_HOOK_JOB_ID" ]; then', + ' echo "Failed to submit pre-hook job: $PRE_HOOK_OUTPUT" >&2', + " exit 1", + "fi", + 'echo "Submitted pre-hook batch job $PRE_HOOK_JOB_ID"', + f'ADDITIONAL_SLURM_PARAMS="{base_slurm_params}"', + 'if [ -n "$PRE_HOOK_JOB_ID" ]; then', + ' ADDITIONAL_SLURM_PARAMS="${ADDITIONAL_SLURM_PARAMS};dependency=afterok:${PRE_HOOK_JOB_ID}"', + "fi", + "", + ] + launch_line = f'{launcher_cmd} --additional_slurm_params "$ADDITIONAL_SLURM_PARAMS" >>"$LOG" 2>&1 || LAUNCH_RC=$?' # noqa: E501 + else: + launch_line = f'{launcher_cmd} >>"$LOG" 2>&1 || LAUNCH_RC=$?' + script_lines = [ "#!/usr/bin/env bash", "set -o pipefail", @@ -195,6 +289,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher "", *container_runtime_exports, "", + *pre_hook_lines, ': >"$LOG"', "WANDB_INSTALL_RC=0", f'{shlex.quote(launcher_python)} -m pip install wandb numpy==1.26.4 >>"$LOG" 2>&1 || WANDB_INSTALL_RC=$?', @@ -205,7 +300,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher "fi", "", "LAUNCH_RC=0", - f'{launcher_cmd} >>"$LOG" 2>&1 || LAUNCH_RC=$?', + launch_line, "", # Parse job id from Megatron-Bridge output (multiple possible formats) # Patterns: "Submitted batch job 694112", "Job id: 694112", "- Job id: 694112", "Job ID: 694112" @@ -247,7 +342,12 @@ def _list_or_comma_str(self, val: str | list[str] | None) -> Optional[str]: raise RuntimeError("Unexpected sweeps list. At this point code expects scalars only") def _build_launcher_parts( # noqa: C901 - self, args: MegatronBridgeCmdArgs, tdef: MegatronBridgeTestDefinition, repo_path: Path, launcher_py: Path + self, + args: MegatronBridgeCmdArgs, + tdef: MegatronBridgeTestDefinition, + repo_path: Path, + launcher_py: Path, + include_slurm_params: bool = True, ) -> list[str]: fields_set = args.model_fields_set force_fields = { @@ -451,25 +551,10 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("nsys_trace", "--nsys_trace", self._list_or_comma_str(args.nsys_trace)) add_field("nsys_extra_args", "--nsys_extra_args", self._list_or_comma_str(args.nsys_extra_args)) - additional_slurm_params: list[str] = [] - - if self.system.gpus_per_node and self.system.supports_gpu_directives: - additional_slurm_params.append(f"gpus-per-node={self.system.gpus_per_node}") - additional_slurm_params.append(f"gres=gpu:{self.system.gpus_per_node}") - - _, node_list = self.get_cached_nodes_spec() - if node_list: - nodelist_str = ",".join(node_list) - additional_slurm_params.append(f"nodelist={nodelist_str}") - elif self.test_run.exclude_nodes: - additional_slurm_params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") - - for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): - if source: - additional_slurm_params.extend(self._parse_srun_args_as_slurm_params(source)) - - if additional_slurm_params: - parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) + if include_slurm_params: + additional_slurm_params = self._collect_additional_slurm_params() + if additional_slurm_params: + parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) # Config variant add_field("config_variant", "-cv", args.config_variant)