# SPDX-FileCopyrightText: Copyright (c) 2025-2026 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for exporting trained models to ONNX format.
This module provides functions to export TensorDict-based models to ONNX format
using torch.onnx.export. The exported models can be used for deployment
and inference in production environments.
Key Functions:
- export_onnx: Export a TensorDictModule to ONNX format
- export_ppo_model: Export a trained PPO model to ONNX
- export_observations: Export observation computation to ONNX
- export_unified_pipeline: Export complete pipeline (context -> actions)
Note:
Action processing is now handled by ActionProcessor in the policy network.
When you export the model, action processing is automatically included.
"""
import torch
import json
from pathlib import Path
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from typing import Optional, Dict, Any
def _resolve_context_path(path: str, context: Any) -> Any:
"""Resolve a dotted attribute path on a context object.
Args:
path: Dotted path string, e.g. "current.rigid_body_pos".
context: The root context object (e.g. EnvContext instance).
Returns:
The resolved value at the given path.
"""
obj = context
for attr in path.split("."):
obj = getattr(obj, attr)
return obj
[docs]
class ONNXExportWrapper(torch.nn.Module):
"""Wrapper for TensorDictModule that accepts positional args for ONNX export.
TensorDictModules expect a TensorDict argument, but torch.onnx.export
uses positional tensor inputs. This wrapper bridges the gap.
"""
[docs]
def __init__(self, module: TensorDictModuleBase, in_keys: list, batch_size: int):
super().__init__()
self.module = module
self.in_keys = in_keys
self._batch_size = batch_size
[docs]
def forward(self, *args):
"""Forward that reconstructs TensorDict from positional args."""
# Reconstruct TensorDict from positional args
# Use stored batch_size since args[0].shape[0] doesn't work during JIT tracing
td = TensorDict(
{key: tensor for key, tensor in zip(self.in_keys, args)},
batch_size=[self._batch_size],
)
output_td = self.module(td)
return tuple(output_td[key] for key in self.module.out_keys)
[docs]
def export_onnx(
module: TensorDictModuleBase,
td: TensorDict,
path: str,
meta: Optional[Dict[str, Any]] = None,
validate: bool = True,
opset_version: int = 17,
):
"""Export a TensorDictModule to ONNX format.
Uses torch.onnx.export to export the module. Creates a wrapper that
converts between TensorDict and positional tensor inputs for ONNX compatibility.
Args:
module: TensorDictModule to export.
td: Sample TensorDict input (used for tracing).
path: Path to save the ONNX model (must end with .onnx).
meta: Optional additional metadata to save.
validate: If True, validates the exported model with onnxruntime.
opset_version: ONNX opset version to use (default: 17).
Raises:
ValueError: If path doesn't end with .onnx.
Example:
>>> from protomotions.agents.ppo.model import PPOModel
>>> from tensordict import TensorDict
>>> model = PPOModel(config)
>>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1)
>>> export_onnx(model, sample_input, "policy.onnx")
"""
if not path.endswith(".onnx"):
raise ValueError(f"Export path must end with .onnx, got {path}.")
# Move to CPU and select only required input keys
td = td.cpu().select(*module.in_keys, strict=True)
module = module.cpu()
module.eval()
in_keys = list(module.in_keys)
out_keys = list(module.out_keys)
print(f"Exporting model to ONNX (PyTorch {torch.__version__})...")
print(f" Input keys: {in_keys}")
print(f" Output keys: {out_keys}")
# Create wrapper that accepts positional args instead of TensorDict
batch_size = td.batch_size[0] if td.batch_size else 1
wrapper = ONNXExportWrapper(module, in_keys, batch_size)
wrapper.eval()
# Prepare input tuple for torch.onnx.export
input_tensors = tuple(td[key] for key in in_keys)
# Create input/output names for ONNX
input_names = [f"input_{i}" for i in range(len(in_keys))]
output_names = [f"output_{i}" for i in range(len(out_keys))]
torch.onnx.export(
wrapper,
input_tensors,
path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
**{name: {0: "batch_size"} for name in input_names},
**{name: {0: "batch_size"} for name in output_names},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
print(f"✓ Exported ONNX model to {path}")
# Save metadata
meta_path = path.replace(".onnx", ".json")
if meta is None:
meta = {}
meta["in_keys"] = in_keys
meta["out_keys"] = out_keys
meta["in_shapes"] = [list(td[k].shape) for k in in_keys]
meta["onnx_input_names"] = input_names
meta["onnx_output_names"] = output_names
meta["input_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(input_names, in_keys)
}
meta["output_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(output_names, out_keys)
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=4)
print(f"✓ Exported metadata to {meta_path}")
# Validate with onnxruntime
if validate:
try:
import onnxruntime as ort
ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)
onnxruntime_input = {
name: to_numpy(tensor)
for name, tensor in zip(input_names, input_tensors)
}
ort_output = ort_session.run(None, onnxruntime_input)
assert len(ort_output) == len(
out_keys
), f"Output length mismatch: {len(ort_output)} vs {len(out_keys)}"
print("✓ ONNX model validation successful!")
except ImportError:
print("⚠ Warning: onnxruntime not installed, skipping validation.")
except Exception as e:
print(f"⚠ Warning: ONNX validation failed: {e}")
[docs]
def export_ppo_actor(
actor, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True
):
"""Export a PPO actor's mu network to ONNX.
Exports the mean network (mu) of a PPO actor, which is the core policy
network without the distribution layer. Uses real observations from the
environment to ensure proper tracing.
Args:
actor: PPOActor instance to export.
sample_obs: Sample observation dict from environment (via agent.get_obs()).
path: Path to save the ONNX model.
validate: If True, validates the exported model.
Example:
>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx")
"""
# Create TensorDict from sample observations
batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0]
td = TensorDict(sample_obs, batch_size=batch_size)
meta = {
"model_type": "PPOActor",
"observation_keys": list(sample_obs.keys()),
"observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()},
}
export_onnx(actor, td, path, meta=meta, validate=validate)
[docs]
def export_ppo_critic(
critic, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True
):
"""Export a PPO critic network to ONNX.
Uses real observations from the environment to ensure proper tracing.
Args:
critic: PPO critic (MultiHeadedMLP) instance to export.
sample_obs: Sample observation dict from environment (via agent.get_obs()).
path: Path to save the ONNX model.
validate: If True, validates the exported model.
Example:
>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx")
"""
# Create TensorDict from sample observations
batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0]
td = TensorDict(sample_obs, batch_size=batch_size)
meta = {
"model_type": "PPOCritic",
"num_out": critic.config.num_out,
"observation_keys": list(sample_obs.keys()),
"observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()},
}
export_onnx(critic, td, path, meta=meta, validate=validate)
[docs]
def export_ppo_model(
model, sample_obs: Dict[str, torch.Tensor], output_dir: str, validate: bool = True
):
"""Export a complete PPO model (actor and critic) to ONNX.
Exports both the actor and critic networks to separate ONNX files
in the specified directory.
Args:
model: PPOModel instance to export.
sample_obs: Sample observation dict for tracing.
output_dir: Directory to save the ONNX models.
validate: If True, validates the exported models.
Returns:
Dict with paths to exported files.
Example:
>>> model = trained_agent.model
>>> sample_obs = {"obs": torch.randn(1, 128)}
>>> paths = export_ppo_model(model, sample_obs, "exported_models/")
>>> print(paths)
{'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'}
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
actor_path = str(output_dir / "actor.onnx")
critic_path = str(output_dir / "critic.onnx")
print("Exporting PPO Actor...")
export_ppo_actor(model._actor, sample_obs, actor_path, validate=validate)
print("\nExporting PPO Critic...")
export_ppo_critic(model._critic, sample_obs, critic_path, validate=validate)
print(f"\nExport complete! Models saved to {output_dir}")
return {
"actor": actor_path,
"critic": critic_path,
"metadata": {
"actor_meta": str(output_dir / "actor.json"),
"critic_meta": str(output_dir / "critic.json"),
},
}
###############################################################################
# Unified Pipeline Export (Context -> Actions)
###############################################################################
[docs]
class ActionExportModule(torch.nn.Module):
"""Module that wraps action processing functions for ONNX export.
Takes raw actions from the policy and produces processed actions
with stiffness/damping targets.
Works with action config format:
{"fn": normalized_pd_fixed_gains_action, "pd_action_offset": ..., ...}
The function is extracted via "fn" key, and all other dict entries are
passed as kwargs to the function along with the action tensor.
"""
[docs]
def __init__(
self,
action_config: Dict[str, Any],
device: torch.device,
):
super().__init__()
self.device = device
self._constants = {}
self._output_keys = ["processed_action", "stiffness_targets", "damping_targets"]
if action_config is None:
self._action_function = None
return
# Extract function from config
self._action_function = action_config.get("fn")
if self._action_function is None:
return
# All non-"fn" keys are parameters for the function
for key, value in action_config.items():
if key == "fn":
continue
if isinstance(value, torch.Tensor):
self.register_buffer(key, value.to(device))
else:
self._constants[key] = value
[docs]
def get_output_keys(self) -> list:
return self._output_keys
[docs]
def forward(self, action: torch.Tensor) -> tuple:
func_kwargs = {"action": action}
for name in self._buffers:
func_kwargs[name] = getattr(self, name)
func_kwargs.update(self._constants)
result = self._action_function(**func_kwargs)
return (
result["processed_action"],
result["stiffness_targets"],
result["damping_targets"],
)
[docs]
class UnifiedPipelineModule(torch.nn.Module):
"""Unified module that combines observations + policy + action processing.
Pipeline: Context -> Observations -> Policy -> Action Processing
Outputs:
- actions: Raw actions from the policy
- processed_action: PD targets (clamped and transformed)
- stiffness_targets: Per-DOF stiffness values
- damping_targets: Per-DOF damping values
"""
[docs]
def __init__(
self,
observation_module: "ObservationExportModule",
policy_module: torch.nn.Module,
action_module: "ActionExportModule",
policy_in_keys: list,
policy_action_key: str = "mean_action",
passthrough_keys: list = None,
):
super().__init__()
self.observation_module = observation_module
self.policy_module = policy_module
self.action_module = action_module
self.policy_in_keys = policy_in_keys
self.policy_action_key = policy_action_key
self.passthrough_keys = passthrough_keys or []
self.obs_output_keys = observation_module.get_output_keys()
self.obs_input_keys = observation_module.get_input_keys()
self.num_obs_inputs = len(self.obs_input_keys)
[docs]
def forward(self, *all_tensors) -> tuple:
obs_context_tensors = all_tensors[: self.num_obs_inputs]
passthrough_tensors = all_tensors[self.num_obs_inputs :]
obs_outputs = self.observation_module(*obs_context_tensors)
obs_dict = {key: obs_outputs[i] for i, key in enumerate(self.obs_output_keys)}
for key, tensor in zip(self.passthrough_keys, passthrough_tensors):
obs_dict[key] = tensor
from tensordict import TensorDict
# Infer batch size from passthrough tensors (which always have batch dim)
# or from a policy input tensor, avoiding constant tensors like body_ids
batch_size = None
if passthrough_tensors:
batch_size = passthrough_tensors[0].shape[0]
else:
# Find a tensor with "current_state_" prefix which should have batch dim
for key in self.policy_in_keys:
if key in obs_dict and key.startswith("current_state_"):
batch_size = obs_dict[key].shape[0]
break
if batch_size is None:
# Fallback: use first policy input that's 2D or more
for key in self.policy_in_keys:
if key in obs_dict and obs_dict[key].dim() >= 2:
batch_size = obs_dict[key].shape[0]
break
if batch_size is None:
batch_size = 1
policy_input = TensorDict(
{key: obs_dict[key] for key in self.policy_in_keys},
batch_size=[batch_size],
)
policy_output = self.policy_module(policy_input)
actions = policy_output[self.policy_action_key]
processed_action, stiffness_targets, damping_targets = self.action_module(
actions
)
return actions, processed_action, stiffness_targets, damping_targets
###############################################################################
# YAML Configuration Generation
###############################################################################
# Mapping from ONNX input names to (name, kind) tuples for isaac-deploy YAML.
ONNX_INPUT_MAPPING = {
# Current state (current_state.* context paths)
"current_state_dof_pos": ("joint_pos", "joint_pos"),
"current_state_dof_vel": ("joint_vel", "joint_vel"),
"current_state_root_ang_vel": ("root_ang_vel", "root_ang_vel"),
"current_state_root_local_ang_vel": ("root_ang_vel", "local_root_ang_vel"),
"current_state_root_rot": ("root_body_rot", "root_body_rot"),
"current_state_anchor_rot": ("anchor_rot", "anchor_rot"),
"current_state_rigid_body_pos": ("body_pos", "body_pos"),
"current_state_rigid_body_rot": ("body_rot", "body_rot"),
# Current state (current.* context paths)
"current_dof_pos": ("joint_pos", "joint_pos"),
"current_dof_vel": ("joint_vel", "joint_vel"),
"current_root_ang_vel": ("root_ang_vel", "root_ang_vel"),
"current_root_local_ang_vel": ("root_ang_vel", "local_root_ang_vel"),
"current_root_rot": ("root_body_rot", "root_body_rot"),
"current_anchor_rot": ("anchor_rot", "anchor_rot"),
"current_rigid_body_pos": ("body_pos", "body_pos"),
"current_rigid_body_rot": ("body_rot", "body_rot"),
# Historical state
"historical_actions": ("last_actions", "last_actions"),
"historical_processed_actions": ("processed_actions_history", "last_actions"),
"historical_dof_pos": ("joint_pos_history", "joint_pos"),
"historical_dof_vel": ("joint_vel_history", "joint_vel"),
"historical_root_ang_vel": ("root_ang_vel_history", "root_ang_vel"),
"historical_root_local_ang_vel": ("root_ang_vel_history", "local_root_ang_vel"),
"historical_root_rot": ("root_body_rot_history", "root_body_rot"),
"historical_anchor_rot": ("anchor_rot_history", "anchor_rot"),
"historical_rigid_body_pos": ("body_pos_history", "body_pos"),
"historical_rigid_body_rot": ("body_rot_history", "body_rot"),
"historical_rigid_body_vel": ("body_vel_history", "body_vel"),
"historical_rigid_body_ang_vel": ("body_ang_vel_history", "body_ang_vel"),
"historical_ground_heights": ("ground_heights_history", "ground_heights"),
"previous_actions": ("previous_actions", "last_actions"),
"previous_processed_actions": ("previous_processed_actions", "last_actions"),
# Reference motion (mimic.ref_* context paths)
"mimic_ref_ang_vel": (
"reference_motion_body_ang_vel",
"reference_motion_body_ang_vel",
),
"mimic_ref_dof_pos": ("reference_motion_joint_pos", "reference_motion_joint_pos"),
"mimic_ref_dof_vel": ("reference_motion_joint_vel", "reference_motion_joint_vel"),
"mimic_ref_rot": ("reference_motion_body_rot", "reference_motion_body_rot"),
"mimic_ref_anchor_rot": (
"reference_motion_anchor_rot",
"reference_motion_body_rot",
),
# Reference motion (mimic.future_* context paths)
"mimic_future_ang_vel": (
"reference_motion_body_ang_vel",
"reference_motion_body_ang_vel",
),
"mimic_future_dof_pos": (
"reference_motion_joint_pos",
"reference_motion_joint_pos",
),
"mimic_future_dof_vel": (
"reference_motion_joint_vel",
"reference_motion_joint_vel",
),
"mimic_future_rot": ("reference_motion_body_rot", "reference_motion_body_rot"),
"mimic_future_anchor_rot": (
"reference_motion_anchor_rot",
"reference_motion_body_rot",
),
# MaskedMimic sparse conditioning
"masked_mimic_ref_pos": ("masked_mimic_ref_pos", "masked_mimic_body_pos"),
"masked_mimic_ref_rot": ("masked_mimic_ref_rot", "masked_mimic_body_rot"),
"masked_mimic_target_bodies_masks": (
"masked_mimic_body_masks",
"masked_mimic_masks",
),
"masked_mimic_target_poses_masks": (
"masked_mimic_pose_masks",
"masked_mimic_masks",
),
"masked_mimic_time_offsets": ("masked_mimic_time_offsets", "masked_mimic_time"),
# VAE
"vae_noise": ("vae_noise", "vae_noise"),
}
# Mapping from ONNX output names to (name, kind) tuples for isaac-deploy YAML.
ONNX_OUTPUT_MAPPING = {
"actions": ("actions", "actions"),
"joint_pos_targets": ("joint_pos_targets", "joint_pos_targets"),
"stiffness_targets": ("stiffness_targets", "stiffness_targets"),
"damping_targets": ("damping_targets", "damping_targets"),
}
def _build_policy_input(
onnx_name: str,
input_shapes: Dict[str, list],
joint_names: list,
body_names: list,
anchor_body: str = "pelvis",
) -> Optional[Dict[str, Any]]:
"""Build a policy input entry for YAML from ONNX name."""
if onnx_name not in ONNX_INPUT_MAPPING:
return None
name, kind = ONNX_INPUT_MAPPING[onnx_name]
shape = input_shapes.get(onnx_name, [1, 1])
# Normalize shape to have batch size 1.
shape = [1] + list(shape)[1:]
# Reference motion inputs don't have history fields.
is_reference_motion = kind.startswith("reference_motion_")
# Infer history from shape (dimension 1 if > 2D).
history = shape[1] if len(shape) >= 3 and shape[1] > 1 else 0
# Determine include_current_value_in_history:
# - True for simulator values (joint_pos, joint_vel, root_body_rot, root_ang_vel,
# anchor_rot) and last_actions.
# - False for historical observations (history != 0 and not last_actions).
simulator_kinds = {
"joint_pos",
"joint_vel",
"root_body_rot",
"root_ang_vel",
"anchor_rot",
}
if kind == "last_actions":
include_current = True
elif kind in simulator_kinds and history == 0:
include_current = True
else:
include_current = False
entry = {
"name": name,
"kind": kind,
"shape": shape,
"key": onnx_name,
}
if not is_reference_motion:
entry["history"] = history
entry["include_current_value_in_history"] = include_current
# For reference motion inputs with multiple future steps, emit future_steps.
if is_reference_motion and len(shape) >= 3 and shape[1] > 1:
entry["future_steps"] = shape[1]
# Generate element_names for the TensorProcessor output ordering.
# This tells isaac-deploy the order that the policy expects tensor elements.
quat_elements = list("xyzw") # ["x", "y", "z", "w"]
ang_vel_elements = ["x", "y", "z"]
if kind in ("joint_pos", "joint_vel"):
entry["element_names"] = [joint_names]
elif kind in ("root_rot", "anchor_rot", "root_body_rot"):
entry["element_names"] = [quat_elements]
elif kind == "root_ang_vel":
entry["element_names"] = [ang_vel_elements]
elif kind in ("reference_motion_joint_pos", "reference_motion_joint_vel"):
entry["element_names"] = [joint_names]
elif kind == "reference_motion_body_rot":
if onnx_name in ("mimic_ref_anchor_rot", "mimic_future_anchor_rot"):
entry["element_names"] = [[anchor_body], quat_elements]
else:
entry["element_names"] = [body_names, quat_elements]
# Processed action variants read from the actual commanded positions (after interpolation),
# while raw action variants read from the policy output directly.
_PROCESSED_ACTION_ONNX_NAMES = {
"historical_processed_actions",
"previous_processed_actions",
}
if kind == "last_actions":
if onnx_name in _PROCESSED_ACTION_ONNX_NAMES:
entry["output_key"] = "robot_action"
else:
entry["output_key"] = "actions"
return entry
def _build_policy_output(
onnx_name: str,
output_shapes: Dict[str, list],
joint_names: list,
stiffness: list,
damping: list,
use_onnx_for_gains: bool = True,
) -> Optional[Dict[str, Any]]:
"""Build a policy output entry for YAML from ONNX name.
Args:
onnx_name: Name of the ONNX output.
output_shapes: Dictionary mapping ONNX output names to shapes.
joint_names: List of joint names.
stiffness: List of stiffness values (used only if use_onnx_for_gains=False).
damping: List of damping values (used only if use_onnx_for_gains=False).
use_onnx_for_gains: If True, read stiffness/damping from ONNX outputs.
If False, use constant values from YAML.
"""
if onnx_name not in ONNX_OUTPUT_MAPPING:
return None
name, kind = ONNX_OUTPUT_MAPPING[onnx_name]
shape = output_shapes.get(onnx_name, [1, len(joint_names)])
# Normalize shape to have batch size 1.
shape = [1] + list(shape)[1:]
entry = {
"name": name,
"kind": kind,
"key": onnx_name, # Always use ONNX output name as key.
"shape": shape,
}
# Add joint_names only for action terms that need it (not the passthrough "actions" term).
if "joint" in kind or kind in ("stiffness_targets", "damping_targets"):
entry["joint_names"] = joint_names
# For stiffness/damping, optionally fall back to YAML constants.
if kind == "stiffness_targets" and not use_onnx_for_gains:
entry["key"] = None # Use YAML values instead of ONNX.
entry["stiffness"] = stiffness
if kind == "damping_targets" and not use_onnx_for_gains:
entry["key"] = None # Use YAML values instead of ONNX.
entry["damping"] = damping
return entry
def _generate_yaml_content(
input_shapes: Dict[str, list],
output_shapes: Dict[str, list],
onnx_in_names: list,
onnx_out_names: list,
joint_names: list,
body_names: list,
stiffness: list,
damping: list,
anchor_body: str = "pelvis",
dt: Optional[float] = None,
) -> Dict[str, Any]:
"""Generate the complete YAML content for isaac-deploy."""
# Build policy inputs.
policy_inputs = []
for onnx_name in onnx_in_names:
entry = _build_policy_input(
onnx_name, input_shapes, joint_names, body_names, anchor_body
)
if entry:
policy_inputs.append(entry)
# Build policy outputs.
policy_outputs = []
for onnx_name in onnx_out_names:
entry = _build_policy_output(
onnx_name, output_shapes, joint_names, stiffness, damping
)
if entry:
policy_outputs.append(entry)
content = {
"type": "unified_pipeline",
"joint_names": joint_names,
"body_names": body_names,
"default_joint_stiffness": stiffness,
"default_joint_damping": damping,
"policy_inputs": policy_inputs,
"policy_outputs": policy_outputs,
}
if dt is not None:
# Insert dt right after type for readability
ordered = {"type": content.pop("type"), "dt": dt}
ordered.update(content)
content = ordered
return content
[docs]
def export_unified_pipeline(
observation_configs: Dict[str, Any],
action_config: Dict[str, Any],
sample_context: Dict[str, Any],
policy_module: torch.nn.Module,
policy_in_keys: list,
policy_action_key: str,
path: str,
device: torch.device,
robot_config: Any,
passthrough_obs: Optional[Dict[str, torch.Tensor]] = None,
validate: bool = True,
meta: Optional[Dict[str, Any]] = None,
dt: Optional[float] = None,
) -> str:
"""Export the complete pipeline (context -> actions) as a single ONNX model.
Chains observation processing, policy, and action processing into a single
ONNX model. Generates a YAML configuration file for isaac-deploy.
Args:
observation_configs: Dict of MdpComponent instances for observations.
action_config: Single action config dict {"fn": action_fn, ...params...}.
sample_context: Sample context dict for tracing.
policy_module: Policy network module.
policy_in_keys: Keys required by policy.
policy_action_key: Key for policy output.
path: Output ONNX file path.
device: Device for computation.
robot_config: Robot configuration.
passthrough_obs: Direct passthrough observations.
validate: Whether to validate with onnxruntime.
meta: Additional metadata.
dt: Policy control period in seconds (decimation / fps).
"""
import logging
import yaml
log = logging.getLogger(__name__)
log.info("=" * 60)
log.info("Exporting Unified Pipeline (Context -> Actions)")
log.info("=" * 60)
passthrough_obs = passthrough_obs or {}
# Extract robot metadata.
joint_names = robot_config.kinematic_info.dof_names
body_names = robot_config.kinematic_info.body_names
stiffness = [
float(robot_config.control.control_info[j].stiffness) for j in joint_names
]
damping = [float(robot_config.control.control_info[j].damping) for j in joint_names]
# Get anchor body name (defaults to first body if not specified).
anchor_body = (
robot_config.anchor_body_name
if robot_config.anchor_body_name
else body_names[0]
)
log.info(f"Joint names: {joint_names}")
log.info(f"Body names: {body_names}")
log.info(f"Anchor body: {anchor_body}")
# Step 1: Create observation module.
obs_module = ObservationExportModule(observation_configs, sample_context, device)
obs_module.eval()
obs_input_keys = obs_module.get_input_keys()
obs_output_keys = obs_module.get_output_keys()
passthrough_keys = list(passthrough_obs.keys())
log.info(f"Context input keys (for obs): {obs_input_keys}")
log.info(f"Observation output keys: {obs_output_keys}")
log.info(f"Passthrough keys (direct to policy): {passthrough_keys}")
log.info(f"Policy input keys: {policy_in_keys}")
# Check coverage: obs outputs + passthrough should cover all policy inputs.
available = set(obs_output_keys) | set(passthrough_keys)
missing = set(policy_in_keys) - available
if missing:
raise ValueError(
f"Policy requires inputs not available: {missing}. "
f"Available from obs: {obs_output_keys}, passthrough: {passthrough_keys}"
)
# Step 2: Create action processing module.
action_module = ActionExportModule(action_config, device)
action_module.cpu()
action_module.eval()
# Step 3: Create unified module with passthrough.
unified_module = UnifiedPipelineModule(
observation_module=obs_module,
policy_module=policy_module.cpu(),
action_module=action_module,
policy_in_keys=policy_in_keys,
policy_action_key=policy_action_key,
passthrough_keys=passthrough_keys,
)
unified_module.cpu()
unified_module.eval()
# All input keys: observation context + passthrough.
all_input_keys = unified_module.get_all_input_keys()
# Build sample inputs from context (for obs) and passthrough_obs.
# Move all to CPU for ONNX export.
sample_inputs = []
input_shapes = {}
# Create ONNX names.
def sanitize_name(name: str) -> str:
return (
name.replace(".", "_").replace("[", "_").replace("]", "_").replace(":", "_")
)
# Observation context inputs.
for key in obs_input_keys:
value = _resolve_context_path(key, sample_context)
if isinstance(value, torch.Tensor):
sample_inputs.append(value.cpu())
# Store shapes under sanitized ONNX name so _build_policy_input can find them.
input_shapes[sanitize_name(key)] = list(value.shape)
else:
raise ValueError(f"Input '{key}' is not a tensor: {type(value)}")
# Passthrough inputs.
for key in passthrough_keys:
value = passthrough_obs[key]
sample_inputs.append(value.cpu())
input_shapes[sanitize_name(key)] = list(value.shape)
# Run forward pass to get output shapes.
with torch.no_grad():
actions, joint_pos_targets, stiffness_targets, damping_targets = unified_module(
*sample_inputs
)
log.info(f"Actions shape: {list(actions.shape)}")
log.info(f"Joint pos targets shape: {list(joint_pos_targets.shape)}")
log.info(f"Stiffness targets shape: {list(stiffness_targets.shape)}")
log.info(f"Damping targets shape: {list(damping_targets.shape)}")
onnx_input_names = [sanitize_name(k) for k in all_input_keys]
onnx_output_names = [
"actions",
"joint_pos_targets",
"stiffness_targets",
"damping_targets",
]
# Export to ONNX.
log.info(f"Exporting unified pipeline to {path}...")
torch.onnx.export(
unified_module,
tuple(sample_inputs),
path,
input_names=onnx_input_names,
output_names=onnx_output_names,
opset_version=17,
do_constant_folding=True,
dynamic_axes={
**{name: {0: "batch_size"} for name in onnx_input_names},
**{name: {0: "batch_size"} for name in onnx_output_names},
},
dynamo=False,
)
# Load to get actual ONNX names.
import onnxruntime as ort
session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
actual_onnx_in_names = [inp.name for inp in session.get_inputs()]
actual_onnx_out_names = [out.name for out in session.get_outputs()]
# Build mapping from ONNX name to semantic key.
# Match by sanitized name similarity (ONNX may reorder inputs!).
onnx_name_to_in_key = {}
sanitized_to_semantic = {sanitize_name(k): k for k in all_input_keys}
for onnx_name in actual_onnx_in_names:
matched = False
# Try exact match with sanitized names.
if onnx_name in sanitized_to_semantic:
onnx_name_to_in_key[onnx_name] = sanitized_to_semantic[onnx_name]
matched = True
else:
# Try stripping ONNX suffixes (.1, .2, etc.).
base_name = onnx_name
for suffix in [".1", ".2", ".3", "_1", "_2", "_3"]:
if base_name.endswith(suffix):
base_name = base_name[: -len(suffix)]
break
if base_name in sanitized_to_semantic:
onnx_name_to_in_key[onnx_name] = sanitized_to_semantic[base_name]
matched = True
if not matched:
log.warning(f"Could not match ONNX input '{onnx_name}' to any semantic key")
# Build output shapes dict.
output_shapes = {
"actions": list(actions.shape),
"joint_pos_targets": list(joint_pos_targets.shape),
"stiffness_targets": list(stiffness_targets.shape),
"damping_targets": list(damping_targets.shape),
}
# Generate YAML content.
yaml_content = _generate_yaml_content(
input_shapes=input_shapes,
output_shapes=output_shapes,
onnx_in_names=actual_onnx_in_names,
onnx_out_names=actual_onnx_out_names,
joint_names=joint_names,
body_names=body_names,
stiffness=stiffness,
damping=damping,
anchor_body=anchor_body,
dt=dt,
)
# Add runtime metadata for visualization/testing (not used by isaac-deploy).
yaml_content["_runtime"] = {
"onnx_in_names": actual_onnx_in_names,
"onnx_out_names": actual_onnx_out_names,
"onnx_name_to_in_key": onnx_name_to_in_key,
"passthrough_keys": passthrough_keys,
"obs_context_keys": obs_input_keys,
}
# Add metadata if provided.
if meta:
yaml_content["metadata"] = meta
# Save YAML.
yaml_path = path.replace(".onnx", ".yaml")
with open(yaml_path, "w") as f:
yaml.dump(yaml_content, f, default_flow_style=None, sort_keys=False)
log.info(f"✓ Unified pipeline exported to {path}")
log.info(f"✓ YAML configuration saved to {yaml_path}")
# Validate with onnxruntime.
if validate:
try:
import numpy as np
log.info("Validating with onnxruntime...")
# Build input dict.
input_key_to_value = {
key: inp for key, inp in zip(all_input_keys, sample_inputs)
}
onnx_inputs = {}
for onnx_name in actual_onnx_in_names:
if onnx_name in onnx_name_to_in_key:
semantic_key = onnx_name_to_in_key[onnx_name]
onnx_inputs[onnx_name] = (
input_key_to_value[semantic_key].detach().cpu().numpy()
)
onnx_outputs = session.run(actual_onnx_out_names, onnx_inputs)
# Compare with PyTorch outputs.
pytorch_outputs = [
actions.detach().cpu().numpy(),
joint_pos_targets.detach().cpu().numpy(),
stiffness_targets.detach().cpu().numpy(),
damping_targets.detach().cpu().numpy(),
]
for i, (onnx_name, pytorch_out) in enumerate(
zip(onnx_output_names, pytorch_outputs)
):
diff = np.abs(onnx_outputs[i] - pytorch_out).max()
log.info(f" {onnx_name}: max_diff = {diff:.6e}")
if diff > 1e-4:
log.warning(f" ⚠ Large difference detected for {onnx_name}")
log.info("✓ Validation passed")
except ImportError:
log.warning("onnxruntime not installed, skipping validation")
except Exception as e:
log.error(f"Validation failed: {e}")
raise
return path
###############################################################################
# Observation Export Utilities
###############################################################################
[docs]
class ObservationExportModule(torch.nn.Module):
"""Module that wraps observation functions for ONNX export.
This module takes raw context tensors as inputs and computes observations
by calling the configured observation functions. It's designed to be
exported to ONNX for deployment.
Works with MdpComponent-based observation config format::
observation_components = {
"max_coords_obs": MdpComponent(
compute_func=compute_humanoid_max_coords_observations,
dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos, ...},
static_params={"local_obs": True},
),
}
Args:
observation_configs: Dict of MdpComponent instances or legacy dict configs.
sample_context: Sample context dict to determine input shapes and resolve mappings.
device: Device for tensor operations.
Example:
>>> from protomotions.envs.mdp_component import MdpComponent
>>> from protomotions.envs.context_views import EnvContext
>>> configs = {
... "max_coords_obs": MdpComponent(
... compute_func=compute_max_coords,
... dynamic_vars={"body_pos": EnvContext.current.rigid_body_pos},
... )
... }
>>> context = env.context
>>> module = ObservationExportModule(configs, context, device)
>>> export_observations(module, context, "observations.onnx")
"""
[docs]
def __init__(
self,
observation_configs: Dict[str, Any],
sample_context: Dict[str, Any],
device: torch.device,
):
super().__init__()
self.device = device
# Import here to avoid circular dependency
from protomotions.envs.mdp_component import is_mdp_component
# Store observation functions and their input mappings
self._obs_functions = {}
self._obs_input_mappings = {} # obs_name -> {arg_name: context_key}
self._obs_constants = {} # obs_name -> {arg_name: value}
self._input_keys = set() # All unique context keys needed
self._output_keys = [] # Ordered list of output observation names
for obs_name, cfg in observation_configs.items():
assert is_mdp_component(cfg), "Observation config must be a MdpComponent"
router = cfg
compute_func = router.get_compute_func()
bindings_dict = router.get_bindings_dict() # {param_name: path_string}
params = router.get_params()
self._obs_functions[obs_name] = compute_func
self._output_keys.append(obs_name)
input_mapping = {}
for arg_name, context_path in bindings_dict.items():
input_mapping[arg_name] = context_path
self._input_keys.add(context_path)
self._obs_input_mappings[obs_name] = input_mapping
self._obs_constants[obs_name] = params
# Convert to ordered list for consistent ONNX input ordering
self._input_keys = sorted(list(self._input_keys))
# Pre-resolve constant tensors that might be in context (like hinge_axes_map)
self._resolved_constants = {}
for obs_name, mapping in self._obs_input_mappings.items():
self._resolved_constants[obs_name] = {}
for arg_name, var_expr in list(mapping.items()):
try:
# Try to resolve from context - if it's not a tensor, treat as constant
value = _resolve_context_path(var_expr, sample_context)
if not isinstance(value, torch.Tensor):
# Move from input_mapping to constants
self._obs_constants[obs_name][arg_name] = value
del self._obs_input_mappings[obs_name][arg_name]
self._resolved_constants[obs_name][arg_name] = value
except (NameError, KeyError, TypeError):
pass
# Rebuild input keys after removing non-tensors
self._input_keys = set()
for obs_name, mapping in self._obs_input_mappings.items():
for var_expr in mapping.values():
self._input_keys.add(var_expr)
self._input_keys = sorted(list(self._input_keys))
[docs]
def get_output_keys(self) -> list:
"""Get ordered list of output observation names."""
return self._output_keys
[docs]
def forward(self, *args) -> tuple:
"""Compute all observations from input tensors.
Args:
*args: Input tensors in the order of get_input_keys().
Returns:
Tuple of observation tensors in the order of get_output_keys().
"""
# Build context dict from positional args
context = {key: tensor for key, tensor in zip(self._input_keys, args)}
outputs = []
for obs_name in self._output_keys:
func = self._obs_functions[obs_name]
input_mapping = self._obs_input_mappings[obs_name]
constants = self._obs_constants[obs_name]
# Build kwargs for the function
func_kwargs = {}
# Add tensor inputs from context
for arg_name, var_expr in input_mapping.items():
func_kwargs[arg_name] = context[var_expr]
# Add constants
func_kwargs.update(constants)
# Call the observation function
obs_value = func(**func_kwargs)
outputs.append(obs_value)
return tuple(outputs)
[docs]
def export_observations(
observation_configs: Dict[str, Any],
sample_context: Dict[str, Any],
path: str,
device: torch.device,
validate: bool = True,
meta: Optional[Dict[str, Any]] = None,
) -> str:
"""Export observation computation to ONNX format.
Creates an ObservationExportModule from the observation configs and exports
it to ONNX. The exported model takes raw context tensors as inputs and
produces observation tensors as outputs.
Args:
observation_configs: Dict of observation component configurations.
sample_context: Sample context dict for tracing and shape inference.
path: Path to save the ONNX model.
device: Device for tensor operations.
validate: If True, validates with onnxruntime.
meta: Optional metadata to include in the JSON sidecar.
Returns:
Path to the exported ONNX model.
Example:
>>> configs = env.config.observation_components
>>> context = env.context
>>> export_observations(configs, context, "observations.onnx", device)
"""
import logging
log = logging.getLogger(__name__)
# Create the export module
module = ObservationExportModule(observation_configs, sample_context, device)
module.eval()
input_keys = module.get_input_keys()
output_keys = module.get_output_keys()
log.info(f"Observation export - Input keys: {input_keys}")
log.info(f"Observation export - Output keys: {output_keys}")
# Build sample inputs from context
sample_inputs = []
input_shapes = {}
for key in input_keys:
value = _resolve_context_path(key, sample_context)
if isinstance(value, torch.Tensor):
sample_inputs.append(value)
input_shapes[key] = list(value.shape)
else:
raise ValueError(f"Input '{key}' is not a tensor: {type(value)}")
# Run forward pass to get output shapes
with torch.no_grad():
sample_outputs = module(*sample_inputs)
output_shapes = {
key: list(out.shape) for key, out in zip(output_keys, sample_outputs)
}
log.info(f"Observation export - Input shapes: {input_shapes}")
log.info(f"Observation export - Output shapes: {output_shapes}")
# Create ONNX input/output names (sanitize for ONNX compatibility)
def sanitize_name(name: str) -> str:
return (
name.replace(".", "_").replace("[", "_").replace("]", "_").replace(":", "_")
)
onnx_input_names = [sanitize_name(k) for k in input_keys]
onnx_output_names = [sanitize_name(k) for k in output_keys]
# Export to ONNX
log.info(f"Exporting observations to {path}...")
torch.onnx.export(
module,
tuple(sample_inputs),
path,
input_names=onnx_input_names,
output_names=onnx_output_names,
opset_version=17,
do_constant_folding=True,
dynamic_axes={
**{name: {0: "batch_size"} for name in onnx_input_names},
**{name: {0: "batch_size"} for name in onnx_output_names},
},
dynamo=False,
)
# Load the exported model to get ACTUAL input/output names
# ONNX may rename inputs if there are graph-level issues
import onnxruntime as ort
session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
actual_onnx_in_names = [inp.name for inp in session.get_inputs()]
actual_onnx_out_names = [out.name for out in session.get_outputs()]
log.info(f"Requested ONNX input names: {onnx_input_names}")
log.info(f"Actual ONNX input names: {actual_onnx_in_names}")
if len(actual_onnx_in_names) != len(onnx_input_names):
log.warning(
f"ONNX has {len(actual_onnx_in_names)} inputs but we expected {len(onnx_input_names)}!"
)
# Build mapping from actual ONNX input name to semantic key
# This handles cases where ONNX adds suffixes like .1, .2
# Strategy: match by stripping suffixes and finding the original key
onnx_name_to_in_key = {}
in_key_to_onnx_names = {
key: [] for key in input_keys
} # One key may map to multiple ONNX inputs
for onnx_name in actual_onnx_in_names:
# Try to find the semantic key that matches this ONNX name
matched = False
# First try exact match with sanitized names
for i, expected_name in enumerate(onnx_input_names):
if onnx_name == expected_name:
semantic_key = input_keys[i]
onnx_name_to_in_key[onnx_name] = semantic_key
in_key_to_onnx_names[semantic_key].append(onnx_name)
matched = True
break
if not matched:
# Try matching by stripping .1, .2, etc. suffixes
base_name = onnx_name.rsplit(".", 1)[0] if "." in onnx_name else onnx_name
# Also try removing trailing numbers after underscore (e.g., previous_actions_1)
base_name_alt = (
base_name.rsplit("_", 1)[0] if base_name[-1].isdigit() else base_name
)
for i, expected_name in enumerate(onnx_input_names):
if base_name == expected_name or base_name_alt == expected_name:
semantic_key = input_keys[i]
onnx_name_to_in_key[onnx_name] = semantic_key
in_key_to_onnx_names[semantic_key].append(onnx_name)
matched = True
break
if not matched:
log.warning(f"Could not match ONNX input '{onnx_name}' to any semantic key")
log.info(f"ONNX name to semantic key mapping: {onnx_name_to_in_key}")
# Save metadata with ACTUAL ONNX names
metadata = {
"type": "observations",
"in_keys": input_keys,
"out_keys": output_keys,
"onnx_in_names": actual_onnx_in_names, # Use actual names
"onnx_out_names": actual_onnx_out_names, # Use actual names
"onnx_name_to_in_key": onnx_name_to_in_key, # Reverse mapping for inference
"input_shapes": input_shapes,
"output_shapes": output_shapes,
}
if meta:
metadata.update(meta)
meta_path = path.replace(".onnx", ".json")
with open(meta_path, "w") as f:
json.dump(metadata, f, indent=2)
log.info(f"✓ Observations exported to {path}")
log.info(f"✓ Metadata saved to {meta_path}")
# Validate with onnxruntime
if validate:
try:
import numpy as np
log.info("Validating with onnxruntime...")
# Build input dict: for each ONNX input, find the semantic key and get the value
input_key_to_value = {
key: inp for key, inp in zip(input_keys, sample_inputs)
}
onnx_inputs = {}
for onnx_name in actual_onnx_in_names:
if onnx_name in onnx_name_to_in_key:
semantic_key = onnx_name_to_in_key[onnx_name]
onnx_inputs[onnx_name] = (
input_key_to_value[semantic_key].detach().cpu().numpy()
)
else:
log.warning(f"No value for ONNX input '{onnx_name}'")
onnx_outputs = session.run(actual_onnx_out_names, onnx_inputs)
# Compare with PyTorch outputs
for i, (key, onnx_out) in enumerate(zip(output_keys, onnx_outputs)):
pytorch_out = sample_outputs[i].detach().cpu().numpy()
max_diff = np.abs(onnx_out - pytorch_out).max()
log.info(f" {key}: max_diff = {max_diff:.6e}")
if max_diff > 1e-4:
log.warning(f" ⚠ Large difference detected for {key}")
log.info("✓ Validation passed")
except ImportError:
log.warning("onnxruntime not installed, skipping validation")
except Exception as e:
log.error(f"Validation failed: {e}")
raise
return path