Module artemis.interactions_methods.model_specific
Expand source code
from .gb_trees import SplitScoreMethod
from .random_forest import ConditionalMinimalDepthMethod
__all__ = ["SplitScoreMethod", "ConditionalMinimalDepthMethod"]
Sub-modules
artemis.interactions_methods.model_specific.gb_trees
artemis.interactions_methods.model_specific.random_forest
Classes
class ConditionalMinimalDepthMethod
-
Conditional Minimal Depth Method for Feature Interaction Extraction. It applies to tree-based models like Random Forests. Currently scikit-learn forest models are supported, i.e., RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier.
Attributes
method
:str
- Method name, used also for naming column with results in
ovo
pd.DataFrame. visualizer
:Visualizer
- Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
ovo
:pd.DataFrame
- One versus one (pair) feature interaction values.
feature_importance
:pd.DataFrame
- Feature importance values.
model
:object
- Explained model.
features_included
:List[str]
- List of features for which interactions are calculated.
pairs
:List[List[str]]
- List of pairs of features for which interactions are calculated.
References
Constructor for ConditionalMinimalDepthMethod
Expand source code
class ConditionalMinimalDepthMethod(FeatureInteractionMethod): """ Conditional Minimal Depth Method for Feature Interaction Extraction. It applies to tree-based models like Random Forests. Currently scikit-learn forest models are supported, i.e., RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier. Attributes ---------- method : str Method name, used also for naming column with results in `ovo` pd.DataFrame. visualizer : Visualizer Object providing visualization. Automatically created on the basis of a method and used to create visualizations. ovo : pd.DataFrame One versus one (pair) feature interaction values. feature_importance : pd.DataFrame Feature importance values. model : object Explained model. features_included: List[str] List of features for which interactions are calculated. pairs : List[List[str]] List of pairs of features for which interactions are calculated. References ---------- - https://modeloriented.github.io/randomForestExplainer/ - https://doi.org/10.1198/jasa.2009.tm08622 """ def __init__(self): """Constructor for ConditionalMinimalDepthMethod""" super().__init__(InteractionMethod.CONDITIONAL_MINIMAL_DEPTH) @property def _interactions_ascending_order(self): return True @property def _compare_ovo(self): if self.ovo is None: raise MethodNotFittedException(self.method) compare_ovo = self.ovo.copy().rename(columns={"root_variable": "Feature 1", "variable": "Feature 2"}) compare_ovo['id'] = compare_ovo[["Feature 1", "Feature 2"]].apply(lambda x: "".join(sorted(x)), axis=1) return (compare_ovo.groupby("id") .agg({"Feature 1": "first", "Feature 2": "first", self.method: "mean"}) .sort_values(InteractionMethod.CONDITIONAL_MINIMAL_DEPTH, ascending=self._interactions_ascending_order, ignore_index=True)) def fit( self, model: Union[RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier], show_progress: bool = False, ): """Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model. Parameters ---------- model : RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier. show_progress : bool If True, progress bar will be shown. Default is False. """ self.features_included = model.feature_names_in_.tolist() self.pairs = list(combinations(self.features_included, 2)) column_dict = _make_column_dict(model.feature_names_in_) self.raw_result_df, trees = _calculate_conditional_minimal_depths(model.estimators_, len(model.feature_names_in_), show_progress) self.ovo = _summarise_results(self.raw_result_df, column_dict, self.method, self._interactions_ascending_order) self._feature_importance_obj = MinimalDepthImportance() self.feature_importance = self._feature_importance_obj.importance(model,trees) def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs): """ Plot results of explanations. There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance Parameters ---------- vis_type : str Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'. title : str Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. figsize : (float, float) Size of plot. Default is (8, 6). **kwargs : Other Parameters Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below. Other Parameters ------------------------ interaction_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa. importance_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa. annot_fmt : str Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. linewidths : float Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. linecolor : str Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. cbar_shrink : float Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. top_k : int Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10. color : str Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'. n_highest_with_labels : int Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. edge_color: str Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. node_color: str Used for 'graph' visualization. Color of nodes. Default is 'green'. node_size: int Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. font_color: str Used for 'graph' visualization. Font color. Default is '#3B1F2B'. font_weight: str Used for 'graph' visualization. Font weight. Default is 'bold'. font_size: int Used for 'graph' visualization. Font size (networkX scale). Default is 10. threshold_relevant_interaction : float Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method. top_k : int Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15. cmap : matplotlib colormap name or object. Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'. color : str Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'. """ if self.ovo is None: raise MethodNotFittedException(self.method) self.visualizer.plot(self.ovo, vis_type, _feature_column_name_1="root_variable", _feature_column_name_2="variable", _directed=True, feature_importance=self.feature_importance, title = title, figsize=figsize, interactions_ascending_order=self._interactions_ascending_order, importance_ascending_order=self._feature_importance_obj.importance_ascending_order, **kwargs)
Ancestors
- artemis.interactions_methods._method.FeatureInteractionMethod
- abc.ABC
Methods
def fit(self, model: Union[sklearn.ensemble._forest.RandomForestClassifier, sklearn.ensemble._forest.RandomForestRegressor, sklearn.ensemble._forest.ExtraTreesRegressor, sklearn.ensemble._forest.ExtraTreesClassifier], show_progress: bool = False)
-
Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model.
Parameters
model
:RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor,
orExtraTreesClassifier
- Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier.
show_progress
:bool
- If True, progress bar will be shown. Default is False.
Expand source code
def fit( self, model: Union[RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, ExtraTreesClassifier], show_progress: bool = False, ): """Calculates Conditional Smallest Depth Feature Interactions Strenght and Minimal Depth Feature Importance for given model. Parameters ---------- model : RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier Model to be explained. Should be fitted and of type RandomForestClassifier, RandomForestRegressor, ExtraTreesRegressor, or ExtraTreesClassifier. show_progress : bool If True, progress bar will be shown. Default is False. """ self.features_included = model.feature_names_in_.tolist() self.pairs = list(combinations(self.features_included, 2)) column_dict = _make_column_dict(model.feature_names_in_) self.raw_result_df, trees = _calculate_conditional_minimal_depths(model.estimators_, len(model.feature_names_in_), show_progress) self.ovo = _summarise_results(self.raw_result_df, column_dict, self.method, self._interactions_ascending_order) self._feature_importance_obj = MinimalDepthImportance() self.feature_importance = self._feature_importance_obj.importance(model,trees)
def plot(self, vis_type: str = 'heatmap', title: str = 'default', figsize: Tuple[float, float] = (8, 6), **kwargs)
-
Plot results of explanations.
There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance
Parameters
vis_type
:str
- Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'.
title
:str
- Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
figsize
:(float, float)
- Size of plot. Default is (8, 6).
**kwargs
:Other Parameters
- Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below.
Other Parameters
interaction_color_map
:matplotlib colormap name
orobject,
orlist
ofcolors
- Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa.
importance_color_map
:matplotlib colormap name
orobject,
orlist
ofcolors
- Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa.
annot_fmt
:str
- Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
linewidths
:float
- Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
linecolor
:str
- Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
cbar_shrink
:float
- Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1.
top_k
:int
- Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
color
:str
- Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.
n_highest_with_labels
:int
- Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5.
edge_color
:str
- Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
node_color
:str
- Used for 'graph' visualization. Color of nodes. Default is 'green'.
node_size
:int
- Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'.
font_color
:str
- Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
font_weight
:str
- Used for 'graph' visualization. Font weight. Default is 'bold'.
font_size
:int
- Used for 'graph' visualization. Font size (networkX scale). Default is 10.
threshold_relevant_interaction
:float
- Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method.
top_k
:int
- Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15.
- cmap : matplotlib colormap name or object.
- Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'.
color
:str
- Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'.
Expand source code
def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs): """ Plot results of explanations. There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - bar_chart_conditional - bar chart of top feature interactions with additional information about feature importance Parameters ---------- vis_type : str Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_conditional', 'summary']. Default is 'heatmap'. title : str Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. figsize : (float, float) Size of plot. Default is (8, 6). **kwargs : Other Parameters Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below. Other Parameters ------------------------ interaction_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa. importance_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa. annot_fmt : str Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. linewidths : float Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. linecolor : str Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. cbar_shrink : float Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. top_k : int Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10. color : str Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'. n_highest_with_labels : int Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. edge_color: str Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. node_color: str Used for 'graph' visualization. Color of nodes. Default is 'green'. node_size: int Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. font_color: str Used for 'graph' visualization. Font color. Default is '#3B1F2B'. font_weight: str Used for 'graph' visualization. Font weight. Default is 'bold'. font_size: int Used for 'graph' visualization. Font size (networkX scale). Default is 10. threshold_relevant_interaction : float Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method. top_k : int Used for 'bar_chart_conditional' visualization. Maximum number of pairs that will be presented in plot. Default is 15. cmap : matplotlib colormap name or object. Used for 'bar_chart_conditional' visualization. The mapping from number of pair occurences to color space. Default is 'Purples'. color : str Used for 'bar_chart_conditional' visualization. Color of lollipops for parent features. Default is 'black'. """ if self.ovo is None: raise MethodNotFittedException(self.method) self.visualizer.plot(self.ovo, vis_type, _feature_column_name_1="root_variable", _feature_column_name_2="variable", _directed=True, feature_importance=self.feature_importance, title = title, figsize=figsize, interactions_ascending_order=self._interactions_ascending_order, importance_ascending_order=self._feature_importance_obj.importance_ascending_order, **kwargs)
class SplitScoreMethod
-
Split Score Method for Feature Interaction Extraction. It applies to gradient boosting tree-based models. Currently model from LightGBM and XGBoost packages are supported.
Strength of interaction is defined by the metric selected by user (default is sum of gains).
Attributes
method
:str
- Method name, used also for naming column with results in
ovo
pd.DataFrame. visualizer
:Visualizer
- Object providing visualization. Automatically created on the basis of a method and used to create visualizations.
ovo
:pd.DataFrame
- One versus one (pair) feature interaction values.
feature_importance
:pd.DataFrame
- Feature importance values.
model
:object
- Explained model.
metric
:str
- Metric used to calculate strength of interactions.
features_included
:List[str]
- List of features for which interactions are calculated.
pairs
:List[List[str]]
- List of pairs of features for which interactions are calculated.
References
Constructor for SplitScoreMethod
Expand source code
class SplitScoreMethod(FeatureInteractionMethod): """ Split Score Method for Feature Interaction Extraction. It applies to gradient boosting tree-based models. Currently model from LightGBM and XGBoost packages are supported. Strength of interaction is defined by the metric selected by user (default is sum of gains). Attributes ---------- method : str Method name, used also for naming column with results in `ovo` pd.DataFrame. visualizer : Visualizer Object providing visualization. Automatically created on the basis of a method and used to create visualizations. ovo : pd.DataFrame One versus one (pair) feature interaction values. feature_importance : pd.DataFrame Feature importance values. model : object Explained model. metric : str Metric used to calculate strength of interactions. features_included: List[str] List of features for which interactions are calculated. pairs : List[List[str]] List of pairs of features for which interactions are calculated. References ---------- - https://modeloriented.github.io/EIX/ """ def __init__(self): """Constructor for SplitScoreMethod""" super().__init__(InteractionMethod.SPLIT_SCORE) self.metric = None @property def _interactions_ascending_order(self): return self.metric in _ASCENDING_ORDER_METRICS def fit( self, model: GBTreesHandler, show_progress: bool = False, interaction_selected_metric: str = SplitScoreInteractionMetric.MEAN_GAIN, importance_selected_metric: str = SplitScoreImportanceMetric.MEAN_GAIN, only_def_interactions: bool = True, ): """Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model. Parameters ---------- model : GBTreesHandler Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted). show_progress : bool If True, progress bar will be shown. Default is False. interaction_selected_metric : str Metric used to calculate strength of interaction, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'. importance_selected_metric : str Metric used to calculate feature importance, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency']. Default is 'mean_gain'. only_def_interactions : bool Whether to return only pair of sequential features that fulfill the definition of interaction (with better split score for child feature). """ if not isinstance(model, GBTreesHandler): model = GBTreesHandler(model) self.metric = interaction_selected_metric _check_metrics_with_available_info(model.package, interaction_selected_metric, importance_selected_metric) self.full_result = _calculate_full_result( model.trees_df, model.package, show_progress ) self.full_ovo = _get_summary(self.full_result, only_def_interactions) self.ovo = _get_ovo(self, self.full_ovo, interaction_selected_metric) # calculate feature importance self._feature_importance_obj = SplitScoreImportance() self.feature_importance = self._feature_importance_obj.importance( model=model, selected_metric=importance_selected_metric, trees_df=model.trees_df, ) def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs): """ Plot results of explanations. There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - lolliplot - lolliplot for first k decision trees with split scores values. Parameters ---------- vis_type : str Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'. title : str Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. figsize : (float, float) Size of plot. Default is (8, 6). **kwargs : Other Parameters Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below. Other Parameters ------------------------ interaction_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa. importance_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa. annot_fmt : str Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. linewidths : float Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. linecolor : str Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. cbar_shrink : float Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. top_k : int Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10. color : str Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'. n_highest_with_labels : int Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. edge_color: str Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. node_color: str Used for 'graph' visualization. Color of nodes. Default is 'green'. node_size: int Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. font_color: str Used for 'graph' visualization. Font color. Default is '#3B1F2B'. font_weight: str Used for 'graph' visualization. Font weight. Default is 'bold'. font_size: int Used for 'graph' visualization. Font size (networkX scale). Default is 10. threshold_relevant_interaction : float Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method. max_trees : float Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2. colors : List[str] Used for 'lolliplot' visualization. List of colors for nodes with successive depths. Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]. shapes : List[str] Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. Default is ["o", ",", "v", "^", "<", ">"]. max_depth : int Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1. label_threshold : float Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1. labels : bool Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True. scale : str Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'. """ if self.ovo is None: raise MethodNotFittedException(self.method) self.visualizer.plot(self.ovo, vis_type, feature_importance=self.feature_importance, title=title, figsize=figsize, interactions_ascending_order=self._interactions_ascending_order, importance_ascending_order=self._feature_importance_obj.importance_ascending_order, _full_result=self.full_result, **kwargs)
Ancestors
- artemis.interactions_methods._method.FeatureInteractionMethod
- abc.ABC
Methods
def fit(self, model: artemis._utilities._handler.GBTreesHandler, show_progress: bool = False, interaction_selected_metric: str = 'mean_gain', importance_selected_metric: str = 'mean_gain', only_def_interactions: bool = True)
-
Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model.
Parameters
model
:GBTreesHandler
- Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted).
show_progress
:bool
- If True, progress bar will be shown. Default is False.
interaction_selected_metric
:str
- Metric used to calculate strength of interaction, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'.
importance_selected_metric
:str
- Metric used to calculate feature importance, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency']. Default is 'mean_gain'.
only_def_interactions
:bool
- Whether to return only pair of sequential features that fulfill the definition of interaction (with better split score for child feature).
Expand source code
def fit( self, model: GBTreesHandler, show_progress: bool = False, interaction_selected_metric: str = SplitScoreInteractionMetric.MEAN_GAIN, importance_selected_metric: str = SplitScoreImportanceMetric.MEAN_GAIN, only_def_interactions: bool = True, ): """Calculates Split Score Feature Interactions Strength and Split Score Feature Importance for given model. Parameters ---------- model : GBTreesHandler Model to be explained. Should be fitted and of type GBTreesHandler (otherwise it will be converted). show_progress : bool If True, progress bar will be shown. Default is False. interaction_selected_metric : str Metric used to calculate strength of interaction, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth']. Default is 'mean_gain'. importance_selected_metric : str Metric used to calculate feature importance, one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency']. Default is 'mean_gain'. only_def_interactions : bool Whether to return only pair of sequential features that fulfill the definition of interaction (with better split score for child feature). """ if not isinstance(model, GBTreesHandler): model = GBTreesHandler(model) self.metric = interaction_selected_metric _check_metrics_with_available_info(model.package, interaction_selected_metric, importance_selected_metric) self.full_result = _calculate_full_result( model.trees_df, model.package, show_progress ) self.full_ovo = _get_summary(self.full_result, only_def_interactions) self.ovo = _get_ovo(self, self.full_ovo, interaction_selected_metric) # calculate feature importance self._feature_importance_obj = SplitScoreImportance() self.feature_importance = self._feature_importance_obj.importance( model=model, selected_metric=importance_selected_metric, trees_df=model.trees_df, )
def plot(self, vis_type: str = 'heatmap', title: str = 'default', figsize: Tuple[float, float] = (8, 6), **kwargs)
-
Plot results of explanations.
There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - lolliplot - lolliplot for first k decision trees with split scores values.
Parameters
vis_type
:str
- Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'.
title
:str
- Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type.
figsize
:(float, float)
- Size of plot. Default is (8, 6).
**kwargs
:Other Parameters
- Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below.
Other Parameters
interaction_color_map
:matplotlib colormap name
orobject,
orlist
ofcolors
- Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa.
importance_color_map
:matplotlib colormap name
orobject,
orlist
ofcolors
- Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa.
annot_fmt
:str
- Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'.
linewidths
:float
- Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5.
linecolor
:str
- Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'.
cbar_shrink
:float
- Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1.
top_k
:int
- Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10.
color
:str
- Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'.
n_highest_with_labels
:int
- Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5.
edge_color
:str
- Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple.
node_color
:str
- Used for 'graph' visualization. Color of nodes. Default is 'green'.
node_size
:int
- Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'.
font_color
:str
- Used for 'graph' visualization. Font color. Default is '#3B1F2B'.
font_weight
:str
- Used for 'graph' visualization. Font weight. Default is 'bold'.
font_size
:int
- Used for 'graph' visualization. Font size (networkX scale). Default is 10.
threshold_relevant_interaction
:float
- Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method.
max_trees
:float
- Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2.
colors
:List[str]
- Used for 'lolliplot' visualization. List of colors for nodes with successive depths. Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"].
shapes
:List[str]
- Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. Default is ["o", ",", "v", "^", "<", ">"].
max_depth
:int
- Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1.
label_threshold
:float
- Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1.
labels
:bool
- Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True.
scale
:str
- Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'.
Expand source code
def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", figsize: Tuple[float, float] = (8, 6), **kwargs): """ Plot results of explanations. There are five types of plots available: - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) - bar_chart - bar chart of top feature interactions values - graph - graph of feature interactions values - summary - combination of heatmap, bar chart and graph plots - lolliplot - lolliplot for first k decision trees with split scores values. Parameters ---------- vis_type : str Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'lolliplot', 'summary']. Default is 'heatmap'. title : str Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. figsize : (float, float) Size of plot. Default is (8, 6). **kwargs : Other Parameters Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. See key parameters below. Other Parameters ------------------------ interaction_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', depending on whether a greater value means a greater interaction strength or vice versa. importance_color_map : matplotlib colormap name or object, or list of colors Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', depending on whether a greater value means a greater interaction strength or vice versa. annot_fmt : str Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. linewidths : float Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. linecolor : str Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. cbar_shrink : float Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. top_k : int Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10. color : str Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'. n_highest_with_labels : int Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. edge_color: str Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. node_color: str Used for 'graph' visualization. Color of nodes. Default is 'green'. node_size: int Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. font_color: str Used for 'graph' visualization. Font color. Default is '#3B1F2B'. font_weight: str Used for 'graph' visualization. Font weight. Default is 'bold'. font_size: int Used for 'graph' visualization. Font size (networkX scale). Default is 10. threshold_relevant_interaction : float Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display corresponding edge on visualization. Default depends on the interaction method. max_trees : float Used for 'lolliplot' visualization. Fraction of trees that will be presented in plot. Default is 0.2. colors : List[str] Used for 'lolliplot' visualization. List of colors for nodes with successive depths. Default is ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"]. shapes : List[str] Used for 'lolliplot' visualization. List of shapes for nodes with successive depths. Default is ["o", ",", "v", "^", "<", ">"]. max_depth : int Used for 'lolliplot' visualization. Threshold for depth of nodes that will be presented in plot. Default is 1. label_threshold : float Used for 'lolliplot' visualization. Threshold for fraction of score of nodes that will be labeled in plot. Default is 0.1. labels : bool Used for 'lolliplot' visualization. Whether to add labels to plot. Default is True. scale : str Used for 'lolliplot' visualization. Scale for x axis (trees). Default is 'linear'. """ if self.ovo is None: raise MethodNotFittedException(self.method) self.visualizer.plot(self.ovo, vis_type, feature_importance=self.feature_importance, title=title, figsize=figsize, interactions_ascending_order=self._interactions_ascending_order, importance_ascending_order=self._feature_importance_obj.importance_ascending_order, _full_result=self.full_result, **kwargs)