symexp#

hypercoil.functional.symmap.symexp(input: Tensor) Tensor[source]#

Matrix exponential of a batch of symmetric, positive definite matrices.

Computed by diagonalising the matrix \(X = Q_X \Lambda_X Q_X^\intercal\), computing the exponential of the eigenvalues, and recomposing.

\(\exp X = Q_X \exp \Lambda_X Q_X^\intercal\)

Note

This approach is in principle faster than the matrix exponential in JAX, but it is not as robust or general as the JAX implementation (jax.linalg.expm).

Dimension:
Input : \((N, *, D, D)\)

N denotes batch size, * denotes any number of intervening dimensions, D denotes matrix row and column dimension.

Output : \((N, *, D, D)\)

As above.

Parameters:
inputTensor

Batch of symmetric tensors to transform using the matrix exponential.

Returns:
outputTensor

Exponential of each matrix in the input batch.

Warning

jax.scipy.linalg.expm is generally more stable and recommended over this.