This vignette walks demonstrates how to estimate uncertainty in predictions via the bootstrap. We use rsample throughout. First we show how to calculate confidence intervals, and then how to calculate predictive intervals. If you are unfamiliar with the difference between these two, we highly recommend that you read vignette("intervals", package = "safepredict").

Throughout this vignette we use the nonparametric bootstrap, which is more robust than the parametric bootstrap and does not require us to assumption that we have a correctly specified model. The parametric bootstrap will give you tighter intervals than the procedures we outline here, but we recommend against this unless you are very, very certain that you have correctly specified your model.

This vignette assumes you are interested in a continuous outcome.

Bootstrapped confidence intervals

Let \(X\) be the original data (containing both predictors and outcome).

  1. Sample the rows of \(X\) with replacement \(1, ..., B\) times to create bootstrapped data sets \(X_1^*, ..., X_B^*\).
  2. Fit your model of choice on each bootstrapped data set and obtain fits \(\hat f_1, ..., \hat f_B\).
  3. Predict the mean at \(X\) with each \(\hat f_i\) to get samples from the sampling distribution of \(f(X)\).
  4. Look at the appropriate quantiles of \(f(X)\). You’re done!

Let’s work through an example, using glmnet for a binary classification problem. Our goal will be to predict Attrition based on 30 predictors variables.

library(dplyr)
library(rsample)

set.seed(27)

attrition <- attrition %>% 
  sample_n(500)

glimpse(attrition)
#> Observations: 500
#> Variables: 31
#> $ Age                      <int> 32, 51, 33, 31, 31, 50, 50, 33, 27, 5...
#> $ Attrition                <fct> No, No, No, Yes, No, No, No, No, No, ...
#> $ BusinessTravel           <fct> Travel_Rarely, Travel_Rarely, Travel_...
#> $ DailyRate                <int> 234, 684, 867, 1365, 798, 691, 1115, ...
#> $ Department               <fct> Sales, Research_Development, Research...
#> $ DistanceFromHome         <int> 1, 6, 8, 13, 7, 2, 1, 3, 1, 3, 4, 24,...
#> $ Education                <ord> Master, Bachelor, Master, Master, Col...
#> $ EducationField           <fct> Medical, Life_Sciences, Life_Sciences...
#> $ EnvironmentSatisfaction  <ord> Medium, Low, Very_High, Medium, High,...
#> $ Gender                   <fct> Male, Male, Male, Male, Female, Male,...
#> $ HourlyRate               <int> 68, 51, 90, 46, 48, 64, 73, 56, 60, 4...
#> $ JobInvolvement           <ord> Medium, High, Very_High, High, Medium...
#> $ JobLevel                 <int> 1, 5, 1, 2, 3, 4, 5, 1, 2, 4, 5, 2, 2...
#> $ JobRole                  <fct> Sales_Representative, Research_Direct...
#> $ JobSatisfaction          <ord> Medium, High, Low, Low, High, High, M...
#> $ MaritalStatus            <fct> Married, Single, Married, Divorced, M...
#> $ MonthlyIncome            <int> 2269, 19537, 3143, 4233, 8943, 17639,...
#> $ MonthlyRate              <int> 18024, 6462, 6076, 11512, 14034, 6881...
#> $ NumCompaniesWorked       <int> 0, 7, 6, 2, 1, 5, 3, 1, 5, 3, 2, 2, 1...
#> $ OverTime                 <fct> No, No, No, No, No, No, Yes, Yes, No,...
#> $ PercentSalaryHike        <int> 14, 13, 19, 17, 24, 16, 19, 11, 19, 1...
#> $ PerformanceRating        <ord> Excellent, Excellent, Excellent, Exce...
#> $ RelationshipSatisfaction <ord> Medium, High, Medium, High, Low, Very...
#> $ StockOptionLevel         <int> 1, 0, 1, 0, 1, 0, 0, 0, 1, 3, 1, 0, 0...
#> $ TotalWorkingYears        <int> 3, 23, 14, 9, 10, 30, 28, 8, 6, 21, 2...
#> $ TrainingTimesLastYear    <int> 2, 5, 1, 2, 2, 3, 1, 3, 1, 5, 2, 3, 2...
#> $ WorkLifeBalance          <ord> Better, Better, Better, Bad, Better, ...
#> $ YearsAtCompany           <int> 2, 20, 10, 3, 10, 4, 8, 8, 2, 5, 1, 5...
#> $ YearsInCurrentRole       <int> 2, 18, 8, 1, 9, 3, 3, 7, 2, 3, 0, 4, ...
#> $ YearsSinceLastPromotion  <int> 2, 15, 7, 1, 8, 0, 0, 3, 2, 1, 0, 1, ...
#> $ YearsWithCurrManager     <int> 2, 15, 6, 2, 9, 3, 7, 0, 0, 3, 0, 4, ...

Since we’re using glmnet, we have to start with a bunch of preprocessing. The recipes package makes this sane.

We can now fit an L1 penalized logistic regression model, and use safe_predict() to calculate residuals.

Let’s take a quick look to sanity check our work. We’ll plot the estimated probability of attrition versus scaled distanced from home.

This passes the sanity check, we proceed to the bootstrapping.

Next we get predictions for each bootstrapped fit

and all that remains to calculate a 90 percent confidence interval is to look at the quantiles of the bootstrapped fits: