ParameterisedLoss: Extensible class for custom parameterised losses#

class hypercoil.loss.ParameterisedLoss(score: Callable, scalarisation: Callable, nu: float = 1.0, name: str | None = None, *, params: Mapping | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Extensible class for loss functions with simple parameterisations.

This class is intended to be used as a base class for loss functions that have a simple parameterisation, i.e. a fixed set of parameters that are passed to the scoring function. The parameters are specified using the params argument, which should be a mapping from parameter names to values. Note that the class is immutable, so the parameters cannot be changed after the class has been instantiated.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

score: Callable

The scoring function to be used to compute the loss value. This function should take a single argument, which is a tensor of arbitrary shape, and return a score value for each (potentially multivariate) observation in the tensor.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

params: Mapping[str, Any]

A mapping from parameter names to values. These will be passed to the scoring function when the loss function is called.

Methods

__call__(*pparams[, key])

Call self as a function.