Source code for protomotions.agents.base_agent.config
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Configuration classes for base agent.
This module defines the configuration dataclasses used by the base agent and all
derived agents. These configurations specify training parameters, optimization
settings, and evaluation parameters.
Key Classes:
- BaseAgentConfig: Main agent configuration
- BaseModelConfig: Model architecture configuration
- OptimizerConfig: Optimizer parameters
- MaxEpisodeLengthManagerConfig: Episode length curriculum
"""
from typing import Optional, List
from dataclasses import dataclass, field
from protomotions.agents.evaluators.config import EvaluatorConfig
[docs]
@dataclass
class MaxEpisodeLengthManagerConfig:
"""Configuration for managing max episode length during training."""
# Example for configuration for agent to slowly increase the max episode length
# max_episode_length_manager:
# start_length: 5
# end_length: 300
# transition_epochs: 100000
start_length: int = field(default=5, metadata={"help": "Initial max episode length."})
end_length: int = field(default=300, metadata={"help": "Final max episode length."})
transition_epochs: int = field(default=100000, metadata={"help": "Epochs to transition."})
[docs]
def current_max_episode_length(self, current_epoch: int) -> int:
"""
Returns the current max episode length based on linear interpolation.
Args:
current_step: Current step in the episode
Returns:
Interpolated max episode length
"""
if self.transition_epochs == 0:
# No interpolation, return the fixed value
return self.start_length
# Linear interpolation between start and end values
progress = min(current_epoch / self.transition_epochs, 1.0)
return int(self.start_length + progress * (self.end_length - self.start_length))
[docs]
@dataclass
class OptimizerConfig:
"""Configuration for optimizers."""
_target_: str = "torch.optim.Adam"
lr: float = field(default=1e-4, metadata={"help": "Learning rate."})
weight_decay: float = field(default=0.0, metadata={"help": "L2 weight decay."})
eps: float = field(default=1e-8, metadata={"help": "Epsilon for numerical stability."})
betas: tuple = field(default_factory=lambda: (0.9, 0.999), metadata={"help": "Adam betas."})
[docs]
@dataclass
class BaseModelConfig:
"""Configuration for PPO Model (Actor-Critic)."""
_target_: str = "protomotions.agents.base_agent.model.BaseModel"
in_keys: List[str] = field(default_factory=list, metadata={"help": "Input keys."})
out_keys: List[str] = field(default_factory=list, metadata={"help": "Output keys."})
[docs]
@dataclass
class BaseAgentConfig:
"""Main configuration class for PPO Agent."""
batch_size: int = field(metadata={"help": "Training batch size."})
training_max_steps: int = field(metadata={"help": "Maximum training steps."})
_target_: str = "protomotions.agents.base_agent.agent.BaseAgent"
# Model configuration
model: BaseModelConfig = field(default_factory=BaseModelConfig, metadata={"help": "Model config."})
# Base agent hyperparameters
num_steps: int = field(default=32, metadata={"help": "Environment steps per update."})
gradient_clip_val: float = field(default=0.0, metadata={"help": "Max gradient norm. 0=disabled."})
fail_on_bad_grads: bool = field(default=False, metadata={"help": "Fail on NaN/Inf gradients."})
check_grad_mag: bool = field(default=True, metadata={"help": "Log gradient magnitude."})
gamma: float = field(default=0.99, metadata={"help": "Discount factor."})
# Bounds and regularization
bounds_loss_coef: float = field(
default=0.0, metadata={"help": "Action bounds loss. 0 for tanh outputs."}
) # Default policy uses tanh outputs, so we don't need the bounds loss.
# Training configuration
task_reward_w: float = field(default=1.0, metadata={"help": "Task reward weight."})
num_mini_epochs: int = field(default=1, metadata={"help": "Mini-epochs per update."})
training_early_termination: Optional[int] = field(
default=None, metadata={"help": "Stop early at this step. None=disabled."}
)
# Checkpoint saving configuration
save_epoch_checkpoint_every: Optional[int] = field(
default=1000, metadata={"help": "Save epoch_xxx.ckpt every N epochs."}
) # Save epoch_xxx.ckpt every N epochs (None = disabled)
save_last_checkpoint_every: int = field(
default=10, metadata={"help": "Save last.ckpt every K epochs."}
) # Save/overwrite last.ckpt every K epochs
# Episode length management
max_episode_length_manager: Optional[MaxEpisodeLengthManagerConfig] = field(
default=None, metadata={"help": "Episode length curriculum."}
)
# Evaluator configuration
evaluator: EvaluatorConfig = field(
default_factory=EvaluatorConfig, metadata={"help": "Evaluation config."}
)
# Reward normalization
normalize_rewards: bool = field(default=True, metadata={"help": "Normalize rewards."})
normalized_reward_clamp_value: float = field(
default=5.0, metadata={"help": "Clamp normalized rewards to [-val, val]."}
)