symexp
#
- hypercoil.functional.symmap.symexp(input: Tensor) Tensor [source]#
Matrix exponential of a batch of symmetric, positive definite matrices.
Computed by diagonalising the matrix \(X = Q_X \Lambda_X Q_X^\intercal\), computing the exponential of the eigenvalues, and recomposing.
\(\exp X = Q_X \exp \Lambda_X Q_X^\intercal\)
Note
This approach is in principle faster than the matrix exponential in JAX, but it is not as robust or general as the JAX implementation (
jax.linalg.expm
).- Dimension:
- Input : \((N, *, D, D)\)
N denotes batch size,
*
denotes any number of intervening dimensions, D denotes matrix row and column dimension.- Output : \((N, *, D, D)\)
As above.
- Parameters:
- inputTensor
Batch of symmetric tensors to transform using the matrix exponential.
- Returns:
- outputTensor
Exponential of each matrix in the input batch.
Warning
jax.scipy.linalg.expm
is generally more stable and recommended over this.