library(mlr3)
= tsk(...)
task = lrn(...) # create the learner
lrn_rpart $train(...) # train the learner on the task
lrn_rpart$... # access the raw model object that was fitted lrn_rpart
Goal
The goal for this exercise is to familiarize yourself with two very important machine learning methods, the decision tree and random forest. After this exercise, you should be able to train these models and extract important information to understand the model internals.
Exercises
Fit a decision tree
Use task = tsk("german_credit")
to create the classification task for the german_credit
data and create a decision tree learner (e.g., a CART learner). Train the decision tree on the german_credit
classification task. Look at the output of the trained decision tree (you have to access the raw model object).
Hint 1:
The learner we are focusing on here is a decision tree implemented inrpart
. The corresponding mlr3
learner key is "classif.rpart"
. For this exercise, we use the learner with the default hyperparameters. The raw model object can be accessed from the $model
slot of the trained learner.
Hint 2:
library(mlr3)
= tsk("german_credit")
task = lrn("classif.rpart")
lrn_rpart $train(task)
lrn_rpart$model lrn_rpart
n= 1000
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 1000 300 good (0.7000000 0.3000000)
2) status=0<= ... < 200 DM,... >= 200 DM / salary for at least 1 year 457 60 good (0.8687090 0.1312910) *
3) status=no checking account,... < 0 DM 543 240 good (0.5580110 0.4419890)
6) duration< 22.5 306 106 good (0.6535948 0.3464052)
12) credit_history=no credits taken/all credits paid back duly,existing credits paid back duly till now,all credits at this bank paid back duly 278 85 good (0.6942446 0.3057554)
24) amount< 7491.5 271 79 good (0.7084871 0.2915129)
48) purpose=others,car (new),car (used),furniture/equipment,domestic appliances,vacation,retraining,business 256 69 good (0.7304688 0.2695312)
96) duration< 11.5 73 9 good (0.8767123 0.1232877) *
97) duration>=11.5 183 60 good (0.6721311 0.3278689)
194) amount>=1387.5 118 29 good (0.7542373 0.2457627) *
195) amount< 1387.5 65 31 good (0.5230769 0.4769231)
390) property=unknown / no property,car or other 45 14 good (0.6888889 0.3111111) *
391) property=building soc. savings agr. / life insurance,real estate 20 3 bad (0.1500000 0.8500000) *
49) purpose=radio/television,repairs 15 5 bad (0.3333333 0.6666667) *
25) amount>=7491.5 7 1 bad (0.1428571 0.8571429) *
13) credit_history=delay in paying off in the past,critical account/other credits elsewhere 28 7 bad (0.2500000 0.7500000) *
7) duration>=22.5 237 103 bad (0.4345992 0.5654008)
14) savings=500 <= ... < 1000 DM,... >= 1000 DM 41 12 good (0.7073171 0.2926829) *
15) savings=unknown/no savings account,... < 100 DM,100 <= ... < 500 DM 196 74 bad (0.3775510 0.6224490)
30) duration< 47.5 160 69 bad (0.4312500 0.5687500)
60) purpose=car (new) 23 6 good (0.7391304 0.2608696) *
61) purpose=others,car (used),furniture/equipment,domestic appliances,repairs,retraining,business 137 52 bad (0.3795620 0.6204380) *
31) duration>=47.5 36 5 bad (0.1388889 0.8611111) *
Visualize the tree structure
To interpret the model and to gain more information about the decision making of predictions, we decide to take a closer look at the decision tree structure by visualizing it.
Hint 1:
See code example in the help page ?rpart::plot.rpart
which shows how to use the plot
and text
function to the rpart
model object. Note that different packages exist to plot the decision tree structure in a visually more appealing way:
- The
rpart.plot
function from the equally named packagerpart.plot
which is applied on the rawrpart
model object. - The
plot.party
function from the packagepartykit
which is applied to arpart
model object after converting it into aparty
model object using theas.party
function. - The
ggparty
function from the equally named packageggparty
which is applied after converting therpart
model object into aparty
model object using theas.party
function.
Hint 2:
library("rpart")
...(lrn_rpart$...)
text(lrn_rpart$...)
# Alternative using e.g. the rpart.plot package
library("rpart.plot")
...(lrn_rpart$...)
The possibility of visualizing a tree makes it interpretable and helps to understand how new predictions are calculated.
library(rpart.plot)
Lade nötiges Paket: rpart
rpart.plot(lrn_rpart$model)
Note: Other functions to visualize an rpart
tree are:
- The (very) basic
rpart
plot method:
plot(lrn_rpart$model)
text(lrn_rpart$model, use.n = TRUE)
- Convert the
rpart
object to aparty
object to automatically use the respectiveplot()
method:
library(partykit)
Lade nötiges Paket: grid
Lade nötiges Paket: libcoin
Lade nötiges Paket: mvtnorm
= as.party(lrn_rpart$model)
partytree plot(partytree)
- Use
ggparty
to create highly customizable plots:
library(ggparty)
Lade nötiges Paket: ggplot2
ggparty(partytree) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes_string(x = NA, fill = "credit_risk"),
position = position_fill()), xlab("Credit Risk")))
Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
ℹ Please use tidy evaluation idioms with `aes()`.
ℹ See also `vignette("ggplot2-in-packages")` for more information.
Fit a random forest
To get a more powerful learner we decide to also fit a random forest. Therefore, fit a random forest with default hyperparameters to the german_credit
task.
Reminder
One of the drawbacks of using trees is the instability of the predictor. Small changes in the data may lead to a very different model and therefore a high variance of the predictions. The random forest takes advantages of that and reduces the variance by applying bagging to decision trees.
Hint 1:
Use the mlr3
learner classif.ranger
which uses the ranger
implementation to train a random forest.
Hint 2:
library(mlr3)
library(mlr3learners)
= lrn(...) # create the learner
lrn_ranger $...(...) # train the learner on the task lrn_ranger
library(mlr3)
library(mlr3learners)
= lrn("classif.ranger")
lrn_ranger $train(task) lrn_ranger
ROC Analysis
The bank wants to use a tree-based model to predict the credit risk. Conduct a simple benchmark to assess if a decision tree or a random forest works better for these purposes. Specifically, the bank wants that among credit applications the system predicts to be “good”, it can expect at most 10% to be “bad”. Simultaneously, the bank aims at correctly classifying 90% or more of all applications that are “good”. Visualize the benchmark results in a way that helps answer this question. Can the bank expect the model to fulfil their requirements? Which model performs better?
Hint 1:
A benchmark requires three arguments: a task, a list of learners, and a resampling object.Solution
Click me
= lrn("classif.rpart", predict_type = "prob")
tree = lrn("classif.ranger", predict_type = "prob")
forest
= list(tree, forest)
lrns
= rsmp("cv", folds = 5)
cv5 $instantiate(task)
cv5
= benchmark(benchmark_grid(task, lrns, cv5)) bmr
INFO [15:24:33.591] [mlr3] Running benchmark with 10 resampling iterations
INFO [15:24:35.472] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 1/5)
INFO [15:24:36.742] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 2/5)
INFO [15:24:37.993] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 3/5)
INFO [15:24:39.407] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 4/5)
INFO [15:24:40.629] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 5/5)
INFO [15:24:41.849] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 1/5)
INFO [15:24:43.137] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 2/5)
INFO [15:24:44.572] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 3/5)
INFO [15:24:45.424] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 4/5)
INFO [15:24:46.300] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 5/5)
INFO [15:24:48.180] [mlr3] Finished benchmark
::autoplot(bmr, type = "prc") mlr3viz
While the random forest dominates the decision tree, neither model can fulfil the bank’s requirement of a precision and recall of >90%.
Understand hyperparameters
Use task = tsk("german_credit")
to create the classification task for the german_credit
data. In this exercise, we want to fit decision trees and random forests with different hyperparameters (which can have a significant impact on the performance). Each learner implemented in R
(e.g. ranger
or rpart
) has a lot of control settings that directly influence the model fitting (the so-called hyperparameters). Here, we will consdider the hyperparameters mtry
for the ranger
learner and maxdepth
for the rpart
learner.
Your task is to manually create a list containing multiple rpart
and ranger
learners with different hyperparameter values (e.g., try out increasing maxdepth
values for rpart
). In the next step, we will use this list to see how the model performance changes for different hyperparameter values.
The help page of ranger (
?ranger
) gives a detailed explanation of the hyperparameters:mtry
: Number of variables to possibly split at in each node. Default is the (rounded down) square root of the number variables. Alternatively, a single argument function returning an integer, given the number of independent variables.NOTE: In a
ranger
learner created withmlr3
, you have the possibility to setmtry.ratio
instead ofmtry
which allows you to set the fraction of variables to be used instead of having to set the number of variables.For
rpart
, we have to dig a bit deeper. Looking at?rpart
contains no description about the hyperparameters. To get further information we have to open?rpart.control
:maxdepth
: Set the maximum depth of any node of the final tree, with the root node counted as depth 0. Values greater than 30 rpart will give nonsense results on 32-bit machines.
Hint 1:
The learners we are focusing on here is a decision tree implemented in rpart
and a random forest implemented in ranger
. The corresponding mlr3
learner key is "classif.rpart"
and "classif.ranger"
. In mlr3
, we can get an overview about all hyperparameters in the $param_set
slot. With a mlr3
learner it is possible to get help about the underlying method by using the $help()
method (e.g. ?lrn_ranger$help()
):
lrn("classif.rpart")$help()
lrn("classif.ranger")$help()
?rpart::rpart.control
and ?ranger::ranger
.
Hint 2:
The possible choices for the hyperparameters can also be viewed with $param_set
. Setting the hyperparameters can be done directly in the lrn()
call:
# Define a list of learners for the benchmark:
= list(
lrns lrn("classif.rpart", ...),
lrn("classif.rpart", ...),
lrn("classif.rpart", ...),
lrn("classif.ranger", ...),
lrn("classif.ranger", ...),
lrn("classif.ranger", ...))
library(mlr3verse)
set.seed(31415L)
= mlr3::tsk("german_credit")
task
= list(
lrns lrn("classif.rpart", maxdepth = 1),
lrn("classif.rpart", maxdepth = 5),
lrn("classif.rpart", maxdepth = 20),
lrn("classif.ranger", mtry.ratio = 0.2),
lrn("classif.ranger", mtry.ratio = 0.5),
lrn("classif.ranger", mtry.ratio = 0.8))
Comparison of trees and random forests
Does it make a difference w.r.t. model performance if we use different hyperparameters? Use the learners from the previous exercise and compare them in a benchmark. Use 5-fold cross-validation as resampling technique and the classification error as performance measure. Visualize the results of the benchmark.
Hint 1:
The function to conduct the benchmark isbenchmark
and requires to define the resampling with rsmp
and the benchmark grid with benchmark_grid
.
Hint 2:
set.seed(31415L)
= list(
lrns lrn("classif.rpart", maxdepth = 1),
lrn("classif.rpart", maxdepth = 5),
lrn("classif.rpart", maxdepth = 20),
lrn("classif.ranger", mtry.ratio = 0.2),
lrn("classif.ranger", mtry.ratio = 0.5),
lrn("classif.ranger", mtry.ratio = 0.8))
= rsmp(..., folds = ...)
cv5 $instantiate(...)
cv5
= ...(...(task, lrns, cv5))
bmr
::autoplot(bmr, measure = msr("classif.ce")) mlr3viz
set.seed(31415L)
= list(
lrns lrn("classif.rpart", id = "rpart_md1", maxdepth = 1, predict_type = "prob"),
lrn("classif.rpart", id = "rpart_md5", maxdepth = 5, predict_type = "prob"),
lrn("classif.rpart", id = "rpart_md20", maxdepth = 20, predict_type = "prob"),
lrn("classif.ranger", id = "rf_mtryr0.2", mtry.ratio = 0.2, predict_type = "prob"),
lrn("classif.ranger", id = "rf_mtryr0.5", mtry.ratio = 0.5, predict_type = "prob"),
lrn("classif.ranger", id = "rf_mtry0.8", mtry.ratio = 0.8, predict_type = "prob"))
= rsmp("cv", folds = 5)
cv5 $instantiate(task)
cv5
= benchmark(benchmark_grid(task, lrns, cv5)) bmr
INFO [15:24:53.425] [mlr3] Running benchmark with 30 resampling iterations
INFO [15:24:57.074] [mlr3] Applying learner 'rpart_md1' on task 'german_credit' (iter 1/5)
INFO [15:25:01.196] [mlr3] Applying learner 'rpart_md1' on task 'german_credit' (iter 2/5)
INFO [15:25:05.675] [mlr3] Applying learner 'rpart_md1' on task 'german_credit' (iter 3/5)
INFO [15:25:09.801] [mlr3] Applying learner 'rpart_md1' on task 'german_credit' (iter 4/5)
INFO [15:25:14.469] [mlr3] Applying learner 'rpart_md1' on task 'german_credit' (iter 5/5)
INFO [15:25:18.623] [mlr3] Applying learner 'rpart_md5' on task 'german_credit' (iter 1/5)
INFO [15:25:22.188] [mlr3] Applying learner 'rpart_md5' on task 'german_credit' (iter 2/5)
INFO [15:25:26.755] [mlr3] Applying learner 'rpart_md5' on task 'german_credit' (iter 3/5)
INFO [15:25:29.383] [mlr3] Applying learner 'rpart_md5' on task 'german_credit' (iter 4/5)
INFO [15:25:31.529] [mlr3] Applying learner 'rpart_md5' on task 'german_credit' (iter 5/5)
INFO [15:25:33.335] [mlr3] Applying learner 'rpart_md20' on task 'german_credit' (iter 1/5)
INFO [15:25:34.990] [mlr3] Applying learner 'rpart_md20' on task 'german_credit' (iter 2/5)
INFO [15:25:36.561] [mlr3] Applying learner 'rpart_md20' on task 'german_credit' (iter 3/5)
INFO [15:25:38.040] [mlr3] Applying learner 'rpart_md20' on task 'german_credit' (iter 4/5)
INFO [15:25:39.762] [mlr3] Applying learner 'rpart_md20' on task 'german_credit' (iter 5/5)
INFO [15:25:41.187] [mlr3] Applying learner 'rf_mtryr0.2' on task 'german_credit' (iter 1/5)
INFO [15:25:42.751] [mlr3] Applying learner 'rf_mtryr0.2' on task 'german_credit' (iter 2/5)
INFO [15:25:44.388] [mlr3] Applying learner 'rf_mtryr0.2' on task 'german_credit' (iter 3/5)
INFO [15:25:45.799] [mlr3] Applying learner 'rf_mtryr0.2' on task 'german_credit' (iter 4/5)
INFO [15:25:47.310] [mlr3] Applying learner 'rf_mtryr0.2' on task 'german_credit' (iter 5/5)
INFO [15:25:49.597] [mlr3] Applying learner 'rf_mtryr0.5' on task 'german_credit' (iter 1/5)
INFO [15:25:51.239] [mlr3] Applying learner 'rf_mtryr0.5' on task 'german_credit' (iter 2/5)
INFO [15:25:53.017] [mlr3] Applying learner 'rf_mtryr0.5' on task 'german_credit' (iter 3/5)
INFO [15:25:54.643] [mlr3] Applying learner 'rf_mtryr0.5' on task 'german_credit' (iter 4/5)
INFO [15:25:56.334] [mlr3] Applying learner 'rf_mtryr0.5' on task 'german_credit' (iter 5/5)
INFO [15:25:58.093] [mlr3] Applying learner 'rf_mtry0.8' on task 'german_credit' (iter 1/5)
INFO [15:26:00.132] [mlr3] Applying learner 'rf_mtry0.8' on task 'german_credit' (iter 2/5)
INFO [15:26:02.564] [mlr3] Applying learner 'rf_mtry0.8' on task 'german_credit' (iter 3/5)
INFO [15:26:04.397] [mlr3] Applying learner 'rf_mtry0.8' on task 'german_credit' (iter 4/5)
INFO [15:26:06.262] [mlr3] Applying learner 'rf_mtry0.8' on task 'german_credit' (iter 5/5)
INFO [15:26:06.886] [mlr3] Finished benchmark
::autoplot(bmr, measure = msr("classif.ce")) mlr3viz
Looking at the boxplots reveals that the performance of the learners highly depends on the choice of the hyperparameters.
Follow up question: How to properly set the hyperparameters? Answer: Hyperparameter optimization (see next use case)
Summary
- We learned how to use two of the most widely used learner for building a tree with
rpart
and a random forest withranger
. - Finally, we looked at different hyperparameter and how they affect the performance in a benchmark.
- The next step would be to use an algorithm to automatically search for good hyperparameter configurations.
Further information
Tree implementations: One of the longest paragraphs in the CRAN Task View about Machine Learning and Statistical Learning gives an overview of existing tree implementations:
“[…] Tree-structured models for regression, classification and survival analysis, following the ideas in the CART book, are implemented in rpart (shipped with base R) and tree. Package rpart is recommended for computing CART-like trees. A rich toolbox of partitioning algorithms is available in Weka, package RWeka provides an interface to this implementation, including the J4.8-variant of C4.5 and M5. The Cubist package fits rule-based models (similar to trees) with linear regression models in the terminal leaves, instance-based corrections and boosting. The C50 package can fit C5.0 classification trees, rule-based models, and boosted versions of these. pre can fit rule-based models for a wider range of response variable types. […]”