From fde9a668dd06a85f8beede176e05ec5333232b25 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 16 Jun 2026 14:18:04 +0200 Subject: [PATCH 1/3] Fix StaticArgs injection --- src/gt4py/next/otf/compiled_program.py | 21 +++++++++++-------- .../otf_tests/test_compiled_program.py | 11 ++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f09ae16bd9..de66beb47d 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -666,15 +666,18 @@ def compile( """ for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): + # Only inject a `StaticArg` descriptor mapping when static arguments were + # actually given. + argument_descriptors: ArgStaticDescriptorsByType = {} + if static_args: + argument_descriptors[arguments.StaticArg] = dict( + zip( + static_args.keys(), + [arguments.StaticArg(value=v) for v in static_values], + strict=True, + ) + ) self._compile_variant( - argument_descriptors={ - arguments.StaticArg: dict( - zip( - static_args.keys(), - [arguments.StaticArg(value=v) for v in static_values], - strict=True, - ) - ), - }, + argument_descriptors=argument_descriptors, offset_provider=offset_provider, ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..281dbc6ff4 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -168,6 +168,17 @@ def test_different_static_args_break_same_prg_after_static_params_change(testee_ prg.compile(cond=[True], offset_provider={}) +def test_compile_with_empty_static_args_does_not_register_static_arg(testee_prog): + pool = testee_prog._compiled_programs + + # Call pool.compile() with empty static_args dict + static_args = {} + pool.compile(offset_providers=[{}], **static_args) + + # when static_args={}, no StaticArg descriptor should be registered + assert arguments.StaticArg not in pool.argument_descriptor_mapping + + def _verify_program_has_expected_domain( program: itir.Program, expected_domain: gtx.Domain, uids: utils.IDGeneratorPool ): From 77271c03a251329cd02994542c1f3c67f6b38f4a Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 17 Jun 2026 17:15:59 +0200 Subject: [PATCH 2/3] Add comments and regression test --- src/gt4py/next/ffront/decorator.py | 8 ++++++- src/gt4py/next/otf/compiled_program.py | 2 ++ .../ffront_tests/test_compiled_program.py | 22 +++++++++++++++++++ .../otf_tests/test_compiled_program.py | 11 ---------- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c749fcec01..4a1a6f87a0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -169,8 +169,14 @@ def compile( # the dict directly. Note that we don't need to check any args, since the pool checks # this on compile anyway. if "_compiled_programs" not in self.__dict__: + static_params: tuple[str, ...] = () + if static_args: + # This is reached by `entry_point.compile(..., **static_args)` before delegating + # to `CompiledProgramsPool.compile()` below. + # Keep this in sync with `compile` in compiled_program.py. + static_params = tuple(static_args.keys()) self.__dict__["_compiled_programs"] = self._make_compiled_programs_pool( - static_params=tuple(static_args.keys()), + static_params=static_params, static_domains=self.compilation_options.static_domains, ) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index de66beb47d..e9fb4b04bc 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -670,6 +670,8 @@ def compile( # actually given. argument_descriptors: ArgStaticDescriptorsByType = {} if static_args: + # Calls from `Program.compile()`/`FieldOperator.compile()` reach this. + # Keep this in sync with `compile` in decorator.py. argument_descriptors[arguments.StaticArg] = dict( zip( static_args.keys(), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 52e920a48b..8620c170a5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -124,6 +124,28 @@ def test_compile(cartesian_case, compile_testee): assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray) +def test_compile_with_empty_static_args_matches_no_static_args(cartesian_case, compile_testee): + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + empty_static_args = {} + decorator_path_testee = compile_testee.with_backend(cartesian_case.backend) + decorator_path_testee.compile( + offset_provider=cartesian_case.offset_provider, **empty_static_args + ) + decorator_path_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping + assert arguments.StaticArg not in decorator_path_mapping + + compiled_program_path_testee = compile_testee.with_backend(cartesian_case.backend) + compiled_program_path_pool = compiled_program_path_testee._make_compiled_programs_pool( + static_params=(), static_domains=False + ) + compiled_program_path_pool.argument_descriptor_mapping = None + compiled_program_path_pool.compile(offset_providers=[cartesian_case.offset_provider]) + + assert compiled_program_path_pool.argument_descriptor_mapping == decorator_path_mapping + + def test_compile_twice_same_program_errors(cartesian_case, compile_testee): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 281dbc6ff4..def8800c98 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -168,17 +168,6 @@ def test_different_static_args_break_same_prg_after_static_params_change(testee_ prg.compile(cond=[True], offset_provider={}) -def test_compile_with_empty_static_args_does_not_register_static_arg(testee_prog): - pool = testee_prog._compiled_programs - - # Call pool.compile() with empty static_args dict - static_args = {} - pool.compile(offset_providers=[{}], **static_args) - - # when static_args={}, no StaticArg descriptor should be registered - assert arguments.StaticArg not in pool.argument_descriptor_mapping - - def _verify_program_has_expected_domain( program: itir.Program, expected_domain: gtx.Domain, uids: utils.IDGeneratorPool ): From 2716931ee778da5c0134429d2701c80fcdfa7a25 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 26 Jun 2026 12:51:47 +0200 Subject: [PATCH 3/3] Apply review comments --- src/gt4py/next/ffront/decorator.py | 11 ++++----- src/gt4py/next/otf/compiled_program.py | 6 ++--- .../ffront_tests/test_compiled_program.py | 23 ++++++++----------- 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4a1a6f87a0..a99f2eaf76 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -133,6 +133,9 @@ def _make_compiled_programs_pool( argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} if static_params: + # Static parameter names come from either JIT compilation options or + # `entry_point.compile(..., **static_args)` while initializing the pool. + # Keep this in sync with the descriptors created in `CompiledProgramsPool.compile()`. argument_descriptor_mapping[arguments.StaticArg] = static_params if static_domains: @@ -169,14 +172,8 @@ def compile( # the dict directly. Note that we don't need to check any args, since the pool checks # this on compile anyway. if "_compiled_programs" not in self.__dict__: - static_params: tuple[str, ...] = () - if static_args: - # This is reached by `entry_point.compile(..., **static_args)` before delegating - # to `CompiledProgramsPool.compile()` below. - # Keep this in sync with `compile` in compiled_program.py. - static_params = tuple(static_args.keys()) self.__dict__["_compiled_programs"] = self._make_compiled_programs_pool( - static_params=static_params, + static_params=tuple(static_args.keys()), static_domains=self.compilation_options.static_domains, ) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index e9fb4b04bc..5098ad1c26 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -666,12 +666,10 @@ def compile( """ for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): - # Only inject a `StaticArg` descriptor mapping when static arguments were - # actually given. argument_descriptors: ArgStaticDescriptorsByType = {} if static_args: - # Calls from `Program.compile()`/`FieldOperator.compile()` reach this. - # Keep this in sync with `compile` in decorator.py. + # Calls from `Program.compile()` / `FieldOperator.compile()`. + # Keep this in sync with `_make_compiled_programs_pool()` in decorator.py. argument_descriptors[arguments.StaticArg] = dict( zip( static_args.keys(), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 8620c170a5..94e6924df2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -124,26 +124,23 @@ def test_compile(cartesian_case, compile_testee): assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray) -def test_compile_with_empty_static_args_matches_no_static_args(cartesian_case, compile_testee): +def test_precompile_and_jit_have_same_arg_descr_mapping(cartesian_case, compile_testee): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") empty_static_args = {} - decorator_path_testee = compile_testee.with_backend(cartesian_case.backend) - decorator_path_testee.compile( - offset_provider=cartesian_case.offset_provider, **empty_static_args + precompiled_testee = compile_testee.with_backend(cartesian_case.backend) + precompiled_testee.compile(offset_provider=cartesian_case.offset_provider, **empty_static_args) + precompiled_arg_descr_mapping = ( + precompiled_testee._compiled_programs.argument_descriptor_mapping ) - decorator_path_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping - assert arguments.StaticArg not in decorator_path_mapping - compiled_program_path_testee = compile_testee.with_backend(cartesian_case.backend) - compiled_program_path_pool = compiled_program_path_testee._make_compiled_programs_pool( - static_params=(), static_domains=False - ) - compiled_program_path_pool.argument_descriptor_mapping = None - compiled_program_path_pool.compile(offset_providers=[cartesian_case.offset_provider]) + jit_testee = compile_testee.with_backend(cartesian_case.backend) + args, kwargs = cases.get_default_data(cartesian_case, jit_testee) + jit_testee(*args, offset_provider=cartesian_case.offset_provider, **kwargs) + jit_arg_descr_mapping = jit_testee._compiled_programs.argument_descriptor_mapping - assert compiled_program_path_pool.argument_descriptor_mapping == decorator_path_mapping + assert precompiled_arg_descr_mapping == jit_arg_descr_mapping def test_compile_twice_same_program_errors(cartesian_case, compile_testee):