Source code for gtda.plotting.point_clouds

"""Point-cloud–related plotting functions and classes."""
# License: GNU AGPLv3

import numpy as np
import plotly.graph_objs as gobj


[docs]def plot_point_cloud(point_cloud, dimension=None): """Plot the first 2 or 3 coordinates of a point cloud. This function will not work on 1D arrays. Parameters ---------- point_cloud : ndarray of shape (n_samples, n_dimensions) Data points to be represented in a 2D or 3D scatter plot. Only the first 2 or 3 dimensions will be considered for plotting. dimension : int or None, default: ``None`` Sets the dimension of the resulting plot. If ``None``, the dimension will be chosen between 2 and 3 depending on the shape of `point_cloud`. """ # TODO: increase the marker size if dimension is None: dimension = np.min((3, point_cloud.shape[1])) # Check consistency between point_cloud and dimension if point_cloud.shape[1] < dimension: raise ValueError("Not enough dimensions available in the input point " "cloud.") if dimension == 2: layout = { "width": 800, "height": 800, "xaxis1": { "title": "0th", "side": "bottom", "type": "linear", "ticks": "outside", "anchor": "x1", "showline": True, "zeroline": True, "showexponent": "all", "exponentformat": "e" }, "yaxis1": { "title": "1st", "side": "left", "type": "linear", "ticks": "outside", "anchor": "y1", "showline": True, "zeroline": True, "showexponent": "all", "exponentformat": "e" }, "plot_bgcolor": "white" } fig = gobj.Figure(layout=layout) fig.update_xaxes(zeroline=True, linewidth=1, linecolor='black', mirror=False) fig.update_yaxes(zeroline=True, linewidth=1, linecolor='black', mirror=False) fig.add_trace(gobj.Scatter(x=point_cloud[:, 0], y=point_cloud[:, 1], mode='markers', marker=dict(size=4, color=list(range( point_cloud.shape[0])), colorscale='Viridis', opacity=0.8))) fig.show() elif dimension == 3: scene = { "xaxis": { "title": "0th", "type": "linear", "showexponent": "all", "exponentformat": "e" }, "yaxis": { "title": "1st", "type": "linear", "showexponent": "all", "exponentformat": "e" }, "zaxis": { "title": "2nd", "type": "linear", "showexponent": "all", "exponentformat": "e" } } fig = gobj.Figure() fig.update_layout(scene=scene) fig.add_trace(gobj.Scatter3d(x=point_cloud[:, 0], y=point_cloud[:, 1], z=point_cloud[:, 2], mode='markers', marker=dict(size=4, color=list(range( point_cloud.shape[0])), colorscale='Viridis', opacity=0.8))) fig.show() else: raise ValueError("The value of the dimension is different from 2 or 3")