Threshold Tuning for Classification Tasks

Adjust the probability thresholds of classes.

Author

Florian Pfisterer

Published

October 14, 2020

Intro

Predicting probabilities in classification tasks allows us to adjust the probability thresholds required for assigning an observation to a certain class. This can lead to improved classification performance, especially for cases where we e.g. aim to balance off metrics such as false positive and false negative rates.

This is for example often done in ROC Analysis. The mlr3book also has a chapter on ROC Analysis) for the interested reader. This post does not focus on ROC analysis, but instead focusses on the general problem of adjusting classification thresholds for arbitrary metrics.

This post assumes some familiarity with the mlr3, and also the mlr3pipelines and mlr3tuning packages, as both are used during the post. The mlr3book contains more details on those two packages. This post is a more in-depth version of the article on threshold tuning in the mlr3book.

Prerequisites

We load the mlr3verse package which pulls in the most important packages for this example.

library(mlr3)
library(mlr3pipelines)
library(mlr3tuning)

We initialize the random number generator with a fixed seed for reproducibility, and decrease the verbosity of the logger to keep the output clearly represented.

set.seed(7832)
lgr::get_logger("mlr3")$set_threshold("warn")
lgr::get_logger("bbotk")$set_threshold("warn")

Thresholds: A short intro

In order to understand thresholds, we will quickly showcase the effect of setting different thresholds:

First we create a learner that predicts probabilities and use it to predict on holdout data, storing the prediction.

learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(tsk("pima"), learner, rsmp("holdout"))
prd = rr$prediction()
prd
<PredictionClassif> for 256 observations:
 row_ids truth response  prob.pos  prob.neg
       4   neg      neg 0.1057692 0.8942308
       6   neg      neg 0.0200000 0.9800000
      10   pos      neg 0.1428571 0.8571429
     ---   ---      ---       ---       ---
     764   neg      neg 0.2777778 0.7222222
     766   neg      neg 0.0200000 0.9800000
     767   pos      pos 0.8000000 0.2000000

If we now look at the confusion matrix, the off-diagonal elements are errors made by our model (false positives and false negatives) while on-diagol ements are where our model predicted correctly.

# Print confusion matrix
prd$confusion
        truth
response pos neg
     pos  53  27
     neg  37 139
# Print False Positives and False Negatives
prd$score(list(msr("classif.fp"), msr("classif.fn")))
classif.fp classif.fn 
        27         37 

By adjusting the classification threshold, in this case the probability required to predict the positive class, we can now trade off predicting more positive cases (first row) against predicting fewer negative cases (second row) or vice versa.

# Lower threshold: More positives
prd$set_threshold(0.25)$confusion
        truth
response pos neg
     pos  78  71
     neg  12  95
# Higher threshold: Fewer positives
prd$set_threshold(0.75)$confusion
        truth
response pos neg
     pos  52  20
     neg  38 146

This threshold value can now be adjusted optimally for a given measure, such as accuracy. How this can be done is discussed in the following section.

Adjusting thresholds: Two strategies

Currently mlr3pipelines offers two main strategies towards adjusting classification thresholds. We can either expose the thresholds as a hyperparameter of the Learner by using PipeOpThreshold. This allows us to tune the thresholds via an outside optimizer from mlr3tuning.

Alternatively, we can also use PipeOpTuneThreshold which automatically tunes the threshold after each learner fit.

In this blog-post, we’ll go through both strategies.

PipeOpThreshold

PipeOpThreshold can be put directly after a Learner.

A simple example would be:

gr = lrn("classif.rpart", predict_type = "prob") %>>% po("threshold")
l = GraphLearner$new(gr)

Note, that predict_type = “prob” is required for po("threshold") to have any effect.

The thresholds are now exposed as a hyperparameter of the GraphLearner we created:

as.data.table(l$param_set)[, .(id, class, lower, upper, nlevels)]
                              id    class lower upper nlevels
                          <char>   <char> <num> <num>   <num>
 1:             classif.rpart.cp ParamDbl     0     1     Inf
 2:     classif.rpart.keep_model ParamLgl    NA    NA       2
 3:     classif.rpart.maxcompete ParamInt     0   Inf     Inf
 4:       classif.rpart.maxdepth ParamInt     1    30      30
 5:   classif.rpart.maxsurrogate ParamInt     0   Inf     Inf
 6:      classif.rpart.minbucket ParamInt     1   Inf     Inf
 7:       classif.rpart.minsplit ParamInt     1   Inf     Inf
 8: classif.rpart.surrogatestyle ParamInt     0     1       2
 9:   classif.rpart.usesurrogate ParamInt     0     2       3
10:           classif.rpart.xval ParamInt     0   Inf     Inf
11:         threshold.thresholds ParamUty    NA    NA     Inf

We can now tune those thresholds from the outside as follows:

Before tuning, we have to define which hyperparameters we want to tune over. In this example, we only tune over the thresholds parameter of the threshold PipeOp. you can easily imagine, that we can also jointly tune over additional hyperparameters, i.e. rpart’s cp parameter.

As the Task we aim to optimize for is a binary task, we can simply specify the threshold parameter:

search_space = ps(
  threshold.thresholds = p_dbl(lower = 0, upper = 1)
)

We now create a AutoTuner, which automatically tunes the supplied learner over the ParamSet we supplied above.

at = auto_tuner(
  tuner = tnr("random_search"),
  learner = l,
  resampling = rsmp("cv", folds = 3L),
  measure = msr("classif.ce"),
  search_space = search_space,
  term_evals = 5L,
)

at$train(tsk("german_credit"))

For multi-class Tasks, this is a little more complicated. We have to use a trafo to transform a set of ParamDbl into the desired format for threshold.thresholds: A named numeric vector containing the thresholds. This can be easily achieved via a trafo function:

search_space = ps(
  versicolor = p_dbl(lower = 0, upper = 1),
  setosa = p_dbl(lower = 0, upper = 1),
  virginica = p_dbl(lower = 0, upper = 1),
  .extra_trafo = function(x, param_set) {
    list(threshold.thresholds = mlr3misc::map_dbl(x, identity))
  }
)

Inside the .exta_trafo, we simply collect all set params into a named vector via map_dbl and store it in the threshold.thresholds slot expected by the learner.

Again, we create a AutoTuner, which automatically tunes the supplied learner over the ParamSet we supplied above.

at_2 = auto_tuner(
  tuner = tnr("random_search"),
  learner = l,
  resampling = rsmp("cv", folds = 3L),
  measure = msr("classif.ce"),
  search_space = search_space,
  term_evals = 5L,
)

at_2$train(tsk("iris"))

One drawback of this strategy is, that this requires us to fit a new model for each new threshold setting. While setting a threshold and computing performance is relatively cheap, fitting the learner is often more computationally demanding. A better strategy is therefore often to optimize the thresholds separately after each model fit.

PipeOpTuneThreshold

PipeOpTuneThreshold on the other hand works together with PipeOpLearnerCV. It directly optimizes the cross-validated predictions made by this PipeOp.

A simple example would be:

gr = po("learner_cv", lrn("classif.rpart", predict_type = "prob")) %>>%
  po("tunethreshold")
l2 = GraphLearner$new(gr)

Note, that predict_type = “prob” is required for po("tunethreshold") to have any effect. Additionally, note that this time no threshold parameter is exposed, it is automatically tuned internally.

as.data.table(l2$param_set)[, .(id, class, lower, upper, nlevels)]
                                        id    class lower upper nlevels
                                    <char>   <char> <num> <num>   <num>
 1:        classif.rpart.resampling.method ParamFct    NA    NA       2
 2:         classif.rpart.resampling.folds ParamInt     2   Inf     Inf
 3: classif.rpart.resampling.keep_response ParamLgl    NA    NA       2
 4:                       classif.rpart.cp ParamDbl     0     1     Inf
 5:               classif.rpart.keep_model ParamLgl    NA    NA       2
 6:               classif.rpart.maxcompete ParamInt     0   Inf     Inf
 7:                 classif.rpart.maxdepth ParamInt     1    30      30
 8:             classif.rpart.maxsurrogate ParamInt     0   Inf     Inf
 9:                classif.rpart.minbucket ParamInt     1   Inf     Inf
10:                 classif.rpart.minsplit ParamInt     1   Inf     Inf
11:           classif.rpart.surrogatestyle ParamInt     0     1       2
12:             classif.rpart.usesurrogate ParamInt     0     2       3
13:                     classif.rpart.xval ParamInt     0   Inf     Inf
14:           classif.rpart.affect_columns ParamUty    NA    NA     Inf
15:                  tunethreshold.measure ParamUty    NA    NA     Inf
16:                tunethreshold.optimizer ParamUty    NA    NA     Inf
17:                tunethreshold.log_level ParamUty    NA    NA     Inf

If we now use the GraphLearner, it automatically adjusts the thresholds during prediction.

Note that we can set ResamplingInsample as a resampling strategy for PipeOpLearnerCV in order to evaluate predictions on the “training” data. This is generally not advised, as it might lead to over-fitting on the thresholds but can significantly reduce runtime.

Finally, we can compare no threshold tuning to the tunethreshold approach:

Comparison of the approaches

bmr = benchmark(benchmark_grid(
  learners = list(no_tuning = lrn("classif.rpart"), internal = l2),
  tasks = tsk("german_credit"),
  rsmp("cv", folds = 3L)
))
OptimInstanceSingleCrit is deprecated. Use OptimInstanceBatchSingleCrit instead.
OptimInstanceSingleCrit is deprecated. Use OptimInstanceBatchSingleCrit instead.
OptimInstanceSingleCrit is deprecated. Use OptimInstanceBatchSingleCrit instead.
bmr$aggregate(list(msr("classif.ce"), msr("classif.fnr")))
      nr       task_id                  learner_id resampling_id iters classif.ce classif.fnr
   <int>        <char>                      <char>        <char> <int>      <num>       <num>
1:     1 german_credit               classif.rpart            cv     3  0.2760095  0.12723983
2:     2 german_credit classif.rpart.tunethreshold            cv     3  0.2879916  0.04485325
Hidden columns: resample_result

Session Information

sessioninfo::session_info(info = "packages")
═ Session info ═══════════════════════════════════════════════════════════════════════════════════════════════════════
─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────
 ! package        * version    date (UTC) lib source
   backports        1.5.0      2024-05-23 [1] CRAN (R 4.4.1)
   bbotk            1.1.1      2024-10-15 [1] CRAN (R 4.4.1)
   checkmate        2.3.2      2024-07-29 [1] CRAN (R 4.4.1)
   cli              3.6.3      2024-06-21 [1] CRAN (R 4.4.1)
 P codetools        0.2-20     2024-03-31 [?] CRAN (R 4.4.0)
   crayon           1.5.3      2024-06-20 [1] CRAN (R 4.4.1)
   data.table     * 1.16.2     2024-10-10 [1] CRAN (R 4.4.1)
   digest           0.6.37     2024-08-19 [1] CRAN (R 4.4.1)
   evaluate         1.0.1      2024-10-10 [1] CRAN (R 4.4.1)
   fastmap          1.2.0      2024-05-15 [1] CRAN (R 4.4.1)
   future           1.34.0     2024-07-29 [1] CRAN (R 4.4.1)
   future.apply     1.11.2     2024-03-28 [1] CRAN (R 4.4.1)
   GenSA            1.1.14.1   2024-09-21 [1] CRAN (R 4.4.1)
   globals          0.16.3     2024-03-08 [1] CRAN (R 4.4.1)
   htmltools        0.5.8.1    2024-04-04 [1] CRAN (R 4.4.1)
   htmlwidgets      1.6.4      2023-12-06 [1] CRAN (R 4.4.1)
   jsonlite         1.8.9      2024-09-20 [1] CRAN (R 4.4.1)
   knitr            1.48       2024-07-07 [1] CRAN (R 4.4.1)
   lgr              0.4.4      2022-09-05 [1] CRAN (R 4.4.1)
   listenv          0.9.1      2024-01-29 [1] CRAN (R 4.4.1)
   mlr3           * 0.21.1     2024-10-18 [1] CRAN (R 4.4.1)
   mlr3measures     1.0.0      2024-09-11 [1] CRAN (R 4.4.1)
   mlr3misc         0.15.1     2024-06-24 [1] CRAN (R 4.4.1)
   mlr3pipelines  * 0.7.0      2024-09-24 [1] CRAN (R 4.4.1)
   mlr3tuning     * 1.0.2      2024-10-14 [1] CRAN (R 4.4.1)
   mlr3website    * 0.0.0.9000 2024-10-18 [1] Github (mlr-org/mlr3website@20d1ddf)
   palmerpenguins   0.1.1      2022-08-15 [1] CRAN (R 4.4.1)
   paradox        * 1.0.1      2024-07-09 [1] CRAN (R 4.4.1)
   parallelly       1.38.0     2024-07-27 [1] CRAN (R 4.4.1)
   R6               2.5.1      2021-08-19 [1] CRAN (R 4.4.1)
   renv             1.0.11     2024-10-12 [1] CRAN (R 4.4.1)
   rlang            1.1.4      2024-06-04 [1] CRAN (R 4.4.1)
   rmarkdown        2.28       2024-08-17 [1] CRAN (R 4.4.1)
 P rpart            4.1.23     2023-12-05 [?] CRAN (R 4.4.0)
   sessioninfo      1.2.2      2021-12-06 [1] CRAN (R 4.4.1)
   stringi          1.8.4      2024-05-06 [1] CRAN (R 4.4.1)
   uuid             1.2-1      2024-07-29 [1] CRAN (R 4.4.1)
   withr            3.0.1      2024-07-31 [1] CRAN (R 4.4.1)
   xfun             0.48       2024-10-03 [1] CRAN (R 4.4.1)
   yaml             2.3.10     2024-07-26 [1] CRAN (R 4.4.1)

 [1] /home/marc/repositories/mlr3website/mlr-org/renv/library/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu
 [2] /home/marc/.cache/R/renv/sandbox/linux-ubuntu-noble/R-4.4/x86_64-pc-linux-gnu/9a444a72

 P ── Loaded and on-disk path mismatch.

──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────