MultiGBTLearner

A multi-model extension of GBTLearner supporting several independent learners, each with its own parameters and optimizer. Useful in actor-critic settings or multi-head architectures. It supports training, prediction, saving/loading, and SHAP value computation per ensemble.

class gbrl.learners.multi_gbt_learner.MultiGBTLearner(input_dim: int | List[int], output_dim: int | List[int], tree_struct: Dict, optimizers: Dict | List[Dict], params: Dict, n_learners: int, verbose: int = 0, device: str = 'cpu')[source]

Bases: BaseLearner

MultiGBTLearner is a gradient boosted tree learner that utilizes a C++ backend for efficient computation that contains multiple GBT models. It supports training, prediction, saving, loading, and SHAP value computation.

distil(obs: ndarray | Tensor, targets: List[ndarray], params: Dict, verbose: int = 0) Tuple[List[int], List[Dict]][source]

Distills the model into a student model.

Parameters:
  • obs (NumericalData) – Input observations.

  • targets (np.ndarray) – Target values.

  • params (Dict) – Distillation parameters.

  • verbose (int, optional) – Verbosity level. Defaults to 0.

Returns:

The final loss and updated parameters.

Return type:

Tuple[List[int], List[Dict]]

export(filename: str, modelname: str = None) None[source]

Exports the model to a C header file.

Parameters:
  • filename (str) – The filename to export the model to.

  • modelname (str, optional) – The name of the model in the C code.

  • None. (Defaults to)

fit(features: ndarray | Tensor, targets: List[ndarray | Tensor] | ndarray | Tensor, iterations: int, shuffle: bool = True, loss_type: str = 'MultiRMSE', model_idx: int | None = None) float | List[float][source]

Fits the model to the provided features and targets for a given number of iterations.

Parameters:
  • features (NumericalData) – Input features.

  • targets (Union[List[NumericalData], NumericalData]) – Target values.

  • iterations (int) – Number of training iterations.

  • shuffle (bool, optional) – Whether to shuffle the data. Defaults to True.

  • loss_type (str, optional) – Type of loss function. Defaults to ‘MultiRMSE’.

  • model_idx (int, optional) – The index of the model.

Returns:

The final loss value.

Return type:

Union[float, List[float]]

get_bias(model_idx: int | None = None) ndarray | Tuple[ndarray, ...][source]

Returns the bias of the model.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The bias.

Return type:

Union[np.ndarray, Tuple[np.ndarray, …]]

get_device(model_idx: int | None = None) str | Tuple[str, ...][source]

Returns the device the model is running on.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The device.

Return type:

Union[str, Tuple[str, …]]

get_feature_weights(model_idx: int | None = None) ndarray | Tuple[ndarray, ...][source]

Returns the feature weights of the model.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The feature weights.

Return type:

Union[np.ndarray, Tuple[np.ndarray, …]]

get_iteration(model_idx: int | None = None) int | Tuple[int, int][source]

Returns the current iteration number.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The current iteration number.

Return type:

Union[int, Tuple[int, int]]

get_num_trees(model_idx: int | None = None) int | Tuple[int, int][source]

Returns the total number of trees in the ensemble.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The total number of trees.

Return type:

Union[int, Tuple[int, int]]

get_schedule_learning_rates(model_idx: int | None = None) int | Tuple[int, int][source]

Returns the learning rates of the schedulers.

Parameters:

model_idx (int, optional) – The index of the model.

Returns:

The learning rates.

Return type:

Union[int, Tuple[int, int]]

classmethod load(filename: str, device: str) MultiGBTLearner[source]

Loads a MultiGBTLearner model from files.

Parameters:
  • filename (str) – The filename to load the model from.

  • device (str) – The device to load the model onto.

Returns:

The loaded GBTLearner instance.

Return type:

GBTLearner

plot_tree(tree_idx: int, filename: str, model_idx: int | None = None) None[source]

Plots the tree at the given index and saves it to a file.

Parameters:
  • tree_idx (int) – The index of the tree to plot.

  • filename (str) – The filename to save the plot to.

  • model_idx (int, optional) – The index of the model to print.

predict(features: ndarray | Tensor, requires_grad: bool = True, start_idx: int = 0, stop_idx: int = None, tensor: bool = True, model_idx: int | None = None) ndarray | Tensor | List[ndarray | Tensor][source]

Predicts the output for the given features.

Parameters:
  • features (NumericalData) – Input features.

  • requires_grad (bool, optional) – Whether to compute gradients. Defaults to True.

  • start_idx (int, optional) – Start index for prediction. Defaults to 0.

  • stop_idx (int, optional) – Stop index for prediction. Defaults to None.

  • tensor (bool, optional) – Whether to return a tensor. Defaults to True.

  • model_idx (int, optional) – The index of the model to print.

Returns:

The predicted output.

Return type:

Union[NumericalData, List[NumericalData]]

print_ensemble_metadata(model_idx: int | None = None) None[source]

Prints the metadata of the ensemble.

Parameters:

model_idx (int, optional) – The index of the model.

print_tree(tree_idx: int, model_idx: int | None = None) None[source]

Prints the tree at the given index.

Parameters:
  • tree_idx (int) – The index of the tree to print.

  • model_idx (int, optional) – The index of the model to print.

reset() None[source]

Resets the learner to its initial state, reinitializing the C++ model and optimizers.

save(filename: str, custom_names: List | None = None) None[source]

Saves the models to a file.

Parameters:

filename (str) – The filename to save the model to.

set_bias(bias: ndarray | float, model_idx: int | None = None) None[source]

Sets the bias of the model.

Parameters:
  • bias (Union[np.ndarray, float]) – The bias value.

  • model_idx (int, optional) – model index to set bias to.

set_device(device: str | device, model_idx: int | None = None) None[source]

Sets the device the model should run on.

Parameters:
  • device (Union[str, th.device]) – The device to set.

  • model_idx (int, optional) – The index of the model to print.

set_feature_weights(feature_weights: ndarray | float, model_idx: int | None = None) None[source]

Sets the feature weights of the model.

Parameters:
  • feature_weights (Union[np.ndarray, float]) – The feature weights.

  • model_idx (int, optional) – The index of the model.

shap(features: ndarray | Tensor, model_idx: int | None = None) ndarray | Tuple[ndarray, ...][source]

Computes SHAP values for the entire ensemble.

Uses Linear tree shap for each tree in the ensemble (sequentially) Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param features: :type features: NumericalData :param model_idx: The index of the model to print. :type model_idx: int, optional

Returns:

shap values

Return type:

Union[np.ndarray, Tuple[np.ndarray, …]

step(features: ndarray | Tensor | Tuple[ndarray | Tensor, ...], grads: List[ndarray | Tensor] | ndarray | Tensor, model_idx: int | None = None) None[source]

Performs a single gradient update step (e.g, adding a single decision tree).

Parameters:
  • features (Union[np.ndarray, th.Tensor, Tuple]) – Input features.

  • grads (Union[List[NumericalData], NumericalData]) – Gradients.

  • model_idx (int, optional) – The index of the model.

tree_shap(tree_idx: int, features: ndarray | Tensor, model_idx: int | None = None) ndarray | Tuple[ndarray, ...][source]

Computes SHAP values for a single tree.

Implementation based on - https://github.com/yupbank/linear_tree_shap See Linear TreeShap, Yu et al, 2023, https://arxiv.org/pdf/2209.08192 :param tree_idx: tree index :type tree_idx: int :param features: :type features: NumericalData :param model_idx: The index of the model to print. :type model_idx: int, optional

Returns:

shap values

Return type:

Union[np.ndarray, Tuple[np.ndarray, …]