Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions ci/cscs-ci-dace-determinism.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,9 @@ dace_determinism_cscs_gh200:
TEST_VARIANTS: 'cpu cuda12'
SLURM_GPUS_PER_NODE: 1
SLURM_PARTITION: 'shared'
GT4PY_BUILD_JOBS: 8
PYTEST_XDIST_AUTO_NUM_WORKERS: 32
rules:
- *exclude_variants_rules
- if: $SUBPACKAGE == 'next' && $DETAIL == 'nomesh'
variables:
# TODO: investigate why the dace tests seem to hang with multiple jobs
GT4PY_BUILD_JOBS: 1
- when: on_success

dace_determinism_cscs_amd_rocm:
Expand All @@ -126,7 +121,6 @@ dace_determinism_cscs_amd_rocm:
variables:
TEST_VARIANTS: 'rocm7'
SLURM_GPUS_PER_NODE: 4
GT4PY_BUILD_JOBS: 8
PYTEST_XDIST_AUTO_NUM_WORKERS: 32
SLURM_PARTITION: mi300
CMAKE_PREFIX_PATH: /opt/rocm
Expand Down
5 changes: 1 addition & 4 deletions ci/cscs-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ build_cscs_amd_rocm:
stage: test
image: ${CSCS_REGISTRY_PATH}/public/${ARCH}/base/gt4py-ci-${PY_VERSION}:${DOCKER_TAG}
variables:
GT4PY_BUILD_JOBS: 1 # Limit pool size to 1 in order to catch the build errors in the unit test they originate from.
TEST_VARIANTS: 'cpu' # Extended jobs should redefine which variants (cpu, cuda12, rocm6) to test
USE_MPI: 0 # TODO(havogt): to workaround the libfabric hook injecting incompatible libraries
SLURM_JOB_NUM_NODES: 1
Expand Down Expand Up @@ -89,15 +90,12 @@ test_cscs_gh200:
TEST_VARIANTS: 'cpu cuda12'
SLURM_GPUS_PER_NODE: 1
SLURM_PARTITION: 'shared'
GT4PY_BUILD_JOBS: 8
# Limit test parallelism to avoid "OSError: too many open files" in the gt4py build stage.
PYTEST_XDIST_AUTO_NUM_WORKERS: 32
rules:
- *exclude_variants_rules
- if: $SUBPACKAGE == 'next' && $VARIANT == 'dace' && $DETAIL == 'nomesh'
variables:
# TODO: investigate why the dace tests seem to hang with multiple jobs
GT4PY_BUILD_JOBS: 1
SLURM_TIMELIMIT: "00:15:00"
- if: $SUBPACKAGE == 'cartesian' && $VARIANT == 'dace' && $SUBVARIANT == 'cuda12'
variables:
Expand All @@ -113,7 +111,6 @@ test_cscs_amd_rocm:
variables:
TEST_VARIANTS: 'cpu rocm7'
SLURM_GPUS_PER_NODE: 4
GT4PY_BUILD_JOBS: 8
# Limit test parallelism to avoid "OSError: too many open files" in the gt4py build stage.
PYTEST_XDIST_AUTO_NUM_WORKERS: 32
SLURM_PARTITION: mi300
Expand Down
21 changes: 20 additions & 1 deletion src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ def compiled_program_call_context(
# TODO(havogt): We would like this to be a ProcessPoolExecutor, which requires (to decide what) to pickle.
_async_compilation_pool: concurrent.futures.Executor | None = None

# Registry of the futures of all async compilation jobs that have been submitted to
# '_async_compilation_pool'. It is used by 'wait_for_compilation' to surface compilation
# errors that would otherwise stay hidden in the futures until the program is called.
# We use a 'WeakSet' such that futures are removed automatically once they are no longer
# referenced by the owning 'CompiledProgramsPool'.
_pending_compilation_futures: weakref.WeakSet[concurrent.futures.Future] = weakref.WeakSet()
_pending_compilation_futures_lock: threading.Lock = threading.Lock()


def _init_async_compilation_pool() -> None:
global _async_compilation_pool
Expand All @@ -182,6 +190,14 @@ def wait_for_compilation() -> None:
_async_compilation_pool = None
_init_async_compilation_pool()

# All jobs are finished now; re-raise the first compilation error (if any), which would
# otherwise only surface when the corresponding program is called.
with _pending_compilation_futures_lock:
futures = list(_pending_compilation_futures)
_pending_compilation_futures.clear()
for future in futures:
future.result() # re-raises any exception that occurred during compilation


def _make_tuple_expr(el_exprs: list[str]) -> str:
return "".join((f"{el},") for el in el_exprs)
Expand Down Expand Up @@ -643,7 +659,10 @@ def _compile_variant(
if _async_compilation_pool is None:
self.compiled_programs[key] = compile_call()
else:
self._compilation_jobs[key] = _async_compilation_pool.submit(compile_call)
future = _async_compilation_pool.submit(compile_call)
self._compilation_jobs[key] = future
with _pending_compilation_futures_lock:
_pending_compilation_futures.add(future)

# TODO(tehrengruber): Rework the interface to allow precompilation with compile time
# domains and of scans.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -892,20 +892,41 @@ def test_synchronous_compilation(cartesian_case, compile_testee):
assert np.allclose(out.ndarray, a.ndarray + b.ndarray)


@pytest.mark.parametrize("compile_fails", [False, True], ids=["success", "failure"])
@pytest.mark.parametrize("synchronous", [True, False], ids=["synchronous", "asynchronous"])
def test_wait_for_compilation(cartesian_case, compile_testee, compile_testee_domain, synchronous):
def test_wait_for_compilation(
cartesian_case, compile_testee, compile_testee_domain, synchronous, compile_fails
):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

with (
mock.patch.object(compiled_program, "_async_compilation_pool", None)
if synchronous
else contextlib.nullcontext()
):
msg = "compilation failed"

class FailingBackend:
# Delegates everything to the real backend but raises on 'compile'. This scopes the
# failure to this single program's compilation, instead of patching the executor type
# which would affect every program in the process.
def __getattr__(self, name):
return getattr(cartesian_case.backend, name)

def compile(self, program, compile_time_args):
raise RuntimeError(msg)

with contextlib.ExitStack() as stack:
if synchronous:
stack.enter_context(
mock.patch.object(compiled_program, "_async_compilation_pool", None)
)
if compile_fails:
# The pool captures 'backend' when it is first built (on 'compile'), so swap before.
# In synchronous mode the error surfaces from 'compile', in asynchronous mode it is
# deferred to 'wait_for_compilation'.
object.__setattr__(compile_testee, "backend", FailingBackend())
stack.enter_context(pytest.raises(RuntimeError, match=msg))

compile_testee.compile(offset_provider=cartesian_case.offset_provider)
# TODO(havogt): currently only tests that the function call does not crash...
gtx.wait_for_compilation()
# ... and afterwards compilation still works
# If it did not throw an error, compilation still works afterwards.
compile_testee_domain.compile(offset_provider=cartesian_case.offset_provider)


Expand Down