constraint_violation: Soft constraints#

constraint_violation#

hypercoil.loss.constraint_violation(X: Tensor, *, constraints: Sequence[Callable[[Tensor], Tensor]], broadcast_against_input: bool = False, key: PRNGKey | None = None) Tensor[source]#

Constraint violation score function.

This loss uses a set of constraint functions and evaluates them on its input. If a constraint evaluates to 0 or less, then it is considered to be satisfied and no penalty is applied. Otherwise, a score is returned in proportion to the maximum violation of any constraint.

For example, using the identity constraint penalises only positive elements (equivalent to unilateral_loss()), while lambda x: -x penalises only negative elements. lambda x : tensor([1, 3, 0, -2]) @ x - 2 applies the specified affine function as a constraint.

Warning

Because of broadcasting rules, the results of constraint computations are not necessarily the same shape as the input, and the output of this function will be the same shape as the largest constraint. This might lead to unexpected scaling of different constraints, and so the broadcast_against_input option is provided to broadcast all outputs against the input shape. In the future, we might add an option that normalises each constraint violation by the number of elements in the output.

Parameters:
XTensor

Input tensor.

constraintsSequence[Callable[[Tensor], Tensor]]

Iterable containing constraint functions.

broadcast_against_inputbool, optional (default: False)

If True, broadcast all constraint outputs against the input shape.

Returns:
Tensor

Maximum constraint violation score for each element.

ConstraintViolationLoss#

class hypercoil.loss.ConstraintViolationLoss(nu: float = 1.0, name: str | None = None, *, constraints: Sequence[Callable], broadcast_against_input: bool = False, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Loss function for constraint violations.

This loss uses a set of constraint functions and evaluates them on its input. If a constraint evaluates to 0 or less, then it is considered to be satisfied and no penalty is applied. Otherwise, a score is returned in proportion to the maximum violation of any constraint.

For example, using the identity constraint penalises only positive elements (equivalent to unilateral_loss()), while lambda x: -x penalises only negative elements. lambda x : tensor([1, 3, 0, -2]) @ x - 2 applies the specified affine function as a constraint.

Warning

Because of broadcasting rules, the results of constraint computations are not necessarily the same shape as the input, and the output of this function will be the same shape as the largest constraint. This might lead to unexpected scaling of different constraints, and so the broadcast_against_input option is provided to broadcast all outputs against the input shape. In the future, we might add an option that normalises each constraint violation by the number of elements in the output.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

constraintsSequence[Callable[[Tensor], Tensor]]

Iterable containing constraint functions.

broadcast_against_inputbool, optional (default: False)

If True, broadcast all constraint outputs against the input shape.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

Methods

__call__(X, *[, key])

Call self as a function.