1- from typing import List , Union
1+ from typing import List , Optional , Union
22
33from redis .commands .search .dialect import DEFAULT_DIALECT
44
@@ -26,10 +26,10 @@ class Reducer:
2626
2727 NAME = None
2828
29- def __init__ (self , * args : List [ str ] ) -> None :
30- self ._args = args
31- self ._field = None
32- self ._alias = None
29+ def __init__ (self , * args : str ) -> None :
30+ self ._args : tuple [ str , ...] = args
31+ self ._field : Optional [ str ] = None
32+ self ._alias : Optional [ str ] = None
3333
3434 def alias (self , alias : str ) -> "Reducer" :
3535 """
@@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
4949 if alias is FIELDNAME :
5050 if not self ._field :
5151 raise ValueError ("Cannot use FIELDNAME alias with no field" )
52- # Chop off initial '@'
53- alias = self ._field [1 :]
52+ else :
53+ # Chop off initial '@'
54+ alias = self ._field [1 :]
5455 self ._alias = alias
5556 return self
5657
5758 @property
58- def args (self ) -> List [str ]:
59+ def args (self ) -> tuple [str , ... ]:
5960 return self ._args
6061
6162
@@ -64,7 +65,7 @@ class SortDirection:
6465 This special class is used to indicate sort direction.
6566 """
6667
67- DIRSTRING = None
68+ DIRSTRING : Optional [ str ] = None
6869
6970 def __init__ (self , field : str ) -> None :
7071 self .field = field
@@ -104,19 +105,19 @@ def __init__(self, query: str = "*") -> None:
104105 All member methods (except `build_args()`)
105106 return the object itself, making them useful for chaining.
106107 """
107- self ._query = query
108- self ._aggregateplan = []
109- self ._loadfields = []
110- self ._loadall = False
111- self ._max = 0
112- self ._with_schema = False
113- self ._verbatim = False
114- self ._cursor = []
115- self ._dialect = DEFAULT_DIALECT
116- self ._add_scores = False
117- self ._scorer = "TFIDF"
118-
119- def load (self , * fields : List [ str ] ) -> "AggregateRequest" :
108+ self ._query : str = query
109+ self ._aggregateplan : List [ str ] = []
110+ self ._loadfields : List [ str ] = []
111+ self ._loadall : bool = False
112+ self ._max : int = 0
113+ self ._with_schema : bool = False
114+ self ._verbatim : bool = False
115+ self ._cursor : List [ str ] = []
116+ self ._dialect : int = DEFAULT_DIALECT
117+ self ._add_scores : bool = False
118+ self ._scorer : str = "TFIDF"
119+
120+ def load (self , * fields : str ) -> "AggregateRequest" :
120121 """
121122 Indicate the fields to be returned in the response. These fields are
122123 returned in addition to any others implicitly specified.
@@ -133,7 +134,7 @@ def load(self, *fields: List[str]) -> "AggregateRequest":
133134 return self
134135
135136 def group_by (
136- self , fields : List [str ], * reducers : Union [ Reducer , List [ Reducer ]]
137+ self , fields : Union [ str , List [str ]] , * reducers : Reducer
137138 ) -> "AggregateRequest" :
138139 """
139140 Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ def group_by(
147148 `aggregation` module.
148149 """
149150 fields = [fields ] if isinstance (fields , str ) else fields
150- reducers = [reducers ] if isinstance (reducers , Reducer ) else reducers
151151
152152 ret = ["GROUPBY" , str (len (fields )), * fields ]
153153 for reducer in reducers :
@@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest":
223223 self ._aggregateplan .extend (_limit .build_args ())
224224 return self
225225
226- def sort_by (self , * fields : List [ str ] , ** kwargs ) -> "AggregateRequest" :
226+ def sort_by (self , * fields : str , ** kwargs ) -> "AggregateRequest" :
227227 """
228228 Indicate how the results should be sorted. This can also be used for
229229 *top-N* style queries
@@ -251,12 +251,10 @@ def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
251251 .sort_by(Desc("@paid"), max=10)
252252 ```
253253 """
254- if isinstance (fields , (str , SortDirection )):
255- fields = [fields ]
256254
257255 fields_args = []
258256 for f in fields :
259- if isinstance (f , SortDirection ):
257+ if isinstance (f , ( Asc , Desc ) ):
260258 fields_args += [f .field , f .DIRSTRING ]
261259 else :
262260 fields_args += [f ]
@@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
356354 ret .extend (self ._loadfields )
357355
358356 if self ._dialect :
359- ret .extend (["DIALECT" , self ._dialect ])
357+ ret .extend (["DIALECT" , str ( self ._dialect ) ])
360358
361359 ret .extend (self ._aggregateplan )
362360
@@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
393391 self .cursor = cursor
394392 self .schema = schema
395393
396- def __repr__ (self ) -> ( str , str ) :
394+ def __repr__ (self ) -> str :
397395 cid = self .cursor .cid if self .cursor else - 1
398396 return (
399397 f"<{ self .__class__ .__name__ } at 0x{ id (self ):x} "
0 commit comments