Source code for gtda.base

"""Implements a TransformerResamplerMixin for transformers that have a resample
method and TransformerPlotterMixin for transformers that have a plot method."""
# License: GNU AGPLv3


[docs]class TransformerResamplerMixin: """Mixin class for all transformers-resamplers in giotto-tda.""" _estimator_type = 'transformer_resampler'
[docs] def fit_transform(self, X, y=None, **fit_params): """Fit to data, then transform it. Fits transformer to `X` and `y` with optional parameters `fit_params` and returns a transformed version of `X`. Parameters ---------- X : ndarray of shape (n_samples, ...) Input data. y : None There is no need for a target in a transformer, yet the pipeline API requires this parameter. Returns ------- Xt : numpy array of shape (n_samples, ...) Transformed input. """ # non-optimized default implementation; override when a better # method is possible for a given clustering algorithm if y is None: # fit method of arity 1 (unsupervised transformation) return self.fit(X, **fit_params).transform(X) else: # fit method of arity 2 (supervised transformation) return self.fit(X, y, **fit_params).transform(X, y)
[docs] def transform_resample(self, X, y): """Fit to data, then transform it. Fits transformer to `X` and `y` with optional parameters `fit_params` and returns a transformed version of `X`. Parameters ---------- X : ndarray of shape (n_samples, ...) Input data. y : ndarray of shape (n_samples,) Target data. Returns ------- Xt : ndarray of shape (n_samples, ...) Transformed input. yr : ndarray of shape (n_samples, ...) Resampled target. """ return self.transform(X), self.resample(y, X)
[docs] def fit_transform_resample(self, X, y, **fit_params): """Fit to data, then transform the input and resample the target. Fits transformer to X and y with optional parameters fit_params and returns a transformed version of X ans a resampled version of y. Parameters ---------- X : ndarray of shape (n_samples, ...) Input data. y : ndarray of shape (n_samples,) Target data. Returns ------- Xt : ndarray of shape (n_samples, ...) Transformed input. yr : ndarray of shape (n_samples, ...) Resampled target. """ return self.fit(X, y, **fit_params).transform_resample(X, y)
[docs]class PlotterMixin: """Mixin class for all plotters in giotto-tda."""
[docs] def fit_transform_plot(self, X, y=None, sample=0, **plot_params): """Fit to data, then apply :meth:`transform_plot`. Parameters ---------- X : ndarray of shape (n_samples, ...) Input data. y : ndarray of shape (n_samples,) or None Target values for supervised problems. sample : int Sample to be plotted. **plot_params Optional plotting parameters. Returns ------- Xt : ndarray of shape (1, ...) Transformed one-sample slice from the input. """ self.fit(X, y) Xt = self.transform_plot(X, sample=sample, **plot_params) return Xt
[docs] def transform_plot(self, X, sample=0, **plot_params): """Take a one-sample slice from the input collection and transform it. Before returning the transformed object, plot the transformed sample. Parameters ---------- X : ndarray of shape (n_samples, ...) Input data. sample : int Sample to be plotted. plot_params : dict Optional plotting parameters. Returns ------- Xt : ndarray of shape (1, ...) Transformed one-sample slice from the input. """ Xt = self.transform(X[sample:sample+1]) self.plot(Xt, sample=0, **plot_params) return Xt