"""Utilities for input validation."""
# License: GNU AGPLv3
from functools import reduce
from operator import and_
from warnings import warn
import numpy as np
from sklearn.utils.validation import check_array
[docs]def check_diagrams(X, copy=False):
"""Input validation for collections of persistence diagrams.
Basic type and sanity checks are run on the input collection and the
array is converted to float type before returning. In particular,
the input is checked to be an ndarray of shape ``(n_samples, n_points,
3)``.
Parameters
----------
X : object
Input object to check/convert.
copy : bool, optional, default: ``False``
Whether a forced copy should be triggered.
Returns
-------
X_validated : ndarray of shape (n_samples, n_points, 3)
The converted and validated array of persistence diagrams.
"""
X_array = np.asarray(X)
if X_array.ndim == 0:
raise ValueError(
f"Expected 3D array, got scalar array instead:\narray={X_array}.")
if X_array.ndim != 3:
raise ValueError(
f"Input should be a 3D ndarray, the shape is {X_array.shape}.")
if X_array.shape[2] != 3:
raise ValueError(
f"Input should be a 3D ndarray with a 3rd dimension of 3 "
f"components, but there are {X_array.shape[2]} components.")
X_array = X_array.astype(float, copy=False)
homology_dimensions = sorted(list(set(X_array[0, :, 2])))
for dim in homology_dimensions:
if dim == np.inf:
if len(homology_dimensions) != 1:
raise ValueError(
f"np.inf is a valid homology dimension for a stacked "
f"diagram but it should be the only one: "
f"homology_dimensions = {homology_dimensions}.")
else:
if dim != int(dim):
raise ValueError(
f"All homology dimensions should be integer valued: "
f"{dim} can't be cast to an int of the same value.")
if dim != np.abs(dim):
raise ValueError(
f"All homology dimensions should be integer valued: "
f"{dim} can't be cast to an int of the same value.")
n_points_above_diag = np.sum(X_array[:, :, 1] >= X_array[:, :, 0])
n_points_global = X_array.shape[0] * X_array.shape[1]
if n_points_above_diag != n_points_global:
raise ValueError(
f"All points of all persistence diagrams should be above the "
f"diagonal, i.e. X[:,:,1] >= X[:,:,0]. "
f"{n_points_global - n_points_above_diag} points are under the "
f"diagonal.")
if copy:
X_array = np.copy(X_array)
return X_array
def check_graph(X):
# TODO
return X
def _validate_params_single(parameter, reference, name):
if reference is None:
return
ref_type = reference.get('type', None)
# Check that parameter has the correct type
if (ref_type is not None) and (not isinstance(parameter, ref_type)):
raise TypeError(
f"Parameter `{name}` is of type {type(parameter)} while "
f"it should be of type {ref_type}.")
# If the reference type parameter is not list, tuple, np.ndarray or dict,
# the checks are performed on the parameter object directly.
elif ref_type not in [list, tuple, np.ndarray, dict]:
ref_in = reference.get('in', None)
ref_other = reference.get('other', None)
if parameter is not None:
if (ref_in is not None) and (parameter not in ref_in):
raise ValueError(
f"Parameter `{name}` is {parameter}, which is not in "
f"{ref_in}.")
# Perform any other checks via the callable ref_others
if ref_other is not None:
return ref_other(parameter)
# Explicitly return the type of reference if one of list, tuple, np.ndarray
# or dict.
else:
return ref_type
def _validate_params(parameters, references, rec_name=None):
for name, parameter in parameters.items():
if name not in references.keys():
name_extras = "" if rec_name is None else f" in `{rec_name}`"
raise KeyError(
f"`{name}`{name_extras} is not an available parameter. "
f"Available parameters are in {list(references.keys())}.")
reference = references[name]
ref_type = _validate_params_single(parameter, reference, name)
if ref_type:
ref_of = reference.get('of', None)
if ref_type == dict:
_validate_params(parameter, ref_of, rec_name=name)
else: # List, tuple or ndarray type
for i, parameter_elem in enumerate(parameter):
_validate_params_single(
parameter_elem, ref_of, f"{name}[{i}]")
[docs]def validate_params(parameters, references, exclude=None):
"""Function to automate the validation of (hyper)parameters.
Parameters
----------
parameters : dict, required
Dictionary in which the keys parameter names (as strings) and the
corresponding values are parameter values. Unless `exclude` (see
below) contains some of the keys in this dictionary, all parameters
are checked against `references`.
references : dict, required
Dictionary in which the keys are parameter names (as strings). Let
``name`` and ``parameter`` denote a key-value pair in `parameters`.
Since ``name`` should also be a key in `references`, let ``reference``
be the corresponding value there. Then, ``reference`` must be a
dictionary containing any of the following keys:
- ``'type'``, mapping to a class or tuple of classes. ``parameter``
is checked to be an instance of this class or tuple of classes.
- ``'in'``, mapping to a dictionary, when the value of ``'type'`` is
not one of ``list``, ``tuple``, ``numpy.ndarray`` or ``dict``.
Letting ``ref_in`` denote that dictionary, the following check is
performed: ``parameter in ref_in``.
- ``'of'``, mapping to a dictionary, when the value of ``'type'``
is one of ``list``, ``tuple``, ``numpy.ndarray`` or ``dict``.
Let ``ref_of`` denote that dictionary. Then:
a) If ``reference['type'] == dict`` – meaning that ``parameter``
should be a dictionary – ``ref_of`` should have a similar
structure as `references`, and :func:`validate_params` is called
recursively on ``(parameter, ref_of)``.
b) Otherwise, ``ref_of`` should have a similar structure as
``reference`` and each entry in ``parameter`` is checked to
satisfy the constraints in ``ref_of``.
- ``'other'``, which should map to a callable defining custom checks on
``parameter``.
exclude : list of str, or None, optional, default: ``None``
List of parameter names which are among the keys in `parameters` but
should be excluded from validation. ``None`` is equivalent to
passing the empty list.
"""
exclude_ = [] if exclude is None else exclude
parameters_ = {key: value for key, value in parameters.items()
if key not in exclude_}
return _validate_params(parameters_, references)
[docs]def check_point_clouds(X, distance_matrix=False, **kwargs):
"""Input validation on an array or list representing a collection of point
clouds or distance matrices.
The input is checked to be either a single 3D array using a single call
to :func:`~sklearn.utils.validation.check_array`, or a list of 2D arrays by
calling :func:`~sklearn.utils.validation.check_array` on each entry. In
the latter case, warnings are issued when not all point clouds are in
the same Euclidean space.
Conversions and copies may be triggered as per
:func:`~gtda.utils.validation.check_list_of_arrays`.
Parameters
----------
X : object
Input object to check / convert.
distance_matrix : bool, optional, default: ``False``
Whether the input represents a collection of distance matrices or of
concrete point clouds in Euclidean space. In the first case, entries
are allowed to be infinite unless otherwise specified in `kwargs`.
kwargs
Keyword arguments accepted by
:func:`~gtda.utils.validation.check_list_of_arrays`.
Returns
-------
Xnew : ndarray or list
The converted and validated object.
"""
kwargs_ = {'force_all_finite': not distance_matrix}
kwargs_.update(kwargs)
if hasattr(X, 'shape'):
if X.ndim != 3:
raise ValueError("ndarray input must be 3D.")
return check_array(X, allow_nd=True, **kwargs_)
else:
if not distance_matrix:
reference = X[0].shape[1] # Embedding dimension of first sample
if not reduce(
and_, (x.shape[1] == reference for x in X[1:]), True):
warn("Not all point clouds have the same embedding dimension.")
has_check_failed = False
messages = []
Xnew = []
for i, x in enumerate(X):
try:
Xnew.append(check_array(x, **kwargs_))
messages = ['']
except ValueError as e:
has_check_failed = True
messages.append(str(e))
if has_check_failed:
raise ValueError("The following errors were raised by the inputs: \n"
"\n".join(messages))
return Xnew