trace_spspmm#

hypercoil.functional.sparse.trace_spspmm(lhs: BCOO, rhs: BCOO, threshold: float = 0.0, threshold_type: Literal['abs>', 'abs<>', '<'] = 'abs>', top_k: bool = True, top_k_reduction: Literal['mean'] | None = 'mean', fix_indices_over_channel_dims: bool = True) Tensor[source]#

Trace the matrix multiplication of two top-k format sparse matrices to determine the indices of nonzero entries.

The inputs can be boolean matrices, in which case the output contains the indices of True entries. The inputs can also be matrices with values, in which case the output contains the indices of entries that survive the thresholding operation.

Warning

This function is not compatible with JIT compilation.

Warning

If the input is batched or contains multiple channels, the top_k option will return separate indices for each channel and each batch element. Ensure that top_k_reduction is set to 'mean' to obtain a single index across all batch elements (and potentially channels, according to fix_indices_over_channel_dims).

Parameters:
lhsTopKTensor

The left-hand side sparse matrix.

rhsTopKTensor

The right-hand side sparse matrix.

thresholdfloat (default: 0.0)

The threshold value. Used only if the input matrices are matrices with values.

threshold_typeone of ‘abs>’, ‘abs<’, ‘>’, ‘<’ (default: ‘abs>’)

The type of thresholding operation to perform.

top_kbool (default: False)

If True, then the threshold value must be an integer, and the thresholding operation will be replaced by selection of the top k entries.

fix_indices_over_channel_dimsbool (default: True)

If True, then the indices of nonzero entries that are returned will be fixed over all channel dimensions. If False, then the indices of nonzero entries that are returned are allowed to vary over all channel dimensions.