OrthogonalParameter
#
- class hypercoil.init.mapparam.OrthogonalParameter(model: PyTree, *, where: Callable = <function where_weight>)[source]#
Parameter whose constituent slices are orthogonal vectors.
Currently, this is implemented in a crude manner using a QR decomposition.
- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.
Methods
image_map
preimage_map