[Big] Add sharded replay buffer support#1834
Open
QuantuMope wants to merge 1 commit into
Open
Conversation
le-horizon
reviewed
May 29, 2026
le-horizon
left a comment
Contributor
There was a problem hiding this comment.
Thanks for the ambitious undertaking, Andrew.
A high level remark regarding incomplete episodes: incomplete episode can happen in two ways, 1) no LAST step yet, 2) since replay buffer is a ring buffer, old episodes can be covered by new data, if part of an old episode is overwritten then it has no FIRST step. Currently they are removed from replay. This can be a problem if episodes are very long. Maybe log how many episodes and steps are discarded for the user to see.
A few misses found by codex:
- Requested/crash checkpoints still save only rank 0 replay data.:
policy_trainer.py:305-313 only registers the HTTP checkpoint endpoint on rank 0, so _checkpoint_requested is only set there; policy_trainer.py:765-767 then calls _save_checkpoint() only on that rank. Crash checkpoint paths at policy_trainer.py:380-390 are also gated by self._rank <= 0.
Let's add unittest coverage for these cases as well.
- Redistribution preserves stale env_id fields.
In checkpoint_utils.py:641-665, an episode assigned to target env_id is inserted with add_batch(..., env_ids=torch.tensor([env_id])), but the stored episode’s own env_id tensor is not rewritten. After M-to-N redistribution, sampled experiences can contain source env ids that no longer match the replay row, or are out of range for the resumed worker.
Let's add an integration test that loads and redistributes the existing replay buffer and resumes training?
- Prioritized replay priorities are dropped on restore.
Save stores only flattened episode tensors from replay_buffer._buffer (checkpoint_utils.py:443-446), and load rebuilds via add_batch() (checkpoint_utils.py:662-665), which initializes priorities to default. The old checkpoint path saved segment-tree buffers, so this is a behavioral regression for priority_replay=True.
Let's include this in unittest.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds sharded replay buffer checkpointing so distributed training can save replay buffer data from every rank instead of only rank 0. Previously, if
ReplayBuffer.enable_checkpoint=True, replay buffer checkpoints were written only for rank 0. This change now saves each rank’s local replay buffer as an individual source shard among many other fixes:Checkpointerwill now redistribute the M saved shards into N roughly equivalent shards. This distribution is deterministic.