MappedParameter#

class hypercoil.init.mapparam.MappedParameter(model: PyTree, *, where: Callable = <function where_weight>)[source]#

A transformed version of a parameter tensor.

A MappedParameter wraps and replaces a standard parameter in a model. Subclasses can implement image (forward transformation) and preimage (often right inverse transformation) maps to transform the parameter tensor. At instantiation, the original parameter is mapped under the preimage map. It can thereafter be accessed in the original field of the MappedParameter. Accessing the parameter instead accesses the tranformation of the original parameter under the image map.

Note

Rather than first instantiating a new MappedParameter and then creating an updated model, it is also possible to directly create an updated model that immediatelycontains the MappedParameter using the map class method.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

Methods

map(model, *pparams[, where])

Create an updated version of a model that contains the mapped parameter.

image_map

preimage_map