Sylo
#
- class hypercoil.nn.sylo.Sylo(in_channels: int, out_channels: int, dim: int, rank: int = 1, bias: bool = True, symmetry: Optional[Literal['psd', 'cross', 'skew']] = 'psd', coupling: Optional[Union[Literal['+', '-', 'split'], int, float]] = None, fixed_coupling: bool = False, similarity: Callable = <function crosshair_similarity>, remove_diagonal: bool = False, *, key: Optional['jax.random.PRNGKey'] = None)[source]#
Layer that learns a set of (possibly) symmetric, low-rank representations of a dataset.
- Parameters:
- in_channels: int
Number of channels or layers in the input graph or matrix.
- out_channels: int
Number of channels or layers in the output graph or matrix. This is equal to the number of learnable templates.
- dim: int or tuple(int)
Number of vertices in the graph or number of columns in the matrix. If the graph is bipartite or the matrix is nonsquare, then this should be a 2-tuple.
- rank: int
Rank of the templates learned by the sylo module. Default: 1.
- bias: bool
If True, adds a learnable bias to the output. Default: True
- symmetry: ``’psd’``, ``’cross’``, ``’skew’``, or None (default ``’psd’``)
Symmetry constraints to impose on learnable templates.
If None, no symmetry constraints are placed on the templates learned by the module.
If
'psd'
, the module is constrained to learn symmetric representations of the graph or matrix: the left and right generators of each template are constrained to be identical.If
'cross'
, the module is also constrained to learn symmetric representations of the graph or matrix. However, in this case, the left and right generators can be different, and the template is defined as the average of the expansion and its transpose: \(\frac{1}{2} \left(L R^{\intercal} + R L^{\intercal}\right)\)If
'skew'
, the module is constrained to learn skew-symmetric representations of the graph or matrix. The template is defined as the difference between the expansion and its transpose: \(\frac{1}{2} \left(L R^{\intercal} - R L^{\intercal}\right)\)
This option is not available for nonsquare matrices or bipartite graphs. Note that the parameter count doubles if this is False.
- coupling: None, ``’+’``, ``’-‘``, ``’split’``, int, or float
Coupling parameter when expanding outer-product template banks.
A value of
None
disables the coupling parameter.'+'
is equivalent toNone
, fixing coupling to positive 1. Forsymmetry=True
, this enforces positive semidefinite templates.'-'
fixes coupling parameters to negative 1. Forsymmetry=True
, this enforces negative semidefinite templates.'split'
splits channels such that approximately an equal number have coupling parameters fixed to +1 and -1. Forsymmetry=True
, this splits channels among positive and negative semidefinite templates. This option can also be useful when imposing a unilateral normed penalty to favour nonnegative weights, as the template bank can simultaneously satisfy the soft nonnegativity constraint and respond with positive activations to features of either sign, enabling these activations to survive downstream rectifiers.A float value in (0, 1) is just like
split
but fixes the fraction of negative parameters to approximately the specified value.Similarly, an int value fixes the number of negative output channels to the specified value.
'learnable'
sets the diagonal terms of the coupling parameter (the coupling between vector 0 of the left generator and vector 0 of the right generator, for instance, but not between vector 0 of the left generator and vector 1 of the right generator) to be learnable.'learnable_all'
sets all terms of the coupling parameter to be
learnable.
- similarity: function
Definition of the similarity metric. This must be a function whose inputs and outputs are:
input 0: reference matrix
(N x C_in x H x W)
input 1: left template generator
(C_out x C_in x H x R)
input 2: right template generator
(C_out x C_in x H x R)
input 3: symmetry constraint (
'cross'
,'skew'
, or other)output
(N x C_out x H x W)
Similarity is computed between each of the N matrices in the first input stack and the (low-rank) matrix derived from the outer-product expansion of the second and third inputs. Default:
crosshair_similarity
- delete_diagonal: bool
Delete the diagonal of the output.
- Attributes:
- weight: Tensor
The learnable weights of the module of shape
out_channels x in_channels x dim x rank
.- bias: Tensor
The learnable bias of the module of shape
out_channels
.
Methods
__call__
(input, *[, key])Call self as a function.