protomotions.utils.export_utils module#
Utilities for exporting trained models to ONNX format.
This module provides functions to export TensorDict-based models to ONNX format using torch.onnx.export. The exported models can be used for deployment and inference in production environments.
- Key Functions:
export_onnx: Export a TensorDictModule to ONNX format
export_ppo_model: Export a trained PPO model to ONNX
export_observations: Export observation computation to ONNX
export_unified_pipeline: Export complete pipeline (context -> actions)
Note
Action processing is now handled by ActionProcessor in the policy network. When you export the model, action processing is automatically included.
- class protomotions.utils.export_utils.ONNXExportWrapper(module, in_keys, batch_size)[source]#
Bases:
<Mock object at 0x7faa730c24d0>[]Wrapper for TensorDictModule that accepts positional args for ONNX export.
TensorDictModules expect a TensorDict argument, but torch.onnx.export uses positional tensor inputs. This wrapper bridges the gap.
- protomotions.utils.export_utils.export_onnx(
- module,
- td,
- path,
- meta=None,
- validate=True,
- opset_version=17,
Export a TensorDictModule to ONNX format.
Uses torch.onnx.export to export the module. Creates a wrapper that converts between TensorDict and positional tensor inputs for ONNX compatibility.
- Parameters:
module (tensordict.nn.TensorDictModuleBase) – TensorDictModule to export.
td (MockTensorDict) – Sample TensorDict input (used for tracing).
path (str) – Path to save the ONNX model (must end with .onnx).
meta (Dict[str, Any] | None) – Optional additional metadata to save.
validate (bool) – If True, validates the exported model with onnxruntime.
opset_version (int) – ONNX opset version to use (default: 17).
- Raises:
ValueError – If path doesn’t end with .onnx.
Example
>>> from protomotions.agents.ppo.model import PPOModel >>> from tensordict import TensorDict >>> model = PPOModel(config) >>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1) >>> export_onnx(model, sample_input, "policy.onnx")
- protomotions.utils.export_utils.export_ppo_actor(actor, sample_obs, path, validate=True)[source]#
Export a PPO actor’s mu network to ONNX.
Exports the mean network (mu) of a PPO actor, which is the core policy network without the distribution layer. Uses real observations from the environment to ensure proper tracing.
- Parameters:
Example
>>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx")
- protomotions.utils.export_utils.export_ppo_critic(critic, sample_obs, path, validate=True)[source]#
Export a PPO critic network to ONNX.
Uses real observations from the environment to ensure proper tracing.
- Parameters:
Example
>>> # Get real observations from environment >>> env.reset() >>> sample_obs = agent.get_obs() >>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx")
- protomotions.utils.export_utils.export_ppo_model(model, sample_obs, output_dir, validate=True)[source]#
Export a complete PPO model (actor and critic) to ONNX.
Exports both the actor and critic networks to separate ONNX files in the specified directory.
- Parameters:
- Returns:
Dict with paths to exported files.
Example
>>> model = trained_agent.model >>> sample_obs = {"obs": torch.randn(1, 128)} >>> paths = export_ppo_model(model, sample_obs, "exported_models/") >>> print(paths) {'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'}
- class protomotions.utils.export_utils.ActionExportModule(action_config, device)[source]#
Bases:
<Mock object at 0x7faa7244be50>[]Module that wraps action processing functions for ONNX export.
Takes raw actions from the policy and produces processed actions with stiffness/damping targets.
- Works with action config format:
{“fn”: normalized_pd_fixed_gains_action, “pd_action_offset”: …, …}
The function is extracted via “fn” key, and all other dict entries are passed as kwargs to the function along with the action tensor.
- class protomotions.utils.export_utils.UnifiedPipelineModule(
- observation_module,
- policy_module,
- action_module,
- policy_in_keys,
- policy_action_key='mean_action',
- passthrough_keys=None,
Bases:
<Mock object at 0x7faa727bf710>[]Unified module that combines observations + policy + action processing.
Pipeline: Context -> Observations -> Policy -> Action Processing
- Outputs:
actions: Raw actions from the policy
processed_action: PD targets (clamped and transformed)
stiffness_targets: Per-DOF stiffness values
damping_targets: Per-DOF damping values
- protomotions.utils.export_utils.export_unified_pipeline(
- observation_configs,
- action_config,
- sample_context,
- policy_module,
- policy_in_keys,
- policy_action_key,
- path,
- device,
- robot_config,
- passthrough_obs=None,
- validate=True,
- meta=None,
- dt=None,
Export the complete pipeline (context -> actions) as a single ONNX model.
Chains observation processing, policy, and action processing into a single ONNX model. Generates a YAML configuration file for isaac-deploy.
- Parameters:
observation_configs (Dict[str, Any]) – Dict of MdpComponent instances for observations.
action_config (Dict[str, Any]) – Single action config dict {“fn”: action_fn, …params…}.
sample_context (Dict[str, Any]) – Sample context dict for tracing.
policy_module (<Mock object at 0x7faa727a9050>[]) – Policy network module.
policy_in_keys (list) – Keys required by policy.
policy_action_key (str) – Key for policy output.
path (str) – Output ONNX file path.
device (<Mock object at 0x7faa727aaa50>[]) – Device for computation.
robot_config (Any) – Robot configuration.
passthrough_obs (Dict[str, MockTensor] | None) – Direct passthrough observations.
validate (bool) – Whether to validate with onnxruntime.
dt (float | None) – Policy control period in seconds (decimation / fps).
- class protomotions.utils.export_utils.ObservationExportModule(observation_configs, sample_context, device)[source]#
Bases:
<Mock object at 0x7faa72974090>[]Module that wraps observation functions for ONNX export.
This module takes raw context tensors as inputs and computes observations by calling the configured observation functions. It’s designed to be exported to ONNX for deployment.
Works with MdpComponent-based observation config format:
observation_components = { "max_coords_obs": MdpComponent( compute_func=compute_humanoid_max_coords_observations, dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos, ...}, static_params={"local_obs": True}, ), }
- Parameters:
Example
>>> from protomotions.envs.mdp_component import MdpComponent >>> from protomotions.envs.context_views import EnvContext >>> configs = { ... "max_coords_obs": MdpComponent( ... compute_func=compute_max_coords, ... dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos}, ... ) ... } >>> context = env.context >>> module = ObservationExportModule(configs, context, device) >>> export_observations(module, context, "observations.onnx")
- protomotions.utils.export_utils.export_observations(
- observation_configs,
- sample_context,
- path,
- device,
- validate=True,
- meta=None,
Export observation computation to ONNX format.
Creates an ObservationExportModule from the observation configs and exports it to ONNX. The exported model takes raw context tensors as inputs and produces observation tensors as outputs.
- Parameters:
observation_configs (Dict[str, Any]) – Dict of observation component configurations.
sample_context (Dict[str, Any]) – Sample context dict for tracing and shape inference.
path (str) – Path to save the ONNX model.
device (<Mock object at 0x7faa724697d0>[]) – Device for tensor operations.
validate (bool) – If True, validates with onnxruntime.
meta (Dict[str, Any] | None) – Optional metadata to include in the JSON sidecar.
- Returns:
Path to the exported ONNX model.
- Return type:
Example
>>> configs = env.config.observation_components >>> context = env.context >>> export_observations(configs, context, "observations.onnx", device)