AmplitudeTanhMappedParameter#

class hypercoil.init.mapparam.AmplitudeTanhMappedParameter(model: PyTree, *, where: Callable = <function where_weight>, preimage_bound: Tuple[float, float] = (-3.0, 3.0), handler: Callable = None, scale: float = 1.0)[source]#

Complex-valued parameter whose amplitude is mapped through a hyperbolic tangent function. The amplitude is thus constrained between some finite scale value and its negation.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

preimage_bound(float min, float max) (default -3, 3)

Minimum and maximum prescribed values for the preimage map. Note that these are not necessarily the same as the minimum and maximum (or infimum and supremum) attainable by the map itself. For example, it might be desirable to further truncate values to a range where the magnitude of the gradient is non-negligible.

handlerOutOfDomainHandler object (default Clip)

Object specifying a method for handling out-of-domain entries.

scalefloat (default 1)

Maximum/minimum value attained by the hyperbolic tangent map.