Source code for gtda.mapper.utils.decorators
"""Convenience class decorators for use in a Mapper context."""
# License: GNU AGPLv3
from sklearn.base import TransformerMixin
[docs]def method_to_transform(cls, method_name):
"""Wrap a class to add a :meth:`transform` method as an alias to an
existing method.
An example of use is for classes possessing a :meth:`score` method such as
kernel density estimators and anomaly/novelty detection estimators, to
allow for these estimators are to be used as steps in a pipeline.
Note that 1D array outputs are reshaped into 2D column vectors before
being returned by the new :meth:`transform`.
Parameters
----------
cls : object
Class to be wrapped. If `method_name` is not one of its methods,
:meth:`transform` always returns ``None``.
method_name : str
Name of the method in `cls` to which :meth:`transform` will be
an alias. The fist argument of this method (after ``self``) becomes
the ``X`` input for :meth:`transform`.
Returns
-------
wrapped_cls : object
New class inheriting from :class:`sklearn.base.TransformerMixin`, so
that both :meth:`transform` and :meth:`fit_transform` are available.
Its name is the name of `cls` prepended with ``'Extended'``.
Examples
--------
>>> import numpy as np
>>> from sklearn.neighbors import KernelDensity
>>> from gtda.mapper import method_to_transform
>>> X = np.random.random((100, 2))
>>> kde = KernelDensity()
>>> kde_extended = method_to_transform(
... KernelDensity, 'score_samples')()
>>> Xt = kde.fit(X).score_samples(X)
>>> print(Xt.shape)
(100,)
>>> Xt_extended = kde_extended.fit_transform(X)
>>> print(Xt_extended.shape)
(100, 1)
>>> np.array_equal(Xt, Xt_extended.flatten())
True
"""
def wrapper(wrapped):
class ExtendedEstimator(wrapped, TransformerMixin):
def transform(self, X, y=None):
has_method = hasattr(self, method_name)
if has_method:
Xt = getattr(self, method_name)(X)
# reshape 1D estimators to have shape (n_samples, 1)
if Xt.ndim == 1:
Xt = Xt[:, None]
return Xt
ExtendedEstimator.__name__ = 'Extended' + wrapped.__name__
return ExtendedEstimator
wrapped_cls = wrapper(cls)
return wrapped_cls