DomainMappedParameter
#
- class hypercoil.init.mapparam.DomainMappedParameter(model: PyTree, *, where: Callable = <function where_weight>, image_bound: Any = None, preimage_bound: Any = None, handler: Callable = None)[source]#
A parameter that is mapped between different domains.
This extends
MappedParameter
in that it admits for the forward (image) and backward (preimage) maps (i) a set of bounds on the domain, and (ii) a mechanism for handling out-of-domain values.- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.- image_boundTuple[float, float]
Minimum and maximum prescribed values for the image map. Note that these are not necessarily the same as the minimum and maximum (or infimum and supremum) attainable by the map itself. For example, it might be desirable to further truncate some maps to a range where the magnitude of the gradient is non-negligible. However, in general, the prescribed bounds should not admit any values outside the domain of the map.
- preimage_boundTuple[float, float]
Minimum and maximum prescribed values for the preimage map.
- handler
OutOfDomainHandler
object (default:Clip
) The handler to use for imputing out-of-domain values.
- Attributes:
- handler
- image_bound
- preimage_bound
Methods
handle_ood
(param)Apply an out-of-domain handler to ensure that all tensor entries are within bounds.
image_map
(param)Map a tensor to its image under the transformation.
preimage_map
(param)Map a tensor to its preimage under the transformation.
test
(param)Evaluate whether each entry in a tensor falls within the prescribed bounds of the image map.
image_map_impl
preimage_map_impl