"""Feature extraction from curves."""
# License: GNU AGPLv3
from copy import deepcopy
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_array
from ._functions import _AVAILABLE_FUNCTIONS, _implemented_function_recipes, \
    _parallel_featurization
from ..utils._docs import adapt_fit_transform_docs
from ..utils.validation import validate_params
[docs]@adapt_fit_transform_docs
class StandardFeatures(BaseEstimator, TransformerMixin):
    """Standard features from multi-channel curves.
    A multi-channel (integer sampled) curve is a 2D array of shape
    ``(n_channels, n_bins)``, where each row represents the y-values in one of
    the channels. This transformer applies scalar or vector-valued functions
    channel-wise to extract features from each multi-channel curve in a
    collection. The output is always a 2D array such that row ``i`` is the
    concatenation of the outputs of the chosen functions on the channels in the
    ``i``-th (multi-)curve in the collection.
    Parameters
    ----------
    function : string, callable, list or tuple, optional, default: ``"max"``
        Function or list/tuple of functions to apply to each channel of each
        multi-channel curve. Functions can map to scalars or to 1D arrays. If a
        string (see below) or a callable, then the same function is applied to
        all channels. Otherwise, `function` is a list/tuple of the same length
        as the number of entries along axis 1 in the collection passed to
        :meth:`fit`. Lists/tuples may contain allowed strings (see below),
        callables, and ``None`` in some positions to indicate that no feature
        should be extracted from the corresponding channel. Available strings
        are ``"identity"``, ``"argmin"``, ``"argmax"``, ``"min"``, ``"max"``,
        ``"mean"``, ``"std"``, ``"median"`` and ``"average"``.
    function_params : dict, None, list or tuple, optional, default: ``None``
        Additional keyword arguments for the function or functions in
        `function`. Passing ``None`` is equivalent to passing no arguments.
        Otherwise, if `function` is a single string or callable then
        `function_params` must be a dictionary. For functions encoded by
        allowed strings, the dictionary keys are as follows:
        - If ``function == "average"``, the only key is ``"weights"``
          (np.ndarray or None, default: ``None``).
        - Otherwise, there are no allowed keys.
        If `function` is a list or tuple, `function_params` must be a list or
        tuple of dictionaries (or ``None``) as above, of the same length as
        `function`.
    n_jobs : int or None, optional, default: ``None``
        The number of jobs to use for the computation. ``None`` means 1 unless
        in a :obj:`joblib.parallel_backend` context. ``-1`` means using all
        processors. Ignored if `function` is one of the allowed string options.
    Attributes
    ----------
    n_channels_ : int
        Number of channels present in the 3D array passed to :meth:`fit`. Must
        match the number of channels in the 3D array passed to
        :meth:`transform`.
    effective_function_ : callable or tuple
        Callable, or tuple of callables or ``None``, describing the function(s)
        used to compute features in each available channel. It is a single
        callable only when `function` was passed as a string.
    effective_function_params_ : dict or tuple
        Dictionary or tuple of dictionaries containing all information present
        in `function_params` as well as relevant quantities computed in
        :meth:`fit`. It is a single dict only when `function` was passed as a
        string. ``None``s are converted to empty dictionaries.
    """
    _hyperparameters = {
        "function": {"type": (str, FunctionType, list, tuple),
                     "in": tuple(_AVAILABLE_FUNCTIONS.keys()),
                     "of": {"type": (str, FunctionType, type(None)),
                            "in": tuple(_AVAILABLE_FUNCTIONS.keys())}},
        "function_params": {"type": (dict, type(None), list, tuple)},
        }
[docs]    def __init__(self, function="max", function_params=None, n_jobs=None):
        self.function = function
        self.function_params = function_params
        self.n_jobs = n_jobs 
    def _validate_params(self):
        params = self.get_params().copy()
        _hyperparameters = deepcopy(self._hyperparameters)
        if not isinstance(self.function, str):
            _hyperparameters["function"].pop("in")
        try:
            validate_params(params, _hyperparameters, exclude=["n_jobs"])
        # Another go if we fail because function is a list/tuple containing
        # instances of FunctionType and the "in" key checks fail
        except ValueError as ve:
            end_string = f"which is not in " \
                         
f"{tuple(_AVAILABLE_FUNCTIONS.keys())}."
            function = params["function"]
            if ve.args[0].endswith(end_string) \
                    
and isinstance(function, (list, tuple)):
                params["function"] = [f for f in function
                                      if isinstance(f, str)]
                validate_params(params, _hyperparameters, exclude=["n_jobs"])
            else:
                raise ve
        if isinstance(self.function, (list, tuple)) \
                
and isinstance(self.function_params, dict):
            raise TypeError("If `function` is a list/tuple then "
                            "`function_params` must be a list/tuple of dict, "
                            "or None.")
        elif isinstance(self.function, (str, FunctionType)) \
                
and isinstance(self.function_params, (list, tuple)):
            raise TypeError("If `function` is a string or a callable "
                            "function then `function_params` must be a dict "
                            "or None.")
[docs]    def fit(self, X, y=None):
        """Compute :attr:`n_channels_` and :attr:`effective_function_params_`.
        Then, return the estimator.
        This function is here to implement the usual scikit-learn API and hence
        work in pipelines.
        Parameters
        ----------
        X : ndarray of shape (n_samples, n_channels, n_bins)
            Input data. Collection of multi-channel curves.
        y : None
            There is no need for a target in a transformer, yet the pipeline
            API requires this parameter.
        Returns
        -------
        self : object
        """
        check_array(X, ensure_2d=False, allow_nd=True)
        if X.ndim != 3:
            raise ValueError("Input must be 3-dimensional.")
        self._validate_params()
        self.n_channels_ = X.shape[1]
        if isinstance(self.function, str):
            self.effective_function_ = \
                
_implemented_function_recipes[self.function]
            if self.function_params is None:
                self.effective_function_params_ = {}
            else:
                validate_params(self.function_params,
                                _AVAILABLE_FUNCTIONS[self.function])
                self.effective_function_params_ = self.function_params.copy()
        elif isinstance(self.function, FunctionType):
            self.effective_function_ = \
                
tuple([self.function] * self.n_channels_)
            if self.function_params is None:
                self.effective_function_params_ = \
                    
tuple([{}] * self.n_channels_)
            else:
                self.effective_function_params_ = \
                    
tuple([self.function_params.copy()] * self.n_channels_)
        else:
            n_functions = len(self.function)
            if len(self.function) != self.n_channels_:
                raise ValueError(
                    f"`function` has length {n_functions} while curves in `X` "
                    f"have {self.n_channels_} channels."
                    )
            if self.function_params is None:
                self._effective_function_params = [{}] * self.n_channels_
            else:
                self._effective_function_params = self.function_params
                n_function_params = len(self._effective_function_params)
                if n_function_params != self.n_channels_:
                    raise ValueError(f"`function_params` has length "
                                     f"{n_function_params} while curves in "
                                     f"`X` have {self.n_channels_} channels.")
            self.effective_function_ = []
            self.effective_function_params_ = []
            for f, p in zip(self.function, self._effective_function_params):
                if isinstance(f, str):
                    validate_params(p, _AVAILABLE_FUNCTIONS[f])
                    self.effective_function_.\
                        
append(_implemented_function_recipes[f])
                else:
                    self.effective_function_.append(f)
                self.effective_function_params_.append({} if p is None
                                                       else p.copy())
            self.effective_function_ = tuple(self.effective_function_)
            self.effective_function_params_ = \
                
tuple(self.effective_function_params_)
        return self