Skip to content

Add several fixes to distributed training#1835

Open
QuantuMope wants to merge 2 commits into
pytorchfrom
PR/andrew/dist-rl-fixes
Open

Add several fixes to distributed training#1835
QuantuMope wants to merge 2 commits into
pytorchfrom
PR/andrew/dist-rl-fixes

Conversation

@QuantuMope

Copy link
Copy Markdown
Contributor

This PR adds several fixes for distributed training. The first 5 are in the first commit d281324 while the last one is in the second commit 4688708.

  1. Removes optimizers from distributed wrappers and instead handles state dicts according to the wrapped core alg. This makes it so that checkpoints are agnostic to whether they are being trained with distributed trainer.
  2. Add missing summarize_metrics(self) call so that summaries from the algorithm are recorded.
  3. Add weights_only=False to torch.load for torch.__version__ >= 2.6.
  4. Increase unroller acknowledge timeout from 3 to 60 seconds. Three seconds is generally too fast for transmitting VLA weights.
  5. Bug fix in DistributedUnroller triggered by i % 0.
  6. Checkpoint loading support for DistributedTrainer.

@le-horizon le-horizon left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes, Andrew.

A couple points from codex if they make sense:

  1. alf/algorithms/distributed_off_policy_algorithm.py:155
    DistributedOffPolicyAlgorithm.state_dict() now returns only self._core_alg.state_dict(). The trainer replay buffer lives on the distributed wrapper, not the core alg, so distributed checkpoints saved after this change will not include _replay_buffer.* keys. That makes the new checkpoint-loading setup at lines 519-569 ineffective for replay data: resume creates an empty multiprocessing replay buffer, then loads only core model/optimizer state. It also cannot load older distributed checkpoints that do contain wrapper _replay_buffer.* keys because load_state_dict() delegates straight to the core alg.

Need to throw error when resuming a ckpt encounters an empty replay buffer.
Need to add unittest coverage for the correct ckpt and reload.

  1. alf/algorithms/distributed_off_policy_algorithm.py:211
    summarize_metrics() now only calls self._core_alg.summarize_metrics(). The wrapper owns the env metrics created by RLAlgorithm.init, and unroller rollout updates those wrapper metrics via inherited observe_for_metrics(). This drops normal unroller env metric summaries such as episode count, env steps, return, and episode length. This should likely call super().summarize_metrics() as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants