Use with Oryx models#
Oryx is a probabilistic programming library written in JAX, it is thus natively compatible with Blackjax. In this notebook we will show how we can use Oryx as a modeling language together with Blackjax as an inference library.
We reproduce the example in Oryx’s documentation and train a Bayesian Neural Network (BNN) on the iris dataset:
from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)
Number of features: 4
Number of classes: 3
Number of data points: 150
import jax
import jax.numpy as jnp
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
Oryx’s approach, like Aesara’s, is to implement probabilistic models as generative models and then apply transformations to get the log-probability density function. We begin with implementing a dense layer with normal prior probability on the weights and use the function random_variable to define random variables:
from oryx.core.ppl import random_variable
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
def dense(dim_out, activation=jax.nn.relu):
def forward(key, x):
dim_in = x.shape[-1]
w_key, b_key = jax.random.split(key)
w = random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
name='w'
)(w_key)
b = random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
name='b'
)(b_key)
return activation(jnp.dot(w, x) + b)
return forward
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 from oryx.core.ppl import random_variable
3 from tensorflow_probability.substrates import jax as tfp
4 tfd = tfp.distributions
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/oryx/__init__.py:16
1 # Copyright 2024 The oryx Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...) 12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 """Oryx imports."""
---> 16 from oryx import bijectors
17 from oryx import core
18 from oryx import distributions
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/oryx/bijectors/__init__.py:20
16 import inspect
18 from tensorflow_probability.substrates import jax as tfp # pylint: disable=g-importing-member
---> 20 tfb = tfp.bijectors
22 __all__ = tfb.__all__
23 for name in __all__:
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:56, in LazyLoader.__getattr__(self, item)
55 def __getattr__(self, item):
---> 56 module = self._load()
57 return getattr(module, item)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/lazy_loader.py:43, in LazyLoader._load(self)
41 self._on_first_access = None
42 # Import the target module and insert it into the parent's namespace
---> 43 module = importlib.import_module(self.__name__)
44 if self._parent_module_globals is not None:
45 self._parent_module_globals[self._local_name] = module
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/importlib/__init__.py:90, in import_module(name, package)
88 break
89 level += 1
---> 90 return _bootstrap._gcd_import(name[level:], package, level)
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/__init__.py:42
39 from tensorflow_probability.python.version import __version__
40 # from tensorflow_probability.substrates.jax.google import autosts # DisableOnExport # pylint:disable=line-too-long
41 # from tensorflow_probability.substrates.jax.google import staging # DisableOnExport # pylint:disable=line-too-long
---> 42 from tensorflow_probability.substrates.jax import bijectors
43 from tensorflow_probability.substrates.jax import distributions
44 from tensorflow_probability.substrates.jax import experimental
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/__init__.py:19
15 """Bijective transformations."""
17 # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
---> 19 from tensorflow_probability.substrates.jax.bijectors.absolute_value import AbsoluteValue
20 from tensorflow_probability.substrates.jax.bijectors.ascending import Ascending
21 # from tensorflow_probability.substrates.jax.bijectors.batch_normalization import BatchNormalization
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/absolute_value.py:17
1 # Copyright 2018 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...) 13 # limitations under the License.
14 # ============================================================================
15 """AbsoluteValue bijector."""
---> 17 from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
19 from tensorflow_probability.substrates.jax.bijectors import bijector
20 from tensorflow_probability.substrates.jax.internal import assert_util
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/__init__.py:19
17 from tensorflow_probability.python.internal.backend.jax import __internal__
18 from tensorflow_probability.python.internal.backend.jax import bitwise
---> 19 from tensorflow_probability.python.internal.backend.jax import compat
20 from tensorflow_probability.python.internal.backend.jax import config
21 from tensorflow_probability.python.internal.backend.jax import debugging
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/compat.py:17
1 # Copyright 2018 The TensorFlow Probability Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...) 13 # limitations under the License.
14 # ============================================================================
15 """Experimental Numpy backend."""
---> 17 from tensorflow_probability.python.internal.backend.jax import v1
18 from tensorflow_probability.python.internal.backend.jax import v2
19 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import dimension_value
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/v1.py:23
21 from tensorflow_probability.python.internal.backend.jax import _utils as utils
22 from tensorflow_probability.python.internal.backend.jax import initializers
---> 23 from tensorflow_probability.python.internal.backend.jax import linalg_impl
24 from tensorflow_probability.python.internal.backend.jax import numpy_logging as logging
25 from tensorflow_probability.python.internal.backend.jax.gen.tensor_shape import Dimension
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/linalg_impl.py:23
20 import numpy as onp; import jax.numpy as np
22 from tensorflow_probability.python.internal.backend.jax import _utils as utils
---> 23 from tensorflow_probability.python.internal.backend.jax import ops
25 scipy_linalg = utils.try_import('jax.scipy.linalg')
28 __all__ = [
29 'adjoint',
30 'band_part',
(...) 68 # 'tridiagonal_solve',
69 ]
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:681
677 if JAX_MODE:
678 jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = (
679 jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray])
680 jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
--> 681 jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
682 jax.core.pytype_aval_mappings[NumpyVariable] = (
683 jax.core.pytype_aval_mappings[onp.ndarray])
686 Variable = NumpyVariable
File /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
52 message, fn = deprecations[name]
53 if fn is None: # Is the deprecation accelerated?
---> 54 raise AttributeError(message)
55 warnings.warn(message, DeprecationWarning, stacklevel=2)
56 return fn
AttributeError: jax.interpreters.xla.pytype_aval_mappings was deprecated in JAX v0.5.0 and removed in JAX v0.7.0. jax.core.pytype_aval_mappings can be used as a replacement in most cases.
We now use this layer to build a multi-layer perceptron. The nest function is used to create “scope tags” that allows in this context to re-use our dense layer multiple times without name collision in the dictionary that will contain the parameters:
from oryx.core.ppl import nest
def mlp(hidden_sizes, num_classes):
num_hidden = len(hidden_sizes)
def forward(key, x):
keys = jax.random.split(key, num_hidden + 1)
for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
logits = nest(dense(num_classes, activation=lambda x: x),
scope=f'layer_{num_hidden + 1}')(keys[-1], x)
return logits
return forward
Finally, we model the labels as categorical random variables:
import functools
def predict(mlp):
def forward(key, xs):
mlp_key, label_key = jax.random.split(key)
logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
return random_variable(
tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
return forward
We can now build the BNN and sample an initial position for the inference algorithm using joint_sample:
from oryx.core.ppl import joint_sample
bnn = mlp([50, 50], num_classes)
rng_key, init_key = jax.random.split(rng_key)
initial_weights = joint_sample(bnn)(init_key, jnp.ones(num_features))
print(initial_weights.keys())
To sample from this model we will need to obtain its joint distribution log-probability using joint_log_prob:
from oryx.core.ppl import joint_log_prob
def logdensity_fn(weights):
return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)
We can now run the window adaptation to get good values for the parameters of the NUTS algorithm:
%%time
import blackjax
rng_key, warmup_key = jax.random.split(rng_key)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
(last_state, parameters), _ = adapt.run(warmup_key, initial_weights, 100)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
and sample from the model’s posterior distribution:
%%time
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, kernel, last_state, 100)
We can now use our samples to take an estimate of the accuracy that is averaged over the posterior distribution. We use intervene to “inject” the posterior values of the weights instead of sampling from the prior distribution:
from oryx.core.ppl import intervene
posterior_weights = states.position
rng_key, pred_key = jax.random.split(rng_key)
output_logits = jax.vmap(
lambda weights: jax.vmap(lambda x: intervene(bnn, **weights)(
pred_key, x)
)(features)
)(posterior_weights)
output_probs = jax.nn.softmax(output_logits)