|
4 | 4 | import dataclasses |
5 | 5 | import sys |
6 | 6 |
|
7 | | -from typing import Any, Dict, List, Optional, Union |
| 7 | +from typing import Any, Dict, List, Optional, Tuple, Union |
8 | 8 |
|
| 9 | +from pytype import utils |
9 | 10 | from pytype.pyi import classdef |
10 | 11 | from pytype.pyi import metadata |
11 | 12 | from pytype.pyi import types |
|
19 | 20 | from pytype.pytd.codegen import namedtuple |
20 | 21 | from pytype.pytd.codegen import pytdgen |
21 | 22 | from pytype.pytd.parse import node as pytd_node |
| 23 | +from pytype.pytd.parse import parser_constants |
22 | 24 |
|
23 | 25 | # pylint: disable=g-import-not-at-top |
24 | 26 | if sys.version_info >= (3, 8): |
|
31 | 33 | # Typing members that represent sets of types. |
32 | 34 | _TYPING_SETS = ("typing.Intersection", "typing.Optional", "typing.Union") |
33 | 35 |
|
34 | | -# Aliases for some typing.X types |
35 | | -_ANNOTATED_TYPES = ("typing.Annotated", "typing_extensions.Annotated") |
36 | | -_CALLABLE_TYPES = ("typing.Callable", "collections.abc.Callable") |
37 | | -_CONCATENATE_TYPES = ("typing.Concatenate", "typing_extensions.Concatenate") |
38 | | -_LITERAL_TYPES = ("typing.Literal", "typing_extensions.Literal") |
39 | | -_TUPLE_TYPES = ("tuple", "builtins.tuple", "typing.Tuple") |
40 | | - |
41 | 36 |
|
42 | 37 | class StringParseError(ParseError): |
43 | 38 | pass |
@@ -493,23 +488,43 @@ def add_import(self, from_package, import_list): |
493 | 488 | if t: |
494 | 489 | self.aliases[t.new_name] = t.pytd_alias() |
495 | 490 |
|
496 | | - def _matches_full_name(self, t, full_name): |
497 | | - """Whether t.name matches full_name in format {module}.{member}.""" |
498 | | - return pytd_utils.MatchesFullName( |
499 | | - t, full_name, self.module_info.module_name, self.aliases) |
| 491 | + def _resolve_alias(self, name: str) -> str: |
| 492 | + if name in self.aliases: |
| 493 | + alias = self.aliases[name].type |
| 494 | + if isinstance(alias, pytd.NamedType): |
| 495 | + name = alias.name |
| 496 | + elif isinstance(alias, pytd.Module): |
| 497 | + name = alias.module_name |
| 498 | + return name |
| 499 | + |
| 500 | + def matches_type(self, name: str, target: Union[str, Tuple[str, ...]]): |
| 501 | + """Checks whether 'name' matches the 'target' type.""" |
| 502 | + if isinstance(target, tuple): |
| 503 | + return any(self.matches_type(name, t) for t in target) |
| 504 | + assert "." in target, "'target' must be a fully qualified type name" |
| 505 | + if "." in name: |
| 506 | + prefix, name_base = name.rsplit(".", 1) |
| 507 | + name = f"{self._resolve_alias(prefix)}.{name_base}" |
| 508 | + else: |
| 509 | + name = self._resolve_alias(name) |
| 510 | + name = utils.strip_prefix(name, parser_constants.EXTERNAL_NAME_PREFIX) |
| 511 | + if name == target: |
| 512 | + return True |
| 513 | + module, target_base = target.rsplit(".", 1) |
| 514 | + if name == target_base: |
| 515 | + return True |
| 516 | + if module == "builtins": |
| 517 | + return self.matches_type(name, f"typing.{target_base.title()}") |
| 518 | + equivalent_modules = {"typing", "collections.abc", "typing_extensions"} |
| 519 | + if module not in equivalent_modules: |
| 520 | + return False |
| 521 | + return any(name == f"{mod}.{target_base}" for mod in equivalent_modules) |
500 | 522 |
|
501 | 523 | def _matches_named_type(self, t, names): |
502 | 524 | """Whether t is a NamedType matching any of names.""" |
503 | 525 | if not isinstance(t, pytd.NamedType): |
504 | 526 | return False |
505 | | - for name in names: |
506 | | - if "." in name: |
507 | | - if self._matches_full_name(t, name): |
508 | | - return True |
509 | | - else: |
510 | | - if t.name == name: |
511 | | - return True |
512 | | - return False |
| 527 | + return self.matches_type(t.name, names) |
513 | 528 |
|
514 | 529 | def _is_empty_tuple(self, t): |
515 | 530 | return isinstance(t, pytd.TupleType) and not t.parameters |
@@ -551,22 +566,22 @@ def _remove_unsupported_features(self, parameters): |
551 | 566 |
|
552 | 567 | def _parameterized_type(self, base_type: Any, parameters): |
553 | 568 | """Return a parameterized type.""" |
554 | | - if self._matches_named_type(base_type, _LITERAL_TYPES): |
| 569 | + if self._matches_named_type(base_type, "typing.Literal"): |
555 | 570 | return pytd_literal(parameters, self.aliases) |
556 | | - elif self._matches_named_type(base_type, _ANNOTATED_TYPES): |
| 571 | + elif self._matches_named_type(base_type, "typing.Annotated"): |
557 | 572 | return pytd_annotated(parameters) |
558 | 573 | self._verify_no_literal_parameters(base_type, parameters) |
559 | 574 | arg_is_paramspec = False |
560 | | - if self._matches_named_type(base_type, _TUPLE_TYPES): |
| 575 | + if self._matches_named_type(base_type, "builtins.tuple"): |
561 | 576 | if len(parameters) == 2 and parameters[1] is self.ELLIPSIS: |
562 | 577 | parameters = parameters[:1] |
563 | 578 | builder = pytd.GenericType |
564 | 579 | else: |
565 | 580 | builder = pytdgen.heterogeneous_tuple |
566 | | - elif self._matches_named_type(base_type, _CONCATENATE_TYPES): |
| 581 | + elif self._matches_named_type(base_type, "typing.Concatenate"): |
567 | 582 | assert parameters |
568 | 583 | builder = pytd.Concatenate |
569 | | - elif self._matches_named_type(base_type, _CALLABLE_TYPES): |
| 584 | + elif self._matches_named_type(base_type, "typing.Callable"): |
570 | 585 | if parameters[0] is self.ELLIPSIS: |
571 | 586 | parameters = (pytd.AnythingType(),) + parameters[1:] |
572 | 587 | if parameters and isinstance(parameters[0], pytd.NamedType): |
@@ -661,7 +676,7 @@ def build_class( |
661 | 676 | self, class_name, bases, keywords, decorators, defs |
662 | 677 | ) -> pytd.Class: |
663 | 678 | """Build a pytd.Class from definitions collected from an ast node.""" |
664 | | - bases = classdef.get_bases(bases) |
| 679 | + bases = classdef.get_bases(bases, self.matches_type) |
665 | 680 | keywords = classdef.get_keywords(keywords) |
666 | 681 | constants, methods, aliases, slots, classes = _split_definitions(defs) |
667 | 682 |
|
|
0 commit comments