NormSphereParameter
#
- class hypercoil.init.mapparam.NormSphereParameter(model: PyTree, *, where: Callable = <function where_weight>, handler: Callable = None, loc: float = 0.0, scale: float = 1.0, norm: float = 2.0, axis: Union[int, Tuple[int, ...]] = -1)[source]#
Parameter whose constituent vectors are projected onto a sphere of some fixed norm.
Note that this will only work for proper convex norms. (It will obviously not work for the L0 norm, for example.)
Warning
The normalisation procedure is a simple division operation. Thus, it will not work for many matrix norms.
- Parameters:
- modelPyTree
The model to which the parameter belongs.
- whereCallable
As in
equinox.tree_at
: a function that takes a model (or generally a PyTree) and returns the parameter tensor to be mapped. For example:where = lambda mlp: mlp.layers[-1].linear.weight
. By default, theweight
attribute of the model is retrieved.- handler
OutOfDomainHandler
object (defaultClip
) Object specifying a method for handling out-of-domain entries.
- locfloat or tensor
Location of the centre of the norm sphere. By default, the sphere is centred at the origin (
loc=0
).- scalefloat (default 1)
Scale of the norm sphere. By default, the unit norm sphere (
scale=1
) is used.- normint, str, or tensor
Norm order argument passed to
jnp.linalg.norm
. If this is a tensor of symmetric positive semidefinite matrices, then a Mahalanobis-like ellipse norm is computed using those matrices. (No matrix inversion is performed.) By default, the Euclidean norm is used.- axisint or tuple(int)
Axis or axis tuple over which the norm is computed. Every slice along the specified axis or axis tuple of the mapped tensor is rescaled so that its norm is equal to
scale
in the specifiednorm
.
Methods
image_map_impl
preimage_map_impl