Loss
: Base class for scalar-valued losses#
- class hypercoil.loss.Loss(score: Callable, scalarisation: Callable, nu: float = 1.0, name: str | None = None, *, key: 'jax.random.PRNGKey' | None = None)[source]#
Base class for loss functions.
A loss function is the composition of a score function and a scalarisation map (which might itself be the composition of different tensor rank reduction maps). It also includes a multiplier that can be used to scale its contribution to the overall loss. The multiplier is specified using the
nu
parameter.The API vis-a-vis dimension reduction is subject to change. We will likely make scalarisations more flexible with regard to both compositionality and the number/specification of dimensions they reduce to.
- Parameters:
- name: str
Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.
- nu: float
Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.
- score: Callable
The scoring function to be used to compute the loss value. This function should take a single argument, which is a tensor of arbitrary shape, and return a score value for each (potentially multivariate) observation in the tensor.
- scalarisation: Callable
The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.
- Attributes:
- nu
Methods
__call__
(*pparams[, key])Call self as a function.
cfg
(value[, where])Return a copy of the loss function with the specified attribute modified.
step
([count])If the loss multiplier is a schedule, this will advance the schedule by one step.