You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX is quite strict about buffer sizes, and so it's generally expected that a user applies models with batch size equal to ModelArgs.max_batch_size. Yet, we neither forbid different batch sizes, nor handle it properly. In some cases, different batch size leads to matrix size issues, in other cases the returned array has incorrect size, etc.
We need to either validate input tensors, or handle cases of batch size < max_batch_size properly.
The text was updated successfully, but these errors were encountered:
JAX is quite strict about buffer sizes, and so it's generally expected that a user applies models with batch size equal to
ModelArgs.max_batch_size
. Yet, we neither forbid different batch sizes, nor handle it properly. In some cases, different batch size leads to matrix size issues, in other cases the returned array has incorrect size, etc.We need to either validate input tensors, or handle cases of batch size < max_batch_size properly.
The text was updated successfully, but these errors were encountered: