SharedActorCriticLearner
SharedActorCriticLearner uses a single gradient boosted tree learners to represent both an actor and a critic (Value function). Useful when the actor and critic need to share tree configurations or update rules. It is a wrapper around GBTLearner and supports training, prediction, saving/loading, and SHAP value computation per ensemble.
- class gbrl.learners.actor_critic_learner.SharedActorCriticLearner(input_dim: int, output_dim: int, tree_struct: Dict, policy_optimizer: Dict, value_optimizer: Dict, params: Dict = {}, verbose: int = 0, device: str = 'cpu')[source]
Bases:
GBTLearner
SharedActorCriticLearner is a variant of GBTLearner where a single tree is used for both actor (policy) and critic (value) learning. It utilizes gradient boosting trees (GBTs) to estimate both policy and value function parameters efficiently.
- distil(obs: ndarray | Tensor, policy_targets: ndarray, value_targets: ndarray, params: Dict, verbose: int) → Tuple[float, Dict][source]
Distills the trained model into a student model.
- Parameters:
obs (NumericalData) – Input observations.
policy_targets (np.ndarray) – Target values for the policy (actor).
value_targets (np.ndarray) – Target values for the value function
(critic).
params (Dict) – Distillation parameters.
verbose (int) – Verbosity level.
- Returns:
The final loss value and updated parameters for distillation
- Return type:
Tuple[float, Dict]
- predict(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True) → Tuple[ndarray, ndarray][source]
Predicts both policy and value function outputs.
- Parameters:
obs (NumericalData) – Input observations.
requires_grad (bool, optional) – Whether to compute gradients.
True. (Defaults to)
start_idx (int, optional) – Start index for prediction. Defaults to
0.
stop_idx (int, optional) – Stop index for prediction. Defaults to
None.
tensor (bool, optional) – Whether to return a tensor. Defaults to
True.
- Returns:
Predicted policy and value outputs.
- Return type:
Tuple[np.ndarray, np.ndarray]
- predict_critic(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True)[source]
Predicts the value function (critic) output for the given observations.
- Parameters:
obs (NumericalData) – Input observations.
requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.
start_idx (int, optional) – Start index for prediction. Defaults to 0.
stop_idx (int, optional) – Stop index for prediction. Defaults to None.
tensor (bool, optional) – Whether to return a tensor. Defaults to True.
- Returns:
Predicted value function outputs.
- Return type:
np.ndarray
- predict_policy(obs: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True)[source]
Predicts the policy (actor) output for the given observations.
- Parameters:
obs (NumericalData) – Input observations.
requires_grad (bool, optional) – Whether to compute gradients.
True. (Defaults to)
start_idx (int, optional) – Start index for prediction. Defaults to
0.
stop_idx (int, optional) – Stop index for prediction. Defaults to
None.
tensor (bool, optional) – Whether to return a tensor. Defaults to
True.
- Returns:
Predicted policy outputs.
- Return type:
np.ndarray
- step(obs: ndarray | Tensor, theta_grad: ndarray, value_grad: ndarray) → None[source]
Performs a gradient update step for both policy and value function.
- Parameters:
obs (NumericalData) – Input observations.
theta_grad (np.ndarray) – Gradient of the policy parameters.
value_grad (np.ndarray) – Gradient of the value function parameters.