Skip to content

mlnomadpy/flaxchat

Repository files navigation

flaxchat

A minimal, end-to-end LLM training harness for Google Cloud TPU pods, built on JAX/Flax NNX.

Faithful port of nanochat (Andrej Karpathy's PyTorch GPU trainer) to the JAX ecosystem, with full feature parity plus speculative decoding.

pixi install
pixi run test                     # 148 tests
python -m scripts.run_tinystories # full pipeline on TinyStories

What is this?

flaxchat is the complete LLM pipeline running natively on TPUs and GPUs with automatic data parallelism:

Stage Script Description
Tokenizer scripts/tok_train.py Train BPE tokenizer (rustbpe + tiktoken)
Pretrain scripts/pretrain.py Pretrain GPT on ClimbMix-400B or TinyStories
SFT scripts/sft.py Supervised fine-tuning on conversations
RL scripts/rl.py GRPO/REINFORCE on GSM8K with tool use
Eval scripts/eval.py CORE metric, MMLU, ARC, GSM8K, HumanEval
Chat scripts/chat_web.py FastAPI WebSocket chat UI
Local scripts/run_tinystories.py Full pipeline on TinyStories (laptop or GPU)
Export scripts/convert_to_tflite.py LiteRT/TFLite export for edge deployment

~7,500 lines of readable, hackable JAX/Flax NNX code across 45 Python files.

Architecture

The GPT model faithfully replicates every feature from nanochat:

  • Rotary Embeddings (RoPE) with 100K base theta
  • Group-Query Attention (GQA) via jax.nn.dot_product_attention (hardware-adaptive)
  • QK Normalization with 1.2x scaling for sharper attention
  • ReLU^2 MLP (squared ReLU activation)
  • Value Embeddings (ResFormer-style, alternating layers with gating)
  • Sliding Window Attention per-layer configurable (SSSL pattern)
  • Per-layer Residual Scaling (resid_lambdas + x0_lambdas)
  • Smear — cheap bigram-like token mixing from previous position
  • Backout — subtract mid-layer residual to remove low-level features
  • Logit Soft-capping via tanh(x/15)*15
  • Gradient Checkpointing (nnx.remat, dots_saveable policy)

Optimizer

Mixed Muon + AdamW (ported to optax):

Group Optimizer Notes
Attention/MLP matrices Muon Polar Express orthogonalization + NorMuon variance reduction
Embeddings (wte) AdamW b1=0.8, b2=0.995
LM head AdamW Lower LR for stability
Value embeddings AdamW Half embedding LR
Per-layer scalars AdamW Separate groups for resid_lambdas and x0_lambdas
Smear/Backout AdamW No weight decay

LR schedules: warmup (40 steps) -> constant -> warmdown (65% of total, final 5% of peak). Falls back to pure AdamW on Flax 0.11 where NamedTuple state has issues.

Inference Engine

Four generation modes with increasing performance:

Mode Function Speed Use Case
Padded generate() ~1-2 tok/s Testing, debugging
KV-cached generate_with_cache() ~10-50 tok/s Production, Python loop
Fully JIT generate_fast() ~200+ tok/s TPU inference via jax.lax.while_loop
Speculative generate_speculative() ~2-4x KV-cached Large model + small draft model

Tool Use

The Engine class provides streaming generation with automatic tool execution:

engine = Engine(model, tokenizer)
for token_column, masks in engine.generate(prompt_ids, num_samples=3, max_tokens=256):
    print(tokenizer.decode([token_column[0]]), end="")

When the model outputs <|python_start|>2+2<|python_end|>, the engine:

  1. Tries the safe calculator (use_calculator) for math and string.count()
  2. Falls back to sandboxed Python execution (execute_code) for general code
  3. Injects <|output_start|>4<|output_end|> tokens back into the stream

Speculative Decoding

Use a smaller draft model to propose tokens, verified in batch by the main model:

from flaxchat.engine import generate_speculative

# draft_model: 2-layer, model: 12-layer (same vocab)
tokens = generate_speculative(model, draft_model, prompt_ids, draft_steps=4)

Sandboxed Code Execution

For HumanEval evaluation and RL tool use:

from flaxchat.execution import execute_code

result = execute_code("print(sum(range(10)))", timeout=5.0)
# ExecutionResult(success=True, stdout="45\n", stderr="", error=None)

Process isolation, signal-based timeouts, memory limits (Linux), and dangerous function blocking.

Parallelism (built-in, not optional)

  • compute_init() creates a mesh over ALL available devices automatically
  • Data parallelism: with_sharding_constraint(data, P('data')) in every train step
  • FSDP: shard_model_fsdp() for models exceeding single-device memory
  • Multi-host: jax.distributed.initialize() + jax.make_array_from_process_local_data()
  • No manual all-reduce — JAX SPMD compiler handles gradient synchronization

Configuration

Single-dial depth-based auto-config — all hyperparameters derive from depth:

from flaxchat.config import FlaxChatConfig

config = FlaxChatConfig.from_depth(
    depth=12,            # 12 layers
    aspect_ratio=64,     # base_dim = 12 * 64 = 768
    head_dim=128,        # n_heads = 768 / 128 = 6
    max_seq_len=2048,
    window_pattern="SSSL",
)
# -> 12 layers, 768 dims, 6 heads, ~79M params

Evaluation Tasks

Task Type Source
MMLU Categorical (4-choice) cais/mmlu
ARC-Challenge Categorical allenai/ai2_arc
GSM8K Generative (math + calculator) openai/gsm8k
HumanEval Generative (code + sandbox) openai/humaneval
SpellingBee Generative (tool use) Built-in (30+ templates)
SmolTalk Conversation quality HuggingFaceTB/smol-smoltalk
CORE ICL benchmark (DCLM paper) Hellaswag, ARC, PIQA, Winogrande

Quick Start

Install

pixi install    # or: pip install -e ".[dev]"

Train locally on TinyStories

python -m scripts.run_tinystories --depth=4 --steps=1000

Full pipeline on TPU pod

python -m scripts.pretrain --depth=24 --num-iterations=50000
python -m scripts.sft --base-model=d24
python -m scripts.rl --model=d24
python -m scripts.eval --model=d24 --tasks=all
python -m scripts.chat_web --model=d24

Remote execution

# Kaggle GPU (via kgz)
from flaxchat.remote import KaggleRunner
runner = KaggleRunner("https://...")
runner.run_pipeline(depth=8, steps=5000)

# GCP TPU (via tpuz)
from tpuz import TPU
tpu = TPU("my-tpu", accelerator="v6e-8")
tpu.up()
tpu.setup(extra_pip="flaxchat")
tpu.run("python -m scripts.pretrain --depth=12", sync=".")

Project Structure

flaxchat/
├── flaxchat/                  # Core library (~3,500 LOC)
│   ├── gpt.py                 # GPT model (all nanochat features)
│   ├── optim.py               # Mixed Muon+AdamW optimizer (optax)
│   ├── engine.py              # Inference: padded, cached, JIT, speculative, tool use
│   ├── execution.py           # Sandboxed Python code execution
│   ├── eval.py                # CORE metric + BPB evaluation
│   ├── dataloader.py          # BOS-aligned best-fit packing
│   ├── tokenizer.py           # BPE tokenizer (rustbpe + tiktoken + HF)
│   ├── config.py              # Depth-based auto-config
│   ├── common.py              # Mesh, distributed, logging
│   ├── checkpoint.py          # Orbax checkpoint manager
│   ├── report.py              # Training reports
│   └── dataset.py             # Parquet file listing
├── scripts/                   # Executable scripts (~2,500 LOC)
├── tasks/                     # Evaluation tasks (MMLU, ARC, GSM8K, HumanEval, ...)
├── tests/                     # 148 unit tests
├── docs/                      # GitHub Pages documentation
├── configs/                   # YAML configuration templates
└── runs/                      # Launch scripts

Test Suite

148 tests across 10 test files:

File Tests Coverage
test_model.py 23 GPT architecture, forward pass, loss, gradients, masking, JIT
test_engine.py 17 All 4 gen modes, speculative decoding, tool use
test_optim.py 17 Muon, LR/WD/momentum schedules
test_execution.py 19 Sandbox, timeout, safety guards
test_tokenizer.py 15 BPE train/encode/decode/save/load
test_checkpoint.py 10 Orbax save/load round-trip
test_eval.py 9 CORE, multiple-choice, generative
test_dataloader.py 8 BOS packing, sharding
test_config.py 8 Depth scaling, YAML/JSON
test_common.py 13 Mesh, dtype, distributed

Verified Results

Full Pipeline: Pretrain -> SFT -> RL (Kaggle TPU v5e-8)

End-to-end training pipeline completed on a single Kaggle TPU v5e-8 session:

Stage Dataset Steps Loss Throughput Time
Pretrain FineWeb-Edu (2B tokens) 15,258 10.4 -> 2.94 379K tok/s ~1.5h
SFT SmolTalk (50K conversations) 2,000 2.94 -> 1.82 ~7 min
GRPO GSM8K (math + calculator) 500 RL training running
  • Model: 12L/768d/6h (GQA: 3kv) = 203.7M params
  • Hardware: Kaggle TPU v5e-8 (8 chips, bf16)
  • W&B: irf-sic/flaxchat

Chinchilla Scaling Law (TRC TPU v6e-8)

Nanochat architecture (value embeddings, sliding window, tied embeddings) trained at Chinchilla-optimal token budgets (20× params) on C4 with plain AdamW:

Depth Params Tokens Final Loss Throughput
2 9M 0.18B 7.28 1.4M tok/s
4 28M 0.56B 5.79 1.1M tok/s
6 61M 1.22B 4.24 800K tok/s
8 109M 2.18B 3.95 600K tok/s
12 261M 5.22B 3.42 500K tok/s
16 503M 10.06B 3.39 290K tok/s

Scaling Law

GELU MLP Ablation — d=12 Chinchilla (TRC TPU v6e-8)

The nanochat architecture at d=12 with GELU replacing the default ReLU² in the MLP (Linear → gelu → Linear), trained on C4 to Chinchilla 20× (5.22B tokens) with plain AdamW. 3 seeds for variance estimation.

Seed 0 Seed 1 Seed 2 Mean ± Std
C4 smooth loss 3.1106 3.1097 3.1261 3.1155 ± 0.008
Throughput 703K tok/s 717K tok/s 717K tok/s 712 ± 7 K tok/s
Wall time 2.06 h 2.02 h 2.02 h 2.03 h

Downstream evaluation (seed 0):

Benchmark Score
Wikitext-103 PPL 46.52
LAMBADA accuracy 18.4%
LAMBADA PPL 42.0
HellaSwag 31.4%
ARC-Easy 34.5%

Config: d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window, batch 256, LR 0.01 warmup-cosine-decay. Hardware: single TPU v6e-8 (TRC, europe-west4-a).

Pretrained weights: mlnomad/gelu-d12-chinchilla-261M (Flax/Orbax) · mlnomad/gelu-d12-chinchilla-261M-pytorch (PyTorch, AutoModelForCausalLM compatible)

# Load and generate with 3 lines:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("mlnomad/gelu-d12-chinchilla-261M-pytorch", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

TinyStories Baselines

Hardware Model Throughput Loss Time
Kaggle 2xT4 GPU 8L/256d (18.9M) 55K tok/s 2.20 50 min
Kaggle TPU v5e-8 8L/512d (90.2M) 149K tok/s 2.79 109s

148 tests passing on CPU (local), GPU (Kaggle 2xT4), and TPU (v5e-8).

Comparison with nanochat

nanochat flaxchat
Framework PyTorch JAX/Flax NNX
Hardware NVIDIA GPU (8xH100) TPU pods + GPUs
Distributed DDP + torch.distributed JAX SPMD mesh (automatic)
Compile torch.compile jax.jit / nnx.jit
Attention Flash Attention 3 jax.nn.dot_product_attention
Precision bf16/fp16/fp8 bf16 (TPU native)
Optimizer Custom MuonAdamW Custom optax Muon+AdamW
Checkpointing Pickle-based Orbax (async, cloud-friendly)
Generation KV-cache + Python loop 4 modes: padded, cached, JIT, speculative
Tool use Calculator + Python REPL Calculator + sandboxed REPL
Remote execution N/A Kaggle (kgz) + TPU (tpuz)
Config Manual Depth-based auto-scaling

Acknowledgments

This project is part of the 2026 Q1 TPU Sprint, supported by the Google AI Developer Programs team.

We gratefully acknowledge:

  • Google AI Developer Programs for issuing GCP credits that made large-scale training experiments possible
  • TPU Research Cloud (TRC) for providing free access to Cloud TPU v4, v5e, and v6e accelerators
  • Kaggle for providing free TPU v5e access for prototyping and validation

Built on:

License

MIT

About

No description, website, or topics provided.

Resources

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors