-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature request: add GP classification example (pref. multi-class) and Poisson likelihood example #22
Comments
Hey @murphyk! Thank you for opening an issue. I think a GP a multi-class GP classification example and Poisson likelihood example would be a super good addition. There is a similar request in #24. I'm currently a little overloaded, but I'd be very happy to put together these examples in a week or two. |
Hi @murphyk! I've put together a non-optimised example of three-class GP classification using Stheno + JAX. A Poisson likelihood example should be similar. Also tagging @patel-zeel here in case he might be interested in this. For simplicity, the example performs variational inference with a correlated Gaussian over a latent MOGP with two outputs. Alternatively, one can use a Laplace approximation. Perhaps the simplest approach, however, is this. How does this look to you? Are there any improvements or changes that you would like to see, or does this suffice? from varz.jax import Vars, minimise_adam
from wbml.plot import tweak
import jax.numpy as jnp
import lab.jax as B
import matplotlib.pyplot as plt
from matrix import Diagonal
from stheno.jax import Measure, Normal, GP, EQ, cross
B.set_random_seed(0)
B.epsilon = 1e-8
# Setup some data.
x_obs = 5 * B.rand(jnp.float64, 200)
y_obs = B.cast(jnp.float64, x_obs < 2) + B.cast(jnp.float64, x_obs < 4)
# Setup inducing points locations.
x_ind = B.linspace(jnp.float64, 0, 5, 30)
def build_model(vs):
params = vs.struct
# Define GPs which determine the log-probability of the classes. For class 0, we
# set `f0 = 0`.
with Measure() as prior:
f1 = GP(EQ().stretch(0.5))
f2 = GP(EQ().stretch(0.5))
# Setup the inducing point approximation.
f = cross(f1, f2) # Cartesian product of `f1` and `f2`, a MOGP
p_u = f(x_ind) # Prior over the inducing points
# Setup the variational approximation.
mean = params.q.mean.unbounded(p_u.mean)
var = params.q.var.pd(B.dense(B.eye(p_u.var)))
chol = B.chol(p_u.var) # This whitening greatly helps learning.
q_u = Normal(B.matmul(chol, mean), B.matmul(chol, var, chol, tr_c=True))
return prior, f1, f2, p_u, q_u
def nelbo(vs, state, suppress_kl=False):
prior, f1, f2, p_u, q_u = build_model(vs)
# Sample a minibatch of size 20.
state, perm = B.randperm(state, jnp.int64, len(x_obs))
inds = perm[:20]
x_batch = B.take(x_obs, inds)
y_batch = B.take(y_obs, inds)
# Sample the latent functions `f0`, `f1`, and `f2` at the data under the approximate
# posterior. Take 10 samples to average over.
state, u = q_u.sample(state, 10)
q = prior | (p_u, B.transpose(u)[:, None, :, None])
state, f1, f2 = q.sample(
state,
f1(x_batch[None, :, None, None]),
f2(x_batch[None, :, None, None]),
)
f0 = B.zeros(f1)
# Compute the ELBO:
rec = len(x_obs) / len(x_batch) * B.sum(
(y_batch == 0)[None, :, None, None] * f0
+ (y_batch == 1)[None, :, None, None] * f1
+ (y_batch == 2)[None, :, None, None] * f2
- B.logsumexp(B.concat(f0, f1, f2, axis=-1), axis=-1, squeeze=False),
axis=(1, 2, 3)
)
elbo = B.mean(rec) - q_u.kl(p_u)
return -elbo, state
def predict(vs, x, n_samples=100):
prior, f1, f2, p_u, q_u = build_model(vs)
# Sample `n_samples` samples from `f0`, `f1`, and `f2` under the approximate
# posterior.
u = q_u.sample(n_samples)
u = B.reshape(B.transpose(u), n_samples, -1, 1)
q = prior | (p_u, u)
f1, f2 = q.sample(f1(x), f2(x))
f0 = B.zeros(f1)
# Turn those samples into probabilities.
probs = B.exp(B.concat(f0, f1, f2, axis=-1))
probs = probs / B.sum(probs, axis=-1, squeeze=False)
return B.mean(probs, axis=0), B.std(probs, axis=0)
# Learn variational approximation.
vs = Vars(jnp.float64)
state = B.create_random_state(jnp.float64, seed=0)
_, state = minimise_adam(
nelbo,
(vs, state),
trace=True,
jit=True,
rate=5e-2,
iters=10_000,
)
vs.print()
# Make predictions.
x = B.linspace(jnp.float64, 0, 5, 100)[None, :, None]
mean, std = predict(vs, x)
# Plot predictions.
plt.figure()
plt.scatter(x_obs, y_obs / B.max(y_obs), style="train")
for i, c in enumerate(["tab:blue", "tab:green", "tab:red"]):
plt.plot(x, mean[:, i], ls="-", color=c)
plt.fill_between(
x,
mean[:, i] - 1.96 * std[:, i],
mean[:, i] + 1.96 * std[:, i],
ls="-",
facecolor=c,
alpha=0.2,
edgecolor="none",
)
tweak()
plt.show() |
No description provided.
The text was updated successfully, but these errors were encountered: