A couple of years ago I made the jump from PyTorch to JAX. Now, the skill of writing autodifferentiable code turns out to translate pretty smoothly between different frameworks. In this case, PyTorch and JAX really aren’t that different: replace torch.foo(...) with jax.numpy.foo(...) and you’re 95% of the way there! What about the other 5%? That’s the purpose of this article! Assuming you already know PyTorch, this is what I’ve found that you need to know to get up to speed with JAX.