AtlasLinear: Linear atlas layer#

class hypercoil.nn.atlas.AtlasLinear(n_locations: Dict[str, int], n_labels: Dict[str, int], limits: Dict[str, Tuple[int, int]] = None, decoder: Dict[str, Tensor] | None = None, normalisation: Literal['mean', 'absmean', 'zscore', 'psc'] | None = 'mean', forward_mode: Literal['map', 'project'] = 'map', concatenate: bool = True, *, key: jax.random.PRNGKey)[source]#

Time series extraction from an atlas via a linear map.

Dimension:
Input : (N,,V,T)

N denotes batch size, * denotes any number of intervening dimensions, V denotes total number of voxels or spatial locations, T denotes number of time points or observations.

Output : (N,,L,T)

L denotes number of labels in the provided atlas.

Note

To initialise the atlas linear module from a pre-defined atlas, use the class method from_atlas() or the hypercoil.init.atlas.AtlasInitialiser class after defining the atlas as a hypercoil.init.atlas.BaseAtlas instance.

If the module is initialised withouth an atlas, the atlas linear module will be initialised from a Dirichlet distribution with concentration 50 for each label.

Parameters:
n_locationsDict[str, int]

Number of locations (e.g., voxels or vertices) in each compartment.

n_labelsDict[str, int]

Number of labels in each compartment.

limitsDict[str, Tuple[int, int]], optional (default=None)

Limits of each compartment. The first element of the tuple denotes the lower limit and the second element denotes the size, i.e., the number of locations in the compartment – not the upper limit. If None, the limits are set to the default values, which are defined using the cumulative sum of the number of locations in each compartment.

decoderOptional[Dict[str, Tensor]], optional (default=None)

Decoder for labels in each compartment. The decoder is an integer-valued tensor that defines the map from row numbers to label numbers. If None, the decoder corresponds to the identity map – i.e., the row numbers are the same as the label numbers.

normalisation'mean', 'absmean', 'zscore', 'psc', or None (default 'mean')

Strategy for normalising across voxels and generating a representative time series for each label.

  • None or sum: No normalisation, i.e., use the weighted sum over voxel time series.

  • mean: Compute the weighted mean over voxel time series.

  • absmean: Compute the weighted mean over voxel time series, treating any negative voxel weights as though they were positive.

  • zscore: Transform the sum of time series such that its temporal mean is 0 and its temporal standard deviation is 1.

  • psc: Transform the time series such that its value indicates the percent signal change from the mean.

forward_mode'map' or 'project'

Strategy for extracting regional time series from parcels.

  • 'map': Simple linear map. Given a compartment atlas AR(L×V) and a vertex-wise or voxel-wise input time series TinR(V×T), returns

    Tout=ATin.

  • 'project': Projection using a linear least-squares fit. Given a compartment atlas AR(L×V) and a vertex-wise or voxel-wise input time series TinR(V×T), returns

    Tout=minXR(L×T)AXTinF=(AA)1ATin
concatenatebool, optional (default=True)

Whether to concatenate the output time series across compartments.

Attributes:
normalisation

Methods

__call__(input[, normalisation, ...])

Call self as a function.

from_atlas(atlas[, normalisation, ...])

Initialise the atlas module from an instance of a BaseAtlas

set_default_limits