MultiGBTLearner
A multi-model extension of GBTLearner supporting several independent learners, each with its own parameters and optimizer. Useful in actor-critic settings or multi-head architectures. It supports training, prediction, saving/loading, and SHAP value computation per ensemble.
- class gbrl.learners.multi_gbt_learner.MultiGBTLearner(input_dim: int | List[int], output_dim: int | List[int], tree_struct: Dict, optimizers: Dict | List[Dict], params: Dict, n_learners: int, verbose: int = 0, device: str = 'cpu')[source]
Bases:
BaseLearner
MultiGBTLearner is a gradient boosted tree learner that utilizes a C++ backend for efficient computation that contains multiple GBT models. It supports training, prediction, saving, loading, and SHAP value computation.
- distil(obs: ndarray | Tensor, targets: List[ndarray], params: Dict, verbose: int = 0) Tuple[List[int], List[Dict]] [source]
Distills the model into a student model.
- Parameters:
obs (NumericalData) – Input observations.
targets (np.ndarray) – Target values.
params (Dict) – Distillation parameters.
verbose (int, optional) – Verbosity level. Defaults to 0.
- Returns:
The final loss and updated parameters.
- Return type:
Tuple[List[int], List[Dict]]
- export(filename: str, modelname: str = None) None [source]
Exports the model to a C header file.
- Parameters:
filename (str) – The filename to export the model to.
modelname (str, optional) – The name of the model in the C code.
None. (Defaults to)
- fit(features: ndarray | Tensor, targets: List[ndarray | Tensor] | ndarray | Tensor, iterations: int, shuffle: bool = True, loss_type: str = 'MultiRMSE', model_idx: int | None = None) float | List[float] [source]
Fits the model to the provided features and targets for a given number of iterations.
- Parameters:
features (NumericalData) – Input features.
targets (Union[List[NumericalData], NumericalData]) – Target values.
iterations (int) – Number of training iterations.
shuffle (bool, optional) – Whether to shuffle the data. Defaults to True.
loss_type (str, optional) – Type of loss function. Defaults to ‘MultiRMSE’.
model_idx (int, optional) – The index of the model.
- Returns:
The final loss value.
- Return type:
Union[float, List[float]]
- get_bias(model_idx: int | None = None) ndarray | Tuple[ndarray, ...] [source]
Returns the bias of the model.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The bias.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]
- get_device(model_idx: int | None = None) str | Tuple[str, ...] [source]
Returns the device the model is running on.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The device.
- Return type:
Union[str, Tuple[str, …]]
- get_feature_weights(model_idx: int | None = None) ndarray | Tuple[ndarray, ...] [source]
Returns the feature weights of the model.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The feature weights.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]]
- get_iteration(model_idx: int | None = None) int | Tuple[int, int] [source]
Returns the current iteration number.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The current iteration number.
- Return type:
Union[int, Tuple[int, int]]
- get_num_trees(model_idx: int | None = None) int | Tuple[int, int] [source]
Returns the total number of trees in the ensemble.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The total number of trees.
- Return type:
Union[int, Tuple[int, int]]
- get_schedule_learning_rates(model_idx: int | None = None) int | Tuple[int, int] [source]
Returns the learning rates of the schedulers.
- Parameters:
model_idx (int, optional) – The index of the model.
- Returns:
The learning rates.
- Return type:
Union[int, Tuple[int, int]]
- classmethod load(filename: str, device: str) MultiGBTLearner [source]
Loads a MultiGBTLearner model from files.
- Parameters:
filename (str) – The filename to load the model from.
device (str) – The device to load the model onto.
- Returns:
The loaded GBTLearner instance.
- Return type:
- plot_tree(tree_idx: int, filename: str, model_idx: int | None = None) None [source]
Plots the tree at the given index and saves it to a file.
- Parameters:
tree_idx (int) – The index of the tree to plot.
filename (str) – The filename to save the plot to.
model_idx (int, optional) – The index of the model to print.
- predict(features: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True, model_idx: int | None = None) ndarray | Tensor | List[ndarray | Tensor] [source]
Predicts the output for the given features.
- Parameters:
features (NumericalData) – Input features.
requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.
start_idx (int, optional) – Start index for prediction. Defaults to 0.
stop_idx (int, optional) – Stop index for prediction. Defaults to None.
tensor (bool, optional) – Whether to return a tensor. Defaults to True.
model_idx (int, optional) – The index of the model to print.
- Returns:
The predicted output.
- Return type:
Union[NumericalData, List[NumericalData]]
- print_ensemble_metadata(model_idx: int | None = None) None [source]
Prints the metadata of the ensemble.
- Parameters:
model_idx (int, optional) – The index of the model.
- print_tree(tree_idx: int, model_idx: int | None = None) None [source]
Prints the tree at the given index.
- Parameters:
tree_idx (int) – The index of the tree to print.
model_idx (int, optional) – The index of the model to print.
- reset() None [source]
Resets the learner to its initial state, reinitializing the C++ model and optimizers.
- save(filename: str, custom_names: List | None = None) None [source]
Saves the models to a file.
- Parameters:
filename (str) – The filename to save the model to.
- set_bias(bias: ndarray | float, model_idx: int | None = None) None [source]
Sets the bias of the model.
- Parameters:
bias (Union[np.ndarray, float]) – The bias value.
model_idx (int, optional) – model index to set bias to.
- set_device(device: str | device, model_idx: int | None = None) None [source]
Sets the device the model should run on.
- Parameters:
device (Union[str, th.device]) – The device to set.
model_idx (int, optional) – The index of the model to print.
- set_feature_weights(feature_weights: ndarray | float, model_idx: int | None = None) None [source]
Sets the feature weights of the model.
- Parameters:
feature_weights (Union[np.ndarray, float]) – The feature weights.
model_idx (int, optional) – The index of the model.
- shap(features: ndarray | Tensor, model_idx: int | None = None) ndarray | Tuple[ndarray, ...] [source]
Computes SHAP values for the entire ensemble.
Uses Linear tree shap for each tree in the ensemble (sequentially) Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param features: :type features: NumericalData :param model_idx: The index of the model to print. :type model_idx: int, optional
- Returns:
shap values
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]
- step(features: ndarray | Tensor | Tuple[ndarray | Tensor, ...], grads: List[ndarray | Tensor] | ndarray | Tensor, model_idx: int | None = None) None [source]
Performs a single gradient update step (e.g, adding a single decision tree).
- Parameters:
features (Union[np.ndarray, th.Tensor, Tuple]) – Input features.
grads (Union[List[NumericalData], NumericalData]) – Gradients.
model_idx (int, optional) – The index of the model.
- tree_shap(tree_idx: int, features: ndarray | Tensor, model_idx: int | None = None) ndarray | Tuple[ndarray, ...] [source]
Computes SHAP values for a single tree.
Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param tree_idx: tree index :type tree_idx: int :param features: :type features: NumericalData :param model_idx: The index of the model to print. :type model_idx: int, optional
- Returns:
shap values
- Return type:
Union[np.ndarray, Tuple[np.ndarray, …]