apply_mask#

hypercoil.functional.utils.apply_mask(tensor: Tensor, msk: Tensor, axis: int) Tensor[source]#

Mask a tensor along an axis.

Warning

This function will only work if the mask is one-dimensional. For multi-dimensional masks, use conform_mask().

Warning

Use of this function is strongly discouraged. It is incompatible with jax.jit.