From d8f200debe69843ab73e495016f556b015fd851e Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 16:08:19 +0800 Subject: [PATCH 01/15] fix __eq__() and __ne__() type hints --- pypika/queries.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 2adc1b15..12d5a3b3 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -72,7 +72,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name return self.query.get_sql(**kwargs) - def __eq__(self, other: "AliasedQuery") -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, AliasedQuery) and self.name == other.name def __hash__(self) -> int: @@ -84,10 +84,10 @@ def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: self._name = name self._parent = parent - def __eq__(self, other: "Schema") -> bool: + def __eq__(self, other: Any) -> bool: return isinstance(other, Schema) and self._name == other._name and self._parent == other._parent - def __ne__(self, other: "Schema") -> bool: + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @ignore_copy @@ -181,7 +181,7 @@ def for_portion(self, period_criterion: PeriodCriterion) -> "Table": def __str__(self) -> str: return self.get_sql(quote_char='"') - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Table): return False @@ -1193,7 +1193,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other: "QueryBuilder") -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, QueryBuilder): return False @@ -1202,7 +1202,7 @@ def __eq__(self, other: "QueryBuilder") -> bool: return True - def __ne__(self, other: "QueryBuilder") -> bool: + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) def __hash__(self) -> int: From 2c668c2b861dd58a581596cf4ebbb14cfbc679db Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 17:14:06 +0800 Subject: [PATCH 02/15] fix @builder type hint --- pypika/dialects.py | 48 +++++++++---------- pypika/queries.py | 112 ++++++++++++++++++++++----------------------- pypika/terms.py | 42 ++++++++--------- pypika/utils.py | 22 ++++++++- 4 files changed, 121 insertions(+), 103 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index 6e151d68..7642ab8d 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -105,14 +105,14 @@ def __copy__(self) -> "MySQLQueryBuilder": @builder def for_update( self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () - ) -> "QueryBuilder": + ): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait self._for_update_of = set(of) @builder - def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> "MySQLQueryBuilder": + def on_duplicate_key_update(self, field: Union[Field, str], value: Any): if self._ignore_duplicates: raise QueryException("Can not have two conflict handlers") @@ -120,7 +120,7 @@ def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> "MySQ self._duplicate_updates.append((field, ValueWrapper(value))) @builder - def on_duplicate_key_ignore(self) -> "MySQLQueryBuilder": + def on_duplicate_key_ignore(self): if self._duplicate_updates: raise QueryException("Can not have two conflict handlers") @@ -162,7 +162,7 @@ def _on_duplicate_key_ignore_sql(self) -> str: return " ON DUPLICATE KEY IGNORE" @builder - def modifier(self, value: str) -> "MySQLQueryBuilder": + def modifier(self, value: str): """ Adds a modifier such as SQL_CALC_FOUND_ROWS to the query. https://dev.mysql.com/doc/refman/5.7/en/select.html @@ -191,11 +191,11 @@ def __init__(self) -> None: self._into_table = None @builder - def load(self, fp: str) -> "MySQLLoadQueryBuilder": + def load(self, fp: str): self._load_file = fp @builder - def into(self, table: Union[str, Table]) -> "MySQLLoadQueryBuilder": + def into(self, table: Union[str, Table]): self._into_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -254,7 +254,7 @@ def __init__(self, **kwargs: Any) -> None: self._hint = None @builder - def hint(self, label: str) -> "VerticaQueryBuilder": + def hint(self, label: str): self._hint = label def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -275,14 +275,14 @@ def __init__(self) -> None: self._preserve_rows = False @builder - def local(self) -> "VerticaCreateQueryBuilder": + def local(self): if not self._temporary: raise AttributeError("'Query' object has no attribute temporary") self._local = True @builder - def preserve_rows(self) -> "VerticaCreateQueryBuilder": + def preserve_rows(self): if not self._temporary: raise AttributeError("'Query' object has no attribute temporary") @@ -318,11 +318,11 @@ def __init__(self) -> None: self._from_file = None @builder - def from_file(self, fp: str) -> "VerticaCopyQueryBuilder": + def from_file(self, fp: str): self._from_file = fp @builder - def copy_(self, table: Union[str, Table]) -> "VerticaCopyQueryBuilder": + def copy_(self, table: Union[str, Table]): self._copy_table = table if isinstance(table, Table) else Table(table) def get_sql(self, *args: Any, **kwargs: Any) -> str: @@ -410,7 +410,7 @@ def __copy__(self) -> "PostgreSQLQueryBuilder": return newone @builder - def distinct_on(self, *fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": + def distinct_on(self, *fields: Union[str, Term]): for field in fields: if isinstance(field, str): self._distinct_on.append(Field(field)) @@ -420,7 +420,7 @@ def distinct_on(self, *fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": @builder def for_update( self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () - ) -> "QueryBuilder": + ): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait @@ -440,7 +440,7 @@ def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuild self._on_conflict_fields.append(target_field) @builder - def do_nothing(self) -> "PostgreSQLQueryBuilder": + def do_nothing(self): if len(self._on_conflict_do_updates) > 0: raise QueryException("Can not have two conflict handlers") self._on_conflict_do_nothing = True @@ -448,7 +448,7 @@ def do_nothing(self) -> "PostgreSQLQueryBuilder": @builder def do_update( self, update_field: Union[str, Field], update_value: Optional[Any] = None - ) -> "PostgreSQLQueryBuilder": + ): if self._on_conflict_do_nothing: raise QueryException("Can not have two conflict handlers") @@ -465,7 +465,7 @@ def do_update( self._on_conflict_do_updates.append((field, None)) @builder - def where(self, criterion: Criterion) -> "PostgreSQLQueryBuilder": + def where(self, criterion: Criterion): if not self._on_conflict: return super().where(criterion) @@ -489,7 +489,7 @@ def where(self, criterion: Criterion) -> "PostgreSQLQueryBuilder": raise QueryException('Can not have fieldless ON CONFLICT WHERE') @builder - def using(self, table: Union[Selectable, str]) -> "QueryBuilder": + def using(self, table: Union[Selectable, str]): self._using.append(table) def _distinct_sql(self, **kwargs: Any) -> str: @@ -567,7 +567,7 @@ def _on_conflict_action_sql(self, **kwargs: Any) -> str: return '' @builder - def returning(self, *terms: Any) -> "PostgreSQLQueryBuilder": + def returning(self, *terms: Any): for term in terms: if isinstance(term, Field): self._return_field(term) @@ -680,7 +680,7 @@ def __init__(self, **kwargs: Any) -> None: self._top_percent: bool = False @builder - def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = False) -> "MSSQLQueryBuilder": + def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = False): """ Implements support for simple TOP clauses. https://docs.microsoft.com/en-us/sql/t-sql/queries/top-transact-sql?view=sql-server-2017 @@ -696,7 +696,7 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F self._top_with_ties: bool = with_ties @builder - def fetch_next(self, limit: int) -> "MSSQLQueryBuilder": + def fetch_next(self, limit: int): # Overridden to provide a more domain-specific API for T-SQL users self._limit = limit @@ -813,15 +813,15 @@ def __init__(self): self._cluster_name = None @builder - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(self, dictionary: str): super()._set_target('DICTIONARY', dictionary) @builder - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(self, quota: str): super()._set_target('QUOTA', quota) @builder - def on_cluster(self, cluster: str) -> "ClickHouseDropQueryBuilder": + def on_cluster(self, cluster: str): if self._cluster_name: raise AttributeError("'DropQuery' object already has attribute cluster_name") self._cluster_name = cluster @@ -860,7 +860,7 @@ def __init__(self, **kwargs: Any) -> None: self._insert_or_replace = False @builder - def insert_or_replace(self, *terms: Any) -> "SQLLiteQueryBuilder": + def insert_or_replace(self, *terms: Any): self._apply_terms(*terms) self._replace = True self._insert_or_replace = True diff --git a/pypika/queries.py b/pypika/queries.py index 12d5a3b3..b076e6b1 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -39,7 +39,7 @@ def __init__(self, alias: str) -> None: self.alias = alias @builder - def as_(self, alias: str) -> "Selectable": + def as_(self, alias: str): self.alias = alias def field(self, name: str) -> Field: @@ -163,7 +163,7 @@ def get_sql(self, **kwargs: Any) -> str: return format_alias_sql(table_sql, self.alias, **kwargs) @builder - def for_(self, temporal_criterion: Criterion) -> "Table": + def for_(self, temporal_criterion: Criterion): if self._for: raise AttributeError("'Query' object already has attribute for_") if self._for_portion: @@ -171,7 +171,7 @@ def for_(self, temporal_criterion: Criterion) -> "Table": self._for = temporal_criterion @builder - def for_portion(self, period_criterion: PeriodCriterion) -> "Table": + def for_portion(self, period_criterion: PeriodCriterion): if self._for_portion: raise AttributeError("'Query' object already has attribute for_portion") if self._for: @@ -542,7 +542,7 @@ def __init__( self._wrapper_cls = wrapper_cls @builder - def orderby(self, *fields: Field, **kwargs: Any) -> "_SetOperation": + def orderby(self, *fields: Field, **kwargs: Any): for field in fields: field = ( Field(field, table=self.base_query._from[0]) @@ -553,31 +553,31 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "_SetOperation": self._orderbys.append((field, kwargs.get("order"))) @builder - def limit(self, limit: int) -> "_SetOperation": + def limit(self, limit: int): self._limit = limit @builder - def offset(self, offset: int) -> "_SetOperation": + def offset(self, offset: int): self._offset = offset @builder - def union(self, other: Selectable) -> "_SetOperation": + def union(self, other: Selectable): self._set_operation.append((SetOperation.union, other)) @builder - def union_all(self, other: Selectable) -> "_SetOperation": + def union_all(self, other: Selectable): self._set_operation.append((SetOperation.union_all, other)) @builder - def intersect(self, other: Selectable) -> "_SetOperation": + def intersect(self, other: Selectable): self._set_operation.append((SetOperation.intersect, other)) @builder - def except_of(self, other: Selectable) -> "_SetOperation": + def except_of(self, other: Selectable): self._set_operation.append((SetOperation.except_of, other)) @builder - def minus(self, other: Selectable) -> "_SetOperation": + def minus(self, other: Selectable): self._set_operation.append((SetOperation.minus, other)) def __add__(self, other: Selectable) -> "_SetOperation": @@ -757,7 +757,7 @@ def __copy__(self) -> "QueryBuilder": return newone @builder - def from_(self, selectable: Union[Selectable, Query, str]) -> "QueryBuilder": + def from_(self, selectable: Union[Selectable, Query, str]): """ Adds a table to the query. This function can only be called once and will raise an AttributeError if called a second time. @@ -784,7 +784,7 @@ def from_(self, selectable: Union[Selectable, Query, str]) -> "QueryBuilder": self._subquery_count = sub_query_count + 1 @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "QueryBuilder": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -821,12 +821,12 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._select_star_tables.add(new_table) @builder - def with_(self, selectable: Selectable, name: str) -> "QueryBuilder": + def with_(self, selectable: Selectable, name: str): t = AliasedQuery(name, selectable) self._with.append(t) @builder - def into(self, table: Union[str, Table]) -> "QueryBuilder": + def into(self, table: Union[str, Table]): if self._insert_table is not None: raise AttributeError("'Query' object has no attribute '%s'" % "into") @@ -836,7 +836,7 @@ def into(self, table: Union[str, Table]) -> "QueryBuilder": self._insert_table = table if isinstance(table, Table) else Table(table) @builder - def select(self, *terms: Any) -> "QueryBuilder": + def select(self, *terms: Any): for term in terms: if isinstance(term, Field): self._select_field(term) @@ -848,21 +848,21 @@ def select(self, *terms: Any) -> "QueryBuilder": self._select_other(self.wrap_constant(term, wrapper_cls=self._wrapper_cls)) @builder - def delete(self) -> "QueryBuilder": + def delete(self): if self._delete_from or self._selects or self._update_table: raise AttributeError("'Query' object has no attribute '%s'" % "delete") self._delete_from = True @builder - def update(self, table: Union[str, Table]) -> "QueryBuilder": + def update(self, table: Union[str, Table]): if self._update_table is not None or self._selects or self._delete_from: raise AttributeError("'Query' object has no attribute '%s'" % "update") self._update_table = table if isinstance(table, Table) else Table(table) @builder - def columns(self, *terms: Any) -> "QueryBuilder": + def columns(self, *terms: Any): if self._insert_table is None: raise AttributeError("'Query' object has no attribute '%s'" % "insert") @@ -875,17 +875,17 @@ def columns(self, *terms: Any) -> "QueryBuilder": self._columns.append(term) @builder - def insert(self, *terms: Any) -> "QueryBuilder": + def insert(self, *terms: Any): self._apply_terms(*terms) self._replace = False @builder - def replace(self, *terms: Any) -> "QueryBuilder": + def replace(self, *terms: Any): self._apply_terms(*terms) self._replace = True @builder - def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "QueryBuilder": + def force_index(self, term: Union[str, Index], *terms: Union[str, Index]): for t in (term, *terms): if isinstance(t, Index): self._force_indexes.append(t) @@ -893,7 +893,7 @@ def force_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "Qu self._force_indexes.append(Index(t)) @builder - def use_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "QueryBuilder": + def use_index(self, term: Union[str, Index], *terms: Union[str, Index]): for t in (term, *terms): if isinstance(t, Index): self._use_indexes.append(t) @@ -901,19 +901,19 @@ def use_index(self, term: Union[str, Index], *terms: Union[str, Index]) -> "Quer self._use_indexes.append(Index(t)) @builder - def distinct(self) -> "QueryBuilder": + def distinct(self): self._distinct = True @builder - def for_update(self) -> "QueryBuilder": + def for_update(self): self._for_update = True @builder - def ignore(self) -> "QueryBuilder": + def ignore(self): self._ignore = True @builder - def prewhere(self, criterion: Criterion) -> "QueryBuilder": + def prewhere(self, criterion: Criterion): if not self._validate_table(criterion): self._foreign_table = True @@ -923,7 +923,7 @@ def prewhere(self, criterion: Criterion) -> "QueryBuilder": self._prewheres = criterion @builder - def where(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": + def where(self, criterion: Union[Term, EmptyCriterion]): if isinstance(criterion, EmptyCriterion): return @@ -936,7 +936,7 @@ def where(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": self._wheres = criterion @builder - def having(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": + def having(self, criterion: Union[Term, EmptyCriterion]): if isinstance(criterion, EmptyCriterion): return @@ -946,7 +946,7 @@ def having(self, criterion: Union[Term, EmptyCriterion]) -> "QueryBuilder": self._havings = criterion @builder - def groupby(self, *terms: Union[str, int, Term]) -> "QueryBuilder": + def groupby(self, *terms: Union[str, int, Term]): for term in terms: if isinstance(term, str): term = Field(term, table=self._from[0]) @@ -956,11 +956,11 @@ def groupby(self, *terms: Union[str, int, Term]) -> "QueryBuilder": self._groupbys.append(term) @builder - def with_totals(self) -> "QueryBuilder": + def with_totals(self): self._with_totals = True @builder - def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any) -> "QueryBuilder": + def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): for_mysql = "mysql" == kwargs.get("vendor") if self._mysql_rollup: @@ -986,7 +986,7 @@ def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any) -> "Query self._groupbys.append(Rollup(*terms)) @builder - def orderby(self, *fields: Any, **kwargs: Any) -> "QueryBuilder": + def orderby(self, *fields: Any, **kwargs: Any): for field in fields: field = Field(field, table=self._from[0]) if isinstance(field, str) else self.wrap_constant(field) @@ -1040,11 +1040,11 @@ def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner return self.join(item, JoinType.hash) @builder - def limit(self, limit: int) -> "QueryBuilder": + def limit(self, limit: int): self._limit = limit @builder - def offset(self, offset: int) -> "QueryBuilder": + def offset(self, offset: int): self._offset = offset @builder @@ -1068,7 +1068,7 @@ def minus(self, other: "QueryBuilder") -> _SetOperation: return _SetOperation(self, other, SetOperation.minus, wrapper_cls=self._wrapper_cls) @builder - def set(self, field: Union[Field, str], value: Any) -> "QueryBuilder": + def set(self, field: Union[Field, str], value: Any): field = Field(field) if not isinstance(field, Field) else field self._updates.append((field, self._wrapper_cls(value))) @@ -1082,7 +1082,7 @@ def __sub__(self, other: "QueryBuilder") -> _SetOperation: return self.minus(other) @builder - def slice(self, slice: slice) -> "QueryBuilder": + def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop @@ -1602,7 +1602,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "Join": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1644,7 +1644,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: ) @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "JoinOn": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1676,7 +1676,7 @@ def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: pass @builder - def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]) -> "JoinUsing": + def replace_table(self, current_table: Optional[Table], new_table: Optional[Table]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1726,7 +1726,7 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def create_table(self, table: Union[Table, str]) -> "CreateQueryBuilder": + def create_table(self, table: Union[Table, str]): """ Creates the table. @@ -1745,7 +1745,7 @@ def create_table(self, table: Union[Table, str]) -> "CreateQueryBuilder": self._create_table = table if isinstance(table, Table) else Table(table) @builder - def temporary(self) -> "CreateQueryBuilder": + def temporary(self): """ Makes the table temporary. @@ -1755,7 +1755,7 @@ def temporary(self) -> "CreateQueryBuilder": self._temporary = True @builder - def unlogged(self) -> "CreateQueryBuilder": + def unlogged(self): """ Makes the table unlogged. @@ -1765,7 +1765,7 @@ def unlogged(self) -> "CreateQueryBuilder": self._unlogged = True @builder - def with_system_versioning(self) -> "CreateQueryBuilder": + def with_system_versioning(self): """ Adds system versioning. @@ -1775,7 +1775,7 @@ def with_system_versioning(self) -> "CreateQueryBuilder": self._with_system_versioning = True @builder - def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> "CreateQueryBuilder": + def columns(self, *columns: Union[str, TypedTuple[str, str], Column]): """ Adds the columns. @@ -1803,7 +1803,7 @@ def columns(self, *columns: Union[str, TypedTuple[str, str], Column]) -> "Create @builder def period_for( self, name, start_column: Union[str, Column], end_column: Union[str, Column] - ) -> "CreateQueryBuilder": + ): """ Adds a PERIOD FOR clause. @@ -1822,7 +1822,7 @@ def period_for( self._period_fors.append(PeriodFor(name, start_column, end_column)) @builder - def unique(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": + def unique(self, *columns: Union[str, Column]): """ Adds a UNIQUE constraint. @@ -1837,7 +1837,7 @@ def unique(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": self._uniques.append(self._prepare_columns_input(columns)) @builder - def primary_key(self, *columns: Union[str, Column]) -> "CreateQueryBuilder": + def primary_key(self, *columns: Union[str, Column]): """ Adds a primary key constraint. @@ -1864,7 +1864,7 @@ def foreign_key( reference_columns: List[Union[str, Column]], on_delete: ReferenceOption = None, on_update: ReferenceOption = None, - ) -> "CreateQueryBuilder": + ): """ Adds a foreign key constraint. @@ -1908,7 +1908,7 @@ def foreign_key( self._foreign_key_on_update = on_update @builder - def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder": + def as_select(self, query_builder: QueryBuilder): """ Creates the table from a select statement. @@ -1930,7 +1930,7 @@ def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder": self._as_select = query_builder @builder - def if_not_exists(self) -> "CreateQueryBuilder": + def if_not_exists(self): self._if_not_exists = True def get_sql(self, **kwargs: Any) -> str: @@ -2064,25 +2064,25 @@ def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("dialect", self.dialect) @builder - def drop_database(self, database: Union[Database, str]) -> "DropQueryBuilder": + def drop_database(self, database: Union[Database, str]): target = database if isinstance(database, Database) else Database(database) self._set_target('DATABASE', target) @builder - def drop_table(self, table: Union[Table, str]) -> "DropQueryBuilder": + def drop_table(self, table: Union[Table, str]): target = table if isinstance(table, Table) else Table(table) self._set_target('TABLE', target) @builder - def drop_user(self, user: str) -> "DropQueryBuilder": + def drop_user(self, user: str): self._set_target('USER', user) @builder - def drop_view(self, view: str) -> "DropQueryBuilder": + def drop_view(self, view: str): self._set_target('VIEW', view) @builder - def if_exists(self) -> "DropQueryBuilder": + def if_exists(self): self._if_exists = True def _set_target(self, kind: str, target: Union[Database, Table, str]) -> None: diff --git a/pypika/terms.py b/pypika/terms.py index c522550a..5b30b8d8 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -44,7 +44,7 @@ def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias @builder - def as_(self, alias: str) -> "Term": + def as_(self, alias: str): self.alias = alias @property @@ -547,7 +547,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.table.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Field": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -628,7 +628,7 @@ def is_aggregate(self) -> bool: return resolve_is_aggregate([val.is_aggregate for val in self.values]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Tuple": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -687,7 +687,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right, self.nested]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NestedCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -747,7 +747,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BasicCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -799,7 +799,7 @@ def is_aggregate(self) -> Optional[bool]: return self.term.is_aggregate @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ContainsCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -821,7 +821,7 @@ def get_sql(self, subquery: Any = None, **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) @builder - def negate(self) -> "ContainsCriterion": + def negate(self): self._is_negated = True @@ -862,7 +862,7 @@ def is_aggregate(self) -> Optional[bool]: class BetweenCriterion(RangeCriterion): @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BetweenCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -907,7 +907,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.value.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "BitwiseAndCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -938,7 +938,7 @@ def nodes_(self) -> Iterator[NodeT]: yield from self.term.nodes_() @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "NullCriterion": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1023,7 +1023,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([self.left.is_aggregate, self.right.is_aggregate]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "ArithmeticExpression": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1124,11 +1124,11 @@ def is_aggregate(self) -> Optional[bool]: ) @builder - def when(self, criterion: Any, term: Any) -> "Case": + def when(self, criterion: Any, term: Any): self._cases.append((criterion, self.wrap_constant(term))) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Case": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1205,7 +1205,7 @@ def inner(inner_self, *args, **kwargs): return inner @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Not": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1283,7 +1283,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([arg.is_aggregate for arg in self.args]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Function": + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. @@ -1348,7 +1348,7 @@ def __init__(self, name, *args, **kwargs): self._include_filter = False @builder - def filter(self, *filters: Any) -> "AnalyticFunction": + def filter(self, *filters: Any): self._include_filter = True self._filters += filters @@ -1379,12 +1379,12 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: self._include_over = False @builder - def over(self, *terms: Any) -> "AnalyticFunction": + def over(self, *terms: Any): self._include_over = True self._partition += terms @builder - def orderby(self, *terms: Any, **kwargs: Any) -> "AnalyticFunction": + def orderby(self, *terms: Any, **kwargs: Any): self._include_over = True self._orderbys += [(term, kwargs.get("order")) for term in terms] @@ -1453,11 +1453,11 @@ def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[Edge self.bound = (bound, and_bound) if and_bound else bound @builder - def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction": + def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None): self._set_frame_and_bounds("ROWS", bound, and_bound) @builder - def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None) -> "WindowFrameAnalyticFunction": + def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None): self._set_frame_and_bounds("RANGE", bound, and_bound) def get_frame_sql(self) -> str: @@ -1486,7 +1486,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: self._ignore_nulls = False @builder - def ignore_nulls(self) -> "IgnoreNullsAnalyticFunction": + def ignore_nulls(self): self._ignore_nulls = True def get_special_params_sql(self, **kwargs: Any) -> Optional[str]: diff --git a/pypika/utils.py b/pypika/utils.py index 1506704b..7f661b0d 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,4 +1,11 @@ -from typing import Any, Callable, List, Optional, Type +from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING +if TYPE_CHECKING: + import sys + from typing import overload, TypeVar + if sys.version_info >= (3, 10): + from typing import ParamSpec, Concatenate + else: + from typing_extensions import ParamSpec, Concatenate __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -36,7 +43,18 @@ class FunctionException(Exception): pass -def builder(func: Callable) -> Callable: +if TYPE_CHECKING: + _T = TypeVar('_T') + _S = TypeVar('_S') + _P = ParamSpec('_P') + +if TYPE_CHECKING: + @overload + def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: ... + @overload + def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: ... + +def builder(func): """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is From ea9dd15135eb43795fd012c513befd753bd8e2f8 Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 17:15:03 +0800 Subject: [PATCH 03/15] fix RangeCriterion.__init__ type hint --- pypika/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 5b30b8d8..fe520dd2 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -843,7 +843,7 @@ def negate(self): class RangeCriterion(Criterion): - def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> str: + def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None) -> None: super().__init__(alias) self.term = term self.start = start From 417beae8e1b81f2451763b0c2eb03bb339ac83f4 Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 17:32:22 +0800 Subject: [PATCH 04/15] fix invaild type comments --- pypika/clickhouse/array.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 67929f16..9e9a9147 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,4 +1,5 @@ import abc +from typing import Union from pypika.terms import ( Field, @@ -32,8 +33,8 @@ def get_sql(self): class HasAny(Function): def __init__( self, - left_array: Array or Field, - right_array: Array or Field, + left_array: Union[Array, Field], + right_array: Union[Array, Field], alias: str = None, schema: str = None, ): @@ -56,7 +57,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Array or Field, alias: str = None, schema: str = None): + def __init__(self, array: Union[Array, Field], alias: str = None, schema: str = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() From 720f7d4fd57055474a44cacf8688f150a60d2db4 Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 22:49:17 +0800 Subject: [PATCH 05/15] fix problems in terms.py --- pypika/enums.py | 2 +- pypika/terms.py | 196 +++++++++++++++++++++++++++--------------------- 2 files changed, 111 insertions(+), 87 deletions(-) diff --git a/pypika/enums.py b/pypika/enums.py index 751889c4..0db191b6 100644 --- a/pypika/enums.py +++ b/pypika/enums.py @@ -145,7 +145,7 @@ class Dialects(Enum): SNOWFLAKE = "snowflake" -class JSONOperators(Enum): +class JSONOperators(Comparator): HAS_KEY = "?" CONTAINS = "@>" CONTAINED_BY = "<@" diff --git a/pypika/terms.py b/pypika/terms.py index fe520dd2..f861f7e6 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1,9 +1,10 @@ import inspect import re +import typing import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -28,20 +29,25 @@ class Node: - is_aggregate = None + @property + def is_aggregate(self) -> Optional[bool]: + return None - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator["Node"]: yield self def find_(self, type: Type[NodeT]) -> List[NodeT]: return [node for node in self.nodes_() if isinstance(node, type)] +WrappedConstant = Union[Node, "LiteralValue", "Array", "Tuple", "ValueWrapper"] class Term(Node): - is_aggregate = False - def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias + + @property + def is_aggregate(self) -> Optional[bool]: + return False @builder def as_(self, alias: str): @@ -59,7 +65,7 @@ def fields_(self) -> Set["Field"]: @staticmethod def wrap_constant( val, wrapper_cls: Optional[Type["Term"]] = None - ) -> Union[ValueError, NodeT, "LiteralValue", "Array", "Tuple", "ValueWrapper"]: + ) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -149,28 +155,28 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.glob, self, Term._assert_guard(self.wrap_constant(expr))) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.like, self, Term._assert_guard(self.wrap_constant(expr))) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.not_like, self, Term._assert_guard(self.wrap_constant(expr))) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.ilike, self, Term._assert_guard(self.wrap_constant(expr))) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.not_ilike, self, Term._assert_guard(self.wrap_constant(expr))) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.rlike, self, Term._assert_guard(self.wrap_constant(expr))) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.regex, self, Term._assert_guard(self.wrap_constant(pattern))) def regexp(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.regexp, self, Term._assert_guard(self.wrap_constant(pattern))) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) @@ -179,7 +185,7 @@ def from_to(self, start: Any, end: Any) -> "PeriodCriterion": return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) + return BasicCriterion(Matching.as_of, self, Term._assert_guard(self.wrap_constant(expr))) def all_(self) -> "All": return All(self) @@ -193,7 +199,7 @@ def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) + return BasicCriterion(Matching.bin_regex, self, Term._assert_guard(self.wrap_constant(pattern))) def negate(self) -> "Not": return Not(self) @@ -255,23 +261,23 @@ def __rlshift__(self, other: Any) -> "ArithmeticExpression": def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) - def __eq__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) + def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore + return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) - def __ne__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) + def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore + return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) + return BasicCriterion(Equality.gt, self, Term._assert_guard(self.wrap_constant(other))) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) + return BasicCriterion(Equality.gte, self, Term._assert_guard(self.wrap_constant(other))) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) + return BasicCriterion(Equality.lt, self, Term._assert_guard(self.wrap_constant(other))) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) + return BasicCriterion(Equality.lte, self, Term._assert_guard(self.wrap_constant(other))) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -286,17 +292,26 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() + + @classmethod + def _assert_guard(cls, v: Any) -> "Term": + if isinstance(v, cls): + return v + else: + raise TypeError("expect Term object, got {}".format(type(v).__name__)) class Parameter(Term): - is_aggregate = None - def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() self.placeholder = placeholder def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + + @property + def is_aggregate(self) -> Optional[bool]: + return None class QmarkParameter(Parameter): @@ -344,7 +359,7 @@ class Negative(Term): def __init__(self, term: Term) -> None: super().__init__() self.term = term - + @property def is_aggregate(self) -> Optional[bool]: return self.term.is_aggregate @@ -354,11 +369,13 @@ def get_sql(self, **kwargs: Any) -> str: class ValueWrapper(Term): - is_aggregate = None - def __init__(self, value: Any, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value + + @property + def is_aggregate(self) -> Optional[bool]: + return None def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @@ -391,11 +408,10 @@ def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = class JSON(Term): - table = None - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value + self.table: Optional[Union[str, "Selectable"]] = None def _recursive_get_sql(self, value: Any, **kwargs: Any) -> str: if isinstance(value, dict): @@ -429,10 +445,10 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) + return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) def get_path_json_value(self, path_json: str) -> "BasicCriterion": return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) @@ -512,14 +528,11 @@ def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": return crit - def get_sql(self) -> str: + def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() class EmptyCriterion(Criterion): - is_aggregate = None - tables_ = set() - def fields_(self) -> Set["Field"]: return set() @@ -531,6 +544,14 @@ def __or__(self, other: Any) -> Any: def __xor__(self, other: Any) -> Any: return other + + @property + def is_aggregate(self) -> Optional[bool]: + return None + + @property + def tables_(self) -> Set: + return set() class Field(Criterion, JSON): @@ -541,9 +562,9 @@ def __init__( self.name = name self.table = table - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self - if self.table is not None: + if self.table is not None and not isinstance(self.table, str): yield from self.table.nodes_() @builder @@ -560,7 +581,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T """ self.table = new_table if self.table == current_table else self.table - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, **kwargs: Any) -> str: # type: ignore with_alias = kwargs.pop("with_alias", False) with_namespace = kwargs.pop("with_namespace", False) quote_char = kwargs.pop("quote_char", None) @@ -568,8 +589,8 @@ def get_sql(self, **kwargs: Any) -> str: field_sql = format_quotes(self.name, quote_char) # Need to add namespace if the table has an alias - if self.table and (with_namespace or self.table.alias): - table_name = self.table.get_table_name() + if self.table and (with_namespace or (not isinstance(self.table, str) and self.table.alias)): + table_name = self.table.get_table_name() if not isinstance(self.table, str) else self.table field_sql = "{namespace}.{name}".format( namespace=format_quotes(table_name, quote_char), name=field_sql, @@ -594,16 +615,16 @@ class Star(Field): def __init__(self, table: Optional[Union[str, "Selectable"]] = None) -> None: super().__init__("*", table=table) - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self - if self.table is not None: + if self.table is not None and not isinstance(self.table, str): yield from self.table.nodes_() - def get_sql( + def get_sql( # type: ignore self, with_alias: bool = False, with_namespace: bool = False, quote_char: Optional[str] = None, **kwargs: Any ) -> str: - if self.table and (with_namespace or self.table.alias): - namespace = self.table.alias or getattr(self.table, "_table_name") + if self.table and (with_namespace or (not isinstance(self.table, str) and self.table.alias)): + namespace = (self.table.alias if not isinstance(self.table, str) else self.table) or getattr(self.table, "_table_name") return "{}.*".format(format_quotes(namespace, quote_char)) return "*" @@ -614,17 +635,17 @@ def __init__(self, *values: Any) -> None: super().__init__() self.values = [self.wrap_constant(value) for value in values] - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for value in self.values: yield from value.nodes_() def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values)) + sql = "({})".format(",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values)) return format_alias_sql(sql, self.alias, **kwargs) @property - def is_aggregate(self) -> bool: + def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([val.is_aggregate for val in self.values]) @builder @@ -639,13 +660,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the field with the tables replaced. """ - self.values = [value.replace_table(current_table, new_table) for value in self.values] + self.values = [Term._assert_guard(value).replace_table(current_table, new_table) for value in self.values] class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: dialect = kwargs.get("dialect", None) - values = ",".join(term.get_sql(**kwargs) for term in self.values) + values = ",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): @@ -676,7 +697,7 @@ def __init__( self.right = right self.nested = nested - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.right.nodes_() yield from self.left.nodes_() @@ -707,7 +728,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: left=self.left.get_sql(**kwargs), comparator=self.comparator.value, right=self.right.get_sql(**kwargs), - nested_comparator=self.nested_comparator.value, + nested_comparator=self.nested_comparator.comparator.value, nested=self.nested.get_sql(**kwargs), ) @@ -737,7 +758,7 @@ def __init__(self, comparator: Comparator, left: Term, right: Term, alias: Optio self.left = left self.right = right - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.right.nodes_() yield from self.left.nodes_() @@ -789,7 +810,7 @@ def __init__(self, term: Any, container: Term, alias: Optional[str] = None) -> N self.container = container self._is_negated = False - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.container.nodes_() @@ -849,7 +870,7 @@ def __init__(self, term: Term, start: Any, end: Any, alias: Optional[str] = None self.start = start self.end = end - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.start.nodes_() @@ -901,7 +922,7 @@ def __init__(self, term: Term, value: Any, alias: Optional[str] = None) -> None: self.term = term self.value = value - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() yield from self.value.nodes_() @@ -933,7 +954,7 @@ def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @@ -967,7 +988,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class ComplexCriterion(BasicCriterion): - def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: + def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: # type: ignore sql = "{left} {comparator} {right}".format( comparator=self.comparator.value, left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs), @@ -1012,7 +1033,7 @@ def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[ self.left = left self.right = right - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.left.nodes_() yield from self.right.nodes_() @@ -1102,10 +1123,10 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) - self._cases = [] - self._else = None + self._cases: List[typing.Tuple[Any, Any]] = [] + self._else: WrappedConstant| None = None - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for criterion, term in self._cases: @@ -1140,13 +1161,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T A copy of the term with the tables replaced. """ self._cases = [ - [ + ( criterion.replace_table(current_table, new_table), term.replace_table(current_table, new_table), - ] + ) for criterion, term in self._cases ] - self._else = self._else.replace_table(current_table, new_table) if self._else else None + self._else = Term._assert_guard(self._else).replace_table(current_table, new_table) if self._else else None @builder def else_(self, term: Any) -> "Case": @@ -1161,7 +1182,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: "WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs)) for criterion, term in self._cases ) - else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" + else_ = " ELSE {}".format(Term._assert_guard(self._else).get_sql(**kwargs)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) @@ -1176,7 +1197,7 @@ def __init__(self, term: Any, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @@ -1224,7 +1245,7 @@ def __init__(self, term: Any, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self yield from self.term.nodes_() @@ -1246,7 +1267,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "Function": raise FunctionException( "Function {name} require these arguments ({params}), ({args}) passed".format( name=self.name, - params=", ".join(str(p) for p in self.params), + params=", ".join(str(p) for p in self.params) if self.params else "", args=", ".join(str(p) for p in args), ) ) @@ -1259,7 +1280,6 @@ def _has_params(self): def _is_valid_function_call(self, *args): return len(args) == len(self.params) - class Function(Criterion): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) @@ -1267,7 +1287,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: self.args = [self.wrap_constant(param) for param in args] self.schema = kwargs.get("schema") - def nodes_(self) -> Iterator[NodeT]: + def nodes_(self) -> Iterator[Node]: yield self for arg in self.args: yield from arg.nodes_() @@ -1294,7 +1314,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the criterion with the tables replaced. """ - self.args = [param.replace_table(current_table, new_table) for param in self.args] + self.args = [Term._assert_guard(param).replace_table(current_table, new_table) for param in self.args] def get_special_params_sql(self, **kwargs: Any) -> Any: pass @@ -1309,7 +1329,7 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - p.get_sql(with_alias=False, subquery=True, **kwargs) + Term._assert_guard(p).get_sql(with_alias=False, subquery=True, **kwargs) if hasattr(p, "get_sql") else self.get_arg_sql(p, **kwargs) for p in self.args @@ -1352,9 +1372,10 @@ def filter(self, *filters: Any): self._include_filter = True self._filters += filters - def get_filter_sql(self, **kwargs: Any) -> str: + def get_filter_sql(self, **kwargs: Any) -> Optional[str]: if self._include_filter: return "WHERE {criterions}".format(criterions=Criterion.all(self._filters).get_sql(**kwargs)) + return None def get_function_sql(self, **kwargs: Any): sql = super(AggregateFunction, self).get_function_sql(**kwargs) @@ -1373,8 +1394,8 @@ class AnalyticFunction(AggregateFunction): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(name, *args, **kwargs) self._filters = [] - self._partition = [] - self._orderbys = [] + self._partition: List[Any] = [] + self._orderbys: List[Any] = [] self._include_filter = False self._include_over = False @@ -1428,24 +1449,26 @@ def get_function_sql(self, **kwargs: Any) -> str: EdgeT = TypeVar("EdgeT", bound="WindowFrameAnalyticFunction.Edge") +AnyEdge = Union[str, "WindowFrameAnalyticFunction.Edge"] class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: + modifier: ClassVar[Optional[str]] = None def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value def __str__(self) -> str: return "{value} {modifier}".format( value=self.value or "UNBOUNDED", - modifier=self.modifier, + modifier=self.modifier or "", ) def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(name, *args, **kwargs) - self.frame = None - self.bound = None + self.frame: Optional[str] = None + self.bound: Optional[Union[typing.Tuple[AnyEdge, AnyEdge], AnyEdge]] = None - def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[EdgeT]) -> None: + def _set_frame_and_bounds(self, frame: str, bound: AnyEdge, and_bound: Optional[AnyEdge]) -> None: if self.frame or self.bound: raise AttributeError() @@ -1453,11 +1476,11 @@ def _set_frame_and_bounds(self, frame: str, bound: str, and_bound: Optional[Edge self.bound = (bound, and_bound) if and_bound else bound @builder - def rows(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None): + def rows(self, bound: AnyEdge, and_bound: Optional[AnyEdge] = None): self._set_frame_and_bounds("ROWS", bound, and_bound) @builder - def range(self, bound: Union[str, EdgeT], and_bound: Optional[EdgeT] = None): + def range(self, bound: AnyEdge, and_bound: Optional[AnyEdge] = None): self._set_frame_and_bounds("RANGE", bound, and_bound) def get_frame_sql(self) -> str: @@ -1497,7 +1520,7 @@ def get_special_params_sql(self, **kwargs: Any) -> Optional[str]: return None -class Interval(Node): +class Interval(Term): templates = { # PostgreSQL, Redshift and Vertica require quotes around the expr and unit e.g. INTERVAL '1 week' Dialects.POSTGRESQL: "INTERVAL '{expr} {unit}'", @@ -1558,6 +1581,7 @@ def __str__(self) -> str: def get_sql(self, **kwargs: Any) -> str: dialect = self.dialect or kwargs.get("dialect") + unit: Optional[str] if self.largest == "MICROSECOND": expr = getattr(self, "microseconds") unit = "MICROSECOND" @@ -1598,7 +1622,7 @@ def get_sql(self, **kwargs: Any) -> str: if unit is None: unit = "DAY" - return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) + return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) # type: ignore class Pow(Function): From 0e76c2d193784a9d41107533c49627d96e569fec Mon Sep 17 00:00:00 2001 From: thisLight Date: Sat, 6 Aug 2022 23:08:58 +0800 Subject: [PATCH 06/15] validate: remove Optional from exc --- pypika/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/utils.py b/pypika/utils.py index 7f661b0d..58bcc166 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -136,7 +136,7 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Optional[Exception] = None, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: if type is not None: for arg in args: if not isinstance(arg, type): From 47591f7d0f6a8d671a31f208d1c88f5ee6439912 Mon Sep 17 00:00:00 2001 From: thisLight Date: Mon, 19 Sep 2022 09:03:30 +0800 Subject: [PATCH 07/15] fix typo --- pypika/clickhouse/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 9e9a9147..40de67c5 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -42,7 +42,7 @@ def __init__( self._right_array = right_array self.alias = alias self.schema = schema - self.args = () + self.args = tuple() self.name = "hasAny" def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): From 707184f937485f9af34e28339aa6a25aea116851 Mon Sep 17 00:00:00 2001 From: thisLight Date: Mon, 19 Sep 2022 09:12:29 +0800 Subject: [PATCH 08/15] dialects: fix type comments --- pypika/dialects.py | 67 ++++++++++++++++++++---------------- pypika/terms.py | 2 +- pypika/tests/test_selects.py | 2 +- 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index 7642ab8d..8ede83cd 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,6 @@ import itertools from copy import copy -from typing import Any, Optional, Union, Tuple as TypedTuple +from typing import Any, List, Optional, Set, Union, Tuple as TypedTuple, cast from pypika.enums import Dialects from pypika.queries import ( @@ -88,23 +88,23 @@ class MySQLQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.MYSQL, wrap_set_operation_queries=False, **kwargs) - self._duplicate_updates = [] + self._duplicate_updates: List[TypedTuple[Field, ValueWrapper]] = [] self._ignore_duplicates = False - self._modifiers = [] + self._modifiers: List[str] = [] self._for_update_nowait = False self._for_update_skip_locked = False - self._for_update_of = set() + self._for_update_of: Set[str] = set() def __copy__(self) -> "MySQLQueryBuilder": - newone = super().__copy__() + newone = cast(MySQLQueryBuilder, super().__copy__()) newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) return newone @builder def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () + self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple() ): self._for_update = True self._for_update_skip_locked = skip_locked @@ -126,7 +126,7 @@ def on_duplicate_key_ignore(self): self._ignore_duplicates = True - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, **kwargs: Any) -> str: # type: ignore self._set_kwargs_defaults(kwargs) querystring = super(MySQLQueryBuilder, self).get_sql(**kwargs) if querystring: @@ -187,8 +187,8 @@ class MySQLLoadQueryBuilder: QUERY_CLS = MySQLQuery def __init__(self) -> None: - self._load_file = None - self._into_table = None + self._load_file: Optional[str] = None + self._into_table: Optional[Table] = None @builder def load(self, fp: str): @@ -211,6 +211,7 @@ def _load_file_sql(self, **kwargs: Any) -> str: return "LOAD DATA LOCAL INFILE '{}'".format(self._load_file) def _into_table_sql(self, **kwargs: Any) -> str: + assert self._into_table is not None return " INTO TABLE `{}`".format(self._into_table.get_sql(**kwargs)) def _options_sql(self, **kwargs: Any) -> str: @@ -251,7 +252,7 @@ class VerticaQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.VERTICA, **kwargs) - self._hint = None + self._hint: Optional[str] = None @builder def hint(self, label: str): @@ -289,6 +290,7 @@ def preserve_rows(self): self._preserve_rows = True def _create_table_sql(self, **kwargs: Any) -> str: + assert self._create_table is not None return "CREATE {local}{temporary}TABLE {table}".format( local="LOCAL " if self._local else "", temporary="TEMPORARY " if self._temporary else "", @@ -301,6 +303,7 @@ def _table_options_sql(self, **kwargs) -> str: return table_options def _as_select_sql(self, **kwargs: Any) -> str: + assert self._as_select is not None return "{preserve_rows} AS ({query})".format( preserve_rows=self._preserve_rows_sql(), query=self._as_select.get_sql(**kwargs), @@ -314,8 +317,8 @@ class VerticaCopyQueryBuilder: QUERY_CLS = VerticaQuery def __init__(self) -> None: - self._copy_table = None - self._from_file = None + self._copy_table: Optional[Table] = None + self._from_file: Optional[str] = None @builder def from_file(self, fp: str): @@ -335,6 +338,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return querystring def _copy_table_sql(self, **kwargs: Any) -> str: + assert self._copy_table return 'COPY "{}"'.format(self._copy_table.get_sql(**kwargs)) def _from_file_sql(self, **kwargs: Any) -> str: @@ -387,24 +391,24 @@ class PostgreSQLQueryBuilder(QueryBuilder): def __init__(self, **kwargs: Any) -> None: super().__init__(dialect=Dialects.POSTGRESQL, **kwargs) - self._returns = [] + self._returns: List[Term] = [] self._return_star = False self._on_conflict = False - self._on_conflict_fields = [] + self._on_conflict_fields: List[Term] = [] self._on_conflict_do_nothing = False - self._on_conflict_do_updates = [] - self._on_conflict_wheres = None - self._on_conflict_do_update_wheres = None + self._on_conflict_do_updates: List[TypedTuple[Field, Optional[ValueWrapper]]] = [] + self._on_conflict_wheres: Optional[Criterion] = None + self._on_conflict_do_update_wheres: Optional[Criterion] = None - self._distinct_on = [] + self._distinct_on: List[Term] = [] self._for_update_nowait = False self._for_update_skip_locked = False - self._for_update_of = set() + self._for_update_of: Set[str] = set() def __copy__(self) -> "PostgreSQLQueryBuilder": - newone = super().__copy__() + newone = cast(PostgreSQLQueryBuilder, super().__copy__()) newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) return newone @@ -419,7 +423,7 @@ def distinct_on(self, *fields: Union[str, Term]): @builder def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = () + self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple() ): self._for_update = True self._for_update_skip_locked = skip_locked @@ -454,6 +458,7 @@ def do_update( if isinstance(update_field, str): field = self._conflict_field_str(update_field) + assert field is not None elif isinstance(update_field, Field): field = update_field else: @@ -502,6 +507,7 @@ def _distinct_sql(self, **kwargs: Any) -> str: def _conflict_field_str(self, term: str) -> Optional[Field]: if self._insert_table: return Field(term, table=self._insert_table) + return None def _on_conflict_sql(self, **kwargs: Any) -> str: if not self._on_conflict_do_nothing and len(self._on_conflict_do_updates) == 0: @@ -578,7 +584,9 @@ def returning(self, *terms: Any): raise QueryException("Aggregate functions are not allowed in returning") self._return_other(term) else: - self._return_other(self.wrap_constant(term, self._wrapper_cls)) + constant = self.wrap_constant(term, self._wrapper_cls) + assert isinstance(constant, Term) + self._return_other(constant) def _validate_returning_term(self, term: Term) -> None: for field in term.fields_(): @@ -596,7 +604,7 @@ def _set_returns_for_star(self) -> None: self._returns = [returning for returning in self._returns if not hasattr(returning, "table")] self._return_star = True - def _return_field(self, term: Union[str, Field]) -> None: + def _return_field(self, term: Field) -> None: if self._return_star: # Do not add select terms after a star is selected return @@ -615,11 +623,11 @@ def _return_field_str(self, term: Union[str, Field]) -> None: return if self._insert_table: - self._return_field(Field(term, table=self._insert_table)) + self._return_field(Field(term, table=self._insert_table) if isinstance(term, str) else term) elif self._update_table: - self._return_field(Field(term, table=self._update_table)) + self._return_field(Field(term, table=self._update_table) if isinstance(term, str) else term) elif self._delete_from: - self._return_field(Field(term, table=self._from[0])) + self._return_field(Field(term, table=self._from[0]) if isinstance(term, str) else term) else: raise QueryException("Returning can't be used in this query") @@ -692,8 +700,8 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F if percent and not (0 <= int(value) <= 100): raise QueryException("TOP value must be between 0 and 100 when `percent`" " is specified") - self._top_percent: bool = percent - self._top_with_ties: bool = with_ties + self._top_percent = percent + self._top_with_ties = with_ties @builder def fetch_next(self, limit: int): @@ -754,7 +762,7 @@ def _builder(cls, **kwargs: Any) -> "ClickHouseQueryBuilder": ) @classmethod - def drop_database(self, database: Union[Database, str]) -> "ClickHouseDropQueryBuilder": + def drop_database(cls, database: Union[Database, str]) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_database(database) @classmethod @@ -786,6 +794,7 @@ def _delete_sql(**kwargs: Any) -> str: return 'ALTER TABLE' def _update_sql(self, **kwargs: Any) -> str: + assert self._update_table return "ALTER TABLE {table}".format(table=self._update_table.get_sql(**kwargs)) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: diff --git a/pypika/terms.py b/pypika/terms.py index f861f7e6..992b89b2 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1284,7 +1284,7 @@ class Function(Criterion): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args = [self.wrap_constant(param) for param in args] + self.args: Iterable[WrappedConstant] = [self.wrap_constant(param) for param in args] self.schema = kwargs.get("schema") def nodes_(self) -> Iterator[Node]: diff --git a/pypika/tests/test_selects.py b/pypika/tests/test_selects.py index 1ce04937..26078b4a 100644 --- a/pypika/tests/test_selects.py +++ b/pypika/tests/test_selects.py @@ -346,7 +346,7 @@ class MyEnum(Enum): INT = 0 BOOL = True DATE = date(2020, 2, 2) - NONE = None + NONE: None = None class WhereTests(unittest.TestCase): From f624ab78bca74c32eec6b2e359694f860712dabd Mon Sep 17 00:00:00 2001 From: thisLight Date: Mon, 19 Sep 2022 09:15:56 +0800 Subject: [PATCH 09/15] test_internals: fix a test - fix criterion_replace_table_with_value tests Field == Table, which is unintended. --- pypika/tests/test_internals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/tests/test_internals.py b/pypika/tests/test_internals.py index 921e381b..fdf217cc 100644 --- a/pypika/tests/test_internals.py +++ b/pypika/tests/test_internals.py @@ -13,7 +13,7 @@ def test__criterion_replace_table_with_value(self): table = Table("a") c = (Field("foo") > 1).replace_table(None, table) - self.assertEqual(c.left, table) + self.assertEqual(c.left.tables_, {table}) self.assertEqual(c.tables_, {table}) def test__case_tables(self): From 3c4f936c3012057a29b07a0bf88afe4f065c09ce Mon Sep 17 00:00:00 2001 From: thisLight Date: Mon, 19 Sep 2022 09:28:52 +0800 Subject: [PATCH 10/15] BasicCriterion.replace_table: add return type --- pypika/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 992b89b2..fb9fe2ca 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -768,7 +768,7 @@ def is_aggregate(self) -> Optional[bool]: return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]]) @builder - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]): + def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> None: """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. From 414b729d5fc376b0fe090e0714156b4ba6345b4e Mon Sep 17 00:00:00 2001 From: thisLight Date: Tue, 20 Sep 2022 21:28:26 +0800 Subject: [PATCH 11/15] Fix errors in queries.py - fix mypy errors in queries.py - add Protocol class SQLPart --- pypika/queries.py | 396 +++++++++++++++++++++++++++------------------- pypika/utils.py | 9 +- 2 files changed, 245 insertions(+), 160 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index b076e6b1..d59a72cd 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1,8 +1,10 @@ from copy import copy from functools import reduce -from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set +from itertools import chain +import operator +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set, cast, TypeVar -from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation +from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( ArithmeticExpression, Criterion, @@ -18,6 +20,7 @@ ValueWrapper, Criterion, PeriodCriterion, + WrappedConstant, ) from pypika.utils import ( JoinException, @@ -28,14 +31,18 @@ format_alias_sql, format_quotes, ignore_copy, + SQLPart, ) __author__ = "Timothy Heys" __email__ = "theys@kayak.com" +_T = TypeVar("_T") + + class Selectable(Node): - def __init__(self, alias: str) -> None: + def __init__(self, alias: Optional[str]) -> None: self.alias = alias @builder @@ -58,10 +65,12 @@ def __getitem__(self, name: str) -> Field: return self.field(name) def get_table_name(self) -> str: + if not self.alias: + raise TypeError("expect str, got None") return self.alias -class AliasedQuery(Selectable): +class AliasedQuery(Selectable, SQLPart): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: super().__init__(alias=name) self.name = name @@ -79,7 +88,7 @@ def __hash__(self) -> int: return hash(str(self.name)) -class Schema: +class Schema(SQLPart): def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: self._name = name self._parent = parent @@ -115,7 +124,7 @@ def __getattr__(self, item: str) -> Schema: class Table(Selectable): @staticmethod - def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Union[str, list, tuple, Schema, None]: + def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for # the next major release. Schema is accepted as a string, list/tuple, Schema instance, or None if isinstance(schema, Schema): @@ -137,8 +146,8 @@ def __init__( self._table_name = name self._schema = self._init_schema(schema) self._query_cls = query_cls or Query - self._for = None - self._for_portion = None + self._for: Optional[Criterion] = None + self._for_portion: Optional[PeriodCriterion] = None if not issubclass(self._query_cls, Query): raise TypeError("Expected 'query_cls' to be subclass of Query") @@ -250,13 +259,16 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List """ tables = [] for name in names: - if isinstance(name, tuple) and len(name) == 2: - t = Table( - name=name[0], - alias=name[1], - schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), - ) + if isinstance(name, tuple): + if len(name) == 2: + t = Table( + name=name[0], + alias=name[1], + schema=kwargs.get("schema"), + query_cls=kwargs.get("query_cls"), + ) + else: + raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) else: t = Table( name=name, @@ -267,7 +279,7 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List return tables -class Column: +class Column(SQLPart): """Represents a column.""" def __init__( @@ -313,8 +325,11 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: """ columns = [] for name in names: - if isinstance(name, tuple) and len(name) == 2: - column = Column(column_name=name[0], column_type=name[1]) + if isinstance(name, tuple): + if len(name) == 2: + column = Column(column_name=name[0], column_type=name[1]) + else: + raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) else: column = Column(column_name=name) columns.append(column) @@ -322,7 +337,7 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: return columns -class PeriodFor: +class PeriodFor(SQLPart): def __init__(self, name: str, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: self.name = name self.start_column = start_column if isinstance(start_column, Column) else Column(start_column) @@ -385,7 +400,7 @@ def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": return CreateQueryBuilder().create_table(table) @classmethod - def drop_database(cls, database: Union[Database, Table]) -> "DropQueryBuilder": + def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": """ Query builder entry point. Initializes query building and sets the table name to be dropped. When using this function, the query becomes a DROP statement. @@ -514,7 +529,7 @@ def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List return make_tables(*names, **kwargs) -class _SetOperation(Selectable, Term): +class _SetOperation(Selectable, Term, SQLPart): """ A Query class wrapper for a all set operations, Union DISTINCT or ALL, Intersect, Except or Minus @@ -533,24 +548,29 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation = [(set_operation, set_operation_query)] - self._orderbys = [] + self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [(set_operation, set_operation_query)] + self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] - self._limit = None - self._offset = None + self._limit: Optional[int] = None + self._offset: Optional[int] = None self._wrapper_cls = wrapper_cls @builder - def orderby(self, *fields: Field, **kwargs: Any): - for field in fields: - field = ( - Field(field, table=self.base_query._from[0]) - if isinstance(field, str) - else self.base_query.wrap_constant(field) - ) - - self._orderbys.append((field, kwargs.get("order"))) + def orderby(self, *fields: Union[Field, str], order: Optional[Order] = None): + field: Union[None, Field, WrappedConstant] + if fields: + field_val = fields[-1] + if isinstance(field_val, str): + table = self.base_query._from[0] + if not isinstance(table, Table): + raise TypeError("expect the first \"from\" selectable is table, got {}".format(type(table).__name__)) + field = Field(field_val, table=table) + else: + field = self.base_query.wrap_constant(field_val) + else: + field = None + self._orderbys.append((field, order)) @builder def limit(self, limit: int): @@ -580,13 +600,13 @@ def except_of(self, other: Selectable): def minus(self, other: Selectable): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": + def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": + def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> "_SetOperation": + def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.minus(other) def __str__(self) -> str: @@ -647,12 +667,12 @@ def _orderby_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: the alias, otherwise the field will be rendered as SQL. """ clauses = [] - selected_aliases = {s.alias for s in self.base_query._selects} + selected_aliases = {s.alias for s in self.base_query._selects if isinstance(s, Term)} for field, directionality in self._orderbys: term = ( - format_quotes(field.alias, quote_char) - if field.alias and field.alias in selected_aliases - else field.get_sql(quote_char=quote_char, **kwargs) + format_quotes(field.alias, quote_char) # type: ignore + if field.alias and (field.alias in selected_aliases) # type: ignore + else field.get_sql(quote_char=quote_char, **kwargs) # type: ignore ) clauses.append( @@ -668,16 +688,16 @@ def _limit_sql(self) -> str: return " LIMIT {limit}".format(limit=self._limit) -class QueryBuilder(Selectable, Term): +class QueryBuilder(Selectable, Term, SQLPart): """ Query Builder is the main class in pypika which stores the state of a query and offers functions which allow the state to be branched immutably. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None - QUERY_ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None + QUERY_ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__( @@ -690,40 +710,40 @@ def __init__( ): super().__init__(None) - self._from = [] - self._insert_table = None - self._update_table = None + self._from: List[Union[Selectable, QueryBuilder, None]] = [] + self._insert_table: Optional[Table] = None + self._update_table: Optional[Table] = None self._delete_from = False self._replace = False - self._with = [] - self._selects = [] - self._force_indexes = [] - self._use_indexes = [] - self._columns = [] - self._values = [] + self._with: List[AliasedQuery] = [] + self._selects: List[Term] = [] + self._force_indexes: List[Index] = [] + self._use_indexes: List[Index] = [] + self._columns: List[Term] = [] + self._values: List[Sequence[Union[Term, WrappedConstant]]] = [] self._distinct = False self._ignore = False self._for_update = False - self._wheres = None - self._prewheres = None - self._groupbys = [] + self._wheres: Optional[Union[Term, Criterion]] = None + self._prewheres: Optional[Criterion] = None + self._groupbys: List[Union[Term, WrappedConstant]] = [] self._with_totals = False - self._havings = None - self._orderbys = [] - self._joins = [] - self._unions = [] - self._using = [] + self._havings: Optional[Union[Term, Criterion]] = None + self._orderbys: List[TypedTuple[Union[Field, WrappedConstant], Optional[Order]]] = [] + self._joins: List[Join] = [] + self._unions: List[None] = [] + self._using: List[Union[Selectable, str]] = [] - self._limit = None - self._offset = None + self._limit: Optional[int] = None + self._offset: Optional[int] = None - self._updates = [] + self._updates: List[TypedTuple[Field, ValueWrapper]] = [] self._select_star = False - self._select_star_tables = set() + self._select_star_tables: Set[Optional[Union[str, Selectable]]] = set() self._mysql_rollup = False self._select_into = False @@ -757,7 +777,7 @@ def __copy__(self) -> "QueryBuilder": return newone @builder - def from_(self, selectable: Union[Selectable, Query, str]): + def from_(self, selectable: Union[Selectable, "QueryBuilder", str]): """ Adds a table to the query. This function can only be called once and will raise an AttributeError if called a second time. @@ -801,18 +821,18 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._update_table = new_table if self._update_table == current_table else self._update_table self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] - self._selects = [select.replace_table(current_table, new_table) for select in self._selects] + self._selects = [select.replace_table(current_table, new_table) if isinstance(select, Term) else select for select in self._selects] self._columns = [column.replace_table(current_table, new_table) for column in self._columns] self._values = [ - [value.replace_table(current_table, new_table) for value in value_list] for value_list in self._values + [(value.replace_table(current_table, new_table) if isinstance(value, Term) else value) for value in value_list] for value_list in self._values ] self._wheres = self._wheres.replace_table(current_table, new_table) if self._wheres else None self._prewheres = self._prewheres.replace_table(current_table, new_table) if self._prewheres else None - self._groupbys = [groupby.replace_table(current_table, new_table) for groupby in self._groupbys] + self._groupbys = [groupby.replace_table(current_table, new_table) if isinstance(groupby, Term) else groupby for groupby in self._groupbys] self._havings = self._havings.replace_table(current_table, new_table) if self._havings else None self._orderbys = [ - (orderby[0].replace_table(current_table, new_table), orderby[1]) for orderby in self._orderbys + (orderby[0].replace_table(current_table, new_table), orderby[1]) if isinstance(orderby[0], Term) else orderby for orderby in self._orderbys ] self._joins = [join.replace_table(current_table, new_table) for join in self._joins] @@ -845,7 +865,8 @@ def select(self, *terms: Any): elif isinstance(term, (Function, ArithmeticExpression)): self._select_other(term) else: - self._select_other(self.wrap_constant(term, wrapper_cls=self._wrapper_cls)) + value = self.wrap_constant(term, wrapper_cls=self._wrapper_cls) + self._select_other(Term._assert_guard(value)) @builder def delete(self): @@ -862,14 +883,18 @@ def update(self, table: Union[str, Table]): self._update_table = table if isinstance(table, Table) else Table(table) @builder - def columns(self, *terms: Any): + def columns(self, *terms: Union[str, Field, List[Union[str, Field]], TypedTuple[Union[str, Field], ...]]) -> None: if self._insert_table is None: raise AttributeError("'Query' object has no attribute '%s'" % "insert") + columns: Iterable[Union[str, Field]] if terms and isinstance(terms[0], (list, tuple)): - terms = terms[0] + columns = terms[0] # FIXME: should not sliently ignore rest arguments + # Alternative solution: fix the type comment to tell use here only accepts one sequence. + else: + columns = cast(TypedTuple[Union[str, Field]], terms) - for term in terms: + for term in columns: if isinstance(term, str): term = Field(term, table=self._insert_table) self._columns.append(term) @@ -931,7 +956,7 @@ def where(self, criterion: Union[Term, EmptyCriterion]): self._foreign_table = True if self._wheres: - self._wheres &= criterion + self._wheres &= criterion # type: ignore else: self._wheres = criterion @@ -941,19 +966,25 @@ def having(self, criterion: Union[Term, EmptyCriterion]): return if self._havings: - self._havings &= criterion + self._havings &= criterion # type: ignore else: self._havings = criterion @builder def groupby(self, *terms: Union[str, int, Term]): + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) for term in terms: + new_term: Union[WrappedConstant, Field] if isinstance(term, str): - term = Field(term, table=self._from[0]) + new_term = Field(term, table=table) elif isinstance(term, int): - term = Field(str(term), table=self._from[0]).wrap_constant(term) + new_term = Field(str(term), table=table).wrap_constant(term) + else: + new_term = term - self._groupbys.append(term) + self._groupbys.append(new_term) @builder def with_totals(self): @@ -966,40 +997,43 @@ def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): if self._mysql_rollup: raise AttributeError("'Query' object has no attribute '%s'" % "rollup") - terms = [Tuple(*term) if isinstance(term, (list, tuple, set)) else term for term in terms] + wrapped_terms = [Tuple(*term) if isinstance(term, (list, tuple, set)) else term for term in terms] if for_mysql: # MySQL rolls up all of the dimensions always - if not terms and not self._groupbys: + if not wrapped_terms and not self._groupbys: raise RollupException( "At least one group is required. Call Query.groupby(term) or pass" "as parameter to rollup." ) self._mysql_rollup = True - self._groupbys += terms + self._groupbys += wrapped_terms elif 0 < len(self._groupbys) and isinstance(self._groupbys[-1], Rollup): # If a rollup was added last, then append the new terms to the previous rollup - self._groupbys[-1].args += terms + self._groupbys[-1].args += wrapped_terms else: - self._groupbys.append(Rollup(*terms)) + self._groupbys.append(Rollup(*wrapped_terms)) @builder - def orderby(self, *fields: Any, **kwargs: Any): + def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) for field in fields: - field = Field(field, table=self._from[0]) if isinstance(field, str) else self.wrap_constant(field) + target_field = Field(field, table=table) if isinstance(field, str) else self.wrap_constant(field) - self._orderbys.append((field, kwargs.get("order"))) + self._orderbys.append((target_field, order)) @builder def join( - self, item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], how: JoinType = JoinType.inner + self, item: Union[Table, "QueryBuilder", AliasedQuery, _SetOperation], how: JoinType = JoinType.inner ) -> "Joiner": if isinstance(item, Table): return Joiner(self, item, how, type_label="table") - elif isinstance(item, QueryBuilder): + elif isinstance(item, (QueryBuilder, _SetOperation)): if item.alias is None: self._tag_subquery(item) return Joiner(self, item, how, type_label="subquery") @@ -1007,9 +1041,6 @@ def join( elif isinstance(item, AliasedQuery): return Joiner(self, item, how, type_label="table") - elif isinstance(item, Selectable): - return Joiner(self, item, how, type_label="subquery") - raise ValueError("Cannot join on type '%s'" % type(item)) def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": @@ -1072,13 +1103,13 @@ def set(self, field: Union[Field, str], value: Any): field = Field(field) if not isinstance(field, Field) else field self._updates.append((field, self._wrapper_cls(value))) - def __add__(self, other: "QueryBuilder") -> _SetOperation: + def __add__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union(other) - def __mul__(self, other: "QueryBuilder") -> _SetOperation: + def __mul__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> _SetOperation: + def __sub__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.minus(other) @builder @@ -1086,7 +1117,7 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: + def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1103,8 +1134,10 @@ def _select_field_str(self, term: str) -> None: self._select_star = True self._selects = [Star()] return - - self._select_field(Field(term, table=self._from[0])) + table = self._from[0] + if not isinstance(table, Selectable): + raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) + self._select_field(Field(term, table=table)) def _select_field(self, term: Field) -> None: if self._select_star: @@ -1117,25 +1150,43 @@ def _select_field(self, term: Field) -> None: if isinstance(term, Star): self._selects = [ - select for select in self._selects if not hasattr(select, "table") or term.table != select.table + select for select in self._selects if (not hasattr(select, "table")) or (isinstance(select, Field) and term.table != select.table) ] self._select_star_tables.add(term.table) self._selects.append(term) - def _select_other(self, function: Function) -> None: + def _select_other(self, function: Term) -> None: self._selects.append(function) - def fields_(self) -> List[Field]: + def fields_(self) -> Set[Field]: # Don't return anything here. Subqueries have their own fields. - return [] + return set() def do_join(self, join: "Join") -> None: - base_tables = self._from + [self._update_table] + self._with + def _assert_not_none(v): + if v is not None: + return v + else: + raise TypeError("expect Selectable, got None") + base_tables = tuple( + map(_assert_not_none, chain( + self._from, + (self._update_table, ) if self._update_table else tuple(), + self._with + )) + ) join.validate(base_tables, self._joins) - table_in_query = any(isinstance(clause, Table) and join.item in base_tables for clause in base_tables) - if isinstance(join.item, Table) and join.item.alias is None and table_in_query: + table_in_query = reduce( + operator.add, + ( + clause._table_name == join.item._table_name + for clause in base_tables if isinstance(clause, Table) + ), + 0 + ) + if isinstance(join.item, Table) and (join.item.alias is None) and (table_in_query > 0): # On the odd chance that we join the same table as the FROM table and don't set an alias # FIXME only works once join.item.alias = join.item._table_name + "2" @@ -1166,7 +1217,7 @@ def _validate_table(self, term: Term) -> bool: return False return True - def _tag_subquery(self, subquery: "QueryBuilder") -> None: + def _tag_subquery(self, subquery: Union["QueryBuilder", _SetOperation]) -> None: subquery.alias = "sq%d" % self._subquery_count self._subquery_count += 1 @@ -1182,10 +1233,10 @@ def _apply_terms(self, *terms: Any) -> None: return if not isinstance(terms[0], (list, tuple, set)): - terms = [terms] + terms = (terms, ) for values in terms: - self._values.append([value if isinstance(value, Term) else self.wrap_constant(value) for value in values]) + self._values.append([(value if isinstance(value, Term) else self.wrap_constant(value)) for value in values]) def __str__(self) -> str: return self.get_sql(dialect=self.dialect) @@ -1193,7 +1244,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: Any) -> bool: # type: ignore if not isinstance(other, QueryBuilder): return False @@ -1202,7 +1253,7 @@ def __eq__(self, other: Any) -> bool: return True - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: Any) -> bool: # type: ignore return not self.__eq__(other) def __hash__(self) -> int: @@ -1384,12 +1435,14 @@ def _select_sql(self, **kwargs: Any) -> str: ) def _insert_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return "INSERT {ignore}INTO {table}".format( table=self._insert_table.get_sql(**kwargs), ignore="IGNORE " if self._ignore else "", ) def _replace_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return "REPLACE INTO {table}".format( table=self._insert_table.get_sql(**kwargs), ) @@ -1399,6 +1452,7 @@ def _delete_sql(**kwargs: Any) -> str: return "DELETE" def _update_sql(self, **kwargs: Any) -> str: + assert self._update_table is not None return "UPDATE {table}".format(table=self._update_table.get_sql(**kwargs)) def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: @@ -1410,27 +1464,35 @@ def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " ({columns})".format( columns=",".join(term.get_sql(with_namespace=False, **kwargs) for term in self._columns) ) + + @classmethod + def _assert_type_fn(cls, klass: Type[_T]) -> Callable[[Any], _T]: + def _assert_type(val: Any): + assert isinstance(val, klass) + return val + return _assert_type def _values_sql(self, **kwargs: Any) -> str: return " VALUES ({values})".format( values="),(".join( - ",".join(term.get_sql(with_alias=True, subquery=True, **kwargs) for term in row) for row in self._values + ",".join(term.get_sql(with_alias=True, subquery=True, **kwargs) for term in map(self._assert_type_fn(Term), row)) for row in self._values ) ) def _into_sql(self, **kwargs: Any) -> str: + assert self._insert_table is not None return " INTO {table}".format( table=self._insert_table.get_sql(with_alias=False, **kwargs), ) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " FROM {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) + selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) # type: ignore ) def _using_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " USING {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._using) + selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause for clause in self._using) ) def _force_index_sql(self, **kwargs: Any) -> str: @@ -1444,11 +1506,13 @@ def _use_index_sql(self, **kwargs: Any) -> str: ) def _prewhere_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + assert self._prewheres is not None return " PREWHERE {prewhere}".format( prewhere=self._prewheres.get_sql(quote_char=quote_char, subquery=True, **kwargs) ) def _where_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: + assert self._wheres is not None return " WHERE {where}".format(where=self._wheres.get_sql(quote_char=quote_char, subquery=True, **kwargs)) def _group_sql( @@ -1470,7 +1534,8 @@ def _group_sql( clauses = [] selected_aliases = {s.alias for s in self._selects} for field in self._groupbys: - if groupby_alias and field.alias and field.alias in selected_aliases: + assert isinstance(field, Term) + if groupby_alias and field.alias and (field.alias in selected_aliases): clauses.append(format_quotes(field.alias, alias_quote_char or quote_char)) else: clauses.append(field.get_sql(quote_char=quote_char, alias_quote_char=alias_quote_char, **kwargs)) @@ -1502,6 +1567,7 @@ def _orderby_sql( clauses = [] selected_aliases = {s.alias for s in self._selects} for field, directionality in self._orderbys: + assert isinstance(field, Term) term = ( format_quotes(field.alias, alias_quote_char or quote_char) if orderby_alias and field.alias and field.alias in selected_aliases @@ -1518,7 +1584,7 @@ def _rollup_sql(self) -> str: return " WITH ROLLUP" def _having_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: - return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) + return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) # type: ignore def _offset_sql(self) -> str: return " OFFSET {offset}".format(offset=self._offset) @@ -1537,9 +1603,12 @@ def _set_sql(self, **kwargs: Any) -> str: ) +JoinableTerm = Union[Table, "QueryBuilder", AliasedQuery, _SetOperation] + + class Joiner: def __init__( - self, query: QueryBuilder, item: Union[Table, "QueryBuilder", AliasedQuery], how: JoinType, type_label: str + self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str ) -> None: self.query = query self.item = item @@ -1562,12 +1631,12 @@ def on_field(self, *fields: Any) -> QueryBuilder: "Parameter 'fields' is required for a " "{type} JOIN but was not supplied.".format(type=self.type_label) ) - criterion = None + criterion: Optional[Criterion] = None for field in fields: consituent = Field(field, table=self.query._from[0]) == Field(field, table=self.item) - criterion = consituent if criterion is None else criterion & consituent + criterion = (criterion & consituent) if (criterion is not None) else consituent - self.query.do_join(JoinOn(self.item, self.how, criterion)) + self.query.do_join(JoinOn(self.item, self.how, cast(Criterion, criterion))) return self.query def using(self, *fields: Any) -> QueryBuilder: @@ -1584,8 +1653,8 @@ def cross(self) -> QueryBuilder: return self.query -class Join: - def __init__(self, item: Term, how: JoinType) -> None: +class Join(SQLPart): + def __init__(self, item: JoinableTerm, how: JoinType) -> None: self.item = item self.how = how @@ -1598,7 +1667,7 @@ def get_sql(self, **kwargs: Any) -> str: return "{type} {join}".format(join=sql, type=self.how.value) return sql - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable["Join"]) -> None: pass @builder @@ -1618,7 +1687,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl class JoinOn(Join): - def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: Optional[str] = None) -> None: + def __init__(self, item: JoinableTerm, how: JoinType, criteria: Criterion, collate: Optional[str] = None) -> None: super().__init__(item, how) self.criterion = criteria self.collate = collate @@ -1631,7 +1700,7 @@ def get_sql(self, **kwargs: Any) -> str: collate=" COLLATE {}".format(self.collate) if self.collate else "", ) - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable[Join]) -> None: criterion_tables = set([f.table for f in self.criterion.fields_()]) available_tables = set(_from) | {join.item for join in _joins} | {self.item} missing_tables = criterion_tables - available_tables @@ -1656,12 +1725,15 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = new_table if self.item == current_table else self.item - self.criterion = self.criterion.replace_table(current_table, new_table) + if new_table is not None: + self.item = new_table if self.item == current_table else self.item + self.criterion = self.criterion.replace_table(current_table, new_table) + else: + raise ValueError("new_table should not be None for {}".format(type(self).__name__)) class JoinUsing(Join): - def __init__(self, item: Term, how: JoinType, fields: Sequence[Field]) -> None: + def __init__(self, item: JoinableTerm, how: JoinType, fields: Sequence[Field]) -> None: super().__init__(item, how) self.fields = fields @@ -1672,7 +1744,7 @@ def get_sql(self, **kwargs: Any) -> str: fields=",".join(field.get_sql(**kwargs) for field in self.fields), ) - def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: + def validate(self, _from: Iterable[Selectable], _joins: Iterable[Join]) -> None: pass @builder @@ -1688,37 +1760,40 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = new_table if self.item == current_table else self.item - self.fields = [field.replace_table(current_table, new_table) for field in self.fields] + if new_table is not None: + self.item = new_table if self.item == current_table else self.item + self.fields = [field.replace_table(current_table, new_table) for field in self.fields] + else: + raise ValueError("new_table should not be None for {}".format(type(self).__name__)) -class CreateQueryBuilder: +class CreateQueryBuilder(SQLPart): """ Query builder used to build CREATE queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: - self._create_table = None + self._create_table: Optional[Table] = None self._temporary = False self._unlogged = False - self._as_select = None - self._columns = [] - self._period_fors = [] + self._as_select: Optional[QueryBuilder] = None + self._columns: List[Column] = [] + self._period_fors: List[PeriodFor] = [] self._with_system_versioning = False - self._primary_key = None - self._uniques = [] + self._primary_key: Optional[List[Column]] = [] + self._uniques: List[Iterable[Column]] = [] self._if_not_exists = False self.dialect = dialect - self._foreign_key = None - self._foreign_key_reference_table = None - self._foreign_key_reference = None - self._foreign_key_on_update: ReferenceOption = None - self._foreign_key_on_delete: ReferenceOption = None + self._foreign_key: Optional[List[Column]] = None + self._foreign_key_reference_table: Optional[Union[Table, str]] = None + self._foreign_key_reference: Optional[List[Column]] = None + self._foreign_key_on_update: Optional[ReferenceOption] = None + self._foreign_key_on_delete: Optional[ReferenceOption] = None def _set_kwargs_defaults(self, kwargs: dict) -> None: kwargs.setdefault("quote_char", self.QUOTE_CHAR) @@ -1974,7 +2049,7 @@ def _create_table_sql(self, **kwargs: Any) -> str: return "CREATE {table_type}TABLE {if_not_exists}{table}".format( table_type=table_type, if_not_exists=if_not_exists, - table=self._create_table.get_sql(**kwargs), + table=self._create_table.get_sql(**kwargs), # type: ignore ) def _table_options_sql(self, **kwargs) -> str: @@ -1999,14 +2074,17 @@ def _unique_key_clauses(self, **kwargs) -> List[str]: def _primary_key_clause(self, **kwargs) -> str: return "PRIMARY KEY ({columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) + columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) # type: ignore ) def _foreign_key_clause(self, **kwargs) -> str: clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), - table_name=self._foreign_key_reference_table.get_sql(**kwargs), - reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), + columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore + table_name=( + self._foreign_key_reference_table.get_sql(**kwargs) + if isinstance(self._foreign_key_reference_table, Table) + else Table(self._foreign_key_reference_table).get_sql()), # type: ignore + reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), # type: ignore ) if self._foreign_key_on_delete: clause += " ON DELETE " + self._foreign_key_on_delete.value @@ -2029,10 +2107,10 @@ def _body_sql(self, **kwargs) -> str: def _as_select_sql(self, **kwargs: Any) -> str: return " AS ({query})".format( - query=self._as_select.get_sql(**kwargs), + query=self._as_select.get_sql(**kwargs), # type: ignore ) - def _prepare_columns_input(self, columns: List[Union[str, Column]]) -> List[Column]: + def _prepare_columns_input(self, columns: Iterable[Union[str, Column]]) -> List[Column]: return [(column if isinstance(column, Column) else Column(column)) for column in columns] def __str__(self) -> str: @@ -2042,18 +2120,18 @@ def __repr__(self) -> str: return self.__str__() -class DropQueryBuilder: +class DropQueryBuilder(SQLPart): """ Query builder used to build DROP queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR = None + QUOTE_CHAR: Optional[str] = '"' + SECONDARY_QUOTE_CHAR: Optional[str] = "'" + ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_CLS = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: - self._drop_target_kind = None + self._drop_target_kind: Optional[str] = None self._drop_target: Union[Database, Table, str] = "" self._if_exists = None self.dialect = dialect diff --git a/pypika/utils.py b/pypika/utils.py index 58bcc166..f902249e 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING +from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable if TYPE_CHECKING: import sys from typing import overload, TypeVar @@ -141,3 +141,10 @@ def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: for arg in args: if not isinstance(arg, type): raise exc + + +@runtime_checkable +class SQLPart(Protocol): + """This protocol indicates the class can generate a part of SQL""" + def get_sql(self, **kwargs) -> str: + ... From 97e5f7fcf5843c25138689afa7ebb633b0f351b5 Mon Sep 17 00:00:00 2001 From: thisLight Date: Tue, 20 Sep 2022 21:31:38 +0800 Subject: [PATCH 12/15] random fixes in terms.py --- pypika/terms.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index fb9fe2ca..6cb87284 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -4,7 +4,7 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, List, MutableSequence, Optional, Sequence, Set, Type, TypeVar, Union from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -15,10 +15,12 @@ format_quotes, ignore_copy, resolve_is_aggregate, + SQLPart, ) if TYPE_CHECKING: from pypika.queries import QueryBuilder, Selectable, Table + from _typeshed import Self __author__ = "Timothy Heys" @@ -39,9 +41,13 @@ def nodes_(self) -> Iterator["Node"]: def find_(self, type: Type[NodeT]) -> List[NodeT]: return [node for node in self.nodes_() if isinstance(node, type)] -WrappedConstant = Union[Node, "LiteralValue", "Array", "Tuple", "ValueWrapper"] -class Term(Node): +WrappedConstantStrict = Union["LiteralValue", "Array", "Tuple", "ValueWrapper"] + + +WrappedConstant = Union[Node, WrappedConstantStrict] + +class Term(Node, SQLPart): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias @@ -110,7 +116,7 @@ def wrap_json( return JSON(val) - def replace_table(self, current_table: Optional["Table"], new_table: Optional["Table"]) -> "Term": + def replace_table(self: "Self", current_table: Optional["Table"], new_table: Optional["Table"]) -> "Self": """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. The base implementation returns self because not all terms have a table property. @@ -123,6 +129,10 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T Self. """ return self + + # FIXME: separate all operator override to another class, + # some term does not have these operators overrides, for example Table, + # cause inconsistent behaviour def eq(self, other: Any) -> "BasicCriterion": return self == other @@ -596,7 +606,7 @@ def get_sql(self, **kwargs: Any) -> str: # type: ignore name=field_sql, ) - field_alias = getattr(self, "alias", None) + field_alias = self.alias if with_alias: return format_alias_sql(field_sql, field_alias, quote_char=quote_char, **kwargs) return field_sql @@ -1284,7 +1294,7 @@ class Function(Criterion): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args: Iterable[WrappedConstant] = [self.wrap_constant(param) for param in args] + self.args: MutableSequence[WrappedConstant] = [self.wrap_constant(param) for param in args] self.schema = kwargs.get("schema") def nodes_(self) -> Iterator[Node]: @@ -1654,7 +1664,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name -class AtTimezone(Term): +class AtTimezone(Term, SQLPart): """ Generates AT TIME ZONE SQL. Examples: From 9079ea0323c40ecc4342ae4ccf4745cc8e802307 Mon Sep 17 00:00:00 2001 From: thisLight Date: Tue, 20 Sep 2022 22:05:35 +0800 Subject: [PATCH 13/15] fix random errors in dialects.py --- pypika/dialects.py | 17 +++++++++++------ pypika/queries.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index 8ede83cd..ed885710 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,6 @@ import itertools from copy import copy -from typing import Any, List, Optional, Set, Union, Tuple as TypedTuple, cast +from typing import Any, Iterable, List, NoReturn, Optional, Set, Union, Tuple as TypedTuple, cast from pypika.enums import Dialects from pypika.queries import ( @@ -11,6 +11,7 @@ Table, Query, QueryBuilder, + JoinOn, ) from pypika.terms import ArithmeticExpression, Criterion, EmptyCriterion, Field, Function, Star, Term, ValueWrapper from pypika.utils import QueryException, builder, format_quotes @@ -431,7 +432,7 @@ def for_update( self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuilder": + def on_conflict(self, *target_fields: Union[str, Term]) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -439,7 +440,9 @@ def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuild for target_field in target_fields: if isinstance(target_field, str): - self._on_conflict_fields.append(self._conflict_field_str(target_field)) + field = self._conflict_field_str(target_field) + assert field is not None + self._on_conflict_fields.append(field) elif isinstance(target_field, Term): self._on_conflict_fields.append(target_field) @@ -594,8 +597,8 @@ def _validate_returning_term(self, term: Term) -> None: raise QueryException("Returning can't be used in this query") table_is_insert_or_update_table = field.table in {self._insert_table, self._update_table} - join_tables = set(itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins])) - join_and_base_tables = set(self._from) | join_tables + join_tables = set(itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins if isinstance(j, JoinOn)])) + join_and_base_tables = set(cast(Iterable[Table], filter(lambda v: isinstance(v, Table), self._from))) | join_tables table_not_base_or_join = bool(term.tables_ - join_and_base_tables) if not table_is_insert_or_update_table and table_not_base_or_join: raise QueryException("You can't return from other tables") @@ -798,7 +801,9 @@ def _update_sql(self, **kwargs: Any) -> str: return "ALTER TABLE {table}".format(table=self._update_table.get_sql(**kwargs)) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: - selectable = ",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) + def _error_none(v) -> NoReturn: + raise TypeError("expect Selectable or QueryBuilder, got {}".format(type(v).__name__)) + selectable = ",".join((clause.get_sql(subquery=True, with_alias=True, **kwargs) if clause is not None else _error_none(clause)) for clause in self._from) if self._delete_from: return " {selectable} DELETE".format(selectable=selectable) return " FROM {selectable}".format(selectable=selectable) diff --git a/pypika/queries.py b/pypika/queries.py index d59a72cd..4b0add17 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -2,7 +2,7 @@ from functools import reduce from itertools import chain import operator -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set, cast, TypeVar +from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set, cast, TypeVar from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( From 46f6bfe3e936dc88dad55f53e5d6a58cd2c2a93f Mon Sep 17 00:00:00 2001 From: thisLight Date: Tue, 20 Sep 2022 22:06:35 +0800 Subject: [PATCH 14/15] reformatted files using black --- pypika/dialects.py | 28 ++++----- pypika/queries.py | 139 ++++++++++++++++++++++++++++----------------- pypika/terms.py | 58 ++++++++++++------- pypika/utils.py | 12 +++- 4 files changed, 150 insertions(+), 87 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index ed885710..ab3f1f20 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -104,9 +104,7 @@ def __copy__(self) -> "MySQLQueryBuilder": return newone @builder - def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple() - ): + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple()): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait @@ -127,7 +125,7 @@ def on_duplicate_key_ignore(self): self._ignore_duplicates = True - def get_sql(self, **kwargs: Any) -> str: # type: ignore + def get_sql(self, **kwargs: Any) -> str: # type: ignore self._set_kwargs_defaults(kwargs) querystring = super(MySQLQueryBuilder, self).get_sql(**kwargs) if querystring: @@ -423,9 +421,7 @@ def distinct_on(self, *fields: Union[str, Term]): self._distinct_on.append(field) @builder - def for_update( - self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple() - ): + def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = tuple()): self._for_update = True self._for_update_skip_locked = skip_locked self._for_update_nowait = nowait @@ -453,9 +449,7 @@ def do_nothing(self): self._on_conflict_do_nothing = True @builder - def do_update( - self, update_field: Union[str, Field], update_value: Optional[Any] = None - ): + def do_update(self, update_field: Union[str, Field], update_value: Optional[Any] = None): if self._on_conflict_do_nothing: raise QueryException("Can not have two conflict handlers") @@ -597,8 +591,12 @@ def _validate_returning_term(self, term: Term) -> None: raise QueryException("Returning can't be used in this query") table_is_insert_or_update_table = field.table in {self._insert_table, self._update_table} - join_tables = set(itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins if isinstance(j, JoinOn)])) - join_and_base_tables = set(cast(Iterable[Table], filter(lambda v: isinstance(v, Table), self._from))) | join_tables + join_tables = set( + itertools.chain.from_iterable([j.criterion.tables_ for j in self._joins if isinstance(j, JoinOn)]) + ) + join_and_base_tables = ( + set(cast(Iterable[Table], filter(lambda v: isinstance(v, Table), self._from))) | join_tables + ) table_not_base_or_join = bool(term.tables_ - join_and_base_tables) if not table_is_insert_or_update_table and table_not_base_or_join: raise QueryException("You can't return from other tables") @@ -803,7 +801,11 @@ def _update_sql(self, **kwargs: Any) -> str: def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: def _error_none(v) -> NoReturn: raise TypeError("expect Selectable or QueryBuilder, got {}".format(type(v).__name__)) - selectable = ",".join((clause.get_sql(subquery=True, with_alias=True, **kwargs) if clause is not None else _error_none(clause)) for clause in self._from) + + selectable = ",".join( + (clause.get_sql(subquery=True, with_alias=True, **kwargs) if clause is not None else _error_none(clause)) + for clause in self._from + ) if self._delete_from: return " {selectable} DELETE".format(selectable=selectable) return " FROM {selectable}".format(selectable=selectable) diff --git a/pypika/queries.py b/pypika/queries.py index 4b0add17..81d235c2 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -2,7 +2,21 @@ from functools import reduce from itertools import chain import operator -from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set, cast, TypeVar +from typing import ( + Any, + Callable, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple as TypedTuple, + Type, + Union, + Set, + cast, + TypeVar, +) from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( @@ -548,7 +562,9 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [(set_operation, set_operation_query)] + self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [ + (set_operation, set_operation_query) + ] self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] self._limit: Optional[int] = None @@ -564,7 +580,9 @@ def orderby(self, *fields: Union[Field, str], order: Optional[Order] = None): if isinstance(field_val, str): table = self.base_query._from[0] if not isinstance(table, Table): - raise TypeError("expect the first \"from\" selectable is table, got {}".format(type(table).__name__)) + raise TypeError( + "expect the first \"from\" selectable is table, got {}".format(type(table).__name__) + ) field = Field(field_val, table=table) else: field = self.base_query.wrap_constant(field_val) @@ -600,13 +618,13 @@ def except_of(self, other: Selectable): def minus(self, other: Selectable): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore + def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.minus(other) def __str__(self) -> str: @@ -670,9 +688,9 @@ def _orderby_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: selected_aliases = {s.alias for s in self.base_query._selects if isinstance(s, Term)} for field, directionality in self._orderbys: term = ( - format_quotes(field.alias, quote_char) # type: ignore - if field.alias and (field.alias in selected_aliases) # type: ignore - else field.get_sql(quote_char=quote_char, **kwargs) # type: ignore + format_quotes(field.alias, quote_char) # type: ignore + if field.alias and (field.alias in selected_aliases) # type: ignore + else field.get_sql(quote_char=quote_char, **kwargs) # type: ignore ) clauses.append( @@ -821,18 +839,31 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._update_table = new_table if self._update_table == current_table else self._update_table self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] - self._selects = [select.replace_table(current_table, new_table) if isinstance(select, Term) else select for select in self._selects] + self._selects = [ + select.replace_table(current_table, new_table) if isinstance(select, Term) else select + for select in self._selects + ] self._columns = [column.replace_table(current_table, new_table) for column in self._columns] self._values = [ - [(value.replace_table(current_table, new_table) if isinstance(value, Term) else value) for value in value_list] for value_list in self._values + [ + (value.replace_table(current_table, new_table) if isinstance(value, Term) else value) + for value in value_list + ] + for value_list in self._values ] self._wheres = self._wheres.replace_table(current_table, new_table) if self._wheres else None self._prewheres = self._prewheres.replace_table(current_table, new_table) if self._prewheres else None - self._groupbys = [groupby.replace_table(current_table, new_table) if isinstance(groupby, Term) else groupby for groupby in self._groupbys] + self._groupbys = [ + groupby.replace_table(current_table, new_table) if isinstance(groupby, Term) else groupby + for groupby in self._groupbys + ] self._havings = self._havings.replace_table(current_table, new_table) if self._havings else None self._orderbys = [ - (orderby[0].replace_table(current_table, new_table), orderby[1]) if isinstance(orderby[0], Term) else orderby for orderby in self._orderbys + (orderby[0].replace_table(current_table, new_table), orderby[1]) + if isinstance(orderby[0], Term) + else orderby + for orderby in self._orderbys ] self._joins = [join.replace_table(current_table, new_table) for join in self._joins] @@ -889,7 +920,7 @@ def columns(self, *terms: Union[str, Field, List[Union[str, Field]], TypedTuple[ columns: Iterable[Union[str, Field]] if terms and isinstance(terms[0], (list, tuple)): - columns = terms[0] # FIXME: should not sliently ignore rest arguments + columns = terms[0] # FIXME: should not sliently ignore rest arguments # Alternative solution: fix the type comment to tell use here only accepts one sequence. else: columns = cast(TypedTuple[Union[str, Field]], terms) @@ -956,7 +987,7 @@ def where(self, criterion: Union[Term, EmptyCriterion]): self._foreign_table = True if self._wheres: - self._wheres &= criterion # type: ignore + self._wheres &= criterion # type: ignore else: self._wheres = criterion @@ -966,7 +997,7 @@ def having(self, criterion: Union[Term, EmptyCriterion]): return if self._havings: - self._havings &= criterion # type: ignore + self._havings &= criterion # type: ignore else: self._havings = criterion @@ -1103,13 +1134,13 @@ def set(self, field: Union[Field, str], value: Any): field = Field(field) if not isinstance(field, Field) else field self._updates.append((field, self._wrapper_cls(value))) - def __add__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore + def __add__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union(other) - def __mul__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore + def __mul__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.union_all(other) - def __sub__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore + def __sub__(self, other: "QueryBuilder") -> _SetOperation: # type: ignore return self.minus(other) @builder @@ -1117,7 +1148,7 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore + def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1150,7 +1181,9 @@ def _select_field(self, term: Field) -> None: if isinstance(term, Star): self._selects = [ - select for select in self._selects if (not hasattr(select, "table")) or (isinstance(select, Field) and term.table != select.table) + select + for select in self._selects + if (not hasattr(select, "table")) or (isinstance(select, Field) and term.table != select.table) ] self._select_star_tables.add(term.table) @@ -1169,23 +1202,20 @@ def _assert_not_none(v): return v else: raise TypeError("expect Selectable, got None") + base_tables = tuple( - map(_assert_not_none, chain( - self._from, - (self._update_table, ) if self._update_table else tuple(), - self._with - )) + map( + _assert_not_none, + chain(self._from, (self._update_table,) if self._update_table else tuple(), self._with), ) + ) join.validate(base_tables, self._joins) table_in_query = reduce( operator.add, - ( - clause._table_name == join.item._table_name - for clause in base_tables if isinstance(clause, Table) - ), - 0 - ) + (clause._table_name == join.item._table_name for clause in base_tables if isinstance(clause, Table)), + 0, + ) if isinstance(join.item, Table) and (join.item.alias is None) and (table_in_query > 0): # On the odd chance that we join the same table as the FROM table and don't set an alias # FIXME only works once @@ -1233,7 +1263,7 @@ def _apply_terms(self, *terms: Any) -> None: return if not isinstance(terms[0], (list, tuple, set)): - terms = (terms, ) + terms = (terms,) for values in terms: self._values.append([(value if isinstance(value, Term) else self.wrap_constant(value)) for value in values]) @@ -1244,7 +1274,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other: Any) -> bool: # type: ignore + def __eq__(self, other: Any) -> bool: # type: ignore if not isinstance(other, QueryBuilder): return False @@ -1253,7 +1283,7 @@ def __eq__(self, other: Any) -> bool: # type: ignore return True - def __ne__(self, other: Any) -> bool: # type: ignore + def __ne__(self, other: Any) -> bool: # type: ignore return not self.__eq__(other) def __hash__(self) -> int: @@ -1464,18 +1494,23 @@ def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " ({columns})".format( columns=",".join(term.get_sql(with_namespace=False, **kwargs) for term in self._columns) ) - + @classmethod def _assert_type_fn(cls, klass: Type[_T]) -> Callable[[Any], _T]: def _assert_type(val: Any): assert isinstance(val, klass) return val + return _assert_type def _values_sql(self, **kwargs: Any) -> str: return " VALUES ({values})".format( values="),(".join( - ",".join(term.get_sql(with_alias=True, subquery=True, **kwargs) for term in map(self._assert_type_fn(Term), row)) for row in self._values + ",".join( + term.get_sql(with_alias=True, subquery=True, **kwargs) + for term in map(self._assert_type_fn(Term), row) + ) + for row in self._values ) ) @@ -1487,12 +1522,15 @@ def _into_sql(self, **kwargs: Any) -> str: def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " FROM {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) # type: ignore + selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from) # type: ignore ) def _using_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " USING {selectable}".format( - selectable=",".join(clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause for clause in self._using) + selectable=",".join( + clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause + for clause in self._using + ) ) def _force_index_sql(self, **kwargs: Any) -> str: @@ -1584,7 +1622,7 @@ def _rollup_sql(self) -> str: return " WITH ROLLUP" def _having_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: - return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) # type: ignore + return " HAVING {having}".format(having=self._havings.get_sql(quote_char=quote_char, **kwargs)) # type: ignore def _offset_sql(self) -> str: return " OFFSET {offset}".format(offset=self._offset) @@ -1607,9 +1645,7 @@ def _set_sql(self, **kwargs: Any) -> str: class Joiner: - def __init__( - self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str - ) -> None: + def __init__(self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how @@ -1876,9 +1912,7 @@ def columns(self, *columns: Union[str, TypedTuple[str, str], Column]): self._columns.append(column) @builder - def period_for( - self, name, start_column: Union[str, Column], end_column: Union[str, Column] - ): + def period_for(self, name, start_column: Union[str, Column], end_column: Union[str, Column]): """ Adds a PERIOD FOR clause. @@ -2049,7 +2083,7 @@ def _create_table_sql(self, **kwargs: Any) -> str: return "CREATE {table_type}TABLE {if_not_exists}{table}".format( table_type=table_type, if_not_exists=if_not_exists, - table=self._create_table.get_sql(**kwargs), # type: ignore + table=self._create_table.get_sql(**kwargs), # type: ignore ) def _table_options_sql(self, **kwargs) -> str: @@ -2074,17 +2108,18 @@ def _unique_key_clauses(self, **kwargs) -> List[str]: def _primary_key_clause(self, **kwargs) -> str: return "PRIMARY KEY ({columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) # type: ignore + columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key) # type: ignore ) def _foreign_key_clause(self, **kwargs) -> str: clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( - columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore + columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore table_name=( self._foreign_key_reference_table.get_sql(**kwargs) if isinstance(self._foreign_key_reference_table, Table) - else Table(self._foreign_key_reference_table).get_sql()), # type: ignore - reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), # type: ignore + else Table(self._foreign_key_reference_table).get_sql() + ), # type: ignore + reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference), # type: ignore ) if self._foreign_key_on_delete: clause += " ON DELETE " + self._foreign_key_on_delete.value @@ -2107,7 +2142,7 @@ def _body_sql(self, **kwargs) -> str: def _as_select_sql(self, **kwargs: Any) -> str: return " AS ({query})".format( - query=self._as_select.get_sql(**kwargs), # type: ignore + query=self._as_select.get_sql(**kwargs), # type: ignore ) def _prepare_columns_input(self, columns: Iterable[Union[str, Column]]) -> List[Column]: diff --git a/pypika/terms.py b/pypika/terms.py index 6cb87284..a831c216 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -4,7 +4,21 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, List, MutableSequence, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Set, + Type, + TypeVar, + Union, +) from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -47,10 +61,11 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstant = Union[Node, WrappedConstantStrict] + class Term(Node, SQLPart): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias - + @property def is_aggregate(self) -> Optional[bool]: return False @@ -69,9 +84,7 @@ def fields_(self) -> Set["Field"]: return set(self.find_(Field)) @staticmethod - def wrap_constant( - val, wrapper_cls: Optional[Type["Term"]] = None - ) -> WrappedConstant: + def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -129,7 +142,7 @@ def replace_table(self: "Self", current_table: Optional["Table"], new_table: Opt Self. """ return self - + # FIXME: separate all operator override to another class, # some term does not have these operators overrides, for example Table, # cause inconsistent behaviour @@ -271,10 +284,10 @@ def __rlshift__(self, other: Any) -> "ArithmeticExpression": def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) - def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore + def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) - def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore + def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) def __gt__(self, other: Any) -> "BasicCriterion": @@ -302,7 +315,7 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() - + @classmethod def _assert_guard(cls, v: Any) -> "Term": if isinstance(v, cls): @@ -318,7 +331,7 @@ def __init__(self, placeholder: Union[str, int]) -> None: def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) - + @property def is_aggregate(self) -> Optional[bool]: return None @@ -369,7 +382,7 @@ class Negative(Term): def __init__(self, term: Term) -> None: super().__init__() self.term = term - + @property def is_aggregate(self) -> Optional[bool]: return self.term.is_aggregate @@ -382,7 +395,7 @@ class ValueWrapper(Term): def __init__(self, value: Any, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value - + @property def is_aggregate(self) -> Optional[bool]: return None @@ -554,11 +567,11 @@ def __or__(self, other: Any) -> Any: def __xor__(self, other: Any) -> Any: return other - + @property def is_aggregate(self) -> Optional[bool]: return None - + @property def tables_(self) -> Set: return set() @@ -591,7 +604,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T """ self.table = new_table if self.table == current_table else self.table - def get_sql(self, **kwargs: Any) -> str: # type: ignore + def get_sql(self, **kwargs: Any) -> str: # type: ignore with_alias = kwargs.pop("with_alias", False) with_namespace = kwargs.pop("with_namespace", False) quote_char = kwargs.pop("quote_char", None) @@ -630,11 +643,13 @@ def nodes_(self) -> Iterator[Node]: if self.table is not None and not isinstance(self.table, str): yield from self.table.nodes_() - def get_sql( # type: ignore + def get_sql( # type: ignore self, with_alias: bool = False, with_namespace: bool = False, quote_char: Optional[str] = None, **kwargs: Any ) -> str: if self.table and (with_namespace or (not isinstance(self.table, str) and self.table.alias)): - namespace = (self.table.alias if not isinstance(self.table, str) else self.table) or getattr(self.table, "_table_name") + namespace = (self.table.alias if not isinstance(self.table, str) else self.table) or getattr( + self.table, "_table_name" + ) return "{}.*".format(format_quotes(namespace, quote_char)) return "*" @@ -998,7 +1013,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class ComplexCriterion(BasicCriterion): - def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: # type: ignore + def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: # type: ignore sql = "{left} {comparator} {right}".format( comparator=self.comparator.value, left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs), @@ -1134,7 +1149,7 @@ class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self._cases: List[typing.Tuple[Any, Any]] = [] - self._else: WrappedConstant| None = None + self._else: WrappedConstant | None = None def nodes_(self) -> Iterator[Node]: yield self @@ -1290,6 +1305,7 @@ def _has_params(self): def _is_valid_function_call(self, *args): return len(args) == len(self.params) + class Function(Criterion): def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) @@ -1461,9 +1477,11 @@ def get_function_sql(self, **kwargs: Any) -> str: AnyEdge = Union[str, "WindowFrameAnalyticFunction.Edge"] + class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: modifier: ClassVar[Optional[str]] = None + def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value @@ -1632,7 +1650,7 @@ def get_sql(self, **kwargs: Any) -> str: if unit is None: unit = "DAY" - return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) # type: ignore + return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) # type: ignore class Pow(Function): diff --git a/pypika/utils.py b/pypika/utils.py index f902249e..07e63e8d 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,7 +1,9 @@ from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable + if TYPE_CHECKING: import sys from typing import overload, TypeVar + if sys.version_info >= (3, 10): from typing import ParamSpec, Concatenate else: @@ -49,10 +51,15 @@ class FunctionException(Exception): _P = ParamSpec('_P') if TYPE_CHECKING: + @overload - def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: ... + def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: + ... + @overload - def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: ... + def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: + ... + def builder(func): """ @@ -146,5 +153,6 @@ def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: @runtime_checkable class SQLPart(Protocol): """This protocol indicates the class can generate a part of SQL""" + def get_sql(self, **kwargs) -> str: ... From 150da3574f9b69a0a3411d27302c057cf9a138cb Mon Sep 17 00:00:00 2001 From: Rubicon Rowe Date: Mon, 27 Feb 2023 22:27:53 +0800 Subject: [PATCH 15/15] Added type checking CI flow (#2) * Added type checking CI flow * ci/typechecking: Removed parallel limit * ci/typechecking: Limit triggers * ci/typechecking: Set trigger to push * ci/typechecking: fixed CI using python 3.1 instead of 3.10 --- .github/workflows/typechecking.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/typechecking.yml diff --git a/.github/workflows/typechecking.yml b/.github/workflows/typechecking.yml new file mode 100644 index 00000000..ec3ad083 --- /dev/null +++ b/.github/workflows/typechecking.yml @@ -0,0 +1,23 @@ +name: "Type Checking" + +on: push + +jobs: + typechecking: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v4.5.0 + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + pip install mypy + + - name: Run type checking + run: | + mypy -p pypika --python-version ${{ matrix.python-version }}