protomotions.utils.export_utils module#

Utilities for exporting trained models to ONNX format.

This module provides functions to export TensorDict-based models to ONNX format using torch.onnx.export. The exported models can be used for deployment and inference in production environments.

Key Functions:
  • export_onnx: Export a TensorDictModule to ONNX format

  • export_ppo_model: Export a trained PPO model to ONNX

  • export_observations: Export observation computation to ONNX

  • export_unified_pipeline: Export complete pipeline (context -> actions)

Note

Action processing is now handled by ActionProcessor in the policy network. When you export the model, action processing is automatically included.

class protomotions.utils.export_utils.ONNXExportWrapper(module, in_keys, batch_size)[source]#

Bases: <Mock object at 0x7faa730c24d0>[]

Wrapper for TensorDictModule that accepts positional args for ONNX export.

TensorDictModules expect a TensorDict argument, but torch.onnx.export uses positional tensor inputs. This wrapper bridges the gap.

__init__(module, in_keys, batch_size)[source]#
forward(*args)[source]#

Forward that reconstructs TensorDict from positional args.

protomotions.utils.export_utils.export_onnx(
module,
td,
path,
meta=None,
validate=True,
opset_version=17,
)[source]#

Export a TensorDictModule to ONNX format.

Uses torch.onnx.export to export the module. Creates a wrapper that converts between TensorDict and positional tensor inputs for ONNX compatibility.

Parameters:
  • module (tensordict.nn.TensorDictModuleBase) – TensorDictModule to export.

  • td (MockTensorDict) – Sample TensorDict input (used for tracing).

  • path (str) – Path to save the ONNX model (must end with .onnx).

  • meta (Dict[str, Any] | None) – Optional additional metadata to save.

  • validate (bool) – If True, validates the exported model with onnxruntime.

  • opset_version (int) – ONNX opset version to use (default: 17).

Raises:

ValueError – If path doesn’t end with .onnx.

Example

>>> from protomotions.agents.ppo.model import PPOModel
>>> from tensordict import TensorDict
>>> model = PPOModel(config)
>>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1)
>>> export_onnx(model, sample_input, "policy.onnx")
protomotions.utils.export_utils.export_ppo_actor(actor, sample_obs, path, validate=True)[source]#

Export a PPO actor’s mu network to ONNX.

Exports the mean network (mu) of a PPO actor, which is the core policy network without the distribution layer. Uses real observations from the environment to ensure proper tracing.

Parameters:
  • actor – PPOActor instance to export.

  • sample_obs (Dict[str, MockTensor]) – Sample observation dict from environment (via agent.get_obs()).

  • path (str) – Path to save the ONNX model.

  • validate (bool) – If True, validates the exported model.

Example

>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx")
protomotions.utils.export_utils.export_ppo_critic(critic, sample_obs, path, validate=True)[source]#

Export a PPO critic network to ONNX.

Uses real observations from the environment to ensure proper tracing.

Parameters:
  • critic – PPO critic (MultiHeadedMLP) instance to export.

  • sample_obs (Dict[str, MockTensor]) – Sample observation dict from environment (via agent.get_obs()).

  • path (str) – Path to save the ONNX model.

  • validate (bool) – If True, validates the exported model.

Example

>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx")
protomotions.utils.export_utils.export_ppo_model(model, sample_obs, output_dir, validate=True)[source]#

Export a complete PPO model (actor and critic) to ONNX.

Exports both the actor and critic networks to separate ONNX files in the specified directory.

Parameters:
  • model – PPOModel instance to export.

  • sample_obs (Dict[str, MockTensor]) – Sample observation dict for tracing.

  • output_dir (str) – Directory to save the ONNX models.

  • validate (bool) – If True, validates the exported models.

Returns:

Dict with paths to exported files.

Example

>>> model = trained_agent.model
>>> sample_obs = {"obs": torch.randn(1, 128)}
>>> paths = export_ppo_model(model, sample_obs, "exported_models/")
>>> print(paths)
{'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'}
class protomotions.utils.export_utils.ActionExportModule(action_config, device)[source]#

Bases: <Mock object at 0x7faa7244be50>[]

Module that wraps action processing functions for ONNX export.

Takes raw actions from the policy and produces processed actions with stiffness/damping targets.

Works with action config format:

{“fn”: normalized_pd_fixed_gains_action, “pd_action_offset”: …, …}

The function is extracted via “fn” key, and all other dict entries are passed as kwargs to the function along with the action tensor.

__init__(action_config, device)[source]#
get_output_keys()[source]#
forward(action)[source]#
class protomotions.utils.export_utils.UnifiedPipelineModule(
observation_module,
policy_module,
action_module,
policy_in_keys,
policy_action_key='mean_action',
passthrough_keys=None,
)[source]#

Bases: <Mock object at 0x7faa727bf710>[]

Unified module that combines observations + policy + action processing.

Pipeline: Context -> Observations -> Policy -> Action Processing

Outputs:
  • actions: Raw actions from the policy

  • processed_action: PD targets (clamped and transformed)

  • stiffness_targets: Per-DOF stiffness values

  • damping_targets: Per-DOF damping values

__init__(
observation_module,
policy_module,
action_module,
policy_in_keys,
policy_action_key='mean_action',
passthrough_keys=None,
)[source]#
get_all_input_keys()[source]#
forward(*all_tensors)[source]#
protomotions.utils.export_utils.export_unified_pipeline(
observation_configs,
action_config,
sample_context,
policy_module,
policy_in_keys,
policy_action_key,
path,
device,
robot_config,
passthrough_obs=None,
validate=True,
meta=None,
dt=None,
)[source]#

Export the complete pipeline (context -> actions) as a single ONNX model.

Chains observation processing, policy, and action processing into a single ONNX model. Generates a YAML configuration file for isaac-deploy.

Parameters:
  • observation_configs (Dict[str, Any]) – Dict of MdpComponent instances for observations.

  • action_config (Dict[str, Any]) – Single action config dict {“fn”: action_fn, …params…}.

  • sample_context (Dict[str, Any]) – Sample context dict for tracing.

  • policy_module (<Mock object at 0x7faa727a9050>[]) – Policy network module.

  • policy_in_keys (list) – Keys required by policy.

  • policy_action_key (str) – Key for policy output.

  • path (str) – Output ONNX file path.

  • device (<Mock object at 0x7faa727aaa50>[]) – Device for computation.

  • robot_config (Any) – Robot configuration.

  • passthrough_obs (Dict[str, MockTensor] | None) – Direct passthrough observations.

  • validate (bool) – Whether to validate with onnxruntime.

  • meta (Dict[str, Any] | None) – Additional metadata.

  • dt (float | None) – Policy control period in seconds (decimation / fps).

class protomotions.utils.export_utils.ObservationExportModule(observation_configs, sample_context, device)[source]#

Bases: <Mock object at 0x7faa72974090>[]

Module that wraps observation functions for ONNX export.

This module takes raw context tensors as inputs and computes observations by calling the configured observation functions. It’s designed to be exported to ONNX for deployment.

Works with MdpComponent-based observation config format:

observation_components = {
    "max_coords_obs": MdpComponent(
        compute_func=compute_humanoid_max_coords_observations,
        dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos, ...},
        static_params={"local_obs": True},
    ),
}
Parameters:
  • observation_configs (Dict[str, Any]) – Dict of MdpComponent instances or legacy dict configs.

  • sample_context (Dict[str, Any]) – Sample context dict to determine input shapes and resolve mappings.

  • device (<Mock object at 0x7faa72af0810>[]) – Device for tensor operations.

Example

>>> from protomotions.envs.mdp_component import MdpComponent
>>> from protomotions.envs.context_views import EnvContext
>>> configs = {
...     "max_coords_obs": MdpComponent(
...         compute_func=compute_max_coords,
...         dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos},
...     )
... }
>>> context = env.context
>>> module = ObservationExportModule(configs, context, device)
>>> export_observations(module, context, "observations.onnx")
__init__(
observation_configs,
sample_context,
device,
)[source]#
get_input_keys()[source]#

Get ordered list of input context keys needed.

get_output_keys()[source]#

Get ordered list of output observation names.

forward(*args)[source]#

Compute all observations from input tensors.

Parameters:

*args – Input tensors in the order of get_input_keys().

Returns:

Tuple of observation tensors in the order of get_output_keys().

Return type:

tuple

protomotions.utils.export_utils.export_observations(
observation_configs,
sample_context,
path,
device,
validate=True,
meta=None,
)[source]#

Export observation computation to ONNX format.

Creates an ObservationExportModule from the observation configs and exports it to ONNX. The exported model takes raw context tensors as inputs and produces observation tensors as outputs.

Parameters:
  • observation_configs (Dict[str, Any]) – Dict of observation component configurations.

  • sample_context (Dict[str, Any]) – Sample context dict for tracing and shape inference.

  • path (str) – Path to save the ONNX model.

  • device (<Mock object at 0x7faa724697d0>[]) – Device for tensor operations.

  • validate (bool) – If True, validates with onnxruntime.

  • meta (Dict[str, Any] | None) – Optional metadata to include in the JSON sidecar.

Returns:

Path to the exported ONNX model.

Return type:

str

Example

>>> configs = env.config.observation_components
>>> context = env.context
>>> export_observations(configs, context, "observations.onnx", device)