ContinuousCritic
GBRL model designed for continuous-action Q-learning tasks. Outputs parameterized Q-functions (linear, quadratic, or tanh forms) where the Q-function is differentiable w.r.t. actions, allowing integration with algorithms like SAC.
- class gbrl.models.critic.ContinuousCritic(tree_struct: Dict, input_dim: int, output_dim: int, weights_optimizer: Dict, bias_optimizer: Dict = None, params: Dict = {}, target_update_interval: int = 100, bias: ndarray = None, verbose: int = 0, device: str = 'cpu')[source]
Bases:
BaseGBT
GBRL model for a Continuous Critic ensemble. Designed for Q-function approximation in continuous action spaces, such as SAC. Model is designed to output parameters of 3 types of Q-functions: - linear Q(theta(s), a) = <w_theta, a> + b_theta, (<> denotes a dot product). - quadratic Q(theta(s), a) = -(<w_theta, a> - b_theta)**2 + c_theta. - tanh Q(theta(s), a) = b_theta*tanh(<w_theta, a>)
This allows to pass derivatives w.r.t to action a while the Q parameters are a function of a GBT model theta. The target model is approximated as the ensemble without the last <target_update_interval> trees.
- predict_target(observations: ndarray | Tensor, tensor: bool = True) Tuple[Tensor | ndarray, Tensor | ndarray] [source]
Predict the parameters of a Target Continuous Critic as Tensors. Prediction is made by summing the outputs the trees from Continuous Critic model up to n_trees - target_update_interval.
- Parameters:
observations (NumericalData)
tensor (bool, optional) – Return PyTorch Tensor, False returns a numpy array. Defaults to True.
- Returns:
weights and bias parameters to thetype of Q-functions
- Return type:
Tuple[th.Tensor, th.Tensor]
- step(observations: ndarray | Tensor | None = None, weight_grad: ndarray | Tensor | None = None, bias_grad: ndarray | Tensor | None = None, q_grad_clip: float | None = None) None [source]
Performs a single boosting step
- Parameters:
observations (NumericalData)
q_grad_clip (float, optional) – . Defaults to None.
weight_grad (Optional[NumericalData], optional) – manually calculated gradients. Defaults to None.
bias_grad (Optional[NumericalData], optional) – manually calculated gradients. Defaults to None.