protomotions.agents.ase.model module#
- class protomotions.agents.ase.model.ASEDiscriminatorEncoder(*args, **kwargs)[source]#
Bases:
DiscriminatorDiscriminator with MI encoder head for ASE.
Inherits from Discriminator and adds an encoder head for mutual information learning.
- config: ASEDiscriminatorEncoderConfig#
- forward(tensordict)[source]#
Forward pass computing discriminator and MI encoder outputs.
- Parameters:
tensordict (MockTensorDict) – TensorDict containing observations and latents.
- Returns:
TensorDict with disc_logits and mi_enc_output added.
- Return type:
MockTensorDict
- compute_mi_reward(
- tensordict,
- mi_hypersphere_reward_shift,
Computes the Mutual Information based reward.
- Parameters:
tensordict (MockTensorDict) – TensorDict with mi_enc_output and latents.
mi_hypersphere_reward_shift (bool) – Whether to shift reward to [0, 1].
- Returns:
Mutual Information reward tensor.
- Return type:
- calc_von_mises_fisher_enc_error(
- enc_pred,
- latent,
Calculates the Von Mises-Fisher error between predicted and true latent vectors.
- Parameters:
enc_pred (torch.Tensor) – Predicted encoded latent vector. Shape (batch_size, latent_dim).
latent (torch.Tensor) – True latent vector. Shape (batch_size, latent_dim).
- Returns:
Von Mises-Fisher error. Shape (batch_size, 1).
- Return type:
- all_weights()[source]#
Returns all weights from all sequential modules (trunk + discriminator + encoder).
Uses explicit walking to avoid duplicates in nested structures.
- Returns:
List of all weight parameters.
- Return type:
List[nn.Parameter]
- all_discriminator_weights()[source]#
Returns weights of discriminator part only (excludes encoder head).
Explicitly walks through sequential_models to avoid including encoder head.
- Returns:
List of discriminator weight parameters.
- Return type:
List[nn.Parameter]
- logit_weights()[source]#
Returns the weights of the final discriminator layer.
- Returns:
List containing the weight parameter of the discriminator’s output layer.
- Return type:
List[nn.Parameter]
- class protomotions.agents.ase.model.ASEModel(*args, **kwargs)[source]#
Bases:
AMPModelASE model with actor, task critic, disc critic, MI critic, and discriminator.
Extends AMPModel by adding an MI critic for estimating MI reward values.
- Parameters:
config (ASEModelConfig) – ASEModelConfig specifying all networks.
- _actor#
Policy network.
- _critic#
Task value network.
- _disc_critic#
Discriminator reward value network.
- _mi_critic#
MI reward value network.
- _discriminator#
Style discriminator with MI encoder.
- config: ASEModelConfig#