SharedActorCriticLearner

SharedActorCriticLearner uses a single gradient boosted tree learners to represent both an actor and a critic (Value function). Useful when the actor and critic need to share tree configurations or update rules. It is a wrapper around GBTLearner and supports training, prediction, saving/loading, and SHAP value computation per ensemble.

class gbrl.learners.actor_critic_learner.SharedActorCriticLearner(input_dim: int, output_dim: int, tree_struct: Dict, policy_optimizer: Dict, value_optimizer: Dict, params: Dict = {}, verbose: int = 0, device: str = 'cpu')[source]

Bases: GBTLearner

SharedActorCriticLearner is a variant of GBTLearner where a single tree is used for both actor (policy) and critic (value) learning. It utilizes gradient boosting trees (GBTs) to estimate both policy and value function parameters efficiently.

distil(obs: ndarray | Tensor, policy_targets: ndarray, value_targets: ndarray, params: Dict, verbose: int) Tuple[float, Dict][source]

Distills the trained model into a student model.

Parameters:
  • obs (NumericalData) – Input observations.

  • policy_targets (np.ndarray) – Target values for the policy (actor).

  • value_targets (np.ndarray) – Target values for the value function

  • (critic).

  • params (Dict) – Distillation parameters.

  • verbose (int) – Verbosity level.

Returns:

The final loss value and updated parameters for distillation

Return type:

Tuple[float, Dict]

predict(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True) Tuple[ndarray, ndarray][source]

Predicts both policy and value function outputs.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients.

  • True. (Defaults to)

  • start_idx (int, optional) – Start index for prediction. Defaults to

  • 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to

  • None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to

  • True.

Returns:

Predicted policy and value outputs.

Return type:

Tuple[np.ndarray, np.ndarray]

predict_critic(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True)[source]

Predicts the value function (critic) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.

  • start_idx (int, optional) – Start index for prediction. Defaults to 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to True.

Returns:

Predicted value function outputs.

Return type:

np.ndarray

predict_policy(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True)[source]

Predicts the policy (actor) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients.

  • True. (Defaults to)

  • start_idx (int, optional) – Start index for prediction. Defaults to

  • 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to

  • None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to

  • True.

Returns:

Predicted policy outputs.

Return type:

np.ndarray

step(obs: ndarray | Tensor, theta_grad: ndarray, value_grad: ndarray) None[source]

Performs a gradient update step for both policy and value function.

Parameters:
  • obs (NumericalData) – Input observations.

  • theta_grad (np.ndarray) – Gradient of the policy parameters.

  • value_grad (np.ndarray) – Gradient of the value function parameters.