SharedActorCriticLearner

SharedActorCriticLearner uses a single gradient boosted tree learners to represent both an actor and a critic (Value function). Useful when the actor and critic need to share tree configurations or update rules. It is a wrapper around GBTLearner and supports training, prediction, saving/loading, and SHAP value computation per ensemble.

class gbrl.learners.actor_critic_learner.SharedActorCriticLearner(input_dim: int, output_dim: int, tree_struct: Dict, policy_optimizer: Dict, value_optimizer: Dict, params: Dict = {}, verbose: int = 0, device: str = 'cpu', name: str = 'SharedActorCritic')[source]

Bases: GBTLearner

SharedActorCriticLearner is a variant of GBTLearner where a single tree is used for both actor (policy) and critic (value) learning. It utilizes gradient boosting trees (GBTs) to estimate both policy and value function parameters efficiently.

compress(trees_to_keep: int, gradient_steps: int, features: numpy.ndarray | torch.Tensor, actions: torch.Tensor | None = None, log_std: torch.Tensor | None = None, method: str = 'first_k', dist_type: str = 'deterministic', optimizer_kwargs: Dict[str, Any] | None = None, temperature: float = 1.0, lambda_reg: float = 1.0, **kwargs) float[source]

Compresses the tree ensemble by selecting and retraining a subset of trees.

Parameters:
  • trees_to_keep (int) – Number of trees to retain in the compressed model.

  • gradient_steps (int) – Number of optimization steps during compression.

  • features (NumericalData) – Input feature matrix (n_samples, n_features).

  • actions (th.Tensor, optional) – Target actions (for policy compression). Required unless dist_type is ‘deterministic’ or ‘supervised_learning’.

  • log_std (th.Tensor, optional) – Log standard deviation (only used for certain policy types).

  • method (str) – Tree selection method. Defaults to ‘first_k’.

  • dist_type (str) – Compression type. Supported: ‘deterministic’, ‘supervised_learning’, ‘categorical’, ‘gaussian’. For ‘deterministic’ and ‘supervised_learning’, actions are not required.

  • optimizer_kwargs (dict, optional) – Optimizer configuration.

  • temperature (float) – Temperature parameter for soft selection.

  • lambda_reg (float) – L2 regularization coefficient on weights.

  • **kwargs – Additional keyword arguments passed to the compressor.

Returns:

Final loss value after compression.

Return type:

float

distil(obs: numpy.ndarray, policy_targets: numpy.ndarray, value_targets: numpy.ndarray, params: Dict, verbose: int = 0) Tuple[float, Dict][source]

Distills the trained model into a student model.

Parameters:
  • obs (np.ndarray) – Input observations.

  • policy_targets (np.ndarray) – Target values for the policy (actor).

  • value_targets (np.ndarray) – Target values for the value function

  • (critic).

  • params (Dict) – Distillation parameters.

  • verbose (int) – Verbosity level.

Returns:

The final loss value and updated parameters for distillation

Return type:

Tuple[float, Dict]

predict(inputs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) Tuple[numpy.ndarray | torch.Tensor, numpy.ndarray | torch.Tensor][source]

Predicts both policy and value function outputs.

Parameters:
  • inputs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients.

  • True. (Defaults to)

  • start_idx (int, optional) – Start index for prediction. Defaults to

  • 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to

  • None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to

  • True.

Returns:

Predicted policy and value outputs.

Return type:

Tuple[NumericalData, NumericalData]

predict_critic(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]

Predicts the value function (critic) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.

  • start_idx (int, optional) – Start index for prediction. Defaults to 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to True.

Returns:

Predicted value function outputs.

Return type:

NumericalData

predict_policy(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]

Predicts the policy (actor) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients.

  • True. (Defaults to)

  • start_idx (int, optional) – Start index for prediction. Defaults to

  • 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to

  • None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to

  • True.

Returns:

Predicted policy outputs.

Return type:

NumericalData