I release a minimal (<150 lines) JAX implementation of "Gradients without Backpropagation" paper. It proposes a simple addition to forward AD to estimate unbiased gradients during single inference pass (quick project, might be further optimized)
https://github.com/YigitDemirag/forward-gradients
https://github.com/YigitDemirag/forward-gradients
Comments