protomotions.agents.ase.model module#

class protomotions.agents.ase.model.ASEDiscriminatorEncoder(*args, **kwargs)[source]#

Bases: Discriminator

Discriminator with MI encoder head for ASE.

Inherits from Discriminator and adds an encoder head for mutual information learning.

config: ASEDiscriminatorEncoderConfig#
__init__(config)[source]#
forward(tensordict)[source]#

Forward pass computing discriminator and MI encoder outputs.

Parameters:

tensordict (MockTensorDict) – TensorDict containing observations and latents.

Returns:

TensorDict with disc_logits and mi_enc_output added.

Return type:

MockTensorDict

compute_mi_reward(
tensordict,
mi_hypersphere_reward_shift,
)[source]#

Computes the Mutual Information based reward.

Parameters:
  • tensordict (MockTensorDict) – TensorDict with mi_enc_output and latents.

  • mi_hypersphere_reward_shift (bool) – Whether to shift reward to [0, 1].

Returns:

Mutual Information reward tensor.

Return type:

torch.Tensor

calc_von_mises_fisher_enc_error(
enc_pred,
latent,
)[source]#

Calculates the Von Mises-Fisher error between predicted and true latent vectors.

Parameters:
  • enc_pred (torch.Tensor) – Predicted encoded latent vector. Shape (batch_size, latent_dim).

  • latent (torch.Tensor) – True latent vector. Shape (batch_size, latent_dim).

Returns:

Von Mises-Fisher error. Shape (batch_size, 1).

Return type:

torch.Tensor

all_weights()[source]#

Returns all weights from all sequential modules (trunk + discriminator + encoder).

Uses explicit walking to avoid duplicates in nested structures.

Returns:

List of all weight parameters.

Return type:

List[nn.Parameter]

all_discriminator_weights()[source]#

Returns weights of discriminator part only (excludes encoder head).

Explicitly walks through sequential_models to avoid including encoder head.

Returns:

List of discriminator weight parameters.

Return type:

List[nn.Parameter]

logit_weights()[source]#

Returns the weights of the final discriminator layer.

Returns:

List containing the weight parameter of the discriminator’s output layer.

Return type:

List[nn.Parameter]

all_enc_weights()[source]#

Returns all weights of the encoder part only (includes trunk + encoder head).

Returns:

List of encoder weight parameters.

Return type:

List[nn.Parameter]

enc_weights()[source]#

Returns the weights of the final encoder layer only.

Returns:

List containing the weight parameter of the encoder’s output layer.

Return type:

List[nn.Parameter]

class protomotions.agents.ase.model.ASEModel(*args, **kwargs)[source]#

Bases: AMPModel

ASE model with actor, task critic, disc critic, MI critic, and discriminator.

Extends AMPModel by adding an MI critic for estimating MI reward values.

Parameters:

config (ASEModelConfig) – ASEModelConfig specifying all networks.

_actor#

Policy network.

_critic#

Task value network.

_disc_critic#

Discriminator reward value network.

_mi_critic#

MI reward value network.

_discriminator#

Style discriminator with MI encoder.

config: ASEModelConfig#
__init__(config)[source]#
forward(tensordict)[source]#

Forward pass through AMP model and MI critic.

Parameters:

tensordict (MockTensorDict) – TensorDict containing observations.

Returns:

TensorDict with all model outputs added.

Return type:

MockTensorDict