topkx#

hypercoil.functional.sparse.topkx(f: Callable, *, retnums: Sequence[int] = (0,), auto_index: bool = False, threshold_type: Literal['abs>', 'abs<>', '<'] = 'abs>', fix_indices_over_channel_dims: bool = True) BCOO[source]#

Transform a function that produces a full matrix so that it instead produces a sparse matrix in the top-k format.

Warning

The transformed function produces as an intermediate result a full matrix. Thus, this transformation will not reduce the memory usage of the function.

Warning

The transformed function is compatible with JIT compilation only if auto_index is False. Indices can be obtained in a separate step using the select_indices() function.

Note

The transformed function requires an additional first argument, which depends on the value of auto_index. If auto_index is True, the first argument should be the value of k. Otherwise, the first argument should be the indices of the top k elements.

Parameters:
fCallable

The function to transform.

retnumsSequence[int] (default: (0,))

The indices of the return values to be converted to the top-k format.

auto_indexbool (default: False)

If True, the indices of the return values will be automatically generated by the transformed function.

indicesTensor (default: None)

The indices of the nonzero entries in the top-k format. If None, then the indices are sampled from the output of the function according to the thresholding operation.

threshold_typeone of ‘abs>’, ‘abs<’, ‘>’, ‘<’ (default: ‘abs>’)

The type of thresholding operation to perform.

fix_indices_over_channel_dimsbool (default: True)

If True, then the indices of nonzero entries that are returned will be fixed over all channel dimensions. If False, then the indices of nonzero entries that are returned are allowed to vary over all channel dimensions.

Returns:
Callable

The transformed function. Any tensor return values specified by retnums will be returned in the top-k sparse format when the transformed function is called. If auto_index is True, then the transformed function takes an additional first argument, which is the value of k to use in the top-k selection. Otherwise, the transformed function takes an additional first argument, which is the indices of the top k entries to include in the output.