TangentProject#

hypercoil.nn.semidefinite.TangentProject(out_channels: int, matrix_size: int, psi: float = 0.0, scale: float = 1.0, dest: Literal['tangent', 'cone'] = 'tangent', *, key: jax.random.PRNGKey)[source]#

Tangent/cone projection with a learnable or fixed point of tangency.

At initialisation, a data sample is required to set the point of tangency. In particular, the tangency point is initialised as a mean of the dataset, which can be the standard Euclidean mean or a measure of central tendency specifically derived for positive semidefinite matrices. Data transported through the module is projected from the positive semidefinite cone into a proper subspace tangent to the cone at the reference point which is the module weight. Given a tangency point \(\Omega\), each input \(\Theta\) is projected as:

\(\vec{\Theta} = \log \Omega^{-1/2} \Theta \Omega^{-1/2}\)

Alternatively, the module destination can be set to the semidefinite cone, in which case symmetric matrices are projected into the cone using the same reference point:

\(\Theta = \Omega^{1/2} \exp \vec{\Theta} \Omega^{1/2}\)

From initialisation, the tangency point can be learned to optimise any differentiable loss.

Dimension:
Input : \((*, N, N)\)

* denotes any number of preceding dimensions and N denotes the size of each square symmetric matrix.

Output : \((*, C, N, N)\)

C denotes the number of output channels (points of tangency).

Parameters:
init_dataTensor

Data sample whose central tendency initialises the reference point of tangency.

mean_specslist(_SemidefiniteMean object)

Objects encoding a measure of central tendency in the positive semidefinite cone. Used to initialise the reference points of tangency. Selected from means on the semidefinite cone.

psifloat in [0, 1]

Conditioning factor to promote positive definiteness. If this is in (0, 1], the original input will be replaced with a convex combination of the input and an identity matrix.

\(\hat{X} = (1 - \psi) X + \psi I\)

A suitable value can be used to ensure that all eigenvalues are positive and therefore guarantee that the matrix is in the domain of projection operations.

scalefloat

Scaling factor for the initialisation of the reference point of tangency.

dest'tangent' or 'cone'

Target space/manifold of the projection operation.

Attributes:
weightTensor \((*, C, N, N)\)

Reference point of tangency for the projection between the semidefinite cone and a tangent subspace.