identity_init#

hypercoil.init.base.identity_init(*, shape: Tuple[int], scale: float = 1, shift: float = 0, key: PRNGKey | None = None) Tensor[source]#

Initialise a tensor such that each of its slices is an identity matrix. Currently this sets each slice defined by the last two axes to identity. If there is a use case for other slices, it can be made more flexible in the future.