
Transform and Re-Transform the Target Variable
Source:R/pipeline_targettrafo.R
mlr_graphs_targettrafo.RdWraps a Graph that transforms a target during training and inverts the transformation
during prediction. This is done as follows:
Specify a transformation and inversion function using any subclass of
PipeOpTargetTrafo, defaults toPipeOpTargetMutate, afterwards applygraph.At the very end, during prediction the transformation is inverted using
PipeOpTargetInvert.To set a transformation and inversion function for
PipeOpTargetMutatesee the parameterstrafoandinverterof theparam_setof the resultingGraph.Note that the input
graphis not explicitly checked to actually return aPredictionduring prediction.
All input arguments are cloned and have no references in common with the returned Graph.
Usage
pipeline_targettrafo(
graph,
trafo_pipeop = PipeOpTargetMutate$new(),
id_prefix = ""
)Arguments
- graph
PipeOpLearner|Graph
APipeOpLearnerorGraphto wrap between a transformation and re-transformation of the target variable.- trafo_pipeop
PipeOp
APipeOpthat is a subclass ofPipeOpTargetTrafo. Default isPipeOpTargetMutate.- id_prefix
character(1)
Optional id prefix to prepend toPipeOpTargetInvertID. The resulting ID will be"[id_prefix]targetinvert". Default is"".
Examples
library("mlr3")
tt = pipeline_targettrafo(PipeOpLearner$new(LearnerRegrRpart$new()))
tt$param_set$values$targetmutate.trafo = function(x) log(x, base = 2)
tt$param_set$values$targetmutate.inverter = function(x) list(response = 2 ^ x$response)
# gives the same as
g = Graph$new()
g$add_pipeop(PipeOpTargetMutate$new(param_vals = list(
trafo = function(x) log(x, base = 2),
inverter = function(x) list(response = 2 ^ x$response))
)
)
g$add_pipeop(LearnerRegrRpart$new())
g$add_pipeop(PipeOpTargetInvert$new())
g$add_edge(src_id = "targetmutate", dst_id = "targetinvert",
src_channel = 1, dst_channel = 1)
g$add_edge(src_id = "targetmutate", dst_id = "regr.rpart",
src_channel = 2, dst_channel = 1)
g$add_edge(src_id = "regr.rpart", dst_id = "targetinvert",
src_channel = 1, dst_channel = 2)