BatchTangentProject#

hypercoil.nn.semidefinite.BatchTangentProject(out_channels: int, matrix_size: int, mean_specs: Sequence[_SemidefiniteMean], psi: float = 0.0, scale: float = 1.0, inertia: float = 0.0, dest: Literal['tangent', 'cone'] = 'tangent', *, key: jax.random.PRNGKey)[source]#

Tangent/cone projection with a new tangency point computed for each batch.

Warning

Initialise this only using the from_specs class method.

Data transported through the module is projected from the positive semidefinite cone into a proper subspace tangent to the cone at the reference point which is the module weight. Given a tangency point \(\Omega\), each input \(\Theta\) is projected as:

\(\vec{\Theta} = \log \Omega^{-1/2} \Theta \Omega^{-1/2}\)

Alternatively, the module destination can be set to the semidefinite cone, in which case symmetric matrices are projected into the cone using the same reference point:

\(\Theta = \Omega^{1/2} \exp \vec{\Theta} \Omega^{1/2}\)

Here, the tangency point is computed as a convex combination of the previous tangency point and some measure of central tendency in the current data batch. The tangency point is not learnable. This module is almost definitely a bad idea, but it might somehow be helpful for regularisation, augmentation, or increasing the model’s robustness to different views on the input data.

The weight is updated only during projection into tangent space. Given an inertial parameter \(\eta\) and a measure of central tendency \(ar{\Theta}\), the weight is updated as

\(\Omega_t := \eta \Omega_{t-1} + (1 - \eta) ar{\Theta}\)

Dimension:
Input : \((*, N, N)\)

* denotes any number of preceding dimensions and N denotes the size of each square symmetric matrix.

Output : \((*, C, N, N)\)

C denotes the number of output channels (points of tangency).

Parameters:
mean_specslist(_SemidefiniteMean object)

Objects encoding a measure of central tendency in the positive semidefinite cone. Used to initialise the reference points of tangency. Selected from means on the semidefinite cone.

psifloat in [0, 1]

Conditioning factor to promote positive definiteness. If this is in (0, 1], the original input will be replaced with a convex combination of the input and an identity matrix.

\(\hat{X} = (1 - \psi) X + \psi I\)

A suitable value can be used to ensure that all eigenvalues are positive and therefore guarantee that the matrix is in the domain of projection operations.

inertiafloat in [0, 1]

Parameter \(\eta\) describing the relative weighting of the historical tangencypoint and the current batch mean. Zero inertia strictly uses the current batch mean. High inertia prevents the tangency point from skipping by heavily weighting the history.

dest'tangent' or 'cone'

Target space/manifold of the projection operation.

Attributes:
weightTensor \((*, C, N, N)\)

Current reference point of tangency \(\Omega\) for the projection between the semidefinite cone and a tangent subspace.