kl_divergence
: Kullback-Leibler divergence#
kl_divergence
#
- hypercoil.loss.kl_divergence(P: Tensor, Q: Tensor, *, axis: int | Sequence[int] = -1, keepdims: bool = True, reduce: bool = True, key: PRNGKey | None = None) Tensor [source]#
Kullback-Leibler divergence between two categorical distributions.
This function operates on probability tensors. For a version that operates on logits, see
kl_divergence_logit()
.Adapted from
distrax
.Note
The KL divergence is not symmetric, so this function returns \(KL(P || Q)\). For a symmetric measure, see
js_divergence()
.\[KL(P || Q) = \sum_{x \in \mathcal{X}}^n P_x \log \frac{P_x}{Q_x}\]- Parameters:
- PTensor
Input tensor parameterising the first categorical distribution.
- QTensor
Input tensor parameterising the second categorical distribution.
- axisint or sequence of ints, optional (default:
-1
) Axis or axes over which to compute the KL divergence.
- keepdimsbool, optional (default:
True
) As in
jax.numpy.sum
.- reducebool, optional (default:
True
) If this is False, then the unsummed KL divergence is computed for each element of the input tensor. Otherwise, the KL divergence is computed over the specified axis or axes.
- Returns:
- Tensor
KL divergence between the two distributions.
- hypercoil.loss.kl_divergence_logit(P: Tensor, Q: Tensor, *, axis: int | Sequence[int] = -1, keepdims: bool = True, reduce: bool = True, key: PRNGKey | None = None)[source]#
Kullback-Leibler divergence between two categorical distributions.
This function operates on logits. For a version that operates on probabilities, see
kl_divergence()
.Adapted from
distrax
.Note
The KL divergence is not symmetric, so this function returns \(KL(P || Q)\). For a symmetric measure, see
js_divergence()
.\[KL(P || Q) = \sum_{x \in \mathcal{X}}^n P_x \log \frac{P_x}{Q_x}\]- Parameters:
- PTensor
Input tensor parameterising the first categorical distribution.
- QTensor
Input tensor parameterising the second categorical distribution.
- axisint or sequence of ints, optional (default:
-1
) Axis or axes over which to compute the KL divergence.
- keepdimsbool, optional (default:
True
) As in
jax.numpy.sum
.- reducebool, optional (default:
True
) If this is False, then the unsummed KL divergence is computed for each element of the input tensor. Otherwise, the KL divergence is computed over the specified axis or axes.
- Returns:
- Tensor
KL divergence between the two distributions.
KLDivergenceLoss
#
- class hypercoil.loss.KLDivergenceLoss(nu: float = 1.0, name: str | None = None, *, axis: int | Tuple[int, ...] = -1, keepdims: bool = False, reduce: bool = True, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#
Loss based on the Kullback-Leibler divergence between two categorical distributions.
This operates on probability tensors. For a version that operates on logits, see
KLDivergenceLogitLoss
.Adapted from
distrax
.Note
The KL divergence is not symmetric, so this function returns \(KL(P || Q)\). For a symmetric measure, see
js_divergence()
.\[KL(P || Q) = \sum_{x \in \mathcal{X}}^n P_x \log \frac{P_x}{Q_x}\]- Parameters:
- name: str
Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.
- nu: float
Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.
- axisint or sequence of ints, optional (default:
-1
) Axis or axes over which to compute the KL divergence.
- keepdimsbool, optional (default:
True
) As in
jax.numpy.sum
.- reducebool, optional (default:
True
) If this is False, then the unsummed KL divergence is computed for each element of the input tensor. Otherwise, the KL divergence is computed over the specified axis or axes.
- scalarisation: Callable
The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.
- class hypercoil.loss.KLDivergenceLogitLoss(nu: float = 1.0, name: str | None = None, *, axis: int | Tuple[int, ...] = -1, keepdims: bool = False, reduce: bool = True, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#
Loss based on the Kullback-Leibler divergence between two categorical distributions.
This operates on logit tensors. For a version that operates on probabilities, see
KLDivergenceLoss
.Adapted from
distrax
.Note
The KL divergence is not symmetric, so this function returns \(KL(P || Q)\). For a symmetric measure, see
js_divergence()
.\[KL(P || Q) = \sum_{x \in \mathcal{X}}^n P_x \log \frac{P_x}{Q_x}\]- Parameters:
- name: str
Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.
- nu: float
Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.
- axisint or sequence of ints, optional (default:
-1
) Axis or axes over which to compute the KL divergence.
- keepdimsbool, optional (default:
True
) As in
jax.numpy.sum
.- reducebool, optional (default:
True
) If this is False, then the unsummed KL divergence is computed for each element of the input tensor. Otherwise, the KL divergence is computed over the specified axis or axes.
- scalarisation: Callable
The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.