Skip to content

feat: support SD3.5 GRPO training#4

Open
niehen6174 wants to merge 22 commits into
mainfrom
feat/support_sd3
Open

feat: support SD3.5 GRPO training#4
niehen6174 wants to merge 22 commits into
mainfrom
feat/support_sd3

Conversation

@niehen6174
Copy link
Copy Markdown
Collaborator

Motivation

Miles_diffusion currently supports GRPO training for image diffusion models (e.g., Qwen-Image) but lacks support for Stable Diffusion 3.5 (SD3). This PR adds full SD3 GRPO training capability, supporting both local diffusers rollout and sglang-based rollout pipelines.

Modifications

Core: SD3 Model Support

  • Add SD3TrainPipelineConfig defining transformer structure, VAE config, and LoRA target modules
  • Extend CondKwargs with pooled_projections field required by SD3's triple text encoder architecture

Core: Actor Training Logic

  • Implement _init_lora() / _save_local_rollout_lora() for SD3 LoRA weight initialization and sync
  • Support two weight-sync paths: cuda-IPC for sglang, file-based for local rollout
  • Use exact sigmas from rollout snapshots (instead of reconstructing from timesteps) to prevent flow-match scheduler sigma drift — critical for log-prob consistency
  • Add ShardedGradScaler for fp16 training: prevents gradient underflow in fp16 policy gradients while keeping found_inf synchronized across FSDP ranks (no-op for bf16/fp32)
  • Implement KL divergence regularization (--diffusion-kl-beta): computes reference model log-prob by disabling LoRA adapter, penalizes drift from base model via mean-squared difference in predicted means

Core: Local Diffusion Rollout

  • Full local SD3 rollout pipeline: load model → encode prompt → denoise with log-prob → decode → compute reward
  • LoRA hot-reload from file for on-policy weight sync
  • Pipeline lifecycle management (offload/onload for memory efficiency)

Fixes

  • dtype mismatch: Cast transformer inputs to model dtype (fp16) before forward in sd3_pipeline_with_logprob.py
  • sglang response deserialization: Fix TypeError by passing dict directly to deserialize_func (sglang returns dict, not {"data": ...})
  • OCR reward: Support 3D tensor [C, H, W] (SD3) in addition to 4D [C, F, H, W] (video models)
  • global-batch-size: Fix from 128 → 64 to ensure num_steps_per_rollout=2 (prevents reward collapse)
  • Placement group: Properly differentiate sglang mode (no GPU/no PG for RolloutManager) vs local rollout mode (GPU + PG binding)

Infra

  • Auto-derive num_steps_per_rollout from global_batch_size when not explicitly set
  • Migrate FastAPI router from deprecated on_event("startup") to lifespan context manager
  • Add --diffusion-ignore-last, --diffusion-init-lora-weight, --diffusion-kl-beta CLI parameters
  • Add SD3 OCR training scripts (sglang + local rollout) and PickScore prompt conversion tool

Experimental Validation

  • SGLang rollout: OCR reward rises from 0.38 → 0.70+ over 300+ rollouts (stable)
image
  • Local rollout: Training loop verified with correct log-prob computation and reward signals
image

niehen6174 and others added 18 commits April 29, 2026 03:31
…tions + PipelineConfig)

- sgl_d_dtype_patch: monkey-patch DenoisingStage so target_dtype follows
  pipeline_config.dit_precision instead of hardcoded bf16. Without this,
  fp16-trained models (e.g. SD3) get a systematic logprob mismatch vs the
  trainer's fp16 FSDP forward, blowing up approx_kl / clipfrac.

- _compute_server_args: build PipelineConfig explicitly via from_kwargs and
  forward all --sglang-* PipelineConfig flags that differ from the base
  default. The previous loop dropped them on the floor, so --sglang-dit-
  precision fp16 never reached sglang.

- _launch_server_target / scheduler wrapper: install monkey_patch_torch_
  reductions in both the launch_server process and the scheduler grandchild.
  sgl-d's multimodal_gen weights_updater path (unlike srt) doesn't call it,
  so the receiver hits AttributeError(_rebuild_cuda_tensor_original) on the
  first cuda-IPC bucket.

- Generalize _scheduler_process_with_qwen_image_patch into
  _scheduler_process_with_miles_patches: dtype + reductions patches always
  apply, qwen-image parity patch is opt-in via --apply-qwen-image-sgl-d-patch.

Co-authored-by: Cursor <cursoragent@cursor.com>
…pt --colocate

- Revert b4509b6's CPU pickle+base64 fallback in
  DiffusionUpdateWeightFromTensor: serializing every bucket to CPU and
  shipping ~5 GiB over HTTP took ~238 s per update_weights, dominating
  step time. Use MultiprocessingSerializer (cuda IPC, zero-copy) again.

- The reason b4509b6 went to CPU was "Invalid device_uuid" from sglang.
  Real fix is --colocate: actor and sglang must share the same Ray
  placement-group bundle so they see the same CUDA_VISIBLE_DEVICES, which
  is exactly what the qwen-image script does. Without it Ray hands the two
  ray actors disjoint visible devices and CUDA IPC can't map the sender's
  GPU UUID. The companion piece (monkey_patch_torch_reductions on the
  receiver) lives in the previous commit.

- scripts/run-diffusion-grpo-sd3-ocr-sglang.sh:
  * add --colocate
  * add --update-weight-buffer-size 2147483648 (2 GiB) so the ~5 GiB DiT
    fits in 3 buckets instead of ~20
  * keep --diffusion-gradient-accumulation-steps 64
  * NUM_ROLLOUT env var + MILES_DEBUG_ALIGNMENT=1 toggle for debug paths

- One-shot LoRA-sync diagnostic log (only printed for weight_version<=2):
  sent / merged_lora / unmatched_base_layer counts. Cost is negligible and
  this is what surfaces LoRA-prefix bugs in the future.

Net effect on the 20260508 run: update_weights_time 238s → 2.9s (~82×),
overall step_time roughly halved.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ruction

Reconstructing sigmas from rollout timesteps via timesteps/num_train_timesteps
is brittle: for flow-match schedulers with use_dynamic_shifting=True the
rollout's sigmas are post-shift, and small drift between rollout and trainer
sigmas shows up as SDE logprob mismatch (approx_kl / ratio_abs_minus_1
inflation).

Prefer the sigmas snapshot the rollout actually used (carried on
DitTrajectory.sigmas), and fall back to the previous reconstruction only
when it isn't present.

Co-authored-by: Cursor <cursoragent@cursor.com>
…h, and use target branch seed logic

- Remove all defensive getattr(args, ...) patterns in actor.py (args fields are registered)
- Delete --qkv-format and --true-on-policy-mode CLI parameters from arguments.py
- Replace _train_microgroup_seed with inline seed logic from diffusion_RL_v0.1
- Remove sgl_d_dtype_patch.py (dit_precision fix now upstream in sglang)
- sd3_pipeline_with_logprob: cast inputs to model dtype (fp16) before
  transformer forward to match actor recompute precision
- router: migrate deprecated on_event('startup') to FastAPI lifespan
- arguments: add --debug-disable-weight-sync flag; derive
  num_steps_per_rollout from global_batch_size when not set
- diffusion_rollout_response: simplify log_prob deserialization to
  pass full dict to deserialize_func (fixes TypeError with sglang)
- sglang script: fix global-batch-size 128->64, add HF_TOKEN export,
  add --sglang-pipeline-class-name StableDiffusion3Pipeline
- actor.py: trim verbose scheduler sigmas comment
- diffusion_update_weight_utils.py: trim verbose CUDA IPC comment
- placement_group.py: fix RolloutManager GPU allocation to support both
  sglang (no GPU) and local diffusion rollout (needs GPU from PG)
- arguments.py: remove unused --debug-check-update-direction param
- delete scripts/run-diffusion-grpo-sd3-ocr-debug.sh (no longer needed)
- Remove --diffusion-train (unused)
- Remove --diffusion-timestep-batch (alias for --micro-batch-size-tstep)
- Remove --diffusion-dtype (alias for --diffusion-forward-dtype)
- Remove --diffusion-gradient-accumulation-steps (dead code, never read)
- Remove --debug-disable-weight-sync and related actor.py logic
- Simplify noise_level getattr in diffusion_rollout.py (remove non-existent param fallback)
SD3.5 is a gated model; sglang needs HF_TOKEN to fetch model_index.json
for auto-detection. With the token set, explicit --pipeline-class-name is
unnecessary — sglang resolves StableDiffusion3Pipeline from model_index.json.
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# Cast to transformer dtype (fp16) for forward pass.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep flowGRPO as an intact reference and revert these codes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

device: torch.device,
) -> dict:
out = {}
for key in per_sample_cond_kwargs[0]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check, we don't need padding here because sd3.5 pad all cond embeddings into the same length right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed, no padding needed for SD3.5.

out[k] = v
return out

def concat_cfg_cond_batches(self, neg_cond_kwargs: dict, pos_cond_kwargs: dict) -> dict:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a dead code block here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, rmoved

Comment thread miles/backends/fsdp_utils/actor.py Outdated
with self._get_init_weight_context_manager():
pipeline = DiffusionPipeline.from_pretrained(
self.args.hf_checkpoint,
diffusion_model_id,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should put hf_checkpoint here since this is for model loading

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and diffusion_model is the "model backbone name"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 85 shows that diffusion_model actually is used for specifying model pipeline config

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — from_pretrained uses hf_checkpoint; diffusion_model only for get_train_pipeline_config().

Comment thread tools/prepare_pickscore_jsonl.py Outdated
@@ -0,0 +1,30 @@
#!/usr/bin/env python3
from __future__ import annotations
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tool is depricated now, let's upload converted dataset directly to https://huggingface.co/datasets/rockdu/miles-diffusion-datasets

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. SD3 script now uses hf download rockdu/miles-diffusion-datasets.

Comment thread miles/ray/rollout.py Outdated
logger.info("RolloutManager: starting router...")
_start_router(args)
logger.info("RolloutManager: router started, init tracking...")
self._uses_local_diffusion_rollout = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove diffusers_rollout logics here as well

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — always starts sglang router now.

self.aborted = False

def submit_generate_tasks(self, samples: list[list[Sample]]) -> None:
def submit_generate_tasks(self, samples: list[list[Sample]], rollout_id: int) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, removed

# get samples from the buffer and submit the generation requests.
samples = data_source(args.over_sampling_batch_size)
state.submit_generate_tasks(samples)
state.submit_generate_tasks(samples, rollout_id)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm seems like a dead param here, is passing rollout_id a design for future?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this parameter was used to randomize the seed, but the related logic has been removed, so this parameter can also be removed.

Comment thread miles/router/router.py

self.app = FastAPI()
self.app.add_event_handler("startup", self._start_background_health_check)
@asynccontextmanager
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zhihengy, seems like @niehen6174 also encountered fastAPI issue here, I remember some problems with fast api version, please check if this solves the FastAPI version issue, thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated to FastAPI lifespan API (≥0.109). No downgrade needed.

Comment thread miles/utils/arguments.py Outdated
"the trainer disables the LoRA adapter to compute the base-model reference."
),
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg should be depricated, we have --num-steps-per-rollout and --global-batch-size for this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, removed

@Rockdu
Copy link
Copy Markdown
Collaborator

Rockdu commented May 22, 2026

Nice PR! Just some comments in details.

niehen6174 and others added 4 commits May 23, 2026 02:46
Fix num_train_timesteps scope bug when rollout sigmas snapshot is present,
use hf_checkpoint for model loading, remove dead code/unused CLI args, and
correct CUDA IPC comment accuracy.

Co-authored-by: Cursor <cursoragent@cursor.com>
Drop diffusion_rollout module, disk-based weight sync, and the local SD3
OCR script. RolloutManager always uses sglang-d via HTTP with no GPU binding.

Co-authored-by: Cursor <cursoragent@cursor.com>
Download OCR prompts from rockdu/miles-diffusion-datasets via hf download,
remove the deprecated prepare_pickscore_jsonl tool, and point the SD3 sglang
script at the shared dataset layout used by other GRPO scripts.

Co-authored-by: Cursor <cursoragent@cursor.com>
Remove the fp16 input cast; miles SD3 training no longer uses this path
after dropping local diffusers rollout.

Co-authored-by: Cursor <cursoragent@cursor.com>
@niehen6174
Copy link
Copy Markdown
Collaborator Author

Verified SD3 sglang OCR training locally (~95 rollouts): reward tracks the before reference run at the same stage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants