wmean#

hypercoil.loss.scalarise.wmean(input: Tensor, weight: Tensor, axis: int | Sequence[int] | None = None, keepdims: bool = False) Tensor[source]#

Rank-reducing function for scalarisation maps: weighted mean.

>>> wmean(jnp.array([1, 2, 3]), jnp.array([1, 0, 1]))
Array(2., dtype=float32)
>>> wmean(
...     jnp.array([[1, 2, 3],
...                [1, 2, 3],
...                [1, 2, 3]]),
...     jnp.array([1, 0, 1]),
...     axis=0
... )
Array([1., 2., 3.], dtype=float32)
>>> wmean(
...     jnp.array([[1, 2, 3],
...                [1, 2, 3],
...                [1, 2, 3]]),
...     jnp.array([1, 0, 1]),
...     axis=1,
...     keepdims=True
... )
Array([[2.],
             [2.],
             [2.]], dtype=float32)