Skip to contents

A Learner that encapsulates a Graph to be used in mlr3 resampling and benchmarks.

The Graph must return a single Prediction on its $predict() call. The result of the $train() call is discarded, only the internal state changes during training are used.

The predict_type of a GraphLearner can be obtained or set via it's predict_type active binding. Setting a new predict type will try to set the predict_type in all relevant PipeOp / Learner encapsulated within the Graph. Similarly, the predict_type of a Graph will always be the smallest denominator in the Graph.

A GraphLearner is always constructed in an untrained state. When the graph argument has a non-NULL $state, it is ignored.

Format

R6Class object inheriting from mlr3::Learner.

Construction

GraphLearner$new(graph, id = NULL, param_vals = list(), task_type = NULL, predict_type = NULL)

  • graph :: Graph | PipeOp
    Graph to wrap. Can be a PipeOp, which is automatically converted to a Graph. This argument is usually cloned, unless clone_graph is FALSE; to access the Graph inside GraphLearner by-reference, use $graph.

  • id :: character(1) Identifier of the resulting Learner.

  • param_vals :: named list
    List of hyperparameter settings, overwriting the hyperparameter settings . Default list().

  • task_type :: character(1)
    What task_type the GraphLearner should have; usually automatically inferred for Graphs that are simple enough.

  • predict_type :: character(1)
    What predict_type the GraphLearner should have; usually automatically inferred for Graphs that are simple enough.

  • clone_graph :: logical(1)
    Whether to clone graph upon construction. Unintentionally changing graph by reference can lead to unexpected behaviour, so TRUE (default) is recommended. In particular, note that the $state of $graph is set to NULL by reference on construction of GraphLearner, during $train(), and during $predict() when clone_graph is FALSE.

Fields

Fields inherited from Learner, as well as:

  • graph :: Graph
    Graph that is being wrapped. This field contains the prototype of the Graph that is being trained, but does not contain the model. Use graph_model to access the trained Graph after $train(). Read-only.

  • graph_model :: Learner
    Graph that is being wrapped. This Graph contains a trained state after $train(). Read-only.

  • pipeops :: named list of PipeOp
    Contains all PipeOps in the underlying Graph, named by the PipeOp's $ids. Shortcut for $graph_model$pipeops. See Graph for details.

  • edges :: data.table with columns src_id (character), src_channel (character), dst_id (character), dst_channel (character)
    Table of connections between the PipeOps in the underlying Graph. Shortcut for $graph$edges. See Graph for details.

  • param_set :: ParamSet
    Parameters of the underlying Graph. Shortcut for $graph$param_set. See Graph for details.

  • pipeops_param_set :: named list()
    Named list containing the ParamSets of all PipeOps in the Graph. See there for details.

  • pipeops_param_set_values :: named list()
    Named list containing the set parameter values of all PipeOps in the Graph. See there for details.

  • internal_tuned_values :: named list() or NULL
    The internal tuned parameter values collected from all PipeOps. NULL is returned if the learner is not trained or none of the wrapped learners supports internal tuning.

  • internal_valid_scores :: named list() or NULL
    The internal validation scores as retrieved from the PipeOps. The names are prefixed with the respective IDs of the PipeOps. NULL is returned if the learner is not trained or none of the wrapped learners supports internal validation.

  • validate :: numeric(1), "predefined", "test" or NULL
    How to construct the validation data. This also has to be configured for the individual PipeOps such as PipeOpLearner, see set_validate.GraphLearner. For more details on the possible values, see mlr3::Learner.

  • marshaled :: logical(1)
    Whether the learner is marshaled.

  • impute_selected_features :: logical(1)
    Whether to heuristically determine $selected_features() as all $selected_features() of all "base learner" Learners, even if they do not have the "selected_features" property / do not implement $selected_features(). If impute_selected_features is TRUE and the base learners do not implement $selected_features(), the GraphLearner's $selected_features() method will return all features seen by the base learners. This is useful in cases where feature selection is performed inside the Graph: The $selected_features() will then be the set of features that were selected by the Graph. If impute_selected_features is FALSE, the $selected_features() method will throw an error if $selected_features() is not implemented by the base learners.
    This is a heuristic and may report more features than actually used by the base learners, in cases where the base learners do not implement $selected_features(). The default is FALSE.

Methods

Methods inherited from Learner, as well as:

  • ids(sorted = FALSE)
    (logical(1)) -> character
    Get IDs of all PipeOps. This is in order that PipeOps were added if sorted is FALSE, and topologically sorted if sorted is TRUE.

  • plot(html = FALSE, horizontal = FALSE)
    (logical(1), logical(1)) -> NULL
    Plot the Graph, using either the igraph package (for html = FALSE, default) or the visNetwork package for html = TRUE producing a htmlWidget. The htmlWidget can be rescaled using visOptions. For html = FALSE, the orientation of the plotted graph can be controlled through horizontal.

  • marshal
    (any) -> self
    Marshal the model.

  • unmarshal
    (any) -> self
    Unmarshal the model.

  • base_learner(recursive = Inf, return_po = FALSE, return_all = FALSE, resolve_branching = TRUE)
    (numeric(1), logical(1), logical(1), character(1)) -> Learner | PipeOp | list of Learner | list of PipeOp
    Return the base learner of the GraphLearner. If recursive is 0, the GraphLearner itself is returned. Otherwise, the Graph is traversed backwards to find the first PipeOp containing a $learner_model field. If recursive is 1, that $learner_model (or containing PipeOp, if return_po is TRUE) is returned. If recursive is greater than 1, the discovered base learner's base_learner() method is called with recursive - 1. recursive must be set to 1 if return_po is TRUE, and must be set to at most 1 if return_all is TRUE.
    If return_po is TRUE, the container-PipeOp is returned instead of the Learner. This will typically be a PipeOpLearner or a PipeOpLearnerCV.
    If return_all is TRUE, a list of Learners or PipeOps is returned. If return_po is FALSE, this list may contain Multiplicity objects, which are not unwrapped. If return_all is FALSE and there are multiple possible base learners, an error is thrown. This may also happen if only a single PipeOpLearner is present that was trained with a Multiplicity.
    If resolve_branching is TRUE, and when a PipeOpUnbranch is encountered, the corresponding PipeOpBranch is searched, and its hyperparameter configuration is used to select the base learner. There may be multiple corresponding PipeOpBranchs, which are all considered. If resolve_branching is FALSE, PipeOpUnbranch is treated as any other PipeOp with multiple inputs; all possible branch paths are considered equally.

The following standard extractors as defined by the Learner class are available. Note that these typically only extract information from the $base_learner(). This works well for simple Graphs that do not modify features too much, but may give unexpected results for Graphs that add new features or move information between features.

As an example, consider a feature A with missing values, and a feature B that is used for imputation, using a po("imputelearner"). In a case where the following Learner performs embedded feature selection and only selects feature A, the selected_features() method could return only feature A, and $importance() may even report 0 for feature B. This would not be entirely accurate when considering the entire GraphLearner, as feature B is used for imputation and would therefore have an impact on predictions. The following should therefore only be used if the Graph is known to not have an impact on the relevant properties.

  • importance()
    () -> numeric
    The $importance() returned by the base learner, if it has the "importance property. Throws an error otherwise.

  • selected_features()
    () -> character
    The $selected_features() returned by the base learner, if it has the "selected_features property. If the base learner does not have the "selected_features" property and impute_selected_features is TRUE, all features seen by the base learners are returned. Throws an error otherwise.

  • oob_error()
    () -> numeric(1)
    The $oob_error() returned by the base learner, if it has the "oob_error property. Throws an error otherwise.

  • loglik()
    () -> numeric(1)
    The $loglik() returned by the base learner, if it has the "loglik property. Throws an error otherwise.

Internals

as_graph() is called on the graph argument, so it can technically also be a list of things, which is automatically converted to a Graph via gunion(); however, this will usually not result in a valid Graph that can work as a Learner. graph can furthermore be a Learner, which is then automatically wrapped in a Graph, which is then again wrapped in a GraphLearner object; this usually only adds overhead and is not recommended.

See also

Other Learners: mlr_learners_avg

Examples

library("mlr3")

graph = po("pca") %>>% lrn("classif.rpart")

lr = GraphLearner$new(graph)
lr = as_learner(graph)  # equivalent

lr$train(tsk("iris"))

lr$graph$state  # untrained version!
#> $pca
#> NULL
#> 
#> $classif.rpart
#> NULL
#> 
# The following is therefore NULL:
lr$graph$pipeops$classif.rpart$learner_model$model
#> NULL

# To access the trained model from the PipeOpLearner's Learner, use:
lr$graph_model$pipeops$classif.rpart$learner_model$model
#> n= 150 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
#>   2) PC1< -1.553145 50   0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) PC1>=-1.553145 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) PC1< 1.142805 44   1 versicolor (0.00000000 0.97727273 0.02272727) *
#>     7) PC1>=1.142805 56   7 virginica (0.00000000 0.12500000 0.87500000) *

# Feature importance (of principal components):
lr$graph_model$pipeops$classif.rpart$learner_model$importance()
#>       PC1       PC2       PC3       PC4 
#> 85.795455 18.016529  5.694731  3.254132