Skip to Content

Announcing mlr3spatial

We are happy to announce that mlr3spatial has been released on CRAN in November 2021. mlr3spatial simplifies the handling of spatial objects in the mlr3 ecosystem. Before mlr3spatial, the user had to extract tabular data from spatial objects to train a model or predict spatial data.

Now, mlr3 Tasks can directly read from spatial objects via specialized Data Backends. Such tasks can be used to train a model or to perform resampling just like any other mlr3 task. We support spatial raster objects created by the terra, raster and stars packages with DataBackendRaster. Additionally, vector data created with the sf package is handled with DataBackendVector.

The predict_raster() function creates spatial rasters and features with predictions from mlr3 learners. We only have to pass a task with a spatial data backend which provides the data and spatial reference. To avoid memory issues with large raster files, prediction is done in chunks. For this, the raster map is divided into multiple horizontal strips. The vertical extension of these strips is controlled by the chunksize parameter. The actual memory usage per core is a multiple of the specified chunk size. We choose a default chunk size of 200 Megabytes which should work on most consumer grade machines. If more memory is available, a larger chunk size accelerates the prediction process.

One after the other, the raster chunks are loaded into memory and the prediction is written to disk. Finally, the complete raster is available on disk. The learner can also make use of future-based parallelization to accelerate the predictions. The vignette on “Benchmarking parallel predictions” showcases the parallelization capabilities of mlr3spatial.

Use Case - Landsat7 data as {stars} object

Data Preparation

library("mlr3")
library("mlr3spatial")

First, the TIFF files is read via stars::read_stars() and put into a DataBackendRaster. The DataBackend is then used to create a regression task with the response being layer.1.

tif = system.file("tif/L7_ETMs.tif", package = "stars")
stack = stars::read_stars(tif)

backend = as_data_backend(stack)
task = as_task_regr(backend, target = "layer.1")

print(task)
## <TaskRegr:backend> (122848 x 6)
## * Target: layer.1
## * Properties: -
## * Features (5):
##   - dbl (5): layer.2, layer.3, layer.4, layer.5, layer.6

For large raster files with millions of values it helps to predict in parallel. To enable this, set learner$parallel_predict = TRUE and initiate a parallel plan via {future}, e.g. via plan("multisession"). Since this is only an example, parallelization is not enabled here. Here we will use a simple regression tree as an example learner. In practice you might want to use a different learner - you can find an overview of available learners here.

learner = lrn("regr.rpart")
set.seed(42)
row_ids = sample(1:task$nrow, 500)
learner$train(task, row_ids = row_ids)

print(learner)
## <LearnerRegrRpart:regr.rpart>
## * Model: rpart
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, selected_features, weights

Prediction

For prediction predict_spatial() is used. It will return a raster file which contains the predictions. Users can select which R spatial format the returned raster should have. Note that by default mlr3spatial operates with SpatRaster objects from package terra. If a different output format is requested (e.g. "stars"), coercion is happening in the background which might take some time.

ras = predict_spatial(task, learner, format = "stars")
## INFO  [15:50:07.531] Start raster prediction 
## INFO  [15:50:07.548] Prediction is executed with a chunksize of 200, 1 chunk(s) in total, 122848 values per chunk 
## INFO  [15:50:07.705] Chunk 1 of 1 finished 
## INFO  [15:50:07.716] Finished raster prediction in 0 seconds
names(ras) = "cadmium"

print(ras)
## stars object with 2 dimensions and 1 attribute
## attribute(s):
##             Min.  1st Qu.   Median     Mean 3rd Qu.     Max.
## cadmium  62.3629 70.30233 77.01695 79.05135 89.2809 118.1429
## dimension(s):
##   from  to  offset delta                     refsys point values x/y
## x    1 349  288776  28.5 SIRGAS 2000 / UTM zone 25S FALSE   NULL [x]
## y    1 352 9120761 -28.5 SIRGAS 2000 / UTM zone 25S FALSE   NULL [y]

Visualization

Finally we can plot the predictions. The color vector is extracted from the “viridis” color palette via dput(viridis::viridis_pal()(5)) and provided to the S3 plot() call, which makes use of the S3 plot method within the stars package in this scenario.

plot(ras, col = c("#440154FF", "#443A83FF", "#31688EFF", "#21908CFF", "#35B779FF", "#8FD744FF", "#FDE725FF"))