Use with Oryx models

Use with Oryx models#

Oryx is a probabilistic programming library written in JAX, it is thus natively compatible with Blackjax. In this notebook we will show how we can use Oryx as a modeling language together with Blackjax as an inference library.

We reproduce the example in Oryx’s documentation and train a Bayesian Neural Network (BNN) on the iris dataset:

from sklearn import datasets

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)

Hide code cell source

print(f"Number of features: {num_features}")
print(f"Number of classes: {num_classes}")
print(f"Number of data points: {features.shape[0]}")
Number of features: 4
Number of classes: 3
Number of data points: 150
import jax
import jax.numpy as jnp

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

Oryx’s approach, like Aesara’s, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function random_variable to define random variables:

from oryx.core.ppl import random_variable

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions


def dense(dim_out, activation=jax.nn.relu):

    def forward(key, x):
        dim_in = x.shape[-1]
        w_key, b_key = jax.random.split(key)
        w = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
            name='w'
        )(w_key)
        b = random_variable(
            tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
            name='b'
        )(b_key)

        return activation(jnp.dot(w, x) + b)

    return forward
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 from oryx.core.ppl import random_variable
      3 from tensorflow_probability.substrates import jax as tfp
      4 tfd = tfp.distributions

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/oryx/__init__.py:16
      1 # Copyright 2024 The oryx Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 """Oryx imports."""
---> 16 from oryx import bijectors
     17 from oryx import core
     18 from oryx import distributions

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/oryx/bijectors/__init__.py:20
     16 import inspect
     18 from tensorflow_probability.substrates import jax as tfp  # pylint: disable=g-importing-member
---> 20 tfb = tfp.bijectors
     22 __all__ = tfb.__all__
     23 for name in __all__:

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:56, in LazyLoader.__getattr__(self, item)
     55 def __getattr__(self, item):
---> 56   module = self._load()
     57   return getattr(module, item)

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:43, in LazyLoader._load(self)
     41   self._on_first_access = None
     42 # Import the target module and insert it into the parent's namespace
---> 43 module = importlib.import_module(self.__name__)
     44 if self._parent_module_globals is not None:
     45   self._parent_module_globals[self._local_name] = module

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/importlib/__init__.py:90, in import_module(name, package)
     88             break
     89         level += 1
---> 90 return _bootstrap._gcd_import(name[level:], package, level)

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py:42
     39 from tensorflow_probability.python.version import __version__
     40 # from tensorflow_probability.substrates.jax.google import autosts  # DisableOnExport  # pylint:disable=line-too-long
     41 # from tensorflow_probability.substrates.jax.google import staging  # DisableOnExport  # pylint:disable=line-too-long
---> 42 from tensorflow_probability.substrates.jax import bijectors
     43 from tensorflow_probability.substrates.jax import distributions
     44 from tensorflow_probability.substrates.jax import experimental

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py:19
     15 """Bijective transformations."""
     17 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
---> 19 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
     20 from tensorflow_probability.substrates.jax.bijectors.ascending import Ascending
     21 # from tensorflow_probability.substrates.jax.bijectors.batch_normalization import BatchNormalization

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py:17
      1 # Copyright 2018 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)     13 # limitations under the License.
     14 # ============================================================================
     15 """AbsoluteValue bijector."""
---> 17 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
     19 from tensorflow_probability.substrates.jax.bijectors import bijector
     20 from tensorflow_probability.substrates.jax.internal import assert_util

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py:19
     17 from tensorflow_probability.python.internal.backend.jax import __internal__
     18 from tensorflow_probability.python.internal.backend.jax import bitwise
---> 19 from tensorflow_probability.python.internal.backend.jax import compat
     20 from tensorflow_probability.python.internal.backend.jax import config
     21 from tensorflow_probability.python.internal.backend.jax import debugging

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py:17
      1 # Copyright 2018 The TensorFlow Probability Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)     13 # limitations under the License.
     14 # ============================================================================
     15 """Experimental Numpy backend."""
---> 17 from tensorflow_probability.python.internal.backend.jax import v1
     18 from tensorflow_probability.python.internal.backend.jax import v2
     19 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import dimension_value

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py:23
     21 from tensorflow_probability.python.internal.backend.jax import _utils as utils
     22 from tensorflow_probability.python.internal.backend.jax import initializers
---> 23 from tensorflow_probability.python.internal.backend.jax import linalg_impl
     24 from tensorflow_probability.python.internal.backend.jax import numpy_logging as logging
     25 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import Dimension

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py:23
     20 import numpy as onp; import jax.numpy as np
     22 from tensorflow_probability.python.internal.backend.jax import _utils as utils
---> 23 from tensorflow_probability.python.internal.backend.jax import ops
     25 scipy_linalg = utils.try_import('jax.scipy.linalg')
     28 __all__ = [
     29     'adjoint',
     30     'band_part',
   (...)     68     # 'tridiagonal_solve',
     69 ]

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:681
    677 if JAX_MODE:
    678   jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = (
    679       jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray])
    680   jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
--> 681       jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
    682   jax.core.pytype_aval_mappings[NumpyVariable] = (
    683       jax.core.pytype_aval_mappings[onp.ndarray])
    686 Variable = NumpyVariable

File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52 message, fn = deprecations[name]
     53 if fn is None:  # Is the deprecation accelerated?
---> 54   raise AttributeError(message)
     55 warnings.warn(message, DeprecationWarning, stacklevel=2)
     56 return fn

AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.

We now use this layer to build a multi-layer perceptron. The nest function is used to create “scope tags” that allows in this context to re-use our dense layer multiple times without name collision in the dictionary that will contain the parameters:

from oryx.core.ppl import nest

def mlp(hidden_sizes, num_classes):
    num_hidden = len(hidden_sizes)

    def forward(key, x):
        keys = jax.random.split(key, num_hidden + 1)
        for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
            x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
        logits = nest(dense(num_classes, activation=lambda x: x),
                        scope=f'layer_{num_hidden + 1}')(keys[-1], x)
        return logits

    return forward

Finally, we model the labels as categorical random variables:

import functools

def predict(mlp):
    def forward(key, xs):
        mlp_key, label_key = jax.random.split(key)
        logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
        return random_variable(
            tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)

    return forward

We can now build the BNN and sample an initial position for the inference algorithm using joint_sample:

from oryx.core.ppl import joint_sample

bnn = mlp([50, 50], num_classes)
rng_key, init_key = jax.random.split(rng_key)
initial_weights = joint_sample(bnn)(init_key, jnp.ones(num_features))

print(initial_weights.keys())

Hide code cell source

num_parameters = sum([layer.size for layer in jax.tree_util.tree_flatten(initial_weights)[0]])
print(f"Number of parameters in the model: {num_parameters}")

To sample from this model we will need to obtain its joint distribution log-probability using joint_log_prob:

from oryx.core.ppl import joint_log_prob

def logdensity_fn(weights):
  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

We can now run the window adaptation to get good values for the parameters of the NUTS algorithm:

%%time
import blackjax

rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

and sample from the model’s posterior distribution:

Hide code cell content

def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos
%%time

rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 100)

We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use intervene to “inject” the posterior values of the weights instead of sampling from the prior distribution:

from oryx.core.ppl import intervene

posterior_weights = states.position

rng_key, pred_key = jax.random.split(rng_key)
output_logits = jax.vmap(
    lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
        pred_key, x)
    )(features)
)(posterior_weights)

output_probs = jax.nn.softmax(output_logits)

Hide code cell source

print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())

print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())