Both undersamples a Task
to keep only a fraction of the rows of the majority class,
as well as oversamples (repeats data points) rows of the minority class.
Sampling happens only during training phase. Class-balancing a Task
by sampling may be
beneficial for classification with imbalanced training data.
Format
R6Class
object inheriting from PipeOpTaskPreproc
/PipeOp
.
Construction
id
::character(1)
Identifier of the resulting object, default"classbalancing"
param_vals
:: namedlist
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Defaultlist()
.
Input and Output Channels
Input and output channels are inherited from PipeOpTaskPreproc
. Instead of a Task
, a
TaskClassif
is used as input and output during training and prediction.
The output during training is the input Task
with added or removed rows to balance target classes.
The output during prediction is the unchanged input.
State
The $state
is a named list
with the $state
elements inherited from PipeOpTaskPreproc
.
Parameters
The parameters are the parameters inherited from PipeOpTaskPreproc
; however, the affect_columns
parameter is not present. Further parameters are:
ratio
::numeric(1)
Ratio of number of rows of classes to keep, relative to the$reference
value. Initialized to 1.reference
::numeric(1)
What the$ratio
value is measured against. Can be"all"
(mean instance count of all classes),"major"
(instance count of class with most instances),"minor"
(instance count of class with fewest instances),"nonmajor"
(average instance count of all classes except the major one),"nonminor"
(average instance count of all classes except the minor one), and"one"
($ratio
determines the number of instances to have, per class). Initialized to"all"
.adjust
::numeric(1)
Which classes to up / downsample. Can be"all"
(up and downsample all to match required instance count),"major"
,"minor"
,"nonmajor"
,"nonminor"
(see respective values for$reference
),"upsample"
(only upsample), and"downsample"
. Initialized to"all"
.shuffle
::logical(1)
Whether to shuffle the rows of the resulting task. In case the data is upsampled andshuffle = FALSE
, the resulting task will have the original rows (which were not removed in downsampling) in the original order, followed by all newly added rows ordered by target class. Initialized toTRUE
.
Internals
Up / downsampling happens as follows: At first, a "target class count" is calculated, by taking the mean
class count of all classes indicated by the reference
parameter (e.g. if reference
is "nonmajor"
:
the mean class count of all classes that are not the "major" class, i.e. the class with the most samples)
and multiplying this with the value of the ratio
parameter. If reference
is "one"
, then the "target
class count" is just the value of ratio
(i.e. 1 * ratio
).
Then for each class that is referenced by the adjust
parameter (e.g. if adjust
is "nonminor"
:
each class that is not the class with the fewest samples), PipeOpClassBalancing
either throws out
samples (downsampling), or adds additional rows that are equal to randomly chosen samples (upsampling),
until the number of samples for these classes equals the "target class count".
Uses task$filter()
to remove rows. When identical rows are added during upsampling, then the task$row_roles$use
can not be used
to duplicate rows because of [inaudible]; instead the task$rbind()
function is used, and
a new data.table
is attached that contains all rows that are being duplicated exactly as many times as they are being added.
Fields
Only fields inherited from PipeOpTaskPreproc
/PipeOp
.
Methods
Only methods inherited from PipeOpTaskPreproc
/PipeOp
.
See also
https://mlr-org.com/pipeops.html
Other PipeOps:
PipeOp
,
PipeOpEnsemble
,
PipeOpImpute
,
PipeOpTargetTrafo
,
PipeOpTaskPreproc
,
PipeOpTaskPreprocSimple
,
mlr_pipeops
,
mlr_pipeops_adas
,
mlr_pipeops_blsmote
,
mlr_pipeops_boxcox
,
mlr_pipeops_branch
,
mlr_pipeops_chunk
,
mlr_pipeops_classifavg
,
mlr_pipeops_classweights
,
mlr_pipeops_colapply
,
mlr_pipeops_collapsefactors
,
mlr_pipeops_colroles
,
mlr_pipeops_copy
,
mlr_pipeops_datefeatures
,
mlr_pipeops_encode
,
mlr_pipeops_encodeimpact
,
mlr_pipeops_encodelmer
,
mlr_pipeops_featureunion
,
mlr_pipeops_filter
,
mlr_pipeops_fixfactors
,
mlr_pipeops_histbin
,
mlr_pipeops_ica
,
mlr_pipeops_imputeconstant
,
mlr_pipeops_imputehist
,
mlr_pipeops_imputelearner
,
mlr_pipeops_imputemean
,
mlr_pipeops_imputemedian
,
mlr_pipeops_imputemode
,
mlr_pipeops_imputeoor
,
mlr_pipeops_imputesample
,
mlr_pipeops_kernelpca
,
mlr_pipeops_learner
,
mlr_pipeops_learner_pi_cvplus
,
mlr_pipeops_learner_quantiles
,
mlr_pipeops_missind
,
mlr_pipeops_modelmatrix
,
mlr_pipeops_multiplicityexply
,
mlr_pipeops_multiplicityimply
,
mlr_pipeops_mutate
,
mlr_pipeops_nearmiss
,
mlr_pipeops_nmf
,
mlr_pipeops_nop
,
mlr_pipeops_ovrsplit
,
mlr_pipeops_ovrunite
,
mlr_pipeops_pca
,
mlr_pipeops_proxy
,
mlr_pipeops_quantilebin
,
mlr_pipeops_randomprojection
,
mlr_pipeops_randomresponse
,
mlr_pipeops_regravg
,
mlr_pipeops_removeconstants
,
mlr_pipeops_renamecolumns
,
mlr_pipeops_replicate
,
mlr_pipeops_rowapply
,
mlr_pipeops_scale
,
mlr_pipeops_scalemaxabs
,
mlr_pipeops_scalerange
,
mlr_pipeops_select
,
mlr_pipeops_smote
,
mlr_pipeops_smotenc
,
mlr_pipeops_spatialsign
,
mlr_pipeops_subsample
,
mlr_pipeops_targetinvert
,
mlr_pipeops_targetmutate
,
mlr_pipeops_targettrafoscalerange
,
mlr_pipeops_textvectorizer
,
mlr_pipeops_threshold
,
mlr_pipeops_tomek
,
mlr_pipeops_tunethreshold
,
mlr_pipeops_unbranch
,
mlr_pipeops_updatetarget
,
mlr_pipeops_vtreat
,
mlr_pipeops_yeojohnson
Examples
library("mlr3")
task = tsk("spam")
opb = po("classbalancing")
# target class counts
table(task$truth())
#>
#> spam nonspam
#> 1813 2788
# double the instances in the minority class (spam)
opb$param_set$values = list(ratio = 2, reference = "minor",
adjust = "minor", shuffle = FALSE)
result = opb$train(list(task))[[1L]]
table(result$truth())
#>
#> spam nonspam
#> 3626 2788
# up or downsample all classes until exactly 20 per class remain
opb$param_set$values = list(ratio = 20, reference = "one",
adjust = "all", shuffle = FALSE)
result = opb$train(list(task))[[1]]
table(result$truth())
#>
#> spam nonspam
#> 20 20