apply_mask#
- hypercoil.functional.utils.apply_mask(tensor: Tensor, msk: Tensor, axis: int) Tensor[source]#
Mask a tensor along an axis.
Warning
This function will only work if the mask is one-dimensional. For multi-dimensional masks, use
conform_mask().Warning
Use of this function is strongly discouraged. It is incompatible with jax.jit.
See also