-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix torch.linspace #2416
base: main
Are you sure you want to change the base?
fix torch.linspace #2416
Conversation
Thanks for looking into this. Please add your unit test to the To get things running locally, see our Building from Source Document. |
This existing test fails: Here is a minimal example when the inf start coming in arange = mb.range_1d(start=0.0, end=2_000_000.0, step=1.0) # no infs
res = mb.add(x=arange, y=1.0) # now we have infs! (same also if we do mul instead of add) Do you have an idea how to avoid that? Based on the test name "fp16", this sounds somewhat expected: 2_000_000 > 65_504? |
The |
Yes, it works on main and fails on this PR:
This PR is closer to the pytorch and numpy behavior: import numpy as np
import torch
torch.linspace(0, 2_000_000, 2_000_000).to(dtype=torch.float16)
# tensor([0., 1., 2., ..., inf, inf, inf], dtype=torch.float16)
torch.linspace(0, 2_000_000, 2_000_000, dtype=torch.float16)
# RuntimeError: value cannot be converted to type at::Half without overflow
np.linspace(0, 2_000_000, 2_000_000, dtype=np.float16)
# array([ 0., 1., 2., ..., inf, inf, inf], dtype=float16)
np.linspace(0, 2_000_000, 2_000_000).astype(np.float16)
# array([ 0., 1., 2., ..., inf, inf, inf], dtype=float16) It migt be good to explicilty cast the actual and expected output in edit: this static test ends up in the dynamic code because |
I extended an existing test to trigger the reported bug on main (these tests pass on this PR):
|
I xfailed the test on fp16 for now. Converting the pytorch expected results to float16 in
|
fixes #2412
@TobyRoseman where should I add testcases for this? I grep'ed through the tests, there are currently none for
torch.linspace
. And having quick tutorial how to get the test locally running would be great, too :) GettingRuntimeError: BlobWriter not loaded
when callingct.convert
.