ProbabilitySimplexParameter
#
- class hypercoil.init.mapparam.ProbabilitySimplexParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, axis: int = -1, minimum: float = 0.001, smoothing: float = 0, temperature: Union[float, Literal['auto']] = 1.0)[source]#
Parameter whose constituent slices are projected onto the probability simplex.
The forward function is a softmax. Note that the softmax function does not have a unique inverse; here we use the elementwise natural logarithm as an ‘inverse’. For a relatively well-behaved map, pair this with Dirichlet initialisation.
- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.- handler
OutOfDomainHandler
object (defaultClip
) Object specifying a method for handling out-of-domain entries.
- axisint (default -1)
Axis of tensors in the domain along which 1D slices are mapped to the probability simplex.
- minimumnonnegative float (default 1e-3)
Lower prescribed bound on inputs to the elementwise natural logarithm.
- smoothingnonnegative float (default 0)
For use when configuring the
original
parameter. Increasing the smoothing will result in a higher-entropy / more equiprobable image.- temperaturenonnegative float or
'auto'
(default 1) Softmax temperature.
Methods
image_map_impl
preimage_map_impl