Skip to contents

Tunes optimal probability thresholds over different PredictionClassifs.

mlr3::Learner predict_type: "prob" is required. Thresholds for each learner are optimized using the Optimizer supplied via the param_set. Defaults to GenSA. Returns a single PredictionClassif.

This PipeOp should be used in conjunction with PipeOpLearnerCV in order to optimize thresholds of cross-validated predictions. In order to optimize thresholds without cross-validation, use PipeOpLearnerCV in conjunction with ResamplingInsample.

Format

R6Class object inheriting from PipeOp.

Construction

* `PipeOpTuneThreshold$new(id = "tunethreshold", param_vals = list())` \cr
  (`character(1)`, `list`) -> `self` \cr

  • id :: character(1)
    Identifier of resulting object. Default: "tunethreshold".

  • param_vals :: named list
    List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default list().

Input and Output Channels

Input and output channels are inherited from PipeOp.

State

The $state is a named list with elements

  • thresholds :: numeric learned thresholds

Parameters

The parameters are the parameters inherited from PipeOp, as well as:

  • measure :: Measure | character
    Measure to optimize for. Will be converted to a Measure in case it is character. Initialized to "classif.ce", i.e. misclassification error.

  • optimizer :: Optimizer|character(1)
    Optimizer used to find optimal thresholds. If character, converts to Optimizer via opt. Initialized to OptimizerGenSA.

  • log_level :: character(1) | integer(1)
    Set a temporary log-level for lgr::get_logger("bbotk"). Initialized to: "warn".

Internals

Uses the optimizer provided as a param_val in order to find an optimal threshold. See the optimizer parameter for more info.

Methods

Only methods inherited from PipeOp.

See also

https://mlr-org.com/pipeops.html

Other PipeOps: PipeOpEnsemble, PipeOpImpute, PipeOpTargetTrafo, PipeOpTaskPreprocSimple, PipeOpTaskPreproc, PipeOp, mlr_pipeops_boxcox, mlr_pipeops_branch, mlr_pipeops_chunk, mlr_pipeops_classbalancing, mlr_pipeops_classifavg, mlr_pipeops_classweights, mlr_pipeops_colapply, mlr_pipeops_collapsefactors, mlr_pipeops_colroles, mlr_pipeops_copy, mlr_pipeops_datefeatures, mlr_pipeops_encodeimpact, mlr_pipeops_encodelmer, mlr_pipeops_encode, mlr_pipeops_featureunion, mlr_pipeops_filter, mlr_pipeops_fixfactors, mlr_pipeops_histbin, mlr_pipeops_ica, mlr_pipeops_imputeconstant, mlr_pipeops_imputehist, mlr_pipeops_imputelearner, mlr_pipeops_imputemean, mlr_pipeops_imputemedian, mlr_pipeops_imputemode, mlr_pipeops_imputeoor, mlr_pipeops_imputesample, mlr_pipeops_kernelpca, mlr_pipeops_learner, mlr_pipeops_missind, mlr_pipeops_modelmatrix, mlr_pipeops_multiplicityexply, mlr_pipeops_multiplicityimply, mlr_pipeops_mutate, mlr_pipeops_nmf, mlr_pipeops_nop, mlr_pipeops_ovrsplit, mlr_pipeops_ovrunite, mlr_pipeops_pca, mlr_pipeops_proxy, mlr_pipeops_quantilebin, mlr_pipeops_randomprojection, mlr_pipeops_randomresponse, mlr_pipeops_regravg, mlr_pipeops_removeconstants, mlr_pipeops_renamecolumns, mlr_pipeops_replicate, mlr_pipeops_scalemaxabs, mlr_pipeops_scalerange, mlr_pipeops_scale, mlr_pipeops_select, mlr_pipeops_smote, mlr_pipeops_spatialsign, mlr_pipeops_subsample, mlr_pipeops_targetinvert, mlr_pipeops_targetmutate, mlr_pipeops_targettrafoscalerange, mlr_pipeops_textvectorizer, mlr_pipeops_threshold, mlr_pipeops_unbranch, mlr_pipeops_updatetarget, mlr_pipeops_vtreat, mlr_pipeops_yeojohnson, mlr_pipeops

Examples

library("mlr3")

task = tsk("iris")
pop = po("learner_cv", lrn("classif.rpart", predict_type = "prob")) %>>%
  po("tunethreshold")

task$data()
#>        Species Petal.Length Petal.Width Sepal.Length Sepal.Width
#>   1:    setosa          1.4         0.2          5.1         3.5
#>   2:    setosa          1.4         0.2          4.9         3.0
#>   3:    setosa          1.3         0.2          4.7         3.2
#>   4:    setosa          1.5         0.2          4.6         3.1
#>   5:    setosa          1.4         0.2          5.0         3.6
#>  ---                                                            
#> 146: virginica          5.2         2.3          6.7         3.0
#> 147: virginica          5.0         1.9          6.3         2.5
#> 148: virginica          5.2         2.0          6.5         3.0
#> 149: virginica          5.4         2.3          6.2         3.4
#> 150: virginica          5.1         1.8          5.9         3.0
pop$train(task)
#> $tunethreshold.output
#> NULL
#> 

pop$state
#> $classif.rpart
#> $classif.rpart$model
#> n= 150 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
#>   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
#>     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *
#> 
#> $classif.rpart$log
#> Empty data.table (0 rows and 3 cols): stage,class,msg
#> 
#> $classif.rpart$train_time
#> [1] 0.005
#> 
#> $classif.rpart$param_vals
#> $classif.rpart$param_vals$xval
#> [1] 0
#> 
#> 
#> $classif.rpart$task_hash
#> [1] "b39ef23a66b1f1ee"
#> 
#> $classif.rpart$data_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#> 
#> $classif.rpart$task_prototype
#> Empty data.table (0 rows and 5 cols): Species,Petal.Length,Petal.Width,Sepal.Length,Sepal.Width
#> 
#> $classif.rpart$mlr3_version
#> [1] ‘0.16.0’
#> 
#> $classif.rpart$train_task
#> <TaskClassif:iris> (150 x 5): Iris Flowers
#> * Target: Species
#> * Properties: multiclass
#> * Features (4):
#>   - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width
#> 
#> $classif.rpart$affected_cols
#> [1] "Petal.Length" "Petal.Width"  "Sepal.Length" "Sepal.Width" 
#> 
#> $classif.rpart$intasklayout
#>              id    type
#> 1: Petal.Length numeric
#> 2:  Petal.Width numeric
#> 3: Sepal.Length numeric
#> 4:  Sepal.Width numeric
#> 
#> $classif.rpart$outtasklayout
#>                               id    type
#> 1:     classif.rpart.prob.setosa numeric
#> 2: classif.rpart.prob.versicolor numeric
#> 3:  classif.rpart.prob.virginica numeric
#> 
#> $classif.rpart$outtaskshell
#> Empty data.table (0 rows and 4 cols): Species,classif.rpart.prob.setosa,classif.rpart.prob.versicolor,classif.rpart.prob.virginica
#> 
#> 
#> $tunethreshold
#> $tunethreshold$threshold
#>     setosa versicolor  virginica 
#>  0.5134816  0.8486673  0.2710435 
#> 
#>