BaseLearner
Abstract base class for all GBRL learners. Defines the common interface for gradient boosting learners, including methods for fitting, stepping, SHAP value computation, saving/loading, and device management.
- class gbrl.learners.base.BaseLearner(input_dim: int, output_dim: int, tree_struct: Dict, params: Dict, verbose: int = 0, device: str = 'cpu')[source]
Bases:
ABC
Abstract base class for gradient boosting tree learners.
This class defines the fundamental interface for gradient boosting learners and serves as a wrapper for the C++ backend.
- input_dim
The number of input features.
- Type:
int
- output_dim
The number of output dimensions.
- Type:
int
- tree_struct
Dictionary containing tree structure parameters.
- Type:
Dict
- params
Dictionary containing model parameters.
- Type:
Union[Dict, List[Dict]]
- verbose
Verbosity level (0 = silent, 1 = debug).
- Type:
int
- device
The device the model runs on (e.g., ‘cpu’ or ‘cuda’).
- Type:
str
- iteration
The current training iteration.
- Type:
int
- total_iterations
Total number of training iterations.
- Type:
int
- feature_weights
Feature importance weights.
- Type:
np.ndarray
- abstractmethod distil(*args, **kwargs) Tuple[int, Dict] [source]
Distills the model into a smaller, simplified version.
- Returns:
Final loss and updated parameters.
- Return type:
Tuple[int, Dict]
- abstractmethod fit(*args, **kwargs) float | List[float] [source]
Trains the model on provided data.
- Returns:
The final loss after training per model.
- Return type:
Union[float, List[float]]
- abstractmethod get_bias(*args, **kwargs) ndarray | Tuple[ndarray, ...] [source]
Retrieves the model bias.
- Returns:
Bias values.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]
- abstractmethod get_device(*args, **kwargs) str | Tuple[str, ...] [source]
Retrieves the current device the model is running on.
- Returns:
Device (e.g., ‘cpu’, ‘cuda’).
- Return type:
Union[str, Tuple[str, …]]
- abstractmethod get_feature_weights(*args, **kwargs) ndarray | Tuple[ndarray, ...] [source]
Retrieves the feature importance weights.
- Returns:
Feature weights.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]
- abstractmethod get_iteration(*args, **kwargs) int | Tuple[int, ...] [source]
Retrieves the current iteration count.
- Returns:
The current iteration number.
- Return type:
Union[int, Tuple[int, …]]
- abstractmethod get_num_trees(*args, **kwargs) int | Tuple[int, ...] [source]
Retrieves the number of trees in the model.
- Returns:
Number of trees.
- Return type:
Union[int, Tuple[int, …]]
- abstractmethod get_schedule_learning_rates(*args, **kwargs) int | Tuple[int, ...] [source]
Retrieves the scheduled learning rates.
- Returns:
Learning rate(s).
- Return type:
Union[int, Tuple[int, …]]
- get_total_iterations() int [source]
Returns the total number of iterations performed.
- Returns:
The total number of iterations.
- Return type:
int
- abstractmethod classmethod load(filename: str, device: str, *args, **kwargs) BaseLearner [source]
Loads a model from a file.
- Parameters:
filename (str) – Path to the model file.
device (str) – Device to load the model onto.
- Returns:
Loaded model instance.
- Return type:
- abstractmethod plot_tree(tree_idx: int, filename: str, *args, **kwargs) None [source]
Plots a decision tree and saves it to a file.
- Parameters:
tree_idx (int) – Index of the tree.
filename (str) – Path to save the tree visualization.
- abstractmethod predict(*args, **kwargs) ndarray | Tuple[ndarray, ...] [source]
Generates predictions using the trained model.
- Returns:
Model predictions.
- Return type:
np.ndarray
- abstractmethod print_ensemble_metadata() None [source]
Prints metadata information about the entire ensemble.
- abstractmethod print_tree(tree_idx: int, *args, **kwargs) None [source]
Prints the structure of a specific decision tree.
- Parameters:
tree_idx (int) – Index of the tree.
- abstractmethod reset() None [source]
Resets the model, reinitializing internal states and parameters.
- abstractmethod save(filename: str, *args, **kwargs) None [source]
Saves the model to a file.
- Parameters:
filename (str) – The filename to save the model to.
- abstractmethod set_bias(bias: ndarray | float, *args, **kwargs) None [source]
Sets the bias term for the model.
- Parameters:
bias (Union[np.ndarray, float]) – Bias value(s).
- abstractmethod set_device(device: str | device, *args, **kwargs) None [source]
Sets the device the model should run on.
- Parameters:
device (Union[str, th.device]) – Target device.
- abstractmethod set_feature_weights(feature_weights: ndarray | float, *args, **kwargs) None [source]
Sets the feature importance weights.
- Parameters:
feature_weights (Union[np.ndarray, float]) – Feature weights.
- abstractmethod shap(features: ndarray | Tensor, *args, **kwargs) ndarray | Tuple[ndarray, ...] [source]
Computes SHAP values for the entire model.
Uses Linear tree shap for each tree in the ensemble (sequentially) Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param features: :type features: NumericalData
- Returns:
shap values
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]
- abstractmethod step(*args, **kwargs) None [source]
Performs a single update step using provided gradients.
- abstractmethod tree_shap(tree_idx: int, features: ndarray | Tensor, *args, **kwargs) ndarray | Tuple[ndarray, ...] [source]
Computes SHAP values for a specific tree.
Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param tree_idx: tree index :type tree_idx: int :param features: :type features: NumericalData
- Returns:
shap values
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]