NormSphereParameter#

class hypercoil.init.mapparam.NormSphereParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, loc: float = 0.0, scale: float = 1.0, norm: float = 2.0, axis: Union[int, Tuple[int, ...]] = -1)[source]#

Parameter whose constituent vectors are projected onto a sphere of some fixed norm.

Note that this will only work for proper convex norms. (It will obviously not work for the L0 norm, for example.)

Warning

The normalisation procedure is a simple division operation. Thus, it will not work for many matrix norms.

Parameters:
modelPyTree

The model to which the parameter belongs.

whereCallable

As in equinox.tree_at: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example: where = lambda mlp: mlp.layers[-1].linear.weight. By default, the weight attribute of the model is retrieved.

handlerOutOfDomainHandler object (default Clip)

Object specifying a method for handling out-of-domain entries.

locfloat or tensor

Location of the centre of the norm sphere. By default, the sphere is centred at the origin (loc=0).

scalefloat (default 1)

Scale of the norm sphere. By default, the unit norm sphere (scale=1) is used.

normint, str, or tensor

Norm order argument passed to jnp.linalg.norm. If this is a tensor of symmetric positive semidefinite matrices, then a Mahalanobis-like ellipse norm is computed using those matrices. (No matrix inversion is performed.) By default, the Euclidean norm is used.

axisint or tuple(int)

Axis or axis tuple over which the norm is computed. Every slice along the specified axis or axis tuple of the mapped tensor is rescaled so that its norm is equal to scale in the specified norm.

Methods

image_map_impl

preimage_map_impl