Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature request: add GP classification example (pref. multi-class) and Poisson likelihood example #22

Open
murphyk opened this issue Sep 8, 2022 · 2 comments

Comments

@murphyk
Copy link

murphyk commented Sep 8, 2022

No description provided.

@wesselb
Copy link
Owner

wesselb commented Sep 9, 2022

Hey @murphyk! Thank you for opening an issue. I think a GP a multi-class GP classification example and Poisson likelihood example would be a super good addition. There is a similar request in #24. I'm currently a little overloaded, but I'd be very happy to put together these examples in a week or two.

@wesselb
Copy link
Owner

wesselb commented Oct 23, 2022

Hi @murphyk! I've put together a non-optimised example of three-class GP classification using Stheno + JAX. A Poisson likelihood example should be similar. Also tagging @patel-zeel here in case he might be interested in this.

For simplicity, the example performs variational inference with a correlated Gaussian over a latent MOGP with two outputs. Alternatively, one can use a Laplace approximation. Perhaps the simplest approach, however, is this.

How does this look to you? Are there any improvements or changes that you would like to see, or does this suffice?

image

from varz.jax import Vars, minimise_adam
from wbml.plot import tweak
import jax.numpy as jnp
import lab.jax as B
import matplotlib.pyplot as plt
from matrix import Diagonal

from stheno.jax import Measure, Normal, GP, EQ, cross

B.set_random_seed(0)
B.epsilon = 1e-8

# Setup some data.
x_obs = 5 * B.rand(jnp.float64, 200)
y_obs = B.cast(jnp.float64, x_obs < 2) + B.cast(jnp.float64, x_obs < 4)

# Setup inducing points locations.
x_ind = B.linspace(jnp.float64, 0, 5, 30)


def build_model(vs):
    params = vs.struct

    # Define GPs which determine the log-probability of the classes. For class 0, we
    # set `f0 = 0`.
    with Measure() as prior:
        f1 = GP(EQ().stretch(0.5))
        f2 = GP(EQ().stretch(0.5))

    # Setup the inducing point approximation.
    f = cross(f1, f2)  # Cartesian product of `f1` and `f2`, a MOGP
    p_u = f(x_ind)  # Prior over the inducing points

    # Setup the variational approximation.
    mean = params.q.mean.unbounded(p_u.mean)
    var = params.q.var.pd(B.dense(B.eye(p_u.var)))
    chol = B.chol(p_u.var)  # This whitening greatly helps learning.
    q_u = Normal(B.matmul(chol, mean), B.matmul(chol, var, chol, tr_c=True))

    return prior, f1, f2, p_u, q_u


def nelbo(vs, state, suppress_kl=False):
    prior, f1, f2, p_u, q_u = build_model(vs)

    # Sample a minibatch of size 20.
    state, perm = B.randperm(state, jnp.int64, len(x_obs))
    inds = perm[:20]
    x_batch = B.take(x_obs, inds)
    y_batch = B.take(y_obs, inds)

    # Sample the latent functions `f0`, `f1`, and `f2` at the data under the approximate
    # posterior. Take 10 samples to average over.
    state, u = q_u.sample(state, 10)
    q = prior | (p_u, B.transpose(u)[:, None, :, None])
    state, f1, f2 = q.sample(
        state,
        f1(x_batch[None, :, None, None]),
        f2(x_batch[None, :, None, None]),
    )
    f0 = B.zeros(f1)

    # Compute the ELBO:
    rec = len(x_obs) / len(x_batch) * B.sum(
        (y_batch == 0)[None, :, None, None] * f0
        + (y_batch == 1)[None, :, None, None] * f1
        + (y_batch == 2)[None, :, None, None] * f2
        - B.logsumexp(B.concat(f0, f1, f2, axis=-1), axis=-1, squeeze=False),
        axis=(1, 2, 3)
    )
    elbo = B.mean(rec) - q_u.kl(p_u)

    return -elbo, state


def predict(vs, x, n_samples=100):
    prior, f1, f2, p_u, q_u = build_model(vs)

    # Sample `n_samples` samples from `f0`, `f1`, and `f2` under the approximate
    # posterior.
    u = q_u.sample(n_samples)
    u = B.reshape(B.transpose(u), n_samples, -1, 1)
    q = prior | (p_u, u)
    f1, f2 = q.sample(f1(x), f2(x))
    f0 = B.zeros(f1)

    # Turn those samples into probabilities.
    probs = B.exp(B.concat(f0, f1, f2, axis=-1))
    probs = probs / B.sum(probs, axis=-1, squeeze=False)

    return B.mean(probs, axis=0), B.std(probs, axis=0)


# Learn variational approximation.
vs = Vars(jnp.float64)
state = B.create_random_state(jnp.float64, seed=0)
_, state = minimise_adam(
    nelbo,
    (vs, state),
    trace=True,
    jit=True,
    rate=5e-2,
    iters=10_000,
)
vs.print()

# Make predictions.
x = B.linspace(jnp.float64, 0, 5, 100)[None, :, None]
mean, std = predict(vs, x)

# Plot predictions.
plt.figure()
plt.scatter(x_obs, y_obs / B.max(y_obs), style="train")
for i, c in enumerate(["tab:blue", "tab:green", "tab:red"]):
    plt.plot(x, mean[:, i], ls="-", color=c)
    plt.fill_between(
        x,
        mean[:, i] - 1.96 * std[:, i],
        mean[:, i] + 1.96 * std[:, i],
        ls="-",
        facecolor=c,
        alpha=0.2,
        edgecolor="none",
    )
tweak()
plt.show()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants