From a74ea67a349ee3a30e97086d2eadf38adf374c61 Mon Sep 17 00:00:00 2001 From: Glyphack Date: Sun, 30 Jul 2023 11:57:59 +0200 Subject: [PATCH] Basic implementation for handling variable declarations --- parser/src/parser/ast.rs | 69 ++++++++++ typechecker/src/ast_visitor.rs | 4 +- typechecker/src/build.rs | 48 ++++--- typechecker/src/nodes.rs | 7 + typechecker/src/semantic_analyzer.rs | 86 ++++++++---- ...checker__build__tests__assign_stmt-2.snap} | 54 ++------ ...echecker__build__tests__assign_stmt-3.snap | 124 ++++++++++++++++++ ...echecker__build__tests__assign_stmt-4.snap | 56 ++++++++ ...echecker__build__tests__assign_stmt-5.snap | 9 ++ ...ypechecker__build__tests__assign_stmt.snap | 46 +++++++ typechecker/src/symbol_table.rs | 2 +- 11 files changed, 414 insertions(+), 91 deletions(-) rename typechecker/src/snapshots/{typechecker__build__tests__create_symbol_table.snap => typechecker__build__tests__assign_stmt-2.snap} (54%) create mode 100644 typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap create mode 100644 typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap create mode 100644 typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap create mode 100644 typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap diff --git a/parser/src/parser/ast.rs b/parser/src/parser/ast.rs index 46dadf59..5f3e765c 100644 --- a/parser/src/parser/ast.rs +++ b/parser/src/parser/ast.rs @@ -19,6 +19,10 @@ impl Node { } } +trait GetNode { + fn get_node(&self) -> Node; +} + impl From for SourceSpan { fn from(val: Node) -> Self { Self::new( @@ -64,6 +68,37 @@ pub enum Statement { Match(Match), } +impl GetNode for Statement { + fn get_node(&self) -> Node { + match self { + Statement::AssignStatement(s) => s.node, + Statement::AnnAssignStatement(s) => s.node, + Statement::AugAssignStatement(s) => s.node, + Statement::ExpressionStatement(s) => s.get_node(), + Statement::Assert(s) => s.node, + Statement::Pass(s) => s.node, + Statement::Delete(s) => s.node, + Statement::Return(s) => s.node, + Statement::Raise(s) => s.node, + Statement::Break(s) => s.node, + Statement::Continue(s) => s.node, + Statement::Import(s) => s.node, + Statement::ImportFrom(s) => s.node, + Statement::Global(s) => s.node, + Statement::Nonlocal(s) => s.node, + Statement::IfStatement(s) => s.node, + Statement::WhileStatement(s) => s.node, + Statement::ForStatement(s) => s.node, + Statement::WithStatement(s) => s.node, + Statement::TryStatement(s) => s.node, + Statement::TryStarStatement(s) => s.node, + Statement::FunctionDef(s) => s.node, + Statement::ClassDef(s) => s.node, + Statement::Match(s) => s.node, + } + } +} + #[derive(Debug, Clone)] pub struct Assign { pub node: Node, @@ -218,6 +253,40 @@ pub enum Expression { FormattedValue(Box), } +impl GetNode for Expression { + fn get_node(&self) -> Node { + match self { + Expression::Constant(c) => c.node, + Expression::List(l) => l.node, + Expression::Tuple(t) => t.node, + Expression::Dict(d) => d.node, + Expression::Set(s) => s.node, + Expression::Name(n) => n.node, + Expression::BoolOp(b) => b.node, + Expression::UnaryOp(u) => u.node, + Expression::BinOp(b) => b.node, + Expression::NamedExpr(n) => n.node, + Expression::Yield(y) => y.node, + Expression::YieldFrom(y) => y.node, + Expression::Starred(s) => s.node, + Expression::Generator(g) => g.node, + Expression::ListComp(l) => l.node, + Expression::SetComp(s) => s.node, + Expression::DictComp(d) => d.node, + Expression::Attribute(a) => a.node, + Expression::Subscript(s) => s.node, + Expression::Slice(s) => s.node, + Expression::Call(c) => c.node, + Expression::Await(a) => a.node, + Expression::Compare(c) => c.node, + Expression::Lambda(l) => l.node, + Expression::IfExp(i) => i.node, + Expression::JoinedStr(j) => j.node, + Expression::FormattedValue(f) => f.node, + } + } +} + // https://docs.python.org/3/reference/expressions.html#atom-identifiers #[derive(Debug, Clone)] pub struct Name { diff --git a/typechecker/src/ast_visitor.rs b/typechecker/src/ast_visitor.rs index 5844f954..7888b3c3 100644 --- a/typechecker/src/ast_visitor.rs +++ b/typechecker/src/ast_visitor.rs @@ -248,9 +248,7 @@ pub trait TraversalVisitor { fn visit_set(&mut self, s: &Set) { todo!() } - fn visit_name(&mut self, n: &Name) { - todo!() - } + fn visit_name(&mut self, n: &Name) {} fn visit_bool_op(&mut self, b: &BoolOperation) { todo!() } diff --git a/typechecker/src/build.rs b/typechecker/src/build.rs index cddb9345..d4966da8 100644 --- a/typechecker/src/build.rs +++ b/typechecker/src/build.rs @@ -104,25 +104,33 @@ mod tests { } #[test] - fn create_symbol_table() { - let source = "a = 'hello world'\nb = a + 1"; - let path = write_temp_source(source); - let mut manager = BuildManager::new( - vec![BuildSource { - path, - module: String::from("test"), - source: source.to_string(), - followed: false, - }], - Settings::test_settings(), - ); - manager.build(); - let module = manager.modules.values().last().unwrap(); - insta::with_settings!({ - description => "simple assignment", // the template source code - omit_expression => true // do not include the default expression - }, { - assert_debug_snapshot!(module.symbol_table); - }); + fn assign_stmt() { + let sources = vec![ + "a = 'hello world'", + "b = a + 1", + "c,d = 1,2", + "a: int = 1", + "a += b", + ]; + for source in sources { + let path = write_temp_source(source); + let mut manager = BuildManager::new( + vec![BuildSource { + path, + module: String::from("test"), + source: source.to_string(), + followed: false, + }], + Settings::test_settings(), + ); + manager.build(); + let module = manager.modules.values().last().unwrap(); + insta::with_settings!({ + description => source, // the template source code + omit_expression => true // do not include the default expression + }, { + assert_debug_snapshot!(module.symbol_table); + }); + } } } diff --git a/typechecker/src/nodes.rs b/typechecker/src/nodes.rs index f982ff9c..9e5db330 100755 --- a/typechecker/src/nodes.rs +++ b/typechecker/src/nodes.rs @@ -46,6 +46,13 @@ impl<'a> TraversalVisitor for EnderpyFile { self.defs.push(Statement::AssignStatement(stmt)); } + fn visit_ann_assign(&mut self, a: &parser::ast::AnnAssign) { + let stmt = a.clone(); + self.defs.push(Statement::AnnAssignStatement(stmt)); + } + + fn visit_aug_assign(&mut self, a: &parser::ast::AugAssign) {} + fn visit_import(&mut self, i: &Import) { let import = i.clone(); self.imports.push(ImportKinds::Import(import)); diff --git a/typechecker/src/semantic_analyzer.rs b/typechecker/src/semantic_analyzer.rs index dad06585..27c3d26d 100644 --- a/typechecker/src/semantic_analyzer.rs +++ b/typechecker/src/semantic_analyzer.rs @@ -29,7 +29,7 @@ impl SemanticAnalyzer { }; } - fn add_declaration_to_symbol_table(&mut self, name: String, decl: Declaration) { + fn create_symbol(&mut self, name: String, decl: Declaration) { let symbol_node = SymbolTableNode { name, declarations: vec![decl], @@ -49,11 +49,42 @@ impl SemanticAnalyzer { fn current_scope(&self) -> SymbolScope { SymbolScope::Global } + + fn create_variable_declaration_symbol( + &mut self, + target: &Expression, + value: Option, + declaration_path: DeclarationPath, + type_annotation: Option, + ) { + match target { + Expression::Name(n) => { + let decl = Declaration::Variable(Box::new(Variable { + declaration_path, + scope: self.current_scope(), + type_annotation, + inferred_type_source: value, + is_constant: false, + })); + self.create_symbol(n.id.clone(), decl) + } + Expression::Tuple(t) => { + for elm in t.elements.iter() { + self.create_variable_declaration_symbol( + elm, + value.clone(), + declaration_path.clone(), + type_annotation.clone(), + ) + } + } + _ => panic!("cannot assign to {:?} is not supported", target), + } + } } impl TraversalVisitor for SemanticAnalyzer { fn visit_stmt(&mut self, s: &parser::ast::Statement) { - // map all statements and call visit match s { parser::ast::Statement::ExpressionStatement(e) => self.visit_expr(e), parser::ast::Statement::Import(i) => self.visit_import(i), @@ -219,7 +250,7 @@ impl TraversalVisitor for SemanticAnalyzer { } fn visit_tuple(&mut self, t: &parser::ast::Tuple) { - todo!() + return; } fn visit_dict(&mut self, d: &parser::ast::Dict) { @@ -321,36 +352,47 @@ impl TraversalVisitor for SemanticAnalyzer { } fn visit_assign(&mut self, assign: &parser::ast::Assign) { + let value = &assign.value; if assign.targets.len() > 1 { - panic!("assignment to multiple targets not implemented") + panic!("multiple assignment not suported"); } - let name_node = match assign.targets.last().unwrap() { - Expression::Name(n) => n, - _ => panic!("assignment to other than name node"), + let target = assign.targets.last().unwrap(); + let declaration_path = DeclarationPath { + module_name: self.file.module_name.clone(), + node: assign.node, }; - - let declared_name = name_node.id.clone(); - let decl = Declaration::Variable(Box::new(Variable { - declaration_path: DeclarationPath { - module_name: self.file.module_name.clone(), - node: assign.node, - }, - scope: self.current_scope(), - type_annotation: None, - inferred_type_source: Some(assign.value.clone()), - is_constant: false, - })); - self.add_declaration_to_symbol_table(declared_name, decl); + self.create_variable_declaration_symbol( + target, + Some(value.clone()), + declaration_path, + None, + ); self.visit_expr(&assign.value); } fn visit_ann_assign(&mut self, a: &parser::ast::AnnAssign) { - todo!() + let value = &a.value; + let target = &a.target; + let declaration_path = DeclarationPath { + module_name: self.file.module_name.clone(), + node: a.node, + }; + self.create_variable_declaration_symbol( + target, + value.clone(), + declaration_path, + Some(a.annotation.clone()), + ); + + if let Some(val) = &a.value { + self.visit_expr(&val); + } } fn visit_aug_assign(&mut self, a: &parser::ast::AugAssign) { - todo!() + self.visit_expr(&a.target); + self.visit_expr(&a.value); } fn visit_assert(&mut self, a: &parser::ast::Assert) { diff --git a/typechecker/src/snapshots/typechecker__build__tests__create_symbol_table.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap similarity index 54% rename from typechecker/src/snapshots/typechecker__build__tests__create_symbol_table.snap rename to typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap index 9b9f4d23..edca698f 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__create_symbol_table.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap @@ -1,46 +1,10 @@ --- source: typechecker/src/build.rs -description: simple assignment +description: b = a + 1 --- SymbolTable { symbol_table_type: Module, symbols: { - "a": SymbolTableNode { - name: "a", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 17, - }, - }, - scope: Global, - type_annotation: None, - inferred_type_source: Some( - Constant( - Constant { - node: Node { - start: 4, - end: 17, - }, - value: Str( - "hello world", - ), - }, - ), - ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, - }, "b": SymbolTableNode { name: "b", declarations: [ @@ -49,8 +13,8 @@ SymbolTable { declaration_path: DeclarationPath { module_name: "test", node: Node { - start: 18, - end: 27, + start: 0, + end: 9, }, }, scope: Global, @@ -59,15 +23,15 @@ SymbolTable { BinOp( BinOp { node: Node { - start: 22, - end: 27, + start: 4, + end: 9, }, op: Add, left: Name( Name { node: Node { - start: 22, - end: 23, + start: 4, + end: 5, }, id: "a", }, @@ -75,8 +39,8 @@ SymbolTable { right: Constant( Constant { node: Node { - start: 26, - end: 27, + start: 8, + end: 9, }, value: Int( "1", diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap new file mode 100644 index 00000000..4e08920d --- /dev/null +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap @@ -0,0 +1,124 @@ +--- +source: typechecker/src/build.rs +description: "c,d = 1,2" +--- +SymbolTable { + symbol_table_type: Module, + symbols: { + "c": SymbolTableNode { + name: "c", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 0, + end: 9, + }, + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Tuple( + Tuple { + node: Node { + start: 6, + end: 9, + }, + elements: [ + Constant( + Constant { + node: Node { + start: 6, + end: 7, + }, + value: Int( + "1", + ), + }, + ), + Constant( + Constant { + node: Node { + start: 8, + end: 9, + }, + value: Int( + "2", + ), + }, + ), + ], + }, + ), + ), + is_constant: false, + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + scope: Global, + }, + "d": SymbolTableNode { + name: "d", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 0, + end: 9, + }, + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Tuple( + Tuple { + node: Node { + start: 6, + end: 9, + }, + elements: [ + Constant( + Constant { + node: Node { + start: 6, + end: 7, + }, + value: Int( + "1", + ), + }, + ), + Constant( + Constant { + node: Node { + start: 8, + end: 9, + }, + value: Int( + "2", + ), + }, + ), + ], + }, + ), + ), + is_constant: false, + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + scope: Global, + }, + }, + start_line_number: 0, +} diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap new file mode 100644 index 00000000..efa19177 --- /dev/null +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap @@ -0,0 +1,56 @@ +--- +source: typechecker/src/build.rs +description: "a: int = 1" +--- +SymbolTable { + symbol_table_type: Module, + symbols: { + "a": SymbolTableNode { + name: "a", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 0, + end: 10, + }, + }, + scope: Global, + type_annotation: Some( + Name( + Name { + node: Node { + start: 3, + end: 6, + }, + id: "int", + }, + ), + ), + inferred_type_source: Some( + Constant( + Constant { + node: Node { + start: 9, + end: 10, + }, + value: Int( + "1", + ), + }, + ), + ), + is_constant: false, + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + scope: Global, + }, + }, + start_line_number: 0, +} diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap new file mode 100644 index 00000000..2160df9f --- /dev/null +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap @@ -0,0 +1,9 @@ +--- +source: typechecker/src/build.rs +description: a += b +--- +SymbolTable { + symbol_table_type: Module, + symbols: {}, + start_line_number: 0, +} diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap new file mode 100644 index 00000000..0ae2a9cb --- /dev/null +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap @@ -0,0 +1,46 @@ +--- +source: typechecker/src/build.rs +description: "a = 'hello world'" +--- +SymbolTable { + symbol_table_type: Module, + symbols: { + "a": SymbolTableNode { + name: "a", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 0, + end: 17, + }, + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Constant( + Constant { + node: Node { + start: 4, + end: 17, + }, + value: Str( + "hello world", + ), + }, + ), + ), + is_constant: false, + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + scope: Global, + }, + }, + start_line_number: 0, +} diff --git a/typechecker/src/symbol_table.rs b/typechecker/src/symbol_table.rs index 62a95724..ca52784f 100644 --- a/typechecker/src/symbol_table.rs +++ b/typechecker/src/symbol_table.rs @@ -29,7 +29,7 @@ pub struct SymbolTableNode { pub scope: SymbolScope, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DeclarationPath { pub module_name: String, pub node: Node,