Source code for gtda.mapper.visualization

"""Static and interactive visualisation functions for Mapper graphs."""
# License: GNU AGPLv3

import logging
import traceback

import numpy as np
import plotly.graph_objects as go
from ipywidgets import Layout, widgets
from sklearn.base import clone

from .utils._logging import OutputWidgetHandler
from .utils.visualization import (_calculate_graph_data,
                                  _get_column_color_buttons)


[docs]def plot_static_mapper_graph( pipeline, data, layout='kamada_kawai', layout_dim=2, color_variable=None, node_color_statistic=None, color_by_columns_dropdown=False, plotly_kwargs=None, clone_pipeline=True): """Plotting function for static Mapper graphs. Nodes are colored according to :attr:`color_variable`. By default, the hovertext displays a globally unique ID and the number of elements associated with a given node. Parameters ---------- pipeline : :class:`~gtda.mapper.pipeline.MapperPipeline` object Mapper pipeline to act on to data. data : array-like of shape (n_samples, n_features) Data used to generate the Mapper graph. Can be a pandas dataframe. layout : None, str or callable, optional, default: ``'kamada-kawai'`` Layout algorithm for the graph. Can be any accepted value for the ``layout`` parameter in the :meth:`layout` method of :class:`igraph.Graph`. [1]_ layout_dim : int, default: ``2`` The number of dimensions for the layout. Can be 2 or 3. color_variable : object or None, optional, default: ``None`` Specifies which quantity is to be used for node coloring. 1. If a numpy ndarray or pandas dataframe, `color_variable` must have the same length as `data` and is interpreted as a quantity of interest according to which node of the Mapper graph is to be colored (see `node_color_statistic`). 2. If ``None`` then equivalent to passing `data`. 3. If an object implementing :meth:`transform` or :meth:`fit_transform`, e.g. a scikit-learn estimator or pipeline, it is applied to `data` to generate the quantity of interest. 4. If an index or string, or list of indices / strings, equivalent to selecting a column or subset of columns from `data`. node_color_statistic : None, callable, or ndarray of shape (n_nodes,) or \ (n_nodes, 1), optional, default: ``None`` Specifies how to determine the colors of each node. If a numpy array, it must have the same length as the number of nodes in the Mapper graph, and its values are used directly for node coloring, ignoring `color_variable`. Otherwise, it can be a callable object which is used to obtain a summary statistic, within each Mapper node, of the quantity specified by `color_variable`. The default value ``None`` is equivalent to passing ``numpy.mean``. color_by_columns_dropdown : bool, optional, default: ``False`` If ``True``, a dropdown widget is generated which allows the user to color Mapper nodes according to any column in `data`. plotly_kwargs : dict, optional, default: ``None`` Keyword arguments to configure the Plotly Figure. clone_pipeline : bool, optional, default: ``True`` If ``True``, the input `pipeline` is cloned before computing the Mapper graph to prevent unexpected side effects from in-place parameter updates. Returns ------- fig : :class:`plotly.graph_objects.Figure` object Figure representing the Mapper graph with appropriate node colouring and size. References ---------- .. [1] `igraph.Graph.layout <https://igraph.org/python/doc/igraph.Graph-class.html#layout>`_ documentation. """ # Compute the graph and fetch the indices of points in each node if clone_pipeline: pipe = clone(pipeline) else: pipe = pipeline if node_color_statistic is not None: _node_color_statistic = node_color_statistic else: _node_color_statistic = np.mean # Simple duck typing to determine whether data is a pandas dataframe is_data_dataframe = hasattr(data, 'columns') node_trace, edge_trace, node_elements, _node_colors, plot_options = \ _calculate_graph_data( pipe, data, layout, layout_dim, color_variable, _node_color_statistic, plotly_kwargs) # Define layout options that are common to 2D and 3D figures layout_options_common = go.Layout( showlegend=plot_options['layout_showlegend'], hovermode=plot_options['layout_hovermode'], margin=plot_options['layout_margin'], autosize=False ) fig = go.FigureWidget(data=[edge_trace, node_trace], layout=layout_options_common) if layout_dim == 2: layout_options_2d = { 'layout_xaxis': plot_options['layout_xaxis'], 'layout_xaxis_title': plot_options['layout_xaxis_title'], 'layout_yaxis': plot_options['layout_yaxis'], 'layout_yaxis_title': plot_options['layout_yaxis_title'], 'layout_template': 'simple_white', } fig.update(layout_options_2d) elif layout_dim == 3: layout_options_3d = { 'layout_scene': plot_options['layout_scene'], 'layout_annotations': plot_options['layout_annotations'], } fig.update(layout_options_3d) # Compute node colors according to data columns only if necessary if color_by_columns_dropdown: column_color_buttons = _get_column_color_buttons( data, is_data_dataframe, node_elements, _node_colors, plot_options['node_trace_marker_colorscale']) else: column_color_buttons = None button_height = 1.1 fig.update_layout( updatemenus=[ go.layout.Updatemenu( buttons=column_color_buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.11, xanchor='left', y=button_height, yanchor="top" ), ]) if color_by_columns_dropdown: fig.add_annotation( go.layout.Annotation(text="Color by:", x=0, xref="paper", y=button_height - 0.045, yref="paper", align="left", showarrow=False) ) return fig
[docs]def plot_interactive_mapper_graph(pipeline, data, layout='kamada_kawai', layout_dim=2, color_variable=None, node_color_statistic=None, color_by_columns_dropdown=False, plotly_kwargs=None): """Plotting function for interactive Mapper graphs. Provides functionality to interactively update parameters from the cover and clustering steps defined in :attr:`pipeline`. Nodes are colored according to :attr:`color_variable`. By default, the hovertext displays a globally unique ID and the number of elements associated with a given node. Parameters ---------- pipeline : :class:`~gtda.mapper.pipeline.MapperPipeline` object Mapper pipeline to act on to data. data : array-like of shape (n_samples, n_features) Data used to generate the Mapper graph. Can be a pandas dataframe. layout : None, str or callable, optional, default: ``'kamada-kawai'`` Layout algorithm for the graph. Can be any accepted value for the ``layout`` parameter in the :meth:`layout` method of :class:`igraph.Graph`. [1]_ layout_dim : int, default: ``2`` The number of dimensions for the layout. Can be 2 or 3. color_variable : object or None, optional, default: ``None`` Specifies which quantity is to be used for node coloring. 1. If a numpy ndarray or pandas dataframe, `color_variable` must have the same length as `data` and is interpreted as a quantity of interest according to which node of the Mapper graph is to be colored (see `node_color_statistic`). 2. If ``None`` then equivalent to passing `data`. 3. If an object implementing :meth:`transform` or :meth:`fit_transform`, e.g. a scikit-learn estimator or pipeline, it is applied to `data` to generate the quantity of interest. 4. If an index or string, or list of indices / strings, equivalent to selecting a column or subset of columns from `data`. node_color_statistic :None, callable, or ndarray of shape (n_nodes,) or \ (n_nodes, 1), optional, default: ``None`` Specifies how to determine the colors of each node. If a numpy array, it must have the same length as the number of nodes in the Mapper graph, and its values are used directly for node coloring, ignoring `color_variable`. Otherwise, it can be a callable object which is used to obtain a summary statistic, within each Mapper node, of the quantity specified by `color_variable`. The default value ``None`` is equivalent to passing ``numpy.mean``. color_by_columns_dropdown : bool, optional, default: ``False`` If ``True``, a dropdown widget is generated which allows the user to color Mapper nodes according to any column in `data`. plotly_kwargs : dict, optional, default: ``None`` Keyword arguments to configure the Plotly Figure. Returns ------- box : :class:`ipywidgets.VBox` object A box containing the following widgets: parameters of the clustering algorithm, parameters for the covering scheme, a Mapper graph arising from those parameters, a validation box, and logs. References ---------- .. [1] `igraph.Graph.layout <https://igraph.org/python/doc/igraph.Graph-class.html#layout>`_ documentation. """ # clone pipeline to avoid side effects from in-place parameter changes pipe = clone(pipeline) if node_color_statistic is not None: _node_color_statistic = node_color_statistic else: _node_color_statistic = np.mean def get_widgets_per_param(param, value): if isinstance(value, float): return (param, widgets.FloatText( value=value, step=0.05, description=param.split('__')[1], continuous_update=False, disabled=False )) elif isinstance(value, int): return (param, widgets.IntText( value=value, step=1, description=param.split('__')[1], continuous_update=False, disabled=False )) elif isinstance(value, str): return (param, widgets.Text( value=value, description=param.split('__')[1], continuous_update=False, disabled=False )) else: return None def update_figure(figure, edge_trace, node_trace, layout_dim): figure.data[0].x = edge_trace.x figure.data[0].y = edge_trace.y figure.data[1].x = node_trace.x figure.data[1].y = node_trace.y if layout_dim == 3: figure.data[0].z = edge_trace.z figure.data[1].z = node_trace.z figure.data[1].marker.size = node_trace.marker.size figure.data[1].marker.color = node_trace.marker.color figure.data[1].marker.cmin = node_trace.marker.cmin figure.data[1].marker.cmax = node_trace.marker.cmax figure.data[1].marker.sizeref = node_trace.marker.sizeref figure.data[1].hoverlabel = node_trace.hoverlabel figure.data[1].hovertext = node_trace.hovertext def on_parameter_change(change): handler.clear_logs() try: for param, value in cover_params.items(): if isinstance(value, (int, float, str)): pipe.set_params( **{param: cover_params_widgets[param].value}) for param, value in cluster_params.items(): if isinstance(value, (int, float, str)): pipe.set_params( **{param: cluster_params_widgets[param].value}) logger.info("Updating figure ...") with fig.batch_update(): (node_trace, edge_trace, node_elements, node_colors, plot_options) = _calculate_graph_data( pipe, data, layout, layout_dim, color_variable, _node_color_statistic, plotly_kwargs ) update_figure(fig, edge_trace, node_trace, layout_dim) # Update color by column buttons is_data_dataframe = hasattr(data, 'columns') if color_by_columns_dropdown: column_color_buttons = _get_column_color_buttons( data, is_data_dataframe, node_elements, node_colors, plot_options['node_trace_marker_colorscale']) else: column_color_buttons = None button_height = 1.1 fig.update_layout( updatemenus=[ go.layout.Updatemenu( buttons=column_color_buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.11, xanchor='left', y=button_height, yanchor="top" ), ]) valid.value = True except Exception: exception_data = traceback.format_exc().splitlines() logger.exception(exception_data[-1]) valid.value = False def observe_widgets(params, widgets): for param, value in params.items(): if isinstance(value, (int, float, str)): widgets[param].observe(on_parameter_change, names='value') # define output widget to capture logs out = widgets.Output() @out.capture() def click_box(change): if logs_box.value: out.clear_output() handler.show_logs() else: out.clear_output() # initialise logging logger = logging.getLogger(__name__) handler = OutputWidgetHandler() handler.setFormatter(logging.Formatter( '%(asctime)s - [%(levelname)s] %(message)s')) logger.addHandler(handler) logger.setLevel(logging.INFO) # initialise cover and cluster dictionaries of parameters and widgets cover_params = dict(filter(lambda x: x[0].startswith('cover'), pipe.get_mapper_params().items())) cover_params_widgets = dict( filter(None, map(lambda x: get_widgets_per_param(*x), cover_params.items()))) cluster_params = dict(filter(lambda x: x[0].startswith('clusterer'), pipe.get_mapper_params().items())) cluster_params_widgets = dict( filter(None, map(lambda x: get_widgets_per_param(*x), cluster_params.items()))) # initialise widgets for validating input parameters of pipeline valid = widgets.Valid( value=True, description='Valid parameters', style={'description_width': '100px'}, ) # initialise widget for showing the logs logs_box = widgets.Checkbox( description='Show logs: ', value=False, indent=False ) # initialise figure with initial pipeline and config if plotly_kwargs is None: plotly_kwargs = dict() fig = plot_static_mapper_graph( pipe, data, layout, layout_dim, color_variable, _node_color_statistic, color_by_columns_dropdown, plotly_kwargs, clone_pipeline=False) observe_widgets(cover_params, cover_params_widgets) observe_widgets(cluster_params, cluster_params_widgets) logs_box.observe(click_box, names='value') # define containers for input widgets container_cover = widgets.HBox( children=list(cover_params_widgets.values())) container_cluster_layout = Layout(display='flex', flex_flow='row wrap') container_cluster = widgets.HBox( children=list(cluster_params_widgets.values()), layout=container_cluster_layout) box = widgets.VBox( [container_cover, container_cluster, fig, valid, logs_box, out]) return box