Welcome to part six!!! of our ongoing series on making sparse linear algebra differentiable in JAX with the eventual hope to be able to do some cool statistical shit. We are nowhere near done. Last time, we looked at making JAX primitives. We built four of them. Today we are going to implement the corresponding differentiation rules! For three1 of them. So strap yourselves in. This is gonna be detailed. If you’re interested in the code2, the git repo for this post is linked at the bottom an...