Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to jit logpdf #21

Open
patel-zeel opened this issue Sep 7, 2022 · 10 comments
Open

Unable to jit logpdf #21

patel-zeel opened this issue Sep 7, 2022 · 10 comments

Comments

@patel-zeel
Copy link
Contributor

patel-zeel commented Sep 7, 2022

Description of the bug

Hi @wesselb,

I am trying to write some GP code in JAX and accelerate it with jax.jit, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.

Code

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    530         device=device, backend=backend, name=flat_fun.__name__,
--> 531         donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
    532     out_pytree_def = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
   1962   def bind(self, fun, *args, **params):
-> 1963     return call_bind(self, fun, *args, **params)
   1964 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
   1978   fun_ = lu.annotate(fun_, fun.in_type)
-> 1979   outs = top_trace.process_call(primitive, fun_, tracers, params)
   1980   return map(full_lower, apply_todos(env_trace_todo(), outs))

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
    688   def process_call(self, primitive, f, tracers, params):
--> 689     return primitive.impl(f, *tracers, **params)
    690   process_map = process_call

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
    233   compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234                               keep_unused, *arg_specs)
    235   try:

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    294     else:
--> 295       ans = call(fun, *args)
    296       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    324     return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325                               keep_unused, *arg_specs).compile().unsafe_call
    326 

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
    400     jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401         fun, pe.debug_info_final(fun, "jit"))
    402   out_avals, kept_outputs = util.unzip2(out_type)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
   2024     with core.new_sublevel():
-> 2025       jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2026     del fun, main

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   1974     in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975     ans = fun.call_wrapped(*in_tracers_)
   1976     out_tracers = map(trace.full_raise, ans)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
   1002   def grad_f(*args, **kwargs):
-> 1003     _, g = value_and_grad_f(*args, **kwargs)
   1004     return g

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
   1078     if not has_aux:
-> 1079       ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
   1080     else:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
   2497     out_primal, out_vjp = ad.vjp(
-> 2498         flat_fun, primals_flat, reduce_axes=reduce_axes)
   2499     out_tree = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
    132   if not has_aux:
--> 133     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    134   else:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
    121   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122   jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    123   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    768     fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    770     assert not env

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    535   def __array__(self, *args, **kw):
--> 536     raise TracerArrayConversionError(self)
    537 

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
      9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      7 y = jnp.arange(10)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
     11 grad_fn(lengthscale)

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    260         # Handle missing data. We don't handle missing data for batched computation.
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):
    264                 # Take the elements of the mean, variance, and inputs corresponding to

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1531         return B.control_flow.get_outcome("to_numpy")
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:
   1535             B.control_flow.set_outcome("to_numpy", res)

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1495         `np.ndarray`: `a` as NumPy.
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 
   1499 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     30         object: `obj` converted to type `type_to`.
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 
     34 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    605         @wraps(self._f)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 
    609         return wrapped_method

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     59     @_convert.dispatch
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 
     63 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Description of your environment

Tried this in Google colab.

@wesselb
Copy link
Owner

wesselb commented Sep 7, 2022

Hey @patel-zeel! Nice to hear from you. :)

The reason why this is failing is because there is some logic going on which checks for missing values, as you've noticed, and this logic unfortunately doens't work well with the JIT.

The recommended solution is to use B.jit instead where import lab.jax as B. B.jit is a thin wrapper around jax.jit which runs an additional "compilation step" that takes care of code like the checking for NaNs:

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ
import lab.jax as B

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output:

DeviceArray(50.12888031, dtype=float64)

Would you be able to check if this also works on your end?

Thanks for mentioning Chex! I wasn't aware of the library. It looks useful—I'm going to have a closer look!

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Sep 7, 2022

Thank you, @wesselb, for a quick response! This works, and the depth of your customization is amazingly unimaginable :)

When I try with a slightly different variant (passing a dictionary instead of value), it throws plum.function.NotFoundLookupError to me:

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ
import lab.jax as B

x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params))

Output

Traceback (most recent call last):
  File "/home/patel_zeel/gpax/testbed.py", line 12, in <module>
    print(grad_fn(params))
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/lab/generic.py", line 146, in __call__
    return _jit_run(
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 591, in __call__
    method, return_type = self.resolve_method(*sig_types)
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 556, in resolve_method
    method, return_type = self._methods[self.resolve_signature(signature)]
  File "/home/patel_zeel/miniconda3/envs/ajax/lib/python3.9/site-packages/plum/function.py", line 492, in resolve_signature
    raise NotFoundLookupError(
plum.function.NotFoundLookupError: For function "_jit_run", signature Signature(builtins.function, builtins.dict, builtins.dict, builtins.dict) could not be resolved.

Also, I wonder how to use B.jit at places like jax.lax.scan function, which applies jax.jit internally.
Update: I checked with jax.lax.scan, and it works! It looks like lax.scan does not explicitly use jax.jit internally. Now only plum.function.NotFoundLookupError problem remains.

@wesselb
Copy link
Owner

wesselb commented Sep 9, 2022

I think the problem now is that is typing is a little restrictive. In particular, a dictionary with JAX-valued values isn't recognised as a JAX object, and that's where the method error comes from. It is possible to add that method manually, as follows:

from types import FunctionType

import jax
import jax.numpy as jnp

import lab.jax as B
from plum import Union, Dict
from stheno.jax import GP, EQ


@B.generic._jit_run.dispatch
def _jit_run(
    f: FunctionType,
    compilation_cache: dict,
    jit_kw_args: dict,
    *args: Dict[object, B.JAXNumeric],
    **kw_args,
):
    return B.generic._jit_run.invoke(
        FunctionType, dict, dict, B.JAXNumeric,
    )(f, compilation_cache, jit_kw_args, *args, **kw_args)


x = jnp.arange(10)
y = jnp.arange(10)
params = {"lengthscale": jnp.array(1.0)}
loss_fn = lambda params: GP(EQ().stretch(params["lengthscale"]))(x).logpdf(y)
grad_fn = B.jit(jax.grad(loss_fn))
print(grad_fn(params))
{'lengthscale': DeviceArray(50.12888031, dtype=float64)}

It would be possible to make a small amendment to LAB so that this is the default behaviour, if that would be desirable. :)

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Sep 9, 2022

Thanks for the solution, @wesselb. Due to the PyTree concept gaining popularity in JAX, dictionaries of parameters or lists and tuples of parameters or a mixture of these are also used sometimes as parameters. Is it possible to make B.jit work with any PyTree? Maybe PyTree from jaxtyping could be used somehow, but I am not sure how.

@wesselb
Copy link
Owner

wesselb commented Sep 9, 2022

I see! Hmm, this might be challenging. Dispatch currently heavily leverages types, and the type of a PyTree is somewhat troublesome. You're right that jaxtyping offers a PyTree type, but that type seems to only perform instance checking rather than containing the recursive type definition that we would like. I'll have to think about this! It agree that it would be super useful to support PyTrees.

@patel-zeel
Copy link
Contributor Author

Yes, from the links you have shared, it looks like a harder problem. Maybe there can be a hotfix specifically for lab and Stheno till a more generic solution is available.

@wesselb
Copy link
Owner

wesselb commented Sep 11, 2022

A hacky halfway house solution is to check for JAX-like objects in the first 10 (or so) layers of a PyTree... Then it wouldn't detect things like ((((((((((jnp.array(1),)))))))))), but maybe that's okay as a temporary solution. Do you think that would be reasonable? Or do you perhaps have another fix in mind?

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Sep 11, 2022

I think 10 or so layers seem practical for most applications. However, how difficult would it be to convert that static number 10 to a dynamic depth by recursively checking PyTree?

@wesselb
Copy link
Owner

wesselb commented Oct 18, 2022

I think that it would be possible to convert the static number to a dynamic depth. However, perhaps the right solution here is to see if we can actually give PyTrees first-class support. I’ll soon be working on a 2.0 of Plum, which is where currently the restrictions derive from. I will put PyTree support on the list of desired improvements!

@wesselb
Copy link
Owner

wesselb commented Apr 8, 2023

With the new version of Plum, I believe it should be possible to use the PyTree type from jaxtyping to get the desired behaviour. I'll open issues in the appropriate places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants