[NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator#3846
Draft
ecnal-cienet wants to merge 4 commits intomainfrom
Draft
[NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator#3846ecnal-cienet wants to merge 4 commits intomainfrom
ecnal-cienet wants to merge 4 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
4 tasks
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Part 1 — sharding diagnostics: - maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input) - run_sharding_dump.py: add --pure_nnx flag Part 2 — post-training bugfixes (NNX-side): - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts the individual image/audio/mask fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad inside nnx.jit (Tunix's default trainer), which raises "graph structure of a node added to cached_partial was mutated" — refactor to jax.value_and_grad with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint + dtype-cast compat shims for jax 0.9 / tpu_inference Linen<->NNX checkpoint conversion utility and validation tool moved to a follow-up PR (PR4.5) to keep this change reviewable.
Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape
both directions; round-trips preserve byte values.
Top-level key mapping:
- Linen params/params/<model> <-> NNX model/<model> (double-nesting,
{value:} wrappers).
- Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu).
- Linen step <-> NNX optimizer/step.
Layer structure:
- scan_layers=True (default): stack layers_N -> layers tensor.
- scan_layers=False: rename layers_N -> integer-keyed layers/{N}.
NNX->Linen direction auto-detects which layer layout the source uses.
--direction=auto picks Linen vs NNX from top-level keys.
Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Comparison utility split into PR4.6.
Structural + numerical comparison utility for validating checkpoint
parity. Supports any combination of Linen and NNX checkpoints:
- Linen vs NNX (cross-format)
- Linen vs Linen (same-format)
- NNX vs NNX (same-format)
Auto-detects format per checkpoint and applies matching normalization
(strip {value:} wrappers on NNX, collapse double 'params' nesting on
Linen, filter NNX-only RNG entries). Cross-format-only transforms
(layer-axis transpose between Linen per-layer arrays and NNX stacked
tensors) only run for Linen<->NNX comparisons.
Reports tree-structure mismatches, shape mismatches, and (with
--compare_values) numerical diffs at configurable --atol / --rtol.
Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Stacks on PR4.5 (converter); the two are
file-disjoint.
ae4b21f to
fdab1ad
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NNX Migration Route Map
pure_nnxflag,init_state_fn,TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)get_abstract_state_nnx,get_named_sharding_nnx,set_named_sharding_nnx,get_partition_spec_nnx,get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
4.6. 🔄 [This PR] Linen↔NNX checkpoint comparator (stacked on PR4.5; the two are file-disjoint).
9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
custom_vjpfor NNX.True; regenerate sharding goldens; flip back integration-testpure_nnx=Falseannotations.Description
Companion to PR4.5 (the converter). This PR adds
compare_linen_nnx_checkpoint.py— a structural and numerical comparison utility used during the NNX migration to validate parity between Linen and NNX checkpoints. Split out of the original PR4.5 bundle on 2026-05-07 so each PR stays narrowly reviewable. PR4.5 and PR4.6 are file-disjoint, and the comparator does not import the converter — PR4.6 only stacks on PR4.5 so reviewers see the delta cleanly.This is a pure addition — no existing files are modified, no production-code paths reference the utility, and no Linen or NNX runtime behavior changes. PR5+ do not depend on this branch.
Diff: +1110 / −0 across 2 new files.
What it does
src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py:{value: ...}wrappers on NNXparamsnesting on Linen--compare_values) numerical diffs at configurable--atol/--rtol.Tests
tests/unit/compare_linen_nnx_checkpoint_test.py— pure-CPU, 60 cases. Covers structural diffing, shape mismatches, numerical comparison at configurable tolerances, and each of the cross- / same-format combinations.Existing tests untouched.
Stats
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.