Source code for protomotions.agents.evaluators.base_evaluator

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Base evaluator for agent evaluation and metrics computation.

This module provides the base evaluation infrastructure for computing performance
metrics during training and evaluation. Evaluators run periodic assessments of
agent performance and compute task-specific metrics.

Key Classes:
    - BaseEvaluator: Base class for all evaluators with hook-based customization

Key Features:
    - Periodic evaluation during training
    - Hook pattern for subclass customization (4 hooks: start, reset_kwargs, check, step)
    - MdpComponent-based evaluation with threshold failure detection
    - Aggregate metrics via plugin system (see aggregate_metrics.py)
    - Episode statistics aggregation
    - Distributed evaluation support
    
Note:
    Aggregate metric plugins (SmoothnessAggregateMetric, ActionSmoothnessAggregateMetric)
    are defined in aggregate_metrics.py and compute post-hoc statistics over
    accumulated MotionMetrics trajectories.
"""

import logging
import numpy as np
import torch
from torch import Tensor
from typing import Dict, Optional, Tuple, Any
from lightning.fabric import Fabric

from protomotions.agents.evaluators.metrics import MotionMetrics
from protomotions.envs.base_env.env import BaseEnv
from protomotions.envs.component_manager import ComponentManager
from protomotions.envs.base_env.utils import combine_evaluation
from protomotions.agents.evaluators.config import EvaluatorConfig
from protomotions.agents.evaluators.aggregate_metrics import (
    SmoothnessAggregateMetric,
    ActionSmoothnessAggregateMetric,
)

log = logging.getLogger(__name__)


[docs] class BaseEvaluator: """Base class for agent evaluation and metrics computation. Runs periodic evaluations during training to assess agent performance. Collects episode statistics, computes task-specific metrics, and provides feedback for checkpoint selection (best model saving). Args: agent: The agent being evaluated. fabric: Lightning Fabric instance for distributed evaluation. config: Evaluator configuration specifying eval frequency and length. Example: >>> evaluator = BaseEvaluator(agent, fabric, config) >>> metrics, score = evaluator.evaluate() """
[docs] def __init__(self, agent: Any, fabric: Fabric, config: EvaluatorConfig): """ Initialize the evaluator. Args: agent: The agent to evaluate fabric: Lightning Fabric instance for distributed training """ self.agent = agent self.fabric = fabric self.config = config self.metric_plugins = [] self._register_plugins() self.eval_count = 0 self._component_manager: Optional[ComponentManager] = None self._motion_failed: Optional[Tensor] = None self._per_component_failures: Dict[str, Tensor] = {} self._component_value_sum: Dict[str, Tensor] = {} self._component_value_min: Dict[str, Tensor] = {} self._component_value_max: Dict[str, Tensor] = {} self._component_step_count: Dict[str, Tensor] = {} # Instance state for metrics collection during evaluation self._metrics: Optional[Dict] = None
@property def device(self) -> torch.device: """Device for computations (from fabric).""" return self.fabric.device @property def env(self) -> BaseEnv: """Environment instance (from agent).""" return self.agent.env @property def root_dir(self): """Root directory for saving outputs (from agent).""" return self.agent.root_dir @torch.no_grad() def evaluate(self) -> Tuple[Dict, Optional[float]]: """ Evaluate the agent and calculate metrics. This is the main entry point that orchestrates the evaluation process. Returns: Tuple containing: - Dict of evaluation metrics for logging - Optional score value for determining best model """ if not self.config.evaluation_components: return {}, None self.agent.eval() self._metrics = self.initialize_eval() if self._metrics is None: return {}, None self.run_evaluation() evaluation_log, evaluated_score = self.process_eval_results() self.cleanup_after_evaluation() self.eval_count += 1 return evaluation_log, evaluated_score @property def num_envs(self) -> int: """Number of environments (from agent).""" return self.agent.num_envs @property def max_eval_steps(self) -> int: """Maximum steps per evaluation episode.""" return self.config.max_eval_steps
[docs] def initialize_eval(self) -> Dict: """Initialize evaluation tracking.""" self._init_eval_component_buffers(self.num_envs) return {}
[docs] def run_evaluation(self) -> None: """Run the evaluation process.""" env_ids = torch.arange(self.num_envs, device=self.device) self.evaluate_episode(env_ids, self.max_eval_steps)
[docs] def evaluate_episode(self, env_ids: Tensor, max_steps: int) -> None: """Run a single episode batch. Subclasses customize behavior via 4 hooks: - _on_episode_start: pre-reset setup - _get_reset_kwargs: customize env.reset() call - _check_eval_components: per-step evaluation component checking - _on_episode_step: per-step data collection Args: env_ids: Environment IDs to evaluate [num_envs] max_steps: Maximum steps for this episode """ self._on_episode_start(env_ids) obs, _ = self.env.reset(env_ids, **self._get_reset_kwargs()) obs = self.agent.add_agent_info_to_obs(obs) obs_td = self.agent.obs_dict_to_tensordict(obs) for step_idx in range(max_steps): model_outs = self.agent.model(obs_td) actions = model_outs.get("mean_action", model_outs.get("action")) obs, rewards, dones, terminated, extras = self.env.step(actions) obs = self.agent.add_agent_info_to_obs(obs) obs_td = self.agent.obs_dict_to_tensordict(obs) self._check_eval_components(env_ids, step_idx) self._on_episode_step(env_ids, extras, actions)
def _on_episode_start(self, env_ids: Tensor) -> None: """Hook called before episode reset. Override in subclasses for pre-reset setup. Args: env_ids: Environment IDs about to be reset """ pass def _get_reset_kwargs(self) -> dict: """Hook to provide extra kwargs for env.reset(). Override in subclasses. Returns: Dictionary of kwargs passed to env.reset() """ return {} def _check_eval_components(self, env_ids: Tensor, step_idx: int) -> None: """Hook for per-step evaluation component checking. Override in subclasses. Default behavior: check all env_ids, mapping env_ids to eval_ids 1:1. Subclasses can filter env_ids (e.g., skip finished motion clips) and provide custom env-to-eval-ID mapping. Args: env_ids: Environment IDs active this step step_idx: Current step index in the episode """ self._check_evaluation_failures(env_ids, env_ids) def _on_episode_step(self, env_ids: Tensor, extras: Dict, actions: Tensor) -> None: """Hook called after each step. Override in subclasses to collect metrics. Args: env_ids: Environment IDs active this step extras: Extra data from env.step() actions: Actions taken this step """ pass
[docs] def process_eval_results(self) -> Tuple[Dict, Optional[float]]: """Process collected metrics and prepare for logging.""" to_log = {} if self._motion_failed is not None: success_rate = 1.0 - self._motion_failed.float().mean().item() to_log["eval/success_rate"] = success_rate for name, component in self.config.evaluation_components.items(): threshold = component.static_params.get("threshold", None) if threshold is not None: failure_rate = self._per_component_failures[name].float().mean().item() to_log[f"eval/{name}/failure_rate"] = failure_rate for name in self._component_value_sum.keys(): step_count = self._component_step_count[name].float() valid = step_count > 0 if valid.any(): mean_per_motion = self._component_value_sum[name] / step_count.clamp(min=1) to_log[f"eval/{name}/mean"] = mean_per_motion[valid].mean().item() to_log[f"eval/{name}/max"] = self._component_value_max[name][valid].max().item() to_log[f"eval/{name}/min"] = self._component_value_min[name][valid].min().item() return to_log, success_rate return to_log, None
[docs] def cleanup_after_evaluation(self) -> None: """Clean up after evaluation.""" self._metrics = None self._motion_failed = None self._per_component_failures = {} self._component_value_sum = {} self._component_value_min = {} self._component_value_max = {} self._component_step_count = {} self._component_manager = None
def _init_eval_component_buffers(self, num_eval_ids: int) -> None: """Initialize per-component failure and value accumulators for this evaluation run.""" if not self.config.evaluation_components: return self._motion_failed = torch.zeros(num_eval_ids, dtype=torch.bool, device=self.device) self._per_component_failures = { name: torch.zeros(num_eval_ids, dtype=torch.bool, device=self.device) for name in self.config.evaluation_components.keys() } self._component_value_sum = { name: torch.zeros(num_eval_ids, device=self.device) for name in self.config.evaluation_components.keys() } self._component_value_min = { name: torch.full((num_eval_ids,), float('inf'), device=self.device) for name in self.config.evaluation_components.keys() } self._component_value_max = { name: torch.full((num_eval_ids,), float('-inf'), device=self.device) for name in self.config.evaluation_components.keys() } self._component_step_count = { name: torch.zeros(num_eval_ids, dtype=torch.long, device=self.device) for name in self.config.evaluation_components.keys() } self._component_manager = ComponentManager(self.device) def _check_evaluation_failures( self, active_env_ids: Tensor, active_motion_ids: Tensor, ) -> None: """Check evaluation components and accumulate values/failures for active motions.""" if self._component_manager is None: return raw_values = self._component_manager.execute_all( self.config.evaluation_components, self.env.context ) failed_buf, component_values, component_failures = combine_evaluation( raw_values=raw_values, configs=self.config.evaluation_components, num_envs=self.agent.num_envs, device=self.device, ) # Vectorized update of motion failures active_failed = failed_buf[active_env_ids] self._motion_failed[active_motion_ids] = self._motion_failed[active_motion_ids] | active_failed for name, failures in component_failures.items(): active_failures = failures[active_env_ids] self._per_component_failures[name][active_motion_ids] = ( self._per_component_failures[name][active_motion_ids] | active_failures ) for name, values in component_values.items(): active_vals = values[active_env_ids] self._component_value_sum[name][active_motion_ids] += active_vals self._component_value_min[name][active_motion_ids] = torch.minimum( self._component_value_min[name][active_motion_ids], active_vals ) self._component_value_max[name][active_motion_ids] = torch.maximum( self._component_value_max[name][active_motion_ids], active_vals ) self._component_step_count[name][active_motion_ids] += 1 def _create_base_metrics( self, metric_keys: list, num_motions: int, motion_num_frames: torch.Tensor, max_eval_steps: int, ) -> Dict[str, MotionMetrics]: """ Create MotionMetrics objects for a list of keys. Args: metric_keys: List of metric keys to create num_motions: Number of motions to evaluate motion_num_frames: Number of frames per motion max_eval_steps: Maximum evaluation steps Returns: Dictionary of MotionMetrics objects """ metrics = {} for k in metric_keys: metrics[k] = MotionMetrics( num_motions, motion_num_frames, max_eval_steps, device=self.device ) return metrics def _add_robot_state_metrics( self, metrics: Dict[str, MotionMetrics], num_motions: int, motion_num_frames: torch.Tensor, max_eval_steps: int, ) -> None: """ Add metrics for raw robot state (dof_pos, rigid_body_pos, etc.). This is needed for derived metrics like smoothness. Args: metrics: Existing metrics dict to add to num_motions: Number of motions to evaluate motion_num_frames: Number of frames per motion max_eval_steps: Maximum evaluation steps """ # Default implementation for humanoid robot state if not hasattr(self.env, "simulator"): return try: from protomotions.simulator.base_simulator.simulator_state import RobotState dummy_state: RobotState = self.env.simulator.get_robot_state() shape_mapping = dummy_state.get_shape_mapping(flattened=True) for k, shape in shape_mapping.items(): metrics[k] = MotionMetrics( num_motions, motion_num_frames, max_eval_steps, num_sub_features=shape[0], device=self.device, ) except (AttributeError, KeyError, IndexError) as e: log.warning("Could not add robot state metrics: %s", e) def _register_plugins(self) -> None: """Register metric computation plugins. Override in subclasses.""" pass def _register_smoothness_plugin( self, window_sec: float = 0.4, high_jerk_threshold: float = 6500.0 ) -> bool: """ Convenience method to register smoothness aggregate metric. Args: window_sec: Window size in seconds for smoothness computation high_jerk_threshold: Threshold for classifying high jerk frames Returns: True if plugin was registered successfully, False otherwise """ try: self.metric_plugins.append( SmoothnessAggregateMetric(self, window_sec, high_jerk_threshold) ) return True except ValueError as e: log.warning("Skipping smoothness plugin: %s", e) return False def _register_action_smoothness_plugin(self) -> bool: """ Convenience method to register action smoothness aggregate metric. Measures how much actions change between consecutive timesteps. Returns: True if plugin was registered successfully, False otherwise """ try: self.metric_plugins.append(ActionSmoothnessAggregateMetric(self)) return True except (ValueError, TypeError) as e: log.warning("Skipping action smoothness plugin: %s", e) return False def _compute_additional_metrics( self, metrics: Dict[str, MotionMetrics] ) -> Dict[str, float]: """ Run all registered metric plugins to compute additional metrics. Args: metrics: Dictionary of collected MotionMetrics Returns: Dictionary of additional computed metrics """ additional_metrics = {} for plugin in self.metric_plugins: try: plugin_metrics = plugin.compute(metrics) additional_metrics.update(plugin_metrics) except Exception as e: log.warning("Plugin %s failed: %s", plugin.__class__.__name__, e) return additional_metrics def _gen_metrics( self, metrics: Dict[str, MotionMetrics], keys_to_log: list, prefix: str = "eval" ) -> Dict[str, float]: """ Log metrics with mean/max/min aggregations across motions. For each metric, computes: - mean: average across all per-motion means (overall performance) - max: maximum of per-motion means (worst performing motion) - min: minimum of per-motion means (best performing motion) This gives you 3 separate line plot groups that track over time: - {prefix}_mean/{metric}: How well you perform on average - {prefix}_max/{metric}: How well you perform on the hardest motion - {prefix}_min/{metric}: How well you perform on the easiest motion Args: metrics: Dictionary of MotionMetrics keys_to_log: List of metric keys to log prefix: Base prefix for logged metric names (default: "eval") Returns: Dictionary of logged metrics """ to_log = {} for k in keys_to_log: if k in metrics: to_log[f"{prefix}_mean/{k}"] = metrics[k].mean_mean_reduce().item() to_log[f"{prefix}_max/{k}"] = metrics[k].mean_max_reduce().item() to_log[f"{prefix}_min/{k}"] = metrics[k].mean_min_reduce().item() return to_log def _save_list_to_file( self, items: list, filename: str, subdirectory: Optional[str] = None ) -> None: """ Save a list of items to a text file (one per line). Args: items: List of items to save filename: Name of output file subdirectory: Optional subdirectory within root_dir """ if subdirectory: output_dir = self.root_dir / subdirectory output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / filename else: output_path = self.root_dir / filename print(f"Saving to: {output_path}") with open(output_path, "w") as f: for item in items: f.write(f"{item}\n") def _plot_per_frame_metrics( self, metrics: Dict[str, MotionMetrics], keys_to_plot: Optional[list] = None, motion_id: int = 0, custom_colors: Optional[Dict[str, str]] = None, output_filename: str = "metrics_per_frame_plot.png", ) -> None: """ Plot per-frame metrics vs time for a single motion. Only plots single-feature metrics (ignores multi-feature metrics). Args: metrics: Dictionary of MotionMetrics objects keys_to_plot: List of keys to plot (None = plot all single-feature metrics) motion_id: Which motion to plot (default: 0) custom_colors: Optional dict mapping metric keys to colors output_filename: Name of output file """ try: import matplotlib.pyplot as plt except ImportError: print("matplotlib not available, skipping plotting") return dt = self.env.dt custom_colors = custom_colors or {} # Filter to only single-feature metrics single_feature_metrics = {} valid_frames = {} # Determine which keys to plot if keys_to_plot is None: keys_to_plot = list(metrics.keys()) for k in keys_to_plot: if k in metrics and metrics[k].num_sub_features == 1: single_feature_metrics[k] = metrics[k] valid_frames[k] = metrics[k].frame_counts[motion_id].item() if not single_feature_metrics: print("No single-feature metrics found for plotting") return # Create subplots for each single-feature metric num_metrics = len(single_feature_metrics) fig, axes = plt.subplots(num_metrics, 1, figsize=(12, 4 * num_metrics)) if num_metrics == 1: axes = [axes] for i, k in enumerate(single_feature_metrics.keys()): metric = single_feature_metrics[k] num_valid_frames = valid_frames[k] if num_valid_frames == 0: axes[i].text( 0.5, 0.5, f"No data for {k}", horizontalalignment="center", verticalalignment="center", transform=axes[i].transAxes, ) axes[i].set_title(f"{k}") continue # Extract data for the single motion (single feature) data = metric.data[motion_id, :num_valid_frames, 0].cpu().numpy() time_steps = np.arange(num_valid_frames) * dt # Use custom color if provided, otherwise matplotlib default plot_kwargs = {"label": k, "linewidth": 2} if k in custom_colors: plot_kwargs["color"] = custom_colors[k] axes[i].plot(time_steps, data, **plot_kwargs) axes[i].set_xlabel("Time (s)") axes[i].set_ylabel(f"{k}") axes[i].set_title(f"{k} vs Time") axes[i].grid(True, alpha=0.3) axes[i].legend() plt.tight_layout() # Save the plot if hasattr(self, "root_dir") and self.root_dir is not None: plot_path = self.root_dir / output_filename plt.savefig(plot_path, dpi=150, bbox_inches="tight") print(f"Per-frame metrics plot saved to: {plot_path}") plt.close(fig) print("Per-frame metrics plotted successfully")
[docs] def simple_test_policy(self, collect_metrics: bool = False) -> None: """ Simple evaluation loop for interactive testing. Runs policy indefinitely, collecting running average of metrics. Press Ctrl+C to stop and print summary. Args: collect_metrics: If True, collect and print average metrics on exit. """ self.agent.eval() done_indices = None step = 0 # Running averages for metrics metric_sums: Dict[str, float] = {} metric_counts: Dict[str, int] = {} print("Evaluating policy... (Ctrl+C to stop)") try: while True: obs, _ = self.env.reset(done_indices) obs = self.agent.add_agent_info_to_obs(obs) obs_td = self.agent.obs_dict_to_tensordict(obs) model_outs = self.agent.model(obs_td) action = model_outs.get("mean_action", model_outs["action"]) obs, rewards, dones, terminated, extras = self.env.step(action) obs = self.agent.add_agent_info_to_obs(obs) obs_td = self.agent.obs_dict_to_tensordict(obs) # Accumulate metrics if collect_metrics and "eval_values" in extras: for k, v in extras["eval_values"].items(): val = v.mean().item() metric_sums[k] = metric_sums.get(k, 0.0) + val metric_counts[k] = metric_counts.get(k, 0) + 1 done_indices = dones.nonzero(as_tuple=False).squeeze(-1) step += 1 except KeyboardInterrupt: print(f"\nStopped after {step} steps.") if collect_metrics and metric_counts: print("Average metrics:") for k in sorted(metric_counts.keys()): avg = metric_sums[k] / metric_counts[k] print(f" {k}: {avg:.4f}")