protomotions.agents.ase.agent module#
Adversarial Skill Embeddings (ASE) agent implementation.
This module implements the ASE algorithm which extends AMP with learned skill embeddings. The discriminator encodes motions into a latent skill space, and the policy is conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for complex tasks.
- Key Classes:
ASE: Main ASE agent class extending AMP
References
Peng et al. “ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters” (2022)
- class protomotions.agents.ase.agent.ASE(fabric, env, config, root_dir=None)[source]#
Bases:
AMPAdversarial Skill Embeddings (ASE) agent.
Extends AMP with a low-level policy conditioned on learned skill embeddings. The discriminator learns to encode skills from motion data into a latent space, while the policy learns to execute behaviors conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for tasks.
Key components: - Low-level policy: Conditioned on latent skill codes - Discriminator: Encodes motions into skill embeddings - Mutual information: Encourages skill diversity - Latent sampling: Periodically samples new skills during rollouts
- Parameters:
fabric (MockFabric) – Lightning Fabric instance for distributed training.
env (BaseEnv) – Environment instance with diverse motion library.
config – ASE-specific configuration including latent dimensions.
root_dir (Path | None) – Optional root directory for saving outputs.
- latents#
Current latent skill codes for each environment.
- latent_reset_steps#
Steps until next latent resample.
Example
>>> fabric = Fabric(devices=4) >>> env = Mimic(config, robot_config, simulator_config, device) >>> agent = ASE(fabric, env, config) >>> agent.setup() >>> agent.train()
Note
Requires large diverse motion dataset for effective skill learning.
- discriminator: ASEDiscriminatorEncoder#
- config: ASEAgentConfig#
- __init__(fabric, env, config, root_dir=None)[source]#
Initialize the base agent.
Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.
- Parameters:
fabric (MockFabric) – Lightning Fabric for distributed training and device management.
env (BaseEnv) – Environment instance for agent-environment interaction.
config – Configuration containing hyperparameters and training settings.
root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).
- create_optimizers(model)[source]#
Create separate optimizers for actor and critic.
Sets up Adam optimizers for policy and value networks with independent learning rates. Uses Fabric for distributed training setup.
- Parameters:
model (ASEModel) – PPOModel with actor and critic networks.
- load_parameters(state_dict)[source]#
Load PPO-specific parameters from checkpoint.
Loads actor, critic, and optimizer states. Preserves config overrides for actor_logstd if specified at command line.
- Parameters:
state_dict – Checkpoint state dictionary containing model and optimizer states.
- get_state_dict(state_dict)[source]#
Get complete state dictionary for checkpointing.
Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.
- Parameters:
state_dict – Existing state dict to update (typically empty dict).
- Returns:
Updated state dictionary containing all agent state.
- reset_latents(env_ids=None)[source]#
Resets latent variables for specified environments or all environments if None.
- Parameters:
env_ids (torch.Tensor, optional) – Environment indices to reset latents for. Defaults to None (all envs).
- store_latents(latents, env_ids)[source]#
Stores latent variables for specified environments.
- Parameters:
latents (torch.Tensor) – Latent variables to store. Shape (num_envs, latent_dim).
env_ids (torch.Tensor) – Environment indices to store latents for. Shape (num_envs,).
- sample_latents(n)[source]#
Samples new latent variables uniformly on the unit-sphere.
- Parameters:
n (int) – Number of latent variables to sample.
- Returns:
Sampled latent variables. Shape (n, latent_dim).
- Return type:
- mi_enc_forward(obs_dict)[source]#
Forward pass through the Mutual Information encoder.
- Parameters:
obs_dict (dict) – Dictionary containing observations.
- Returns:
Encoded observation tensor. Shape (batch_size, encoder_output_dim).
- Return type:
Tensor
- register_algorithm_experience_buffer_keys()[source]#
Register algorithm-specific keys in the experience buffer.
Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).
- add_agent_info_to_obs(obs)[source]#
Perform an environment step and inject current latents into observations.
- perform_optimization_step(batch_dict, batch_idx)[source]#
Perform one PPO optimization step on a minibatch.
Computes actor and critic losses, performs backpropagation, clips gradients, and updates both networks.
- get_expert_disc_obs(num_samples)[source]#
Build expert observations from motion library for discriminator training.
Iterates over reference_obs_components defined in AMPAgentConfig and uses them to compute demo observations from sampled motions.
- discriminator_step(batch_dict)[source]#
Performs a discriminator update step.
- Parameters:
batch_dict (dict) – Batch of data from the experience buffer.
- Returns:
Discriminator loss and logging dictionary.
- Return type:
Tuple[Tensor, Dict]
- compute_uniformity_loss(encodings)[source]#
Computes uniformity loss to encourage uniform distribution on unit sphere.
- Parameters:
encodings (Tensor) – Normalized encodings on unit sphere. Shape (batch_size, latent_dim).
- Returns:
Uniformity loss value.
- Return type:
Tensor
- calculate_extra_actor_loss(batch_td)[source]#
Adds the diversity loss, if enabled.
- Parameters:
batch_td (TensorDict) – Batch of data from the experience buffer and the actor.
- Returns:
Extra actor loss and logging dictionary.
- Return type:
Tuple[Tensor, Dict]