ContinuousCritic Class
This class implements a GBT-based Continuous Critic learner for reinforcement learning. The ContinuousCritic class is designed to output the parameters of a differentiable Critic function (e.g., linear, quadratic). Usage examples: GBT-based SAC/DDPG/TD3 implementations.
- class gbrl.ac_gbrl.ContinuousCritic(tree_struct: Dict, output_dim: int, weights_optimizer: Dict, bias_optimizer: Dict = None, gbrl_params: Dict = {}, target_update_interval: int = 100, bias: ndarray = None, verbose: int = 0, device: str = 'cpu')[source]
Bases:
GBRL
- get_num_trees() int [source]
Get number of trees in model.
- Returns:
return number of trees
- Return type:
int
- predict_target(observations: ndarray | Tensor, tensor: bool = True) Tuple[Tensor | ndarray, Tensor | ndarray] [source]
Predict the parameters of a Target Continuous Critic as Tensors. Prediction is made by summing the outputs the trees from Continuous Critic model up to n_trees - target_update_interval.
- Parameters:
observations (Union[np.ndarray, th.Tensor])
tensor (bool, optional) – Return PyTorch Tensor, False returns a numpy array. Defaults to True.
- Returns:
weights and bias parameters to the type of Q-functions
- Return type:
Tuple[th.Tensor, th.Tensor]
- step(observations: ndarray | Tensor, q_grad_clip: float = None, weight_grad: ndarray | Tensor | None = None, bias_grad: ndarray | Tensor | None = None) None [source]
Performs a single boosting step
- Parameters:
observations (Union[np.ndarray, th.Tensor])
q_grad_clip (float, optional) – . Defaults to None.
- weight_grad (Optional[Union[np.ndarray, th.Tensor]], optional): manually calculated gradients. Defaults to None.
bias_grad (Optional[Union[np.ndarray, th.Tensor]], optional): manually calculated gradients. Defaults to None.