Obtain insights from a black box model in the form of feature effects.

insights(
  mfit,
  vars,
  data,
  interactions = "user",
  hcut = 0.75,
  pred_fun = NULL,
  fx_in = NULL,
  ncores = -1
)

Arguments

mfit

Fitted model object (e.g., a "gbm" or "randomForest" object).

vars

Character vector specifying the features to get insights on.

data

Data frame containing the original training data.

interactions

String specifying how to deal with interaction effects:

'user'

specify interactions in vars as "var1_var2".

'auto'

automatic selection of interactions based on hcut.

hcut

Numeric in the range [0,1] specifying the cut-off value for the normalized cumulative H-statistic over all two-way interactions, ordered from most to least important, between the features in vars. Note that hcut = 0 will add the single most important interaction, while hcut = 1 will add all possible two-way interactions.

pred_fun

Optional prediction function to calculate feature effects for the model in mfit. Requires two arguments: object and newdata. See pdp::partial and this article for the details. See also the function gbm_fun in the example.

fx_in

Optional named list of data frames containing feature effects for features in vars that are already calculated beforehand, to avoid having to calculate these again. A possible use case is to supply the main effects such that only the interaction effects still need to be calculated. Precalculated interactions are ignored when interactions = "auto", but can be supplied when interactions = "user". It is important to make sure that you supply the pure interaction effects.

ncores

Integer specifying the number of cores to use. The default ncores = -1 uses all the available physical cores (not threads), as determined by parallel::detectCores(logical = 'FALSE').

Value

List of tidy data frames (i.e., "tibble" objects), containing the partial dependencies for the features (and interactions) in vars.

Examples

if (FALSE) { data('mtpl_be') features <- setdiff(names(mtpl_be), c('id', 'nclaims', 'expo', 'long', 'lat')) set.seed(12345) gbm_fit <- gbm::gbm(as.formula(paste('nclaims ~', paste(features, collapse = ' + '))), distribution = 'poisson', data = mtpl_be, n.trees = 50, interaction.depth = 3, shrinkage = 0.1) gbm_fun <- function(object, newdata) mean(predict(object, newdata, n.trees = object$n.trees, type = 'response')) gbm_fit %>% insights(vars = c('ageph', 'bm', 'coverage', 'fuel'), data = mtpl_be, interactions = 'auto', hcut = 0.75, pred_fun = gbm_fun) gbm_fit %>% insights(vars = c('ageph', 'bm', 'coverage', 'fuel', 'bm_fuel'), data = mtpl_be, interactions = 'user', pred_fun = gbm_fun) }