SeparateActorCriticLearner

SeparateActorCriticLearner uses two distinct gradient boosted tree learners to represent an actor and a critic (Value function). Useful when the actor and critic need to be trained independently with different tree configurations or update rules. It is a wrapper around MultiGBTLearner and supports training, prediction, saving/loading, and SHAP value computation per ensemble.

class gbrl.learners.actor_critic_learner.SeparateActorCriticLearner(input_dim: int, output_dim: int, tree_struct: Dict, policy_optimizer: Dict, value_optimizer: Dict, params: Dict = {}, verbose: int = 0, device: str = 'cpu')[source]

Bases: MultiGBTLearner

Implements a separate actor-critic learner using two independent gradient boosted trees.

This class extends MultiGBTLearner by maintaining two separate models: - One for policy learning (Actor). - One for value estimation (Critic).

It provides separate step_actor and step_critic methods for updating the respective models.

distil(obs: numpy.ndarray | torch.Tensor, policy_targets: numpy.ndarray, value_targets: numpy.ndarray, params: Dict, verbose: int = 0) Tuple[List[float], List[Dict]][source]

Distills the trained model into a student model.

Parameters:
  • obs (NumericalData) – Input observations.

  • policy_targets (np.ndarray) – Target values for the policy (actor).

  • value_targets (np.ndarray) – Target values for the value function (critic).

  • params (Dict) – Distillation parameters.

  • verbose (int) – Verbosity level.

Returns:

The final loss values and updated parameters for distillation.

Return type:

Tuple[List[float], List[Dict]]

predict_critic(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]

Predicts the value function (critic) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.

  • start_idx (int, optional) – Start index for prediction. Defaults to 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to True.

Returns:

Predicted value function outputs.

Return type:

NumericalData

predict_policy(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]

Predicts the policy (actor) output for the given observations.

Parameters:
  • obs (NumericalData) – Input observations.

  • requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.

  • start_idx (int, optional) – Start index for prediction. Defaults to 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to True.

Returns:

Predicted policy outputs.

Return type:

NumericalData

step_actor(inputs: numpy.ndarray | torch.Tensor, grads: numpy.ndarray | torch.Tensor) None[source]

Performs a gradient update step for the policy (actor) model.

Parameters:
  • obs (NumericalData) – Input observations.

  • theta_grad (NumericalData) – Gradient update for the policy (actor).

step_critic(inputs: numpy.ndarray | torch.Tensor, grads: numpy.ndarray | torch.Tensor) None[source]

Performs a gradient update step for the value function (critic) model.

Parameters:
  • obs (NumericalData) – Input observations.

  • value_grad (NumericalData) – Gradient update for the value function (critic).