Transform and Re-Transform the Target Variable
Source:R/pipeline_targettrafo.R
mlr_graphs_targettrafo.Rd
Wraps a Graph
that transforms a target during training and inverts the transformation
during prediction. This is done as follows:
Specify a transformation and inversion function using any subclass of
PipeOpTargetTrafo
, defaults toPipeOpTargetMutate
, afterwards applygraph
.At the very end, during prediction the transformation is inverted using
PipeOpTargetInvert
.To set a transformation and inversion function for
PipeOpTargetMutate
see the parameterstrafo
andinverter
of theparam_set
of the resultingGraph
.Note that the input
graph
is not explicitly checked to actually return aPrediction
during prediction.
All input arguments are cloned and have no references in common with the returned Graph
.
Usage
pipeline_targettrafo(
graph,
trafo_pipeop = PipeOpTargetMutate$new(),
id_prefix = ""
)
Arguments
- graph
PipeOpLearner
|Graph
APipeOpLearner
orGraph
to wrap between a transformation and re-transformation of the target variable.- trafo_pipeop
PipeOp
APipeOp
that is a subclass ofPipeOpTargetTrafo
. Default isPipeOpTargetMutate
.- id_prefix
character(1)
Optional id prefix to prepend toPipeOpTargetInvert
ID. The resulting ID will be"[id_prefix]targetinvert"
. Default is""
.
Examples
library("mlr3")
tt = pipeline_targettrafo(PipeOpLearner$new(LearnerRegrRpart$new()))
tt$param_set$values$targetmutate.trafo = function(x) log(x, base = 2)
tt$param_set$values$targetmutate.inverter = function(x) list(response = 2 ^ x$response)
# gives the same as
g = Graph$new()
g$add_pipeop(PipeOpTargetMutate$new(param_vals = list(
trafo = function(x) log(x, base = 2),
inverter = function(x) list(response = 2 ^ x$response))
)
)
g$add_pipeop(LearnerRegrRpart$new())
g$add_pipeop(PipeOpTargetInvert$new())
g$add_edge(src_id = "targetmutate", dst_id = "targetinvert",
src_channel = 1, dst_channel = 1)
g$add_edge(src_id = "targetmutate", dst_id = "regr.rpart",
src_channel = 2, dst_channel = 1)
g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert",
src_channel = 1, dst_channel = 2)