OrthogonalParameter#

class hypercoil.init.mapparam.OrthogonalParameter(model: PyTree, *, where: Callable = <function where_weight>)[source]#

Parameter whose constituent slices are orthogonal vectors.

Currently, this is implemented in a crude manner using a QR decomposition.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

Methods

image_map

preimage_map