Skip to content

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844

Draft
ecnal-cienet wants to merge 9 commits intomainfrom
feat/nnx-aqt-maxengine
Draft

[NNX] NNX migration prep (9.5/N): NNX + AQT in MaxEngine + serve-mode reload + gpt3 prefill fix#3844
ecnal-cienet wants to merge 9 commits intomainfrom
feat/nnx-aqt-maxengine

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented May 7, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR NNX migration prep (1/N): pure_nnx flag and init_state_fn scaffolding #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR NNX migration prep (2/N): NNX utils and sharding utilities #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop #3500)
  4. ✅ Sharding diagnostics on NNX, plus post-training bugfixes that surfaced once the NNX path got exercised end-to-end. (PR [NNX] NNX migration prep (4/N): sharding tools and post-training fixes #3652)
    4.5. ✅ Linen↔NNX checkpoint converter. (PR [NNX] NNX migration prep (4.5/N): Linen<->NNX checkpoint converter #3843)
    4.6. ❌ Linen↔NNX checkpoint comparator (sibling branch on PR4.5).
  5. ✅ NNX correctness fixes, feature enablements, and vocab tiling on NNX.
  6. ✅ NNX-native DPO.
  7. ✅ NNX-native MaxEngine inference. (PR [NNX] NNX migration prep (7/N): NNX-native MaxEngine inference #3821)
  8. ✅ NNX-native LoRA + GRPO. (PR [NNX] NNX migration prep (8/N): NNX native lora grpo #3824)
  9. ✅ NNX-aware QK-Clip + remaining checkpoint utilities. (PR [NNX] NNX migration prep (9/N): NNX-aware QK-Clip + checkpoint utilities #3836)
    9.5. 🔄 [This PR] NNX + AQT in MaxEngine: pre-quantized loading (checkpoint_is_quantized=True) via quant_mode_str="serve", convert-on-load via TRAIN-mode AQT, and a pre-existing gpt3 prefill / non-TRAIN inference bug fix (Gpt3MultiHeadAttention.__call__ was missing update_kv_caches). Split out of original PR9 on 2026-05-07. Stacks on PR9; PR9 and PR9.5 are file-disjoint.
  10. ❌ Vocab tiling custom_vjp for NNX.
  11. ❌ Set NNX defaults to True; regenerate sharding goldens; flip back integration-test pure_nnx=False annotations.
  12. ❌ Delete Linen-specific code paths and NNX compatibility flags.

Description

This PR migrates the NNX + AQT integration in MaxEngine so pure_nnx=True can both load pre-quantized checkpoints directly and convert full-precision checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill / autoregressive bug surfaced by the AQT end-to-end validation.

Originally part of PR9; split into its own follow-up on 2026-05-07 because the AQT chain has 5 chained QTensor / sharding / restore-target bugs that warrant focused review independent of the QK-Clip + checkpoint-utilities work in PR9.

Diff: +637 / −57 across 8 files (5 source + 3 tests, of which 2 are new). Stacks on PR9 (feat/nnx-qk-clip-and-checkpoint-utils); both halves are file-disjoint, so PR9.5 could equally well sibling-target PR8 once PR9 lands.

Part 1: NNX + AQT in MaxEngine — two paths

  • Pre-quantized load (checkpoint_is_quantized=True): from_pretrained(quant_mode_str="serve") reads the on-disk qrhs.frozen directly so AQT layers don't materialize the full-precision kernel.
  • Convert-on-load (checkpoint_is_quantized=False + quantization=int8): full-precision kernels load normally, AQT layers quantize per-forward against them. Same numerical result as serve mode for absmax calibration; slower but correct.

Threaded quant_mode_str ("train" | "convert" | "serve") through from_configcreate_modelget_nnx_create_model_fncreate_nnx_abstract_modelfrom_pretrained. Default "train" preserves existing callers; "serve" propagates to configure_quantization. maxengine.__init__ selects the quant mode from config.checkpoint_is_quantized; _load_params_nnx drops its NotImplementedError.

Part 2: _load_and_quantize_nnx — NNX whole-model convert path

src/maxtext/utils/layerwise_quantization.py:

  • Loads full-precision in TRAIN mode via from_pretrained(quant_mode_str="train").
  • Builds a separate CONVERT-mode model and copies kernels into it via _copy_kernel_leaves_.
  • Runs a forward — the ToNNX(AqtDotGeneral) bridge auto-captures qrhs.frozen per flax/nnx/bridge/wrappers.py:230-243.
  • Strips kernels at quantized paths via _strip_kernels_at_quantized_paths.
  • Saves serve-mode-shaped state.

The DeepSeek-only assertion is lifted for NNX since the whole-model approach is decoder-agnostic.

Part 3: Sharding helpers + from_pretrained QTensor handling — 5 chained fixes

The serve-mode reload chain hit five surface bugs in NNX/AQT-serve interaction. All closed here:

  1. Sharding helper for QTensor leaves (maxtext_utils.get_nnx_named_sharding_with_scan_axis): emits a parallel-tree of replicated NamedSharding leaves when a Variable's value is a composite pytree (AQT serve-mode QTensor with an int8 qvalue leaf and a list of bf16 scale leaves). Previously returned the Variable as-is when val had no .shape, leaving ShapeDtypeStruct leaves where the downstream jax.ShapeDtypeStruct(..., sharding=s) call expected Shardings.
  2. Variable indexing on QTensor: _build_value_target, _free_device_memory, and _unwrap_for_align in from_pretrained now use Variable.get_value() instead of v[...]. QTensor's __getitem__ calls qvalue[idx] on a LogicallyPartitioned wrapper — that fails. Composite leaves now flow through unchanged.
  3. Filter widening: both from_pretrained's NNX-detection branch and maxengine._load_params_nnx previously filtered sharded_state to nnx.Param only, dropping AQT qrhs.frozen leaves (which are stored as a separate aqt Variable type, not a Param subclass). They now filter to "everything except nnx.RngState and nnx.Cache". _load_params_nnx also adds a 4-way nnx.split + overlay step so the loaded aqt-typed leaves survive into _nnx_rest_state.
  4. Partitioned-unwrap for QTensor leaf paths: the abstract NNX model's QTensor qvalue / scale come back wrapped in LogicallyPartitioned. Under jax.tree.flatten_with_path, that wrapper adds an extra GetAttrKey('value') to every leaf — so the restore target's tree path looks like qrhs.frozen.value.qvalue.value, but _load_and_quantize_nnx flushes the QTensor as plain arrays at qrhs.frozen.value.qvalue (no extra .value). Orbax silently filled the missing paths with the model's init values (qvalue=0, scale=1 — exactly the symptom we saw). _build_value_target now strips Partitioned wrappers around composite-leaf values so the tree path matches the on-disk layout.
  5. Shape-alignment crash on QTensor: _walk_align previously called ckpt_arr.shape on every leaf, which hit qvalue.shape on a LogicallyPartitioned. Composite leaves are now passed through unchanged in the per-axis alignment dispatch — quantized payloads are saved at the exact model shape, no alignment needed.

Also dropped a redundant jax.set_mesh(mesh) wrap inside create_nnx_abstract_model's nnx.eval_shape call. Under jax.set_mesh, Flax 0.12.6's _to_variable rejects serve-mode AQT variables because they hit NamedSharding(mesh=AbstractMesh, spec=None). Sharding is resolved afterwards via get_nnx_named_sharding_with_scan_axis, so the wrap was redundant; removing it lets serve-mode model construction reach the orbax restore step.

Part 4: gpt3 Prefill / Autoregressive Fix

A pre-existing gpt3 bug surfaced when validating the AQT pre-quantized load end-to-end: Gpt3MultiHeadAttention.__call__ (src/maxtext/models/gpt3.py) invoked self.attention_op(...) without ever calling update_kv_caches to build cached_values, so any non-TRAIN forward (prefill or autoregressive) tripped the assert prefill_kv_cache check at the top of AttentionOp.__call__. Affects every gpt3 inference call regardless of quantization; included here because the AQT e2e validation exercises this path.

Mirrors the standard Attention class plumbing in attentions.py:

  • __init__ constructs a KVCache_0 module when model_mode != MODEL_MODE_TRAIN, sized from max_prefill_predict_length / max_target_length / batch / num_heads / head_dim.
  • __init__ also threads max_prefill_predict_length into AttentionOp (was previously left at the -1 default, breaking the prefill-cache shape sizing).
  • __call__ calls self.KVCache_0(...) to produce [prefill_kv_cache, ar_kv_cache] and passes that as the cached_values argument to attention_op.

TRAIN-mode shape unchanged (KVCache_0 stays None, no extra parameters).

Tests

New unit tests (tests/unit/layerwise_quantization_nnx_test.py, 3 tests): _strip_kernels_at_quantized_paths covering quantized-kernel removal, non-quantized-kernel preservation (norms, embeddings), mixed-shape trees.

New unit tests (tests/unit/aqt_serve_roundtrip_nnx_test.py, 1 test — end-to-end regression): builds a small NNX model in CONVERT mode with int8, runs a forward to populate qrhs.frozen via the ToNNX bridge, saves the serve-mode-shape state to a tmp local orbax checkpoint, reloads via from_pretrained(quant_mode_str="serve"), and asserts every saved qrhs.frozen.qvalue array byte-matches what came back. Guards the full chain of QTensor / Partitioned / filter fixes. Runs on CPU under DECOUPLE_GCLOUD=TRUE.

Modified test (tests/unit/maxengine_test.py): test_quantize_raises_for_nnx (asserted NotImplementedError) replaced by test_quantize_passes_gate_for_nnx (verifies the convert-on-load path reaches from_pretrained in TRAIN mode). Added test_load_pre_quantized_nnx_passes_quant_gate (verifies checkpoint_is_quantized=True reaches from_pretrained in SERVE mode) and test_quantized_prefill_nnx_train_mode (full prefill with quantization=int8 + random params + TRAIN mode produces finite logits — real numerical verification).

Existing Linen tests: untouched and still pass.

End-to-end on TPU (gpt3-52k): convert-mode forward → qrhs.frozen extraction → serve-mode-shape save to orbax → reload via from_pretrained(quant_mode_str="serve") → quantized prefill forward → finite logits. Save-side qvalue nonzero_frac=0.99x; reload preserves bytes exactly.

Linting: bash lint.sh — pyink + pylint clean.

Stats

  • Diff: +637 / −57 across 8 files (2 new, 6 modified).
  • Production-code impact: Linen behavior preserved; every NNX edit is gated on config.pure_nnx or runtime state-shape detection. The gpt3 KVCache plumbing is gated on model_mode != MODEL_MODE_TRAIN, so TRAIN-mode shape is unchanged.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Part 1 — sharding diagnostics:
- maxtext_utils.py: extend print_shardings_params to support NNX (nnx.State input)
- run_sharding_dump.py: add --pure_nnx flag

Part 2 — post-training bugfixes (NNX-side):
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing
  the whole object as multimodal_input= kwarg; NNXDecoder only accepts the
  individual image/audio/mask fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
  (callable() check before invoking learning_rate_fn)
- train_distill.py / train_sft.py / train_rl.py: avoid nesting nnx.value_and_grad
  inside nnx.jit (Tunix's default trainer), which raises "graph structure of a
  node added to cached_partial was mutated" — refactor to jax.value_and_grad
  with explicit nnx.split / nnx.merge; train_rl.py also adds with_sharding_constraint
  + dtype-cast compat shims for jax 0.9 / tpu_inference

Linen<->NNX checkpoint conversion utility and validation tool moved to a
follow-up PR (PR4.5) to keep this change reviewable.
Bidirectional Linen <-> NNX checkpoint conversion. Same on-disk shape
both directions; round-trips preserve byte values.

Top-level key mapping:
- Linen params/params/<model> <-> NNX model/<model> (double-nesting,
  {value:} wrappers).
- Linen opt_state <-> NNX optimizer/opt_state (params level on mu/nu).
- Linen step <-> NNX optimizer/step.

Layer structure:
- scan_layers=True (default): stack layers_N -> layers tensor.
- scan_layers=False: rename layers_N -> integer-keyed layers/{N}.

NNX->Linen direction auto-detects which layer layout the source uses.
--direction=auto picks Linen vs NNX from top-level keys.

Pure utility addition. No production-code dependencies; PR5+ do not
depend on this branch. Comparison utility split into PR4.6.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch from e173538 to 7af8157 Compare May 8, 2026 16:19
Bug fixes (run as no-op while pure_nnx=False stays default):
- nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing;
  call from ToLinen after nnx.update to fix "Cannot extract graph node from
  different trace level" when grad tracers leak into Variable._trace_state.
- gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout =
  linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError.
- normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept
  shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity.
- attentions.py / qwen3.py: callsites eps= -> epsilon=.
- moe.py: per_expert_scale block moved into the unfused-kernel else branch
  (was scaling wo even when fused_kernel was active).
- models.py: build MTP block as MultiTokenPredictionBlock(...) directly
  (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole
  to NNXDecoder instead of unpacking 5 fields.
- gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until
  after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan
  carry); use nnx.merge(..., copy=True) to avoid Variable reuse.
- diloco.py: NNX-aware state handling — state.params -> state.model.filter
  (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params
  helper for jax.lax.cond pytree-structure parity.
- train_compile.py: new _collect_nnx_activation_shardings helper (forward
  pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only
  traces __init__); NNX path now passes 2-arg shaped_train_args (no rng);
  diloco path patched to handle the 2-vs-3 length difference.
- muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as
  {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape.
- nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu)
  retain tracers in Linen scope across re-traces. Skip jax.checkpoint and
  use a Python for-loop instead of jax.lax.scan when quantization is FP8.
  Makes FP8 quantization usable on the NNX path.
- train.py (pre-train train_step): return nnx.state(new_state, nnx.Not
  (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for
  QK-Clip) don't break leaf-count parity with state_mesh_shardings.
- llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer
  _norm RMSNorm (was missing on this norm only).
- base.yml: add 4 pipeline-related logical_axis_rules — layers_outside
  _pipeline, layers_per_stage, num_activations, circular_repeats. Additive,
  no-op without use_nnx_pipeline=True.

NNX feature enablements (clear all 17 "Pure NNX support has not been
implemented yet" NotImplementedError sites by routing Linen-coupled
utilities to the Linen path; their on-disk format is Linen):
- layerwise_quantization.py (2 sites): operates on Linen-format checkpoints
  via DeepSeek*ToLinen layers.
- lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen
  tree shape; LoRA adapters on disk are Linen.
- standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses
  state.opt_state[0]._replace(mu=..., nu=...) — Linen-only.
- generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and
  _save_decode_checkpoint use state.params["params"]["decoder"] — Linen.
- convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree
  paths (.params['params'], .opt_state.mu['params']).
- maxengine.py (3 sites): inference engine uses state.params and serves
  Linen-format inference checkpoints.
- grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route
  to Linen with a clear log warning since NNX-format checkpoints will fail
  at restore time.

Vocab tiling on NNX (real implementation, not just routing):
- models.py: add Transformer.logits_from_hidden_states on the NNX
  Transformer class — wraps NNXDecoder.apply_output_head with the
  token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states.
- vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis
  via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per
  chunk. The NNX model carries its parameters internally so no explicit
  FSDP gather is needed (unlike the Linen gathered_params pattern). MVP
  uses default autograd; custom_vjp memory-savings optimization is a
  follow-up if backward memory becomes a concern.
- train.py (NNX loss_fn): replace the NotImplementedError with the call
  to vocab_tiling_nnx_loss using hidden_states from intermediates.
- pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1
  and enable_nnx validation guards (no longer needed).

DPO + NNX retained as NotImplementedError but with a much more informative
message (points users at pure_nnx=False workaround). Full implementation
is deferred — needs a new TrainState shape carrying both policy and
reference NNX models plus an NNX dpo_loss_fn.

Stats: 26 source files modified, +406 / -171 lines. Linen invariant
verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False;
Linen-path UTs unaffected (3 pre-existing failures on the parent branch
remain unchanged — sharding_compare_test::deepseek2-16b,
optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two
_slices x2). All "Pure NNX support has not been implemented yet"
NotImplementedError sites cleared (was 17, now 0).
Implements NNX-native DPO so that the pure_nnx=True training path no
longer raises NotImplementedError on use_dpo runs. The Linen DPO
overlay pattern (model.apply(params=..., reference_params=...)) does
not translate to NNX modules, which carry their parameters internally.
Instead the policy and reference models are held as separate
nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx
runs both forwards with stop_gradient on the reference logits.

TrainStateNNX:
- Add optional `reference_model: nnx.Module` field. apply_gradients
  continues to update only `self.model`, leaving `self.reference_model`
  bit-identical across steps.

dpo_utils.py:
- Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params,
  reference_model, is_train=True). Signature mirrors the Linen
  dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's
  dispatcher (dropout_rng / params slots are unused for NNX; carried
  for parity, and reference_model is passed as the single
  extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over
  the policy, no gradient flows to the reference model's nnx.Param
  leaves; the explicit jax.lax.stop_gradient on ref_logits is a
  belt-and-braces guard.
- Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include
  indexer_loss=0.0 and mtp_loss=0.0 in aux so the
  gradient_accumulation aux pytree shape matches the non-DPO loss_fn.

train.py:
- Drop the NotImplementedError in train_step's NNX branch. When
  use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as
  extra_dpo_args; otherwise use the regular loss_fn. eval_step gains
  the same dispatch.
- diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init
  block, so both the GA and non-GA NNX paths route DPO identically.
- Checkpoint-save _split_dpo_state stripping is now Linen-only;
  TrainStateNNX saves whole (reference_model included) — the step-0
  reload later overwrites reference_model from the step-0 checkpoint.

train_utils.py:
- NNX init_state_fn materializes a frozen reference_model alongside
  the policy when config.use_dpo. Both are constructed by
  _create_model_partial() with config.init_weights_seed, so they
  start identical (standard DPO practice) until the step-0 reload.
- Step-0 checkpoint reload: copy step0_state["model"] into
  state["reference_model"]. Linen path unchanged.

Tests:
- New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX
  reference_model init/hasattr semantics; apply_gradients leaves
  reference bit-identical; aux key set; identical policy/reference
  yields loss=log(2) and reward_accuracy=0.0 (strict > on equal
  logratios); dropout_rng/params slots are signature-compat only;
  nnx.value_and_grad(argnums=0) over the policy yields finite grads
  on policy params only.
- train_nnx_test.py: drop the two stale negative tests
  (vocab_tiling_raises_not_implemented,
  train_step_dpo_raises_for_nnx) — both features are now real.

Stats: 4 source files + 2 test files, +199/-22 source lines. Linen
DPO path behaviorally unchanged (only adds two harmless aux-dict
keys); NNX non-DPO path unchanged (all changes gated on
config.use_dpo).
…e.py)

PR5 audited maxengine.py and routed the inference path to the Linen
implementation regardless of pure_nnx, with a comment block explaining
that "the flag affects training, not inference serving." That kept the
Linen serving path unchanged but meant pure_nnx=True users silently got
the Linen engine. This change replaces the route with a real NNX flow:
when config.pure_nnx=True, the engine builds an NNX Transformer, splits
out (params, cache, rest) with nnx.split, and at every JIT body merges
the model concretely with nnx.merge to run the forward pass. Linen is
preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:`
and pure_nnx=False is still the default.

maxengine.py (__init__):
- Build two abstract NNX Transformers on the NNX path: self.model with
  model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar
  with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on,
  decode_state shape). Both are needed because NNX cache vars inherit
  CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode,
  and bulk_insert searches for the substring "cache_batch" in the
  AR-mode logical-axes tuple. nnx.eval_shape is called directly inside
  nn_partitioning.axis_rules rather than through create_nnx_abstract_model
  to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only
  axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh).
- Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT
  bodies can pass (params, cache, rest) separately to nnx.merge. The
  rest slot (RNG vars etc.) is materialized concretely in load_params.

maxengine.py (cache adapter + _nnx_run_model):
- bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the
  cache via tree_map_with_path and switch on path[-1].key (the cache
  variable name like "cached_prefill_key"). Linen mutable cache is a
  plain nested dict. NNX Cache state would expose a ".value" accessor
  at that position. Bridge via nnx.State.to_pure_dict() (after the
  model run) and nnx.replace_by_pure_dict (before nnx.merge), so the
  cache plumbing helpers see the same shape on both paths.
- Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True)
  -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True
  avoids reusing Variable objects across traces (TraceContextError),
  mirroring train.py's diff_wrapper workaround.
- Add _nnx_cache_state_template / _nnx_init_cache_dict helpers
  parametrised by mode so prefill (batch 1) and decode_state (batch N)
  pull from the right abstract model.

maxengine.py (load_params):
- New _load_params_nnx: accepts user-provided NNX-shape params or loads
  via from_pretrained. For user-provided params, materializes a concrete
  model once via _create_model_fn() to capture a real rest state for
  nnx.merge (wasteful but simple; the from_pretrained branch avoids
  this). Refreshes self.graphdef from the concrete model so subsequent
  merges line up exactly.
- Builds self.abstract_params, populates self.prefill_kv_cache_annotations
  and self.kv_cache_annotations (using model_ar for the latter so
  bulk_insert's substring lookup hits), wraps both into NamedSharding.
- pure_nnx + quantization, pure_nnx + LoRA, pure_nnx +
  stack_prefill_result_cache=True, pure_nnx + prefill_multisampling,
  and pure_nnx + prefill_concat raise NotImplementedError for now;
  the Linen path is the workaround. AOT compilation
  (aot_compile / _compile_generate_and_get_layouts) is not gated and
  may work as-is; not exercised by tests yet.

maxengine.py (init_decode_state, _prefill_jit, _generate_jit):
- _init_decode_state_nnx zero-initializes a pure-dict cache from
  model_ar (so the leading batch dim matches generate's input shape)
  and builds kv_cache_annotations_named per leaf by reading
  nnx.Cache.metadata. Tries "out_sharding", "sharding", and
  "sharding_names" because Flax 0.12.6 renamed these.
- _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch
  that calls _nnx_run_model in place of self.model.apply with
  mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict
  cache directly (no params|{"cache":...} dict-merge — params is an
  nnx.State, not a dict).

maxtext_utils.py:
- New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx
  that mirror the Linen helpers' return shape (per-leaf PartitionSpec
  tree). Both delegate to _nnx_cache_partition_specs which extracts
  nnx.Cache state via nnx.split, calls
  get_nnx_named_sharding_with_scan_axis inside
  nn_partitioning.axis_rules so logical axes ("layers", "cache_batch",
  "norm", ...) resolve to physical mesh axes, and converts the result
  to a pure-dict tree.

tests/unit/maxengine_test.py:
- New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and
  per-layer cache shape checks), test_basic_decode_nnx (4-step generate
  with next_pos advancement check), test_quantize_raises_for_nnx,
  test_lora_raises_for_nnx.
- New test_linen_nnx_parity_prefill: bridges Linen-init params into
  the NNX engine via linen_nnx_converter (convert_linen_to_nnx ->
  _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the
  NNX engine's prefill matches Linen on the same weights — logits
  within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses
  bf16 compute) and exact greedy first-token argmax.
- Existing Linen tests untouched.

Test summary: 9 passed, 1 skipped (test_chunked_prefill is a
pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink
all green.
Closes the QK-Clip TODO and migrates the remaining Linen-only
checkpoint utilities to NNX. Linen paths preserved byte-for-byte
(every NNX edit is gated on `config.pure_nnx` or runtime state-shape
detection).

QK-Clip:
- qk_clip_utils.apply_qk_clip_nnx mutates state.model in place via
  nnx.split -> pure-dict tree_map -> nnx.replace_by_pure_dict ->
  nnx.update. Accepts both the production NNX intermediate shape
  (self_attention.attention_op.max_logits) and the synthetic-fixture
  shape from the existing Linen tests (self_attention.max_logits).
- train.py train_step dispatches to apply_qk_clip_nnx for NNX,
  removing the prior TODO at the QK-Clip call site.

Checkpoint utilities (NNX paths added):
- standalone_checkpointer.checkpoint_loop builds an NNX init_state_fn
  under pure_nnx; add_entropy_to_checkpoint dispatches across Linen
  TrainState, NNX TrainStateNNX Module, and post-split nnx.State
  shapes.
- generate_param_only_checkpoint: NNX init_state_fn under pure_nnx;
  _possibly_unroll_params_nnx slices scanned NNX layers via dict-style
  state mutation; _save_decode_checkpoint_nnx writes a bf16 pure-dict
  tree to orbax. Parallel LoRA decode flow operates on the
  single-nested LoRA delta tree from PR8's get_lora_abstract_state_nnx.
- convert_gpt3_ckpt_from_paxml: parallel NNX state_map keystr
  translation (.params['params']<rest> -> .model<rest>.value, etc.).
  End-to-end paxml -> NNX conversion is wired but not yet validated
  on hardware.

Tests:
- qk_clip_test: 7 new NNX cases covering attention-type guard, MLA
  wq_b/wkv_b math, both intermediate shapes, no-clip-below-threshold,
  missing-stats resilience, Linen<->NNX numeric parity.
- standalone_checkpointer_nnx_test (new): 3 cases for adam mu/nu
  overwrite on TrainStateNNX Module shape, no mutation of state.model
  params, post-split nnx.State shape from setup_training_state.
- generate_param_only_checkpoint_nnx_test (new): 3 cases for scanned
  layer slicing (Llama-style; DeepSeek-style dense+moe split; LoRA
  delta unroll on the single-nested NNX shape).

NNX + AQT in MaxEngine and the layerwise_quantization NNX path are
split into the follow-up PR9.5.
Builds on PR9. Migrates the NNX + AQT integration so MaxEngine can both
load pre-quantized checkpoints directly and convert full-precision
checkpoints to int8 on load. Also bundles a pre-existing gpt3 prefill
bug surfaced by the AQT end-to-end validation.

NNX + AQT in MaxEngine:
- model_creation_utils threads quant_mode_str ("train" | "convert" |
  "serve") through from_config / create_model /
  get_nnx_create_model_fn / create_nnx_abstract_model /
  from_pretrained. Default "train" preserves existing callers; "serve"
  propagates to configure_quantization so AQT layers don't materialize
  the full-precision kernel when the on-disk checkpoint already
  carries qrhs scale factors.
- maxengine.__init__ selects the quant mode from
  config.checkpoint_is_quantized; _load_params_nnx drops its
  NotImplementedError. Two paths: pre-quantized
  (checkpoint_is_quantized=True) loads via quant_mode_str="serve";
  full-precision + quantization=int8 loads in TRAIN mode and AQT
  layers quantize per-forward (same numerical result for absmax
  calibration).
- layerwise_quantization._load_and_quantize_nnx: whole-model NNX
  convert path. Loads full-precision in TRAIN mode, transfers kernels
  into a CONVERT-mode model, runs forward to populate qrhs.frozen via
  the ToNNX(AqtDotGeneral) bridge, strips kernels at quantized paths,
  saves serve-mode-shaped state.

Sharding helpers and from_pretrained QTensor handling (5 chained fixes
that kept the serve-mode reload from working):
- maxtext_utils.get_nnx_named_sharding_with_scan_axis emits a
  parallel-tree of replicated NamedSharding leaves when a Variable's
  value is a composite pytree (AQT serve-mode QTensor with a qvalue
  int8 leaf and a list of bf16 scale leaves).
- model_creation_utils.from_pretrained: drops a redundant
  jax.set_mesh wrap in create_nnx_abstract_model that broke serve-mode
  AQT under Flax 0.12.6. _build_value_target / _free_device_memory /
  _unwrap_for_align use Variable.get_value() instead of v[...]
  indexing for QTensor leaves (QTensor.__getitem__ trips on the
  LogicallyPartitioned wrapper around qvalue). Widens the restore
  filter beyond nnx.Param to cover the aqt-typed qrhs.frozen Variable
  type. Skips QTensor leaves in the per-axis shape-alignment dispatch
  (their saved shape already matches the model). _build_value_target
  strips Partitioned wrappers around composite-leaf values so the
  restore tree path matches the on-disk layout (LogicallyPartitioned
  was adding an extra .value key under each QTensor leaf, which made
  orbax silently fill the path with zero-init values).

gpt3 prefill / autoregressive fix (pre-existing, surfaced here):
- Gpt3MultiHeadAttention.__call__ invoked attention_op(...) without
  ever calling update_kv_caches to build cached_values, so any
  non-TRAIN forward (prefill or autoregressive) tripped the
  `assert prefill_kv_cache` check. Mirror the standard Attention
  plumbing in attentions.py: __init__ constructs a KVCache_0 module
  when model_mode != MODEL_MODE_TRAIN, threads
  max_prefill_predict_length into AttentionOp; __call__ calls
  self.KVCache_0(...) and passes [prefill_kv_cache, ar_kv_cache] as
  cached_values to attention_op. TRAIN-mode shape unchanged.

Tests:
- layerwise_quantization_nnx_test (new): 3 cases for
  _strip_kernels_at_quantized_paths covering quantized removal,
  non-quantized preservation (norms, embeddings), mixed-shape trees.
- aqt_serve_roundtrip_nnx_test (new): end-to-end regression test that
  builds a small NNX model in CONVERT mode with int8, runs a forward
  to populate qrhs.frozen via the ToNNX bridge, saves the
  serve-mode-shape state to a tmp local orbax checkpoint, reloads via
  from_pretrained(quant_mode_str="serve"), and asserts every saved
  qrhs.frozen.qvalue array byte-matches what came back. Guards the
  full chain of QTensor / Partitioned / filter fixes.
- maxengine_test: replaced test_quantize_raises_for_nnx with
  test_quantize_passes_gate_for_nnx; added
  test_load_pre_quantized_nnx_passes_quant_gate and
  test_quantized_prefill_nnx_train_mode (real numerical verification
  with quantization=int8 + random params + TRAIN mode).

End-to-end on TPU (gpt3-52k): convert-mode forward + qrhs.frozen
extraction + serve-mode-shape save + reload via
from_pretrained(quant_mode_str="serve") + maxengine.load_params +
quantized prefill forward all work; loaded qrhs.frozen.qvalue
byte-matches the on-disk state.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-aqt-maxengine branch from 7af8157 to d24b69a Compare May 8, 2026 17:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant