second_moment: Second moments#

second_moment#

hypercoil.loss.second_moment(X: Tensor, weight: Tensor, *, standardise: bool = False, skip_normalise: bool = False, key: PRNGKey | None = None) Tensor[source]#

Compute the second moment of a dataset.

Given an input matrix \(T\) and a weight matrix \(A\), the second moment is computed as

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

The term \(\frac{AT}{A\mathbf{1}}\) can also be precomputed and passed as the mu argument to the second_moment_centred() function. If the mean is already known, it is more efficient to use that function. Otherwise, the second_moment() function will compute the mean internally.

Parameters:
X: Tensor

A tensor of observations.

weight: Tensor

A tensor of weights.

standardise: bool, optional

If True, z-score the input matrix before computing the second moment. The default is False.

skip_normalise: bool, optional

If True, do not include normalisation by the sum of the weights in the computation. In practice, this seems to work better than computing the actual second moment. Instead of computing the second moment, this corresponds to computed a weighted mean squared error about the mean. The default is False.

hypercoil.loss.second_moment_centred(X: Tensor, weight: Tensor, mu: Tensor, *, standardise_data: bool = False, standardise_mu: bool = False, skip_normalise: bool = False, key: PRNGKey | None = None) Tensor[source]#

Compute the second moment of a dataset about a specified mean.

Given an input matrix \(T\) and a weight matrix \(A\), the second moment is computed as

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

The term \(\frac{AT}{A\mathbf{1}}\) can also be precomputed and passed as the mu argument to the second_moment_centred() function. If the mean is already known, it is more efficient to use that function. Otherwise, the second_moment() function will compute the mean internally.

Parameters:
X: Tensor

A tensor of observations.

weight: Tensor

A tensor of weights.

standardise_data: bool, optional

If True, z-score the input matrix before computing the second moment. The default is False.

standardise_mu: bool, optional

If True, z-score the mean matrix mu before computing the second moment. The default is False.

skip_normalise: bool, optional

If True, do not include normalisation by the sum of the weights in the computation. In practice, this seems to work better than computing the actual second moment. Instead of computing the second moment, this corresponds to computed a weighted mean squared error about the mean. The default is False.

SecondMomentLoss#

class hypercoil.loss.SecondMomentLoss(nu: float = 1.0, name: str | None = None, *, standardise: bool = False, skip_normalise: bool = False, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Second moment loss.

Given an input matrix \(T\) and a weight matrix \(A\), the second moment is computed as

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

The term \(\frac{AT}{A\mathbf{1}}\) can also be precomputed and passed as the mu argument to the second_moment_centred() function. If the mean is already known, it is more efficient to use that function. Otherwise, the second_moment() function will compute the mean internally.

Regularise the second moment, e.g. to favour a dimension reduction mapping that is internally homogeneous.

Second Moment

Second moment losses are based on a reduction of the second moment quantity

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

where the division operator is applied elementwise with broadcasting and the difference operator is applied via broadcasting. The broadcasting operations involved in the core computation – estimating a weighted mean and then computing the weighted sum of squares about that mean – are illustrated in the below cartoon.

hypercoil/_images/secondmomentloss.svg

Illustration of the most memory-intensive stage of loss computation. The lavender tensor represents the weighted mean, the blue tensor the original observations, and the green tensor the weights (which might correspond to a dimension reduction mapping such as a parcellation).

Note

In practice, we’ve found that using the actual second moment loss often results in large and uneven parcels. Accordingly, an unnormalised extension of the second moment (which omits the normalisation \(\frac{1}{A \mathbf{1}}\)) is also available. This unnormalised quantity is equivalent to the weighted mean squared error about each weighted mean. In practice, we’ve found that this quantity works better for most of our use cases.

Warning

This loss can have a very large memory footprint, because it requires computing an intermediate tensor with dimensions equal to the number of rows in the linear mapping, multiplied by the number of columns in the linear mapping, multiplied by the number of columns in the dataset.

When using this loss to learn a parcellation on voxelwise time series, the full computation will certainly be much too large to fit in GPU memory. Fortunately, because much of the computation is elementwise, it can be broken down along multiple axes without affecting the result. This tensor slicing is implemented automatically in the ReactiveTerminal class. Use extreme caution with ReactiveTerminals, as improper use can result in destruction of the computational graph.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

standardise: bool, optional

If True, z-score the input matrix before computing the second moment. The default is False.

skip_normalise: bool, optional

If True, do not include normalisation by the sum of the weights in the computation. In practice, this seems to work better than computing the actual second moment. Instead of computing the second moment, this corresponds to computed a weighted mean squared error about the mean. The default is False.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

Methods

__call__(X, weight, *[, key])

Call self as a function.

class hypercoil.loss.SecondMomentCentredLoss(nu: float = 1.0, name: str | None = None, *, standardise_data: bool = False, standardise_mu: bool = False, skip_normalise: bool = False, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Second moment loss centred on a precomputed mean.

Given an input matrix \(T\) and a weight matrix \(A\), the second moment is computed as

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

The term \(\frac{AT}{A\mathbf{1}}\) can also be precomputed and passed as the mu argument to the second_moment_centred() function. If the mean is already known, it is more efficient to use that function. Otherwise, the second_moment() function will compute the mean internally.

Regularise the second moment, e.g. to favour a dimension reduction mapping that is internally homogeneous.

Second Moment

Second moment losses are based on a reduction of the second moment quantity

\(\left[ A \circ \left (T - \frac{AT}{A\mathbf{1}} \right )^2 \right] \frac{\mathbf{1}}{A \mathbf{1}}\)

where the division operator is applied elementwise with broadcasting and the difference operator is applied via broadcasting. The broadcasting operations involved in the core computation – estimating a weighted mean and then computing the weighted sum of squares about that mean – are illustrated in the below cartoon.

hypercoil/_images/secondmomentloss.svg

Illustration of the most memory-intensive stage of loss computation. The lavender tensor represents the weighted mean, the blue tensor the original observations, and the green tensor the weights (which might correspond to a dimension reduction mapping such as a parcellation).

Note

In practice, we’ve found that using the actual second moment loss often results in large and uneven parcels. Accordingly, an unnormalised extension of the second moment (which omits the normalisation \(\frac{1}{A \mathbf{1}}\)) is also available. This unnormalised quantity is equivalent to the weighted mean squared error about each weighted mean. In practice, we’ve found that this quantity works better for most of our use cases.

Warning

This loss can have a very large memory footprint, because it requires computing an intermediate tensor with dimensions equal to the number of rows in the linear mapping, multiplied by the number of columns in the linear mapping, multiplied by the number of columns in the dataset.

When using this loss to learn a parcellation on voxelwise time series, the full computation will certainly be much too large to fit in GPU memory. Fortunately, because much of the computation is elementwise, it can be broken down along multiple axes without affecting the result. This tensor slicing is implemented automatically in the ReactiveTerminal class. Use extreme caution with ReactiveTerminals, as improper use can result in destruction of the computational graph.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

standardise: bool, optional

If True, z-score the input matrix before computing the second moment. The default is False.

skip_normalise: bool, optional

If True, do not include normalisation by the sum of the weights in the computation. In practice, this seems to work better than computing the actual second moment. Instead of computing the second moment, this corresponds to computed a weighted mean squared error about the mean. The default is False.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

Methods

__call__(X, weight, mu, *[, key])

Call self as a function.