I release a minimal (<150 lines) JAX implementation of "Gradients without Backpropagation" paper. It proposes a simple addition to forward AD to estimate unbiased gradients during single inference pass (quick project, might be further optimized)

https://github.com/YigitDemirag/forward-gradients

Comments