@@ -268,6 +268,29 @@ impl SymbolTableBuilder {
268268            . last ( ) 
269269            . expect ( "Scope stack should never be empty" ) 
270270    } 
271+ 
272+     fn  with_type_params ( 
273+         & mut  self , 
274+         name :  & str , 
275+         params :  & Option < Box < ast:: TypeParams > > , 
276+         nested :  impl  FnOnce ( & mut  Self ) , 
277+     )  { 
278+         if  let  Some ( type_params)  = params { 
279+             self . push_scope ( self . cur_scope ( ) ,  name,  ScopeKind :: Annotation ) ; 
280+             for  type_param in  & type_params. type_params  { 
281+                 let  name = match  type_param { 
282+                     ast:: TypeParam :: TypeVar ( ast:: TypeParamTypeVar  {  name,  .. } )  => name, 
283+                     ast:: TypeParam :: ParamSpec ( ast:: TypeParamParamSpec  {  name,  .. } )  => name, 
284+                     ast:: TypeParam :: TypeVarTuple ( ast:: TypeParamTypeVarTuple  {  name,  .. } )  => name, 
285+                 } ; 
286+                 self . add_symbol ( name) ; 
287+             } 
288+         } 
289+         nested ( self ) ; 
290+         if  params. is_some ( )  { 
291+             self . pop_scope ( ) ; 
292+         } 
293+     } 
271294} 
272295
273296impl < ' a >  PreorderVisitor < ' a >  for  SymbolTableBuilder  { 
@@ -280,17 +303,25 @@ impl<'a> PreorderVisitor<'a> for SymbolTableBuilder {
280303
281304    fn  visit_stmt ( & mut  self ,  stmt :  & ' a  ast:: Stmt )  { 
282305        match  stmt { 
283-             ast:: Stmt :: ClassDef ( ast:: StmtClassDef  {  name,  .. } )  => { 
306+             ast:: Stmt :: ClassDef ( ast:: StmtClassDef  { 
307+                 name,  type_params,  ..
308+             } )  => { 
284309                self . add_symbol ( name) ; 
285-                 self . push_scope ( self . cur_scope ( ) ,  name,  ScopeKind :: Class ) ; 
286-                 ast:: visitor:: preorder:: walk_stmt ( self ,  stmt) ; 
287-                 self . pop_scope ( ) ; 
310+                 self . with_type_params ( name,  type_params,  |builder| { 
311+                     builder. push_scope ( builder. cur_scope ( ) ,  name,  ScopeKind :: Class ) ; 
312+                     ast:: visitor:: preorder:: walk_stmt ( builder,  stmt) ; 
313+                     builder. pop_scope ( ) ; 
314+                 } ) ; 
288315            } 
289-             ast:: Stmt :: FunctionDef ( ast:: StmtFunctionDef  {  name,  .. } )  => { 
316+             ast:: Stmt :: FunctionDef ( ast:: StmtFunctionDef  { 
317+                 name,  type_params,  ..
318+             } )  => { 
290319                self . add_symbol ( name) ; 
291-                 self . push_scope ( self . cur_scope ( ) ,  name,  ScopeKind :: Function ) ; 
292-                 ast:: visitor:: preorder:: walk_stmt ( self ,  stmt) ; 
293-                 self . pop_scope ( ) ; 
320+                 self . with_type_params ( name,  type_params,  |builder| { 
321+                     builder. push_scope ( builder. cur_scope ( ) ,  name,  ScopeKind :: Function ) ; 
322+                     ast:: visitor:: preorder:: walk_stmt ( builder,  stmt) ; 
323+                     builder. pop_scope ( ) ; 
324+                 } ) ; 
294325            } 
295326            ast:: Stmt :: Import ( ast:: StmtImport  {  names,  .. } )  => { 
296327                for  alias in  names { 
@@ -424,6 +455,56 @@ mod tests {
424455            assert_eq ! ( names( table. symbols_for_scope( func_scope_1) ) ,  vec![ "x" ] ) ; 
425456            assert_eq ! ( names( table. symbols_for_scope( func_scope_2) ) ,  vec![ "y" ] ) ; 
426457        } 
458+ 
459+         #[ test]  
460+         fn  generic_func ( )  { 
461+             let  table = build ( 
462+                 " 
463+                 def func[T](): 
464+                     x = 1 
465+                 " , 
466+             ) ; 
467+             assert_eq ! ( names( table. root_symbols( ) ) ,  vec![ "func" ] ) ; 
468+             let  scopes = table. root_child_scopes ( ) ; 
469+             assert_eq ! ( scopes. len( ) ,  1 ) ; 
470+             let  ann_scope_id = scopes[ 0 ] ; 
471+             let  ann_scope = & table. scopes_by_id [ ann_scope_id] ; 
472+             assert_eq ! ( ann_scope. kind,  ScopeKind :: Annotation ) ; 
473+             assert_eq ! ( ann_scope. name. as_str( ) ,  "func" ) ; 
474+             assert_eq ! ( names( table. symbols_for_scope( ann_scope_id) ) ,  vec![ "T" ] ) ; 
475+             let  scopes = table. child_scopes_of ( ann_scope_id) ; 
476+             assert_eq ! ( scopes. len( ) ,  1 ) ; 
477+             let  func_scope_id = scopes[ 0 ] ; 
478+             let  func_scope = & table. scopes_by_id [ func_scope_id] ; 
479+             assert_eq ! ( func_scope. kind,  ScopeKind :: Function ) ; 
480+             assert_eq ! ( func_scope. name. as_str( ) ,  "func" ) ; 
481+             assert_eq ! ( names( table. symbols_for_scope( func_scope_id) ) ,  vec![ "x" ] ) ; 
482+         } 
483+ 
484+         #[ test]  
485+         fn  generic_class ( )  { 
486+             let  table = build ( 
487+                 " 
488+                 class C[T]: 
489+                     x = 1 
490+                 " , 
491+             ) ; 
492+             assert_eq ! ( names( table. root_symbols( ) ) ,  vec![ "C" ] ) ; 
493+             let  scopes = table. root_child_scopes ( ) ; 
494+             assert_eq ! ( scopes. len( ) ,  1 ) ; 
495+             let  ann_scope_id = scopes[ 0 ] ; 
496+             let  ann_scope = & table. scopes_by_id [ ann_scope_id] ; 
497+             assert_eq ! ( ann_scope. kind,  ScopeKind :: Annotation ) ; 
498+             assert_eq ! ( ann_scope. name. as_str( ) ,  "C" ) ; 
499+             assert_eq ! ( names( table. symbols_for_scope( ann_scope_id) ) ,  vec![ "T" ] ) ; 
500+             let  scopes = table. child_scopes_of ( ann_scope_id) ; 
501+             assert_eq ! ( scopes. len( ) ,  1 ) ; 
502+             let  func_scope_id = scopes[ 0 ] ; 
503+             let  func_scope = & table. scopes_by_id [ func_scope_id] ; 
504+             assert_eq ! ( func_scope. kind,  ScopeKind :: Class ) ; 
505+             assert_eq ! ( func_scope. name. as_str( ) ,  "C" ) ; 
506+             assert_eq ! ( names( table. symbols_for_scope( func_scope_id) ) ,  vec![ "x" ] ) ; 
507+         } 
427508    } 
428509
429510    #[ test]  
0 commit comments