freqfilter: Frequency-domain filter layer#

class hypercoil.nn.freqfilter.FrequencyDomainFilter(num_channels: int, clamp_points: Optional[Tensor] = None, clamp_values: Optional[Tensor] = None, freq_dim: Optional[int] = None, time_dim: Optional[int] = None, filter: Callable = <function product_filtfilt>, *, key: jax.random.PRNGKey)[source]#

Filtering or convolution via transfer function multiplication in the frequency domain.

Each time series in the input dataset is transformed into the frequency domain, where it is multiplied by the complex-valued transfer function of each filter in the module’s bank. Each filtered frequency spectrum is then transformed back into the time domain. To ensure a zero-phase filter, the filtered time series are reversed and the process is repeated.

Dimension:
Input : \((N, *, C, T)\)

N denotes batch size, * denotes any number of intervening dimensions, \(C\) denotes number of data channels or variables, T denotes number of time points or observations per channel.

Output : \((N, *, F, C, T)\)

F denotes number of filters.

Parameters:
filter_specslist(FreqFilterSpec)

A list of filter specifications implemented as FreqFilterSpec objects. These determine the filter bank that is applied to the input. Consult the FreqFilterSpec documentation for further details.

dimint or None

Number of frequency bins. This must be conformant with the time series supplied as input. If you are uncertain about the dimension in the frequency domain, it is possible to instead provide the time_dim argument (the length of the time series), but either time_dim or dim (but not both) must be specified.

time_dimint or None

Number of time points in the input time series. Either time_dim or dim (but not both) must be specified.

filtercallable (default product_filtfilt)

Callable function that takes as its arguments an input time series and a set of transfer functions. It transforms the input into the frequency domain, multiplies it by the transfer function bank, and transforms it back. By default, the product_filtfilt function is used to ensure a zero-phase filter.

domainDomain object (default AmplitudeAtanh)

A domain object from hypercoil.init.domain, used to specify the domain of the filter spectrum. An Identity object yields the raw transfer function, while an AmplitudeAtanh object transforms the amplitudes of each bin by the inverse tanh (atanh) function prior to convolution with the input. Using the AmplitudeAtanh domain thereby constrains transfer function amplitudes to [0, 1) and prevents explosive gain. An AmplitudeMultiLogit domain can be used to instantiate and learn a parcellation over frequencies.

Attributes:
preweightTensor \((F, D)\)

Filter bank transfer functions in the module’s domain. F denotes the total number of filters in the bank, and D denotes the dimension of the input dataset in the frequency domain. The weights are initialised to emulate each of the filters specified in the filter_specs parameter following the freqfilter_init_ function.

weightTensor \((F, D)\)

The transfer function weights as seen by the input dataset in the frequency domain. This entails mapping the weights out of the specified predomain and applying any clamps declared in the input specifications.

clamp_pointsTensor \((F, D)\)

Boolean-valued tensor mask indexing points in the transfer function that should be clamped to particular values. Any points so indexed will not be learnable. If this is None, then no clamp is applied.

clamp_valuesTensor \((V)\)

Tensor containing values to which the transfer functions are clamped. V denotes the total number of values to be clamped across all transfer functions. If this is None, then no clamp is applied.

Methods

__call__(input, *[, key])

Transform the input into the frequency domain, filter it, and transform the filtered signal back.

filter(weight, **params)

Perform zero-phase digital filtering of a signal via multiplication in the frequency domain.

from_specs