library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3proba)
library(distr6)
library(BART) # 2.9.4
library(dplyr)
library(tidyr)
library(tibble)
library(ggplot2)
Intro
Here are some interesting reads regarding BART:
- The first BART paper (Chipman, George, and McCulloch 2010).
- The first implementation of BART for survival data (Bonato et al. 2011). This includes fully parametric AFT and Weibull models and the semi-parametric CoxPH regression model.
- The first non-parametric implementation of BART for survival data (R. A. Sparapani et al. 2016)
BART
R package tutorial (R. Sparapani, Spanbauer, and McCulloch 2021)
We incorporated the survival BART
model in mlr3extralearners
and in this tutorial we will demonstrate how we can use packages like mlr3
, mlr3proba
and distr6
to more easily manipulate the output predictions to assess model convergence, validate our model (via several survival metrics), as well as perform model interpretation via PDPs (Partial Dependence Plots).
Libraries
Data
We will use the Lung Cancer Dataset. We convert the time
variable from days to months to ease the computational burden:
= tsk('lung')
task_lung
= task_lung$data()
d # in case we want to select specific columns to keep
# d = d[ ,colnames(d) %in% c("time", "status", "age", "sex", "ph.karno"), with = FALSE]
$time = ceiling(d$time/30.44)
d= as_task_surv(d, time = 'time', event = 'status', id = 'lung')
task_lung $label = "Lung Cancer" task_lung
- The original
BART
implementation supports categorical features (factors). This results in different importance scores per each dummy level which doesn’t work well withmlr3
. So features of typefactor
orcharacter
are not allowed and we leave it to the user to encode them as they please. - The original
BART
implementation supports features with missing values. This is totally fine withmlr3
as well! In this example, we impute the features to show good ML practice.
In our lung dataset, we encode the sex
feature and perform model-based imputation with the rpart
regression learner:
= po('encode', method = 'treatment')
po_encode = po('imputelearner', lrn('regr.rpart'))
po_impute = po_encode %>>% po_impute
pre = pre$train(task_lung)[[1]]
task task
<TaskSurv:lung> (228 x 10): Lung Cancer
* Target: time, status
* Properties: -
* Features (8):
- int (7): age, inst, meal.cal, pat.karno, ph.ecog, ph.karno, wt.loss
- dbl (1): sex
No missing values in our data:
$missings() task
time status age sex inst meal.cal pat.karno ph.ecog ph.karno wt.loss
0 0 0 0 0 0 0 0 0 0
We partition the data to train and test sets:
set.seed(42)
= partition(task, ratio = 0.9) part
Train and Test
We train the BART
model and predict on the test set:
# default `ndpost` value: 1000. We reduce it to 50 to speed up calculations in this tutorial
= lrn("surv.bart", nskip = 250, ndpost = 50, keepevery = 10, mc.cores = 10)
learner $train(task, row_ids = part$train)
learner= learner$predict(task, row_ids = part$test)
p p
<PredictionSurv> for 23 observations:
row_ids time status crank distr
9 8 TRUE 66.19326 <list[1]>
10 6 TRUE 98.43005 <list[1]>
21 10 TRUE 54.82313 <list[1]>
---
160 13 FALSE 37.82089 <list[1]>
163 10 FALSE 69.63534 <list[1]>
194 8 FALSE 81.13678 <list[1]>
See more details about BART
’s parameters on the online documentation.
distr
What kind of object is the predicted distr
?
$distr p
Arrdist(23x31x50)
- Patients (observations)
- Time points (months)
- Number of posterior draws
Actually the $distr
is an active R6 field - this means that some computation is required to create it. What the prediction object actually stores internally is a 3d survival array (can be used directly with no performance overhead):
dim(p$data$distr)
[1] 23 31 50
This is a more easy-to-understand and manipulate form of the full posterior survival matrix prediction from the BART
package ((R. Sparapani, Spanbauer, and McCulloch 2021), pages 34-35).
Though we have optimized with C++ code the way the Arrdist
object is constructed, calling the $distr
field can be computationally taxing if the product of the sizes of the 3 dimensions above exceeds ~1 million. In our case, \(23 \times 31 \times 50 = 35650\) so the conversion to an Arrdist
via $distr
will certainly not create performance issues.
An example using the internal prediction data: get all the posterior probabilities of the 3rd patient in the test set, at 12 months (1 year):
$data$distr[3, 12, ] p
[1] 0.26546909 0.27505937 0.21151435 0.46700513 0.26178380 0.24040003 0.29946469 0.52357780 0.40833108 0.40367780
[11] 0.27027392 0.31781286 0.54151844 0.34460027 0.41826554 0.41866367 0.33694401 0.34511270 0.47244492 0.49423660
[21] 0.42069678 0.20095489 0.48696980 0.48409357 0.35649439 0.47969355 0.16355660 0.33728105 0.40245228 0.42418033
[31] 0.36336145 0.48181667 0.51858238 0.49635078 0.37238179 0.26694030 0.52219952 0.48992897 0.08572207 0.30306005
[41] 0.33881682 0.33463870 0.29102074 0.43176131 0.38554545 0.38053756 0.36808776 0.13772665 0.21898264 0.14552514
Working with the $distr
interface and Arrdist
objects is very efficient as we will see later for predicting survival estimates.
In survival analysis, \(S(t) = 1 - F(t)\), where \(S(t)\) the survival function and \(F(t)\) the cumulative distribution function (cdf). The latter can be interpreted as risk
or probability of death up to time \(t\).
We can verify the above from the prediction object:
= 1 - distr6::gprm(p$distr, "cdf") # 3d array
surv_array ::expect_equal(p$data$distr, surv_array) testthat
crank
crank
is the expected mortality (Sonabend, Bender, and Vollmer 2022) which is the sum of the predicted cumulative hazard function (as is done in random survival forest models). Higher values denote larger risk. To calculate crank
, we need a survival matrix. So we have to choose which 3rd dimension we should use from the predicted survival array. This is what the which.curve
parameter of the learner
does:
$param_set$get_values()$which.curve learner
[1] 0.5
The default value (\(0.5\) quantile) is the median survival probability. It could be any other quantile (e.g. \(0.25\)). Other possible values for which.curve
are mean
or a number denoting the exact posterior draw to extract (e.g. the last one, which.curve = 50
).
Feature importance
Default score is the observed count of each feature in the trees (so the higher the score, the more important the feature):
$param_set$values$importance learner
[1] "count"
$importance() learner
sex meal.cal inst pat.karno ph.karno wt.loss age ph.ecog
7.84 7.46 7.08 6.76 6.60 6.46 5.48 5.42
MCMC Diagnostics
BART
uses internally MCMC (Markov Chain Monte Carlo) to sample from the posterior survival distribution. We need to check that MCMC has converged, meaning that the chains have reached a stationary distribution that approximates the true posterior survival distribution (otherwise the predictions may be inaccurate, misleading and unreliable).
We use Geweke’s convergence diagnostic test as it is implemented in the BART
R package. We choose 10 random patients from the train set to evaluate the MCMC convergence.
# predictions on the train set
= learner$predict(task, row_ids = part$train)
p_train
# choose 10 patients from the train set randomly and make a list
= as.list(sample(length(part$train), 10))
ids
= lapply(ids, function(id) {
z_list # matrix with columns => time points and rows => posterior draws
= 1 - t(distr6::gprm(p_train$distr[id], "cdf")[1,,])
post_surv ::gewekediag(post_surv)$z # get the z-scores
BART
})
# plot the z scores vs time for all patients
::bind_rows(z_list) %>%
dplyr::pivot_longer(cols = everything()) %>%
tidyrmutate(name = as.numeric(name)) %>%
ggplot(aes(x = name, y = value)) +
geom_point() +
labs(x = "Time (months)", y = "Z-scores") +
# add critical values for a = 0.05
geom_hline(yintercept = 1.96, linetype = 'dashed', color = "red") +
geom_hline(yintercept = -1.96, linetype = 'dashed', color = "red") +
theme_bw(base_size = 14)
Performance (test set)
We will use the following survival metrics:
- Integrated Brier Score (requires a survival distribution prediction -
distr
) - Uno’s C-index (requires a continuous ranking score prediction -
crank
)
For the first measure we will use the ERV (Explained Residual Variation) version, which standardizes the score against a Kaplan-Meier (KM) baseline (Sonabend et al. 2022). This means that values close to \(0\) represent performance similar to a KM model, negative values denote worse performance than KM and \(1\) is the absolute best possible score.
= list(
measures msr("surv.graf", ERV = TRUE),
msr("surv.cindex", weight_meth = "G2", id = "surv.cindex.uno")
)
for (measure in measures) {
print(p$score(measure, task = task, train_set = part$train))
}
surv.graf
-0.09950096
surv.cindex.uno
0.551951
All metrics use by default the median survival distribution from the 3d array, no matter what is the which.curve
argument during the learner’s construction.
Resampling
Performing resampling with the BART
learner is very easy using mlr3
.
We first stratify the data by status
, so that in each resampling the proportion of censored vs un-censored patients remains the same:
$col_roles$stratum = 'status'
task$strata task
N row_id
1: 165 1,2,4,5,7,8,...
2: 63 3, 6,38,68,71,83,...
= resample(task, learner, resampling = rsmp("cv", folds = 5), store_backends = TRUE) rr
INFO [17:07:04.946] [mlr3] Applying learner 'surv.bart' on task 'lung' (iter 1/5)
INFO [17:07:07.647] [mlr3] Applying learner 'surv.bart' on task 'lung' (iter 2/5)
INFO [17:07:10.708] [mlr3] Applying learner 'surv.bart' on task 'lung' (iter 3/5)
INFO [17:07:13.382] [mlr3] Applying learner 'surv.bart' on task 'lung' (iter 4/5)
INFO [17:07:16.361] [mlr3] Applying learner 'surv.bart' on task 'lung' (iter 5/5)
No errors or warnings:
$errors rr
Empty data.table (0 rows and 2 cols): iteration,msg
$warnings rr
Empty data.table (0 rows and 2 cols): iteration,msg
Performance in each fold:
$score(measures) rr
task_id learner_id resampling_id iteration surv.graf surv.cindex.uno
1: lung surv.bart cv 1 -0.312614598 0.5869665
2: lung surv.bart cv 2 -0.103181391 0.5502903
3: lung surv.bart cv 3 0.001448263 0.6178001
4: lung surv.bart cv 4 -0.044161171 0.6157215
5: lung surv.bart cv 5 -0.043129352 0.5688389
Hidden columns: task, learner, resampling, prediction
Mean cross-validation performance:
$aggregate(measures) rr
surv.graf surv.cindex.uno
-0.1003276 0.5879235
Uncertainty Quantification in Survival Prediction
We will choose two patients from the test set and plot their survival prediction posterior estimates.
Let’s choose the patients with the worst and the best survival time:
= p$truth[,1]
death_times sort(death_times)
[1] 3 5 5 6 6 6 7 8 8 8 8 10 10 10 12 12 12 13 15 16 17 18 27
= which(death_times == min(death_times))[1] # died first
worst_indx = which(death_times == max(death_times))[1] # died last
best_indx
= c(worst_indx, best_indx)
patient_ids # which patient IDs patient_ids
[1] 5 18
= death_times[patient_ids]
death_times # 1st is worst, 2nd is best death_times
[1] 3 27
Subset Arrdist
to only the above 2 patients:
= p$distr[patient_ids]
arrd arrd
Arrdist(2x31x50)
We choose time points (in months) for the survival estimates:
= seq(1, 36) # 1 month - 3 years months
We use the $distr
interface and the $survival
property to get survival probabilities from an Arrdist
object as well as the quantile credible intervals (CIs). The median survival probabilities can be extracted as follows:
= arrd$survival(months) # 'med' for median
med
colnames(med) = paste0(patient_ids, "_med")
= as_tibble(med) %>% add_column(month = months)
med head(med)
# A tibble: 6 × 3
`5_med` `18_med` month
<dbl> <dbl> <int>
1 0.874 0.981 1
2 0.767 0.962 2
3 0.670 0.945 3
4 0.569 0.927 4
5 0.465 0.901 5
6 0.366 0.869 6
We can briefly verify model’s predictions: 1st patient survival probabilities on any month are lower (worst) compared to the 2nd patient.
Note that subsetting an Arrdist
(3d array) creates a Matdist
(2d matrix), for example we can explicitly get the median survival probabilities:
= arrd[, 0.5] # median
matd_median head(matd_median$survival(months)) # same as with `arrd`
[,1] [,2]
1 0.8741127 0.9808363
2 0.7670382 0.9621618
3 0.6701276 0.9450867
4 0.5688809 0.9272284
5 0.4647686 0.9007042
6 0.3660939 0.8687270
Using the mean
posterior survival probabilities or the ones from the last posterior draw is also possible and can be done as follows:
= arrd[, "mean"] # mean (if needed)
matd_mean head(matd_mean$survival(months))
[,1] [,2]
1 0.8652006 0.9748463
2 0.7533538 0.9521817
3 0.6560050 0.9293229
4 0.5623555 0.9051549
5 0.4750038 0.8758896
6 0.3815333 0.8360373
= arrd[, 50] # the 50th posterior draw
matd_50draw head(matd_50draw$survival(months))
[,1] [,2]
1 0.9178342 0.9920982
2 0.8424195 0.9842589
3 0.7732014 0.9764815
4 0.7096707 0.9687656
5 0.6029119 0.9495583
6 0.5122132 0.9307318
To get the CIs we will subset the Arrdist
using a quantile number (0-1), which extracts a Matdist
based on the cdf. The survival function is 1 - cdf, so low and upper bounds are reversed:
= arrd[, 0.975]$survival(months) # 2.5% bound
low = arrd[, 0.025]$survival(months) # 97.5% bound
high colnames(low) = paste0(patient_ids, "_low")
colnames(high) = paste0(patient_ids, "_high")
= as_tibble(low)
low = as_tibble(high) high
The median posterior survival probabilities for the two patient of interest and the corresponding CI bounds in a tidy format are:
=
surv_tbl bind_cols(low, med, high) %>%
pivot_longer(cols = !month, values_to = "surv",
names_to = c("patient_id", ".value"), names_sep = "_") %>%
relocate(patient_id)
surv_tbl
# A tibble: 72 × 5
patient_id month low med high
<chr> <int> <dbl> <dbl> <dbl>
1 5 1 0.713 0.874 0.953
2 18 1 0.929 0.981 0.996
3 5 2 0.508 0.767 0.903
4 18 2 0.863 0.962 0.991
5 5 3 0.362 0.670 0.855
6 18 3 0.801 0.945 0.985
7 5 4 0.244 0.569 0.804
8 18 4 0.734 0.927 0.977
9 5 5 0.146 0.465 0.748
10 18 5 0.654 0.901 0.969
# ℹ 62 more rows
We draw survival curves with the uncertainty for the survival probability quantified:
= c("#E41A1C", "#4DAF4A")
my_colors names(my_colors) = patient_ids
%>%
surv_tbl ggplot(aes(x = month, y = med)) +
geom_step(aes(color = patient_id), linewidth = 1) +
xlab('Time (Months)') +
ylab('Survival Probability') +
geom_ribbon(aes(ymin = low, ymax = high, fill = patient_id),
alpha = 0.3, show.legend = F) +
geom_vline(xintercept = death_times[1], linetype = 'dashed', color = my_colors[1]) +
geom_vline(xintercept = death_times[2], linetype = 'dashed', color = my_colors[2]) +
theme_bw(base_size = 14) +
scale_color_manual(values = my_colors) +
scale_fill_manual(values = my_colors) +
guides(color = guide_legend(title = "Patient ID"))
Partial Dependence Plot
We will use a Partial Dependence Plot (PDP) (Friedman 2001) to visualize how much different are males vs females in terms of their average survival predictions across time.
PDPs assume that features are independent. In our case we need to check that sex
doesn’t correlate with any of the other features used for training the BART
learner. Since sex
is a categorical feature, we fit a linear model using as target variable every other feature in the data (\(lm(feature \sim sex)\)) and conduct an ANOVA (ANalysis Of VAriance) to get the variance explained or \(R^2\). The square root of that value is the correlation measure we want.
# code from https://christophm.github.io/interpretable-ml-book/ale.html
= function(cnames, data) {
mycor = data[, cnames[1], with = FALSE][[1]]
x.num = data[, cnames[2], with = FALSE][[1]]
x.cat # R^2 = Cor(X, Y)^2 in simple linear regression
sqrt(summary(lm(x.num ~ x.cat))$r.squared)
}
= c("sex")
cnames = expand.grid(y = setdiff(colnames(d), "sex"), x = cnames)
combs $cor = apply(combs, 1, mycor, data = task$data()) # use the train set
combs combs
y x cor
1 time sex 0.12941337
2 status sex 0.24343282
3 age sex 0.12216709
4 inst sex 0.07826337
5 meal.cal sex 0.18389545
6 pat.karno sex 0.04132443
7 ph.ecog sex 0.02564987
8 ph.karno sex 0.01702471
9 wt.loss sex 0.13431983
sex
doesn’t correlate strongly with any other feature, so we can compute the PDP:
# create two datasets: one with males and one with females
# all other features remain the same (use train data, 205 patients)
= task$data(rows = part$train) # `rows = part$test` to use the test set
d
$sex = 1
d= as_task_surv(d, time = 'time', event = 'status', id = 'lung-males')
task_males $sex = 0
d= as_task_surv(d, time = 'time', event = 'status', id = 'lung-females')
task_females
# make predictions
= learner$predict(task_males)
p_males = learner$predict(task_females)
p_females
# take the median posterior survival probability
= p_males$distr$survival(months) # patients x times
surv_males = p_females$distr$survival(months) # patients x times
surv_females
# tidy up data: average and quantiles across patients
=
data_males apply(surv_males, 1, function(row) {
tibble(
low = quantile(row, probs = 0.025),
avg = mean(row),
high = quantile(row, probs = 0.975)
)%>%
}) bind_rows() %>%
add_column(sex = 'male', month = months, .before = 1)
=
data_females apply(surv_females, 1, function(row) {
tibble(
low = quantile(row, probs = 0.025),
avg = mean(row),
high = quantile(row, probs = 0.975)
)%>%
}) bind_rows() %>%
add_column(sex = 'female', month = months, .before = 1)
= bind_rows(data_males, data_females)
pdp_tbl pdp_tbl
# A tibble: 72 × 5
sex month low avg high
<chr> <int> <dbl> <dbl> <dbl>
1 male 1 0.836 0.942 0.981
2 male 2 0.704 0.889 0.963
3 male 3 0.587 0.839 0.943
4 male 4 0.488 0.788 0.924
5 male 5 0.392 0.732 0.897
6 male 6 0.304 0.663 0.860
7 male 7 0.234 0.601 0.829
8 male 8 0.172 0.550 0.799
9 male 9 0.130 0.503 0.766
10 male 10 0.0945 0.455 0.733
# ℹ 62 more rows
= c("#E41A1C", "#4DAF4A")
my_colors names(my_colors) = c('male', 'female')
%>%
pdp_tbl ggplot(aes(x = month, y = avg)) +
geom_step(aes(color = sex), linewidth = 1) +
xlab('Time (Months)') +
ylab('Survival Probability') +
geom_ribbon(aes(ymin = low, ymax = high, fill = sex), alpha = 0.2, show.legend = F) +
theme_bw(base_size = 14) +
scale_color_manual(values = my_colors) +
scale_fill_manual(values = my_colors)