Module artemis.interactions_methods.model_specific

Expand source code
from .gb_trees import SplitScoreMethod
from .random_forest import ConditionalMinimalDepthMethod

__all__ = ["SplitScoreMethod", "ConditionalMinimalDepthMethod"]

Sub-modules

artemis.interactions_methods.model_specific.gb_trees
artemis.interactions_methods.model_specific.random_forest

Classes

class ConditionalMinimalDepthMethod

Conditional Minimal Depth Method for Feature Interaction Extraction. It applies to tree-based models like Random Forests. Currently scikit-learn forest models are supported, i.e., RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier.

Attributes

method : str
Method name, used also for naming column with results in ovo pd.DataFrame.
visualizer : Visualizer
Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
ovo : pd.DataFrame
One versus one (pair) feature interaction values.
feature_importance : pd.DataFrame
Feature importance values.
model : object
Explained model.
features_included : List[str]
List of features for which interactions are calculated.
pairs : List[List[str]]
List of pairs of features for which interactions are calculated.

References

Constructor for ConditionalMinimalDepthMethod

Expand source code
class ConditionalMinimalDepthMethod(FeatureInteractionMethod):
    """
    Conditional Minimal Depth Method for Feature Interaction Extraction.
    It applies to tree-based models like Random Forests.
    Currently scikit-learn forest models are supported, i.e., RandomForestClassifier, RandomForestRegressor, 
    ExtraTreesRegressor, ExtraTreesClassifier. 

    Attributes
    ----------
    method : str 
        Method name, used also for naming column with results in `ovo` pd.DataFrame.
    visualizer : Visualizer
        Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
    ovo : pd.DataFrame 
        One versus one (pair) feature interaction values. 
    feature_importance : pd.DataFrame 
        Feature importance values.
    model : object
        Explained model.
    features_included: List[str]
        List of features for which interactions are calculated.
    pairs : List[List[str]]
        List of pairs of features for which interactions are calculated.

    References
    ----------
    - https://modeloriented.github.io/randomForestExplainer/
    - https://doi.org/10.1198/jasa.2009.tm08622
    """
    def __init__(self):
        """Constructor for ConditionalMinimalDepthMethod"""
        super().__init__(InteractionMethod.CONDITIONAL_MINIMAL_DEPTH)

    @property
    def _interactions_ascending_order(self):
        return True

    @property
    def _compare_ovo(self):
        if self.ovo is None:
            raise MethodNotFittedException(self.method)
        compare_ovo = self.ovo.copy().rename(columns={"root_variable": "Feature 1", "variable": "Feature 2"})
        compare_ovo['id'] = compare_ovo[["Feature 1", "Feature 2"]].apply(lambda x: "".join(sorted(x)), axis=1)
        return (compare_ovo.groupby("id")
                           .agg({"Feature 1": "first", "Feature 2": "first", self.method: "mean"})
                           .sort_values(InteractionMethod.CONDITIONAL_MINIMAL_DEPTH,
                                        ascending=self._interactions_ascending_order,
                                        ignore_index=True))


    def fit(
            self,
            model: Union[RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier],
            show_progress: bool = False,
    ):
        """Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model.

        Parameters
        ----------
        model : RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier
            Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, 
            ExtraTreesRegressor, or ExtraTreesClassifier.
        show_progress : bool
            If True, progress bar will be shown. Default is False.
        """
        self.features_included = model.feature_names_in_.tolist()
        self.pairs = list(combinations(self.features_included, 2))
        column_dict = _make_column_dict(model.feature_names_in_)
        self.raw_result_df, trees = _calculate_conditional_minimal_depths(model.estimators_, len(model.feature_names_in_), show_progress)
        self.ovo = _summarise_results(self.raw_result_df, column_dict, self.method, self._interactions_ascending_order)
        self._feature_importance_obj = MinimalDepthImportance()
        self.feature_importance = self._feature_importance_obj.importance(model,trees)

    def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs):
        """
        Plot results of explanations.

        There are five types of plots available:
        - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default)
        - bar_chart - bar chart of top feature interactions values
        - graph - graph of feature interactions values
        - summary - combination of heatmap, bar chart and graph plots
        - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance
        
        Parameters
        ----------
        vis_type : str 
            Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'.
        title : str 
            Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
        figsize : (float, float) 
            Size of plot. Default is (8, 6).
        **kwargs : Other Parameters
            Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 
            For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 
            See key parameters below. 

        Other Parameters
        ------------------------
        interaction_color_map : matplotlib colormap name or object, or list of colors
            Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r',
            depending on whether a greater value means a greater interaction strength or vice versa.
        importance_color_map :  matplotlib colormap name or object, or list of colors
            Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r',
            depending on whether a greater value means a greater interaction strength or vice versa.
        annot_fmt : str
            Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
        linewidths : float
            Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
        linecolor : str
            Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
        cbar_shrink : float
            Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 
    
        top_k : int 
            Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
        color : str 
            Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.

        n_highest_with_labels : int
            Used for 'graph' visualization. Top most important interactions to show as labels on edges.  Default is 5.
        edge_color: str
            Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
        node_color: str
            Used for 'graph' visualization. Color of nodes. Default is 'green'.
        node_size: int
            Used for 'graph' visualization. Size of the nodes (networkX scale).  Default is '1800'.
        font_color: str
            Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
        font_weight: str
            Used for 'graph' visualization. Font weight. Default is 'bold'.
        font_size: int
            Used for 'graph' visualization. Font size (networkX scale). Default is 10.
        threshold_relevant_interaction : float
            Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display
            corresponding edge on visualization. Default depends on the interaction method.
        
        top_k : int 
            Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15. 
        cmap : matplotlib colormap name or object.
            Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'. 
        color : str
            Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'. 
        """
        if self.ovo is None:
            raise MethodNotFittedException(self.method)

        self.visualizer.plot(self.ovo,
                             vis_type,
                             _feature_column_name_1="root_variable",
                             _feature_column_name_2="variable",
                             _directed=True,
                             feature_importance=self.feature_importance,
                             title = title,
                             figsize=figsize,
                             interactions_ascending_order=self._interactions_ascending_order,
                             importance_ascending_order=self._feature_importance_obj.importance_ascending_order,
                             **kwargs)

Ancestors

  • artemis.interactions_methods._method.FeatureInteractionMethod
  • abc.ABC

Methods

def fit(self, model: Union[sklearn.ensemble._forest.RandomForestClassifier, sklearn.ensemble._forest.RandomForestRegressor, sklearn.ensemble._forest.ExtraTreesRegressor, sklearn.ensemble._forest.ExtraTreesClassifier], show_progress: bool = False)

Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model.

Parameters

model : RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier
Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier.
show_progress : bool
If True, progress bar will be shown. Default is False.
Expand source code
def fit(
        self,
        model: Union[RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier],
        show_progress: bool = False,
):
    """Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model.

    Parameters
    ----------
    model : RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier
        Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, 
        ExtraTreesRegressor, or ExtraTreesClassifier.
    show_progress : bool
        If True, progress bar will be shown. Default is False.
    """
    self.features_included = model.feature_names_in_.tolist()
    self.pairs = list(combinations(self.features_included, 2))
    column_dict = _make_column_dict(model.feature_names_in_)
    self.raw_result_df, trees = _calculate_conditional_minimal_depths(model.estimators_, len(model.feature_names_in_), show_progress)
    self.ovo = _summarise_results(self.raw_result_df, column_dict, self.method, self._interactions_ascending_order)
    self._feature_importance_obj = MinimalDepthImportance()
    self.feature_importance = self._feature_importance_obj.importance(model,trees)
def plot(self, vis_type: str = 'heatmap', title: str = 'default', figsize: Tuple[float, float] = (8, 6), **kwargs)

Plot results of explanations.

There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance

Parameters

vis_type : str
Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'.
title : str
Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
figsize : (float, float)
Size of plot. Default is (8, 6).
**kwargs : Other Parameters
Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below.

Other Parameters

interaction_color_map : matplotlib colormap name or object, or list of colors
Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa.
importance_color_map :  matplotlib colormap name or object, or list of colors
Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa.
annot_fmt : str
Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
linewidths : float
Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
linecolor : str
Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
cbar_shrink : float
Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1.
top_k : int
Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
color : str
Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.
n_highest_with_labels : int
Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5.
edge_color : str
Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
node_color : str
Used for 'graph' visualization. Color of nodes. Default is 'green'.
node_size : int
Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'.
font_color : str
Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
font_weight : str
Used for 'graph' visualization. Font weight. Default is 'bold'.
font_size : int
Used for 'graph' visualization. Font size (networkX scale). Default is 10.
threshold_relevant_interaction : float
Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method.
top_k : int
Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15.
cmap : matplotlib colormap name or object.
Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'.
color : str
Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'.
Expand source code
def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs):
    """
    Plot results of explanations.

    There are five types of plots available:
    - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default)
    - bar_chart - bar chart of top feature interactions values
    - graph - graph of feature interactions values
    - summary - combination of heatmap, bar chart and graph plots
    - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance
    
    Parameters
    ----------
    vis_type : str 
        Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'.
    title : str 
        Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
    figsize : (float, float) 
        Size of plot. Default is (8, 6).
    **kwargs : Other Parameters
        Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 
        For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 
        See key parameters below. 

    Other Parameters
    ------------------------
    interaction_color_map : matplotlib colormap name or object, or list of colors
        Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r',
        depending on whether a greater value means a greater interaction strength or vice versa.
    importance_color_map :  matplotlib colormap name or object, or list of colors
        Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r',
        depending on whether a greater value means a greater interaction strength or vice versa.
    annot_fmt : str
        Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
    linewidths : float
        Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
    linecolor : str
        Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
    cbar_shrink : float
        Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 

    top_k : int 
        Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
    color : str 
        Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.

    n_highest_with_labels : int
        Used for 'graph' visualization. Top most important interactions to show as labels on edges.  Default is 5.
    edge_color: str
        Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
    node_color: str
        Used for 'graph' visualization. Color of nodes. Default is 'green'.
    node_size: int
        Used for 'graph' visualization. Size of the nodes (networkX scale).  Default is '1800'.
    font_color: str
        Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
    font_weight: str
        Used for 'graph' visualization. Font weight. Default is 'bold'.
    font_size: int
        Used for 'graph' visualization. Font size (networkX scale). Default is 10.
    threshold_relevant_interaction : float
        Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display
        corresponding edge on visualization. Default depends on the interaction method.
    
    top_k : int 
        Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15. 
    cmap : matplotlib colormap name or object.
        Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'. 
    color : str
        Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'. 
    """
    if self.ovo is None:
        raise MethodNotFittedException(self.method)

    self.visualizer.plot(self.ovo,
                         vis_type,
                         _feature_column_name_1="root_variable",
                         _feature_column_name_2="variable",
                         _directed=True,
                         feature_importance=self.feature_importance,
                         title = title,
                         figsize=figsize,
                         interactions_ascending_order=self._interactions_ascending_order,
                         importance_ascending_order=self._feature_importance_obj.importance_ascending_order,
                         **kwargs)
class SplitScoreMethod

Split Score Method for Feature Interaction Extraction. It applies to gradient boosting tree-based models. Currently model from LightGBM and XGBoost packages are supported.

Strength of interaction is defined by the metric selected by user (default is sum of gains).

Attributes

method : str
Method name, used also for naming column with results in ovo pd.DataFrame.
visualizer : Visualizer
Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
ovo : pd.DataFrame
One versus one (pair) feature interaction values.
feature_importance : pd.DataFrame
Feature importance values.
model : object
Explained model.
metric : str
Metric used to calculate strength of interactions.
features_included : List[str]
List of features for which interactions are calculated.
pairs : List[List[str]]
List of pairs of features for which interactions are calculated.

References

Constructor for SplitScoreMethod

Expand source code
class SplitScoreMethod(FeatureInteractionMethod):
    """
    Split Score Method for Feature Interaction Extraction.
    It applies to gradient boosting tree-based models.
    Currently model from LightGBM and XGBoost packages are supported. 

    Strength of interaction is defined by the metric selected by user (default is sum of gains).

    Attributes
    ----------
    method : str 
        Method name, used also for naming column with results in `ovo` pd.DataFrame.
    visualizer : Visualizer
        Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
    ovo : pd.DataFrame 
        One versus one (pair) feature interaction values. 
    feature_importance : pd.DataFrame 
        Feature importance values.
    model : object
        Explained model.
    metric : str 
        Metric used to calculate strength of interactions.
    features_included: List[str]
        List of features for which interactions are calculated.
    pairs : List[List[str]]
        List of pairs of features for which interactions are calculated.

    References
    ----------
    - https://modeloriented.github.io/EIX/
    """
    def __init__(self):
        """Constructor for SplitScoreMethod"""
        super().__init__(InteractionMethod.SPLIT_SCORE)
        self.metric = None

    @property
    def _interactions_ascending_order(self):
        return self.metric in _ASCENDING_ORDER_METRICS

    def fit(
        self,
        model: GBTreesHandler,
        show_progress: bool = False,
        interaction_selected_metric: str = SplitScoreInteractionMetric.MEAN_GAIN,
        importance_selected_metric: str = SplitScoreImportanceMetric.MEAN_GAIN,
        only_def_interactions: bool = True,
    ):
        """Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model.

        Parameters
        ----------
        model : GBTreesHandler
            Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted). 
        show_progress : bool
            If True, progress bar will be shown. Default is False.
        interaction_selected_metric : str 
            Metric used to calculate strength of interaction, 
            one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'.
        importance_selected_metric : str 
            Metric used to calculate feature importance, 
            one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 
            'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency'].
            Default is 'mean_gain'.
        only_def_interactions : bool 
            Whether to return only pair of sequential features that fulfill the definition of interaction 
            (with better split score for child feature).
        """
        if not isinstance(model, GBTreesHandler):
            model = GBTreesHandler(model)
        self.metric = interaction_selected_metric
        _check_metrics_with_available_info(model.package, interaction_selected_metric, importance_selected_metric)
        self.full_result = _calculate_full_result(
            model.trees_df, model.package, show_progress
        )
        self.full_ovo = _get_summary(self.full_result, only_def_interactions)
        self.ovo = _get_ovo(self, self.full_ovo, interaction_selected_metric)

        # calculate feature importance
        self._feature_importance_obj = SplitScoreImportance()
        self.feature_importance = self._feature_importance_obj.importance(
            model=model,
            selected_metric=importance_selected_metric,
            trees_df=model.trees_df,
        )

    def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs):
        """
        Plot results of explanations.

        There are five types of plots available:
        - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default)
        - bar_chart - bar chart of top feature interactions values
        - graph - graph of feature interactions values
        - summary - combination of heatmap, bar chart and graph plots
        - lolliplot - lolliplot for first k decision trees with split scores values.
        
        Parameters
        ----------
        vis_type : str 
            Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'.
        title : str 
            Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
        figsize : (float, float) 
            Size of plot. Default is (8, 6).
        **kwargs : Other Parameters
            Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 
            For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 
            See key parameters below. 
        
        Other Parameters
        ------------------------
        interaction_color_map : matplotlib colormap name or object, or list of colors
            Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r',
            depending on whether a greater value means a greater interaction strength or vice versa.
        importance_color_map :  matplotlib colormap name or object, or list of colors
            Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r',
            depending on whether a greater value means a greater interaction strength or vice versa.
        annot_fmt : str
            Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
        linewidths : float
            Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
        linecolor : str
            Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
        cbar_shrink : float
            Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 
    
        top_k : int 
            Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
        color : str 
            Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.

        n_highest_with_labels : int
            Used for 'graph' visualization. Top most important interactions to show as labels on edges.  Default is 5.
        edge_color: str
            Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
        node_color: str
            Used for 'graph' visualization. Color of nodes. Default is 'green'.
        node_size: int
            Used for 'graph' visualization. Size of the nodes (networkX scale).  Default is '1800'.
        font_color: str
            Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
        font_weight: str
            Used for 'graph' visualization. Font weight. Default is 'bold'.
        font_size: int
            Used for 'graph' visualization. Font size (networkX scale). Default is 10.
        threshold_relevant_interaction : float
            Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display
            corresponding edge on visualization. Default depends on the interaction method.

        max_trees : float
            Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2. 
        colors : List[str]
            Used for 'lolliplot' visualization. List of colors for nodes with successive depths. 
            Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]. 
        shapes : List[str]
            Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. 
            Default is ["o", ",", "v", "^", "<", ">"]. 
        max_depth : int
            Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1. 
        label_threshold : float
            Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1. 
        labels : bool
            Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True. 
        scale : str
            Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'. 
        """
        if self.ovo is None:
            raise MethodNotFittedException(self.method)
        self.visualizer.plot(self.ovo,
                             vis_type,
                             feature_importance=self.feature_importance,
                             title=title,
                             figsize=figsize,
                             interactions_ascending_order=self._interactions_ascending_order,
                             importance_ascending_order=self._feature_importance_obj.importance_ascending_order,
                             _full_result=self.full_result,
                             **kwargs)

Ancestors

  • artemis.interactions_methods._method.FeatureInteractionMethod
  • abc.ABC

Methods

def fit(self, model: artemis._utilities._handler.GBTreesHandler, show_progress: bool = False, interaction_selected_metric: str = 'mean_gain', importance_selected_metric: str = 'mean_gain', only_def_interactions: bool = True)

Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model.

Parameters

model : GBTreesHandler
Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted).
show_progress : bool
If True, progress bar will be shown. Default is False.
interaction_selected_metric : str
Metric used to calculate strength of interaction, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'.
importance_selected_metric : str
Metric used to calculate feature importance, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency']. Default is 'mean_gain'.
only_def_interactions : bool
Whether to return only pair of sequential features that fulfill the definition of interaction (with better split score for child feature).
Expand source code
def fit(
    self,
    model: GBTreesHandler,
    show_progress: bool = False,
    interaction_selected_metric: str = SplitScoreInteractionMetric.MEAN_GAIN,
    importance_selected_metric: str = SplitScoreImportanceMetric.MEAN_GAIN,
    only_def_interactions: bool = True,
):
    """Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model.

    Parameters
    ----------
    model : GBTreesHandler
        Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted). 
    show_progress : bool
        If True, progress bar will be shown. Default is False.
    interaction_selected_metric : str 
        Metric used to calculate strength of interaction, 
        one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'.
    importance_selected_metric : str 
        Metric used to calculate feature importance, 
        one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 
        'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency'].
        Default is 'mean_gain'.
    only_def_interactions : bool 
        Whether to return only pair of sequential features that fulfill the definition of interaction 
        (with better split score for child feature).
    """
    if not isinstance(model, GBTreesHandler):
        model = GBTreesHandler(model)
    self.metric = interaction_selected_metric
    _check_metrics_with_available_info(model.package, interaction_selected_metric, importance_selected_metric)
    self.full_result = _calculate_full_result(
        model.trees_df, model.package, show_progress
    )
    self.full_ovo = _get_summary(self.full_result, only_def_interactions)
    self.ovo = _get_ovo(self, self.full_ovo, interaction_selected_metric)

    # calculate feature importance
    self._feature_importance_obj = SplitScoreImportance()
    self.feature_importance = self._feature_importance_obj.importance(
        model=model,
        selected_metric=importance_selected_metric,
        trees_df=model.trees_df,
    )
def plot(self, vis_type: str = 'heatmap', title: str = 'default', figsize: Tuple[float, float] = (8, 6), **kwargs)

Plot results of explanations.

There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - lolliplot - lolliplot for first k decision trees with split scores values.

Parameters

vis_type : str
Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'.
title : str
Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
figsize : (float, float)
Size of plot. Default is (8, 6).
**kwargs : Other Parameters
Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below.

Other Parameters

interaction_color_map : matplotlib colormap name or object, or list of colors
Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa.
importance_color_map :  matplotlib colormap name or object, or list of colors
Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa.
annot_fmt : str
Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
linewidths : float
Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
linecolor : str
Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
cbar_shrink : float
Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1.
top_k : int
Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
color : str
Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.
n_highest_with_labels : int
Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5.
edge_color : str
Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
node_color : str
Used for 'graph' visualization. Color of nodes. Default is 'green'.
node_size : int
Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'.
font_color : str
Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
font_weight : str
Used for 'graph' visualization. Font weight. Default is 'bold'.
font_size : int
Used for 'graph' visualization. Font size (networkX scale). Default is 10.
threshold_relevant_interaction : float
Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method.
max_trees : float
Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2.
colors : List[str]
Used for 'lolliplot' visualization. List of colors for nodes with successive depths. Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"].
shapes : List[str]
Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. Default is ["o", ",", "v", "^", "<", ">"].
max_depth : int
Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1.
label_threshold : float
Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1.
labels : bool
Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True.
scale : str
Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'.
Expand source code
def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs):
    """
    Plot results of explanations.

    There are five types of plots available:
    - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default)
    - bar_chart - bar chart of top feature interactions values
    - graph - graph of feature interactions values
    - summary - combination of heatmap, bar chart and graph plots
    - lolliplot - lolliplot for first k decision trees with split scores values.
    
    Parameters
    ----------
    vis_type : str 
        Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'.
    title : str 
        Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
    figsize : (float, float) 
        Size of plot. Default is (8, 6).
    **kwargs : Other Parameters
        Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 
        For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 
        See key parameters below. 
    
    Other Parameters
    ------------------------
    interaction_color_map : matplotlib colormap name or object, or list of colors
        Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r',
        depending on whether a greater value means a greater interaction strength or vice versa.
    importance_color_map :  matplotlib colormap name or object, or list of colors
        Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r',
        depending on whether a greater value means a greater interaction strength or vice versa.
    annot_fmt : str
        Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
    linewidths : float
        Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
    linecolor : str
        Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
    cbar_shrink : float
        Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 

    top_k : int 
        Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
    color : str 
        Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.

    n_highest_with_labels : int
        Used for 'graph' visualization. Top most important interactions to show as labels on edges.  Default is 5.
    edge_color: str
        Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
    node_color: str
        Used for 'graph' visualization. Color of nodes. Default is 'green'.
    node_size: int
        Used for 'graph' visualization. Size of the nodes (networkX scale).  Default is '1800'.
    font_color: str
        Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
    font_weight: str
        Used for 'graph' visualization. Font weight. Default is 'bold'.
    font_size: int
        Used for 'graph' visualization. Font size (networkX scale). Default is 10.
    threshold_relevant_interaction : float
        Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display
        corresponding edge on visualization. Default depends on the interaction method.

    max_trees : float
        Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2. 
    colors : List[str]
        Used for 'lolliplot' visualization. List of colors for nodes with successive depths. 
        Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]. 
    shapes : List[str]
        Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. 
        Default is ["o", ",", "v", "^", "<", ">"]. 
    max_depth : int
        Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1. 
    label_threshold : float
        Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1. 
    labels : bool
        Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True. 
    scale : str
        Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'. 
    """
    if self.ovo is None:
        raise MethodNotFittedException(self.method)
    self.visualizer.plot(self.ovo,
                         vis_type,
                         feature_importance=self.feature_importance,
                         title=title,
                         figsize=figsize,
                         interactions_ascending_order=self._interactions_ascending_order,
                         importance_ascending_order=self._feature_importance_obj.importance_ascending_order,
                         _full_result=self.full_result,
                         **kwargs)