dirichlet#

Initialise a tensor such that elements along a given axis are Dirichlet samples.

dirichlet_init#

hypercoil.init.dirichlet.dirichlet_init(*, shape: Tuple[int], distr: Distribution, axis: int = -1, key: PRNGKey) Tensor[source]#

Dirichlet sample initialisation.

Initialise a tensor such that any 1D slice through that tensor along a given axis is a sample from a specified Dirichlet distribution. Each 1D slice can therefore be understood as encoding a categorical probability distribution.

Parameters:
shapetuple

Shape of the tensor to initialise.

distrinstance of torch.distributions.Dirichlet

Parametrised Dirichlet distribution from which all 1D slices of the input tensor along the specified axis are sampled.

axisint (default -1)

Axis along which slices are sampled from the specified Dirichlet distribution.

keyjax.random.PRNGKey

Pseudo-random number generator key for sampling the Dirichlet distribution.

DirichletInitialiser#

class hypercoil.init.dirichlet.DirichletInitialiser(concentration: Sequence[float], num_classes: int | None = None, axis: int = -1, mapper: Type[MappedParameter] | None = None)[source]#

Initialise a parameter such that all slices along the final axis are samples from a specified Dirichlet distribution.

See dirichlet_init() and MappedInitialiser for argument details.

Methods

init(model, *[, mapper, num_classes, axis, ...])

Initialise a parameter using the specified initialiser and mapper.