The Generative Cookbook

(Probabilistic programming language) GenJAX is a probabilistic programming language (PPL): a system which provides automation for writing programs which perform computations on probability distributions, including sampling, variational approximation, gradient estimation, and more.

(With programmable inference) The design of GenJAX is centered on programmable inference (Mansinghka et al. 2018): automation which allows users to express and customize Bayesian inference algorithms (algorithms for computing with posterior distributions: “x affects y, and I observe y, what are my new beliefs about x?”). Programmable inference supports advanced forms of Monte Carlo and variational inference methods.

Following (Cusumano-Towner et al. 2019), GenJAX’s automation is based on two key concepts: parallel generative functions (GenJAX’s version of probabilistic programs) and traces (samples from probabilistic programs). GenJAX provides:

(Fully vectorized & compatible with JAX) All of GenJAX’s automation is compatible with JAX, implying that any program written in GenJAX can be vmap’d and jit compiled.

Prelude
import genstudio.plot as Plot
import jax.numpy as jnp
import jax.random as jrand
import jax.tree_util as jtu
import treescope
from jax import make_jaxpr, vmap
from jax.lax import cond, scan
from jax.numpy import array, sum

from genjax import (
    GFI,
    Trace,
    flip,
    gen,
    normal,
    normal_reinforce,
    normal_reparam,
    seed,
    trace,
)
from genjax import modular_vmap as pjax_vmap
from genjax.adev import Dual, expectation, flip_enum

treescope.basic_interactive_setup(autovisualize_arrays=False)


def dot(x, y, aspect_ratio=None):
    points = list(zip(x, y))
    plot = (
        Plot.dot(
            points, fill="black", stroke="white", opacity=1.0, strokeWidth=0.5, r=2.5
        )
        + Plot.frame()
    )
    plot = plot + Plot.aspectRatio(aspect_ratio) if aspect_ratio else plot
    return plot


def histogram(data, **kwargs):
    return Plot.histogram(data, **kwargs)


def grid(*plots):
    return Plot.Grid(*plots)

Modeling & inference with GenJAX

Writing a probabilistic model involves telling a story about how the data might have been generated. Inference is the process of inverting that story, to attempt to construct a representation of the elements of the story which coheres with the data.

Another way to think about it: you’re authoring a world whose behavior can give rise to the data, and we’re exploring queries like “I see the data, now what was the behavior, probably?”

Even a regression model follows this pattern. GenJAX supports convenient syntax to express programs that denote probability distributions (over these worlds!). The program below defines a polynomial regression model with a prior over coefficients ("alpha").

# "Authoring a world" in GenJAX:
# * convenient syntax for denoting random variables
# * modeling constructs to build larger distributions from
#   small ones
# * compatibility with JAX computations
@gen
def regression(x):
    # Addresses like "alpha" denote random variables.
    coefficients = normal.repeat(n=3)(0.0, 1.0) @ "alpha"

    # The `@gen` decorator creates a probabilistic program
    # from JAX-compatible Python source code.
    @gen
    def generate_y(x):
        basis_value = array([1.0, x, x**2])
        polynomial_value = sum(
            basis_value * coefficients,
        )
        y = normal(polynomial_value, 0.2) @ "v"
        return y

    # Probabilistic programs can be transformed
    # into new ones: here, `generate_y.vmap` creates
    # a new probabilistic program which applies itself
    # independently to the elements of `x`.
    return generate_y.vmap(in_axes=0)(x) @ "y"


# Sample a curve.
x_range = jnp.linspace(-1, 1, 100)
y_samples = regression.simulate((x_range,)).get_retval()
dot(x_range, y_samples, aspect_ratio=1)

Samples from a probabilistic program defining a distribution which can be used as a regression model. Points sampled noisily near polynomial curves.

The @gen decorator creates a parallel generative function, a probabilistic program datatype which implements an interface that provides automation for sampling and density computation. The interface is called the generative function interface, or GFI for short.

isinstance(regression, GFI)

GenJAX’s GFI consists of 3 methods (simulate, assess, and update), which are shown below.

GFI.simulate

# Sample a trace.
trace = regression.simulate((x_range,))
trace

A trace is a recording of the execution of a parallel generative function. It contains the random choices which were sampled during the execution, as well as other data associated with the execution (the arguments, the return value).

Importantly, a trace also contains a quantity called the score, which is a recording of \(1 / P(\text{random choices})\).

trace.get_score()

GFI.assess

# Evaluate the density of random choices, and
# the return value given those choices.
choices = trace.get_choices()
density, retval = regression.assess((x_range,), choices)
density

The assess method gives you access to \(P(\text{random choices})\). simulate and assess pair together to allow you to implement importance samplers, which we’ll see in a moment.

GFI.update

# Reweight a trace by changing the arguments to the
# probabilistic program which generated it, or
# change the values of random choices (or both!),
# and compute a density ratio for the change.
new_choices = {"alpha": jnp.array([0.3, 1.0, 2.0])}
new_trace, w, _ = trace.update((x_range,), new_choices)
new_trace["alpha"]

The update method allows you to modify or reweight an existing trace. This method is useful when you wish to implement algorithms whose logic involves making change to samples, like Markov chain Monte Carlo (MCMC), or variants of sequential Monte Carlo (SMC).

Why the GFI?

The methods of the GFI are focused on the expression of approximations to inference problems by using Monte Carlo sampling and properly-weighted approximations. This sentence is strongly already an “if you know you know” description: the gist is, inference problems are often intractable for analytical methods, and Monte Carlo is a broad class of approximation methods. The GFI focuses on automation support for a subclass of Monte Carlo that the creators of GenJAX have found to be particularly useful in their own work.

Marginalization of random choices

Marginalization provides a way to hide random choices. Exact marginalization involves computing integrals, which is often intractable for complex distributions. GenJAX supports pseudo-marginalization via stochastic probabilities (Lew, Ghavamizadeh, et al. 2023). To support pseudo-marginalization, constructing a marginal requires that you provide a proposal:

@gen
def model():
    x = normal(0.0, 10.0) @ "x"
    y = normal(0.0, 10.0) @ "y"
    rs = x**2 + y**2
    z = normal(rs, 0.1 + (rs / 100.0)) @ "z"


@gen
def proposal(*args):
    x = normal(0.0, 10.0) @ "x"
    y = normal(0.0, 10.0) @ "y"


# Tell me the model, what address to marginalize to,
# and a proposal for the other addresses given the
# remaining one.
# vmap(marginal(model, Importance(proposal, 5), "z").simulate, axis_size=1000)(())["z"]

Automatic differentiation of expected values

GenJAX also exposes functionality to support unbiased gradient estimation of expected value objectives (Lew, Huot, et al. 2023).

@expectation
def flip_exact_loss(p):
    b = flip_enum(p)
    return cond(
        b,
        lambda _: 0.0,
        lambda p: -p / 2.0,
        p,
    )


flip_exact_loss

Using the @expectation decorator creates an Expectation object, which supports jvp_estimate and grad_estimate methods.

For the above @expectation-decorated program, the meaning corresponds to the following expectation: \[\mathcal{L}(p) := \mathbb{E}_{v \sim Ber(\cdot; p)}[\textbf{if}~v~\textbf{then}~0.0~\textbf{else}~\frac{-p}{2}]\] which we can evaluate analytically: \[\mathcal{L}(p) = (1 - p) * \frac{-p}{2} = \frac{p^2 - p}{2}\] and whose \(\nabla_p\) we can also evaluate analytically: \[\nabla_p\mathcal{L}(p) = p - \frac{1}{2}\]

The methods jvp_estimate and grad_estimate provide access to gradient estimators for the expected value objective \(\mathcal{L}(p)\).

In the @expectation-decorated program, users can inform the automation what gradient estimator they’d like to construct by using samplers equipped with estimation strategies (flip_enum is a Bernoulli sampler with an annotation which directs ADEV’s automation to use enumeration, exactly evaluating the expectation, to construct a gradient estimator).

# Compare ADEV's derived derivatives with the exact value.
for p in [0.1, 0.3, 0.5, 0.7, 0.9]:
    p_dual = flip_exact_loss.jvp_estimate(Dual(p, 1.0))
    treescope.show(p_dual.tangent - (p - 0.5))

Programmable variational inference

ADEV provides access to unbiased gradient estimators of expected value objectives, and expected value objectives occur often in variational inference, where users define optimization problems over spaces of distributions, often using some notion of closeness defined via an expected value.

For instance, given the density of an unnormalized measure \(P\) and parametric variational approximation \(Q(\cdot; \theta)\), the evidence lower bound objective \[\mathbb{E}_{x \sim Q(\cdot; \theta)}[\log P(x) - \log Q(x; \theta)]\] is often used to define inference problems as optimization, where \(\theta \mapsto \theta'\) involves maximizing the objective, squeezing a KL divergence between the normalized version of \(P\) and \(Q\).

@gen
def variational_model():
    x = normal(0.0, 1.0) @ "x"
    y = normal(x, 0.3) @ "y"


@gen
def variational_family(theta):
    # Use distribution with a gradient strategy!
    x = normal_reinforce(theta, 1.0) @ "x"


@expectation
def elbo(family, data: dict, theta):
    # Use GFI methods to structure the objective function!
    tr = family.simulate((theta,))
    q = tr.get_score()
    p, _ = variational_model.assess((), {**data, **tr.get_choices()})
    return p - q


def optimize(family, data, init_theta):
    def update(theta, _):
        _, _, theta_grad = elbo.grad_estimate(family, data, theta)
        theta += 1e-3 * theta_grad
        return theta, theta

    final_theta, intermediate_thetas = scan(
        update,
        init_theta,
        length=500,
    )
    return final_theta, intermediate_thetas


# `seed`: seed any sampling with fresh random keys.
# (GenJAX will send you a warning if you need to do this)
_, thetas = seed(optimize)(
    jrand.key(1),
    variational_family,
    {"y": 3.0},
    0.01,
)
dot(jnp.arange(500), thetas)

The theta parameter over the course of training with a REINFORCE gradient estimator.

What’s programmable about it?

In programmable variational inference (Becker et al. 2024), users are allowed to change their objective function (by writing programs which denote objectives), and they are also allowed to change the unbiased gradient estimator strategy for the objective.

For instance, instead of using normal_reinforce, we could use normal_reparam.

@gen
def reparam_variational_family(theta):
    # Use distribution with a gradient strategy!
    x = normal_reparam(theta, 1.0) @ "x"


_, thetas = seed(optimize)(
    jrand.key(1),
    reparam_variational_family,
    {"y": 3.0},
    0.01,
)
dot(jnp.arange(500), thetas)

The theta parameter over the course of training with a reparametrization gradient estimator.

which leads to a significantly less noisy training process. Trying out different objective functions, and unbiased gradient estimators is an important part of designing variational inference algorithms, and programmable variational inference tries to make this more convenient.

Programmable Monte Carlo

GenJAX’s GFI is designed to provide users with the ability to construct customized Monte Carlo algorithms, producing better quality approximations in difficult inference settings. Paired with expressive modeling syntax, GenJAX allows users to construct complex distributions concisely, and develop effective sampling (and variational) approximations.

Case study: inferencing the Game of Life

The Game of Life is a computational system which gives rise to a bewildering array of interesting phenomenon. One interesting theoretical question: given a fixed Life configuration, is it possible to find a configuration that precedes it? This is known as atavising. Exact atavisation is a complex, high-dimensional discrete search problem. By relaxing the problem to noisily atavise (by adding a bit of probability), we can construct algorithms that build approximate predecessor states almost instantaneously on modern hardware.

Under the hood: PJAX

All of GenJAX’s functionality is constructed on top of a vectorizable probabilistic intermediate representation called PJAX.

PJAX is a modular extension to JAX that explicitly represents operations on probability distributions as first class primitives.

def sampler():
    v = normal.sample(0.0, 1.0)
    return jnp.exp(v)


make_jaxpr(sampler)()
{ lambda ; . let
    a:f32[] = pjax.assume[name=Normal] 0.0 1.0
    b:f32[] = exp a
  in (b,) }

PJAX supports a vmap-like transformation, which is an extension of jax.vmap to natively work with operations on probability distributions.

make_jaxpr(pjax_vmap(sampler, axis_size=10))()
{ lambda ; . let
    a:f32[10] = pjax.assume[name=Normal] 0.0 1.0
    b:f32[10] = exp a
  in (b,) }

GenJAX’s GFI is implemented in terms of PJAX:

@gen
def model():
    x = normal(0.0, 1.0) @ "x"
    y = normal(0.0, 1.0) @ "y"
    return x + y


make_jaxpr(model.simulate)(())
{ lambda ; . let
    a:f32[] = pjax.assume[name=Normal] 0.0 1.0
    b:f32[] = pjax.log_density[name=None] a 0.0 1.0
    c:f32[] = neg b
    d:f32[] = add 0.0 c
    e:f32[] = pjax.assume[name=Normal] 0.0 1.0
    f:f32[] = pjax.log_density[name=None] e 0.0 1.0
    g:f32[] = neg f
    h:f32[] = add d g
    i:f32[] = add a e
  in (0.0, 1.0, a, a, c, 0.0, 1.0, e, e, g, i, h) }

Even ADEV’s unbiased gradient estimator programs are implemented in terms of PJAX:

@expectation
def loss(mu):
    x = normal_reparam.sample(mu, 1.0)
    return x**2


make_jaxpr(loss.grad_estimate)(0.1)
{ lambda ; a:f32[]. let
    b:f32[] = pjax.assume[name=Normal] 0.0 1.0
    c:f32[] = mul 1.0 b
    _:f32[] = mul 0.0 b
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    e:f32[] = add d c
    _:f32[] = integer_pow[y=2] e
    f:f32[] = integer_pow[y=1] e
    g:f32[] = mul 2.0 f
    h:f32[] = mul 1.0 g
    i:f32[] = convert_element_type[new_dtype=float32 weak_type=True] h
  in (i,) }

Advanced: GFI theory

Future work & sharp edges

The creators of GenJAX have intended for GenJAX to be a useful and somewhat general design for GPU-accelerated probabilistic programming! Of course, that’s not always possible: there are known sharp edges when using GenJAX’s automation, which we tabulate below.

Known incompatibilities between features

vmap within ADEV programs

The semantics of vmap within ADEV programs is a direction of future research. vmap is a vectorization transformation that converts primitives into batched versions of themselves. For ADEV’s samplers, the primitives come equipped with gradient estimation strategies. Therefore, using vmap on code which contains ADEV samplers requires care. Not all primitives support “batched” gradient estimation strategies.

References

Becker, McCoy R., Alexander K. Lew, Xiaoyan Wang, Matin Ghavami, Mathieu Huot, Martin C. Rinard, and Vikash K. Mansinghka. 2024. “Probabilistic Programming with Programmable Variational Inference.” Reproduction Packager for Article "Probabilistic Programming with Programmable Variational Inference" 8 (PLDI): 233:2123–47. https://doi.org/10.1145/3656463.
Cusumano-Towner, Marco F., Feras A. Saad, Alexander K. Lew, and Vikash K. Mansinghka. 2019. “Gen: A General-Purpose Probabilistic Programming System with Programmable Inference.” In Proceedings of the 40th ACM SIGPLAN Conference on Programming Language Design and Implementation, 221–36. PLDI 2019. New York, NY, USA: Association for Computing Machinery. https://doi.org/10.1145/3314221.3314642.
Lew, Alexander K., Matin Ghavamizadeh, Martin C. Rinard, and Vikash K. Mansinghka. 2023. “Probabilistic Programming with Stochastic Probabilities.” Proc. ACM Program. Lang. 7 (PLDI): 176:1708–32. https://doi.org/10.1145/3591290.
Lew, Alexander K., Mathieu Huot, Sam Staton, and Vikash K. Mansinghka. 2023. ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs.” Proc. ACM Program. Lang. 7 (POPL): 5:121–53. https://doi.org/10.1145/3571198.
Mansinghka, Vikash K., Ulrich Schaechtle, Shivam Handa, Alexey Radul, Yutian Chen, and Martin Rinard. 2018. “Probabilistic Programming with Programmable Inference.” In Proceedings of the 39th ACM SIGPLAN Conference on Programming Language Design and Implementation, 603–16. PLDI 2018. New York, NY, USA: Association for Computing Machinery. https://doi.org/10.1145/3192366.3192409.