feat: support SD3.5 GRPO training#4
Conversation
Made-with: Cursor
…tions + PipelineConfig) - sgl_d_dtype_patch: monkey-patch DenoisingStage so target_dtype follows pipeline_config.dit_precision instead of hardcoded bf16. Without this, fp16-trained models (e.g. SD3) get a systematic logprob mismatch vs the trainer's fp16 FSDP forward, blowing up approx_kl / clipfrac. - _compute_server_args: build PipelineConfig explicitly via from_kwargs and forward all --sglang-* PipelineConfig flags that differ from the base default. The previous loop dropped them on the floor, so --sglang-dit- precision fp16 never reached sglang. - _launch_server_target / scheduler wrapper: install monkey_patch_torch_ reductions in both the launch_server process and the scheduler grandchild. sgl-d's multimodal_gen weights_updater path (unlike srt) doesn't call it, so the receiver hits AttributeError(_rebuild_cuda_tensor_original) on the first cuda-IPC bucket. - Generalize _scheduler_process_with_qwen_image_patch into _scheduler_process_with_miles_patches: dtype + reductions patches always apply, qwen-image parity patch is opt-in via --apply-qwen-image-sgl-d-patch. Co-authored-by: Cursor <cursoragent@cursor.com>
…pt --colocate - Revert b4509b6's CPU pickle+base64 fallback in DiffusionUpdateWeightFromTensor: serializing every bucket to CPU and shipping ~5 GiB over HTTP took ~238 s per update_weights, dominating step time. Use MultiprocessingSerializer (cuda IPC, zero-copy) again. - The reason b4509b6 went to CPU was "Invalid device_uuid" from sglang. Real fix is --colocate: actor and sglang must share the same Ray placement-group bundle so they see the same CUDA_VISIBLE_DEVICES, which is exactly what the qwen-image script does. Without it Ray hands the two ray actors disjoint visible devices and CUDA IPC can't map the sender's GPU UUID. The companion piece (monkey_patch_torch_reductions on the receiver) lives in the previous commit. - scripts/run-diffusion-grpo-sd3-ocr-sglang.sh: * add --colocate * add --update-weight-buffer-size 2147483648 (2 GiB) so the ~5 GiB DiT fits in 3 buckets instead of ~20 * keep --diffusion-gradient-accumulation-steps 64 * NUM_ROLLOUT env var + MILES_DEBUG_ALIGNMENT=1 toggle for debug paths - One-shot LoRA-sync diagnostic log (only printed for weight_version<=2): sent / merged_lora / unmatched_base_layer counts. Cost is negligible and this is what surfaces LoRA-prefix bugs in the future. Net effect on the 20260508 run: update_weights_time 238s → 2.9s (~82×), overall step_time roughly halved. Co-authored-by: Cursor <cursoragent@cursor.com>
…ruction Reconstructing sigmas from rollout timesteps via timesteps/num_train_timesteps is brittle: for flow-match schedulers with use_dynamic_shifting=True the rollout's sigmas are post-shift, and small drift between rollout and trainer sigmas shows up as SDE logprob mismatch (approx_kl / ratio_abs_minus_1 inflation). Prefer the sigmas snapshot the rollout actually used (carried on DitTrajectory.sigmas), and fall back to the previous reconstruction only when it isn't present. Co-authored-by: Cursor <cursoragent@cursor.com>
…h, and use target branch seed logic - Remove all defensive getattr(args, ...) patterns in actor.py (args fields are registered) - Delete --qkv-format and --true-on-policy-mode CLI parameters from arguments.py - Replace _train_microgroup_seed with inline seed logic from diffusion_RL_v0.1 - Remove sgl_d_dtype_patch.py (dit_precision fix now upstream in sglang)
- sd3_pipeline_with_logprob: cast inputs to model dtype (fp16) before
transformer forward to match actor recompute precision
- router: migrate deprecated on_event('startup') to FastAPI lifespan
- arguments: add --debug-disable-weight-sync flag; derive
num_steps_per_rollout from global_batch_size when not set
- diffusion_rollout_response: simplify log_prob deserialization to
pass full dict to deserialize_func (fixes TypeError with sglang)
- sglang script: fix global-batch-size 128->64, add HF_TOKEN export,
add --sglang-pipeline-class-name StableDiffusion3Pipeline
- actor.py: trim verbose scheduler sigmas comment - diffusion_update_weight_utils.py: trim verbose CUDA IPC comment - placement_group.py: fix RolloutManager GPU allocation to support both sglang (no GPU) and local diffusion rollout (needs GPU from PG) - arguments.py: remove unused --debug-check-update-direction param - delete scripts/run-diffusion-grpo-sd3-ocr-debug.sh (no longer needed)
- Remove --diffusion-train (unused) - Remove --diffusion-timestep-batch (alias for --micro-batch-size-tstep) - Remove --diffusion-dtype (alias for --diffusion-forward-dtype) - Remove --diffusion-gradient-accumulation-steps (dead code, never read) - Remove --debug-disable-weight-sync and related actor.py logic - Simplify noise_level getattr in diffusion_rollout.py (remove non-existent param fallback)
SD3.5 is a gated model; sglang needs HF_TOKEN to fetch model_index.json for auto-detection. With the token set, explicit --pipeline-class-name is unnecessary — sglang resolves StableDiffusion3Pipeline from model_index.json.
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | ||
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
| timestep = t.expand(latent_model_input.shape[0]) | ||
| # Cast to transformer dtype (fp16) for forward pass. |
There was a problem hiding this comment.
let's keep flowGRPO as an intact reference and revert these codes
| device: torch.device, | ||
| ) -> dict: | ||
| out = {} | ||
| for key in per_sample_cond_kwargs[0]: |
There was a problem hiding this comment.
Just to double check, we don't need padding here because sd3.5 pad all cond embeddings into the same length right?
There was a problem hiding this comment.
Confirmed, no padding needed for SD3.5.
| out[k] = v | ||
| return out | ||
|
|
||
| def concat_cfg_cond_batches(self, neg_cond_kwargs: dict, pos_cond_kwargs: dict) -> dict: |
There was a problem hiding this comment.
seems like a dead code block here
| with self._get_init_weight_context_manager(): | ||
| pipeline = DiffusionPipeline.from_pretrained( | ||
| self.args.hf_checkpoint, | ||
| diffusion_model_id, |
There was a problem hiding this comment.
maybe we should put hf_checkpoint here since this is for model loading
There was a problem hiding this comment.
and diffusion_model is the "model backbone name"
There was a problem hiding this comment.
Line 85 shows that diffusion_model actually is used for specifying model pipeline config
There was a problem hiding this comment.
Fixed — from_pretrained uses hf_checkpoint; diffusion_model only for get_train_pipeline_config().
| @@ -0,0 +1,30 @@ | |||
| #!/usr/bin/env python3 | |||
| from __future__ import annotations | |||
There was a problem hiding this comment.
This tool is depricated now, let's upload converted dataset directly to https://huggingface.co/datasets/rockdu/miles-diffusion-datasets
There was a problem hiding this comment.
Removed. SD3 script now uses hf download rockdu/miles-diffusion-datasets.
| logger.info("RolloutManager: starting router...") | ||
| _start_router(args) | ||
| logger.info("RolloutManager: router started, init tracking...") | ||
| self._uses_local_diffusion_rollout = ( |
There was a problem hiding this comment.
remove diffusers_rollout logics here as well
There was a problem hiding this comment.
Fixed — always starts sglang router now.
| self.aborted = False | ||
|
|
||
| def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: | ||
| def submit_generate_tasks(self, samples: list[list[Sample]], rollout_id: int) -> None: |
| # get samples from the buffer and submit the generation requests. | ||
| samples = data_source(args.over_sampling_batch_size) | ||
| state.submit_generate_tasks(samples) | ||
| state.submit_generate_tasks(samples, rollout_id) |
There was a problem hiding this comment.
hmmm seems like a dead param here, is passing rollout_id a design for future?
There was a problem hiding this comment.
No, this parameter was used to randomize the seed, but the related logic has been removed, so this parameter can also be removed.
|
|
||
| self.app = FastAPI() | ||
| self.app.add_event_handler("startup", self._start_background_health_check) | ||
| @asynccontextmanager |
There was a problem hiding this comment.
Hi @zhihengy, seems like @niehen6174 also encountered fastAPI issue here, I remember some problems with fast api version, please check if this solves the FastAPI version issue, thanks!
There was a problem hiding this comment.
Migrated to FastAPI lifespan API (≥0.109). No downgrade needed.
| "the trainer disables the LoRA adapter to compute the base-model reference." | ||
| ), | ||
| ) | ||
| parser.add_argument( |
There was a problem hiding this comment.
This arg should be depricated, we have --num-steps-per-rollout and --global-batch-size for this
|
Nice PR! Just some comments in details. |
Fix num_train_timesteps scope bug when rollout sigmas snapshot is present, use hf_checkpoint for model loading, remove dead code/unused CLI args, and correct CUDA IPC comment accuracy. Co-authored-by: Cursor <cursoragent@cursor.com>
Drop diffusion_rollout module, disk-based weight sync, and the local SD3 OCR script. RolloutManager always uses sglang-d via HTTP with no GPU binding. Co-authored-by: Cursor <cursoragent@cursor.com>
Download OCR prompts from rockdu/miles-diffusion-datasets via hf download, remove the deprecated prepare_pickscore_jsonl tool, and point the SD3 sglang script at the shared dataset layout used by other GRPO scripts. Co-authored-by: Cursor <cursoragent@cursor.com>
Remove the fp16 input cast; miles SD3 training no longer uses this path after dropping local diffusers rollout. Co-authored-by: Cursor <cursoragent@cursor.com>
|
Verified SD3 sglang OCR training locally (~95 rollouts): reward tracks the before reference run at the same stage. |
Motivation
Miles_diffusion currently supports GRPO training for image diffusion models (e.g., Qwen-Image) but lacks support for Stable Diffusion 3.5 (SD3). This PR adds full SD3 GRPO training capability, supporting both local diffusers rollout and sglang-based rollout pipelines.
Modifications
Core: SD3 Model Support
SD3TrainPipelineConfigdefining transformer structure, VAE config, and LoRA target modulesCondKwargswithpooled_projectionsfield required by SD3's triple text encoder architectureCore: Actor Training Logic
_init_lora()/_save_local_rollout_lora()for SD3 LoRA weight initialization and syncShardedGradScalerfor fp16 training: prevents gradient underflow in fp16 policy gradients while keeping found_inf synchronized across FSDP ranks (no-op for bf16/fp32)--diffusion-kl-beta): computes reference model log-prob by disabling LoRA adapter, penalizes drift from base model via mean-squared difference in predicted meansCore: Local Diffusion Rollout
Fixes
sd3_pipeline_with_logprob.pydeserialize_func(sglang returns dict, not{"data": ...})[C, H, W](SD3) in addition to 4D[C, F, H, W](video models)num_steps_per_rollout=2(prevents reward collapse)Infra
num_steps_per_rolloutfromglobal_batch_sizewhen not explicitly seton_event("startup")tolifespancontext manager--diffusion-ignore-last,--diffusion-init-lora-weight,--diffusion-kl-betaCLI parametersExperimental Validation