NormedLoss: Normed parameter regularisation#

class hypercoil.loss.NormedLoss(nu: float = 1.0, name: Optional[str] = None, score: Callable = <function identity>, *, p: float = 2.0, axis: Union[int, Sequence[int]] = None, outer_scalarise: Callable = <function mean_scalarise>, key: Optional['jax.random.PRNGKey'] = None)[source]#

\(L_p\) norm regulariser.

An example of how to compose elements to define a loss function. By default, this function flattens the input tensor and computes the \(L_2\) norm of the resulting vector. The dimensions to be flattened and the norm order can be specified using the axis and p arguments respectively. If the norm is computed over only a subset of axes, the remaining axes can be further reduced by specifying a scalarisation function using the outer_scalarise argument. By default, the outer scalarisation function is the mean function. Setting this to an identity function will result in a loss function that returns a vector of values for each observation.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

score: Callable

The scoring function to be used to compute the loss value. This function should take a single argument, which is a tensor of arbitrary shape, and return a score value for each (potentially multivariate) observation in the tensor. p: float The order of the norm to be computed. If p = 1, the function computes the \(L_1\) Manhattan / city block norm. If p = 2, the function computes the \(L_2\) Euclidean norm. If p = inf, the function computes the \(L_\infty\) maximum norm.

axis: Optional[Union[int, Sequence[int]]]

The axes to be flattened. If None, all axes are flattened.

outer_scalarise: Optional[Callable]

The scalarisation function to be applied to any dimensions that are not flattened (i.e., those not specified in axis). If None, the mean function is used. If axis is None, this argument is ignored. To return a vector of values for each observation, explicitly set this to an identity function.