selfwmean_scalarise: Self-weighted mean scalarisation#

hypercoil.loss.scalarise.selfwmean_scalarise(*, inner: Callable | None = None, axis: int | Sequence[int] | None = None, gradpath: Literal['weight', 'input'] | None = 'input', softmax_axis: Sequence[int] | int | bool | None = False, softmax_invert: bool = False, keepdims: bool = False, key: PRNGKey | None = None) Callable[[Callable[[...], Tensor]], Callable[[...], float]][source]#

Transform a tensor-valued function to a scalar-valued function by taking the self-weighted mean of the tensor along an axis or set of axes.

Note

This function is a scalarisation map, which can be either used on its own or composed with other scalarisation maps to form a scalarisation.

Composition is performed by passing the scalarisation map to be composed as the inner argument to the outer scalarisation map. This can be chained indefinitely.

The output of a scalarisation map (or composition thereof) is a function that maps a tensor-valued function to a scalar-valued function. Be mindful of the signature of scalarisation maps: they do not themselves perform this map, but rather return a function that does.

For example, the following are all valid scalarisations of a function f:

max_scalarise()(f)
mean_scalarise(inner=max_scalarise(axis=-1))(f)
mean_scalarise(inner=vnorm_scalarise(p=1, axis=(0, 2)))(f)

Calling the scalarised function will return a scalar value.

Parameters:
innerCallable, optional

The inner scalarisation map. If not specified, the identity map is used. For many scalarisation maps, the default settings for the inner map and the scalarisation axis amount to applying the scalarisation map over the entire tensor. Users are advised to verify the default settings for the inner map and the scalarisation axis.

axis: Union[int, Sequence[int]], optional

The axis or axes over which to apply the scalarisation map. If not specified, the scalarisation map is applied over the entire tensor, except in the cases of maps that are only defined over vectors or matrices (e.g., norm scalarisations). Check the default arguments to verify the default behaviour.

Warning

When composing scalarisation maps, the value of the axis argument will refer to the axis or axes of the reduced-rank tensor that is the output of the inner scalarisation map. For example, consider the following composition:

mean_scalarise(inner=max_scalarise(axis=-1), axis=-2)(f)

The axis argument of the outer scalarisation map refers to the third from last axis of the original tensor, which is the second from last axis of the tensor after the inner scalarisation map has been applied.

Setting keepdims=True for the inner scalarisations will result in the (perhaps more intuitive) behaviour of the axis argument referring to the axis or axes of the original tensor.

keepdims: bool, optional

Whether to keep the reduced dimensions in the output. If True, the output will have the same number of dimensions as the input, each of singleton size. If False, the output will have one fewer dimension than the input for each dimension over which the scalarisation map is applied.

Note

It can be useful to set keepdims=True when composing scalarisation maps, as the axis argument can then be specified with reference to the original tensor rather than the reduced tensor, only setting keepdims=False for the outermost scalarisation map.

gradpath: Optional[Literal[‘weight’, ‘input’]] (default: ‘input’)

If ‘weight’, the gradient of the scalarisation function will be backpropagated through the weights only. If ‘input’, the gradient of the scalarisation function will be backpropagated through the input only. If None, the gradient will be backpropagated through both.

softmax_axis: Optional[Union[Sequence[int], int, bool]] (default: False)

If not False, instead of using the input as the weight, the input is passed through a softmax function to create a weight. If True, the softmax is taken over all axes. If an integer or sequence of integers, the softmax is taken over those axes. If False, the input is used as the weight.

softmax_invert: bool (default: False)

If True, the input is negated before passing it through the softmax. In this way, the softmax can be used to upweight the minimum instead of the maximum.

keyOptional[jax.random.PRNGKey]

An optional random number generator key. Unused; exists for conformance with potential future scalarisation maps that could inject randomness.

Returns:
Callable[[Callable[…, Tensor]], Callable[…, float]]

The scalarisation transformation. This is a function that takes a tensor-valued function and returns a scalar-valued function.