"""Point-cloud–related plotting functions and classes."""
# License: GNU AGPLv3
import numpy as np
import plotly.graph_objs as gobj
[docs]def plot_point_cloud(point_cloud, dimension=None):
"""Plot the first 2 or 3 coordinates of a point cloud.
This function will not work on 1D arrays.
Parameters
----------
point_cloud : ndarray of shape (n_samples, n_dimensions)
Data points to be represented in a 2D or 3D scatter plot. Only the
first 2 or 3 dimensions will be considered for plotting.
dimension : int or None, default: ``None``
Sets the dimension of the resulting plot. If ``None``, the dimension
will be chosen between 2 and 3 depending on the shape of `point_cloud`.
"""
# TODO: increase the marker size
if dimension is None:
dimension = np.min((3, point_cloud.shape[1]))
# Check consistency between point_cloud and dimension
if point_cloud.shape[1] < dimension:
raise ValueError("Not enough dimensions available in the input point "
"cloud.")
if dimension == 2:
layout = {
"width": 800,
"height": 800,
"xaxis1": {
"title": "0th",
"side": "bottom",
"type": "linear",
"ticks": "outside",
"anchor": "x1",
"showline": True,
"zeroline": True,
"showexponent": "all",
"exponentformat": "e"
},
"yaxis1": {
"title": "1st",
"side": "left",
"type": "linear",
"ticks": "outside",
"anchor": "y1",
"showline": True,
"zeroline": True,
"showexponent": "all",
"exponentformat": "e"
},
"plot_bgcolor": "white"
}
fig = gobj.Figure(layout=layout)
fig.update_xaxes(zeroline=True, linewidth=1, linecolor='black',
mirror=False)
fig.update_yaxes(zeroline=True, linewidth=1, linecolor='black',
mirror=False)
fig.add_trace(gobj.Scatter(x=point_cloud[:, 0],
y=point_cloud[:, 1],
mode='markers',
marker=dict(size=4,
color=list(range(
point_cloud.shape[0])),
colorscale='Viridis',
opacity=0.8)))
fig.show()
elif dimension == 3:
scene = {
"xaxis": {
"title": "0th",
"type": "linear",
"showexponent": "all",
"exponentformat": "e"
},
"yaxis": {
"title": "1st",
"type": "linear",
"showexponent": "all",
"exponentformat": "e"
},
"zaxis": {
"title": "2nd",
"type": "linear",
"showexponent": "all",
"exponentformat": "e"
}
}
fig = gobj.Figure()
fig.update_layout(scene=scene)
fig.add_trace(gobj.Scatter3d(x=point_cloud[:, 0],
y=point_cloud[:, 1],
z=point_cloud[:, 2],
mode='markers',
marker=dict(size=4,
color=list(range(
point_cloud.shape[0])),
colorscale='Viridis',
opacity=0.8)))
fig.show()
else:
raise ValueError("The value of the dimension is different from 2 or 3")