-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Beam Search for CPU and CUDA... Also include it in our API (#669)
### Summary Beam Search was hanging and not outputting correct results. Furthermore, it did not fit into our API design. This PR addresses the correctness and API problems with Beam search. We plan to improve both CPU and CUDA performance and memory efficiency in the near future. ### Issues Addressed Here is a quick summary of the issues addressed by this PR... these apply to both CPU and CUDA implementations: - No log-softmax normalization was performed before adding beam scores. This caused faulty outputs which did not match the ORT implementation. - The `is_done` flag was not set or checked properly in the case of EOS token or `max_sequence_length`. This caused hanging, infinite looping, and memory buffer overflow. This sometimes gave the impression of bad performance, while in reality it was a correctness issue. - `Finalize` was not called automatically. If a user didn't call it manually this could cause a floating point exception or other fault. - There was no easy way to get output from Beam Search. `Finalize` was clunky and unintuitive as it didn't fit with our API. - Our testing file was not up to date with our latest APIs. ### API Changes Given the issues with `Finalize`, this PR introduces an update to the way Beam Search fits into our API. The user no longer has to manually call `Finalize` in order to access the Beam Search results. These are returned automatically by the `Generate()` function and can be accessed using batch beam indexing. --------- Co-authored-by: Baiju Meswani <[email protected]>
- Loading branch information
1 parent
6166902
commit e41fb2c
Showing
16 changed files
with
213 additions
and
204 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.