"""
Module for signal processing.
"""
import cmath
import numba
import numpy as np
import scipy.fftpack
from scipy.signal import butter, filtfilt, hilbert
__all__ = [
    "Filter",
    "ButterworthBandpass",
    "ComposedFilter",
    "Hilbert",
    "NoFilter",
    "Abs",
    "Gaussian",
    "Hanning",
    "rfft_to_hilbert",
    "timeshift_spectra",
]
[docs]
class Filter:
    """
    Abstract filter.
    To implement a new filter, create a derived class and implement the following method such as:
      - ``__init__`` initialiases the filter (take as many arguments as required),
      - ``__call__`` actually does something on the data (take as argument the data to filter),
      - ``__str__`` returns a description of the filter.
    Filters can be composed by using the ``+`` operator.
    """
    def __add__(self, inner_filter):
        """Composition operator for Filter objects."""
        return ComposedFilter(self, inner_filter)
    def __call__(self, *args, **kwargs):
        """Apply the filter on data; to implement in derived class."""
        raise NotImplementedError
    def __str__(self):
        """Description of the filter; to implement in derived class."""
        return "Unspecified filter" 
[docs]
class NoFilter(Filter):
    """
    A filter that does nothing (return data unchanged).
    """
    def __call__(self, arr):
        return arr
    def __str__(self):
        return "No filter" 
[docs]
class ComposedFilter(Filter):
    """
    Composed filter.
    When called, this filter applies each of its subfilters on the data.
    """
    def __init__(self, outer_filters, inner_filters):
        try:
            # If outer_filters is a composed filter:
            outer_ops = outer_filters.ops
        except AttributeError:
            # If outer_filters is a single filter:
            outer_ops = [outer_filters]
        try:
            inner_ops = inner_filters.ops
        except AttributeError:
            inner_ops = [inner_filters]
        self.ops = outer_ops + inner_ops
    def __len__(self):
        return len(self.ops)
    def __call__(self, arr, **kwargs):
        """
        Parameters
        ----------
        arr
            Array to process
        kwargs: dictionary
            Arguments to pass to the __call__ method of each part of the composed filter. Must be indexed by
            the instance of the filter.
        Returns
        -------
        filtered_arr
        """
        out = arr
        for op in reversed(self.ops):
            try:
                op_kwargs = kwargs.pop(op)
            except KeyError:
                out = op(out)
            else:
                out = op(out, **op_kwargs)
        if len(kwargs) != 0:
            raise ValueError(f"Unexpected keys: {kwargs.keys()}")
        return out
    def __str__(self):
        return "\n".join([str(op) for op in self.ops]) 
[docs]
class ButterworthBandpass(Filter):
    """
    Butterworth bandpass filter.
    Parameters
    ----------
    order : int
        Order of the filter
    cutoff_min, cutoff_max : float
        Cutoff frequencies in Hz.
    time : arim.Time
        Time object. This filter can be used only on data sampled consistently with the
        attribute ``time``.
    """
    def __init__(self, order, cutoff_min, cutoff_max, time):
        nyquist = 0.5 / time.step
        cutoff_min = cutoff_min * 1.0
        cutoff_max = cutoff_max * 1.0
        Wn = np.array([cutoff_min, cutoff_max]) / nyquist
        self.order = order
        self.cutoff_min = cutoff_min
        self.cutoff_max = cutoff_max
        self.b, self.a = butter(order, Wn, btype="bandpass")
    def __str__(self):
        return "{} [{:.1f}, {:.1f}] MHz order {}".format(
            self.__class__.__qualname__,
            self.cutoff_min * 1e-6,
            self.cutoff_max * 1e-6,
            self.order,
        )
    def __call__(self, arr, axis=-1, **kwargs):
        """
        Apply the filter on array with ``scipy.signal.filtfilt`` (zero-phase filtering).
        Parameters
        ----------
        arr
        axis
        kwargs: extra arguments for
        Returns
        -------
        filtered_arr
        """
        return np.ascontiguousarray(filtfilt(self.b, self.a, arr, axis=axis, **kwargs))
    def __repr__(self):
        return f"<{str(self)} at {hex(id(self))}>" 
[docs]
class Hanning(Filter):
    """
    Hanning Filter - Apply the Hann function.
    Return the analytical signal
    Parameters
    ----------
    nsamples : int
        ``len(time)``
    centre_freq : float
        In Hz
    half_bandwidth : float
        In Hz
    time : arim.Time
        Time object. This filter can be used only on data sampled consistently with the attribute
    ``time``.
    force_zero : bool
        If True (default), the spectrum amplitudes below ``-db_down`` will be
        replaced by exactly zero.
    db_down : float
    """
    def __init__(self, nsamples, centre_freq, half_bandwidth, time):
        max_freq = 1.0 / (time.step)
        peak_pos_fract = centre_freq / max_freq
        half_width_fract = half_bandwidth / max_freq
        r = np.arange(nsamples) / (nsamples - 1)
        r1 = 0.5 * (1 + np.cos((r - peak_pos_fract) / half_width_fract * np.pi))
        self.samples = nsamples
        self.centre_freq = centre_freq
        self.half_bandwidth = half_bandwidth
        self.max_freq = max_freq
        self.filter_window = r1 * np.logical_and(
            (r >= (peak_pos_fract - half_width_fract)),
            (r <= peak_pos_fract + half_width_fract),
        )
    def __str__(self):
        return "{} [{:.1f}, {:.1f}] MHz order {}".format(
            self.__class__.__qualname__,
            self.max_freq * 1e-6,
            self.half_bandwidth * 1e-6,
            self.max_freq * 1e-6,
        )
    def __call__(self, arr):
        arr = np.asarray(arr)
        # broadcast window to (1, 1, ..., numsamples)
        window = np.array(self.filter_window, ndmin=arr.ndim)
        return np.fft.ifft(np.fft.fft(arr) * window) 
[docs]
class Hilbert(Filter):
    """
    Returns the analytical signal, i.e. ``signal + i * hilbert_signal`` where
    ``hilbert_signal`` is the Hilbert transform of ``signal``.
    """
    def __call__(self, arr, axis=-1):
        return hilbert(arr, axis=axis)
    def __str__(self):
        return "Hilbert transform" 
[docs]
class Abs(Filter):
    """
    Returns the absolute value of a signal.
    """
    def __call__(self, arr):
        return np.abs(arr)
    def __str__(self):
        return "Absolute value" 
[docs]
class Gaussian(Filter):
    """
    Gaussian Filter - As applied in BRAIN **BUT** default is zero outside of filter region, BRAIN is not.
    Return the analytical signal
    Parameters
    ----------
    nsamples : int
        ``len(time)``
    centre_freq : float
        In Hz
    half_bandwidth : float
        In Hz
    time : arim.Time
        Time object. This filter can be used only on data sampled consistently with the attribute
    ``time``.
    force_zero : bool
        If True (default), the spectrum amplitudes below ``-db_down`` will be
        replaced by exactly zero.
    db_down : float
    """
    def __init__(
        self, nsamples, centre_freq, half_bandwidth, time, force_zero=True, db_down=40.0
    ):
        fract = np.power(10, -db_down / 20.0)
        max_freq = 1.0 / (time.step)
        peak_pos_fract = centre_freq / max_freq
        half_width_fract = half_bandwidth / max_freq
        r = np.arange(nsamples) / (nsamples - 1) - peak_pos_fract
        r1 = half_width_fract / (np.sqrt(-np.log(fract)))
        self.samples = nsamples
        self.centre_freq = centre_freq
        self.half_bandwidth = half_bandwidth
        self.max_freq = max_freq
        self.filter_window = np.exp(-np.power(r / r1, 2))
        if force_zero:
            self.filter_window[self.filter_window < fract] = 0
    def __str__(self):
        return "{} [{:.1f}, {:.1f}] MHz order {}".format(
            self.__class__.__qualname__,
            self.max_freq * 1e-6,
            self.half_bandwidth * 1e-6,
            self.max_freq * 1e-6,
        )
    def __call__(self, arr):
        arr = np.asarray(arr)
        # broadcast window to (1, 1, ..., numsamples)
        window = np.array(self.filter_window, ndmin=arr.ndim)
        return np.fft.ifft(np.fft.fft(arr) * window) 
[docs]
def rfft_to_hilbert(xf, n, axis=-1):
    """
    Convert the Fourier transform of a real signal to the analytic signal.
    This is equivalent but faster than doing::
        scipy.signal.hilbert(np.fft.irfft(xf, n))
    where typically ::
        xf = np.fft.rfft(x)
        n = len(xf)
    Convert the positive frequency part as the spectrum, as obtained with ``numpy.fft.rfft``,
    Parameters
    ----------
    xf : ndarray
        Input array
    n : int
        Length of the time domain signal
    axis : int
        Default: -1
    Returns
    -------
    out : complex ndarray
    """
    # cf code of https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.hilbert.html
    if xf.ndim == 0:
        h = 1.0
    else:
        h = np.zeros(xf.shape[axis])
        if n % 2 == 0:
            h[0] = h[n // 2] = 1
            h[1 : n // 2] = 2
        else:
            h[0] = 1
            h[1 : (n + 1) // 2] = 2
    if xf.ndim > 1:
        ind = [np.newaxis] * xf.ndim
        ind[axis] = slice(None)
        h = h[tuple(ind)]
    return scipy.fftpack.ifft(h * xf, n, axis) 
@numba.guvectorize(
    [(numba.float64[:], numba.complex128[:], numba.float64[:], numba.complex128[:])],
    "(),(),(numfreq)->(numfreq)",
    nopython=True,
    target="parallel",
    cache=True,
)
def _timeshift_spectra_singlef(delays, unshifted_x, freq_array, out=None):
    for freq_idx in range(freq_array.shape[0]):
        out[freq_idx] = (
            cmath.exp(-2j * np.pi * freq_array[freq_idx] * delays[0]) * unshifted_x[0]
        )
@numba.guvectorize(
    [(numba.float64[:], numba.complex128[:], numba.float64[:], numba.complex128[:])],
    "(),(numfreq),(numfreq)->(numfreq)",
    nopython=True,
    target="parallel",
    cache=True,
)
def _timeshift_spectra_multif(delays, unshifted_x, freq_array, out=None):
    for freq_idx in range(freq_array.shape[0]):
        out[freq_idx] = (
            cmath.exp(-2j * np.pi * freq_array[freq_idx] * delays[0])
            * unshifted_x[freq_idx]
        )
[docs]
def timeshift_spectra(unshifted_x, delays, freq_array):
    """Time-shift spectra in frequency domain
    Case ``num_x_freq=numfreq``: returns::
        X(omega) exp(-i omega delay)
    Case ``num_x_freq=1``: returns::
        X(omega_0) exp(-i omega delay)
    Parameters
    ----------
    unshifted_x : ndarray
        Shape (shape1, num_x_freq)
    delays : ndarray
        Shape (shape1)
    freq_array : ndarray
        Shape (numfreq)
    Returns
    -------
    shifted_x
        Shape (shape1, numfreq)
    """
    num_tf_freq = unshifted_x.shape[-1]
    if num_tf_freq == 1:
        return _timeshift_spectra_singlef(delays, unshifted_x[..., 0], freq_array)
    else:
        return _timeshift_spectra_multif(delays, unshifted_x, freq_array)