Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TypeIs for types with type params in Unions #17232

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(
always_covariant: bool = False,
ignore_promotions: bool = False,
# Proper subtype flags
erase_instances: bool = False,
keep_erased_types: bool = False,
options: Options | None = None,
) -> None:
Expand All @@ -104,7 +103,6 @@ def __init__(
self.ignore_declared_variance = ignore_declared_variance
self.always_covariant = always_covariant
self.ignore_promotions = ignore_promotions
self.erase_instances = erase_instances
self.keep_erased_types = keep_erased_types
self.options = options

Expand All @@ -114,7 +112,7 @@ def check_context(self, proper_subtype: bool) -> None:
if proper_subtype:
assert not self.ignore_pos_arg_names and not self.ignore_declared_variance
else:
assert not self.erase_instances and not self.keep_erased_types
assert not self.keep_erased_types


def is_subtype(
Expand Down Expand Up @@ -191,27 +189,23 @@ def is_proper_subtype(
*,
subtype_context: SubtypeContext | None = None,
ignore_promotions: bool = False,
erase_instances: bool = False,
keep_erased_types: bool = False,
) -> bool:
"""Is left a proper subtype of right?

For proper subtypes, there's no need to rely on compatibility due to
Any types. Every usable type is a proper subtype of itself.

If erase_instances is True, erase left instance *after* mapping it to supertype
(this is useful for runtime isinstance() checks). If keep_erased_types is True,
do not consider ErasedType a subtype of all types (used by type inference against unions).
If keep_erased_types is True, do not consider ErasedType a subtype
of all types (used by type inference against unions).
"""
if subtype_context is None:
subtype_context = SubtypeContext(
ignore_promotions=ignore_promotions,
erase_instances=erase_instances,
keep_erased_types=keep_erased_types,
ignore_promotions=ignore_promotions, keep_erased_types=keep_erased_types
)
else:
assert not any(
{ignore_promotions, erase_instances, keep_erased_types}
{ignore_promotions, keep_erased_types}
), "Don't pass both context and individual flags"
if type_state.is_assumed_proper_subtype(left, right):
return True
Expand Down Expand Up @@ -403,7 +397,6 @@ def build_subtype_kind(subtype_context: SubtypeContext, proper_subtype: bool) ->
subtype_context.ignore_declared_variance,
subtype_context.always_covariant,
subtype_context.ignore_promotions,
subtype_context.erase_instances,
subtype_context.keep_erased_types,
)

Expand Down Expand Up @@ -527,10 +520,6 @@ def visit_instance(self, left: Instance) -> bool:
) and not self.subtype_context.ignore_declared_variance:
# Map left type to corresponding right instances.
t = map_instance_to_supertype(left, right.type)
if self.subtype_context.erase_instances:
erased = erase_type(t)
assert isinstance(erased, Instance)
t = erased
nominal = True
if right.type.has_type_var_tuple_type:
# For variadic instances we simply find the correct type argument mappings,
Expand Down Expand Up @@ -1929,7 +1918,8 @@ def restrict_subtype_away(t: Type, s: Type) -> Type:
ideal result (just t is a valid result).

This is used for type inference of runtime type checks such as
isinstance(). Currently, this just removes elements of a union type.
isinstance() or TypeIs. Currently, this just removes elements
of a union type.
"""
p_t = get_proper_type(t)
if isinstance(p_t, UnionType):
Expand All @@ -1938,46 +1928,66 @@ def restrict_subtype_away(t: Type, s: Type) -> Type:
new_items = [
restrict_subtype_away(item, s)
for item in p_t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or not covers_at_runtime(item, s))
if isinstance(get_proper_type(item), UnionType) or not covers_type(item, s)
]
return UnionType.make_union(new_items)
elif isinstance(p_t, TypeVarType):
return p_t.copy_modified(upper_bound=restrict_subtype_away(p_t.upper_bound, s))
elif covers_at_runtime(t, s):
elif covers_type(t, s):
return UninhabitedType()
else:
return t


def covers_at_runtime(item: Type, supertype: Type) -> bool:
"""Will isinstance(item, supertype) always return True at runtime?"""
def covers_type(item: Type, supertype: Type) -> bool:
"""Returns if item is covered by supertype.

Any types (or fallbacks to any) should never cover or be covered.

Assumes that item is not a Union type.

Examples:
int covered by int
List[int] covered by List[Any]
A covered by Union[A, Any]
Any NOT covered by int
int NOT covered by Any
"""
item = get_proper_type(item)
supertype = get_proper_type(supertype)

# Since runtime type checks will ignore type arguments, erase the types.
supertype = erase_type(supertype)
if is_proper_subtype(
erase_type(item), supertype, ignore_promotions=True, erase_instances=True
assert not isinstance(item, UnionType)

# Handle possible Any types that should not be covered:
if isinstance(item, AnyType) or isinstance(supertype, AnyType):
return False
elif (isinstance(item, Instance) and item.type.fallback_to_any) or (
isinstance(supertype, Instance) and supertype.type.fallback_to_any
):
return True
if isinstance(supertype, Instance):
return is_same_type(item, supertype)

if isinstance(supertype, UnionType):
# Special case that cannot be handled by is_subtype, because it would
# not ignore the Any types:
return any(covers_type(item, t) for t in supertype.relevant_items())
elif isinstance(supertype, Instance):
if supertype.type.is_protocol:
# TODO: Implement more robust support for runtime isinstance() checks, see issue #3827.
if is_proper_subtype(item, supertype, ignore_promotions=True):
if is_proper_subtype(item, erase_type(supertype), ignore_promotions=True):
return True
if isinstance(item, TypedDictType):
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
if supertype.type.fullname == "builtins.dict":
return True
elif isinstance(item, TypeVarType):
if is_proper_subtype(item.upper_bound, supertype, ignore_promotions=True):
if is_proper_subtype(item.upper_bound, erase_type(supertype), ignore_promotions=True):
return True
elif isinstance(item, Instance) and supertype.type.fullname == "builtins.int":
# "int" covers all native int types
if item.type.fullname in MYPYC_NATIVE_INT_NAMES:
return True
# TODO: Add more special cases.
return False

return is_subtype(item, supertype, ignore_promotions=True)


def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool:
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,23 @@ else:
reveal_type(z) # N: Revealed type is "Any"
[builtins fixtures/isinstance.pyi]

[case testIsinstanceSubclassAny]
from typing import Any, Union
X: Any
class BadParent(X): pass
class GoodParent(object): pass
a: Union[GoodParent, BadParent]
if isinstance(a, BadParent):
reveal_type(a) # N: Revealed type is "__main__.BadParent"
else:
reveal_type(a) # N: Revealed type is "Union[__main__.GoodParent, __main__.BadParent]"
b: Union[int, BadParent]
if isinstance(b, (X, GoodParent)):
reveal_type(b) # N: Revealed type is "Union[Any, __main__.BadParent]"
else:
reveal_type(b) # N: Revealed type is "Union[builtins.int, __main__.BadParent]"
[builtins fixtures/isinstance.pyi]

[case testIsInstanceInitialNoneCheckSkipsImpossibleCasesNoStrictOptional]
from typing import Optional, Union

Expand Down
35 changes: 35 additions & 0 deletions test-data/unit/check-typeis.test
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def main(a: object) -> None:
reveal_type(a) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsUnionWithTypeParams]
from typing_extensions import TypeIs
from typing import Iterable, List, Union
def is_iterable_int(val: object) -> TypeIs[Iterable[int]]: pass
def main(a: Union[List[int], List[str]]) -> None:
if is_iterable_int(a):
reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]"
else:
reveal_type(a) # N: Revealed type is "builtins.list[builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsNonzeroFloat]
from typing_extensions import TypeIs
def is_nonzero(a: object) -> TypeIs[float]: pass
Expand Down Expand Up @@ -155,6 +166,30 @@ class C:
def is_float(self, a: object) -> TypeIs[float]: pass
[builtins fixtures/tuple.pyi]

[case testTypeIsTypeAny]
from typing_extensions import TypeIs
from typing import Any, Type, Union
class A: ...
def is_class(x: object) -> TypeIs[Type[Any]]: ...
def main(a: Union[A, Type[A]]) -> None:
if is_class(a):
reveal_type(a) # N: Revealed type is "Type[Any]"
else:
reveal_type(a) # N: Revealed type is "__main__.A"
[builtins fixtures/tuple.pyi]

[case testTypeIsAwaitableAny]
from typing_extensions import TypeIs
from typing import Any, Awaitable, TypeVar, Union
T = TypeVar('T')
def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: pass
def main(a: Union[Awaitable[T], T]) -> None:
if is_awaitable(a):
reveal_type(a) # N: Revealed type is "Union[typing.Awaitable[T`-1], typing.Awaitable[Any]]"
else:
reveal_type(a) # N: Revealed type is "T`-1"
[builtins fixtures/tuple.pyi]

[case testTypeIsCrossModule]
import guard
from points import Point
Expand Down
Loading