-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: add support for scalar arguments in result_type
#805
Comments
This makes sense to me. torch seems to support this as well. What should the result be if there are multiple Python scalars? Undefined? |
This should indeed probably be undefined by the spec. In most cases I imagine array libraries will have a default dtype, but different libraries will make different choices (e.g., int32 in JAX vs int64 in NumPy):
|
One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So |
In Xarrray, we are thinking of defining something like: def as_shared_dtype(scalars_or_arrays):
xp = get_array_namespace(scalars_or_arrays)
dtype = xp.result_type(*scalars_or_arrays)
return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays) |
Does xarray automatically call asarray on scalar function arguments like NumPy does? Certainly the recommendation of the standard is to not do that, because it's cleaner from a typing perspective. Implicitly calling |
the only time we call that function is when preparing arguments for |
Xarray objects always contain array objects, but indeed there are functions like I opened a separate issue to discuss: #807 |
This sounds like a useful change to me.
What is the problem? It seems well-defined to allow multiple. If multiple arrays and dtype objects are allowed, why not multiple Python scalars? |
I'm not sure, but I think that was referring to a situation where you have no explicit dtypes, just (compatible) python scalars. In that case, we'd have to make an arbitrary choice (or raise an error). |
Ah of course. Agreed, there must be at least one array or dtype object. |
Making this change to |
result_type
The array API's type promotion rules support mixed scalar/array operations, e.g.,
1 + xp.arange(3)
.For Xarray, we would like to be able to figure out the resulting dtype from this sort of operation before actually doing it (pydata/xarray#8946).
Ideally, we could use
xp.result_type()
for this purpose, but as documented result_type only supports arrays and dtype objects. Could we potentially extendresult_type
to also handle Python scalars? It is worth noting that this already works today in NumPy, e.g.,The text was updated successfully, but these errors were encountered: