Use with Oryx models

Use with Oryx models#

Oryx is a probabilistic programming library written in JAX, it is thus natively compatible with Blackjax. In this notebook we will show how we can use Oryx as a modeling language together with Blackjax as an inference library.

We reproduce the example in Oryx’s documentation and train a Bayesian Neural Network (BNN) on the iris dataset:

from sklearn import datasets

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)
Hide code cell source
print(f"Number of features: {num_features}")
print(f"Number of classes: {num_classes}")
print(f"Number of data points: {features.shape[0]}")
Number of features: 4
Number of classes: 3
Number of data points: 150
import jax
import jax.numpy as jnp

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

Oryx’s approach, like Aesara’s, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function random_variable to define random variables:

from oryx.core.ppl import random_variable

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


def dense(dim_out, activation=jax.nn.relu):

    def forward(key, x):
        dim_in = x.shape[-1]
        w_key, b_key = jax.random.split(key)
        w = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
            name='w'
        )(w_key)
        b = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
            name='b'
        )(b_key)

        return activation(jnp.dot(w, x) + b)

    return forward

We now use this layer to build a multi-layer perceptron. The nest function is used to create “scope tags” that allows in this context to re-use our dense layer multiple times without name collision in the dictionary that will contain the parameters:

from oryx.core.ppl import nest

def mlp(hidden_sizes, num_classes):
    num_hidden = len(hidden_sizes)

    def forward(key, x):
        keys = jax.random.split(key, num_hidden + 1)
        for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
            x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
        logits = nest(dense(num_classes, activation=lambda x: x),
                        scope=f'layer_{num_hidden + 1}')(keys[-1], x)
        return logits

    return forward

Finally, we model the labels as categorical random variables:

import functools

def predict(mlp):
    def forward(key, xs):
        mlp_key, label_key = jax.random.split(key)
        logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
        return random_variable(
            tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)

    return forward

We can now build the BNN and sample an initial position for the inference algorithm using joint_sample:

from oryx.core.ppl import joint_sample

bnn = mlp([50, 50], num_classes)
rng_key, init_key = jax.random.split(rng_key)
initial_weights = joint_sample(bnn)(init_key, jnp.ones(num_features))

print(initial_weights.keys())
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[7], line 5
      3 bnn = mlp([50, 50], num_classes)
      4 rng_key, init_key = jax.random.split(rng_key)
----> 5 initial_weights = joint_sample(bnn)(init_key, jnp.ones(num_features))
      7 print(initial_weights.keys())

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/oryx/core/interpreters/harvest.py:768, in reap.<locals>.wrapped(*args, **kwargs)
    767 def wrapped(*args, **kwargs):
--> 768   return call_and_reap(
    769       f,
    770       tag=tag,
    771       allowlist=allowlist,
    772       blocklist=blocklist,
    773       exclusive=exclusive)(*args, **kwargs)[1]

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/oryx/core/interpreters/harvest.py:703, in call_and_reap.<locals>.wrapped(*args, **kwargs)
    702 def wrapped(*args, **kwargs):
--> 703   out, reaps, preds = _call_and_reap(
    704       f,
    705       tag=tag,
    706       allowlist=allowlist,
    707       blocklist=blocklist,
    708       exclusive=exclusive,
    709   )(*args, **kwargs)
    711   def select_from_pred(pred, value):
    712     if pred is None:

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/oryx/core/interpreters/harvest.py:739, in _call_and_reap.<locals>.wrapped(*args, **kwargs)
    737 flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
    738 flat_fun = reap_function(flat_fun, settings, False)
--> 739 out_flat, reaps, preds = flat_fun.call_wrapped(flat_args)
    740 return tree_util.tree_unflatten(out_tree(), out_flat), reaps, preds

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    209 def call_wrapped(self, *args, **kwargs):
    210   """Calls the transformed function"""
--> 211   return self.f_transformed(*args, **kwargs)

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/linear_util.py:245, in transformation.<locals>.gen2(f, *args, **kwargs)
    243 def gen2(f, *args, **kwargs):
    244   gen_inst = gen(*args, **kwargs)
--> 245   args_, kwargs_ = next(gen_inst)
    246   return gen_inst.send(f(*args_, **kwargs_))

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/oryx/core/interpreters/harvest.py:646, in reap_function(settings, return_metadata, args)
    644 """A function transformation that returns reap values and predicates."""
    645 context = ReapContext(settings, {})
--> 646 with harvest_trace(context):
    647   out_values = yield args, {}
    648   reap_values = tree_util.tree_map(lambda x: x.value, context.reaps)

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/contextlib.py:137, in _GeneratorContextManager.__enter__(self)
    135 del self.args, self.kwds, self.func
    136 try:
--> 137     return next(self.gen)
    138 except StopIteration:
    139     raise RuntimeError("generator didn't yield") from None

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/oryx/core/interpreters/harvest.py:1232, in harvest_trace(context)
   1230 with jax_core.take_current_trace() as parent_trace:
   1231   trace = HarvestTrace(parent_trace, context)
-> 1232   with jax_core.set_current_trace(trace):
   1233     yield

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/core.py:1155, in SetCurrentTraceContextManager.__enter__(self)
   1153 def __enter__(self):
   1154   self.prev = trace_ctx.trace
-> 1155   trace_ctx.set_trace(self.trace)

File /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/jax/_src/core.py:1117, in TracingContext.set_trace(self, trace)
   1115 def set_trace(self, trace):
   1116   self.trace = trace
-> 1117   ts = trace._weakref if trace is not None else None
   1118   config.trace_state.set_local(ts)

AttributeError: 'HarvestTrace' object has no attribute '_weakref'
Hide code cell source
num_parameters = sum([layer.size for layer in jax.tree_util.tree_flatten(initial_weights)[0]])
print(f"Number of parameters in the model: {num_parameters}")

To sample from this model we will need to obtain its joint distribution log-probability using joint_log_prob:

from oryx.core.ppl import joint_log_prob

def logdensity_fn(weights):
  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

We can now run the window adaptation to get good values for the parameters of the NUTS algorithm:

%%time
import blackjax

rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

and sample from the model’s posterior distribution:

Hide code cell content
def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos
%%time

rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 100)

We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use intervene to “inject” the posterior values of the weights instead of sampling from the prior distribution:

from oryx.core.ppl import intervene

posterior_weights = states.position

rng_key, pred_key = jax.random.split(rng_key)
output_logits = jax.vmap(
    lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
        pred_key, x)
    )(features)
)(posterior_weights)

output_probs = jax.nn.softmax(output_logits)
Hide code cell source
print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())

print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())