Source code for gtda.mapper.utils.pipeline
"""Utility functions for scikit-learn pipelines."""
# License: GNU AGPLv3
from functools import partial
from inspect import signature
import numpy as np
from sklearn.preprocessing import FunctionTransformer
def _make_func_apply_along_axis_1(func):
    return partial(np.apply_along_axis, func, 1)
def _reshape_after_apply(func, arr):
    if func(arr).ndim == 1:
        return func(arr).reshape(-1, 1)
    return func(arr)
[docs]def transformer_from_callable_on_rows(func, validate=True):
    """Construct a transformer from a callable acting on 1D arrays.
    Given a callable which can act on 1D arrays, this function returns a
    fit-transformer which applies the callable to slices of 2D arrays along
    axis 1. When possible, the array output by the transformer's
    :meth:`fit_transform` is two-dimensional.
    Parameters
    ----------
    func : callable
        A callable object.
    validate : bool, optional, default: ``True``
        Whether the output transformer should implement input validation.
    Returns
    -------
    function_transformer : :class:`sklearn.preprocessing.FunctionTransformer` \
        object
        Output fit-transformer.
    Examples
    --------
    >>> import numpy as np
    >>> from gtda.mapper import transformer_from_callable_on_rows
    >>> function_transformer = transformer_from_callable_on_rows(np.sum)
    >>> X = np.array([[0, 1], [2, 3]])
    >>> print(function_transformer.fit_transform(X))
    [[1],
     [5]]
    """
    if func is not None:
        func_params = signature(func).parameters
        if 'axis' in func_params:  # Use native (faster) numpy implementation
            func_along_axis = partial(func, axis=1, keepdims=True)
        else:
            func_along_axis = partial(_reshape_after_apply,
                                      _make_func_apply_along_axis_1(func))
        return FunctionTransformer(func=func_along_axis, validate=validate)
def identity(validate=False):
    return FunctionTransformer(validate=validate)