Regime switching Hidden Markov model#
This example replicates the case study analyzing financial time series, specifically the daily difference in log price data of Google’s stock, referred to as returns \(r_t\).
We’ll assume that at any given time \(t\) the stock’s returns will follow one of two regimes: an independent random walk regime where \(r_t \sim \mathcal{N}(\alpha_1, \sigma^2_1)\) and an autoregressive regime where \(r_t \sim \mathcal{N}(\alpha_2 + \rho r_{t-1}, \sigma_2^2)\). Being on either of the two regimes, \(s_t\in \{0, 1\}\), will depend on the previous time’s regime \(s_{t-1}\), call these probabilities \(p_{s_{t-1}, s_{t}}\) for \(s_{t-1}, s_t \in \{0, 1\}\). Set as parameters of the model \(p_{1,1}\) and \(p_{2,2}\) and define the complementary probabilities by definition: \(p_{1,2} = 1-p_{1,1}\) and \(p_{2,1} = 1-p_{2,2}\). Since the regime at any time is unobserved, we instead carry over time the probability of belonging to either one regime as \(\xi_{1t} + \xi_{2t} = 1\). Finally, we need to model initial values, both for returns \(r_0\) and probability of belonging to one of the two regimes \(\xi_{10}\).
In the whole, our regime-switching model is defined by the likelihood
where \(\eta_{jt} = p_{j,1}\), \(\mathcal{N}(r_t;\alpha_1, \sigma_1^2) + p_{j,2}\), and \(\mathcal{N}(r_t; \alpha_2 + \rho r_{t-1}, \sigma_2^2)\) for \(j\in\{0, 1\}\). And the priors of the parameters are:
where \(\mathcal{N}^0\) indicates the truncated at 0 Gaussian distribution and \(\mathcal{C}^+\) the half-Cauchy distribution.
Show code cell content Hide code cell content
import matplotlib.pyplot as plt
import arviz as az
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
az.rcParams["plot.max_subplots"] = 50
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as distrib
import pandas as pd
from jax.scipy.stats import norm
from numpyro.diagnostics import print_summary
from numpyro.infer.util import initialize_model
import blackjax
class RegimeMixtureDistribution(distrib.Distribution):
arg_constraints = {
"alpha": distrib.constraints.real,
"rho": distrib.constraints.positive,
"sigma": distrib.constraints.positive,
"p": distrib.constraints.interval(0, 1),
"xi_0": distrib.constraints.interval(0, 1),
"y_0": distrib.constraints.real,
"T": distrib.constraints.positive_integer,
}
support = distrib.constraints.real
def __init__(self, alpha, rho, sigma, p, xi_0, y_0, T, validate_args=True):
self.alpha, self.rho, self.sigma, self.p, self.xi_0, self.y_0, self.T = (
alpha,
rho,
sigma,
p,
xi_0,
y_0,
T,
)
super().__init__(event_shape=(T,), validate_args=validate_args)
def log_prob(self, value):
def obs_t(carry, y):
y_prev, xi_1 = carry
eta_1 = norm.pdf(y, loc=self.alpha[0], scale=self.sigma[0])
eta_2 = norm.pdf(
y, loc=self.alpha[1] + y_prev * self.rho, scale=self.sigma[1]
)
lik_1 = self.p[0] * eta_1 + (1 - self.p[0]) * eta_2
lik_2 = (1 - self.p[1]) * eta_1 + self.p[1] * eta_2
lik = xi_1 * lik_1 + (1 - xi_1) * lik_2
lik = jnp.clip(lik, a_min=1e-6)
return (y, xi_1 * lik_1 / lik), jnp.log(lik)
_, log_liks = jax.lax.scan(obs_t, (self.y_0, self.xi_0), value)
return jnp.sum(log_liks)
def sample(self, key, sample_shape=()):
return jnp.zeros(sample_shape + self.event_shape)
class RegimeSwitchHMM:
def __init__(self, T, y) -> None:
self.T = T
self.y = y
def model(self, y=None):
rho = numpyro.sample("rho", distrib.TruncatedNormal(1.0, 0.1, low=0.0))
alpha = numpyro.sample("alpha", distrib.Normal(0.0, 0.1).expand([2]))
sigma = numpyro.sample("sigma", distrib.HalfCauchy(1.0).expand([2]))
p = numpyro.sample("p", distrib.Beta(10.0, 2.0).expand([2]))
xi_0 = numpyro.sample("xi_0", distrib.Beta(2.0, 2.0))
y_0 = numpyro.sample("y_0", distrib.Normal(0.0, 1.0))
numpyro.sample(
"obs",
RegimeMixtureDistribution(alpha, rho, sigma, p, xi_0, y_0, self.T),
obs=y,
)
def initialize_model(self, rng_key, n_chain):
(init_params, *_), self.potential_fn, *_ = initialize_model(
rng_key,
self.model,
model_kwargs={"y": self.y},
dynamic_args=True,
)
kchain = jax.random.split(rng_key, n_chain)
flat, unravel_fn = jax.flatten_util.ravel_pytree(init_params)
self.init_params = jax.vmap(
lambda k: unravel_fn(jax.random.normal(k, flat.shape))
)(kchain)
self.init_params = {
name: 1.0 + value if name in ["p", "sigma"] else value
for name, value in self.init_params.items()
}
# self.init_params = {name: 3. + value if name in ['sigma'] else value for name, value in self.init_params.items()}
def logdensity_fn(self, params):
return -self.potential_fn(self.y)(params)
def inference_loop(rng, init_state, kernel, n_iter):
keys = jax.random.split(rng, n_iter)
def step(state, key):
state, info = kernel(key, state)
return state, (state, info)
_, (states, info) = jax.lax.scan(step, init_state, keys)
return states, info
url = "https://raw.githubusercontent.com/blackjax-devs/blackjax/main/docs/examples/data/google.csv"
data = pd.read_csv(url)
y = data["dl_ac"].values * 100
T, _ = data.shape
dist = RegimeSwitchHMM(T, y)
n_chain, n_warm, n_iter = 128, 5000, 200
# rng_key, kinit, ksam = jax.random.split(rng_key, 3)
ksam, kinit = jax.random.split(jax.random.key(0), 2)
dist.initialize_model(kinit, n_chain)
tic1 = pd.Timestamp.now()
k_warm, k_sample = jax.random.split(ksam)
warmup = blackjax.meads_adaptation(dist.logdensity_fn, n_chain)
(init_state, parameters), _ = warmup.run(k_warm, dist.init_params, n_warm)
kernel = blackjax.ghmc(dist.logdensity_fn, **parameters).step
def one_chain(k_sam, init_state):
state, info = inference_loop(k_sam, init_state, kernel, n_iter)
return state.position, info
k_sample = jax.random.split(k_sample, n_chain)
samples, infos = jax.vmap(one_chain)(k_sample, init_state)
tic2 = pd.Timestamp.now()
print("Runtime for MEADS", tic2 - tic1)
Runtime for MEADS 0 days 00:00:24.842825
print_summary(samples)
mean std median 5.0% 95.0% n_eff r_hat
alpha[0] 0.06 0.08 0.06 -0.08 0.20 4644.68 1.02
alpha[1] 0.01 0.10 0.01 -0.15 0.17 4273.45 1.02
p[0] 2.70 0.39 2.68 2.04 3.32 3296.83 1.04
p[1] 1.85 0.88 1.76 0.37 3.18 4249.95 1.02
rho -0.02 0.10 -0.01 -0.18 0.15 3522.40 1.03
sigma[0] 1.04 0.05 1.04 0.97 1.12 4382.62 1.03
sigma[1] 2.16 0.18 2.14 1.85 2.44 3270.76 1.03
xi_0 0.51 1.01 0.47 -1.11 2.16 4141.67 1.02
y_0 0.01 1.00 0.00 -1.59 1.68 4428.51 1.02
import arviz as az
idata = az.from_dict(posterior=samples)
az.plot_pair(idata, kind='kde', marginals=True, textsize=20,
marginal_kwargs=dict(plot_kwargs={"lw":5}))
plt.tight_layout();