DiscreteCritic Class

This class implements a GBT-based Discrete Critic learner for reinforcement learning. The DiscreteCritic class is designed to output the parameters of a discrete Critic (e.g., Q-function). Usage examples: GBT-based DQN implementations.

class gbrl.ac_gbrl.DiscreteCritic(tree_struct: Dict, output_dim: int, critic_optimizer: Dict, gbrl_params: Dict = {}, target_update_interval: int = 100, bias: ndarray = None, verbose: int = 0, device: str = 'cpu')[source]

Bases: GBRL

predict_target(observations: ndarray | Tensor, tensor: bool = True) Tensor[source]
Predict and return Target Critic’s outputs as Tensors.

Prediction is made by summing the outputs the trees from Continuous Critic model up to n_trees - target_update_interval.

Parameters:

observations (Union[np.ndarray, th.Tensor])

Returns:

Target Critic’s outputs.

Return type:

th.Tensor

step(observations: ndarray | Tensor, max_q_grad_norm: ndarray = None, q_grad: ndarray | Tensor | None = None) None[source]

Performs a single boosting iterations.

Parameters:
  • observations (Union[np.ndarray, th.Tensor])

  • max_q_grad_norm (np.ndarray, optional)

  • q_grad (Optional[Union[np.ndarray, th.Tensor]], optional) – manually calculated gradients. Defaults to None.