Skip to content

Commit 654adf8

Browse files
committed
support PEP 695 generics in symbol table
1 parent 4a7f7e0 commit 654adf8

File tree

1 file changed

+89
-8
lines changed

1 file changed

+89
-8
lines changed

crates/red_knot/src/symbols.rs

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,29 @@ impl SymbolTableBuilder {
268268
.last()
269269
.expect("Scope stack should never be empty")
270270
}
271+
272+
fn with_type_params(
273+
&mut self,
274+
name: &str,
275+
params: &Option<Box<ast::TypeParams>>,
276+
nested: impl FnOnce(&mut Self),
277+
) {
278+
if let Some(type_params) = params {
279+
self.push_scope(self.cur_scope(), name, ScopeKind::Annotation);
280+
for type_param in &type_params.type_params {
281+
let name = match type_param {
282+
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
283+
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
284+
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
285+
};
286+
self.add_symbol(name);
287+
}
288+
}
289+
nested(self);
290+
if params.is_some() {
291+
self.pop_scope();
292+
}
293+
}
271294
}
272295

273296
impl<'a> PreorderVisitor<'a> for SymbolTableBuilder {
@@ -280,17 +303,25 @@ impl<'a> PreorderVisitor<'a> for SymbolTableBuilder {
280303

281304
fn visit_stmt(&mut self, stmt: &'a ast::Stmt) {
282305
match stmt {
283-
ast::Stmt::ClassDef(ast::StmtClassDef { name, .. }) => {
306+
ast::Stmt::ClassDef(ast::StmtClassDef {
307+
name, type_params, ..
308+
}) => {
284309
self.add_symbol(name);
285-
self.push_scope(self.cur_scope(), name, ScopeKind::Class);
286-
ast::visitor::preorder::walk_stmt(self, stmt);
287-
self.pop_scope();
310+
self.with_type_params(name, type_params, |builder| {
311+
builder.push_scope(builder.cur_scope(), name, ScopeKind::Class);
312+
ast::visitor::preorder::walk_stmt(builder, stmt);
313+
builder.pop_scope();
314+
});
288315
}
289-
ast::Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) => {
316+
ast::Stmt::FunctionDef(ast::StmtFunctionDef {
317+
name, type_params, ..
318+
}) => {
290319
self.add_symbol(name);
291-
self.push_scope(self.cur_scope(), name, ScopeKind::Function);
292-
ast::visitor::preorder::walk_stmt(self, stmt);
293-
self.pop_scope();
320+
self.with_type_params(name, type_params, |builder| {
321+
builder.push_scope(builder.cur_scope(), name, ScopeKind::Function);
322+
ast::visitor::preorder::walk_stmt(builder, stmt);
323+
builder.pop_scope();
324+
});
294325
}
295326
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
296327
for alias in names {
@@ -424,6 +455,56 @@ mod tests {
424455
assert_eq!(names(table.symbols_for_scope(func_scope_1)), vec!["x"]);
425456
assert_eq!(names(table.symbols_for_scope(func_scope_2)), vec!["y"]);
426457
}
458+
459+
#[test]
460+
fn generic_func() {
461+
let table = build(
462+
"
463+
def func[T]():
464+
x = 1
465+
",
466+
);
467+
assert_eq!(names(table.root_symbols()), vec!["func"]);
468+
let scopes = table.root_child_scopes();
469+
assert_eq!(scopes.len(), 1);
470+
let ann_scope_id = scopes[0];
471+
let ann_scope = &table.scopes_by_id[ann_scope_id];
472+
assert_eq!(ann_scope.kind, ScopeKind::Annotation);
473+
assert_eq!(ann_scope.name.as_str(), "func");
474+
assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]);
475+
let scopes = table.child_scopes_of(ann_scope_id);
476+
assert_eq!(scopes.len(), 1);
477+
let func_scope_id = scopes[0];
478+
let func_scope = &table.scopes_by_id[func_scope_id];
479+
assert_eq!(func_scope.kind, ScopeKind::Function);
480+
assert_eq!(func_scope.name.as_str(), "func");
481+
assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]);
482+
}
483+
484+
#[test]
485+
fn generic_class() {
486+
let table = build(
487+
"
488+
class C[T]:
489+
x = 1
490+
",
491+
);
492+
assert_eq!(names(table.root_symbols()), vec!["C"]);
493+
let scopes = table.root_child_scopes();
494+
assert_eq!(scopes.len(), 1);
495+
let ann_scope_id = scopes[0];
496+
let ann_scope = &table.scopes_by_id[ann_scope_id];
497+
assert_eq!(ann_scope.kind, ScopeKind::Annotation);
498+
assert_eq!(ann_scope.name.as_str(), "C");
499+
assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]);
500+
let scopes = table.child_scopes_of(ann_scope_id);
501+
assert_eq!(scopes.len(), 1);
502+
let func_scope_id = scopes[0];
503+
let func_scope = &table.scopes_by_id[func_scope_id];
504+
assert_eq!(func_scope.kind, ScopeKind::Class);
505+
assert_eq!(func_scope.name.as_str(), "C");
506+
assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]);
507+
}
427508
}
428509

429510
#[test]

0 commit comments

Comments
 (0)