Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def _make_compiled_programs_pool(
argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {}

if static_params:
# Static parameter names come from either JIT compilation options or
# `entry_point.compile(..., **static_args)` while initializing the pool.
# Keep this in sync with the descriptors created in `CompiledProgramsPool.compile()`.
argument_descriptor_mapping[arguments.StaticArg] = static_params

if static_domains:
Expand Down
21 changes: 12 additions & 9 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,18 @@ def compile(
"""
for offset_provider in offset_providers: # not included in product for better type checking
for static_values in itertools.product(*static_args.values()):
argument_descriptors: ArgStaticDescriptorsByType = {}
if static_args:
# Calls from `Program.compile()` / `FieldOperator.compile()`.
# Keep this in sync with `_make_compiled_programs_pool()` in decorator.py.
argument_descriptors[arguments.StaticArg] = dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
)
self._compile_variant(
argument_descriptors={
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
},
argument_descriptors=argument_descriptors,
offset_provider=offset_provider,
)
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,25 @@ def test_compile(cartesian_case, compile_testee):
assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray)


def test_precompile_and_jit_have_same_arg_descr_mapping(cartesian_case, compile_testee):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

empty_static_args = {}
precompiled_testee = compile_testee.with_backend(cartesian_case.backend)
precompiled_testee.compile(offset_provider=cartesian_case.offset_provider, **empty_static_args)
precompiled_arg_descr_mapping = (
precompiled_testee._compiled_programs.argument_descriptor_mapping
)

jit_testee = compile_testee.with_backend(cartesian_case.backend)
args, kwargs = cases.get_default_data(cartesian_case, jit_testee)
jit_testee(*args, offset_provider=cartesian_case.offset_provider, **kwargs)
jit_arg_descr_mapping = jit_testee._compiled_programs.argument_descriptor_mapping

assert precompiled_arg_descr_mapping == jit_arg_descr_mapping


def test_compile_twice_same_program_errors(cartesian_case, compile_testee):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")
Expand Down