Skip to content

[NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator#3846

Draft
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-linen-nnx-ckpt-compare
Draft

[NNX] NNX migration prep (4.6/N): Linen<->NNX checkpoint comparator#3846
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-linen-nnx-ckpt-compare

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 7, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. 🔄 [This PR] Linen↔NNX checkpoint comparator (stacked on PR4.5; the two are file-disjoint).
  5. ❌ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ❌ NNX-native DPO.
  7. ❌ NNX-native MaxEngine inference.
  8. ❌ NNX-native LoRA + GRPO.
  9. ❌ NNX-aware QK-Clip + remaining checkpoint utilities.
    9.5. ❌ NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

Companion to PR4.5 (the converter). This PR adds compare_linen_nnx_checkpoint.py — a structural and numerical comparison utility used during the NNX migration to validate parity between Linen and NNX checkpoints. Split out of the original PR4.5 bundle on 2026-05-07 so each PR stays narrowly reviewable. PR4.5 and PR4.6 are file-disjoint, and the comparator does not import the converter — PR4.6 only stacks on PR4.5 so reviewers see the delta cleanly.

This is a pure addition — no existing files are modified, no production-code paths reference the utility, and no Linen or NNX runtime behavior changes. PR5+ do not depend on this branch.

Diff: +1110 / −0 across 2 new files.

What it does

src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py:

  • Three comparison modes — Linen vs NNX (cross-format), Linen vs Linen, NNX vs NNX.
  • Format auto-detection — picks Linen vs NNX per checkpoint and applies matching normalization:
    • Strip {value: ...} wrappers on NNX
    • Collapse double params nesting on Linen
    • Filter NNX-only RNG entries
  • Cross-format-only transforms — layer-axis transposition between Linen per-layer arrays and NNX stacked tensors only runs for Linen↔NNX comparisons.
  • What it reports — tree-structure mismatches, shape mismatches, and (with --compare_values) numerical diffs at configurable --atol / --rtol.

Tests

tests/unit/compare_linen_nnx_checkpoint_test.py — pure-CPU, 60 cases. Covers structural diffing, shape mismatches, numerical comparison at configurable tolerances, and each of the cross- / same-format combinations.

Existing tests untouched.

Stats

  • Diff: +1110 / −0 across 2 files (2 new, 0 modified).
  • Production-code impact: none. No existing source file imports the utility.
  • Linen preservation: trivially preserved — no Linen file is touched.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Part 1 — sharding diagnostics:
- maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input)
- run_sharding_dump.py: add --pure_nnx flag

Part 2 — post-training bugfixes (NNX-side):
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing
  the whole object as multimodal_input= kwarg; NNXDecoder only accepts the
  individual image/audio/mask fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad
  inside nnx.jit (Tunix's default trainer), which raises "graph structure of a
  node added to cached_partial was mutated" — refactor to jax.value_and_grad
  with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint
  + dtype-cast compat shims for jax 0.9 / tpu_inference

Linen<->NNX checkpoint conversion utility and validation tool moved to a
follow-up PR (PR4.5) to keep this change reviewable.
Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape
both directions; round-trips preserve byte values.

Top-level key mapping:
- Linen params/params/<model> <-> NNX model/<model> (double-nesting,
  {value:} wrappers).
- Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu).
- Linen step <-> NNX optimizer/step.

Layer structure:
- scan_layers=True (default): stack layers_N -> layers tensor.
- scan_layers=False: rename layers_N -> integer-keyed layers/{N}.

NNX->Linen direction auto-detects which layer layout the source uses.
--direction=auto picks Linen vs NNX from top-level keys.

Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Comparison utility split into PR4.6.
Structural + numerical comparison utility for validating checkpoint
parity. Supports any combination of Linen and NNX checkpoints:
- Linen vs NNX (cross-format)
- Linen vs Linen (same-format)
- NNX vs NNX (same-format)

Auto-detects format per checkpoint and applies matching normalization
(strip {value:} wrappers on NNX, collapse double 'params' nesting on
Linen, filter NNX-only RNG entries). Cross-format-only transforms
(layer-axis transpose between Linen per-layer arrays and NNX stacked
tensors) only run for Linen<->NNX comparisons.

Reports tree-structure mismatches, shape mismatches, and (with
--compare_values) numerical diffs at configurable --atol / --rtol.

Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Stacks on PR4.5 (converter); the two are
file-disjoint.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-nnx-ckpt-compare branch from ae4b21f to fdab1ad Compare May 8, 2026 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant