MSELoss: Mean squared error#

class hypercoil.loss.MSELoss(nu: float = 1.0, name: str | None = None, *, key: 'jax.random.PRNGKey' | None = None)[source]#

Mean squared error loss function.

An example of how to compose elements to define a loss function. The score function is the difference between the input and the target, and the scalarisation function is the mean of squared values.

There are probably better implementations of the mean squared error loss out there.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

Methods

__call__(Y, Y_hat, *[, key])

Call self as a function.