protomotions.agents.utils.replay_buffer module#

Replay buffer for off-policy learning.

This module provides a circular replay buffer used in AMP and ASE for storing agent transitions. The discriminator trains on batches sampled from this buffer.

Key Classes:
  • ReplayBuffer: Circular buffer with random sampling

class protomotions.agents.utils.replay_buffer.ReplayBuffer(buffer_size, device)[source]#

Bases: <Mock object at 0x7fbf6b723ad0>[]

Circular replay buffer for storing and sampling transitions.

Stores agent transitions in a circular buffer and provides random sampling for discriminator training in AMP/ASE. Automatically handles buffer overflow by overwriting oldest data.

Parameters:
  • buffer_size – Maximum number of transitions to store.

  • device (<Mock object at 0x7fbf6b747ad0>[]) – PyTorch device for tensors.

_head#

Current write position in buffer.

_is_full#

Whether buffer has wrapped around.

Example

>>> buffer = ReplayBuffer(buffer_size=10000, device=torch.device("cuda"))
>>> buffer.store({"obs": observations, "actions": actions})
>>> samples = buffer.sample(256)  # Sample 256 transitions
__init__(buffer_size, device)[source]#
reset()[source]#
get_buffer_size()[source]#
store(data_dict)[source]#
sample(n)[source]#
property device: <Mock object at 0x7fbf6b749850>[]#

Get the current device.