A Learner that encapsulates a Graph to be used in
mlr3 resampling and benchmarks.
The Graph must return a single Prediction on its $predict()
call. The result of the $train() call is discarded, only the
internal state changes during training are used.
The predict_type of a GraphLearner can be obtained or set via it's predict_type active binding.
Setting a new predict type will try to set the predict_type in all relevant
PipeOp / Learner encapsulated within the Graph.
Similarly, the predict_type of a Graph will always be the smallest denominator in the Graph.
A GraphLearner is always constructed in an untrained state. When the graph argument has a
non-NULL $state, it is ignored.
Format
R6Class object inheriting from mlr3::Learner.
Construction
graph::Graph|PipeOpGraphto wrap. Can be aPipeOp, which is automatically converted to aGraph. This argument is usually cloned, unlessclone_graphisFALSE; to access theGraphinsideGraphLearnerby-reference, use$graph.id::character(1)Identifier of the resultingLearner.param_vals:: namedlist
List of hyperparameter settings, overwriting the hyperparameter settings . Defaultlist().task_type::character(1)
Whattask_typetheGraphLearnershould have; usually automatically inferred forGraphs that are simple enough.predict_type::character(1)
Whatpredict_typetheGraphLearnershould have; usually automatically inferred forGraphs that are simple enough.clone_graph::logical(1)
Whether to clonegraphupon construction. Unintentionally changinggraphby reference can lead to unexpected behaviour, soTRUE(default) is recommended. In particular, note that the$stateof$graphis set toNULLby reference on construction ofGraphLearner, during$train(), and during$predict()whenclone_graphisFALSE.
Fields
Fields inherited from Learner, as well as:
graph::GraphGraphthat is being wrapped. This field contains the prototype of theGraphthat is being trained, but does not contain the model. Usegraph_modelto access the trainedGraphafter$train(). Read-only.graph_model::LearnerGraphthat is being wrapped. ThisGraphcontains a trained state after$train(). Read-only.pipeops:: namedlistofPipeOp
Contains allPipeOps in the underlyingGraph, named by thePipeOp's$ids. Shortcut for$graph_model$pipeops. SeeGraphfor details.edges::data.tablewith columnssrc_id(character),src_channel(character),dst_id(character),dst_channel(character)
Table of connections between thePipeOps in the underlyingGraph. Shortcut for$graph$edges. SeeGraphfor details.param_set::ParamSet
Parameters of the underlyingGraph. Shortcut for$graph$param_set. SeeGraphfor details.pipeops_param_set:: namedlist()
Named list containing theParamSets of allPipeOps in theGraph. See there for details.pipeops_param_set_values:: namedlist()
Named list containing the set parameter values of allPipeOps in theGraph. See there for details.internal_tuned_values:: namedlist()orNULL
The internal tuned parameter values collected from allPipeOps.NULLis returned if the learner is not trained or none of the wrapped learners supports internal tuning.internal_valid_scores:: namedlist()orNULL
The internal validation scores as retrieved from thePipeOps. The names are prefixed with the respective IDs of thePipeOps.NULLis returned if the learner is not trained or none of the wrapped learners supports internal validation.validate::numeric(1),"predefined","test"orNULL
How to construct the validation data. This also has to be configured for the individualPipeOps such asPipeOpLearner, seeset_validate.GraphLearner. For more details on the possible values, seemlr3::Learner.marshaled::logical(1)
Whether the learner is marshaled.impute_selected_features::logical(1)
Whether to heuristically determine$selected_features()as all$selected_features()of all "base learner" Learners, even if they do not have the"selected_features"property / do not implement$selected_features(). Ifimpute_selected_featuresisTRUEand the base learners do not implement$selected_features(), theGraphLearner's$selected_features()method will return all features seen by the base learners. This is useful in cases where feature selection is performed inside theGraph: The$selected_features()will then be the set of features that were selected by theGraph. Ifimpute_selected_featuresisFALSE, the$selected_features()method will throw an error if$selected_features()is not implemented by the base learners.
This is a heuristic and may report more features than actually used by the base learners, in cases where the base learners do not implement$selected_features(). The default isFALSE.
Methods
Methods inherited from Learner, as well as:
ids(sorted = FALSE)
(logical(1)) ->character
Get IDs of allPipeOps. This is in order thatPipeOps were added ifsortedisFALSE, and topologically sorted ifsortedisTRUE.plot(html = FALSE, horizontal = FALSE)
(logical(1),logical(1)) ->NULL
Plot theGraph, using either the igraph package (forhtml = FALSE, default) or thevisNetworkpackage forhtml = TRUEproducing ahtmlWidget. ThehtmlWidgetcan be rescaled usingvisOptions. Forhtml = FALSE, the orientation of the plotted graph can be controlled throughhorizontal.marshal
(any) ->self
Marshal the model.unmarshal
(any) ->self
Unmarshal the model.base_learner(recursive = Inf, return_po = FALSE, return_all = FALSE, resolve_branching = TRUE)
(numeric(1),logical(1),logical(1),character(1)) ->Learner|PipeOp|listofLearner|listofPipeOp
Return the base learner of theGraphLearner. Ifrecursiveis 0, theGraphLearneritself is returned. Otherwise, theGraphis traversed backwards to find the firstPipeOpcontaining a$learner_modelfield. Ifrecursiveis 1, that$learner_model(or containingPipeOp, ifreturn_poisTRUE) is returned. Ifrecursiveis greater than 1, the discovered base learner'sbase_learner()method is called withrecursive - 1.recursivemust be set to 1 ifreturn_pois TRUE, and must be set to at most 1 ifreturn_allisTRUE.
Ifreturn_poisTRUE, the container-PipeOpis returned instead of theLearner. This will typically be aPipeOpLearneror aPipeOpLearnerCV.
Ifreturn_allisTRUE, alistofLearners orPipeOps is returned. Ifreturn_poisFALSE, this list may containMultiplicityobjects, which are not unwrapped. Ifreturn_allisFALSEand there are multiple possible base learners, an error is thrown. This may also happen if only a singlePipeOpLearneris present that was trained with aMultiplicity.
Ifresolve_branchingisTRUE, and when aPipeOpUnbranchis encountered, the correspondingPipeOpBranchis searched, and its hyperparameter configuration is used to select the base learner. There may be multiple correspondingPipeOpBranchs, which are all considered. Ifresolve_branchingisFALSE,PipeOpUnbranchis treated as any otherPipeOpwith multiple inputs; all possible branch paths are considered equally.
The following standard extractors as defined by the Learner class are available.
Note that these typically only extract information from the $base_learner().
This works well for simple Graphs that do not modify features too much, but may give unexpected results for Graphs that
add new features or move information between features.
As an example, consider a feature A with missing values, and a feature B that is used for imputation, using a po("imputelearner").
In a case where the following Learner performs embedded feature selection and only selects feature A,
the selected_features() method could return only feature A, and $importance() may even report 0 for feature B.
This would not be entirely accurate when considering the entire GraphLearner, as feature B is used for imputation and would therefore have an impact on predictions.
The following should therefore only be used if the Graph is known to not have an impact on the relevant properties.
importance()
() ->numeric
The$importance()returned by the base learner, if it has the"importanceproperty. Throws an error otherwise.selected_features()
() ->character
The$selected_features()returned by the base learner, if it has the"selected_featuresproperty. If the base learner does not have the"selected_features"property andimpute_selected_featuresisTRUE, all features seen by the base learners are returned. Throws an error otherwise.oob_error()
() ->numeric(1)
The$oob_error()returned by the base learner, if it has the"oob_errorproperty. Throws an error otherwise.loglik()
() ->numeric(1)
The$loglik()returned by the base learner, if it has the"loglikproperty. Throws an error otherwise.
Internals
as_graph() is called on the graph argument, so it can technically also be a list of things, which is
automatically converted to a Graph via gunion(); however, this will usually not result in a valid Graph that can
work as a Learner. graph can furthermore be a Learner, which is then automatically
wrapped in a Graph, which is then again wrapped in a GraphLearner object; this usually only adds overhead and is not
recommended.
See also
Other Learners:
mlr_learners_avg
Examples
library("mlr3")
graph = po("pca") %>>% lrn("classif.rpart")
lr = GraphLearner$new(graph)
lr = as_learner(graph) # equivalent
lr$train(tsk("iris"))
lr$graph$state # untrained version!
#> $pca
#> NULL
#>
#> $classif.rpart
#> NULL
#>
# The following is therefore NULL:
lr$graph$pipeops$classif.rpart$learner_model$model
#> NULL
# To access the trained model from the PipeOpLearner's Learner, use:
lr$graph_model$pipeops$classif.rpart$learner_model$model
#> n= 150
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)
#> 2) PC1< -1.553145 50 0 setosa (1.00000000 0.00000000 0.00000000) *
#> 3) PC1>=-1.553145 100 50 versicolor (0.00000000 0.50000000 0.50000000)
#> 6) PC1< 1.142805 44 1 versicolor (0.00000000 0.97727273 0.02272727) *
#> 7) PC1>=1.142805 56 7 virginica (0.00000000 0.12500000 0.87500000) *
# Feature importance (of principal components):
lr$graph_model$pipeops$classif.rpart$learner_model$importance()
#> PC1 PC2 PC3 PC4
#> 85.795455 18.016529 5.694731 3.254132
