Skip to content

[Big] Add sharded replay buffer support#1834

Open
QuantuMope wants to merge 1 commit into
pytorchfrom
PR/andrew/sharded-replay-buffer
Open

[Big] Add sharded replay buffer support#1834
QuantuMope wants to merge 1 commit into
pytorchfrom
PR/andrew/sharded-replay-buffer

Conversation

@QuantuMope

Copy link
Copy Markdown
Contributor

Adds sharded replay buffer checkpointing so distributed training can save replay buffer data from every rank instead of only rank 0. Previously, if ReplayBuffer.enable_checkpoint=True, replay buffer checkpoints were written only for rank 0. This change now saves each rank’s local replay buffer as an individual source shard among many other fixes:

  1. Allows for flexible loading. If we have M saved replay buffer shards and N workers for a continuing job, Checkpointer will now redistribute the M saved shards into N roughly equivalent shards. This distribution is deterministic.
  2. Partial episodes are now truncated when saving.
  3. Replay buffers were always saved according to their max-allotted size. Now they are saved according to the size of actual valid entries.

@le-horizon le-horizon left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ambitious undertaking, Andrew.

A high level remark regarding incomplete episodes: incomplete episode can happen in two ways, 1) no LAST step yet, 2) since replay buffer is a ring buffer, old episodes can be covered by new data, if part of an old episode is overwritten then it has no FIRST step. Currently they are removed from replay. This can be a problem if episodes are very long. Maybe log how many episodes and steps are discarded for the user to see.

A few misses found by codex:

  1. Requested/crash checkpoints still save only rank 0 replay data.:
    policy_trainer.py:305-313 only registers the HTTP checkpoint endpoint on rank 0, so _checkpoint_requested is only set there; policy_trainer.py:765-767 then calls _save_checkpoint() only on that rank. Crash checkpoint paths at policy_trainer.py:380-390 are also gated by self._rank <= 0.

Let's add unittest coverage for these cases as well.

  1. Redistribution preserves stale env_id fields.
    In checkpoint_utils.py:641-665, an episode assigned to target env_id is inserted with add_batch(..., env_ids=torch.tensor([env_id])), but the stored episode’s own env_id tensor is not rewritten. After M-to-N redistribution, sampled experiences can contain source env ids that no longer match the replay row, or are out of range for the resumed worker.

Let's add an integration test that loads and redistributes the existing replay buffer and resumes training?

  1. Prioritized replay priorities are dropped on restore.
    Save stores only flattened episode tensors from replay_buffer._buffer (checkpoint_utils.py:443-446), and load rebuilds via add_batch() (checkpoint_utils.py:662-665), which initializes priorities to default. The old checkpoint path saved segment-tree buffers, so this is a behavioral regression for priority_replay=True.

Let's include this in unittest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants