ProbabilitySimplexParameter#

class hypercoil.init.mapparam.ProbabilitySimplexParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, axis: int = -1, minimum: float = 0.001, smoothing: float = 0, temperature: Union[float, Literal['auto']] = 1.0)[source]#

Parameter whose constituent slices are projected onto the probability simplex.

The forward function is a softmax. Note that the softmax function does not have a unique inverse; here we use the elementwise natural logarithm as an ‘inverse’. For a relatively well-behaved map, pair this with Dirichlet initialisation.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

handlerOutOfDomainHandler object (default Clip)

Object specifying a method for handling out-of-domain entries.

axisint (default -1)

Axis of tensors in the domain along which 1D slices are mapped to the probability simplex.

minimumnonnegative float (default 1e-3)

Lower prescribed bound on inputs to the elementwise natural logarithm.

smoothingnonnegative float (default 0)

For use when configuring the original parameter. Increasing the smoothing will result in a higher-entropy / more equiprobable image.

temperaturenonnegative float or 'auto' (default 1)

Softmax temperature.

Methods

image_map_impl

preimage_map_impl