TL;DR: you can explicitly use type annotations of the form def f(x: Float[Tensor, "channels"], y: Float[Tensor, "channels"]): ... to specify the shape+dtype of tensors/arrays; declare that these shapes are consistent across multiple arguments; use runtime type-checking to enforce that these are correct. See the (now quite popular!) jaxtyping library on GitHub. And note that the name is now historical – it also supports PyTorch/TensorFlow/NumPy, and has no JAX dependency.