Source code for protomotions.agents.masked_mimic.config
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Configuration classes for MaskedMimic agent.
MaskedMimic uses a VAE-based architecture for versatile motion imitation
with masked conditioning and latent space learning.
"""
from typing import Union, Optional
from enum import Enum
from protomotions.agents.common.config import ModuleContainerConfig
from protomotions.agents.base_agent.config import (
OptimizerConfig,
BaseAgentConfig,
BaseModelConfig,
)
from dataclasses import dataclass, field
[docs]
@dataclass
class KLDScheduleConfig:
"""Configuration for KL divergence scheduling in VAE training."""
init_kld_coeff: float = field(
default=0.0001,
metadata={"help": "Initial KL divergence coefficient.", "min": 0.0}
)
end_kld_coeff: float = field(
default=0.01,
metadata={"help": "Final KL divergence coefficient.", "min": 0.0}
)
start_epoch: int = field(
default=3000,
metadata={"help": "Epoch to start KLD coefficient annealing.", "min": 0}
)
end_epoch: int = field(
default=6000,
metadata={"help": "Epoch to end KLD coefficient annealing.", "min": 0}
)
[docs]
class VaeNoiseType(Enum):
"""Type of noise for VAE sampling."""
NORMAL = "normal"
UNIFORM = "uniform"
ZEROS = "zeros"
[docs]
@classmethod
def from_str(cls, value: str) -> "VaeNoiseType":
"""Create enum from string, case-insensitive."""
try:
return next(
member for member in cls if member.value.lower() == value.lower()
)
except StopIteration:
raise ValueError(
f"'{value}' is not a valid {cls.__name__}. "
f"Valid values are: {[e.value for e in cls]}"
)
return cls(value)
[docs]
@dataclass
class VaeConfig:
"""Configuration for VAE-specific parameters."""
kld_schedule: KLDScheduleConfig = field(
default_factory=KLDScheduleConfig,
metadata={"help": "KL divergence annealing schedule."}
)
vae_latent_dim: int = field(
default=64,
metadata={"help": "Dimension of VAE latent space.", "min": 1}
)
vae_noise_type: VaeNoiseType = field(
default=VaeNoiseType.NORMAL,
metadata={"help": "Type of noise for latent sampling: normal, uniform, or zeros."}
)
[docs]
@dataclass
class FeedForwardModelConfig(BaseModelConfig):
"""Configuration for FeedForwardModel (non-VAE variant)."""
_target_: str = "protomotions.agents.masked_mimic.model.FeedForwardModel"
trunk: ModuleContainerConfig = field(
default_factory=ModuleContainerConfig,
metadata={"help": "Main trunk network for forward pass."}
)
[docs]
@dataclass
class MaskedMimicModelConfig(BaseModelConfig):
"""Configuration for MaskedMimic Model (VAE-based imitation learning)."""
_target_: str = "protomotions.agents.masked_mimic.model.MaskedMimicModel"
encoder: ModuleContainerConfig = field(
default_factory=ModuleContainerConfig,
metadata={"help": "VAE encoder network (maps observations to latent)."}
)
prior: ModuleContainerConfig = field(
default_factory=ModuleContainerConfig,
metadata={"help": "Prior network for latent distribution."}
)
trunk: ModuleContainerConfig = field(
default_factory=ModuleContainerConfig,
metadata={"help": "Decoder trunk network (latent to actions)."}
)
vae: VaeConfig = field(
default_factory=VaeConfig,
metadata={"help": "VAE configuration (latent dim, KLD schedule, etc)."}
)
optimizer: OptimizerConfig = field(
default_factory=lambda: OptimizerConfig(lr=2e-5),
metadata={"help": "Optimizer settings for model training."}
)
[docs]
@dataclass
class MaskedMimicAgentConfig(BaseAgentConfig):
"""Main configuration class for MaskedMimic Agent."""
_target_: str = "protomotions.agents.masked_mimic.agent.MaskedMimic"
model: Union[MaskedMimicModelConfig, FeedForwardModelConfig] = field(
default_factory=MaskedMimicModelConfig,
metadata={"help": "Model configuration (VAE or FeedForward variant)."}
)
expert_model_path: Optional[str] = field(
default=None,
metadata={"help": "Path to pre-trained expert model checkpoint."}
)