TangentProject
#
- hypercoil.nn.semidefinite.TangentProject(out_channels: int, matrix_size: int, psi: float = 0.0, scale: float = 1.0, dest: Literal['tangent', 'cone'] = 'tangent', *, key: jax.random.PRNGKey)[source]#
Tangent/cone projection with a learnable or fixed point of tangency.
At initialisation, a data sample is required to set the point of tangency. In particular, the tangency point is initialised as a mean of the dataset, which can be the standard Euclidean mean or a measure of central tendency specifically derived for positive semidefinite matrices. Data transported through the module is projected from the positive semidefinite cone into a proper subspace tangent to the cone at the reference point which is the module weight. Given a tangency point \(\Omega\), each input \(\Theta\) is projected as:
\(\vec{\Theta} = \log \Omega^{-1/2} \Theta \Omega^{-1/2}\)
Alternatively, the module destination can be set to the semidefinite cone, in which case symmetric matrices are projected into the cone using the same reference point:
\(\Theta = \Omega^{1/2} \exp \vec{\Theta} \Omega^{1/2}\)
From initialisation, the tangency point can be learned to optimise any differentiable loss.
- Dimension:
- Input : \((*, N, N)\)
*
denotes any number of preceding dimensions and N denotes the size of each square symmetric matrix.- Output : \((*, C, N, N)\)
C denotes the number of output channels (points of tangency).
- Parameters:
- init_dataTensor
Data sample whose central tendency initialises the reference point of tangency.
- mean_specslist(
_SemidefiniteMean
object) Objects encoding a measure of central tendency in the positive semidefinite cone. Used to initialise the reference points of tangency. Selected from means on the semidefinite cone.
- psifloat in [0, 1]
Conditioning factor to promote positive definiteness. If this is in (0, 1], the original input will be replaced with a convex combination of the input and an identity matrix.
\(\hat{X} = (1 - \psi) X + \psi I\)
A suitable value can be used to ensure that all eigenvalues are positive and therefore guarantee that the matrix is in the domain of projection operations.
- scalefloat
Scaling factor for the initialisation of the reference point of tangency.
- dest
'tangent'
or'cone'
Target space/manifold of the projection operation.
- Attributes:
- weightTensor \((*, C, N, N)\)
Reference point of tangency for the projection between the semidefinite cone and a tangent subspace.