DiscreteCritic
GBRL model tailored for discrete-action Q-learning tasks. Represents Q-values as a tree ensemble and supports target network approximations by omitting the latest boosting steps, enabling stable training in DQN-like settings.
- class gbrl.models.critic.DiscreteCritic(tree_struct: Dict, input_dim: int, output_dim: int, critic_optimizer: Dict, params: Dict = {}, target_update_interval: int = 100, bias: ndarray = None, verbose: int = 0, device: str = 'cpu')[source]
Bases:
BaseGBT
GBRL model for a Discrete Critic ensemble. Used for Q-function approximation in discrete action spaces. The target model is approximated as the ensemble without the last <target_update_interval> trees.
- predict_target(observations: ndarray | Tensor, tensor: bool = True) Tensor [source]
- Predict and return Target Critic’s outputs as Tensors.
Prediction is made by summing the outputs the trees from Continuous Critic model up to n_trees - target_update_interval.
- Parameters:
observations (NumericalData)
- Returns:
Target Critic’s outputs.
- Return type:
th.Tensor
- step(observations: ndarray | Tensor | None = None, q_grad: ndarray | Tensor | None = None, max_q_grad_norm: ndarray | None = None) None [source]
Performs a single boosting iterations.
- Parameters:
observations (NumericalData)
max_q_grad_norm (np.ndarray, optional)
q_grad (Optional[NumericalData], optional) – manually calculated
None. (gradients. Defaults to)