Skip to contents

Both undersamples a Task to keep only a fraction of the rows of the majority class, as well as oversamples (repeats data points) rows of the minority class.

Sampling happens only during training phase. Class-balancing a Task by sampling may be beneficial for classification with imbalanced training data.

Format

R6Class object inheriting from PipeOpTaskPreproc/PipeOp.

Construction

PipeOpClassBalancing$new(id = "classbalancing", param_vals = list())

  • id :: character(1) Identifier of the resulting object, default "classbalancing"

  • param_vals :: named list
    List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default list().

Input and Output Channels

Input and output channels are inherited from PipeOpTaskPreproc. Instead of a Task, a TaskClassif is used as input and output during training and prediction.

The output during training is the input Task with added or removed rows to balance target classes. The output during prediction is the unchanged input.

State

The $state is a named list with the $state elements inherited from PipeOpTaskPreproc.

Parameters

The parameters are the parameters inherited from PipeOpTaskPreproc; however, the affect_columns parameter is not present. Further parameters are:

  • ratio :: numeric(1)
    Ratio of number of rows of classes to keep, relative to the $reference value. Initialized to 1.

  • reference :: numeric(1)
    What the $ratio value is measured against. Can be "all" (mean instance count of all classes), "major" (instance count of class with most instances), "minor" (instance count of class with fewest instances), "nonmajor" (average instance count of all classes except the major one), "nonminor" (average instance count of all classes except the minor one), and "one" ($ratio determines the number of instances to have, per class). Initialized to "all".

  • adjust :: numeric(1)
    Which classes to up / downsample. Can be "all" (up and downsample all to match required instance count), "major", "minor", "nonmajor", "nonminor" (see respective values for $reference), "upsample" (only upsample), and "downsample". Initialized to "all".

  • shuffle :: logical(1)
    Whether to shuffle the rows of the resulting task. In case the data is upsampled and shuffle = FALSE, the resulting task will have the original rows (which were not removed in downsampling) in the original order, followed by all newly added rows ordered by target class. Initialized to TRUE.

Internals

Up / downsampling happens as follows: At first, a "target class count" is calculated, by taking the mean class count of all classes indicated by the reference parameter (e.g. if reference is "nonmajor": the mean class count of all classes that are not the "major" class, i.e. the class with the most samples) and multiplying this with the value of the ratio parameter. If reference is "one", then the "target class count" is just the value of ratio (i.e. 1 * ratio).

Then for each class that is referenced by the adjust parameter (e.g. if adjust is "nonminor": each class that is not the class with the fewest samples), PipeOpClassBalancing either throws out samples (downsampling), or adds additional rows that are equal to randomly chosen samples (upsampling), until the number of samples for these classes equals the "target class count".

Uses task$filter() to remove rows. When identical rows are added during upsampling, then the task$row_roles$use can not be used to duplicate rows because of [inaudible]; instead the task$rbind() function is used, and a new data.table is attached that contains all rows that are being duplicated exactly as many times as they are being added.

Fields

Only fields inherited from PipeOpTaskPreproc/PipeOp.

Methods

Only methods inherited from PipeOpTaskPreproc/PipeOp.

See also

https://mlr-org.com/pipeops.html

Other PipeOps: PipeOpEnsemble, PipeOpImpute, PipeOpTargetTrafo, PipeOpTaskPreprocSimple, PipeOpTaskPreproc, PipeOp, mlr_pipeops_boxcox, mlr_pipeops_branch, mlr_pipeops_chunk, mlr_pipeops_classifavg, mlr_pipeops_classweights, mlr_pipeops_colapply, mlr_pipeops_collapsefactors, mlr_pipeops_colroles, mlr_pipeops_copy, mlr_pipeops_datefeatures, mlr_pipeops_encodeimpact, mlr_pipeops_encodelmer, mlr_pipeops_encode, mlr_pipeops_featureunion, mlr_pipeops_filter, mlr_pipeops_fixfactors, mlr_pipeops_histbin, mlr_pipeops_ica, mlr_pipeops_imputeconstant, mlr_pipeops_imputehist, mlr_pipeops_imputelearner, mlr_pipeops_imputemean, mlr_pipeops_imputemedian, mlr_pipeops_imputemode, mlr_pipeops_imputeoor, mlr_pipeops_imputesample, mlr_pipeops_kernelpca, mlr_pipeops_learner, mlr_pipeops_missind, mlr_pipeops_modelmatrix, mlr_pipeops_multiplicityexply, mlr_pipeops_multiplicityimply, mlr_pipeops_mutate, mlr_pipeops_nmf, mlr_pipeops_nop, mlr_pipeops_ovrsplit, mlr_pipeops_ovrunite, mlr_pipeops_pca, mlr_pipeops_proxy, mlr_pipeops_quantilebin, mlr_pipeops_randomprojection, mlr_pipeops_randomresponse, mlr_pipeops_regravg, mlr_pipeops_removeconstants, mlr_pipeops_renamecolumns, mlr_pipeops_replicate, mlr_pipeops_scalemaxabs, mlr_pipeops_scalerange, mlr_pipeops_scale, mlr_pipeops_select, mlr_pipeops_smote, mlr_pipeops_spatialsign, mlr_pipeops_subsample, mlr_pipeops_targetinvert, mlr_pipeops_targetmutate, mlr_pipeops_targettrafoscalerange, mlr_pipeops_textvectorizer, mlr_pipeops_threshold, mlr_pipeops_tunethreshold, mlr_pipeops_unbranch, mlr_pipeops_updatetarget, mlr_pipeops_vtreat, mlr_pipeops_yeojohnson, mlr_pipeops

Examples

library("mlr3")

task = tsk("spam")
opb = po("classbalancing")

# target class counts
table(task$truth())
#> 
#>    spam nonspam 
#>    1813    2788 

# double the instances in the minority class (spam)
opb$param_set$values = list(ratio = 2, reference = "minor",
  adjust = "minor", shuffle = FALSE)
result = opb$train(list(task))[[1L]]
table(result$truth())
#> 
#>    spam nonspam 
#>    3626    2788 

# up or downsample all classes until exactly 20 per class remain
opb$param_set$values = list(ratio = 20, reference = "one",
  adjust = "all", shuffle = FALSE)
result = opb$train(list(task))[[1]]
table(result$truth())
#> 
#>    spam nonspam 
#>      20      20