AtlasLinear
: Linear atlas layer#
- class hypercoil.nn.atlas.AtlasLinear(n_locations: Dict[str, int], n_labels: Dict[str, int], limits: Dict[str, Tuple[int, int]] = None, decoder: Dict[str, Tensor] | None = None, normalisation: Literal['mean', 'absmean', 'zscore', 'psc'] | None = 'mean', forward_mode: Literal['map', 'project'] = 'map', concatenate: bool = True, *, key: jax.random.PRNGKey)[source]#
Time series extraction from an atlas via a linear map.
- Dimension:
- Input : \((N, *, V, T)\)
N denotes batch size, * denotes any number of intervening dimensions, V denotes total number of voxels or spatial locations, T denotes number of time points or observations.
- Output : \((N, *, L, T)\)
L denotes number of labels in the provided atlas.
Note
To initialise the atlas linear module from a pre-defined atlas, use the class method
from_atlas()
or thehypercoil.init.atlas.AtlasInitialiser
class after defining the atlas as ahypercoil.init.atlas.BaseAtlas
instance.If the module is initialised withouth an atlas, the atlas linear module will be initialised from a Dirichlet distribution with concentration \(50\) for each label.
- Parameters:
- n_locationsDict[str, int]
Number of locations (e.g., voxels or vertices) in each compartment.
- n_labelsDict[str, int]
Number of labels in each compartment.
- limitsDict[str, Tuple[int, int]], optional (default=None)
Limits of each compartment. The first element of the tuple denotes the lower limit and the second element denotes the size, i.e., the number of locations in the compartment – not the upper limit. If None, the limits are set to the default values, which are defined using the cumulative sum of the number of locations in each compartment.
- decoderOptional[Dict[str, Tensor]], optional (default=None)
Decoder for labels in each compartment. The decoder is an integer-valued tensor that defines the map from row numbers to label numbers. If None, the decoder corresponds to the identity map – i.e., the row numbers are the same as the label numbers.
- normalisation
'mean'
,'absmean'
,'zscore'
,'psc'
, or None (default'mean'
) Strategy for normalising across voxels and generating a representative time series for each label.
None or
sum
: No normalisation, i.e., use the weighted sum over voxel time series.mean
: Compute the weighted mean over voxel time series.absmean
: Compute the weighted mean over voxel time series, treating any negative voxel weights as though they were positive.zscore
: Transform the sum of time series such that its temporal mean is 0 and its temporal standard deviation is 1.psc
: Transform the time series such that its value indicates the percent signal change from the mean.
- forward_mode
'map'
or'project'
Strategy for extracting regional time series from parcels.
'map'
: Simple linear map. Given a compartment atlas \(A \in \mathbb{R}^{(L \times V)}\) and a vertex-wise or voxel-wise input time series \(T_{in} \in \mathbb{R}^{(V \times T)}\), returns\(T_{out} = A T_{in}\).
'project'
: Projection using a linear least-squares fit. Given a compartment atlas \(A \in \mathbb{R}^{(L \times V)}\) and a vertex-wise or voxel-wise input time series \(T_{in} \in \mathbb{R}^{(V \times T)}\), returns\[ \begin{align}\begin{aligned}\begin{aligned} T_{out} &= \min_{X \in \mathbb{R}^{(L \times T)}} \| A^\intercal X - T_{in} \|_F\\&= \left(A A^\intercal\right)^{-1} A T_{in} \end{aligned}\end{aligned}\end{align} \]
- concatenatebool, optional (default=True)
Whether to concatenate the output time series across compartments.
- Attributes:
- normalisation
Methods
__call__
(input[, normalisation, ...])Call self as a function.
from_atlas
(atlas[, normalisation, ...])Initialise the atlas module from an instance of a
BaseAtlas
set_default_limits