diff --git a/src/azul/indexer/__init__.py b/src/azul/indexer/__init__.py index 7699275666..2b6a618b38 100644 --- a/src/azul/indexer/__init__.py +++ b/src/azul/indexer/__init__.py @@ -19,13 +19,9 @@ Optional, Type, TypeVar, - get_args, ) import attr -from more_itertools import ( - one, -) from azul import ( reject, @@ -35,6 +31,7 @@ MutableJSON, MutableJSONs, SupportsLessThan, + get_generic_type_params, ) from azul.uuids import ( UUIDPartition, @@ -350,8 +347,7 @@ def from_json(cls, ref: JSON) -> 'SourceRef': @classmethod def spec_cls(cls) -> Type[SourceSpec]: - base_cls = one(getattr(cls, '__orig_bases__')) - spec_cls, ref_cls = get_args(base_cls) + spec_cls, ref_cls = get_generic_type_params(cls, SourceSpec, SourceRef) return spec_cls diff --git a/src/azul/plugins/__init__.py b/src/azul/plugins/__init__.py index 2f5d7a20e2..2869746420 100644 --- a/src/azul/plugins/__init__.py +++ b/src/azul/plugins/__init__.py @@ -24,19 +24,14 @@ TypeVar, TypedDict, Union, - get_args, ) import attr -from more_itertools import ( - one, -) from azul import ( CatalogName, cached_property, config, - require, ) from azul.chalice import ( Authentication, @@ -51,8 +46,6 @@ Bundle, SOURCE_REF, SOURCE_SPEC, - SourceRef, - SourceSpec, SourcedBundleFQID, ) from azul.indexer.document import ( @@ -65,6 +58,7 @@ JSON, JSONs, MutableJSON, + get_generic_type_params, ) if TYPE_CHECKING: @@ -446,10 +440,7 @@ def list_sources(self, @cached_property def _source_ref_cls(self) -> Type[SOURCE_REF]: cls = type(self) - base_cls = one(getattr(cls, '__orig_bases__')) - spec_cls, ref_cls = get_args(base_cls) - require(issubclass(spec_cls, SourceSpec)) - require(issubclass(ref_cls, SourceRef)) + spec_cls, ref_cls = get_generic_type_params(cls) assert ref_cls.spec_cls() is spec_cls return ref_cls diff --git a/src/azul/types.py b/src/azul/types.py index 2c8c1b0993..b667bdcc3c 100644 --- a/src/azul/types.py +++ b/src/azul/types.py @@ -4,6 +4,7 @@ ) from typing import ( Any, + Generic, Optional, Protocol, TYPE_CHECKING, @@ -12,6 +13,8 @@ get_origin, ) +from more_itertools import one + PrimitiveJSON = Union[str, int, float, bool, None] # Not every instance of Mapping or Sequence can be fed to json.dump() but those @@ -158,6 +161,14 @@ def f(t): return tuple(set(f(t))) +def get_generic_type_params(cls: type[Generic], *required_types: type): + base_cls = one(getattr(cls, '__orig_bases__')) + types = get_args(base_cls) + for required_type, type_ in zip(required_types, types): + assert issubclass(type_, required_type) + return types + + # FIXME: Remove hacky import of SupportsLessThan # https://github.com/DataBiosphere/azul/issues/2783 if TYPE_CHECKING: