diff --git a/parser/src/parser/ast.rs b/parser/src/parser/ast.rs index 5f3e765c..158a2920 100644 --- a/parser/src/parser/ast.rs +++ b/parser/src/parser/ast.rs @@ -40,6 +40,7 @@ pub struct Module { pub body: Vec, } +// Use box to reduce the enum size #[derive(Debug, Clone)] pub enum Statement { AssignStatement(Assign), diff --git a/typechecker/src/build.rs b/typechecker/src/build.rs index d4966da8..7ad5f3d7 100644 --- a/typechecker/src/build.rs +++ b/typechecker/src/build.rs @@ -111,6 +111,10 @@ mod tests { "c,d = 1,2", "a: int = 1", "a += b", + "def f(): + a = 1 + return +", ]; for source in sources { let path = write_temp_source(source); diff --git a/typechecker/src/semantic_analyzer.rs b/typechecker/src/semantic_analyzer.rs index 27c3d26d..448ccd5b 100644 --- a/typechecker/src/semantic_analyzer.rs +++ b/typechecker/src/semantic_analyzer.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::collections::HashMap; use parser::ast::Expression; @@ -6,7 +6,8 @@ use crate::{ ast_visitor::TraversalVisitor, nodes::EnderpyFile, symbol_table::{ - Declaration, DeclarationPath, SymbolScope, SymbolTable, SymbolTableNode, Variable, + Declaration, DeclarationPath, Function, SymbolScope, SymbolTable, SymbolTableNode, + SymbolTableScope, SymbolTableType, Variable, }, }; @@ -15,6 +16,8 @@ pub struct SemanticAnalyzer { // TODO: Replace errors with another type file: Box, errors: Vec, + + // TOD: Not needed? scope: SymbolScope, } @@ -36,7 +39,6 @@ impl SemanticAnalyzer { module_public: false, module_hidden: false, implicit: false, - scope: self.scope, }; self.globals.add_symbol(symbol_node) } @@ -46,8 +48,12 @@ impl SemanticAnalyzer { .push(String::from(format!("cannot resolve reference {}", ""))) } - fn current_scope(&self) -> SymbolScope { - SymbolScope::Global + fn current_scope(&self) -> &SymbolTableType { + return self.globals.current_scope_type(); + } + + fn is_inside_class(&self) -> bool { + return matches!(self.current_scope(), SymbolTableType::Class); } fn create_variable_declaration_symbol( @@ -61,7 +67,8 @@ impl SemanticAnalyzer { Expression::Name(n) => { let decl = Declaration::Variable(Box::new(Variable { declaration_path, - scope: self.current_scope(), + // TODO: Hacky way + scope: SymbolScope::Global, type_annotation, inferred_type_source: value, is_constant: false, @@ -224,9 +231,39 @@ impl TraversalVisitor for SemanticAnalyzer { } fn visit_function_def(&mut self, f: &parser::ast::FunctionDef) { + let declaration_path = DeclarationPath { + module_name: self.file.module_name.clone(), + node: f.node, + }; + self.globals.enter_scope(SymbolTableScope::new( + crate::symbol_table::SymbolTableType::Function, + )); + let mut return_statements = vec![]; + let mut yeild_statements = vec![]; + let mut raise_statements = vec![]; for stmt in &f.body { self.visit_stmt(&stmt); + match &stmt { + parser::ast::Statement::Raise(r) => raise_statements.push(r.clone()), + parser::ast::Statement::Return(r) => return_statements.push(r.clone()), + parser::ast::Statement::ExpressionStatement(e) => match e { + parser::ast::Expression::Yield(y) => yeild_statements.push(*y.clone()), + _ => (), + }, + _ => (), + } } + self.globals.exit_scope(); + + let function_declaration = Declaration::Function(Box::new(Function { + declaration_path, + is_method: self.is_inside_class(), + is_generator: !yeild_statements.is_empty(), + return_statements, + yeild_statements, + raise_statements, + })); + self.create_symbol(f.name.clone(), function_declaration); } fn visit_class_def(&mut self, c: &parser::ast::ClassDef) { @@ -407,9 +444,7 @@ impl TraversalVisitor for SemanticAnalyzer { todo!() } - fn visit_return(&mut self, r: &parser::ast::Return) { - todo!() - } + fn visit_return(&mut self, r: &parser::ast::Return) {} fn visit_raise(&mut self, r: &parser::ast::Raise) { todo!() diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap index edca698f..2fccd604 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-2.snap @@ -3,62 +3,65 @@ source: typechecker/src/build.rs description: b = a + 1 --- SymbolTable { - symbol_table_type: Module, - symbols: { - "b": SymbolTableNode { - name: "b", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 9, - }, - }, - scope: Global, - type_annotation: None, - inferred_type_source: Some( - BinOp( - BinOp { + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: { + "b": SymbolTableNode { + name: "b", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", node: Node { - start: 4, + start: 0, end: 9, }, - op: Add, - left: Name( - Name { + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + BinOp( + BinOp { node: Node { start: 4, - end: 5, - }, - id: "a", - }, - ), - right: Constant( - Constant { - node: Node { - start: 8, end: 9, }, - value: Int( - "1", + op: Add, + left: Name( + Name { + node: Node { + start: 4, + end: 5, + }, + id: "a", + }, + ), + right: Constant( + Constant { + node: Node { + start: 8, + end: 9, + }, + value: Int( + "1", + ), + }, ), }, ), - }, - ), + ), + is_constant: false, + }, ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, }, - }, - start_line_number: 0, + ], + all_scopes: [], } diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap index 4e08920d..1d0e505d 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-3.snap @@ -3,122 +3,124 @@ source: typechecker/src/build.rs description: "c,d = 1,2" --- SymbolTable { - symbol_table_type: Module, - symbols: { - "c": SymbolTableNode { - name: "c", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 9, - }, - }, - scope: Global, - type_annotation: None, - inferred_type_source: Some( - Tuple( - Tuple { + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: { + "d": SymbolTableNode { + name: "d", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", node: Node { - start: 6, + start: 0, end: 9, }, - elements: [ - Constant( - Constant { - node: Node { - start: 6, - end: 7, - }, - value: Int( - "1", - ), + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Tuple( + Tuple { + node: Node { + start: 6, + end: 9, }, - ), - Constant( - Constant { - node: Node { - start: 8, - end: 9, - }, - value: Int( - "2", + elements: [ + Constant( + Constant { + node: Node { + start: 6, + end: 7, + }, + value: Int( + "1", + ), + }, ), - }, - ), - ], - }, - ), - ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, - }, - "d": SymbolTableNode { - name: "d", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 9, + Constant( + Constant { + node: Node { + start: 8, + end: 9, + }, + value: Int( + "2", + ), + }, + ), + ], + }, + ), + ), + is_constant: false, }, - }, - scope: Global, - type_annotation: None, - inferred_type_source: Some( - Tuple( - Tuple { + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + "c": SymbolTableNode { + name: "c", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", node: Node { - start: 6, + start: 0, end: 9, }, - elements: [ - Constant( - Constant { - node: Node { - start: 6, - end: 7, - }, - value: Int( - "1", - ), + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Tuple( + Tuple { + node: Node { + start: 6, + end: 9, }, - ), - Constant( - Constant { - node: Node { - start: 8, - end: 9, - }, - value: Int( - "2", + elements: [ + Constant( + Constant { + node: Node { + start: 6, + end: 7, + }, + value: Int( + "1", + ), + }, ), - }, - ), - ], - }, - ), + Constant( + Constant { + node: Node { + start: 8, + end: 9, + }, + value: Int( + "2", + ), + }, + ), + ], + }, + ), + ), + is_constant: false, + }, ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, }, - }, - start_line_number: 0, + ], + all_scopes: [], } diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap index efa19177..9fa5d625 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-4.snap @@ -3,54 +3,57 @@ source: typechecker/src/build.rs description: "a: int = 1" --- SymbolTable { - symbol_table_type: Module, - symbols: { - "a": SymbolTableNode { - name: "a", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 10, - }, - }, - scope: Global, - type_annotation: Some( - Name( - Name { - node: Node { - start: 3, - end: 6, - }, - id: "int", - }, - ), - ), - inferred_type_source: Some( - Constant( - Constant { + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: { + "a": SymbolTableNode { + name: "a", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", node: Node { - start: 9, + start: 0, end: 10, }, - value: Int( - "1", - ), }, - ), + scope: Global, + type_annotation: Some( + Name( + Name { + node: Node { + start: 3, + end: 6, + }, + id: "int", + }, + ), + ), + inferred_type_source: Some( + Constant( + Constant { + node: Node { + start: 9, + end: 10, + }, + value: Int( + "1", + ), + }, + ), + ), + is_constant: false, + }, ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, }, - }, - start_line_number: 0, + ], + all_scopes: [], } diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap index 2160df9f..45ec782e 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-5.snap @@ -3,7 +3,11 @@ source: typechecker/src/build.rs description: a += b --- SymbolTable { - symbol_table_type: Module, - symbols: {}, - start_line_number: 0, + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: {}, + }, + ], + all_scopes: [], } diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-6.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-6.snap new file mode 100644 index 00000000..6ce0592b --- /dev/null +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt-6.snap @@ -0,0 +1,87 @@ +--- +source: typechecker/src/build.rs +description: "def f():\n a = 1\n return\n" +--- +SymbolTable { + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: { + "f": SymbolTableNode { + name: "f", + declarations: [ + Function( + Function { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 0, + end: 28, + }, + }, + is_method: false, + is_generator: false, + return_statements: [ + Return { + node: Node { + start: 18, + end: 27, + }, + value: None, + }, + ], + yeild_statements: [], + raise_statements: [], + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, + }, + ], + all_scopes: [ + SymbolTableScope { + symbol_table_type: Function, + symbols: { + "a": SymbolTableNode { + name: "a", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", + node: Node { + start: 12, + end: 17, + }, + }, + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Constant( + Constant { + node: Node { + start: 16, + end: 17, + }, + value: Int( + "1", + ), + }, + ), + ), + is_constant: false, + }, + ), + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, + }, + ], +} diff --git a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap index 0ae2a9cb..4008af7e 100644 --- a/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap +++ b/typechecker/src/snapshots/typechecker__build__tests__assign_stmt.snap @@ -3,44 +3,47 @@ source: typechecker/src/build.rs description: "a = 'hello world'" --- SymbolTable { - symbol_table_type: Module, - symbols: { - "a": SymbolTableNode { - name: "a", - declarations: [ - Variable( - Variable { - declaration_path: DeclarationPath { - module_name: "test", - node: Node { - start: 0, - end: 17, - }, - }, - scope: Global, - type_annotation: None, - inferred_type_source: Some( - Constant( - Constant { + scopes: [ + SymbolTableScope { + symbol_table_type: Module, + symbols: { + "a": SymbolTableNode { + name: "a", + declarations: [ + Variable( + Variable { + declaration_path: DeclarationPath { + module_name: "test", node: Node { - start: 4, + start: 0, end: 17, }, - value: Str( - "hello world", - ), }, - ), + scope: Global, + type_annotation: None, + inferred_type_source: Some( + Constant( + Constant { + node: Node { + start: 4, + end: 17, + }, + value: Str( + "hello world", + ), + }, + ), + ), + is_constant: false, + }, ), - is_constant: false, - }, - ), - ], - module_public: false, - module_hidden: false, - implicit: false, - scope: Global, + ], + module_public: false, + module_hidden: false, + implicit: false, + }, + }, }, - }, - start_line_number: 0, + ], + all_scopes: [], } diff --git a/typechecker/src/symbol_table.rs b/typechecker/src/symbol_table.rs index ca52784f..af4bd473 100644 --- a/typechecker/src/symbol_table.rs +++ b/typechecker/src/symbol_table.rs @@ -3,13 +3,25 @@ use std::collections::HashMap; #[derive(Debug)] pub struct SymbolTable { + // Sub tables are scopes inside the current scope + scopes: Vec, + // When a symbol goes out of scope we save it here to be able to look it up later + all_scopes: Vec, +} + +#[derive(Debug)] +pub struct SymbolTableScope { pub symbol_table_type: SymbolTableType, symbols: HashMap, - pub start_line_number: u8, - // all sub tables have to be valid until the top level scope is valid - // sub_tables: Vec<&'a SymbolTable<'a>>, - // index of current scope in this table where we insert new symbols - // current_scope: u8, +} + +impl SymbolTableScope { + pub fn new(symbol_table_type: SymbolTableType) -> Self { + SymbolTableScope { + symbol_table_type, + symbols: HashMap::new(), + } + } } #[derive(Debug)] @@ -26,7 +38,6 @@ pub struct SymbolTableNode { pub module_public: bool, pub module_hidden: bool, pub implicit: bool, - pub scope: SymbolScope, } #[derive(Debug, Clone)] @@ -38,6 +49,7 @@ pub struct DeclarationPath { #[derive(Debug)] pub enum Declaration { Variable(Box), + Function(Box), } #[derive(Debug)] @@ -49,6 +61,17 @@ pub struct Variable { pub is_constant: bool, } +#[derive(Debug)] +pub struct Function { + pub declaration_path: DeclarationPath, + pub is_method: bool, + pub is_generator: bool, + pub return_statements: Vec, + pub yeild_statements: Vec, + // helpful to later type check exceptions + pub raise_statements: Vec, +} + #[derive(Debug, Clone, Copy)] pub enum SymbolScope { Global, @@ -59,24 +82,52 @@ pub enum SymbolScope { impl SymbolTable { pub fn new(symbol_table_type: SymbolTableType, start_line_number: u8) -> Self { - SymbolTable { + let global_scope = SymbolTableScope { symbol_table_type, symbols: HashMap::new(), - start_line_number, + }; + SymbolTable { + scopes: vec![global_scope], + all_scopes: vec![], } } + + fn current_scope(&self) -> &SymbolTableScope { + if let Some(scope) = self.scopes.last() { + return &scope; + } else { + panic!("no scopes") + } + } + + pub fn current_scope_type(&self) -> &SymbolTableType { + return &self.current_scope().symbol_table_type; + } + pub fn lookup_in_scope(&self, name: &str) -> Option<&SymbolTableNode> { - return self.symbols.get(name); + let cur_scope = self.current_scope(); + return cur_scope.symbols.get(name); + } + + pub fn enter_scope(&mut self, new_scope: SymbolTableScope) { + self.scopes.push(new_scope); } - // - // pub fn enter_scope(&mut self, new_symbol_table: &'a SymbolTable<'a>) { - // self.sub_tables.push(new_symbol_table); - // } - pub fn exit_scope(&self) {} + pub fn exit_scope(&mut self) { + let finished_scope = self.scopes.pop(); + match finished_scope { + Some(scope) => self.all_scopes.push(scope), + None => panic!("tried to exit non-existent scope"), + } + } pub fn add_symbol(&mut self, symbol_node: SymbolTableNode) { - self.symbols.insert(symbol_node.name.clone(), symbol_node); + match self.scopes.last_mut() { + Some(scope) => { + scope.symbols.insert(symbol_node.name.clone(), symbol_node); + } + None => panic!("no current scope, there must be a global scope"), + }; } }