AmplitudeProbabilitySimplexParameter
#
- class hypercoil.init.mapparam.AmplitudeProbabilitySimplexParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, axis: int = -1, minimum: float = 0.001, smoothing: float = 0, temperature: Union[float, Literal['auto']] = 1.0)[source]#
Complex-valued parameter whose amplitudes are projected onto the probability simplex.
The forward function is a softmax applied to the amplitudes. Note that the softmax function does not have a unique inverse; here we use the elementwise natural logarithm as an ‘inverse’.
- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.- handler
OutOfDomainHandler
object (defaultClip
) Object specifying a method for handling out-of-domain entries.
- axisint (default -1)
Axis of tensors in the domain along which 1D slices are mapped to the probability simplex.
- minimumnonnegative float (default 1e-3)
Lower prescribed bound on inputs to the elementwise natural logarithm.
- smoothingnonnegative float (default 0)
For use when configuring the
original
parameter. Increasing the smoothing will result in a higher-entropy / more equiprobable image.- temperaturenonnegative float (default 1)
Softmax temperature.