dirichlet
#
Initialise a tensor such that elements along a given axis are Dirichlet samples.
dirichlet_init
#
- hypercoil.init.dirichlet.dirichlet_init(*, shape: Tuple[int], distr: Distribution, axis: int = -1, key: PRNGKey) Tensor [source]#
Dirichlet sample initialisation.
Initialise a tensor such that any 1D slice through that tensor along a given axis is a sample from a specified Dirichlet distribution. Each 1D slice can therefore be understood as encoding a categorical probability distribution.
- Parameters:
- shapetuple
Shape of the tensor to initialise.
- distrinstance of
torch.distributions.Dirichlet
Parametrised Dirichlet distribution from which all 1D slices of the input tensor along the specified axis are sampled.
- axisint (default -1)
Axis along which slices are sampled from the specified Dirichlet distribution.
- keyjax.random.PRNGKey
Pseudo-random number generator key for sampling the Dirichlet distribution.
DirichletInitialiser
#
- class hypercoil.init.dirichlet.DirichletInitialiser(concentration: Sequence[float], num_classes: int | None = None, axis: int = -1, mapper: Type[MappedParameter] | None = None)[source]#
Initialise a parameter such that all slices along the final axis are samples from a specified Dirichlet distribution.
See
dirichlet_init()
andMappedInitialiser
for argument details.Methods
init
(model, *[, mapper, num_classes, axis, ...])Initialise a parameter using the specified initialiser and mapper.