Skip to content

CSSLab/Tandem-RLVR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tandem-RLVR

Tandem Reinforcement Learning (TRL) carries the tandem-training paradigm into RLVR: each rollout is co-generated by a trainable senior and a frozen junior initialized from the senior's pre-RL base, the two alternate at word boundaries, and the standard GRPO loss is applied to senior-emitted tokens only. This repo contains the vLLM-native TRL backend, the verl training pipeline used for the paper's main experiments on Qwen3-4B-Instruct-2507 + DeepScaleR, and the tandem-eval suite.

Paper ↔ code naming map: model_spec/TERMINOLOGY.md. The current source still spells senior/junior as primary/frozen; a sweeping rename is queued.

1. vLLM-native TRL backend (vllm_source/)

vLLM 0.8.5 with a surgical dual-decoder patch.

Module Role
v1/worker/tandem.pyTandemModelManager Loads the junior into the same engine with prefix tandem_junior., on a separate device, with its own paged KV cache. Shares per-step attention metadata with the senior.
v1/sample/tandem_sampler.pyTandemSampler Five selection strategies; paper default is word — Bernoulli(p=0.5) at every orthographic word boundary, with max_gap_tokens=32 cap inside long subword-only runs.
config.pyTandemConfig Configuration surface, exposed to verl via engine_kwargs.vllm.tandem_config.*.
v1/core/sched/, v1/engine/, outputs.py Per-token authorship mask propagated through the engine output chain; surfaced as output.outputs[k].authorship_mask.

Per-step latency is ≈ 2× single-model vLLM. Per-file detail in model_spec/SPEC.md §6.

Install (Linux, 2× A100 80GB)

The full install — base image, pinned pip stack, our vLLM overlay, editable verl, and five import-smoke checks — is captured in Dockerfile.repro. Build (~7 min from scratch, ~18 GB final image) and drop into the container:

git clone https://github.com/CSSLab/Tandem-RLVR.git && cd Tandem-RLVR
docker build -f Dockerfile.repro -t tandem-rlvr:repro .
docker run --gpus all -it -v $(pwd)/scratch:/workspace/Tandem-RLVR/scratch tandem-rlvr:repro bash

Pinned stack inside the image: Python 3.10, vLLM 0.8.5, torch 2.6.0+cu124, transformers 4.57.3, flash-attn 2.7.4.post1, flashinfer 0.2.2.post1 (full list in Dockerfile.repro).

2. Training (verl, DeepScaleR, Qwen3-4B-Instruct-2507)

The launchers source scratch/wandb_secrets.env for credentials (gitignored). Create it on the host before running:

cat > scratch/wandb_secrets.env <<'EOF'
WANDB_API_KEY=<your key>
WANDB_ENTITY=<your entity>
WANDB_PROJECT_TANDEM_NATIVE_GRPO_DEEPSCALER=tandem-native-grpo-deepscaler
WANDB_PROJECT_VANILLA_GRPO_DEEPSCALER=vanilla-grpo-deepscaler
EOF

Then:

bash verl/run_tandem_native_grpo_deepscaler.sh    # canonical TRL
bash verl/run_vanilla_grpo_deepscaler.sh          # matched-optimisation GRPO baseline

Canonical TRL config (paper §3, §4.1): selection_strategy=word, prob_primary=0.5, max_gap_tokens=32, tandem_jr_tkn_weight=0, temperature=0.6, use_kl_loss=False, entropy_coeff=0, senior on cuda:0, junior on cuda:1 (both TP=1). Both senior and self-paired junior init from Qwen/Qwen3-4B-Instruct-2507. Train on DeepScaleR; validate on AMC 23–25 / AIME 24–26 / Minerva at n=4 per checkpoint. Wall-clock on 2× A100 80GB: ~9.4h to TRL best (step 120), ~7.8h to GRPO best (step 200).

3. Evaluation

python tandem_eval/MATH/eval_solo.py              # Figure 2: solo pass@k
python tandem_eval/MATH/eval_sentence_handoff.py  # Figure 3: handoff pass@k

Both scripts run vLLM, dump per-sample scores to jsonl, and compute the unbiased pass@k estimator (Chen et al., 2021). eval_sentence_handoff.py pairs the senior with the frozen Qwen3-4B-Instruct-2507 junior, alternating at \n\n-boundary tokens (one reasoning step at a time). Grader is tandem_eval/grader.py (boxed-extraction, dispatched to verl/utils/reward_score).

About

Tandem Training with VERL

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors