Compute partial dependence functions (i.e., marginal effects) for the predictors in a model. Note that get_pd is based on pdp::partial and adds observation weights.

get_pd(mfit, var, grid, data, subsample = nrow(data), fun = NULL, ...)

Arguments

mfit

Fitted model object (e.g., a "gbm" or "randomForest" object).

var

Character string giving the name of the predictor variable of interest. For an interaction effect, specify as "var1_var2". For now only two-way interactions are supported, so do not specify more than two variable names (i.e., only one underscore allowed in the string).

grid

Data frame containing the (joint) values of interest for the feature(s) listed in var. One column for main effects and two columns for an interaction effect, with feature names as the column names. See the documentation and examples of get_grid for details.

data

Data frame containing the original training data.

subsample

Optional integer specifying the number of observations to use for the computation of the partial dependencies. Defaults to the number of observations in data, but a smaller value saves computation time.

...

Additional optional arguments to be passed onto pdp::partial.

pred_fun

Optional prediction function to calculate feature effects for the model in mfit. Requires two arguments: object and newdata. See pdp::partial and this article for the details. See also the function gbm_fun in the example.

Value

Tidy data frame (i.e., a "tibble" object) with three (x, y, w) or four (x1, x2, y, w) columns for respectively a main and two-way interaction effect. Column(s) x contain variable values, column y the partial dependence effect and w the observation counts in data. The data frame attribute comment contains the variable name, as specified in the var argument.

See also

Examples

if (FALSE) { data('mtpl_be') features <- setdiff(names(mtpl_be), c('id', 'nclaims', 'expo', 'long', 'lat')) set.seed(12345) gbm_fit <- gbm::gbm(as.formula(paste('nclaims ~', paste(features, collapse = ' + '))), distribution = 'poisson', data = mtpl_be, n.trees = 50, interaction.depth = 3, shrinkage = 0.1) gbm_fun <- function(object, newdata) mean(predict(object, newdata, n.trees = object$n.trees, type = 'response')) gbm_fit %>% get_pd(var = 'ageph', grid = 'ageph' %>% get_grid(data = mtpl_be), data = mtpl_be, subsample = 10000, fun = gbm_fun) gbm_fit %>% get_pd(var = 'power_coverage', grid = tidyr::expand_grid('ageph' %>% get_grid(data = mtpl_be), 'coverage' %>% get_grid(data = mtpl_be)), data = mtpl_be, subsample = 10000, fun = gbm_fun) }