MappedLogits
#
- class hypercoil.init.mapparam.MappedLogits(model: PyTree, *, where: Callable = <function where_weight>, preimage_bound: Tuple[float, float] = (-4.5, 4.5), handler: Callable = None, loc: Optional[float] = None, scale: float = 1.0)[source]#
Parameter mapped through a logistic function. The tensor values are thus constrained between 0 and 1, or more generally between 2 arbitrary finite real numbers.
- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.- preimage_bound(float min, float max) (default -3, 3)
Minimum and maximum prescribed values for the preimage map. Note that these are not necessarily the same as the minimum and maximum (or infimum and supremum) attainable by the map itself. For example, it might be desirable to further truncate values to a range where the magnitude of the gradient is non-negligible.
- handler
OutOfDomainHandler
object (defaultClip
) Object specifying a method for handling out-of-domain entries.
- locfloat (default None)
Location parameter for the logistic map. Zero is mapped to this value under the logistic map.
- scalefloat (default 1)
Size of the interval mapped onto by the logistic map.
Methods
image_map_impl
preimage_map_impl