Sylo#

class hypercoil.nn.sylo.Sylo(in_channels: int, out_channels: int, dim: int, rank: int = 1, bias: bool = True, symmetry: Optional[Literal['psd', 'cross', 'skew']] = 'psd', coupling: Optional[Union[Literal['+', '-', 'split'], int, float]] = None, fixed_coupling: bool = False, similarity: Callable = <function crosshair_similarity>, remove_diagonal: bool = False, *, key: Optional['jax.random.PRNGKey'] = None)[source]#

Layer that learns a set of (possibly) symmetric, low-rank representations of a dataset.

Parameters:
in_channels: int

Number of channels or layers in the input graph or matrix.

out_channels: int

Number of channels or layers in the output graph or matrix. This is equal to the number of learnable templates.

dim: int or tuple(int)

Number of vertices in the graph or number of columns in the matrix. If the graph is bipartite or the matrix is nonsquare, then this should be a 2-tuple.

rank: int

Rank of the templates learned by the sylo module. Default: 1.

bias: bool

If True, adds a learnable bias to the output. Default: True

symmetry: ``’psd’``, ``’cross’``, ``’skew’``, or None (default ``’psd’``)

Symmetry constraints to impose on learnable templates.

  • If None, no symmetry constraints are placed on the templates learned by the module.

  • If 'psd', the module is constrained to learn symmetric representations of the graph or matrix: the left and right generators of each template are constrained to be identical.

  • If 'cross', the module is also constrained to learn symmetric representations of the graph or matrix. However, in this case, the left and right generators can be different, and the template is defined as the average of the expansion and its transpose: \(\frac{1}{2} \left(L R^{\intercal} + R L^{\intercal}\right)\)

  • If 'skew', the module is constrained to learn skew-symmetric representations of the graph or matrix. The template is defined as the difference between the expansion and its transpose: \(\frac{1}{2} \left(L R^{\intercal} - R L^{\intercal}\right)\)

This option is not available for nonsquare matrices or bipartite graphs. Note that the parameter count doubles if this is False.

coupling: None, ``’+’``, ``’-‘``, ``’split’``, int, or float

Coupling parameter when expanding outer-product template banks.

  • A value of None disables the coupling parameter.

  • '+' is equivalent to None, fixing coupling to positive 1. For symmetry=True, this enforces positive semidefinite templates.

  • '-' fixes coupling parameters to negative 1. For symmetry=True, this enforces negative semidefinite templates.

  • 'split' splits channels such that approximately an equal number have coupling parameters fixed to +1 and -1. For symmetry=True, this splits channels among positive and negative semidefinite templates. This option can also be useful when imposing a unilateral normed penalty to favour nonnegative weights, as the template bank can simultaneously satisfy the soft nonnegativity constraint and respond with positive activations to features of either sign, enabling these activations to survive downstream rectifiers.

  • A float value in (0, 1) is just like split but fixes the fraction of negative parameters to approximately the specified value.

  • Similarly, an int value fixes the number of negative output channels to the specified value.

  • 'learnable' sets the diagonal terms of the coupling parameter (the coupling between vector 0 of the left generator and vector 0 of the right generator, for instance, but not between vector 0 of the left generator and vector 1 of the right generator) to be learnable.

  • 'learnable_all' sets all terms of the coupling parameter to be

learnable.

similarity: function

Definition of the similarity metric. This must be a function whose inputs and outputs are:

  • input 0: reference matrix (N x C_in x H x W)

  • input 1: left template generator (C_out x C_in x H x R)

  • input 2: right template generator (C_out x C_in x H x R)

  • input 3: symmetry constraint ('cross', 'skew', or other)

  • output (N x C_out x H x W)

Similarity is computed between each of the N matrices in the first input stack and the (low-rank) matrix derived from the outer-product expansion of the second and third inputs. Default: crosshair_similarity

delete_diagonal: bool

Delete the diagonal of the output.

Attributes:
weight: Tensor

The learnable weights of the module of shape out_channels x in_channels x dim x rank.

bias: Tensor

The learnable bias of the module of shape out_channels.

Methods

__call__(input, *[, key])

Call self as a function.