Source code for gtda.mapper.visualization

"""Static and interactive visualisation functions for Mapper graphs."""
# License: GNU AGPLv3

import logging
import traceback
from copy import deepcopy
from warnings import warn

import numpy as np
import plotly.graph_objects as go
from ipywidgets import Layout, widgets
from sklearn.base import clone

from .utils._logging import OutputWidgetHandler
from .utils._visualization import (
    _calculate_graph_data,
    _get_column_color_buttons,
    _get_colors_for_vals,
    PLOT_OPTIONS_LAYOUT_DEFAULTS
)


[docs]def plot_static_mapper_graph( pipeline, data, layout="kamada_kawai", layout_dim=2, color_variable=None, node_color_statistic=None, color_by_columns_dropdown=False, clone_pipeline=True, n_sig_figs=3, node_scale=12, plotly_params=None ): """Plotting function for static Mapper graphs. Nodes are colored according to `color_variable` and `node_color_statistic`. By default, the hovertext on each node displays a globally unique ID for the node, the number of data points associated with the node, and the summary statistic which determines its color. Parameters ---------- pipeline : :class:`~gtda.mapper.pipeline.MapperPipeline` object Mapper pipeline to act onto data. data : array-like of shape (n_samples, n_features) Data used to generate the Mapper graph. Can be a pandas dataframe. layout : None, str or callable, optional, default: ``"kamada-kawai"`` Layout algorithm for the graph. Can be any accepted value for the ``layout`` parameter in the :meth:`layout` method of :class:`igraph.Graph`. [1]_ layout_dim : int, default: ``2`` The number of dimensions for the layout. Can be 2 or 3. color_variable : object or None, optional, default: ``None`` Specifies a feature of interest to be used, together with `node_color_statistic`, to determine node colors. 1. If a numpy array or pandas dataframe, it must have the same length as `data`. 2. ``None`` is equivalent to passing `data`. 3. If an object implementing :meth:`transform` or :meth:`fit_transform`, it is applied to `data` to generate the feature of interest. 4. If an index or string, or list of indices/strings, it is equivalent to selecting a column or subset of columns from `data`. node_color_statistic : None, callable, or ndarray of shape (n_nodes,) or \ (n_nodes, 1), optional, default: ``None`` If a callable, node colors will be computed as summary statistics from the feature array ``Y`` determined by `color_variable` – specifically, the color of a node representing the entries of `data` whose row indices are in ``I`` will be ``node_color_statistic(Y[I])``. ``None`` is equivalent to passing :func:`numpy.mean`. If a numpy array, it must have the same length as the number of nodes in the Mapper graph and its values are used directly as node colors (`color_variable` is ignored). color_by_columns_dropdown : bool, optional, default: ``False`` If ``True``, a dropdown widget is generated which allows the user to color Mapper nodes according to any column in `data` (still using `node_color_statistic`) in addition to `color_variable`. clone_pipeline : bool, optional, default: ``True`` If ``True``, the input `pipeline` is cloned before computing the Mapper graph to prevent unexpected side effects from in-place parameter updates. n_sig_figs : int or None, optional, default: ``3`` If not ``None``, number of significant figures to which to round node node summary statistics. If ``None``, no rounding is performed. node_scale : int or float, optional, default: ``12`` Sets the scale factor used to determine the rendered size of the nodes. Increase for larger nodes. Implements a formula in the `Plotly documentation \ <https://plotly.com/python/bubble-charts/#scaling-the-size-of-bubble\ -charts>`_. plotly_params : dict or None, optional, default: ``None`` Custom parameters to configure the plotly figure. Allowed keys are ``"node_trace"``, ``"edge_trace"`` and ``"layout"``, and the corresponding values should be dictionaries containing keyword arguments as would be fed to the :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. Returns ------- fig : :class:`plotly.graph_objects.Figure` object Figure representing the Mapper graph with appropriate node colouring and size. Examples -------- Setting a colorscale different from the default one: >>> import numpy as np >>> from gtda.mapper import make_mapper_pipeline, plot_static_mapper_graph >>> pipeline = make_mapper_pipeline() >>> data = np.random.random((100, 3)) >>> plotly_params = {"node_trace": {"marker_colorscale": "Blues"}} >>> fig = plot_static_mapper_graph(pipeline, data, ... plotly_params=plotly_params) See also -------- :func:`~gtda.mapper.visualization.plot_interactive_mapper_graph`, :func:`~gtda.mapper.pipeline.make_mapper_pipeline` References ---------- .. [1] `igraph.Graph.layout <https://igraph.org/python/doc/igraph.Graph-class.html#layout>`_ documentation. """ # Compute the graph and fetch the indices of points in each node _pipeline = clone(pipeline) if clone_pipeline else pipeline _node_color_statistic = node_color_statistic or np.mean # Simple duck typing to determine whether data is likely a pandas dataframe is_data_dataframe = hasattr(data, "columns") edge_trace, node_trace, node_elements, node_colors_color_variable = \ _calculate_graph_data( _pipeline, data, is_data_dataframe, layout, layout_dim, color_variable, _node_color_statistic, n_sig_figs, node_scale ) # Define layout options layout_options = go.Layout( **PLOT_OPTIONS_LAYOUT_DEFAULTS["common"], **PLOT_OPTIONS_LAYOUT_DEFAULTS[layout_dim] ) fig = go.FigureWidget(data=[edge_trace, node_trace], layout=layout_options) _plotly_params = deepcopy(plotly_params) # When laying out the graph in 3D, plotly does not automatically give # the background hoverlabel the same color as the respective marker, # so we do this by hand here. # TODO: Extract logic so as to avoid repetitions in interactive version colorscale_for_hoverlabel = None if layout_dim == 3: compute_hoverlabel_bgcolor = True if _plotly_params: if "node_trace" in _plotly_params: if "hoverlabel_bgcolor" in _plotly_params["node_trace"]: fig.update_traces( hoverlabel_bgcolor=_plotly_params["node_trace"].pop( "hoverlabel_bgcolor" ), selector={"name": "node_trace"} ) compute_hoverlabel_bgcolor = False if "marker_colorscale" in _plotly_params["node_trace"]: fig.update_traces( marker_colorscale=_plotly_params["node_trace"].pop( "marker_colorscale" ), selector={"name": "node_trace"} ) if compute_hoverlabel_bgcolor: colorscale_for_hoverlabel = fig.data[1].marker.colorscale node_colors_color_variable = np.asarray(node_colors_color_variable) min_col = np.min(node_colors_color_variable) max_col = np.max(node_colors_color_variable) try: hoverlabel_bgcolor = _get_colors_for_vals( node_colors_color_variable, min_col, max_col, colorscale_for_hoverlabel ) except Exception as e: if e.args[0] == "This colorscale is not supported.": warn("Data-dependent background hoverlabel colors cannot " "be generated with this choice of colorscale. Please " "use a standard hex- or RGB-formatted colorscale.") else: warn("Something went wrong in generating data-dependent " "background hoverlabel colors. All background " "hoverlabel colors will be set to white.") hoverlabel_bgcolor = "white" colorscale_for_hoverlabel = None fig.update_traces( hoverlabel_bgcolor=hoverlabel_bgcolor, selector={"name": "node_trace"} ) # Compute node colors according to data columns only if necessary if color_by_columns_dropdown: hovertext_color_variable = node_trace.hovertext column_color_buttons = _get_column_color_buttons( data, is_data_dataframe, node_elements, node_colors_color_variable, _node_color_statistic, hovertext_color_variable, colorscale_for_hoverlabel, n_sig_figs ) # Avoid recomputing hoverlabel bgcolor for top button column_color_buttons[0]["args"][0]["hoverlabel.bgcolor"] = \ [None, fig.data[1].hoverlabel.bgcolor] else: column_color_buttons = None button_height = 1.1 fig.update_layout( updatemenus=[ go.layout.Updatemenu( buttons=column_color_buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.11, xanchor="left", y=button_height, yanchor="top" ), ]) if color_by_columns_dropdown: fig.add_annotation( go.layout.Annotation( text="Color by:", x=0, xref="paper", y=button_height - 0.045, yref="paper", align="left", showarrow=False ) ) # Update traces and layout according to user input if _plotly_params: for key in ["node_trace", "edge_trace"]: fig.update_traces( _plotly_params.pop(key, None), selector={"name": key} ) fig.update_layout(_plotly_params.pop("layout", None)) return fig
[docs]def plot_interactive_mapper_graph( pipeline, data, layout="kamada_kawai", layout_dim=2, color_variable=None, node_color_statistic=None, clone_pipeline=True, color_by_columns_dropdown=False, n_sig_figs=3, node_scale=12, plotly_params=None ): """Plotting function for interactive Mapper graphs. Provides functionality to interactively update parameters from the cover and clustering steps defined in `pipeline`. Nodes are colored according to `color_variable` and `node_color_statistic`. By default, the hovertext on each node displays a globally unique ID for the node, the number of data points associated with the node, and the summary statistic which determines its color. Parameters ---------- pipeline : :class:`~gtda.mapper.pipeline.MapperPipeline` object Mapper pipeline to act on to data. data : array-like of shape (n_samples, n_features) Data used to generate the Mapper graph. Can be a pandas dataframe. layout : None, str or callable, optional, default: ``"kamada-kawai"`` Layout algorithm for the graph. Can be any accepted value for the ``layout`` parameter in the :meth:`layout` method of :class:`igraph.Graph`. [1]_ layout_dim : int, default: ``2`` The number of dimensions for the layout. Can be 2 or 3. color_variable : object or None, optional, default: ``None`` Specifies a feature of interest to be used, together with `node_color_statistic`, to determine node colors. 1. If a numpy array or pandas dataframe, it must have the same length as `data`. 2. ``None`` is equivalent to passing `data`. 3. If an object implementing :meth:`transform` or :meth:`fit_transform`, it is applied to `data` to generate the feature of interest. 4. If an index or string, or list of indices/strings, it is equivalent to selecting a column or subset of columns from `data`. node_color_statistic : callable or None, optional, default: ``None`` If a callable, node colors will be computed as summary statistics from the feature array ``Y`` determined by `color_variable` – specifically, the color of a node representing the entries of `data` whose row indices are in ``I`` will be ``node_color_statistic(Y[I])``. ``None`` is equivalent to passing :func:`numpy.mean`. color_by_columns_dropdown : bool, optional, default: ``False`` If ``True``, a dropdown widget is generated which allows the user to color Mapper nodes according to any column in `data` (still using `node_color_statistic`) in addition to `color_variable`. clone_pipeline : bool, optional, default: ``True`` If ``True``, the input `pipeline` is cloned before computing the Mapper graph to prevent unexpected side effects from in-place parameter updates. n_sig_figs : int or None, optional, default: ``3`` If not ``None``, number of significant figures to which to round node node summary statistics. If ``None``, no rounding is performed. node_scale : int or float, optional, default: ``12`` Sets the scale factor used to determine the rendered size of the nodes. Increase for larger nodes. Implements a formula in the `Plotly documentation \ <plotly.com/python/bubble-charts/#scaling-the-size-of-bubble-charts>`_. plotly_params : dict or None, optional, default: ``None`` Custom parameters to configure the plotly figure. Allowed keys are ``"node_trace"``, ``"edge_trace"`` and ``"layout"``, and the corresponding values should be dictionaries containing keyword arguments as would be fed to the :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. Returns ------- box : :class:`ipywidgets.VBox` object A box containing the following widgets: parameters of the clustering algorithm, parameters for the covering scheme, a Mapper graph arising from those parameters, a validation box, and logs. See also -------- :func:`~gtda.mapper.visualization.plot_static_mapper_graph`, :func:`~gtda.mapper.pipeline.make_mapper_pipeline` References ---------- .. [1] `igraph.Graph.layout <https://igraph.org/python/doc/igraph.Graph-class.html#layout>`_ documentation. """ # Clone pipeline to avoid side effects from in-place parameter changes _pipeline = clone(pipeline) if clone_pipeline else pipeline _node_color_statistic = node_color_statistic or np.mean def get_widgets_per_param(param, value): if isinstance(value, float): return (param, widgets.FloatText( value=value, step=0.05, description=param.split("__")[1], continuous_update=False, disabled=False )) elif isinstance(value, int): return (param, widgets.IntText( value=value, step=1, description=param.split("__")[1], continuous_update=False, disabled=False )) elif isinstance(value, str): return (param, widgets.Text( value=value, description=param.split("__")[1], continuous_update=False, disabled=False )) else: return None def on_parameter_change(change): handler.clear_logs() try: for param, value in cover_params.items(): if isinstance(value, (int, float, str)): _pipeline.set_params( **{param: cover_params_widgets[param].value} ) for param, value in cluster_params.items(): if isinstance(value, (int, float, str)): _pipeline.set_params( **{param: cluster_params_widgets[param].value} ) logger.info("Updating figure...") with fig.batch_update(): ( edge_trace, node_trace, node_elements, node_colors_color_variable ) = _calculate_graph_data( _pipeline, data, is_data_dataframe, layout, layout_dim, color_variable, _node_color_statistic, n_sig_figs, node_scale ) if colorscale_for_hoverlabel is not None: node_colors_color_variable = np.asarray( node_colors_color_variable ) min_col = np.min(node_colors_color_variable) max_col = np.max(node_colors_color_variable) hoverlabel_bgcolor = _get_colors_for_vals( node_colors_color_variable, min_col, max_col, colorscale_for_hoverlabel ) fig.update_traces( hoverlabel_bgcolor=hoverlabel_bgcolor, selector={"name": "node_trace"} ) fig.update_traces( x=node_trace.x, y=node_trace.y, marker_color=node_trace.marker.color, marker_size=node_trace.marker.size, marker_sizeref=node_trace.marker.sizeref, hovertext=node_trace.hovertext, **({"z": node_trace.z} if layout_dim == 3 else dict()), selector={"name": "node_trace"} ) fig.update_traces( x=edge_trace.x, y=edge_trace.y, **({"z": edge_trace.z} if layout_dim == 3 else dict()), selector={"name": "edge_trace"} ) # Update color by column buttons if color_by_columns_dropdown: hovertext_color_variable = node_trace.hovertext column_color_buttons = _get_column_color_buttons( data, is_data_dataframe, node_elements, node_colors_color_variable, _node_color_statistic, hovertext_color_variable, colorscale_for_hoverlabel, n_sig_figs ) # Avoid recomputing hoverlabel bgcolor for top button if colorscale_for_hoverlabel is not None: column_color_buttons[0]["args"][0][ "hoverlabel.bgcolor"] = [None, hoverlabel_bgcolor] else: column_color_buttons = None button_height = 1.1 fig.update_layout( updatemenus=[ go.layout.Updatemenu( buttons=column_color_buttons, direction="down", pad={"r": 10, "t": 10}, showactive=True, x=0.11, xanchor="left", y=button_height, yanchor="top" ), ]) valid.value = True except Exception: exception_data = traceback.format_exc().splitlines() logger.exception(exception_data[-1]) valid.value = False def observe_widgets(params, widgets): for param, value in params.items(): if isinstance(value, (int, float, str)): widgets[param].observe(on_parameter_change, names="value") # Define output widget to capture logs out = widgets.Output() @out.capture() def click_box(change): if logs_box.value: out.clear_output() handler.show_logs() else: out.clear_output() # Initialise logging logger = logging.getLogger(__name__) handler = OutputWidgetHandler() handler.setFormatter(logging.Formatter( "%(asctime)s - [%(levelname)s] %(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) # Initialise cover and cluster dictionaries of parameters and widgets cover_params = dict( filter( lambda x: x[0].startswith("cover"), _pipeline.get_mapper_params().items() ) ) cover_params_widgets = dict( filter( None, map( lambda x: get_widgets_per_param(*x), cover_params.items() ) ) ) cluster_params = dict( filter( lambda x: x[0].startswith("clusterer"), _pipeline.get_mapper_params().items() ) ) cluster_params_widgets = dict( filter( None, map( lambda x: get_widgets_per_param(*x), cluster_params.items() ) ) ) # Initialise widgets for validating input parameters of pipeline valid = widgets.Valid( value=True, description="Valid parameters", style={"description_width": "100px"}, ) # Initialise widget for showing the logs logs_box = widgets.Checkbox( description="Show logs: ", value=False, indent=False ) # Initialise figure with initial pipeline and config fig = plot_static_mapper_graph( _pipeline, data, layout=layout, layout_dim=layout_dim, color_variable=color_variable, node_color_statistic=_node_color_statistic, color_by_columns_dropdown=color_by_columns_dropdown, clone_pipeline=False, n_sig_figs=n_sig_figs, node_scale=node_scale, plotly_params=plotly_params ) # Store variables for later updates is_data_dataframe = hasattr(data, "columns") colorscale_for_hoverlabel = None if layout_dim == 3: # In plot_static_mapper_graph, hoverlabel bgcolors are set to white if # something goes wrong computing them according to the colorscale. is_bgcolor_not_white = fig.data[1].hoverlabel.bgcolor != "white" user_hoverlabel_bgcolor = False if plotly_params: if "node_trace" in plotly_params: if "hoverlabel_bgcolor" in plotly_params["node_trace"]: user_hoverlabel_bgcolor = True if is_bgcolor_not_white and not user_hoverlabel_bgcolor: colorscale_for_hoverlabel = fig.data[1].marker.colorscale observe_widgets(cover_params, cover_params_widgets) observe_widgets(cluster_params, cluster_params_widgets) logs_box.observe(click_box, names="value") # Define containers for input widgets container_cover = widgets.HBox( children=list(cover_params_widgets.values()) ) container_cluster_layout = Layout(display="flex", flex_flow="row wrap") container_cluster = widgets.HBox( children=list(cluster_params_widgets.values()), layout=container_cluster_layout ) box = widgets.VBox( [container_cover, container_cluster, fig, valid, logs_box, out] ) return box