LinearCombinationSelector#

class hypercoil.nn.confound.LinearCombinationSelector(model_dim: int, num_columns: int, *, key: jax.random.PRNGKey)[source]#

Model selection as a linear combination.

Learn linear combinations of candidate vectors to produce a model. Thin wrapper around LinearRFNN without the convolutional layers for learning response functions.

Dimension:
Input : \((*, I, T)\)

* denotes any number of preceding dimensions, \(I\) denotes number of candidate model vectors, \(T\) denotes number of time points or observations per vector.

Output : \((*, O, T)\)

\(O\) denotes the final model dimension.

Parameters:
model_dimint

Dimension of the model to be learned.

n_columnsint

Number of input vectors to be combined linearly to form the model.

Attributes:
weighttensor

Tensor of shape \((I, O)\) n_columns x model_dim.