equilibrium: Equilibrium loss#

equilibrium#

hypercoil.loss.equilibrium(X: Tensor, *, level_axis: int | Sequence[int] = -1, instance_axes: int | Sequence[int] = (-1, -2), key: PRNGKey | None = None) Tensor[source]#

Compute the parcel equilibrium.

The equilibrium scores the deviation of the total weight assigned to each parcel or level from the mean weight assigned to each parcel or level. It can be used to encourage the model to learn parcels that are balanced in size.

This function operates on probability tensors. For a version that operates on logits, see equilibrium_logit().

Parameters:
X: Tensor

A tensor of probabilities (or masses of another kind).

level_axis: int or sequence of ints, optional

The axis or axes over which to compute the equilibrium. Within each data instance or weight channel, all elements along the specified axis or axes should correspond to a single level or parcel. The default is -1.

prob_axis: int or sequence of ints, optional

The axis or axes over which to compute the probabilities (logit version only). The default is -2. In general the union of level_axis and prob_axis should be the same as instance_axes.

instance_axes: int or sequence of ints, optional

The axis or axes corresponding to a single data instance or weight channel. This should be a superset of level_axis. The default is (-1, -2).

keepdims: bool, optional

As in jax.numpy.sum(). The default is True.

Returns:
Tensor

A tensor of equilibrium scores for each parcel or level.

hypercoil.loss.equilibrium_logit(X: Tensor, *, level_axis: int | Sequence[int] = -1, prob_axis: int | Sequence[int] = -2, instance_axes: int | Sequence[int] = (-1, -2), key: PRNGKey | None = None) Tensor[source]#

Project logits in the input matrix onto the probability simplex, and then compute the parcel equilibrium.

The equilibrium scores the deviation of the total weight assigned to each parcel or level from the mean weight assigned to each parcel or level. It can be used to encourage the model to learn parcels that are balanced in size.

This function operates on logits. For a version that operates on probabilities, see equilibrium().

Parameters:
X: Tensor

A tensor of probabilities (or masses of another kind).

level_axis: int or sequence of ints, optional

The axis or axes over which to compute the equilibrium. Within each data instance or weight channel, all elements along the specified axis or axes should correspond to a single level or parcel. The default is -1.

prob_axis: int or sequence of ints, optional

The axis or axes over which to compute the probabilities (logit version only). The default is -2. In general the union of level_axis and prob_axis should be the same as instance_axes.

instance_axes: int or sequence of ints, optional

The axis or axes corresponding to a single data instance or weight channel. This should be a superset of level_axis. The default is (-1, -2).

keepdims: bool, optional

As in jax.numpy.sum(). The default is True.

Returns:
Tensor

A tensor of equilibrium scores for each parcel or level.

EquilibriumLoss#

class hypercoil.loss.EquilibriumLoss(nu: float = 1.0, name: str | None = None, *, level_axis: int | Tuple[int, ...] = -1, instance_axes: int | Tuple[int, ...] = (-2, -1), scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Mass equilibrium loss.

This loss operates on unmapped mass tensors. For a version that operates on logits, see EquilibriumLogitLoss.

The equilibrium scores the deviation of the total weight assigned to each parcel or level from the mean weight assigned to each parcel or level. It can be used to encourage the model to learn parcels that are balanced in size.

Loss functions to favour equal weight across one dimension of a tensor whose slices are masses.

Equilibrium

The equilibrium loss of a mass tensor \(A\) is defined as

\(\mathbf{1}^\intercal \left[\left(A \mathbf{1}\right) \circ \left(A \mathbf{1}\right) \right]\)

Th equilibrium loss has applications in the context of parcellation tensors. A parcellation tensor is one whose rows correspond to features (e.g., voxels, time points, frequency bins, or network nodes) and whose columns correspond to parcels. Element \(i, j\) in this tensor accordingly indexes the assignment of feature \(j\) to parcel \(i\). Examples of parcellation tensors might include atlases that map voxels to regions or affiliation matrices that map graph vertices to communities. It is often desirable to constrain feature-parcel assignments to \([0, k]\) for some \(k\) and ensure that the sum over each feature’s assignment is always \(k\). (Otherwise, the unnormalised loss could be improved by simply shrinking all weights.) For this reason, we can either normalise the loss or situate the parcellation tensor in the probability simplex using a multi-logit (softmax) domain mapper.

The equilibrium loss attains a minimum when parcels are equal in their total weight. It has a trivial and uninteresting minimum where all parcel assignments are equiprobable for all features. Other minima, which might be of greater interest, occur where each feature is deterministically assigned to a single parcel. These minima can be favoured by using the equilibrium in conjunction with a penalty on the entropy.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

level_axis: int or sequence of ints, optional

The axis or axes over which to compute the equilibrium. Within each data instance or weight channel, all elements along the specified axis or axes should correspond to a single level or parcel. The default is -1.

prob_axis: int or sequence of ints, optional

The axis or axes over which to compute the probabilities (logit version only). The default is -2. In general the union of level_axis and prob_axis should be the same as instance_axes.

instance_axes: int or sequence of ints, optional

The axis or axes corresponding to a single data instance or weight channel. This should be a superset of level_axis. The default is (-1, -2).

keepdims: bool, optional

As in jax.numpy.sum(). The default is True.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean square scalarisation is used.

Methods

__call__(X, *[, key])

Call self as a function.

class hypercoil.loss.EquilibriumLogitLoss(nu: float = 1.0, name: str | None = None, *, level_axis: int | Tuple[int, ...] = -1, prob_axis: int | Tuple[int, ...] = -2, instance_axes: int | Tuple[int, ...] = (-2, -1), scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Mass equilibrium loss.

This loss operates on logits. For a version that operates on unmapped mass tensors, see EquilibriumLoss.

The equilibrium scores the deviation of the total weight assigned to each parcel or level from the mean weight assigned to each parcel or level. It can be used to encourage the model to learn parcels that are balanced in size.

Loss functions to favour equal weight across one dimension of a tensor whose slices are masses.

Equilibrium

The equilibrium loss of a mass tensor \(A\) is defined as

\(\mathbf{1}^\intercal \left[\left(A \mathbf{1}\right) \circ \left(A \mathbf{1}\right) \right]\)

Th equilibrium loss has applications in the context of parcellation tensors. A parcellation tensor is one whose rows correspond to features (e.g., voxels, time points, frequency bins, or network nodes) and whose columns correspond to parcels. Element \(i, j\) in this tensor accordingly indexes the assignment of feature \(j\) to parcel \(i\). Examples of parcellation tensors might include atlases that map voxels to regions or affiliation matrices that map graph vertices to communities. It is often desirable to constrain feature-parcel assignments to \([0, k]\) for some \(k\) and ensure that the sum over each feature’s assignment is always \(k\). (Otherwise, the unnormalised loss could be improved by simply shrinking all weights.) For this reason, we can either normalise the loss or situate the parcellation tensor in the probability simplex using a multi-logit (softmax) domain mapper.

The equilibrium loss attains a minimum when parcels are equal in their total weight. It has a trivial and uninteresting minimum where all parcel assignments are equiprobable for all features. Other minima, which might be of greater interest, occur where each feature is deterministically assigned to a single parcel. These minima can be favoured by using the equilibrium in conjunction with a penalty on the entropy.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

level_axis: int or sequence of ints, optional

The axis or axes over which to compute the equilibrium. Within each data instance or weight channel, all elements along the specified axis or axes should correspond to a single level or parcel. The default is -1.

prob_axis: int or sequence of ints, optional

The axis or axes over which to compute the probabilities (logit version only). The default is -2. In general the union of level_axis and prob_axis should be the same as instance_axes.

instance_axes: int or sequence of ints, optional

The axis or axes corresponding to a single data instance or weight channel. This should be a superset of level_axis. The default is (-1, -2).

keepdims: bool, optional

As in jax.numpy.sum(). The default is True.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

Methods

__call__(X, *[, key])

Call self as a function.