(Probabilistic programming language) GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation, and more.
(With programmable inference) The design of GenJAX is centered on programmable inference(Mansinghka et al. 2018): automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: “x affects y, and I observe y, what are my new beliefs about x?”). Programmable inference supports advanced forms of Monte Carlo and variational inference methods.
Following (Cusumano-Towner et al. 2019), GenJAX’s automation is based on two key concepts: parallel generative functions (GenJAX’s version of probabilistic programs) and traces (samples from probabilistic programs). GenJAX provides:
Modeling language automation for constructing complex probability distributions.
Inference automation for constructing Monte Carlo samplers and variational inference algorithms, including advanced algorithms which utilize marginalization, or complex variational objectives and gradient estimation strategies.
(Fully vectorized & compatible with JAX) All of GenJAX’s automation is compatible with JAX, implying that any program written in GenJAX can be vmap’d and jit compiled.
Writing a probabilistic model involves telling a story about how the data might have been generated. Inference is the process of inverting that story, to attempt to construct a representation of the elements of the story which coheres with the data.
Another way to think about it: you’re authoring a world whose behavior can give rise to the data, and we’re exploring queries like “I see the data, now what was the behavior, probably?”
Even a regression model follows this pattern. GenJAX supports convenient syntax to express programs that denote probability distributions (over these worlds!). The program below defines a polynomial regression model with a prior over coefficients ("alpha").
# "Authoring a world" in GenJAX:# * convenient syntax for denoting random variables# * modeling constructs to build larger distributions from# small ones# * compatibility with JAX computations@gendef regression(x):# Addresses like "alpha" denote random variables. coefficients = normal.repeat(n=3)(0.0, 1.0) @"alpha"# The `@gen` decorator creates a probabilistic program# from JAX-compatible Python source code.@gendef generate_y(x): basis_value = array([1.0, x, x**2]) polynomial_value =sum( basis_value * coefficients, ) y = normal(polynomial_value, 0.2) @"v"return y# Probabilistic programs can be transformed# into new ones: here, `generate_y.vmap` creates# a new probabilistic program which applies itself# independently to the elements of `x`.return generate_y.vmap(in_axes=0)(x) @"y"# Sample a curve.x_range = jnp.linspace(-1, 1, 100)y_samples = regression.simulate((x_range,)).get_retval()dot(x_range, y_samples, aspect_ratio=1)
Samples from a probabilistic program defining a distribution which can be used as a regression model. Points sampled noisily near polynomial curves.
The @gen decorator creates a parallel generative function, a probabilistic program datatype which implements an interface that provides automation for sampling and density computation. The interface is called the generative function interface, or GFI for short.
isinstance(regression, GFI)
GenJAX’s GFI consists of 3 methods (simulate, assess, and update), which are shown below.
GFI.simulate
# Sample a trace.trace = regression.simulate((x_range,))trace
A trace is a recording of the execution of a parallel generative function. It contains the random choices which were sampled during the execution, as well as other data associated with the execution (the arguments, the return value).
Importantly, a trace also contains a quantity called the score, which is a recording of \(1 / P(\text{random choices})\).
trace.get_score()
GFI.assess
# Evaluate the density of random choices, and# the return value given those choices.choices = trace.get_choices()density, retval = regression.assess((x_range,), choices)density
The assess method gives you access to \(P(\text{random choices})\). simulate and assess pair together to allow you to implement importance samplers, which we’ll see in a moment.
GFI.update
# Reweight a trace by changing the arguments to the# probabilistic program which generated it, or# change the values of random choices (or both!),# and compute a density ratio for the change.new_choices = {"alpha": jnp.array([0.3, 1.0, 2.0])}new_trace, w, _ = trace.update((x_range,), new_choices)new_trace["alpha"]
The update method allows you to modify or reweight an existing trace. This method is useful when you wish to implement algorithms whose logic involves making change to samples, like Markov chain Monte Carlo (MCMC), or variants of sequential Monte Carlo (SMC).
Why the GFI?
The methods of the GFI are focused on the expression of approximations to inference problems by using Monte Carlo sampling and properly-weighted approximations. This sentence is strongly already an “if you know you know” description: the gist is, inference problems are often intractable for analytical methods, and Monte Carlo is a broad class of approximation methods. The GFI focuses on automation support for a subclass of Monte Carlo that the creators of GenJAX have found to be particularly useful in their own work.
Marginalization of random choices
Marginalization provides a way to hide random choices. Exact marginalization involves computing integrals, which is often intractable for complex distributions. GenJAX supports pseudo-marginalization via stochastic probabilities (Lew, Ghavamizadeh, et al. 2023). To support pseudo-marginalization, constructing a marginal requires that you provide a proposal:
@gendef model(): x = normal(0.0, 10.0) @"x" y = normal(0.0, 10.0) @"y" rs = x**2+ y**2 z = normal(rs, 0.1+ (rs /100.0)) @"z"@gendef proposal(*args): x = normal(0.0, 10.0) @"x" y = normal(0.0, 10.0) @"y"# Tell me the model, what address to marginalize to,# and a proposal for the other addresses given the# remaining one.# vmap(marginal(model, Importance(proposal, 5), "z").simulate, axis_size=1000)(())["z"]
Automatic differentiation of expected values
GenJAX also exposes functionality to support unbiased gradient estimation of expected value objectives (Lew, Huot, et al. 2023).
Using the @expectation decorator creates an Expectation object, which supports jvp_estimate and grad_estimate methods.
For the above @expectation-decorated program, the meaning corresponds to the following expectation: \[\mathcal{L}(p) := \mathbb{E}_{v \sim Ber(\cdot; p)}[\textbf{if}~v~\textbf{then}~0.0~\textbf{else}~\frac{-p}{2}]\] which we can evaluate analytically: \[\mathcal{L}(p) = (1 - p) * \frac{-p}{2} = \frac{p^2 - p}{2}\] and whose \(\nabla_p\) we can also evaluate analytically: \[\nabla_p\mathcal{L}(p) = p - \frac{1}{2}\]
The methods jvp_estimate and grad_estimate provide access to gradient estimators for the expected value objective \(\mathcal{L}(p)\).
In the @expectation-decorated program, users can inform the automation what gradient estimator they’d like to construct by using samplers equipped with estimation strategies (flip_enum is a Bernoulli sampler with an annotation which directs ADEV’s automation to use enumeration, exactly evaluating the expectation, to construct a gradient estimator).
# Compare ADEV's derived derivatives with the exact value.for p in [0.1, 0.3, 0.5, 0.7, 0.9]: p_dual = flip_exact_loss.jvp_estimate(Dual(p, 1.0)) treescope.show(p_dual.tangent - (p -0.5))
Programmable variational inference
ADEV provides access to unbiased gradient estimators of expected value objectives, and expected value objectives occur often in variational inference, where users define optimization problems over spaces of distributions, often using some notion of closeness defined via an expected value.
For instance, given the density of an unnormalized measure \(P\) and parametric variational approximation \(Q(\cdot; \theta)\), the evidence lower bound objective \[\mathbb{E}_{x \sim Q(\cdot; \theta)}[\log P(x) - \log Q(x; \theta)]\] is often used to define inference problems as optimization, where \(\theta \mapsto \theta'\) involves maximizing the objective, squeezing a KL divergence between the normalized version of \(P\) and \(Q\).
@gendef variational_model(): x = normal(0.0, 1.0) @"x" y = normal(x, 0.3) @"y"@gendef variational_family(theta):# Use distribution with a gradient strategy! x = normal_reinforce(theta, 1.0) @"x"@expectationdef elbo(family, data: dict, theta):# Use GFI methods to structure the objective function! tr = family.simulate((theta,)) q = tr.get_score() p, _ = variational_model.assess((), {**data, **tr.get_choices()})return p - qdef optimize(family, data, init_theta):def update(theta, _): _, _, theta_grad = elbo.grad_estimate(family, data, theta) theta +=1e-3* theta_gradreturn theta, theta final_theta, intermediate_thetas = scan( update, init_theta, length=500, )return final_theta, intermediate_thetas# `seed`: seed any sampling with fresh random keys.# (GenJAX will send you a warning if you need to do this)_, thetas = seed(optimize)( jrand.key(1), variational_family, {"y": 3.0},0.01,)dot(jnp.arange(500), thetas)
The theta parameter over the course of training with a REINFORCE gradient estimator.
What’s programmable about it?
In programmable variational inference (Becker et al. 2024), users are allowed to change their objective function (by writing programs which denote objectives), and they are also allowed to change the unbiased gradient estimator strategy for the objective.
For instance, instead of using normal_reinforce, we could use normal_reparam.
@gendef reparam_variational_family(theta):# Use distribution with a gradient strategy! x = normal_reparam(theta, 1.0) @"x"_, thetas = seed(optimize)( jrand.key(1), reparam_variational_family, {"y": 3.0},0.01,)dot(jnp.arange(500), thetas)
The theta parameter over the course of training with a reparametrization gradient estimator.
which leads to a significantly less noisy training process. Trying out different objective functions, and unbiased gradient estimators is an important part of designing variational inference algorithms, and programmable variational inference tries to make this more convenient.
Programmable Monte Carlo
GenJAX’s GFI is designed to provide users with the ability to construct customized Monte Carlo algorithms, producing better quality approximations in difficult inference settings. Paired with expressive modeling syntax, GenJAX allows users to construct complex distributions concisely, and develop effective sampling (and variational) approximations.
Case study: inferencing the Game of Life
The Game of Life is a computational system which gives rise to a bewildering array of interesting phenomenon. One interesting theoretical question: given a fixed Life configuration, is it possible to find a configuration that precedes it? This is known as atavising. Exact atavisation is a complex, high-dimensional discrete search problem. By relaxing the problem to noisily atavise (by adding a bit of probability), we can construct algorithms that build approximate predecessor states almost instantaneously on modern hardware.
Under the hood: PJAX
All of GenJAX’s functionality is constructed on top of a vectorizable probabilistic intermediate representation called PJAX.
PJAX is a modular extension to JAX that explicitly represents operations on probability distributions as first class primitives.
def sampler(): v = normal.sample(0.0, 1.0)return jnp.exp(v)make_jaxpr(sampler)()
{ lambda ; . let
a:f32[] = pjax.assume[name=Normal] 0.0 1.0
b:f32[] = exp a
in (b,) }
PJAX supports a vmap-like transformation, which is an extension of jax.vmap to natively work with operations on probability distributions.
make_jaxpr(pjax_vmap(sampler, axis_size=10))()
{ lambda ; . let
a:f32[10] = pjax.assume[name=Normal] 0.0 1.0
b:f32[10] = exp a
in (b,) }
GenJAX’s GFI is implemented in terms of PJAX:
@gendef model(): x = normal(0.0, 1.0) @"x" y = normal(0.0, 1.0) @"y"return x + ymake_jaxpr(model.simulate)(())
{ lambda ; . let
a:f32[] = pjax.assume[name=Normal] 0.0 1.0
b:f32[] = pjax.log_density[name=None] a 0.0 1.0
c:f32[] = neg b
d:f32[] = add 0.0 c
e:f32[] = pjax.assume[name=Normal] 0.0 1.0
f:f32[] = pjax.log_density[name=None] e 0.0 1.0
g:f32[] = neg f
h:f32[] = add d g
i:f32[] = add a e
in (0.0, 1.0, a, a, c, 0.0, 1.0, e, e, g, i, h) }
Even ADEV’s unbiased gradient estimator programs are implemented in terms of PJAX:
@expectationdef loss(mu): x = normal_reparam.sample(mu, 1.0)return x**2make_jaxpr(loss.grad_estimate)(0.1)
{ lambda ; a:f32[]. let
b:f32[] = pjax.assume[name=Normal] 0.0 1.0
c:f32[] = mul 1.0 b
_:f32[] = mul 0.0 b
d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
e:f32[] = add d c
_:f32[] = integer_pow[y=2] e
f:f32[] = integer_pow[y=1] e
g:f32[] = mul 2.0 f
h:f32[] = mul 1.0 g
i:f32[] = convert_element_type[new_dtype=float32 weak_type=True] h
in (i,) }
Advanced: GFI theory
Future work & sharp edges
The creators of GenJAX have intended for GenJAX to be a useful and somewhat general design for GPU-accelerated probabilistic programming! Of course, that’s not always possible: there are known sharp edges when using GenJAX’s automation, which we tabulate below.
Known incompatibilities between features
vmap within ADEV programs
The semantics of vmapwithin ADEV programs is a direction of future research. vmap is a vectorization transformation that converts primitives into batched versions of themselves. For ADEV’s samplers, the primitives come equipped with gradient estimation strategies. Therefore, using vmap on code which contains ADEV samplers requires care. Not all primitives support “batched” gradient estimation strategies.
References
Becker, McCoy R., Alexander K. Lew, Xiaoyan Wang, Matin Ghavami, Mathieu Huot, Martin C. Rinard, and Vikash K. Mansinghka. 2024. “Probabilistic Programming with ProgrammableVariationalInference.”Reproduction Packager for Article "Probabilistic Programming with Programmable Variational Inference" 8 (PLDI): 233:2123–47. https://doi.org/10.1145/3656463.
Cusumano-Towner, Marco F., Feras A. Saad, Alexander K. Lew, and Vikash K. Mansinghka. 2019. “Gen: A General-Purpose Probabilistic Programming System with Programmable Inference.” In Proceedings of the 40th ACMSIGPLANConference on ProgrammingLanguageDesign and Implementation, 221–36. PLDI 2019. New York, NY, USA: Association for Computing Machinery. https://doi.org/10.1145/3314221.3314642.
Lew, Alexander K., Matin Ghavamizadeh, Martin C. Rinard, and Vikash K. Mansinghka. 2023. “Probabilistic Programming with StochasticProbabilities.”Proc. ACM Program. Lang. 7 (PLDI): 176:1708–32. https://doi.org/10.1145/3591290.
Lew, Alexander K., Mathieu Huot, Sam Staton, and Vikash K. Mansinghka. 2023. “ADEV: SoundAutomaticDifferentiation of ExpectedValues of ProbabilisticPrograms.”Proc. ACM Program. Lang. 7 (POPL): 5:121–53. https://doi.org/10.1145/3571198.
Mansinghka, Vikash K., Ulrich Schaechtle, Shivam Handa, Alexey Radul, Yutian Chen, and Martin Rinard. 2018. “Probabilistic Programming with Programmable Inference.” In Proceedings of the 39th ACMSIGPLANConference on ProgrammingLanguageDesign and Implementation, 603–16. PLDI 2018. New York, NY, USA: Association for Computing Machinery. https://doi.org/10.1145/3192366.3192409.
Copyright
Copyright McCoy Reynolds Becker & MIT Probabilistic Computing Project. All Rights Reserved.