Tunes optimal probability thresholds over different PredictionClassif
s.
mlr3::Learner
predict_type
: "prob"
is required.
Thresholds for each learner are optimized using the Optimizer
supplied via
the param_set
.
Defaults to GenSA
.
Returns a single PredictionClassif
.
This PipeOp should be used in conjunction with PipeOpLearnerCV
in order to
optimize thresholds of cross-validated predictions.
In order to optimize thresholds without cross-validation, use PipeOpLearnerCV
in conjunction with ResamplingInsample
.
R6Class
object inheriting from PipeOp
.
* `PipeOpTuneThreshold$new(id = "tunethreshold", param_vals = list())` \cr (`character(1)`, `list`) -> `self` \cr
id
:: character(1)
Identifier of resulting object. Default: "tunethreshold".
param_vals
:: named list
List of hyperparameter settings, overwriting the hyperparameter settings
that would otherwise be set during construction. Default list()
.
Input and output channels are inherited from PipeOp
.
The $state
is a named list
with elements
thresholds
:: numeric
learned thresholds
The parameters are the parameters inherited from PipeOp
, as well as:
measure
:: Measure
| character
Measure
to optimize for.
Will be converted to a Measure
in case it is character
.
Initialized to "classif.ce"
, i.e. misclassification error.
optimizer
:: Optimizer
|character(1)
Optimizer
used to find optimal thresholds.
If character
, converts to Optimizer
via opt
. Initialized to OptimizerGenSA
.
log_level
:: character(1)
| integer(1)
Set a temporary log-level for lgr::get_logger("bbotk")
. Initialized to: "warn".
Uses the optimizer
provided as a param_val
in order to find an optimal threshold.
See the optimizer
parameter for more info.
Only methods inherited from PipeOp
.
Other PipeOps:
PipeOpEnsemble
,
PipeOpImpute
,
PipeOpTargetTrafo
,
PipeOpTaskPreprocSimple
,
PipeOpTaskPreproc
,
PipeOp
,
mlr_pipeops_boxcox
,
mlr_pipeops_branch
,
mlr_pipeops_chunk
,
mlr_pipeops_classbalancing
,
mlr_pipeops_classifavg
,
mlr_pipeops_classweights
,
mlr_pipeops_colapply
,
mlr_pipeops_collapsefactors
,
mlr_pipeops_colroles
,
mlr_pipeops_copy
,
mlr_pipeops_datefeatures
,
mlr_pipeops_encodeimpact
,
mlr_pipeops_encodelmer
,
mlr_pipeops_encode
,
mlr_pipeops_featureunion
,
mlr_pipeops_filter
,
mlr_pipeops_fixfactors
,
mlr_pipeops_histbin
,
mlr_pipeops_ica
,
mlr_pipeops_imputeconstant
,
mlr_pipeops_imputehist
,
mlr_pipeops_imputelearner
,
mlr_pipeops_imputemean
,
mlr_pipeops_imputemedian
,
mlr_pipeops_imputemode
,
mlr_pipeops_imputeoor
,
mlr_pipeops_imputesample
,
mlr_pipeops_kernelpca
,
mlr_pipeops_learner
,
mlr_pipeops_missind
,
mlr_pipeops_modelmatrix
,
mlr_pipeops_multiplicityexply
,
mlr_pipeops_multiplicityimply
,
mlr_pipeops_mutate
,
mlr_pipeops_nmf
,
mlr_pipeops_nop
,
mlr_pipeops_ovrsplit
,
mlr_pipeops_ovrunite
,
mlr_pipeops_pca
,
mlr_pipeops_proxy
,
mlr_pipeops_quantilebin
,
mlr_pipeops_randomprojection
,
mlr_pipeops_randomresponse
,
mlr_pipeops_regravg
,
mlr_pipeops_removeconstants
,
mlr_pipeops_renamecolumns
,
mlr_pipeops_replicate
,
mlr_pipeops_scalemaxabs
,
mlr_pipeops_scalerange
,
mlr_pipeops_scale
,
mlr_pipeops_select
,
mlr_pipeops_smote
,
mlr_pipeops_spatialsign
,
mlr_pipeops_subsample
,
mlr_pipeops_targetinvert
,
mlr_pipeops_targetmutate
,
mlr_pipeops_targettrafoscalerange
,
mlr_pipeops_textvectorizer
,
mlr_pipeops_threshold
,
mlr_pipeops_unbranch
,
mlr_pipeops_updatetarget
,
mlr_pipeops_vtreat
,
mlr_pipeops_yeojohnson
,
mlr_pipeops
library("mlr3") task = tsk("iris") pop = po("learner_cv", lrn("classif.rpart", predict_type = "prob")) %>>% po("tunethreshold") task$data()#> Species Petal.Length Petal.Width Sepal.Length Sepal.Width #> 1: setosa 1.4 0.2 5.1 3.5 #> 2: setosa 1.4 0.2 4.9 3.0 #> 3: setosa 1.3 0.2 4.7 3.2 #> 4: setosa 1.5 0.2 4.6 3.1 #> 5: setosa 1.4 0.2 5.0 3.6 #> --- #> 146: virginica 5.2 2.3 6.7 3.0 #> 147: virginica 5.0 1.9 6.3 2.5 #> 148: virginica 5.2 2.0 6.5 3.0 #> 149: virginica 5.4 2.3 6.2 3.4 #> 150: virginica 5.1 1.8 5.9 3.0pop$train(task)#> $tunethreshold.output #> NULL #>pop$state#> $classif.rpart #> $classif.rpart$model #> n= 150 #> #> node), split, n, loss, yval, (yprob) #> * denotes terminal node #> #> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333) #> 2) Petal.Length< 2.45 50 0 setosa (1.00000000 0.00000000 0.00000000) * #> 3) Petal.Length>=2.45 100 50 versicolor (0.00000000 0.50000000 0.50000000) #> 6) Petal.Width< 1.75 54 5 versicolor (0.00000000 0.90740741 0.09259259) * #> 7) Petal.Width>=1.75 46 1 virginica (0.00000000 0.02173913 0.97826087) * #> #> $classif.rpart$log #> Empty data.table (0 rows and 3 cols): stage,class,msg #> #> $classif.rpart$train_time #> [1] 0.006 #> #> $classif.rpart$train_task #> <TaskClassif:iris> (0 x 5) #> * Target: Species #> * Properties: multiclass #> * Features (4): #> - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width #> #> $classif.rpart$affected_cols #> [1] "Petal.Length" "Petal.Width" "Sepal.Length" "Sepal.Width" #> #> $classif.rpart$intasklayout #> id type #> 1: Petal.Length numeric #> 2: Petal.Width numeric #> 3: Sepal.Length numeric #> 4: Sepal.Width numeric #> #> $classif.rpart$outtasklayout #> id type #> 1: classif.rpart.prob.setosa numeric #> 2: classif.rpart.prob.versicolor numeric #> 3: classif.rpart.prob.virginica numeric #> #> $classif.rpart$outtaskshell #> Empty data.table (0 rows and 4 cols): Species,classif.rpart.prob.setosa,classif.rpart.prob.versicolor,classif.rpart.prob.virginica #> #> #> $tunethreshold #> $tunethreshold$threshold #> setosa versicolor virginica #> 0.5134816 0.8486673 0.2710435 #> #>