Skip to content

Commit ac784f7

Browse files
committed
Add binary operation to infer expr type
1 parent b7ccf8b commit ac784f7

File tree

7 files changed

+55
-14
lines changed

7 files changed

+55
-14
lines changed

cli/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn symbols(path: &PathBuf) -> std::result::Result<(), anyhow::Error> {
3434
manager.build();
3535
let module = manager.modules.values().last().unwrap();
3636

37-
format!("{}", module.symbol_table);
37+
println!("{}", module.symbol_table);
3838

3939
Ok(())
4040
}

typechecker/src/symbol_table.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ impl SymbolTableNode {
177177

178178
impl std::fmt::Display for SymbolTable {
179179
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180+
writeln!(f, "-------------------")?;
180181
writeln!(f, "global scope:")?;
181182
let mut sorted_scopes = self.scopes.iter().collect::<Vec<&SymbolTableScope>>();
182183
sorted_scopes.sort_by(|a, b| a.name.cmp(&b.name));
@@ -185,14 +186,14 @@ impl std::fmt::Display for SymbolTable {
185186
writeln!(f, "{}", scope)?;
186187
}
187188

188-
writeln!(f, "-------------------")?;
189189
writeln!(f, "all scopes:")?;
190190

191191
let mut sorted_all_scopes = self.all_scopes.iter().collect::<Vec<&SymbolTableScope>>();
192192
sorted_all_scopes.sort_by(|a, b| a.name.cmp(&b.name));
193193
for scope in sorted_all_scopes {
194194
writeln!(f, "{}", scope)?;
195195
}
196+
writeln!(f, "-------------------")?;
196197
Ok(())
197198
}
198199
}

typechecker/src/type_check/checker.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
};
77

88
use super::{
9-
type_inference::{self, type_check_bin_op},
9+
type_inference::{self, bin_op_result_type, type_check_bin_op},
1010
types::Type,
1111
};
1212

@@ -60,14 +60,12 @@ impl<'a> TypeChecker<'a> {
6060
Some(symbol) => match symbol.last_declaration() {
6161
Some(declaration) => self.infer_declaration_type(declaration),
6262
None => {
63-
self.errors
64-
.push(format!("Symbol '{}' not found", name));
63+
self.errors.push(format!("Symbol '{}' not found", name));
6564
Type::Unknown
6665
}
6766
},
6867
None => {
69-
self.errors
70-
.push(format!("Symbol '{}' not found", name));
68+
self.errors.push(format!("Symbol '{}' not found", name));
7169
Type::Unknown
7270
}
7371
}
@@ -92,6 +90,11 @@ impl<'a> TypeChecker<'a> {
9290
_ => panic!("TODO: infer type from call"),
9391
}
9492
}
93+
ast::Expression::BinOp(b) => bin_op_result_type(
94+
&self.infer_expr_type(&b.left),
95+
&self.infer_expr_type(&b.right),
96+
&b.op,
97+
),
9598
_ => Type::Unknown,
9699
}
97100
}
@@ -293,10 +296,6 @@ impl<'a> TraversalVisitor for TypeChecker<'a> {
293296
let l_type = self.infer_expr_type(&b.left);
294297
let r_type = self.infer_expr_type(&b.right);
295298

296-
println!("{}", self.module.symbol_table);
297-
298-
println!("{} {} {}", l_type, b.op, r_type);
299-
300299
if !type_check_bin_op(&l_type, &r_type, &b.op) {
301300
self.errors.push(format!(
302301
"Operator '{}' not supported for types '{}' and '{}'",

typechecker/src/type_check/type_inference.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,41 @@ pub fn type_check_bin_op(t1: &Type, t2: &Type, op: &BinaryOperator) -> bool {
5959

6060
false
6161
}
62+
63+
pub fn bin_op_result_type(t1: &Type, t2: &Type, op: &BinaryOperator) -> Type {
64+
if !type_check_bin_op(t1, t2, op) {
65+
return Type::Unknown;
66+
}
67+
68+
match op {
69+
BinaryOperator::Add
70+
| BinaryOperator::Sub
71+
| BinaryOperator::Mult
72+
| BinaryOperator::MatMult
73+
| BinaryOperator::Div
74+
| BinaryOperator::Mod
75+
| BinaryOperator::Pow
76+
| BinaryOperator::LShift
77+
| BinaryOperator::RShift
78+
| BinaryOperator::BitOr
79+
| BinaryOperator::BitXor
80+
| BinaryOperator::BitAnd
81+
| BinaryOperator::FloorDiv => {
82+
if type_equal(t1, &Type::Float) || type_equal(t2, &Type::Float) {
83+
return Type::Float;
84+
}
85+
if type_equal(t1, &Type::Int) || type_equal(t2, &Type::Int) {
86+
return Type::Int;
87+
}
88+
match t1 {
89+
Type::Str => Type::Str,
90+
Type::None => Type::None,
91+
Type::Unknown => Type::Unknown,
92+
Type::Bool => Type::Bool,
93+
Type::Int => Type::Int,
94+
Type::Float => Type::Float,
95+
_ => Type::Unknown,
96+
}
97+
}
98+
}
99+
}

typechecker/testdata/output/typechecker__build__tests__class_def.snap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ source: typechecker/src/build.rs
33
description: "class c:\n def __init__(self):\n a = 1\n"
44
expression: result
55
---
6+
-------------------
67
global scope:
78
Symbols:
89
c
910
- Declarations:
1011
--: Class { declaration_path: DeclarationPath { module_name: "test", node: Node { start: 0, end: 47 } }, methods: ["__init__"] }
1112

12-
-------------------
1313
all scopes:
1414
Symbols:
1515
a
@@ -24,4 +24,5 @@ __init__
2424
- Declarations:
2525
--: Function { declaration_path: DeclarationPath { module_name: "test", node: Node { start: 13, end: 47 } }, function_node: FunctionDef { node: Node { start: 13, end: 47 }, name: "__init__", args: Arguments { node: Node { start: 26, end: 30 }, posonlyargs: [], args: [Arg { node: Node { start: 26, end: 30 }, arg: "self", annotation: None }], vararg: None, kwonlyargs: [], kw_defaults: [], kwarg: None, defaults: [] }, body: [AssignStatement(Assign { node: Node { start: 41, end: 46 }, targets: [Name(Name { node: Node { start: 41, end: 42 }, id: "a" })], value: Constant(Constant { node: Node { start: 45, end: 46 }, value: Int("1") }) })], decorator_list: [], returns: None, type_comment: None }, is_method: true, is_generator: false, return_statements: [], yeild_statements: [], raise_statements: [] }
2626

27+
-------------------
2728

typechecker/testdata/output/typechecker__build__tests__function_def.snap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ source: typechecker/src/build.rs
33
description: "def func(a ,b , /, c = 2, **e): pass\n"
44
expression: result
55
---
6+
-------------------
67
global scope:
78
Symbols:
89
func
910
- Declarations:
1011
--: Function { declaration_path: DeclarationPath { module_name: "test", node: Node { start: 0, end: 37 } }, function_node: FunctionDef { node: Node { start: 0, end: 37 }, name: "func", args: Arguments { node: Node { start: 9, end: 29 }, posonlyargs: [Arg { node: Node { start: 9, end: 10 }, arg: "a", annotation: None }, Arg { node: Node { start: 12, end: 13 }, arg: "b", annotation: None }], args: [Arg { node: Node { start: 19, end: 24 }, arg: "c", annotation: None }], vararg: None, kwonlyargs: [], kw_defaults: [], kwarg: Some(Arg { node: Node { start: 28, end: 29 }, arg: "e", annotation: None }), defaults: [Constant(Constant { node: Node { start: 23, end: 24 }, value: Int("2") })] }, body: [Pass(Pass { node: Node { start: 32, end: 36 } })], decorator_list: [], returns: None, type_comment: None }, is_method: false, is_generator: false, return_statements: [], yeild_statements: [], raise_statements: [] }
1112

12-
-------------------
1313
all scopes:
1414
Symbols:
1515
a
@@ -25,4 +25,5 @@ e
2525
- Declarations:
2626
--: Paramter { declaration_path: DeclarationPath { module_name: "test", node: Node { start: 28, end: 29 } }, parameter_node: Arg { node: Node { start: 28, end: 29 }, arg: "e", annotation: None }, type_annotation: None, default_value: None }
2727

28+
-------------------
2829

typechecker/testdata/output/typechecker__build__tests__simple_var_assignments.snap

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ source: typechecker/src/build.rs
33
description: "a = 'hello world'\nb = a + \"!\"\n\nc: int = 1\n\nf: str = \"hello\"\n"
44
expression: result
55
---
6+
-------------------
67
global scope:
78
Symbols:
89
a
@@ -18,6 +19,6 @@ f
1819
- Declarations:
1920
--: Variable { declaration_path: DeclarationPath { module_name: "test", node: Node { start: 43, end: 59 } }, scope: Global, type_annotation: Some(Name(Name { node: Node { start: 46, end: 49 }, id: "str" })), inferred_type_source: Some(Constant(Constant { node: Node { start: 52, end: 59 }, value: Str("hello") })), is_constant: false }
2021

21-
-------------------
2222
all scopes:
23+
-------------------
2324

0 commit comments

Comments
 (0)