# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for exporting trained models to ONNX format.
This module provides functions to export TensorDict-based models to ONNX format
using torch.onnx.export. The exported models can be used for deployment
and inference in production environments.
Key Functions:
- export_onnx: Export a TensorDictModule to ONNX format
- export_ppo_model: Export a trained PPO model to ONNX
- export_with_action_processing: Export model with baked-in action processing
"""
import torch
import json
from pathlib import Path
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from typing import Optional, Dict, Any, Union
from protomotions.robot_configs.base import ControlType
[docs]
class ONNXExportWrapper(torch.nn.Module):
"""Wrapper for TensorDictModule that accepts positional args for ONNX export.
TensorDictModules expect a TensorDict argument, but torch.onnx.export
uses positional tensor inputs. This wrapper bridges the gap.
"""
[docs]
def __init__(self, module: TensorDictModuleBase, in_keys: list, batch_size: int):
super().__init__()
self.module = module
self.in_keys = in_keys
self._batch_size = batch_size
[docs]
def forward(self, *args):
"""Forward that reconstructs TensorDict from positional args."""
# Reconstruct TensorDict from positional args
# Use stored batch_size since args[0].shape[0] doesn't work during JIT tracing
td = TensorDict(
{key: tensor for key, tensor in zip(self.in_keys, args)},
batch_size=[self._batch_size],
)
output_td = self.module(td)
return tuple(output_td[key] for key in self.module.out_keys)
[docs]
class ActionProcessingModule(torch.nn.Module):
"""Standalone action processing module for ONNX export.
This module takes raw actions and applies clamping and PD/torque transformations.
It's designed to be exported separately from the policy model.
For PD control modes (BUILT_IN_PD, PROPORTIONAL):
output = pd_action_offset + pd_action_scale * clamp(action)
For TORQUE control mode:
output = clamp(action) * torque_limits
Args:
control_type: Control type enum (BUILT_IN_PD, PROPORTIONAL, or TORQUE).
clamp_value: Value to clamp actions to [-clamp_value, clamp_value].
pd_action_offset: Offset for PD target computation (for PD modes).
pd_action_scale: Scale for PD target computation (for PD modes).
torque_limits: Torque limits for TORQUE mode (required for TORQUE mode).
stiffness: Per-DOF stiffness values for PD control (optional).
damping: Per-DOF damping values for PD control (optional).
"""
[docs]
def __init__(
self,
control_type: ControlType,
clamp_value: float,
pd_action_offset: Optional[torch.Tensor] = None,
pd_action_scale: Optional[torch.Tensor] = None,
torque_limits: Optional[torch.Tensor] = None,
stiffness: Optional[torch.Tensor] = None,
damping: Optional[torch.Tensor] = None,
):
super().__init__()
# Store control mode as boolean (resolved at construction time).
self._is_pd_control = control_type in (
ControlType.BUILT_IN_PD,
ControlType.PROPORTIONAL,
)
# Determine dtype from provided tensors.
if pd_action_offset is not None:
dtype = pd_action_offset.dtype
elif torque_limits is not None:
dtype = torque_limits.dtype
else:
dtype = torch.float32
self.register_buffer("clamp_value", torch.tensor(clamp_value, dtype=dtype))
if self._is_pd_control:
if pd_action_offset is None or pd_action_scale is None:
raise ValueError(
"pd_action_offset and pd_action_scale required for PD control modes"
)
self.register_buffer("pd_action_offset", pd_action_offset.clone())
self.register_buffer("pd_action_scale", pd_action_scale.clone())
self.output_type = "pd_targets"
elif control_type == ControlType.TORQUE:
if torque_limits is None:
raise ValueError("torque_limits required for TORQUE control mode")
self.register_buffer("torque_limits", torque_limits.clone())
self.output_type = "torque_targets"
else:
raise ValueError(f"Unsupported control type: {control_type}")
# Register stiffness and damping buffers for ONNX export.
if stiffness is not None:
self.register_buffer("stiffness", stiffness.clone())
else:
self.stiffness = None
if damping is not None:
self.register_buffer("damping", damping.clone())
else:
self.damping = None
[docs]
def forward(
self, action: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Process raw action into PD targets or torques.
Args:
action: Raw action tensor [batch_size, num_actions]
Returns:
Tuple of (actions, targets, stiffness_targets, damping_targets).
Each tensor has shape [batch_size, num_actions].
"""
# Clamp actions.
clamped_action = torch.clamp(action, -self.clamp_value, self.clamp_value)
# Apply transformation.
if self._is_pd_control:
targets = self.pd_action_offset + self.pd_action_scale * clamped_action
else:
targets = clamped_action * self.torque_limits
# Broadcast stiffness and damping to batch size.
batch_size = action.shape[0]
if self.stiffness is not None:
stiffness_targets = self.stiffness.unsqueeze(0).expand(batch_size, -1)
else:
stiffness_targets = torch.zeros_like(action)
if self.damping is not None:
damping_targets = self.damping.unsqueeze(0).expand(batch_size, -1)
else:
damping_targets = torch.zeros_like(action)
return clamped_action, targets, stiffness_targets, damping_targets
[docs]
class ActionProcessingONNXWrapper(torch.nn.Module):
"""Wrapper that includes action post-processing in the ONNX export.
This wrapper combines the policy model with action clamping and
PD/torque transformations, producing final targets that can be
directly applied by the simulator without further processing.
For PD control modes (BUILT_IN_PD, PROPORTIONAL):
output = pd_action_offset + pd_action_scale * clamp(action)
For TORQUE control mode:
output = clamp(action) * torque_limits
Args:
module: TensorDictModule (policy) to wrap.
in_keys: List of input keys for the module.
batch_size: Batch size for TensorDict reconstruction.
action_key: Key in module output containing the action (default: "action").
control_type: Control type enum (BUILT_IN_PD, PROPORTIONAL, or TORQUE).
clamp_value: Value to clamp actions to [-clamp_value, clamp_value].
pd_action_offset: Offset for PD target computation (for PD modes).
pd_action_scale: Scale for PD target computation (for PD modes).
torque_limits: Torque limits for TORQUE mode scaling.
Attributes:
out_keys: Output keys for the wrapper (always ["pd_targets"] or ["torque_targets"]).
"""
[docs]
def __init__(
self,
module: TensorDictModuleBase,
in_keys: list,
batch_size: int,
action_key: str,
control_type: ControlType,
clamp_value: float,
pd_action_offset: Optional[torch.Tensor] = None,
pd_action_scale: Optional[torch.Tensor] = None,
torque_limits: Optional[torch.Tensor] = None,
):
super().__init__()
self.module = module
self.in_keys = in_keys
self._batch_size = batch_size
self.action_key = action_key
# Use the standalone ActionProcessingModule
self.action_processor = ActionProcessingModule(
control_type=control_type,
clamp_value=clamp_value,
pd_action_offset=pd_action_offset,
pd_action_scale=pd_action_scale,
torque_limits=torque_limits,
)
self.output_type = self.action_processor.output_type
self.out_keys = [self.output_type]
[docs]
def forward(self, *args):
"""Forward pass: model -> action_processor."""
# Reconstruct TensorDict from positional args
td = TensorDict(
{key: tensor for key, tensor in zip(self.in_keys, args)},
batch_size=[self._batch_size],
)
# Run the model
output_td = self.module(td)
# Get raw action from model output
action = output_td[self.action_key]
# Process action
targets = self.action_processor(action)
return (targets,)
[docs]
def export_action_processing(
control_type: ControlType,
clamp_value: float,
pd_action_offset: torch.Tensor,
pd_action_scale: torch.Tensor,
torque_limits: torch.Tensor,
num_actions: int,
path: str,
batch_size: int = 1,
meta: Optional[Dict[str, Any]] = None,
validate: bool = True,
opset_version: int = 17,
):
"""Export action processing module to ONNX.
Creates a standalone ONNX model that takes raw actions and outputs
PD targets or torques.
Args:
control_type: Control type enum (BUILT_IN_PD, PROPORTIONAL, or TORQUE).
clamp_value: Value to clamp actions to [-clamp_value, clamp_value].
pd_action_offset: Offset for PD target computation.
pd_action_scale: Scale for PD target computation.
torque_limits: Torque limits for TORQUE mode.
num_actions: Number of action dimensions.
path: Path to save the ONNX model (must end with .onnx).
batch_size: Batch size for sample input (default: 1).
meta: Optional additional metadata to save.
validate: If True, validates the exported model with onnxruntime.
opset_version: ONNX opset version to use (default: 17).
"""
if not path.endswith(".onnx"):
raise ValueError(f"Export path must end with .onnx, got {path}.")
# Move tensors to CPU
pd_action_offset = pd_action_offset.cpu()
pd_action_scale = pd_action_scale.cpu()
torque_limits = torque_limits.cpu()
# Create the action processing module
action_processor = ActionProcessingModule(
control_type=control_type,
clamp_value=clamp_value,
pd_action_offset=pd_action_offset,
pd_action_scale=pd_action_scale,
torque_limits=torque_limits,
)
action_processor.eval()
action_processor.cpu()
# Create sample input
sample_action = torch.randn(batch_size, num_actions, dtype=pd_action_offset.dtype)
print(f"Exporting action processing to ONNX...")
print(f" Control type: {control_type.name}")
print(f" Clamp value: {clamp_value}")
print(f" Num actions: {num_actions}")
print(f" Output type: {action_processor.output_type}")
torch.onnx.export(
action_processor,
(sample_action,),
path,
input_names=["raw_action"],
output_names=[action_processor.output_type],
dynamic_axes={
"raw_action": {0: "batch_size"},
action_processor.output_type: {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
print(f"✓ Exported action processing to {path}")
# Save metadata
meta_path = path.replace(".onnx", ".json")
if meta is None:
meta = {}
meta["model_type"] = "action_processing"
meta["control_type"] = control_type.name
meta["clamp_value"] = clamp_value
meta["num_actions"] = num_actions
meta["output_type"] = action_processor.output_type
meta["input_name"] = "raw_action"
meta["output_name"] = action_processor.output_type
with open(meta_path, "w") as f:
json.dump(meta, f, indent=4)
print(f"✓ Exported metadata to {meta_path}")
# Validate
if validate:
try:
import onnxruntime as ort
ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
# Run ONNX
onnx_input = {"raw_action": sample_action.numpy()}
onnx_output = ort_session.run(None, onnx_input)[0]
# Run PyTorch
with torch.no_grad():
pytorch_output = action_processor(sample_action).numpy()
# Compare
max_diff = abs(onnx_output - pytorch_output).max()
print(f"✓ Validation: max diff = {max_diff:.2e}")
except ImportError:
print("⚠ Warning: onnxruntime not installed, skipping validation.")
except Exception as e:
print(f"⚠ Warning: Validation failed: {e}")
return action_processor
[docs]
def export_onnx(
module: TensorDictModuleBase,
td: TensorDict,
path: str,
meta: Optional[Dict[str, Any]] = None,
validate: bool = True,
opset_version: int = 17,
):
"""Export a TensorDictModule to ONNX format.
Uses torch.onnx.export to export the module. Creates a wrapper that
converts between TensorDict and positional tensor inputs for ONNX compatibility.
Args:
module: TensorDictModule to export.
td: Sample TensorDict input (used for tracing).
path: Path to save the ONNX model (must end with .onnx).
meta: Optional additional metadata to save.
validate: If True, validates the exported model with onnxruntime.
opset_version: ONNX opset version to use (default: 17).
Raises:
ValueError: If path doesn't end with .onnx.
Example:
>>> from protomotions.agents.ppo.model import PPOModel
>>> from tensordict import TensorDict
>>> model = PPOModel(config)
>>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1)
>>> export_onnx(model, sample_input, "policy.onnx")
"""
if not path.endswith(".onnx"):
raise ValueError(f"Export path must end with .onnx, got {path}.")
# Move to CPU and select only required input keys
td = td.cpu().select(*module.in_keys, strict=True)
module = module.cpu()
module.eval()
in_keys = list(module.in_keys)
out_keys = list(module.out_keys)
print(f"Exporting model to ONNX (PyTorch {torch.__version__})...")
print(f" Input keys: {in_keys}")
print(f" Output keys: {out_keys}")
# Create wrapper that accepts positional args instead of TensorDict
batch_size = td.batch_size[0] if td.batch_size else 1
wrapper = ONNXExportWrapper(module, in_keys, batch_size)
wrapper.eval()
# Prepare input tuple for torch.onnx.export
input_tensors = tuple(td[key] for key in in_keys)
# Create input/output names for ONNX
input_names = [f"input_{i}" for i in range(len(in_keys))]
output_names = [f"output_{i}" for i in range(len(out_keys))]
torch.onnx.export(
wrapper,
input_tensors,
path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
**{name: {0: "batch_size"} for name in input_names},
**{name: {0: "batch_size"} for name in output_names},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
print(f"✓ Exported ONNX model to {path}")
# Save metadata
meta_path = path.replace(".onnx", ".json")
if meta is None:
meta = {}
meta["in_keys"] = in_keys
meta["out_keys"] = out_keys
meta["in_shapes"] = [list(td[k].shape) for k in in_keys]
meta["onnx_input_names"] = input_names
meta["onnx_output_names"] = output_names
meta["input_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(input_names, in_keys)
}
meta["output_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(output_names, out_keys)
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=4)
print(f"✓ Exported metadata to {meta_path}")
# Validate with onnxruntime
if validate:
try:
import onnxruntime as ort
ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)
onnxruntime_input = {
name: to_numpy(tensor)
for name, tensor in zip(input_names, input_tensors)
}
ort_output = ort_session.run(None, onnxruntime_input)
assert len(ort_output) == len(out_keys), f"Output length mismatch: {len(ort_output)} vs {len(out_keys)}"
print("✓ ONNX model validation successful!")
except ImportError:
print("⚠ Warning: onnxruntime not installed, skipping validation.")
except Exception as e:
print(f"⚠ Warning: ONNX validation failed: {e}")
[docs]
def export_with_action_processing(
module: TensorDictModuleBase,
td: TensorDict,
path: str,
action_key: str,
control_type: ControlType,
clamp_value: float,
pd_action_offset: Optional[torch.Tensor] = None,
pd_action_scale: Optional[torch.Tensor] = None,
torque_limits: Optional[torch.Tensor] = None,
meta: Optional[Dict[str, Any]] = None,
validate: bool = True,
opset_version: int = 17,
):
"""Export a TensorDictModule to ONNX with baked-in action processing.
This function wraps the model with action clamping and PD/torque transformations,
producing an ONNX model that outputs final targets (PD targets or torques) ready
to be applied directly by the simulator.
For PD control modes:
output = pd_action_offset + pd_action_scale * clamp(action, -clamp_value, clamp_value)
For TORQUE control mode:
output = clamp(action, -clamp_value, clamp_value) * torque_limits
Args:
module: TensorDictModule (policy) to export.
td: Sample TensorDict input (used for tracing).
path: Path to save the ONNX model (must end with .onnx).
action_key: Key in module output containing the action (e.g., "action", "mean_action").
control_type: Control type enum (BUILT_IN_PD, PROPORTIONAL, or TORQUE).
clamp_value: Value to clamp actions to [-clamp_value, clamp_value].
pd_action_offset: Offset tensor for PD target computation (required for PD modes).
pd_action_scale: Scale tensor for PD target computation (required for PD modes).
torque_limits: Torque limits tensor for TORQUE mode (required for TORQUE mode).
meta: Optional additional metadata to save.
validate: If True, validates the exported model with onnxruntime.
opset_version: ONNX opset version to use (default: 17).
Raises:
ValueError: If path doesn't end with .onnx or required parameters are missing.
Example:
>>> from protomotions.agents.ppo.model import PPOModel
>>> model = PPOModel(config)
>>> sample_input = TensorDict({"obs": torch.randn(1, 128)}, batch_size=1)
>>> export_with_action_processing(
... model._actor,
... sample_input,
... "policy_with_actions.onnx",
... action_key="mean_action",
... control_type=ControlType.BUILT_IN_PD,
... clamp_value=1.0,
... pd_action_offset=torch.zeros(29),
... pd_action_scale=torch.ones(29),
... )
"""
if not path.endswith(".onnx"):
raise ValueError(f"Export path must end with .onnx, got {path}.")
# Move to CPU and select only required input keys
td = td.cpu().select(*module.in_keys, strict=True)
module = module.cpu()
module.eval()
# Move action processing tensors to CPU
if pd_action_offset is not None:
pd_action_offset = pd_action_offset.cpu()
if pd_action_scale is not None:
pd_action_scale = pd_action_scale.cpu()
if torque_limits is not None:
torque_limits = torque_limits.cpu()
in_keys = list(module.in_keys)
# Determine output type based on control type
if control_type in (ControlType.BUILT_IN_PD, ControlType.PROPORTIONAL):
output_type = "pd_targets"
else:
output_type = "torque_targets"
out_keys = [output_type]
print(f"Exporting model with action processing to ONNX (PyTorch {torch.__version__})...")
print(f" Input keys: {in_keys}")
print(f" Output keys: {out_keys}")
print(f" Control type: {control_type.name}")
print(f" Clamp value: {clamp_value}")
print(f" Action key: {action_key}")
# Create wrapper that includes action processing
batch_size = td.batch_size[0] if td.batch_size else 1
wrapper = ActionProcessingONNXWrapper(
module=module,
in_keys=in_keys,
batch_size=batch_size,
action_key=action_key,
control_type=control_type,
clamp_value=clamp_value,
pd_action_offset=pd_action_offset,
pd_action_scale=pd_action_scale,
torque_limits=torque_limits,
)
wrapper.eval()
# Prepare input tuple for torch.onnx.export
input_tensors = tuple(td[key] for key in in_keys)
# Create input/output names for ONNX
input_names = [f"input_{i}" for i in range(len(in_keys))]
output_names = [f"output_{i}" for i in range(len(out_keys))]
torch.onnx.export(
wrapper,
input_tensors,
path,
input_names=input_names,
output_names=output_names,
dynamic_axes={
**{name: {0: "batch_size"} for name in input_names},
**{name: {0: "batch_size"} for name in output_names},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
print(f"✓ Exported ONNX model to {path}")
# Save metadata
meta_path = path.replace(".onnx", ".json")
if meta is None:
meta = {}
meta["in_keys"] = in_keys
meta["out_keys"] = out_keys
meta["in_shapes"] = [list(td[k].shape) for k in in_keys]
meta["onnx_input_names"] = input_names
meta["onnx_output_names"] = output_names
meta["input_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(input_names, in_keys)
}
meta["output_mapping"] = {
onnx_name: semantic_name
for onnx_name, semantic_name in zip(output_names, out_keys)
}
# Action processing metadata
meta["action_processing_baked"] = True
meta["control_type"] = control_type.name
meta["clamp_value"] = clamp_value
meta["output_type"] = output_type
meta["action_key"] = action_key
with open(meta_path, "w") as f:
json.dump(meta, f, indent=4)
print(f"✓ Exported metadata to {meta_path}")
# Validate with onnxruntime
if validate:
try:
import onnxruntime as ort
ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)
onnxruntime_input = {
name: to_numpy(tensor)
for name, tensor in zip(input_names, input_tensors)
}
ort_output = ort_session.run(None, onnxruntime_input)
assert len(ort_output) == len(out_keys), (
f"Output length mismatch: {len(ort_output)} vs {len(out_keys)}"
)
print("✓ ONNX model validation successful!")
except ImportError:
print("⚠ Warning: onnxruntime not installed, skipping validation.")
except Exception as e:
print(f"⚠ Warning: ONNX validation failed: {e}")
[docs]
def export_ppo_actor(
actor, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True
):
"""Export a PPO actor's mu network to ONNX.
Exports the mean network (mu) of a PPO actor, which is the core policy
network without the distribution layer. Uses real observations from the
environment to ensure proper tracing.
Args:
actor: PPOActor instance to export.
sample_obs: Sample observation dict from environment (via agent.get_obs()).
path: Path to save the ONNX model.
validate: If True, validates the exported model.
Example:
>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_actor(agent.model._actor, sample_obs, "ppo_actor.onnx")
"""
# Create TensorDict from sample observations
batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0]
td = TensorDict(sample_obs, batch_size=batch_size)
meta = {
"model_type": "PPOActor",
"observation_keys": list(sample_obs.keys()),
"observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()},
}
export_onnx(actor, td, path, meta=meta, validate=validate)
[docs]
def export_ppo_critic(
critic, sample_obs: Dict[str, torch.Tensor], path: str, validate: bool = True
):
"""Export a PPO critic network to ONNX.
Uses real observations from the environment to ensure proper tracing.
Args:
critic: PPO critic (MultiHeadedMLP) instance to export.
sample_obs: Sample observation dict from environment (via agent.get_obs()).
path: Path to save the ONNX model.
validate: If True, validates the exported model.
Example:
>>> # Get real observations from environment
>>> env.reset()
>>> sample_obs = agent.get_obs()
>>> export_ppo_critic(agent.model._critic, sample_obs, "ppo_critic.onnx")
"""
# Create TensorDict from sample observations
batch_size = sample_obs[list(sample_obs.keys())[0]].shape[0]
td = TensorDict(sample_obs, batch_size=batch_size)
meta = {
"model_type": "PPOCritic",
"num_out": critic.config.num_out,
"observation_keys": list(sample_obs.keys()),
"observation_shapes": {k: list(v.shape) for k, v in sample_obs.items()},
}
export_onnx(critic, td, path, meta=meta, validate=validate)
[docs]
def export_ppo_model(
model, sample_obs: Dict[str, torch.Tensor], output_dir: str, validate: bool = True
):
"""Export a complete PPO model (actor and critic) to ONNX.
Exports both the actor and critic networks to separate ONNX files
in the specified directory.
Args:
model: PPOModel instance to export.
sample_obs: Sample observation dict for tracing.
output_dir: Directory to save the ONNX models.
validate: If True, validates the exported models.
Returns:
Dict with paths to exported files.
Example:
>>> model = trained_agent.model
>>> sample_obs = {"obs": torch.randn(1, 128)}
>>> paths = export_ppo_model(model, sample_obs, "exported_models/")
>>> print(paths)
{'actor': 'exported_models/actor.onnx', 'critic': 'exported_models/critic.onnx'}
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
actor_path = str(output_dir / "actor.onnx")
critic_path = str(output_dir / "critic.onnx")
print("Exporting PPO Actor...")
export_ppo_actor(model._actor, sample_obs, actor_path, validate=validate)
print("\nExporting PPO Critic...")
export_ppo_critic(model._critic, sample_obs, critic_path, validate=validate)
print(f"\nExport complete! Models saved to {output_dir}")
return {
"actor": actor_path,
"critic": critic_path,
"metadata": {
"actor_meta": str(output_dir / "actor.json"),
"critic_meta": str(output_dir / "critic.json"),
},
}
###############################################################################
# Unified Pipeline Export (Context -> Actions)
###############################################################################
[docs]
class UnifiedPipelineModule(torch.nn.Module):
"""Unified module that combines observations + policy + action processing.
This module takes raw context tensors as inputs and returns raw actions,
post-processed actions (PD targets or torques), and stiffness/damping targets.
Pipeline: Context -> Observations -> Policy -> Action Processing
Inputs are split into two groups:
1. Context tensors for observation computation (from observation_module.get_input_keys())
2. Additional tensors passed directly to policy (e.g., historical observations)
Outputs:
- actions: Raw actions from the policy (used by env.step())
- joint_pos_targets: Actions after clamping and PD/torque transform
- stiffness_targets: Per-DOF stiffness values (constant, broadcast to batch)
- damping_targets: Per-DOF damping values (constant, broadcast to batch)
Args:
observation_module: ObservationExportModule for computing observations
policy_module: Policy model (actor) that takes observations and returns actions
action_processing_module: ActionProcessingModule for post-processing
policy_in_keys: List of observation keys the policy expects as input
policy_action_key: Key for the action output from policy (e.g., 'mean_action')
passthrough_keys: Additional keys passed directly to policy (not from observations)
"""
[docs]
def __init__(
self,
observation_module: "ObservationExportModule",
policy_module: torch.nn.Module,
action_processing_module: "ActionProcessingModule",
policy_in_keys: list,
policy_action_key: str = "mean_action",
passthrough_keys: list = None,
):
super().__init__()
self.observation_module = observation_module
self.policy_module = policy_module
self.action_processing_module = action_processing_module
self.policy_in_keys = policy_in_keys
self.policy_action_key = policy_action_key
self.passthrough_keys = passthrough_keys or []
# Get observation output keys for mapping to policy inputs
self.obs_output_keys = observation_module.get_output_keys()
# Context inputs for observation computation
self.obs_input_keys = observation_module.get_input_keys()
self.num_obs_inputs = len(self.obs_input_keys)
[docs]
def forward(self, *all_tensors) -> tuple:
"""Run the full pipeline from context to actions.
Args:
*all_tensors: Input tensors in order of:
1. observation_module.get_input_keys()
2. passthrough_keys
Returns:
Tuple of (actions, joint_pos_targets, stiffness_targets, damping_targets).
"""
# Split inputs: observation context vs passthrough.
obs_context_tensors = all_tensors[: self.num_obs_inputs]
passthrough_tensors = all_tensors[self.num_obs_inputs :]
# Step 1: Compute observations from context.
obs_outputs = self.observation_module(*obs_context_tensors)
# Build observation dict from computed observations.
obs_dict = {
key: obs_outputs[i] for i, key in enumerate(self.obs_output_keys)
}
# Add passthrough tensors directly to obs_dict.
for key, tensor in zip(self.passthrough_keys, passthrough_tensors):
obs_dict[key] = tensor
# Step 2: Run policy with all required inputs.
from tensordict import TensorDict
# Infer batch size from passthrough tensors (which always have batch dim)
# or from a policy input tensor, avoiding constant tensors like body_ids
batch_size = None
if passthrough_tensors:
batch_size = passthrough_tensors[0].shape[0]
else:
# Find a tensor with "current_state_" prefix which should have batch dim
for key in self.policy_in_keys:
if key in obs_dict and key.startswith("current_state_"):
batch_size = obs_dict[key].shape[0]
break
if batch_size is None:
# Fallback: use first policy input that's 2D or more
for key in self.policy_in_keys:
if key in obs_dict and obs_dict[key].dim() >= 2:
batch_size = obs_dict[key].shape[0]
break
if batch_size is None:
batch_size = 1
policy_input = TensorDict(
{key: obs_dict[key] for key in self.policy_in_keys},
batch_size=[batch_size],
)
policy_output = self.policy_module(policy_input)
# Extract raw actions.
actions = policy_output[self.policy_action_key]
# Step 3: Apply action processing (returns 4 tensors).
actions, joint_pos_targets, stiffness_targets, damping_targets = (
self.action_processing_module(actions)
)
return actions, joint_pos_targets, stiffness_targets, damping_targets
###############################################################################
# YAML Configuration Generation
###############################################################################
# Mapping from ONNX input names to (name, kind) tuples for isaac-deploy YAML.
ONNX_INPUT_MAPPING = {
# Current state
"current_state_dof_pos": ("joint_pos", "joint_pos"),
"current_state_dof_vel": ("joint_vel", "joint_vel"),
"current_state_root_ang_vel": ("root_body_ang_vel", "root_body_ang_vel"),
"current_state_root_local_ang_vel": ("root_body_ang_vel", "root_body_ang_vel"),
"current_state_root_rot": ("root_body_rot", "root_body_rot"),
"current_state_anchor_rot": ("anchor_rot", "anchor_body_rot"),
"current_state_rigid_body_pos": ("body_pos", "body_pos"),
"current_state_rigid_body_rot": ("body_rot", "body_rot"),
# Historical state
"historical_actions": ("last_actions", "last_actions"),
"historical_dof_pos": ("joint_pos_history", "joint_pos"),
"historical_dof_vel": ("joint_vel_history", "joint_vel"),
"historical_root_ang_vel": ("root_body_ang_vel_history", "root_body_ang_vel"),
"historical_root_local_ang_vel": ("root_body_ang_vel_history", "root_body_ang_vel"),
"historical_root_rot": ("root_body_rot_history", "root_body_rot"),
"historical_anchor_rot": ("anchor_rot_history", "anchor_body_rot"),
"historical_rigid_body_pos": ("body_pos_history", "body_pos"),
"historical_rigid_body_rot": ("body_rot_history", "body_rot"),
"historical_rigid_body_vel": ("body_vel_history", "body_vel"),
"historical_rigid_body_ang_vel": ("body_ang_vel_history", "body_ang_vel"),
"historical_ground_heights": ("ground_heights_history", "ground_heights"),
"previous_actions": ("previous_actions", "last_actions"),
# Reference motion (mimic)
"mimic_ref_ang_vel": ("reference_motion_body_ang_vel", "reference_motion_body_ang_vel"),
"mimic_ref_dof_pos": ("reference_motion_joint_pos", "reference_motion_joint_pos"),
"mimic_ref_dof_vel": ("reference_motion_joint_vel", "reference_motion_joint_vel"),
"mimic_ref_rot": ("reference_motion_body_rot", "reference_motion_body_rot"),
"mimic_ref_anchor_rot": ("reference_motion_anchor_rot", "reference_motion_body_rot"),
# MaskedMimic sparse conditioning
"masked_mimic_ref_pos": ("masked_mimic_ref_pos", "masked_mimic_body_pos"),
"masked_mimic_ref_rot": ("masked_mimic_ref_rot", "masked_mimic_body_rot"),
"masked_mimic_target_bodies_masks": ("masked_mimic_body_masks", "masked_mimic_masks"),
"masked_mimic_target_poses_masks": ("masked_mimic_pose_masks", "masked_mimic_masks"),
"masked_mimic_time_offsets": ("masked_mimic_time_offsets", "masked_mimic_time"),
# VAE
"vae_noise": ("vae_noise", "vae_noise"),
}
# Mapping from ONNX output names to (name, kind) tuples for isaac-deploy YAML.
ONNX_OUTPUT_MAPPING = {
"actions": ("actions", "actions"),
"joint_pos_targets": ("joint_pos_targets", "joint_pos_targets"),
"stiffness_targets": ("stiffness_targets", "stiffness_targets"),
"damping_targets": ("damping_targets", "damping_targets"),
}
def _build_policy_input(
onnx_name: str,
input_shapes: Dict[str, list],
joint_names: list,
body_names: list,
anchor_body: str = "pelvis",
) -> Optional[Dict[str, Any]]:
"""Build a policy input entry for YAML from ONNX name."""
if onnx_name not in ONNX_INPUT_MAPPING:
return None
name, kind = ONNX_INPUT_MAPPING[onnx_name]
shape = input_shapes.get(onnx_name, [1, 1])
# Normalize shape to have batch size 1.
shape = [1] + list(shape)[1:]
# Reference motion inputs don't have history fields.
is_reference_motion = kind.startswith("reference_motion_")
# Infer history from shape (dimension 1 if > 2D).
history = shape[1] if len(shape) >= 3 and shape[1] > 1 else 0
# Determine include_current_value_in_history:
# - True for simulator values (joint_pos, joint_vel, root_body_rot, root_body_ang_vel,
# anchor_body_rot) and last_actions.
# - False for historical observations (history != 0 and not last_actions).
simulator_kinds = {
"joint_pos",
"joint_vel",
"root_body_rot",
"root_body_ang_vel",
"anchor_body_rot",
}
if kind == "last_actions":
include_current = True
elif kind in simulator_kinds and history == 0:
include_current = True
else:
include_current = False
entry = {
"name": name,
"kind": kind,
"shape": shape,
"key": onnx_name,
}
if kind.endswith("_rot"):
entry["quaternion_order"] = "xyzw"
if not is_reference_motion:
entry["history"] = history
entry["include_current_value_in_history"] = include_current
if "joint" in kind or kind == "last_actions":
entry["joint_names"] = joint_names
# Handle body_names for body-related observations.
if "body" in kind and kind != "anchor_body_rot":
# For reference_motion_body_rot with anchor, use just the anchor body.
if onnx_name == "mimic_ref_anchor_rot":
entry["body_names"] = [anchor_body]
elif len(body_names) in shape:
entry["body_names"] = body_names
# Handle anchor_body_rot observations.
if kind == "anchor_body_rot":
entry["anchor_body"] = anchor_body
if kind == "last_actions":
entry["output_key"] = "actions"
return entry
def _build_policy_output(
onnx_name: str,
output_shapes: Dict[str, list],
joint_names: list,
stiffness: list,
damping: list,
use_onnx_for_gains: bool = True,
) -> Optional[Dict[str, Any]]:
"""Build a policy output entry for YAML from ONNX name.
Args:
onnx_name: Name of the ONNX output.
output_shapes: Dictionary mapping ONNX output names to shapes.
joint_names: List of joint names.
stiffness: List of stiffness values (used only if use_onnx_for_gains=False).
damping: List of damping values (used only if use_onnx_for_gains=False).
use_onnx_for_gains: If True, read stiffness/damping from ONNX outputs.
If False, use constant values from YAML.
"""
if onnx_name not in ONNX_OUTPUT_MAPPING:
return None
name, kind = ONNX_OUTPUT_MAPPING[onnx_name]
shape = output_shapes.get(onnx_name, [1, len(joint_names)])
# Normalize shape to have batch size 1.
shape = [1] + list(shape)[1:]
entry = {
"name": name,
"kind": kind,
"key": onnx_name, # Always use ONNX output name as key.
"shape": shape,
}
# Add joint_names only for action terms that need it (not the passthrough "actions" term).
if "joint" in kind or kind in ("stiffness_targets", "damping_targets"):
entry["joint_names"] = joint_names
# For stiffness/damping, optionally fall back to YAML constants.
if kind == "stiffness_targets" and not use_onnx_for_gains:
entry["key"] = None # Use YAML values instead of ONNX.
entry["stiffness"] = stiffness
if kind == "damping_targets" and not use_onnx_for_gains:
entry["key"] = None # Use YAML values instead of ONNX.
entry["damping"] = damping
return entry
def _generate_yaml_content(
input_shapes: Dict[str, list],
output_shapes: Dict[str, list],
onnx_in_names: list,
onnx_out_names: list,
joint_names: list,
body_names: list,
stiffness: list,
damping: list,
anchor_body: str = "pelvis",
) -> Dict[str, Any]:
"""Generate the complete YAML content for isaac-deploy."""
# Build policy inputs.
policy_inputs = []
for onnx_name in onnx_in_names:
entry = _build_policy_input(
onnx_name, input_shapes, joint_names, body_names, anchor_body
)
if entry:
policy_inputs.append(entry)
# Build policy outputs.
policy_outputs = []
for onnx_name in onnx_out_names:
entry = _build_policy_output(
onnx_name, output_shapes, joint_names, stiffness, damping
)
if entry:
policy_outputs.append(entry)
return {
"type": "unified_pipeline",
"joint_names": joint_names,
"body_names": body_names,
"stiffness": stiffness,
"damping": damping,
"policy_inputs": policy_inputs,
"policy_outputs": policy_outputs,
}
[docs]
def export_unified_pipeline(
observation_configs: Dict[str, Any],
sample_context: Dict[str, Any],
policy_module: torch.nn.Module,
policy_in_keys: list,
policy_action_key: str,
action_processing_module: "ActionProcessingModule",
path: str,
device: torch.device,
robot_config: Any,
passthrough_obs: Optional[Dict[str, torch.Tensor]] = None,
validate: bool = True,
meta: Optional[Dict[str, Any]] = None,
) -> str:
"""Export the complete pipeline (context -> actions) as a single ONNX model.
Creates a unified model that:
1. Computes observations from context
2. Runs the policy to get raw actions (using both computed obs and passthrough inputs)
3. Applies action processing (clamp + PD/torque transform)
Outputs four tensors:
- actions: Raw actions from policy (used by env.step())
- joint_pos_targets: Actions after clamping and PD/torque transform
- stiffness_targets: Per-DOF stiffness values (constant)
- damping_targets: Per-DOF damping values (constant)
Also generates a YAML configuration file for isaac-deploy.
Args:
observation_configs: Dict of observation component configurations.
sample_context: Sample context dict for tracing.
policy_module: Policy model (actor) that takes observations.
policy_in_keys: List of observation keys the policy expects.
policy_action_key: Key for action output from policy.
action_processing_module: ActionProcessingModule instance.
path: Path to save the ONNX model.
device: Device for tensor operations.
robot_config: Robot configuration with kinematic_info and control.
passthrough_obs: Dict of additional observations passed directly to policy
(e.g., historical observations). Keys are obs names, values are sample tensors.
validate: If True, validates with onnxruntime.
meta: Optional metadata to include.
Returns:
Path to the exported ONNX model.
"""
import logging
import yaml
log = logging.getLogger(__name__)
log.info("=" * 60)
log.info("Exporting Unified Pipeline (Context -> Actions)")
log.info("=" * 60)
passthrough_obs = passthrough_obs or {}
# Extract robot metadata.
joint_names = robot_config.kinematic_info.dof_names
body_names = robot_config.kinematic_info.body_names
stiffness = [
float(robot_config.control.control_info[j].stiffness) for j in joint_names
]
damping = [
float(robot_config.control.control_info[j].damping) for j in joint_names
]
# Get anchor body name (defaults to first body if not specified).
anchor_body = (
robot_config.anchor_body_name
if robot_config.anchor_body_name
else body_names[0]
)
log.info(f"Joint names: {joint_names}")
log.info(f"Body names: {body_names}")
log.info(f"Anchor body: {anchor_body}")
# Step 1: Create observation module.
obs_module = ObservationExportModule(observation_configs, sample_context, device)
obs_module.eval()
obs_input_keys = obs_module.get_input_keys()
obs_output_keys = obs_module.get_output_keys()
passthrough_keys = list(passthrough_obs.keys())
log.info(f"Context input keys (for obs): {obs_input_keys}")
log.info(f"Observation output keys: {obs_output_keys}")
log.info(f"Passthrough keys (direct to policy): {passthrough_keys}")
log.info(f"Policy input keys: {policy_in_keys}")
# Check coverage: obs outputs + passthrough should cover all policy inputs.
available = set(obs_output_keys) | set(passthrough_keys)
missing = set(policy_in_keys) - available
if missing:
raise ValueError(
f"Policy requires inputs not available: {missing}. "
f"Available from obs: {obs_output_keys}, passthrough: {passthrough_keys}"
)
# Step 2: Create unified module with passthrough.
# Move everything to CPU for ONNX export.
unified_module = UnifiedPipelineModule(
observation_module=obs_module,
policy_module=policy_module.cpu(),
action_processing_module=action_processing_module.cpu(),
policy_in_keys=policy_in_keys,
policy_action_key=policy_action_key,
passthrough_keys=passthrough_keys,
)
unified_module.cpu()
unified_module.eval()
# All input keys: observation context + passthrough.
all_input_keys = unified_module.get_all_input_keys()
# Build sample inputs from context (for obs) and passthrough_obs.
# Move all to CPU for ONNX export.
sample_inputs = []
input_shapes = {}
# Observation context inputs.
for key in obs_input_keys:
value = eval(key, {"__builtins__": {}}, sample_context)
if isinstance(value, torch.Tensor):
sample_inputs.append(value.cpu())
input_shapes[key] = list(value.shape)
else:
raise ValueError(f"Input '{key}' is not a tensor: {type(value)}")
# Passthrough inputs.
for key in passthrough_keys:
value = passthrough_obs[key]
sample_inputs.append(value.cpu())
input_shapes[key] = list(value.shape)
# Run forward pass to get output shapes.
with torch.no_grad():
actions, joint_pos_targets, stiffness_targets, damping_targets = (
unified_module(*sample_inputs)
)
log.info(f"Actions shape: {list(actions.shape)}")
log.info(f"Joint pos targets shape: {list(joint_pos_targets.shape)}")
log.info(f"Stiffness targets shape: {list(stiffness_targets.shape)}")
log.info(f"Damping targets shape: {list(damping_targets.shape)}")
# Create ONNX names.
def sanitize_name(name: str) -> str:
return (
name.replace(".", "_").replace("[", "_").replace("]", "_").replace(":", "_")
)
onnx_input_names = [sanitize_name(k) for k in all_input_keys]
onnx_output_names = [
"actions",
"joint_pos_targets",
"stiffness_targets",
"damping_targets",
]
# Export to ONNX.
log.info(f"Exporting unified pipeline to {path}...")
torch.onnx.export(
unified_module,
tuple(sample_inputs),
path,
input_names=onnx_input_names,
output_names=onnx_output_names,
opset_version=17,
do_constant_folding=True,
dynamic_axes={
**{name: {0: "batch_size"} for name in onnx_input_names},
**{name: {0: "batch_size"} for name in onnx_output_names},
},
dynamo=False,
)
# Load to get actual ONNX names.
import onnxruntime as ort
session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
actual_onnx_in_names = [inp.name for inp in session.get_inputs()]
actual_onnx_out_names = [out.name for out in session.get_outputs()]
# Build mapping from ONNX name to semantic key.
# Match by sanitized name similarity (ONNX may reorder inputs!).
onnx_name_to_in_key = {}
sanitized_to_semantic = {sanitize_name(k): k for k in all_input_keys}
for onnx_name in actual_onnx_in_names:
matched = False
# Try exact match with sanitized names.
if onnx_name in sanitized_to_semantic:
onnx_name_to_in_key[onnx_name] = sanitized_to_semantic[onnx_name]
matched = True
else:
# Try stripping ONNX suffixes (.1, .2, etc.).
base_name = onnx_name
for suffix in [".1", ".2", ".3", "_1", "_2", "_3"]:
if base_name.endswith(suffix):
base_name = base_name[: -len(suffix)]
break
if base_name in sanitized_to_semantic:
onnx_name_to_in_key[onnx_name] = sanitized_to_semantic[base_name]
matched = True
if not matched:
log.warning(f"Could not match ONNX input '{onnx_name}' to any semantic key")
# Build output shapes dict.
output_shapes = {
"actions": list(actions.shape),
"joint_pos_targets": list(joint_pos_targets.shape),
"stiffness_targets": list(stiffness_targets.shape),
"damping_targets": list(damping_targets.shape),
}
# Generate YAML content.
yaml_content = _generate_yaml_content(
input_shapes=input_shapes,
output_shapes=output_shapes,
onnx_in_names=actual_onnx_in_names,
onnx_out_names=actual_onnx_out_names,
joint_names=joint_names,
body_names=body_names,
stiffness=stiffness,
damping=damping,
anchor_body=anchor_body,
)
# Add runtime metadata for visualization/testing (not used by isaac-deploy).
yaml_content["_runtime"] = {
"onnx_in_names": actual_onnx_in_names,
"onnx_out_names": actual_onnx_out_names,
"onnx_name_to_in_key": onnx_name_to_in_key,
"passthrough_keys": passthrough_keys,
"obs_context_keys": obs_input_keys,
}
# Add metadata if provided.
if meta:
yaml_content["metadata"] = meta
# Save YAML.
yaml_path = path.replace(".onnx", ".yaml")
with open(yaml_path, "w") as f:
yaml.dump(yaml_content, f, default_flow_style=None, sort_keys=False)
log.info(f"✓ Unified pipeline exported to {path}")
log.info(f"✓ YAML configuration saved to {yaml_path}")
# Validate with onnxruntime.
if validate:
try:
import numpy as np
log.info("Validating with onnxruntime...")
# Build input dict.
input_key_to_value = {
key: inp for key, inp in zip(all_input_keys, sample_inputs)
}
onnx_inputs = {}
for onnx_name in actual_onnx_in_names:
if onnx_name in onnx_name_to_in_key:
semantic_key = onnx_name_to_in_key[onnx_name]
onnx_inputs[onnx_name] = (
input_key_to_value[semantic_key].detach().cpu().numpy()
)
onnx_outputs = session.run(actual_onnx_out_names, onnx_inputs)
# Compare with PyTorch outputs.
pytorch_outputs = [
actions.detach().cpu().numpy(),
joint_pos_targets.detach().cpu().numpy(),
stiffness_targets.detach().cpu().numpy(),
damping_targets.detach().cpu().numpy(),
]
for i, (onnx_name, pytorch_out) in enumerate(
zip(onnx_output_names, pytorch_outputs)
):
diff = np.abs(onnx_outputs[i] - pytorch_out).max()
log.info(f" {onnx_name}: max_diff = {diff:.6e}")
if diff > 1e-4:
log.warning(f" ⚠ Large difference detected for {onnx_name}")
log.info("✓ Validation passed")
except ImportError:
log.warning("onnxruntime not installed, skipping validation")
except Exception as e:
log.error(f"Validation failed: {e}")
raise
return path
###############################################################################
# Observation Export Utilities
###############################################################################
[docs]
class ObservationExportModule(torch.nn.Module):
"""Module that wraps observation functions for ONNX export.
This module takes raw context tensors as inputs and computes observations
by calling the configured observation functions. It's designed to be
exported to ONNX for deployment.
The module resolves all variable mappings at construction time and stores
them as a mapping from input tensor names to function argument names.
Args:
observation_configs: Dict of observation component configurations.
sample_context: Sample context dict to determine input shapes and resolve mappings.
device: Device for tensor operations.
Example:
>>> from protomotions.envs.obs.mimic_obs_functions import mimic_target_poses_max_coords_factory
>>> configs = {"mimic_target_poses": mimic_target_poses_max_coords_factory()}
>>> context = env._get_global_context()
>>> module = ObservationExportModule(configs, context, device)
>>> # Export to ONNX
>>> export_observations(module, context, "observations.onnx")
"""
[docs]
def __init__(
self,
observation_configs: Dict[str, Any],
sample_context: Dict[str, Any],
device: torch.device,
):
super().__init__()
self.device = device
# Store observation functions and their input mappings
self._obs_functions = {}
self._obs_input_mappings = {} # obs_name -> {arg_name: context_key}
self._obs_constants = {} # obs_name -> {arg_name: value}
self._input_keys = set() # All unique context keys needed
self._output_keys = [] # Ordered list of output observation names
for obs_name, component in observation_configs.items():
if component.function is None:
continue
self._obs_functions[obs_name] = component.function
self._output_keys.append(obs_name)
input_mapping = {}
constants = {}
for arg_name, var_value in component.variables.items():
if isinstance(var_value, str):
# This is a context reference (flat keys like "current_state_rigid_body_pos")
input_mapping[arg_name] = var_value
self._input_keys.add(var_value)
else:
# This is a constant value (int, float, bool, etc.)
constants[arg_name] = var_value
self._obs_input_mappings[obs_name] = input_mapping
self._obs_constants[obs_name] = constants
# Convert to ordered list for consistent ONNX input ordering
self._input_keys = sorted(list(self._input_keys))
# Pre-resolve constant tensors that might be in context (like hinge_axes_map)
self._resolved_constants = {}
for obs_name, mapping in self._obs_input_mappings.items():
self._resolved_constants[obs_name] = {}
for arg_name, var_expr in list(mapping.items()):
try:
# Try to resolve from context - if it's not a tensor, treat as constant
value = eval(var_expr, {"__builtins__": {}}, sample_context)
if not isinstance(value, torch.Tensor):
# Move from input_mapping to constants
self._obs_constants[obs_name][arg_name] = value
del self._obs_input_mappings[obs_name][arg_name]
self._resolved_constants[obs_name][arg_name] = value
except Exception:
pass
# Rebuild input keys after removing non-tensors
self._input_keys = set()
for obs_name, mapping in self._obs_input_mappings.items():
for var_expr in mapping.values():
self._input_keys.add(var_expr)
self._input_keys = sorted(list(self._input_keys))
[docs]
def get_output_keys(self) -> list:
"""Get ordered list of output observation names."""
return self._output_keys
[docs]
def forward(self, *args) -> tuple:
"""Compute all observations from input tensors.
Args:
*args: Input tensors in the order of get_input_keys().
Returns:
Tuple of observation tensors in the order of get_output_keys().
"""
# Build context dict from positional args
context = {key: tensor for key, tensor in zip(self._input_keys, args)}
outputs = []
for obs_name in self._output_keys:
func = self._obs_functions[obs_name]
input_mapping = self._obs_input_mappings[obs_name]
constants = self._obs_constants[obs_name]
# Build kwargs for the function
func_kwargs = {}
# Add tensor inputs from context
for arg_name, var_expr in input_mapping.items():
func_kwargs[arg_name] = context[var_expr]
# Add constants
func_kwargs.update(constants)
# Call the observation function
obs_value = func(**func_kwargs)
outputs.append(obs_value)
return tuple(outputs)
[docs]
def export_observations(
observation_configs: Dict[str, Any],
sample_context: Dict[str, Any],
path: str,
device: torch.device,
validate: bool = True,
meta: Optional[Dict[str, Any]] = None,
) -> str:
"""Export observation computation to ONNX format.
Creates an ObservationExportModule from the observation configs and exports
it to ONNX. The exported model takes raw context tensors as inputs and
produces observation tensors as outputs.
Args:
observation_configs: Dict of observation component configurations.
sample_context: Sample context dict for tracing and shape inference.
path: Path to save the ONNX model.
device: Device for tensor operations.
validate: If True, validates with onnxruntime.
meta: Optional metadata to include in the JSON sidecar.
Returns:
Path to the exported ONNX model.
Example:
>>> configs = env.config.observation_components
>>> context = env._get_global_context()
>>> export_observations(configs, context, "observations.onnx", device)
"""
import logging
log = logging.getLogger(__name__)
# Create the export module
module = ObservationExportModule(observation_configs, sample_context, device)
module.eval()
input_keys = module.get_input_keys()
output_keys = module.get_output_keys()
log.info(f"Observation export - Input keys: {input_keys}")
log.info(f"Observation export - Output keys: {output_keys}")
# Build sample inputs from context
sample_inputs = []
input_shapes = {}
for key in input_keys:
value = eval(key, {"__builtins__": {}}, sample_context)
if isinstance(value, torch.Tensor):
sample_inputs.append(value)
input_shapes[key] = list(value.shape)
else:
raise ValueError(f"Input '{key}' is not a tensor: {type(value)}")
# Run forward pass to get output shapes
with torch.no_grad():
sample_outputs = module(*sample_inputs)
output_shapes = {
key: list(out.shape) for key, out in zip(output_keys, sample_outputs)
}
log.info(f"Observation export - Input shapes: {input_shapes}")
log.info(f"Observation export - Output shapes: {output_shapes}")
# Create ONNX input/output names (sanitize for ONNX compatibility)
def sanitize_name(name: str) -> str:
return name.replace(".", "_").replace("[", "_").replace("]", "_").replace(":", "_")
onnx_input_names = [sanitize_name(k) for k in input_keys]
onnx_output_names = [sanitize_name(k) for k in output_keys]
# Export to ONNX
log.info(f"Exporting observations to {path}...")
torch.onnx.export(
module,
tuple(sample_inputs),
path,
input_names=onnx_input_names,
output_names=onnx_output_names,
opset_version=17,
do_constant_folding=True,
dynamic_axes={
**{name: {0: "batch_size"} for name in onnx_input_names},
**{name: {0: "batch_size"} for name in onnx_output_names},
},
dynamo=False,
)
# Load the exported model to get ACTUAL input/output names
# ONNX may rename inputs if there are graph-level issues
import onnxruntime as ort
session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
actual_onnx_in_names = [inp.name for inp in session.get_inputs()]
actual_onnx_out_names = [out.name for out in session.get_outputs()]
log.info(f"Requested ONNX input names: {onnx_input_names}")
log.info(f"Actual ONNX input names: {actual_onnx_in_names}")
if len(actual_onnx_in_names) != len(onnx_input_names):
log.warning(f"ONNX has {len(actual_onnx_in_names)} inputs but we expected {len(onnx_input_names)}!")
# Build mapping from actual ONNX input name to semantic key
# This handles cases where ONNX adds suffixes like .1, .2
# Strategy: match by stripping suffixes and finding the original key
onnx_name_to_in_key = {}
in_key_to_onnx_names = {key: [] for key in input_keys} # One key may map to multiple ONNX inputs
for onnx_name in actual_onnx_in_names:
# Try to find the semantic key that matches this ONNX name
matched = False
# First try exact match with sanitized names
for i, expected_name in enumerate(onnx_input_names):
if onnx_name == expected_name:
semantic_key = input_keys[i]
onnx_name_to_in_key[onnx_name] = semantic_key
in_key_to_onnx_names[semantic_key].append(onnx_name)
matched = True
break
if not matched:
# Try matching by stripping .1, .2, etc. suffixes
base_name = onnx_name.rsplit(".", 1)[0] if "." in onnx_name else onnx_name
# Also try removing trailing numbers after underscore (e.g., previous_actions_1)
base_name_alt = base_name.rsplit("_", 1)[0] if base_name[-1].isdigit() else base_name
for i, expected_name in enumerate(onnx_input_names):
if base_name == expected_name or base_name_alt == expected_name:
semantic_key = input_keys[i]
onnx_name_to_in_key[onnx_name] = semantic_key
in_key_to_onnx_names[semantic_key].append(onnx_name)
matched = True
break
if not matched:
log.warning(f"Could not match ONNX input '{onnx_name}' to any semantic key")
log.info(f"ONNX name to semantic key mapping: {onnx_name_to_in_key}")
# Save metadata with ACTUAL ONNX names
metadata = {
"type": "observations",
"in_keys": input_keys,
"out_keys": output_keys,
"onnx_in_names": actual_onnx_in_names, # Use actual names
"onnx_out_names": actual_onnx_out_names, # Use actual names
"onnx_name_to_in_key": onnx_name_to_in_key, # Reverse mapping for inference
"input_shapes": input_shapes,
"output_shapes": output_shapes,
}
if meta:
metadata.update(meta)
meta_path = path.replace(".onnx", ".json")
with open(meta_path, "w") as f:
json.dump(metadata, f, indent=2)
log.info(f"✓ Observations exported to {path}")
log.info(f"✓ Metadata saved to {meta_path}")
# Validate with onnxruntime
if validate:
try:
import numpy as np
log.info("Validating with onnxruntime...")
# Build input dict: for each ONNX input, find the semantic key and get the value
input_key_to_value = {key: inp for key, inp in zip(input_keys, sample_inputs)}
onnx_inputs = {}
for onnx_name in actual_onnx_in_names:
if onnx_name in onnx_name_to_in_key:
semantic_key = onnx_name_to_in_key[onnx_name]
onnx_inputs[onnx_name] = input_key_to_value[semantic_key].detach().cpu().numpy()
else:
log.warning(f"No value for ONNX input '{onnx_name}'")
onnx_outputs = session.run(actual_onnx_out_names, onnx_inputs)
# Compare with PyTorch outputs
for i, (key, onnx_out) in enumerate(zip(output_keys, onnx_outputs)):
pytorch_out = sample_outputs[i].detach().cpu().numpy()
max_diff = np.abs(onnx_out - pytorch_out).max()
log.info(f" {key}: max_diff = {max_diff:.6e}")
if max_diff > 1e-4:
log.warning(f" ⚠ Large difference detected for {key}")
log.info("✓ Validation passed")
except ImportError:
log.warning("onnxruntime not installed, skipping validation")
except Exception as e:
log.error(f"Validation failed: {e}")
raise
return path