Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add support for scalar arguments in result_type #805

Open
shoyer opened this issue May 15, 2024 · 11 comments · May be fixed by #873
Open

RFC: add support for scalar arguments in result_type #805

shoyer opened this issue May 15, 2024 · 11 comments · May be fixed by #873
Labels
API change Changes to existing functions or objects in the API. topic: Type Promotion Type promotion.
Milestone

Comments

@shoyer
Copy link
Contributor

shoyer commented May 15, 2024

The array API's type promotion rules support mixed scalar/array operations, e.g., 1 + xp.arange(3).

For Xarray, we would like to be able to figure out the resulting dtype from this sort of operation before actually doing it (pydata/xarray#8946).

Ideally, we could use xp.result_type() for this purpose, but as documented result_type only supports arrays and dtype objects. Could we potentially extend result_type to also handle Python scalars? It is worth noting that this already works today in NumPy, e.g.,

>>> np.result_type(1, np.arange(3))
dtype('int64')
@asmeurer
Copy link
Member

This makes sense to me. torch seems to support this as well. What should the result be if there are multiple Python scalars? Undefined?

@shoyer
Copy link
Contributor Author

shoyer commented May 15, 2024

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

In most cases I imagine array libraries will have a default dtype, but different libraries will make different choices (e.g., int32 in JAX vs int64 in NumPy):

>> np.result_type(1, 2)
dtype('int64')
>> jnp.result_type(1, 2)
dtype('int32')

@asmeurer
Copy link
Member

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

@shoyer
Copy link
Contributor Author

shoyer commented May 15, 2024

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

In Xarrray, we are thinking of defining something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

@asmeurer
Copy link
Member

asmeurer commented May 15, 2024

Does xarray automatically call asarray on scalar function arguments like NumPy does? Certainly the recommendation of the standard is to not do that, because it's cleaner from a typing perspective. Implicitly calling asarray at the top of every function is considered a historical NumPy antipattern. It's not disallowed, but we also should probably avoid standardizing things that encourage it.

@keewis
Copy link

keewis commented May 15, 2024

the only time we call that function is when preparing arguments for where (and for concat / stack, but there we don't expect to encounter python scalars), which as far as I can tell doesn't support python scalars.

@shoyer
Copy link
Contributor Author

shoyer commented May 15, 2024

Xarray objects always contain array objects, but indeed there are functions like where() for which it's convenient to be able to use scalars.

I opened a separate issue to discuss: #807

@rgommers rgommers added the topic: Type Promotion Type promotion. label May 17, 2024
@rgommers
Copy link
Member

This sounds like a useful change to me.

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

What is the problem? It seems well-defined to allow multiple. If multiple arrays and dtype objects are allowed, why not multiple Python scalars?

@keewis
Copy link

keewis commented May 17, 2024

I'm not sure, but I think that was referring to a situation where you have no explicit dtypes, just (compatible) python scalars. In that case, we'd have to make an arbitrary choice (or raise an error).

@rgommers
Copy link
Member

Ah of course. Agreed, there must be at least one array or dtype object.

@rgommers
Copy link
Member

Making this change to result_type seemed fair to everyone in the discussion we just had. Given that our type promotion rules include Python scalars, the function that can be used to apply those promotion rules should support them as well.

@rgommers rgommers added this to the v2024 milestone Dec 11, 2024
@kgryte kgryte changed the title result_type() for mixed arrays/Python scalars RFC: add support for scalar arguments in result_type Dec 12, 2024
@kgryte kgryte added the API change Changes to existing functions or objects in the API. label Dec 12, 2024
kgryte added a commit to kgryte/array-api that referenced this issue Dec 12, 2024
@kgryte kgryte linked a pull request Dec 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API change Changes to existing functions or objects in the API. topic: Type Promotion Type promotion.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants