Skip to content

Fix AOT compilation with donated Jittable arguments#334

Open
edawite wants to merge 1 commit into
google-deepmind:mainfrom
edawite:codex/fix-jittable-aot-donation
Open

Fix AOT compilation with donated Jittable arguments#334
edawite wants to merge 1 commit into
google-deepmind:mainfrom
edawite:codex/fix-jittable-aot-donation

Conversation

@edawite

@edawite edawite commented Jun 15, 2026

Copy link
Copy Markdown

Summary

Treat jax.stages.ArgInfo placeholders as dynamic JAX data when flattening Distrax Jittable objects.

During ahead-of-time staging with donate_argnums, JAX replaces donated array leaves with ArgInfo descriptors. Distrax currently classifies those descriptors as static metadata, so the compiled input pytree records different auxiliary data from the runtime distribution and rejects the call.

Keeping ArgInfo in the dynamic children preserves a stable pytree definition between tracing and compiled invocation. The regression test covers the complete trace(...).lower().compile() path with a donated Jittable argument.

Fixes #308.

Testing

  • python -m pytest distrax/_src/utils/jittable_test.py -q with JAX 0.7.2: 7 passed
  • Same focused suite with JAX 0.10.1: 7 passed
  • Original Categorical AOT donation reproduction with JAX 0.10.1
  • python -m ruff check --no-cache distrax/_src/utils/jittable.py distrax/_src/utils/jittable_test.py
  • git diff --check

The full Windows suite was attempted but stalled in parallel JAX workers without reporting a test failure; GitHub CI is the authoritative full-suite run.

Treat JAX ArgInfo placeholders as dynamic pytree children while staging ahead-of-time compiled functions. This keeps Distrax Jittable metadata stable between tracing and compiled invocation when donate_argnums is used, with a regression test covering the compiled-call path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unexpected behaviour when passing a Distribution to a function with donate_argnums

1 participant