protomotions.agents.amp.model module#
AMP model components including discriminator network.
This module implements the AMP-specific neural networks, particularly the discriminator that distinguishes between agent and reference motion data.
- Key Classes:
Discriminator: Binary classifier for agent vs. reference motions
AMPModel: PPO model extended with discriminator
- class protomotions.agents.amp.model.Discriminator(*args, **kwargs)[source]#
Bases:
TensorDictModuleBaseDiscriminator network for AMP style rewards.
Binary classifier that distinguishes between agent-generated and reference motion data. Uses ModuleContainer structure - just chains models together.
- Parameters:
config (DiscriminatorConfig) – DiscriminatorConfig (extends ModuleContainerConfig).
- models#
ModuleContainer list of modules.
- in_keys#
Input keys from config.
- out_keys#
Output keys from config.
- config: DiscriminatorConfig#
- forward(tensordict)[source]#
Forward pass through discriminator.
- Parameters:
tensordict (MockTensorDict) – TensorDict containing observations.
- Returns:
TensorDict with discriminator output added.
- Return type:
MockTensorDict
- compute_disc_reward(disc_logits, eps=0.0001)[source]#
Compute style reward from discriminator logits.
Converts discriminator logits to reward using negative log probability. Higher reward means motion is more similar to reference data.
- Parameters:
disc_logits (MockTensor) – Discriminator logits.
eps (float) – Small constant for numerical stability.
- Returns:
Style rewards for each sample (higher = more reference-like).
- Return type:
MockTensor
- class protomotions.agents.amp.model.AMPModel(*args, **kwargs)[source]#
Bases:
PPOModelAMP model with actor, task critic, disc critic, and discriminator networks.
Extends PPOModel by adding a discriminator network that provides style rewards and a separate critic for estimating discriminator reward values.
- Parameters:
config (AMPModelConfig) – AMPModelConfig specifying all networks.
- _actor#
Policy network.
- _critic#
Task value network.
- _disc_critic#
Discriminator reward value network.
- _discriminator#
Style discriminator network.
- config: AMPModelConfig#