Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove QuerySet alias hacks via PEP 696 TypeVar defaults #2104

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,31 +254,6 @@ func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.annotate(bar=Value("")).get(id=1)) # Error
```

### How do I check if something is an instance of QuerySet in runtime?

A limitation of making `QuerySet` generic is that you can not use
it for `isinstance` checks.

```python
from django.db.models.query import QuerySet

def foo(obj: object) -> None:
if isinstance(obj, QuerySet): # Error: Parameterized generics cannot be used with class or instance checks
...
```

To get around with this issue without making `QuerySet` non-generic,
Django-stubs provides `django_stubs_ext.QuerySetAny`, a non-generic
variant of `QuerySet` suitable for runtime type checking:

```python
from django_stubs_ext import QuerySetAny

def foo(obj: object) -> None:
if isinstance(obj, QuerySetAny): # OK
...
```

### Why am I getting incompatible argument type mentioning `_StrPromise`?

The lazy translation functions of Django (such as `gettext_lazy`) return a `Promise` instead of `str`. These two types [cannot be used interchangeably](https://github.com/typeddjango/django-stubs/pull/1139#issuecomment-1232167698). The return type of these functions was therefore [changed](https://github.com/typeddjango/django-stubs/pull/689) to reflect that.
Expand Down
12 changes: 4 additions & 8 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ from django.db.models.expressions import Combinable, OrderBy
from django.db.models.query import QuerySet, RawQuerySet
from typing_extensions import Self

from django_stubs_ext import ValuesQuerySet

_T = TypeVar("_T", bound=Model, covariant=True)

class BaseManager(Generic[_T]):
Expand Down Expand Up @@ -107,15 +105,13 @@ class BaseManager(Generic[_T]):
using: str | None = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: str | Combinable, **expressions: Any) -> ValuesQuerySet[_T, dict[str, Any]]: ...
def values(self, *fields: str | Combinable, **expressions: Any) -> QuerySet[_T, dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: str | Combinable, flat: bool = ..., named: bool = ...
) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: datetime.tzinfo | None = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
) -> QuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
Expand Down
96 changes: 49 additions & 47 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import datetime
from collections.abc import AsyncIterator, Collection, Iterable, Iterator, MutableMapping, Sequence, Sized
from typing import Any, Generic, NamedTuple, TypeVar, overload
from typing import Any, Generic, NamedTuple, overload

from django.db.backends.utils import _ExecuteQuery
from django.db.models import Manager
from django.db.models.base import Model
from django.db.models.expressions import Combinable, OrderBy
from django.db.models.sql.query import Query, RawQuery
from django.utils.functional import cached_property
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, TypeVar

_T = TypeVar("_T", bound=Model, covariant=True)
_Row = TypeVar("_Row", covariant=True)
_T = TypeVar("_T", covariant=True)
_Model = TypeVar("_Model", bound=Model, covariant=True)
_Row = TypeVar("_Row", covariant=True, default=_Model) # ONLY use together with _Model
_QS = TypeVar("_QS", bound=_QuerySet)
_TupleT = TypeVar("_TupleT", bound=tuple[Any, ...], covariant=True)

MAX_GET_RESULTS: int
REPR_OUTPUT_SIZE: int

class BaseIterable(Generic[_Row]):
class BaseIterable(Generic[_T]):
queryset: QuerySet[Model]
chunked_fetch: bool
chunk_size: int
def __init__(self, queryset: QuerySet[Model], chunked_fetch: bool = ..., chunk_size: int = ...) -> None: ...
def __aiter__(self) -> AsyncIterator[_Row]: ...
def __aiter__(self) -> AsyncIterator[_T]: ...

class ModelIterable(Generic[_T], BaseIterable[_T]):
def __iter__(self) -> Iterator[_T]: ...
class ModelIterable(Generic[_Model], BaseIterable[_Model]):
def __iter__(self) -> Iterator[_Model]: ...

class RawModelIterable(BaseIterable[dict[str, Any]]):
def __iter__(self) -> Iterator[dict[str, Any]]: ...
Expand All @@ -40,11 +41,11 @@ class ValuesListIterable(BaseIterable[_TupleT]):
class NamedValuesListIterable(ValuesListIterable[NamedTuple]):
def __iter__(self) -> Iterator[NamedTuple]: ...

class FlatValuesListIterable(BaseIterable[_Row]):
def __iter__(self) -> Iterator[_Row]: ...
class FlatValuesListIterable(BaseIterable[_T]):
def __iter__(self) -> Iterator[_T]: ...

class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
model: type[_T]
class QuerySet(Generic[_Model, _Row], Iterable[_Row], Sized):
model: type[_Model]
query: Query
_iterable_class: type[BaseIterable]
_result_cache: list[_Row] | None
Expand All @@ -56,14 +57,14 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
hints: dict[str, Model] | None = ...,
) -> None: ...
@classmethod
def as_manager(cls) -> Manager[_T]: ...
def as_manager(cls) -> Manager[_Model]: ...
def __len__(self) -> int: ...
def __bool__(self) -> bool: ...
def __class_getitem__(cls: type[_QS], item: type[_T]) -> type[_QS]: ...
def __class_getitem__(cls: type[_QS], item: type[_Model]) -> type[_QS]: ...
def __getstate__(self) -> dict[str, Any]: ...
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self, other: _QuerySet[_T, _Row]) -> Self: ...
def __or__(self, other: _QuerySet[_T, _Row]) -> Self: ...
def __and__(self, other: QuerySet[_Model, _Row]) -> Self: ...
def __or__(self, other: QuerySet[_Model, _Row]) -> Self: ...
# IMPORTANT: When updating any of the following methods' signatures, please ALSO modify
# the corresponding method in BaseManager.
def iterator(self, chunk_size: int | None = ...) -> Iterator[_Row]: ...
Expand All @@ -72,44 +73,46 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
async def aaggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
async def aget(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ...
async def acreate(self, **kwargs: Any) -> _T: ...
def create(self, **kwargs: Any) -> _Model: ...
async def acreate(self, **kwargs: Any) -> _Model: ...
def bulk_create(
self,
objs: Iterable[_T],
objs: Iterable[_Model],
batch_size: int | None = ...,
ignore_conflicts: bool = ...,
update_conflicts: bool = ...,
update_fields: Collection[str] | None = ...,
unique_fields: Collection[str] | None = ...,
) -> list[_T]: ...
) -> list[_Model]: ...
async def abulk_create(
self,
objs: Iterable[_T],
objs: Iterable[_Model],
batch_size: int | None = ...,
ignore_conflicts: bool = ...,
update_conflicts: bool = ...,
update_fields: Collection[str] | None = ...,
unique_fields: Collection[str] | None = ...,
) -> list[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
async def abulk_update(self, objs: Iterable[_T], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
def get_or_create(self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any) -> tuple[_T, bool]: ...
) -> list[_Model]: ...
def bulk_update(self, objs: Iterable[_Model], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
async def abulk_update(
self, objs: Iterable[_Model], fields: Iterable[str], batch_size: int | None = ...
) -> int: ...
def get_or_create(self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any) -> tuple[_Model, bool]: ...
async def aget_or_create(
self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
def update_or_create(
self,
defaults: MutableMapping[str, Any] | None = ...,
create_defaults: MutableMapping[str, Any] | None = ...,
**kwargs: Any,
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
async def aupdate_or_create(
self,
defaults: MutableMapping[str, Any] | None = ...,
create_defaults: MutableMapping[str, Any] | None = ...,
**kwargs: Any,
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
def earliest(self, *fields: str | OrderBy) -> _Row: ...
async def aearliest(self, *fields: str | OrderBy) -> _Row: ...
def latest(self, *fields: str | OrderBy) -> _Row: ...
Expand All @@ -118,8 +121,8 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
async def afirst(self) -> _Row | None: ...
def last(self) -> _Row | None: ...
async def alast(self) -> _Row | None: ...
def in_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _T]: ...
async def ain_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _T]: ...
def in_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _Model]: ...
async def ain_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _Model]: ...
def delete(self) -> tuple[int, dict[str, int]]: ...
async def adelete(self) -> tuple[int, dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -138,13 +141,13 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
using: str | None = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: str | Combinable, **expressions: Any) -> _QuerySet[_T, dict[str, Any]]: ...
def values(self, *fields: str | Combinable, **expressions: Any) -> QuerySet[_Model, dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> _QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _QuerySet[_T, datetime.date]: ...
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> QuerySet[_Model, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_Model, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: datetime.tzinfo | None = ...
) -> _QuerySet[_T, datetime.datetime]: ...
) -> QuerySet[_Model, datetime.datetime]: ...
def none(self) -> Self: ...
def all(self) -> Self: ...
def filter(self, *args: Any, **kwargs: Any) -> Self: ...
Expand Down Expand Up @@ -173,7 +176,7 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
tables: Sequence[str] | None = ...,
order_by: Sequence[str] | None = ...,
select_params: Sequence[Any] | None = ...,
) -> _QuerySet[Any, Any]: ...
) -> QuerySet[Any, Any]: ...
def reverse(self) -> Self: ...
def defer(self, *fields: Any) -> Self: ...
def only(self, *fields: Any) -> Self: ...
Expand All @@ -192,7 +195,7 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
def __getitem__(self, s: slice) -> Self: ...
def __reversed__(self) -> Iterator[_Row]: ...

class RawQuerySet(Iterable[_T], Sized):
class RawQuerySet(Iterable[_Model], Sized):
query: RawQuery
def __init__(
self,
Expand All @@ -205,28 +208,27 @@ class RawQuerySet(Iterable[_T], Sized):
hints: dict[str, Model] | None = ...,
) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __iter__(self) -> Iterator[_Model]: ...
def __bool__(self) -> bool: ...
@overload
def __getitem__(self, k: int) -> _T: ...
def __getitem__(self, k: int) -> _Model: ...
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> RawQuerySet[_T]: ...
def __getitem__(self, k: slice) -> RawQuerySet[_Model]: ...
@cached_property
def columns(self) -> list[str]: ...
@property
def db(self) -> str: ...
def iterator(self) -> Iterator[_T]: ...
def iterator(self) -> Iterator[_Model]: ...
@cached_property
def model_fields(self) -> dict[str, str]: ...
def prefetch_related(self, *lookups: Any) -> RawQuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> RawQuerySet[_Model]: ...
def resolve_model_init_order(self) -> tuple[list[str], list[int], list[tuple[str, int]]]: ...
def using(self, alias: str | None) -> RawQuerySet[_T]: ...

_QuerySetAny: TypeAlias = _QuerySet # noqa: PYI047
def using(self, alias: str | None) -> RawQuerySet[_Model]: ...

QuerySet: TypeAlias = _QuerySet[_T, _T]
# Deprecated alias of QuerySet, for compatibility only.
_QuerySet: TypeAlias = QuerySet

class Prefetch:
prefetch_through: str
Expand All @@ -240,8 +242,8 @@ class Prefetch:
def get_current_to_attr(self, level: int) -> tuple[str, str]: ...
def get_current_queryset(self, level: int) -> QuerySet | None: ...

def prefetch_related_objects(model_instances: Iterable[_T], *related_lookups: str | Prefetch) -> None: ...
async def aprefetch_related_objects(model_instances: Iterable[_T], *related_lookups: str | Prefetch) -> None: ...
def prefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...
async def aprefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...
def get_prefetcher(instance: Model, through_attr: str, to_attr: str) -> tuple[Any, Any, bool, bool]: ...

class InstanceCheckMeta(type): ...
Expand Down
4 changes: 3 additions & 1 deletion ext/django_stubs_ext/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from django.utils.functional import _StrOrPromise as StrOrPromise
from django.utils.functional import _StrPromise as StrPromise

# Deprecated type aliases. Use the QuerySet class directly instead.
QuerySetAny = _QuerySet
ValuesQuerySet = _QuerySet
else:
from django.db.models.query import QuerySet
from django.utils.functional import Promise as StrPromise

StrOrPromise = typing.Union[str, StrPromise]
# Deprecated type aliases. Use the QuerySet class directly instead.
QuerySetAny = QuerySet
ValuesQuerySet = QuerySet
StrOrPromise = typing.Union[str, StrPromise]

__all__ = ["StrOrPromise", "StrPromise", "QuerySetAny", "ValuesQuerySet"]
2 changes: 1 addition & 1 deletion mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject"
AUTH_USER_MODEL_FULLNAME = "django.conf.settings.AUTH_USER_MODEL"

QUERYSET_CLASS_FULLNAME = "django.db.models.query._QuerySet"
QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet"
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
RELATED_MANAGER_CLASS = "django.db.models.fields.related_descriptors.RelatedManager"
Expand Down
2 changes: 1 addition & 1 deletion mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def example(self, a: T2) -> T_2: ...
return False

if type_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
# If it is a subclass of _QuerySet, it is compatible.
# If it is a subclass of QuerySet, it is compatible.
return True
# check that at least one base is a subclass of queryset with Generic type vars
return any(_has_compatible_type_vars(sub_base.type) for sub_base in type_info.bases)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def find_stub_files(name: str) -> List[str]:
"django-stubs-ext>=5.0.0",
"tomli; python_version < '3.11'",
# Types:
"typing-extensions",
"typing-extensions>=4.11.0",
"types-PyYAML",
]

Expand Down
8 changes: 4 additions & 4 deletions tests/typecheck/contrib/admin/test_decorators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
def method_action_invalid_fancy(self, request: HttpRequest, queryset: int) -> None: ...

def method(self) -> None:
reveal_type(self.method_action_bare) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_fancy) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_http_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel]) -> django.http.response.HttpResponse"
reveal_type(self.method_action_file_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel]) -> django.http.response.FileResponse"
reveal_type(self.method_action_bare) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_fancy) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_http_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel]) -> django.http.response.HttpResponse"
reveal_type(self.method_action_file_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel]) -> django.http.response.FileResponse"
2 changes: 1 addition & 1 deletion tests/typecheck/contrib/admin/test_options.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
pass

class A(admin.ModelAdmin):
actions = [an_action] # E: List item 0 has incompatible type "Callable[[None], None]"; expected "Union[Callable[[Any, HttpRequest, _QuerySet[Any, Any]], Optional[HttpResponseBase]], str]" [list-item]
actions = [an_action] # E: List item 0 has incompatible type "Callable[[None], None]"; expected "Union[Callable[[Any, HttpRequest, QuerySet[Any, Any]], Optional[HttpResponseBase]], str]" [list-item]
- case: errors_for_invalid_model_admin_generic
main: |
from django.contrib.admin import ModelAdmin
Expand Down
2 changes: 1 addition & 1 deletion tests/typecheck/contrib/sitemaps/test_generic_sitemap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
main:26: error: Argument 1 of "location" is incompatible with supertype "Sitemap"; supertype defines the argument type as "Offer" [override]
main:26: note: This violates the Liskov substitution principle
main:26: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
main:40: error: Argument 1 to "GenericSitemap" has incompatible type "Dict[str, List[int]]"; expected "Mapping[str, Union[datetime, _QuerySet[Offer, Offer], str]]" [arg-type]
main:40: error: Argument 1 to "GenericSitemap" has incompatible type "Dict[str, List[int]]"; expected "Mapping[str, Union[datetime, QuerySet[Offer, Offer], str]]" [arg-type]

installed_apps:
- myapp
Expand Down
Loading
Loading