1010from typing import Callable , List , Optional , Set , Union
1111
1212from .. import case , common
13- from ..build import MFCTarget , get_target
13+ from ..build import SIMULATION , MFCTarget , get_target
1414from ..run import input
1515from ..state import ARG
1616
@@ -154,11 +154,13 @@ class TestCase(case.Case):
154154 ppn : int
155155 trace : str
156156 override_tol : Optional [float ] = None
157+ restart_check : bool = False
157158
158- def __init__ (self , trace : str , mods : dict , ppn : int = None , override_tol : float = None ) -> None :
159+ def __init__ (self , trace : str , mods : dict , ppn : int = None , override_tol : float = None , restart_check : bool = False ) -> None :
159160 self .trace = trace
160161 self .ppn = ppn or 1
161162 self .override_tol = override_tol
163+ self .restart_check = restart_check
162164 super ().__init__ ({** BASE_CFG .copy (), ** mods })
163165
164166 def run (self , targets : List [Union [str , MFCTarget ]], gpus : Set [int ]) -> subprocess .CompletedProcess :
@@ -183,6 +185,55 @@ def run(self, targets: List[Union[str, MFCTarget]], gpus: Set[int]) -> subproces
183185
184186 return common .system (command , print_cmd = False , text = True , stdout = subprocess .PIPE , stderr = subprocess .STDOUT )
185187
188+ def run_restart (self , targets , gpus ):
189+ """Run a restart roundtrip: simulate to midpoint, then restart to end."""
190+ # NOTE: This method overrides t_step_save to produce exactly one save
191+ # per phase (at the boundary step). Tests using restart_check=True
192+ # must not rely on custom t_step_save values, as the straight run's
193+ # save points would not match the restart run's.
194+ mid_step = (self .params ["t_step_start" ] + self .params ["t_step_stop" ]) // 2
195+ if mid_step <= self .params ["t_step_start" ]:
196+ raise common .MFCException (
197+ f"run_restart: t_step_stop ({ self .params ['t_step_stop' ]} ) is too close to t_step_start ({ self .params ['t_step_start' ]} ) for a restart roundtrip (need t_step_stop - t_step_start >= 2)."
198+ )
199+ orig = dict (self .params )
200+
201+ try :
202+ self .delete_output ()
203+
204+ # Phase 1: Run to midpoint (generates restart data)
205+ self .params = {** orig , "t_step_stop" : mid_step , "t_step_save" : mid_step - orig ["t_step_start" ]}
206+ self .create_directory ()
207+ result1 = self .run (targets , gpus )
208+ if result1 .returncode != 0 :
209+ return result1
210+
211+ # Keep D/ (has steps 0 and mid_step) and p_all/ (restart data).
212+ dirpath = self .get_dirpath ()
213+ common .delete_directory (os .path .join (dirpath , "silo_hdf5" ))
214+
215+ # Phase 2: Restart simulation from midpoint. Only the simulation
216+ # is run — it reads grid + IC directly from p_all/p0/<mid_step>/.
217+ self .params = {** orig , "t_step_start" : mid_step , "t_step_save" : orig ["t_step_stop" ] - mid_step }
218+ self .create_directory ()
219+ result2 = self .run ([SIMULATION ], gpus )
220+
221+ # Remove intermediate step files from D/ so only step 0 and
222+ # t_step_stop remain, matching the straight run's output.
223+ if result2 .returncode == 0 :
224+ d_dir = os .path .join (dirpath , "D" )
225+ mid_tag = f"{ mid_step :06d} "
226+ for f in glob .glob (os .path .join (d_dir , f"*.{ mid_tag } .dat" )):
227+ common .delete_file (f )
228+
229+ return result2
230+ finally :
231+ self .params = orig
232+ try :
233+ self .create_directory ()
234+ except Exception as exc :
235+ print (f"Warning: failed to restore test directory: { exc } " )
236+
186237 def get_trace (self ) -> str :
187238 return self .trace
188239
@@ -307,6 +358,7 @@ class TestCaseBuilder:
307358 ppn : int
308359 functor : Optional [Callable ]
309360 override_tol : Optional [float ] = None
361+ restart_check : bool = False
310362
311363 def get_uuid (self ) -> str :
312364 return trace_to_uuid (self .trace )
@@ -331,7 +383,7 @@ def to_case(self) -> TestCase:
331383 if self .functor :
332384 self .functor (dictionary )
333385
334- return TestCase (self .trace , dictionary , self .ppn , self .override_tol )
386+ return TestCase (self .trace , dictionary , self .ppn , self .override_tol , self . restart_check )
335387
336388
337389@dataclasses .dataclass
@@ -357,7 +409,7 @@ def define_case_f(trace: str, path: str, args: List[str] = None, ppn: int = None
357409 return TestCaseBuilder (trace , mods or {}, path , args or [], ppn or 1 , functor , override_tol )
358410
359411
360- def define_case_d (stack : CaseGeneratorStack , newTrace : str , newMods : dict , ppn : int = None , functor : Callable = None , override_tol : float = None ) -> TestCaseBuilder :
412+ def define_case_d (stack : CaseGeneratorStack , newTrace : str , newMods : dict , ppn : int = None , functor : Callable = None , override_tol : float = None , restart_check : bool = False ) -> TestCaseBuilder :
361413 mods : dict = {}
362414
363415 for mod in stack .mods :
@@ -373,7 +425,7 @@ def define_case_d(stack: CaseGeneratorStack, newTrace: str, newMods: dict, ppn:
373425 if not common .isspace (trace ):
374426 traces .append (trace )
375427
376- return TestCaseBuilder (" -> " .join (traces ), mods , None , None , ppn or 1 , functor , override_tol )
428+ return TestCaseBuilder (" -> " .join (traces ), mods , None , None , ppn or 1 , functor , override_tol , restart_check )
377429
378430
379431def input_bubbles_lagrange (self ):
0 commit comments