GBTModel
Generic GBT Model for supervised learning. Can be used independently of reinforcement learning pipelines to perform standard regression or classification with gradient-boosted trees using custom optimizers and loss functions.
- class gbrl.models.gbt.GBTModel(tree_struct: Dict, input_dim: int, output_dim: int, optimizers: Dict | List[Dict], params: Dict = {}, verbose: int = 0, device: str = 'cpu')[source]
Bases:
BaseGBT
General class for gradient boosting trees
- copy() GBTModel [source]
Copy class instance
- Returns:
copy of current instance. The actual type will be the type of the subclass that calls this method.
- Return type:
GradientBoostingTrees
- fit(X: ndarray | Tensor, targets: ndarray | Tensor, iterations: int, shuffle: bool = True, loss_type: str = 'MultiRMSE') float [source]
Fit multiple iterations (as in supervised learning)
- Parameters:
X (NumericalData) – inputs
targets (NumericalData) – targets
iterations (int) – number of boosting iterations
shuffle (bool, optional) – Shuffle dataset. Defaults to True.
loss_type (str, optional) – Loss to use (only MultiRMSE is currently implemented ). Defaults to ‘MultiRMSE’.
- Returns:
final loss over all examples.
- Return type:
float
- classmethod load_learner(load_name: str, device: str) GBTModel [source]
Loads GBRL model from a file
- Parameters:
load_name (str) – full path to file name
- Returns:
GBRL instance
- plot_tree(tree_idx: int, filename: str) None [source]
Plots tree using (only works if GBRL was compiled with graphviz)
- Parameters:
tree_idx (int) – tree index to plot
filename (str) – .png filename to save
- print_tree(tree_idx: int) None [source]
Prints tree information
- Parameters:
tree_idx (int) – tree index to print
- set_bias_from_targets(targets: ndarray | Tensor) None [source]
Sets bias as mean of targets
- Parameters:
targets (NumericalData) – Targets
- step(X: ndarray | Tensor | None = None, grad: ndarray | Tensor | None = None, max_grad_norm: float | None = None) None [source]
Perform a boosting step (fits a single tree on the gradients)
- Parameters:
X (NumericalData) – inputs
max_grad_norm (float, optional) – perform gradient clipping by norm. Defaults to None.
grad (Optional[NumericalData], optional) – manually calculated gradients. Defaults to None.