Source code for gtda.mapper.utils.pipeline

"""Utility functions for scikit-learn pipelines."""
# License: GNU AGPLv3

from functools import partial
from inspect import signature

import numpy as np
from sklearn.preprocessing import FunctionTransformer


def _make_func_apply_along_axis_1(func):
    return partial(np.apply_along_axis, func, 1)


def _reshape_after_apply(func, arr):
    if func(arr).ndim == 1:
        return func(arr).reshape(-1, 1)
    return func(arr)


[docs]def transformer_from_callable_on_rows(func, validate=True): """Construct a transformer from a callable acting on 1D arrays. Given a callable which can act on 1D arrays, this function returns a fit-transformer which applies the callable to slices of 2D arrays along axis 1. When possible, the array output by the transformer's :meth:`fit_transform` is two-dimensional. Parameters ---------- func : callable A callable object. validate : bool, optional, default: ``True`` Whether the output transformer should implement input validation. Returns ------- function_transformer : :class:`sklearn.preprocessing.FunctionTransformer` \ object Output fit-transformer. Examples -------- >>> import numpy as np >>> from gtda.mapper import transformer_from_callable_on_rows >>> function_transformer = transformer_from_callable_on_rows(np.sum) >>> X = np.array([[0, 1], [2, 3]]) >>> print(function_transformer.fit_transform(X)) [[1], [5]] """ if func is not None: func_params = signature(func).parameters if 'axis' in func_params: # Use native (faster) numpy implementation func_along_axis = partial(func, axis=1, keepdims=True) else: func_along_axis = partial(_reshape_after_apply, _make_func_apply_along_axis_1(func)) return FunctionTransformer(func=func_along_axis, validate=validate)
def identity(validate=False): return FunctionTransformer(validate=validate)