bregman_divergence: Bregman divergences#

bregman_divergence#

hypercoil.loss.bregman_divergence(X: Tensor, Y: Tensor, *, f: Callable, f_dim: int, key: PRNGKey | None = None) Tensor[source]#

Bregman divergence score function.

This function computes the Bregman divergence between the input tensor and the target tensor, induced according to the convex function f.

For a version of this function that operates on logits, see bregman_divergence_logit().

Parameters:
XTensor

Input tensor.

YTensor

Target tensor.

fCallable

Convex function to induce the Bregman divergence.

f_dimint

Dimension of arguments to f.

Returns:
Tensor

Bregman divergence score for each set of observations.

hypercoil.loss.bregman_divergence_logit(X: Tensor, Y: Tensor, *, f: Callable, f_dim: int, key: PRNGKey | None = None) Tensor[source]#

Bregman divergence score function for logits.

This function computes the Bregman divergence between the input tensor and the target tensor, induced according to the convex function f.

This function operates on logits. For the standard version of this function, see bregman_divergence().

Parameters:
XTensor

Input tensor.

YTensor

Target tensor.

fCallable

Convex function to induce the Bregman divergence.

f_dimint

Dimension of arguments to f.

Returns:
Tensor

Bregman divergence score for each set of observations.

BregmanDivergenceLoss#

class hypercoil.loss.BregmanDivergenceLoss(nu: float = 1.0, name: str | None = None, *, f: Callable, f_dim: int, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Loss based on the Bregman divergence between two categorical distributions.

This operates on unmapped tensors. For a version that operates on logits logits, see BregmanDivergenceLogitLoss.

This function computes the Bregman divergence between the input tensor and the target tensor, induced according to the convex function f.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

Parameters
———-
XTensor

Input tensor.

YTensor

Target tensor.

fCallable

Convex function to induce the Bregman divergence.

f_dimint

Dimension of arguments to f.

Returns:
Tensor

Bregman divergence score for each set of observations.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

class hypercoil.loss.BregmanDivergenceLogitLoss(nu: float = 1.0, name: str | None = None, *, f: Callable, f_dim: int, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Loss based on the Bregman divergence between two categorical distributions.

This operates on logits. For a version that operates on unmapped probabilities, see BregmanDivergenceLoss.

This function computes the Bregman divergence between the input tensor and the target tensor, induced according to the convex function f.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

Parameters
———-
XTensor

Input tensor.

YTensor

Target tensor.

fCallable

Convex function to induce the Bregman divergence.

f_dimint

Dimension of arguments to f.

Returns:
Tensor

Bregman divergence score for each set of observations.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.