-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to jit logpdf #21
Comments
Hey @patel-zeel! Nice to hear from you. :) The reason why this is failing is because there is some logic going on which checks for missing values, as you've noticed, and this logic unfortunately doens't work well with the JIT. The recommended solution is to use import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
import lab.jax as B
x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
grad_fn(lengthscale) Output:
Would you be able to check if this also works on your end? Thanks for mentioning Chex! I wasn't aware of the library. It looks useful—I'm going to have a closer look! |
Thank you, @wesselb, for a quick response! This works, and the depth of your customization is amazingly unimaginable :) When I try with a slightly different variant (passing a dictionary instead of value), it throws import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
import lab.jax as B
x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params)) Output
Also, I wonder how to use |
I think the problem now is that is typing is a little restrictive. In particular, a dictionary with JAX-valued values isn't recognised as a JAX object, and that's where the method error comes from. It is possible to add that method manually, as follows: from types import FunctionType
import jax
import jax.numpy as jnp
import lab.jax as B
from plum import Union, Dict
from stheno.jax import GP, EQ
@B.generic._jit_run.dispatch
def _jit_run(
f: FunctionType,
compilation_cache: dict,
jit_kw_args: dict,
*args: Dict[object, B.JAXNumeric],
**kw_args,
):
return B.generic._jit_run.invoke(
FunctionType, dict, dict, B.JAXNumeric,
)(f, compilation_cache, jit_kw_args, *args, **kw_args)
x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params))
It would be possible to make a small amendment to LAB so that this is the default behaviour, if that would be desirable. :) |
Thanks for the solution, @wesselb. Due to the PyTree concept gaining popularity in JAX, dictionaries of parameters or lists and tuples of parameters or a mixture of these are also used sometimes as parameters. Is it possible to make |
I see! Hmm, this might be challenging. Dispatch currently heavily leverages types, and the type of a PyTree is somewhat troublesome. You're right that |
Yes, from the links you have shared, it looks like a harder problem. Maybe there can be a hotfix specifically for lab and Stheno till a more generic solution is available. |
A hacky halfway house solution is to check for JAX-like objects in the first 10 (or so) layers of a PyTree... Then it wouldn't detect things like |
I think 10 or so layers seem practical for most applications. However, how difficult would it be to convert that static number 10 to a dynamic depth by recursively checking PyTree? |
I think that it would be possible to convert the static number to a dynamic depth. However, perhaps the right solution here is to see if we can actually give PyTrees first-class support. I’ll soon be working on a 2.0 of Plum, which is where currently the restrictions derive from. I will put PyTree support on the list of desired improvements! |
With the new version of Plum, I believe it should be possible to use the |
Description of the bug
Hi @wesselb,
I am trying to write some GP code in JAX and accelerate it with
jax.jit
, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also,chex
mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.Code
Output
Description of your environment
Tried this in Google colab.
The text was updated successfully, but these errors were encountered: