AmplitudeProbabilitySimplexParameter#

class hypercoil.init.mapparam.AmplitudeProbabilitySimplexParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, axis: int = -1, minimum: float = 0.001, smoothing: float = 0, temperature: Union[float, Literal['auto']] = 1.0)[source]#

Complex-valued parameter whose amplitudes are projected onto the probability simplex.

The forward function is a softmax applied to the amplitudes. Note that the softmax function does not have a unique inverse; here we use the elementwise natural logarithm as an ‘inverse’.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

handlerOutOfDomainHandler object (default Clip)

Object specifying a method for handling out-of-domain entries.

axisint (default -1)

Axis of tensors in the domain along which 1D slices are mapped to the probability simplex.

minimumnonnegative float (default 1e-3)

Lower prescribed bound on inputs to the elementwise natural logarithm.

smoothingnonnegative float (default 0)

For use when configuring the original parameter. Increasing the smoothing will result in a higher-entropy / more equiprobable image.

temperaturenonnegative float (default 1)

Softmax temperature.