Skip to content

fix: restore jax.tree_util.Partial in omega() for PowerLaw+JAX#352

Merged
Jammy2211 merged 1 commit intomainfrom
feature/group-features
Apr 14, 2026
Merged

fix: restore jax.tree_util.Partial in omega() for PowerLaw+JAX#352
Jammy2211 merged 1 commit intomainfrom
feature/group-features

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fix a regression in omega() introduced by commit 632ae82, which replaced jax.tree_util.Partial with functools.partial, breaking PowerLaw deflection calculations under JAX.

scan inside omega() is wrapped in jax.jit, so its first argument (the body function) is traced as a pytree. functools.partial is not a registered pytree and fails with:

```
TypeError: Error interpreting argument to as an abstract
array. The problematic value is of type <class 'functools.partial'>
```

jax.tree_util.Partial is still present in JAX 0.9.2, so the "0.5+ compat" rationale in 632ae82 was incorrect. This PR reverts that single-line change.

Reproducer

autolens_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py (which uses al.mp.PowerLaw + a JAX-vmap'd Fitness) is a direct reproducer. It fails on main and passes after this fix.

The bug also blocks all mass_total stages of the group SLaM pipelines in autolens_workspace/scripts/group/features/ that use PowerLaw.

Test Plan

  • `python autolens_workspace_test/scripts/jax_likelihood_functions/imaging/lp.py` — passes numerical assertion `-1.34797827e+09`
  • Group SLaM MGE + linear_light_profiles mass_total stages pass under `PYAUTO_TEST_MODE=2`

🤖 Generated with Claude Code

Commit 632ae82 replaced jax.tree_util.Partial with functools.partial
in omega(), breaking PowerLaw deflection calculations under JAX. scan is
wrapped in jax.jit, so its first argument (the body function) is traced
as a pytree — functools.partial is not a registered pytree and fails
with:

  TypeError: Error interpreting argument to <function scan> as an
  abstract array. The problematic value is of type <class
  'functools.partial'>

jax.tree_util.Partial is still present in JAX 0.9.2, so the "0.5+ compat"
rationale in 632ae82 was incorrect. Reverting just that one line.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autolens_workspace#61

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant