GenTF
TensorFlow plugin for the Gen probabilistic programming system
The Julia package GenTF allows for Gen generative functions to invoke TensorFlow computations executed on the GPU by the TensorFlow runtime. Users construct a TensorFlow computation using the familiar TensorFlow Python API, and then package the TensorFlow computation in a TFFunction
, which is a type of generative function provided by GenTF. Generative functions written in Gen's built-in modeling language can seamlessly call TFFunction
s. GenTF integrates Gen's automatic differentiation with TensorFlow's gradients, allowing automatic differentiation of computations that combine Julia and TensorFlow code.
Installation
The installation requires an installation of Python and an installation of the tensorflow Python package. We recommend creating a Python virtual environment and installing TensorFlow via pip in that environment. In what follows, let <python>
stand for the absolute path of a Python executable that has access to the tensorflow
package.
From the Julia REPL, type ]
to enter the Pkg REPL mode and run:
pkg> add https://github.com/probcomp/GenTF
In a Julia REPL, build the PyCall
module so that it will use the correct Python environment:
using Pkg; ENV["PYTHON"] = "<python>"; Pkg.build("PyCall")
Check that intended python environment is indeed being used with:
using PyCall; println(PyCall.python)
If you encounter problems, see https://github.com/JuliaPy/PyCall.jl#specifying-the-python-version
Calling the TensorFlow Python API
GenTF uses the Julia package PyCall to invoke the TensorFlow Python API.
First, import PyCall:
using PyCall
Then import the tensorflow
Python module:
@pyimport tensorflow as tf
To import a module from a subpackage:
@pyimport tensorflow.train as train
Then, call the TensorFlow Python API with syntax that is very close to Python syntax:
W = tf.get_variable("W", dtype=tf.float32, initializer=init_W)
x = tf.placeholder(tf.float32, shape=(3,), name="x")
y = tf.squeeze(tf.matmul(W, tf.expand_dims(x, axis=1)), axis=1)
sess = tf.Session()
sess[:run](tf.global_variables_initializer())
y_val = sess[:run](y, feed_dict=Dict(x => [1., 2., 3.]))
Here are syntax changes that are required for common situations:
Attributes of Python objects (including methods) are accessed using
o[:attr]
instead ofo.attr
. Therefore, to run something in a TensorFlow sessionsess
, usesess[:run](..)
instead ofsess.run(..)
.Where Python dictionaries would be used, use Julia dictionaries instead.
[1, 1, 1, 1]
constructs a JuliaArray
, which by default gets converted to a numpy array, not a Julia list. When a TensorFlow API function requires that an argument is a Python list or a tuple (e.g. thestrides
argument of tf.nn.conv2d), use a Julia tuple:(1, 1, 1, 1)
.
See the PyCall README for the complete description of syntax differences introduced when using Julia and PyCall instead of Python.
TensorFlow Generative Functions
A TensorFlow computation graph contains both the model(s) being trained as well as the operations that do the training. In contrast, Gen uses a more rigid separation between models (both generative models and inference models) and the operations that act on models. Specifically, models in Gen are defined as (pure functional) generative functions, and the operations that run the models or train the models are defined in separate Julia code. The GenTF package allows users to construct deterministic generative functions of type TFFunction <: GenerativeFunction
from a TensorFlow computation graph in which each TensorFlow element is one of the following roles:
Role in TFFunction | TensorFlow object type |
---|---|
Argument | tf.Tensor produced by tf.placeholder |
Trainable Parameter | tf.Variable |
Operation in Body | tf.Tensor produced by non-mutating TensorFlow operation (e.g. tf.conv2d) |
N/A | tf.Tensor produced by mutating TensorFlow operation (e.g. tf.assign) |
TensorFlow placeholders play the role of arguments to the generative function. TensorFlow Variables play the role of the trainable parameters of the generative function. Their value is shared across all invocations of the generative function and is managed by the TensorFlow runtime, not Julia. We will discuss how to train these parameters in section Implementing parameter updates. Tensors produced from non-mutating operations comprise the body of the generative function. One of these elements (either an argument parameter, or element of the body) is designated the return value of the generative function. Note that we do not currently permit TensorFlow generative functions to use randomness.
To construct a TensorFlow generative function, we first construct the TensorFlow computation graph using the TensorFlow Python API:
using Gen
using GenTF
using PyCall
@pyimport tensorflow as tf
@pyimport tensorflow.nn as nn
xs = tf.placeholder(tf.float64) # N x 784
W = tf.Variable(zeros(Float64, 784, 10))
b = tf.Variable(zeros(Float64, 10))
probs = nn.softmax(tf.add(tf.matmul(xs, W), b), axis=1) # N x 10
Then we construct a TFFunction
from the TensorFlow graph objects. The first argument to TFFunction
is the TensorFlow session, followed by a Vector
of trainable parameters (W
and b
), a Vector
of arguments (xs
), and finally the return value (probs
).
sess = tf.Session()
tf_func = TFFunction([W, b], [xs], probs, sess)
The return value must be a differentiable function of each argument and each parameter. Note that the return value does not need to be a scalar. TensorFlow computations for gradients with respect to the arguments and with respect to the parameters are automatically generated when constructing the TFFunction
.
If a session is not provided a new session is created:
tf_func = TFFunction([W, b], [xs], probs)
Values for the parameters are managed by the TensorFlow runtime. The TensorFlow session that contains the parameter values is obtained with:
sess = get_session(tf_func)
The value of a trainable parameter can be obtained in Julia by fetching the Python Variable object (e.g. 'W'):
W_value = sess[:run](W)
Equivalently, this can be done using a more concise syntax with the runtf
method:
W_value = runtf(tf_func, W)
What happens during Gen.generate
Suppose we run generate
on the TFFunction
:
(trace, weight) = generate(tf_func, (xs_val,), choicemap())
The TensorFlow runtime computes the return value for the given values of the arguments and the current values of of the trainable parameters.
The return value is obtained by Julia from TensorFlow and stored in the trace (it is accessible with
get_retval(trace)
).The given argument values are also stored in the trace (accessible with
get_args(trace)
).
Note that we pass an empty assignment to generate
because a TFFunction
cannot make any random choices that could be constrained.
What happens during Gen.choice_gradients
When running choice_gradients
with a trace produced from a TFFunction
, we must pass a gradient value for the return value. This value should be a Julia Array
with the same shape as the return value.
((xs_grad,), _, _) = choice_gradients(trace, select(), retval_grad)
The gradients with respect to each argument are computed by the TensorFlow runtime.
The values of the gradient are converted to Julia values and returned.
Note that we pass an empty selection because a TFFunction
does not make any random choices that could be selected.
What happens during Gen.accumulate_param_gradients!
When running accumulate_param_gradients!
with a trace produced from a TFFunction
, we must pass a gradient value for the return value. This value should be a Julia Array
with the same shape as the return value.
(xs_grad,) = accumulate_param_gradients!(trace, retval_grad)
Like
choice_gradients
, the method returns the value of the gradient with respect to the argumentsThe gradient with respect to each trainable parameters is computed by the TensorFlow runtime.
A gradient accumulator TensorFlow Variable for each trainable parameter is incremented by the corresponding gradient value.
The gradient accumulator for a parameter accumulates gradient contributions over multiple invocations of accumulate_param_gradients!
. A gradient accumulator TensorFlow Variable value can be obtained from the TFFunction
with get_param_grad_tf_var
(see API below). The value of all gradient accumulators for a given TFFunction
can be reset to zeros with reset_param_grads_tf_op
(see API below).
Implementing parameter updates
Updates to the trainable parameters of a TFFunction
are also defined using the TensorFlow Python API. For example, below we define a TensorFlow operation to apply one step of stochastic gradient descent, based on the current values of the gradient accumulators for all parameters:
opt = train.GradientDescentOptimizer(.00001)
grads_and_vars = []
push!(grads_and_vars, (tf.negative(get_param_grad_tf_var(tf_func, W)), W))
push!(grads_and_vars, (tf.negative(get_param_grad_tf_var(tf_func, b)), b))
update = opt[:apply_gradients](grads_and_vars)
We can then apply this update with:
sess[:run](update)
We can reset the gradient accumulators to zero when desired with:
sess[:run](reset_param_grads_tf_op(tf_func))
Examples
See the examples/
directory for examples that show TFFunction
s being combined with regular Gen functions.
API
GenTF.TFFunction
— Type.gen_fn = TFFunction(params::Vector{PyObject},
inputs::Vector{PyObject}, output::PyObject,
sess::PyObject=tf.compat.v1.Session())
Construct a TensorFlow generative function from elements of a TensorFlow computation graph.
GenTF.get_session
— Function.get_session(gen_fn::TFFunction)
Return the TensorFlow session associated with the given function.
GenTF.runtf
— Function.runtf(gen_fn::TFFunction, ...)
Fetch values or run operations in the TensorFlow session associated with the given function.
Syntactic sugar for get_session(gen_fn).run(args...)
GenTF.get_param_grad_tf_var
— Function.var::PyObject = get_param_grad_tf_var(gen_fn::TFFunction, param::PyObject)
Return the TensorFlow Variable that stores the gradient of the given parameter TensorFlow Variable.
GenTF.reset_param_grads_tf_op
— Function.op::PyObject = reset_param_grads_tf_op(gen_fn::TFFunction)
Return the TensorFlow operation Tensor that resets the gradients of all parameters of the given function to zero.