TensorFlow plugin for the Gen probabilistic programming system

The Julia package GenTF allows for Gen generative functions to invoke TensorFlow computations executed on the GPU by the TensorFlow runtime. Users construct a TensorFlow computation using the familiar TensorFlow Python API, and then package the TensorFlow computation in a TFFunction, which is a type of generative function provided by GenTF. Generative functions written in Gen's built-in modeling language can seamlessly call TFFunctions. GenTF integrates Gen's automatic differentiation with TensorFlow's gradients, allowing automatic differentiation of computations that combine Julia and TensorFlow code.


The installation requires an installation of Python and an installation of the tensorflow Python package. We recommend creating a Python virtual environment and installing TensorFlow via pip in that environment. In what follows, let <python> stand for the absolute path of a Python executable that has access to the tensorflow package.

From the Julia REPL, type ] to enter the Pkg REPL mode and run:

pkg> add https://github.com/probcomp/GenTF

In a Julia REPL, build the PyCall module so that it will use the correct Python environment:

using Pkg; ENV["PYTHON"] = "<python>"; Pkg.build("PyCall")

Check that intended python environment is indeed being used with:

using PyCall; println(PyCall.python)

If you encounter problems, see https://github.com/JuliaPy/PyCall.jl#specifying-the-python-version

Calling the TensorFlow Python API

GenTF uses the Julia package PyCall to invoke the TensorFlow Python API.

First, import PyCall:

using PyCall

Then import the tensorflow Python module:

@pyimport tensorflow as tf

To import a module from a subpackage:

@pyimport tensorflow.train as train

Then, call the TensorFlow Python API with syntax that is very close to Python syntax:

W = tf.get_variable("W", dtype=tf.float32, initializer=init_W)
x = tf.placeholder(tf.float32, shape=(3,), name="x")
y = tf.squeeze(tf.matmul(W, tf.expand_dims(x, axis=1)), axis=1)
sess = tf.Session()
y_val = sess[:run](y, feed_dict=Dict(x => [1., 2., 3.]))

Here are syntax changes that are required for common situations:

See the PyCall README for the complete description of syntax differences introduced when using Julia and PyCall instead of Python.

TensorFlow Generative Functions

A TensorFlow computation graph contains both the model(s) being trained as well as the operations that do the training. In contrast, Gen uses a more rigid separation between models (both generative models and inference models) and the operations that act on models. Specifically, models in Gen are defined as (pure functional) generative functions, and the operations that run the models or train the models are defined in separate Julia code. The GenTF package allows users to construct deterministic generative functions of type TFFunction <: GenerativeFunction from a TensorFlow computation graph in which each TensorFlow element is one of the following roles:

Role in TFFunctionTensorFlow object type
Argumenttf.Tensor produced by tf.placeholder
Trainable Parametertf.Variable
Operation in Bodytf.Tensor produced by non-mutating TensorFlow operation (e.g. tf.conv2d)
N/Atf.Tensor produced by mutating TensorFlow operation (e.g. tf.assign)

TensorFlow placeholders play the role of arguments to the generative function. TensorFlow Variables play the role of the trainable parameters of the generative function. Their value is shared across all invocations of the generative function and is managed by the TensorFlow runtime, not Julia. We will discuss how to train these parameters in section Implementing parameter updates. Tensors produced from non-mutating operations comprise the body of the generative function. One of these elements (either an argument parameter, or element of the body) is designated the return value of the generative function. Note that we do not currently permit TensorFlow generative functions to use randomness.

To construct a TensorFlow generative function, we first construct the TensorFlow computation graph using the TensorFlow Python API:

using Gen
using GenTF
using PyCall

@pyimport tensorflow as tf
@pyimport tensorflow.nn as nn

xs = tf.placeholder(tf.float64) # N x 784
W = tf.Variable(zeros(Float64, 784, 10))
b = tf.Variable(zeros(Float64, 10))
probs = nn.softmax(tf.add(tf.matmul(xs, W), b), axis=1) # N x 10

Then we construct a TFFunction from the TensorFlow graph objects. The first argument to TFFunction is the TensorFlow session, followed by a Vector of trainable parameters (W and b), a Vector of arguments (xs), and finally the return value (probs).

sess = tf.Session()
tf_func = TFFunction([W, b], [xs], probs, sess)

The return value must be a differentiable function of each argument and each parameter. Note that the return value does not need to be a scalar. TensorFlow computations for gradients with respect to the arguments and with respect to the parameters are automatically generated when constructing the TFFunction.

If a session is not provided a new session is created:

tf_func = TFFunction([W, b], [xs], probs)

Values for the parameters are managed by the TensorFlow runtime. The TensorFlow session that contains the parameter values is obtained with:

sess = get_session(tf_func)

The value of a trainable parameter can be obtained in Julia by fetching the Python Variable object (e.g. 'W'):

W_value = sess[:run](W)

Equivalently, this can be done using a more concise syntax with the runtf method:

W_value = runtf(tf_func, W)

What happens during Gen.generate

Suppose we run generate on the TFFunction:

(trace, weight) = generate(tf_func, (xs_val,), choicemap())

Note that we pass an empty assignment to generate because a TFFunction cannot make any random choices that could be constrained.

What happens during Gen.choice_gradients

When running choice_gradients with a trace produced from a TFFunction, we must pass a gradient value for the return value. This value should be a Julia Array with the same shape as the return value.

((xs_grad,), _, _) = choice_gradients(trace, select(), retval_grad)

Note that we pass an empty selection because a TFFunction does not make any random choices that could be selected.

What happens during Gen.accumulate_param_gradients!

When running accumulate_param_gradients! with a trace produced from a TFFunction, we must pass a gradient value for the return value. This value should be a Julia Array with the same shape as the return value.

(xs_grad,) = accumulate_param_gradients!(trace, retval_grad)

The gradient accumulator for a parameter accumulates gradient contributions over multiple invocations of accumulate_param_gradients!. A gradient accumulator TensorFlow Variable value can be obtained from the TFFunction with get_param_grad_tf_var (see API below). The value of all gradient accumulators for a given TFFunction can be reset to zeros with reset_param_grads_tf_op (see API below).

Implementing parameter updates

Updates to the trainable parameters of a TFFunction are also defined using the TensorFlow Python API. For example, below we define a TensorFlow operation to apply one step of stochastic gradient descent, based on the current values of the gradient accumulators for all parameters:

opt = train.GradientDescentOptimizer(.00001)
grads_and_vars = []
push!(grads_and_vars, (tf.negative(get_param_grad_tf_var(tf_func, W)), W))
push!(grads_and_vars, (tf.negative(get_param_grad_tf_var(tf_func, b)), b))
update = opt[:apply_gradients](grads_and_vars)

We can then apply this update with:


We can reset the gradient accumulators to zero when desired with:



See the examples/ directory for examples that show TFFunctions being combined with regular Gen functions.


gen_fn = TFFunction(params::Vector{PyObject},
                    inputs::Vector{PyObject}, output::PyObject,

Construct a TensorFlow generative function from elements of a TensorFlow computation graph.


Return the TensorFlow session associated with the given function.

runtf(gen_fn::TFFunction, ...)

Fetch values or run operations in the TensorFlow session associated with the given function.

Syntactic sugar for get_session(gen_fn).run(args...)

var::PyObject = get_param_grad_tf_var(gen_fn::TFFunction, param::PyObject)

Return the TensorFlow Variable that stores the gradient of the given parameter TensorFlow Variable.

op::PyObject = reset_param_grads_tf_op(gen_fn::TFFunction)

Return the TensorFlow operation Tensor that resets the gradients of all parameters of the given function to zero.