entropy: Categorical entropy#

entropy#

hypercoil.loss.entropy(X: Tensor, *, axis: int | Sequence[int] = -1, keepdims: bool = True, reduce: bool = True, key: PRNGKey | None = None) Tensor[source]#

Entropy of a categorical distribution or set of categorical distributions.

This function operates on probability tensors. For a version that operates on logits, see entropy_logit().

Parameters:
XTensor

Input tensor containing probabilities or logits for each category.

axisint or sequence of ints, optional (default: -1)

Axis or axes over which to compute the entropy.

keepdimsbool, optional (default: True)

As in jax.numpy.sum.

reducebool, optional (default: True)

If this is False, then the unsummed probability-weighted surprise is computed for each element of the input tensor. Otherwise, the entropy is computed over the specified axis or axes.

Returns:
Tensor

Entropy score for each set of observations.

hypercoil.loss.entropy_logit(X: Tensor, *, temperature: float = 1.0, axis: int | Sequence[int] = -1, keepdims: bool = True, reduce: bool = True, key: PRNGKey | None = None) Tensor[source]#

Project logits in the input matrix onto the probability simplex, and then compute the entropy of the resulting categorical distribution.

This function operates on logit tensors. For a version that operates on probabilities, see entropy().

Parameters:
XTensor

Input tensor containing probabilities or logits for each category.

axisint or sequence of ints, optional (default: -1)

Axis or axes over which to compute the entropy.

keepdimsbool, optional (default: True)

As in jax.numpy.sum.

reducebool, optional (default: True)

If this is False, then the unsummed probability-weighted surprise is computed for each element of the input tensor. Otherwise, the entropy is computed over the specified axis or axes.

Returns:
Tensor

Entropy score for each set of observations.

EntropyLoss#

class hypercoil.loss.EntropyLoss(nu: float = 1.0, name: str | None = None, *, axis: int | Tuple[int, ...] = -1, keepdims: bool = False, reduce: bool = True, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Loss based on the entropy of a categorical distribution.

This operates on probability tensors. For a version that operates on logits, see EntropyLogitLoss.

Entropy

The entropy of a categorical distribution \(A\) is defined as

\(-\mathbf{1}^\intercal \left(A \circ \log A\right) \mathbf{1}\)

(where \(\log\) denotes the elementwise logarithm).

hypercoil/_images/entropysimplex.svg

Cartoon schematic of the contours of an entropy-like function over categorical distributions. The function attains its maximum for the distribution in which all outcomes are equiprobable. The function can become smaller without bound away from this maximum. The superposed triangle represents the probability simplex. By pre-transforming the penalised weights to constrain them to the simplex, the entropy function is bounded and attains a separate minimum for each deterministic distribution.

Penalising the entropy promotes concentration of weight into a single category. This has applications in problem settings such as parcellation, when more deterministic parcel assignments are desired.

Warning

Entropy is a concave function. Minimising it without constraint affords an unbounded capacity for reducing the loss. This is almost certainly undesirable. For this reason, it is recommended that some constraint be imposed on the input set when placing a penalty on entropy. One possibility is using a probability simplex parameter mapper to first project the input weights onto the probability simplex.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

axisint or sequence of ints, optional (default: -1)

Axis or axes over which to compute the entropy.

keepdimsbool, optional (default: True)

As in jax.numpy.sum.

reducebool, optional (default: True)

If this is False, then the unsummed probability-weighted surprise is computed for each element of the input tensor. Otherwise, the entropy is computed over the specified axis or axes.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.

class hypercoil.loss.EntropyLogitLoss(nu: float = 1.0, name: str | None = None, *, axis: int | Tuple[int, ...] = -1, keepdims: bool = False, reduce: bool = True, scalarisation: Callable | None = None, key: 'jax.random.PRNGKey' | None = None)[source]#

Loss based on the entropy of a categorical distribution.

This operates on logit tensors. For a version that operates on probabilities, see EntropyLoss.

Entropy

The entropy of a categorical distribution \(A\) is defined as

\(-\mathbf{1}^\intercal \left(A \circ \log A\right) \mathbf{1}\)

(where \(\log\) denotes the elementwise logarithm).

hypercoil/_images/entropysimplex.svg

Cartoon schematic of the contours of an entropy-like function over categorical distributions. The function attains its maximum for the distribution in which all outcomes are equiprobable. The function can become smaller without bound away from this maximum. The superposed triangle represents the probability simplex. By pre-transforming the penalised weights to constrain them to the simplex, the entropy function is bounded and attains a separate minimum for each deterministic distribution.

Penalising the entropy promotes concentration of weight into a single category. This has applications in problem settings such as parcellation, when more deterministic parcel assignments are desired.

Warning

Entropy is a concave function. Minimising it without constraint affords an unbounded capacity for reducing the loss. This is almost certainly undesirable. For this reason, it is recommended that some constraint be imposed on the input set when placing a penalty on entropy. One possibility is using a probability simplex parameter mapper to first project the input weights onto the probability simplex.

Parameters:
name: str

Designated name of the loss function. It is not required that this be specified, but it is recommended to ensure that the loss function can be identified in the context of a reporting utilities. If not explicitly specified, the name will be inferred from the class name and the name of the scoring function.

nu: float

Loss strength multiplier. This is a scalar multiplier that is applied to the loss value before it is returned. This can be used to modulate the relative contributions of different loss functions to the overall loss value. It can also be used to implement a schedule for the loss function, by dynamically adjusting the multiplier over the course of training.

axisint or sequence of ints, optional (default: -1)

Axis or axes over which to compute the entropy.

keepdimsbool, optional (default: True)

As in jax.numpy.sum.

reducebool, optional (default: True)

If this is False, then the unsummed probability-weighted surprise is computed for each element of the input tensor. Otherwise, the entropy is computed over the specified axis or axes.

scalarisation: Callable

The scalarisation function to be used to aggregate the values returned by the scoring function. This function should take a single argument, which is a tensor of arbitrary shape, and return a single scalar value. By default, the mean scalarisation is used.