meansq_scalarise: Squared mean scalarisation#

hypercoil.loss.scalarise.meansq_scalarise(*, inner: Callable | None = None, axis: int | Sequence[int] | None = None, keepdims: bool = False, key: PRNGKey | None = None) Callable[[Callable[[...], Tensor]], Callable[[...], float]][source]#

Transform a tensor-valued function to a scalar-valued function by taking the mean of the elementwise squared tensor.

Note

This function is a scalarisation map, which can be either used on its own or composed with other scalarisation maps to form a scalarisation.

Composition is performed by passing the scalarisation map to be composed as the inner argument to the outer scalarisation map. This can be chained indefinitely.

The output of a scalarisation map (or composition thereof) is a function that maps a tensor-valued function to a scalar-valued function. Be mindful of the signature of scalarisation maps: they do not themselves perform this map, but rather return a function that does.

For example, the following are all valid scalarisations of a function f:

max_scalarise()(f)
mean_scalarise(inner=max_scalarise(axis=-1))(f)
mean_scalarise(inner=vnorm_scalarise(p=1, axis=(0, 2)))(f)

Calling the scalarised function will return a scalar value.

Parameters:
innerCallable, optional

The inner scalarisation map. If not specified, the identity map is used. For many scalarisation maps, the default settings for the inner map and the scalarisation axis amount to applying the scalarisation map over the entire tensor. Users are advised to verify the default settings for the inner map and the scalarisation axis.

axis: Union[int, Sequence[int]], optional

The axis or axes over which to apply the scalarisation map. If not specified, the scalarisation map is applied over the entire tensor, except in the cases of maps that are only defined over vectors or matrices (e.g., norm scalarisations). Check the default arguments to verify the default behaviour.

Warning

When composing scalarisation maps, the value of the axis argument will refer to the axis or axes of the reduced-rank tensor that is the output of the inner scalarisation map. For example, consider the following composition:

mean_scalarise(inner=max_scalarise(axis=-1), axis=-2)(f)

The axis argument of the outer scalarisation map refers to the third from last axis of the original tensor, which is the second from last axis of the tensor after the inner scalarisation map has been applied.

Setting keepdims=True for the inner scalarisations will result in the (perhaps more intuitive) behaviour of the axis argument referring to the axis or axes of the original tensor.

keepdims: bool, optional

Whether to keep the reduced dimensions in the output. If True, the output will have the same number of dimensions as the input, each of singleton size. If False, the output will have one fewer dimension than the input for each dimension over which the scalarisation map is applied.

Note

It can be useful to set keepdims=True when composing scalarisation maps, as the axis argument can then be specified with reference to the original tensor rather than the reduced tensor, only setting keepdims=False for the outermost scalarisation map.

keyOptional[jax.random.PRNGKey]

An optional random number generator key. Unused; exists for conformance with potential future scalarisation maps that could inject randomness.

Returns:
Callable[[Callable[…, Tensor]], Callable[…, float]]

The scalarisation transformation. This is a function that takes a tensor-valued function and returns a scalar-valued function.