diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2310b72aee..02f303ffbf 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -141,6 +141,9 @@ def _make_compiled_programs_pool( argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} if static_params: + # Static parameter names come from either JIT compilation options or + # `entry_point.compile(..., **static_args)` while initializing the pool. + # Keep this in sync with the descriptors created in `CompiledProgramsPool.compile()`. argument_descriptor_mapping[arguments.StaticArg] = static_params if static_domains: diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f09ae16bd9..5098ad1c26 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -666,15 +666,18 @@ def compile( """ for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): + argument_descriptors: ArgStaticDescriptorsByType = {} + if static_args: + # Calls from `Program.compile()` / `FieldOperator.compile()`. + # Keep this in sync with `_make_compiled_programs_pool()` in decorator.py. + argument_descriptors[arguments.StaticArg] = dict( + zip( + static_args.keys(), + [arguments.StaticArg(value=v) for v in static_values], + strict=True, + ) + ) self._compile_variant( - argument_descriptors={ - arguments.StaticArg: dict( - zip( - static_args.keys(), - [arguments.StaticArg(value=v) for v in static_values], - strict=True, - ) - ), - }, + argument_descriptors=argument_descriptors, offset_provider=offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 52e920a48b..94e6924df2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -124,6 +124,25 @@ def test_compile(cartesian_case, compile_testee): assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray) +def test_precompile_and_jit_have_same_arg_descr_mapping(cartesian_case, compile_testee): + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + empty_static_args = {} + precompiled_testee = compile_testee.with_backend(cartesian_case.backend) + precompiled_testee.compile(offset_provider=cartesian_case.offset_provider, **empty_static_args) + precompiled_arg_descr_mapping = ( + precompiled_testee._compiled_programs.argument_descriptor_mapping + ) + + jit_testee = compile_testee.with_backend(cartesian_case.backend) + args, kwargs = cases.get_default_data(cartesian_case, jit_testee) + jit_testee(*args, offset_provider=cartesian_case.offset_provider, **kwargs) + jit_arg_descr_mapping = jit_testee._compiled_programs.argument_descriptor_mapping + + assert precompiled_arg_descr_mapping == jit_arg_descr_mapping + + def test_compile_twice_same_program_errors(cartesian_case, compile_testee): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.")