trace_spspmm
#
- hypercoil.functional.sparse.trace_spspmm(lhs: BCOO, rhs: BCOO, threshold: float = 0.0, threshold_type: Literal['abs>', 'abs<>', '<'] = 'abs>', top_k: bool = True, top_k_reduction: Literal['mean'] | None = 'mean', fix_indices_over_channel_dims: bool = True) Tensor [source]#
Trace the matrix multiplication of two top-k format sparse matrices to determine the indices of nonzero entries.
The inputs can be boolean matrices, in which case the output contains the indices of True entries. The inputs can also be matrices with values, in which case the output contains the indices of entries that survive the thresholding operation.
Warning
This function is not compatible with JIT compilation.
Warning
If the input is batched or contains multiple channels, the
top_k
option will return separate indices for each channel and each batch element. Ensure thattop_k_reduction
is set to'mean'
to obtain a single index across all batch elements (and potentially channels, according tofix_indices_over_channel_dims
).- Parameters:
- lhsTopKTensor
The left-hand side sparse matrix.
- rhsTopKTensor
The right-hand side sparse matrix.
- thresholdfloat (default: 0.0)
The threshold value. Used only if the input matrices are matrices with values.
- threshold_typeone of ‘abs>’, ‘abs<’, ‘>’, ‘<’ (default: ‘abs>’)
The type of thresholding operation to perform.
- top_kbool (default: False)
If True, then the threshold value must be an integer, and the thresholding operation will be replaced by selection of the top k entries.
- fix_indices_over_channel_dimsbool (default: True)
If True, then the indices of nonzero entries that are returned will be fixed over all channel dimensions. If False, then the indices of nonzero entries that are returned are allowed to vary over all channel dimensions.