log_det_gram: Gram log-determinant loss#

log_det_gram#

hypercoil.loss.log_det_gram(X: Tensor, theta: Tensor | None = None, *, op: ~typing.Callable | None = <function corr_kernel>, psi: float | None = 0.0, xi: float | None = 0.0, key: ~jax._src.random.PRNGKey | None = None)[source]#

Gramian log-determinant score function.

This function computes the determinant of the Gram matrix of the input tensor, defined according to the kernel function op. The kernel function should always be a positive semi-definite function, and additional arguments are provided to ensure a non-singular (i.e., strictly positive definite) matrix.

Parameters:
XTensor

Input tensor.

thetaTensor, optional (default: None)

Kernel parameter tensor. If None, then the kernel is assumed to be isotropic.

opCallable, optional (default: corr_kernel())

Kernel function. By default, the Pearson correlation kernel is used.

psifloat, optional (default: 0.)

Kernel regularisation parameter. If psi > 0, then the kernel matrix is regularised by adding psi to the diagonal. This can be used to ensure that the matrix is strictly positive definite.

xifloat, optional (default: 0.)

Kernel regularisation parameter. If xi > 0, then the kernel matrix is regularised by stochastically adding samples from a uniform distribution with support \(\psi - \xi, \xi\) to the diagonal. This can be used to ensure that the matrix does not have degenerate eigenvalues. If xi > 0, then psi must also be greater than xi and a key must be provided.

key: PRNGKey, optional (default: ``None``)

Random number generator key. This is only required if xi > 0.

Returns:
Tensor

Gramian determinant score for each set of observations.

GramLogDeterminantLoss#

class hypercoil.loss.GramLogDeterminantLoss(nu: float = 1.0, name: Optional[str] = None, *, op: Callable = <function corr_kernel>, theta: Optional[Tensor] = None, psi: float = 0.0, xi: float = 0.0, scalarisation: Optional[Callable] = None, key: Optional['jax.random.PRNGKey'] = None)[source]#

Loss based on the log-determinant of the Gram matrix.

This function computes the determinant of the Gram matrix of the input tensor, defined according to the kernel function op. The kernel function should always be a positive semi-definite function, and additional arguments are provided to ensure a non-singular (i.e., strictly positive definite) matrix.

Log-det-Gram

The log-det-Gram loss among a set of vectors \(X\) is defined as the negative log-determinant of the Gram matrix of those vectors.

\(-\log \det \mathbf{K}(X)\)

hypercoil/_images/determinant.svg

Penalising the negative log-determinant of a Gram matrix can promote a degree of independence among the vectors being correlated.

One example of the log-det-Gram loss is the log-det-corr loss, which penalises the negative log-determinant of the correlation matrix of a set of vectors. This has a number of desirable properties and applications outlined below.

Correlation matrices, which occur frequently in time series analysis, have several properties that make them well-suited for loss functions based on the Gram determinant.

First, correlation matrices are positive semidefinite, and accordingly their determinants will always be nonnegative. For positive semidefinite matrices, the log-determinant is a concave function and accordingly has a global maximum that can be identified using convex optimisation methods.

Second, correlation matrices are normalised such that their determinant attains a maximum value of 1. This maximum corresponds to an identity correlation matrix, which in turn occurs when the vectors or time series input to the correlation are orthogonal. Thus, a strong determinant-based loss applied to a correlation matrix will seek an orthogonal basis of input vectors.

In the parcellation setting, a weaker log-det-corr loss can be leveraged to promote relative independence of parcels. Combined with a second-moment loss, a log-det-corr loss can be interpreted as inducing a clustering: the second moment loss favours internal similarity of clusters, while the log-det-corr loss favours separation of different clusters.

Warning

Determinant-based losses use jax’s determinant functionality, which itself might use the singular value decomposition in certain cases. Differentiation through SVD involves terms whose denominators include the differences between pairs of singular values. Thus, if two singular values of the input matrix are close together, the gradient can become unstable (and undefined if the singular values are identical). A simple matrix reconditioning procedure is available for all operations involving the determinant to reduce the occurrence of degenerate eigenvalues.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

thetaTensor, optional (default: None)

Kernel parameter tensor. If None, then the kernel is assumed to be isotropic.

opCallable, optional (default: corr_kernel())

Kernel function. By default, the Pearson correlation kernel is used.

psifloat, optional (default: 0.)

Kernel regularisation parameter. If psi > 0, then the kernel matrix is regularised by adding psi to the diagonal. This can be used to ensure that the matrix is strictly positive definite.

xifloat, optional (default: 0.)

Kernel regularisation parameter. If xi > 0, then the kernel matrix is regularised by stochastically adding samples from a uniform distribution with support \(\psi - \xi, \xi\) to the diagonal. This can be used to ensure that the matrix does not have degenerate eigenvalues. If xi > 0, then psi must also be greater than xi and a key must be provided.

key: PRNGKey, optional (default: ``None``)

Random number generator key. This is only required if xi > 0.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.