Source code for blackjax.progress_bar

# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import io_callback


[docs] def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} if print_rate is None: if num_samples > 20: print_rate = int(num_samples / 20) else: print_rate = 1 # if you run the sampler for less than 20 iterations def _define_bar(arg): del arg progress_bars[0] = progress_bar(range(num_samples)) progress_bars[0].update(0) def _update_bar(arg): progress_bars[0].update_bar(arg + 1) def _close_bar(arg): del arg progress_bars[0].on_iter_end() def _update_progress_bar(iter_num): "Updates progress bar of a JAX scan or loop" _ = lax.cond( iter_num == 0, lambda _: io_callback(_define_bar, None, iter_num), lambda _: None, operand=None, ) _ = lax.cond( # update every multiple of `print_rate` except at the end (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), lambda _: io_callback(_update_bar, None, iter_num), lambda _: None, operand=None, ) _ = lax.cond( iter_num == num_samples - 1, lambda _: io_callback(_close_bar, None, None), lambda _: None, operand=None, ) def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. Note that `body_fun` must either be looping over `np.arange(num_samples)`, or be looping over a tuple who's first element is `np.arange(num_samples)` This means that `iter_num` is the current iteration number """ def wrapper_progress_bar(carry, x): if type(x) is tuple: iter_num, *_ = x else: iter_num = x _update_progress_bar(iter_num) return func(carry, x) return wrapper_progress_bar return _progress_bar_scan