SeparateActorCriticLearner
SeparateActorCriticLearner uses two distinct gradient boosted tree learners to represent an actor and a critic (Value function). Useful when the actor and critic need to be trained independently with different tree configurations or update rules. It is a wrapper around MultiGBTLearner and supports training, prediction, saving/loading, and SHAP value computation per ensemble.
- class gbrl.learners.actor_critic_learner.SeparateActorCriticLearner(input_dim: int, output_dim: int, tree_struct: Dict, policy_optimizer: Dict, value_optimizer: Dict, params: Dict = {}, verbose: int = 0, device: str = 'cpu')[source]
Bases:
MultiGBTLearnerImplements a separate actor-critic learner using two independent gradient boosted trees.
This class extends MultiGBTLearner by maintaining two separate models: - One for policy learning (Actor). - One for value estimation (Critic).
It provides separate step_actor and step_critic methods for updating the respective models.
- distil(obs: numpy.ndarray | torch.Tensor, policy_targets: numpy.ndarray, value_targets: numpy.ndarray, params: Dict, verbose: int = 0) Tuple[List[float], List[Dict]][source]
Distills the trained model into a student model.
- Parameters:
obs (NumericalData) – Input observations.
policy_targets (np.ndarray) – Target values for the policy (actor).
value_targets (np.ndarray) – Target values for the value function (critic).
params (Dict) – Distillation parameters.
verbose (int) – Verbosity level.
- Returns:
The final loss values and updated parameters for distillation.
- Return type:
Tuple[List[float], List[Dict]]
- predict_critic(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]
Predicts the value function (critic) output for the given observations.
- Parameters:
obs (NumericalData) – Input observations.
requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.
start_idx (int, optional) – Start index for prediction. Defaults to 0.
stop_idx (int, optional) – Stop index for prediction. Defaults to None.
tensor (bool, optional) – Whether to return a tensor. Defaults to True.
- Returns:
Predicted value function outputs.
- Return type:
NumericalData
- predict_policy(obs: numpy.ndarray | torch.Tensor, requires_grad: bool = True, start_idx: int | None = None, stop_idx: int | None = None, tensor: bool = True) numpy.ndarray | torch.Tensor[source]
Predicts the policy (actor) output for the given observations.
- Parameters:
obs (NumericalData) – Input observations.
requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.
start_idx (int, optional) – Start index for prediction. Defaults to 0.
stop_idx (int, optional) – Stop index for prediction. Defaults to None.
tensor (bool, optional) – Whether to return a tensor. Defaults to True.
- Returns:
Predicted policy outputs.
- Return type:
NumericalData
- step_actor(inputs: numpy.ndarray | torch.Tensor, grads: numpy.ndarray | torch.Tensor) None[source]
Performs a gradient update step for the policy (actor) model.
- Parameters:
obs (NumericalData) – Input observations.
theta_grad (NumericalData) – Gradient update for the policy (actor).
- step_critic(inputs: numpy.ndarray | torch.Tensor, grads: numpy.ndarray | torch.Tensor) None[source]
Performs a gradient update step for the value function (critic) model.
- Parameters:
obs (NumericalData) – Input observations.
value_grad (NumericalData) – Gradient update for the value function (critic).