Learners

To keep the dependencies on other packages reasonable, the base package mlr3 only ships with with regression and classification trees from the rpart package and some learners for debugging. A subjective selection of implementations for essential ML algorithms can be found in mlr3learners package. Survival learners are provided by mlr3proba, cluster learners via mlr3cluster. Additional learners, including some learners which are not yet to be considered stable or which are not available on CRAN, are connected via the mlr3extralearners package. For neural networks, see the mlr3torch extension.

Example Usage

Fit a classification tree on the Wisconsin Breast Cancer Data Set and predict on left-out observations.

library("mlr3verse")
Registered S3 methods overwritten by 'mlr3viz':
  method                    from     
  autoplot.LearnerSurvCoxPH mlr3proba
  plot.LearnerSurvCoxPH     mlr3proba
# retrieve the task
task = tsk("breast_cancer")

# split into two partitions
split = partition(task)

# retrieve a learner
learner = lrn("classif.rpart", keep_model = TRUE, predict_type = "prob")

# fit decision tree
learner$train(task, split$train)

# access learned model
learner$model
n= 458 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 458 161 benign (0.351528384 0.648471616)  
   2) bare_nuclei=3,4,5,6,7,8,9,10 171  23 malignant (0.865497076 0.134502924)  
     4) cell_shape=3,4,5,6,7,8,9,10 149   7 malignant (0.953020134 0.046979866) *
     5) cell_shape=1,2 22   6 benign (0.272727273 0.727272727)  
      10) cl_thickness=4,5,6,7,8,9,10 8   2 malignant (0.750000000 0.250000000) *
      11) cl_thickness=1,2,3 14   0 benign (0.000000000 1.000000000) *
   3) bare_nuclei=1,2 287  13 benign (0.045296167 0.954703833)  
     6) cell_size=4,5,6,7,8,9,10 14   2 malignant (0.857142857 0.142857143) *
     7) cell_size=1,2,3 273   1 benign (0.003663004 0.996336996) *
# predict on data frame with new data
predictions = learner$predict_newdata(task$data(split$test))

# predict on subset of the task
predictions = learner$predict(task, split$test)

# inspect predictions
predictions

── <PredictionClassif> for 225 observations: ───────────────────────────────────
 row_ids     truth  response prob.malignant prob.benign
       5    benign    benign    0.003663004  0.99633700
       9    benign    benign    0.003663004  0.99633700
      14    benign    benign    0.000000000  1.00000000
     ---       ---       ---            ---         ---
     671    benign    benign    0.003663004  0.99633700
     677    benign    benign    0.003663004  0.99633700
     681 malignant malignant    0.953020134  0.04697987
predictions$score(msr("classif.auc"))
classif.auc 
  0.9780656 
autoplot(predictions, type = "roc")
Warning in ggplot2::fortify(object, raw_curves = raw_curves, reduce_points = reduce_points): Arguments in `...` must be used.
✖ Problematic argument:
• raw_curves = raw_curves
ℹ Did you misspell an argument name?