Source code for gtda.plotting.persistence_diagrams

"""Persistent-homology–related plotting functions and classes."""
# License: GNU AGPLv3

import numpy as np
import plotly.graph_objs as gobj


[docs]def plot_diagram(diagram, homology_dimensions=None, **input_layout): """Plot a single persistence diagram. Parameters ---------- diagram : ndarray of shape (n_points, 3) The persistence diagram to plot, where the third dimension along axis 1 contains homology dimensions, and the first two contain (birth, death) pairs to be used as coordinates in the two-dimensional plot. homology_dimensions : list of int or None, optional, default: ``None`` Homology dimensions which will appear on the plot. If ``None``, all homology dimensions which appear in `diagram` will be plotted. """ from ..diagrams._utils import _subdiagrams # To avoid circular imports # TODO: increase the marker size if homology_dimensions is None: homology_dimensions = np.unique(diagram[:, 2]) max_filt_param = np.where(np.isinf(diagram), -np.inf, diagram).max() layout = dict( width=500, height=500, xaxis1=dict( title='Birth', side='bottom', type='linear', range=[0, 1.1 * max_filt_param], ticks='outside', anchor='y1', showline=True, zeroline=True, showexponent='all', exponentformat='e' ), yaxis1=dict( title='Death', side='left', type='linear', range=[0, 1.1 * max_filt_param], ticks='outside', anchor='x1', showline=True, zeroline=True, showexponent='all', exponentformat='e' ), plot_bgcolor='white' ) layout.update(input_layout) fig = gobj.Figure(layout=layout) fig.update_xaxes(zeroline=True, linewidth=1, linecolor='black', mirror=False) fig.update_yaxes(zeroline=True, linewidth=1, linecolor='black', mirror=False) fig.add_trace(gobj.Scatter(x=np.array([-100 * max_filt_param, 100 * max_filt_param]), y=np.array([-100 * max_filt_param, 100 * max_filt_param]), mode='lines', line=dict(dash='dash', width=1, color='black'), showlegend=False, hoverinfo='none')) for dim in homology_dimensions: name = f'H{int(dim)}' if dim != np.inf else 'Any homology dimension' subdiagram = _subdiagrams(np.asarray([diagram]), [dim], remove_dim=True)[0] diff = (subdiagram[:, 1] != subdiagram[:, 0]) subdiagram = subdiagram[diff] fig.add_trace(gobj.Scatter(x=subdiagram[:, 0], y=subdiagram[:, 1], mode='markers', name=name)) fig.show()