HN.zip

Defining Statistical Models in Jax?

102 points by hackandthink - 18 comments
JHonaker [3 hidden]5 mins ago
I'm very excited by the work being put in to make Bayesian inference more manageable. It's in a spot that feels very similar to deep learning circa mid-2010s when Caffe, Torch, and hand-written gradients were the options. We can do it, but doing anything more complicated than common model structures like hierarchical Gaussian linear models requires dropping out of the nice places and into the guts.

I've had a lot of success with Numpyro (a JAX library), and used quite a lot of tools that are simpler interfaces to Stan. I've also had to write quite a few model-specific things from scratch by hand (more for sequential Monte Carlo than MCMC). I'm very excited for a world where PPLs become scalable and easier to use /customize.

> I think there is a good chance that normalizing flow-based variational inference will displace MCMC as the go-to method for Bayesian posterior inference as soon as everyone gets access to good GPUs.

Wow. This is incredibly surprising. I'm only tangentially aware of normalizing flows, but apparently I need to look at the intersection of them and Bayesian statistics now! Any sources from anyone would be most appreciated!

sarosh [3 hidden]5 mins ago
Defer to other experts, but (briefly) normalizing flows are a method for constructing complex distributions by transforming a probability density through a series of invertible transformations. Normalizing flows are trained using a plain log-likelihood function, and they are capable of exact density evaluation and efficient sampling. See:

Danilo Rezende and Shakir Mohamed. Variational inference with normalizing flows. In ICML, 2015. Link: https://bigdata.duke.edu/wp-content/uploads/2022/08/1505.057...

Laurent Dinh, David Krueger, and Yoshua Bengio. Nice: Non-linear independent components estimation. In ICLR Workshop, 2015. Link: https://arxiv.org/pdf/1410.8516

And for your direct question, the following paper "Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods" appears upon a superficial glance to be relevant. Link: https://arxiv.org/pdf/2107.08001

1980phipsi [3 hidden]5 mins ago
So it's like converting a normal distribution to log normal (and then back). But a more general way of thinking about it.

Where does the name "normalizing flows" come from?

hotstickyballs [3 hidden]5 mins ago
It comes from the Jacobian which you can get from auto diff. It measures how much distortion the function created and normalizes it so that you can integrate correctly without blowing up gradients
theGnuMe [3 hidden]5 mins ago
I mean the whole thing sounds like a deep neural network…
JHonaker [3 hidden]5 mins ago
Thanks! I've read the first one before. I'll take a look at the other two!
szvsw [3 hidden]5 mins ago
> make Bayesian inference more manageable

Discovering PyMC and the excellent accompanying textbook was game changing for me! Being able to write full hierarchical models in a handful of lines of code hooked up to pandas data frames already is so wonderful.

The more tools for this the better!

legobmw99 [3 hidden]5 mins ago
The author links to https://arxiv.org/abs/2006.10343, which seems like a good place to start on normalizing flows for Bayes
nextos [3 hidden]5 mins ago
Pyro has a nice normalizing flows tutorial: https://pyro.ai/examples/normalizing_flows_intro.html
JHonaker [3 hidden]5 mins ago
Ah, I did not realize that the `realNVP` was a link! Thanks.
gnulinux [3 hidden]5 mins ago
Reading this post, and reviewing the documentation of NumPyro/Pyro, I think I'm not following the crucial difference between NumPyro/Pyro. I understand that Pyro uses PyTorch as backend, and NumPyro uses JAX as backend, but other than that I'm not sure about the critical differences. If their frontend is about the same (which seems to be the case here) why is JAX mentioned in this post? Could we simply not replace Pyro with Stan for statistical modelling (whether with PyTorch or JAX backend)?
nextos [3 hidden]5 mins ago
> Could we simply not replace Pyro with Stan for statistical modelling (whether with PyTorch or JAX backend)?

Stan has a fantastic NUTS Monte Carlo implementation. Pyro & NumPyro are more focused on variational inference. For a third alternatively that IMHO doesn't get the attention it deserves, take a look at Infer.NET, which excels at expectation propagation and uses factor graphs underneath. These three offer very different tradeoffs.

Stan is less expressive than Pyro/NumPyro. But for the models it can deal with (generally medium-sized multi-level models), I find it extremely easy to work with. In other words, it's much easier to diagnose model and sampling issues.

Myrmornis [3 hidden]5 mins ago
I'm curious about the involvement of tech companies here. Obviously approximating posterior distributions of explicit statistical models via simulation techniques is common in academic scientific literature but I'd like to hear about examples of it being done in "production" settings, i.e. not just as a one-off analysis. I have for a long time had a vague belief that in production settings people usually opt for heuristics / point estimates etc but I haven't had much involvement with this sort of thing for a while.
nextos [3 hidden]5 mins ago
Pyro was created by Uber AI Labs. Actually, by Geometric Intelligence, which was eventually acquired by Uber. Geometric Intelligence was founded by Gary Marcus, Zoubin Ghahramani and others. They also had Noah Goodman onboard.

AFAIK, Pyro was used in production to make predictions of demand with careful consideration of uncertainty. I was contacted by one of their recruiters when I was doing work in this area, and this was the application they showcased.

Meta is also doing a lot of related work on time series forecasting using Prophet, which employs Stan under the hood. In both cases, Bayesian methods are important to make inference robust, it's not just an academic exercise.

techwizrd [3 hidden]5 mins ago
This is coming at the perfect time! I was recently trying to decide whether I wanted to implement a model in Stan or Pyro/Numpyro, and I've been eyeing implementing in JAX. I would love to write a tutorial comparing Stan to Jax.
helltone [3 hidden]5 mins ago
Off topic: I think there's some opportunities for making bayesian inference technology more accessible, and I'd love to chat with other people in this space. Email in my profile.
Iwan-Zotow [3 hidden]5 mins ago
this is great development!