MappedParameter
#
- class hypercoil.init.mapparam.MappedParameter(model: PyTree, *, where: Callable = <function where_weight>)[source]#
A transformed version of a parameter tensor.
A
MappedParameter
wraps and replaces a standard parameter in a model. Subclasses can implement image (forward transformation) and preimage (often right inverse transformation) maps to transform the parameter tensor. At instantiation, the original parameter is mapped under the preimage map. It can thereafter be accessed in theoriginal
field of theMappedParameter
. Accessing the parameter instead accesses the tranformation of the original parameter under the image map.Note
Rather than first instantiating a new
MappedParameter
and then creating an updated model, it is also possible to directly create an updated model that immediatelycontains theMappedParameter
using themap
class method.- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.
Methods
map
(model, *pparams[, where])Create an updated version of a model that contains the mapped parameter.
image_map
preimage_map