AtlasLinear: Linear atlas layer#

class hypercoil.nn.atlas.AtlasLinear(n_locations: Dict[str, int], n_labels: Dict[str, int], limits: Dict[str, Tuple[int, int]] = None, decoder: Dict[str, Tensor] | None = None, normalisation: Literal['mean', 'absmean', 'zscore', 'psc'] | None = 'mean', forward_mode: Literal['map', 'project'] = 'map', concatenate: bool = True, *, key: jax.random.PRNGKey)[source]#

Time series extraction from an atlas via a linear map.

Dimension:
Input : \((N, *, V, T)\)

N denotes batch size, * denotes any number of intervening dimensions, V denotes total number of voxels or spatial locations, T denotes number of time points or observations.

Output : \((N, *, L, T)\)

L denotes number of labels in the provided atlas.

Note

To initialise the atlas linear module from a pre-defined atlas, use the class method from_atlas() or the hypercoil.init.atlas.AtlasInitialiser class after defining the atlas as a hypercoil.init.atlas.BaseAtlas instance.

If the module is initialised withouth an atlas, the atlas linear module will be initialised from a Dirichlet distribution with concentration \(50\) for each label.

Parameters:
n_locationsDict[str, int]

Number of locations (e.g., voxels or vertices) in each compartment.

n_labelsDict[str, int]

Number of labels in each compartment.

limitsDict[str, Tuple[int, int]], optional (default=None)

Limits of each compartment. The first element of the tuple denotes the lower limit and the second element denotes the size, i.e., the number of locations in the compartment – not the upper limit. If None, the limits are set to the default values, which are defined using the cumulative sum of the number of locations in each compartment.

decoderOptional[Dict[str, Tensor]], optional (default=None)

Decoder for labels in each compartment. The decoder is an integer-valued tensor that defines the map from row numbers to label numbers. If None, the decoder corresponds to the identity map – i.e., the row numbers are the same as the label numbers.

normalisation'mean', 'absmean', 'zscore', 'psc', or None (default 'mean')

Strategy for normalising across voxels and generating a representative time series for each label.

  • None or sum: No normalisation, i.e., use the weighted sum over voxel time series.

  • mean: Compute the weighted mean over voxel time series.

  • absmean: Compute the weighted mean over voxel time series, treating any negative voxel weights as though they were positive.

  • zscore: Transform the sum of time series such that its temporal mean is 0 and its temporal standard deviation is 1.

  • psc: Transform the time series such that its value indicates the percent signal change from the mean.

forward_mode'map' or 'project'

Strategy for extracting regional time series from parcels.

  • 'map': Simple linear map. Given a compartment atlas \(A \in \mathbb{R}^{(L \times V)}\) and a vertex-wise or voxel-wise input time series \(T_{in} \in \mathbb{R}^{(V \times T)}\), returns

    \(T_{out} = A T_{in}\).

  • 'project': Projection using a linear least-squares fit. Given a compartment atlas \(A \in \mathbb{R}^{(L \times V)}\) and a vertex-wise or voxel-wise input time series \(T_{in} \in \mathbb{R}^{(V \times T)}\), returns

    \[ \begin{align}\begin{aligned}\begin{aligned} T_{out} &= \min_{X \in \mathbb{R}^{(L \times T)}} \| A^\intercal X - T_{in} \|_F\\&= \left(A A^\intercal\right)^{-1} A T_{in} \end{aligned}\end{aligned}\end{align} \]
concatenatebool, optional (default=True)

Whether to concatenate the output time series across compartments.

Attributes:
normalisation

Methods

__call__(input[, normalisation, ...])

Call self as a function.

from_atlas(atlas[, normalisation, ...])

Initialise the atlas module from an instance of a BaseAtlas

set_default_limits