diff --git a/spec/api/gen_spec.lua b/spec/api/gen_spec.lua index a53538d46..baee93bcc 100644 --- a/spec/api/gen_spec.lua +++ b/spec/api/gen_spec.lua @@ -69,7 +69,7 @@ describe("tl.gen", function() print(math.floor(2)) ]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal('print(math.floor(2))', output) @@ -83,7 +83,7 @@ describe("tl.gen", function() print(math.floor(2))]] - local env = tl.init_env(true, true) + local env = tl.init_env(true, false) local output, result = tl.gen(input, env) assert.equal(input, output) diff --git a/spec/api/get_types_spec.lua b/spec/api/get_types_spec.lua index 26bbf1d05..b6f55ec83 100644 --- a/spec/api/get_types_spec.lua +++ b/spec/api/get_types_spec.lua @@ -8,7 +8,7 @@ describe("tl.get_types", function() local function a() ::continue:: end - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) assert(tr) @@ -25,7 +25,7 @@ describe("tl.get_types", function() end R.f("hello") - ]], false, env)) + ]], env)) local tr, trenv = tl.get_types(result) local y = 6 diff --git a/spec/api/pretty_print_ast.lua b/spec/api/pretty_print_ast.lua index d87d1ea86..d1d149786 100644 --- a/spec/api/pretty_print_ast.lua +++ b/spec/api/pretty_print_ast.lua @@ -4,7 +4,7 @@ local util = require("spec.util") describe("tl.pretty_print_ast", function() it("returns error for attribute on non 5.4 target", function() local input = [[local x = io.open("foobar", "r")]] - local result = tl.process_string(input, false, tl.init_env(false, "off", "5.4"), "foo.tl") + local result = tl.process_string(input, tl.init_env(false, "off", "5.4"), "foo.tl") local output, err = tl.pretty_print_ast(result.ast, "5.3") assert.is_nil(output) diff --git a/spec/call/generic_function_spec.lua b/spec/call/generic_function_spec.lua index 2fb8cf4d6..ec68bb3ff 100644 --- a/spec/call/generic_function_spec.lua +++ b/spec/call/generic_function_spec.lua @@ -370,7 +370,7 @@ describe("generic function", function() recurse_node(ast, visit_node, visit_type) end ]], { - { x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } + { y = 16, x = 40, msg = "argument 3: in map value: type parameter : got number, expected string" } })) it("inference trickles down to function arguments, pass", util.check([[ diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index a792f8433..c94b51d7a 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -199,7 +199,6 @@ describe("tl types works like check", function() local by_pos = types.by_pos[next(types.by_pos)] assert(by_pos["1"]) assert(by_pos["1"]["13"]) -- require - assert(by_pos["1"]["20"]) -- ( assert(by_pos["1"]["21"]) -- "os" assert(by_pos["1"]["26"]) -- . end) @@ -217,18 +216,17 @@ describe("tl types works like check", function() assert(types.by_pos) local by_pos = types.by_pos[next(types.by_pos)] assert.same({ - ["19"] = 2, - ["20"] = 5, - ["22"] = 2, - ["39"] = 6, - ["41"] = 2, + ["19"] = 8, + ["22"] = 8, + ["23"] = 6, + ["30"] = 2, + ["41"] = 8, }, by_pos["1"]) assert.same({ - ["17"] = 3, - ["20"] = 4, - ["25"] = 17, - ["30"] = 16, - ["31"] = 2, + ["17"] = 6, + ["20"] = 2, + ["25"] = 9, + ["31"] = 8, }, by_pos["2"]) end) end) diff --git a/spec/declaration/record_method_spec.lua b/spec/declaration/record_method_spec.lua index 20cbde3dc..7f8cf2db6 100644 --- a/spec/declaration/record_method_spec.lua +++ b/spec/declaration/record_method_spec.lua @@ -239,8 +239,8 @@ describe("record method", function() return "hello" end ]], { - { msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, - { msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, + { y = 5, msg = "in assignment: incompatible number of returns: got 0 (), expected 1 (string)" }, + { y = 6, msg = "excess return values, expected 0 (), got 1 (string \"hello\")" }, })) it("allows functions declared on method tables (#27)", function() diff --git a/spec/parser/parser_error_spec.lua b/spec/parser/parser_error_spec.lua index ed50e80c9..cfd2e077c 100644 --- a/spec/parser/parser_error_spec.lua +++ b/spec/parser/parser_error_spec.lua @@ -2,7 +2,7 @@ local tl = require("tl") describe("parser errors", function() it("parse errors include filename", function () - local result = tl.process_string("local x 1", false, nil, "foo.tl") + local result = tl.process_string("local x 1", nil, "foo.tl") assert.same("foo.tl", result.syntax_errors[1].filename, "parse errors should contain .filename property") end) @@ -30,7 +30,7 @@ describe("parser errors", function() local code = [[ local bar = require "bar" ]] - local result = tl.process_string(code, true, nil, "foo.tl") + local result = tl.process_string(code, nil, "foo.tl") assert.is_not_nil(string.match(result.env.loaded["./bar.tl"].syntax_errors[1].filename, "bar.tl$"), "errors should contain .filename property") end) end) diff --git a/spec/parser/parser_spec.lua b/spec/parser/parser_spec.lua index d1e66fb38..870260f90 100644 --- a/spec/parser/parser_spec.lua +++ b/spec/parser/parser_spec.lua @@ -19,6 +19,7 @@ describe("parser", function() assert.same({ kind = "statements", tk = "$EOF$", + f = "", x = 1, y = 1, xend = 5, diff --git a/spec/stdlib/require_spec.lua b/spec/stdlib/require_spec.lua index 999cedfd2..17d0d1ba0 100644 --- a/spec/stdlib/require_spec.lua +++ b/spec/stdlib/require_spec.lua @@ -401,7 +401,7 @@ describe("require", function() local result, err = tl.process("foo.tl") assert.same(0, #result.syntax_errors) - assert.same(0, #result.env.loaded["foo.tl"].type_errors) + assert.same({}, result.env.loaded["foo.tl"].type_errors) assert.same(1, #result.env.loaded["./box.tl"].type_errors) assert.match("cannot use operator ..", result.env.loaded["./box.tl"].type_errors[1].msg) end) diff --git a/spec/stdlib/xpcall_spec.lua b/spec/stdlib/xpcall_spec.lua index 87089f162..16e911a7f 100644 --- a/spec/stdlib/xpcall_spec.lua +++ b/spec/stdlib/xpcall_spec.lua @@ -105,7 +105,7 @@ describe("xpcall", function() { msg = "xyz: got boolean, expected number" } })) - it("type checks the message handler", util.check_type_error([[ + it("#only type checks the message handler", util.check_type_error([[ local function f(a: string, b: number) end diff --git a/spec/util.lua b/spec/util.lua index fb9aeeab3..ccaf59e7f 100644 --- a/spec/util.lua +++ b/spec/util.lua @@ -435,7 +435,7 @@ local function check(lax, code, unknowns, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.lua", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.lua", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) batch:add(assert.same, {}, result.type_errors) if unknowns then @@ -456,7 +456,7 @@ local function check_type_error(lax, code, type_errors, gen_target) if gen_target == "5.4" then gen_compat = "off" end - local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target, gen_compat = gen_compat }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target, gen_compat = gen_compat }) local result_type_errors = combine_result(result, "type_errors") batch_compare(batch, "type errors", type_errors, result_type_errors) @@ -525,7 +525,7 @@ function util.check_syntax_error(code, syntax_errors) local batch = batch_assertions() batch_compare(batch, "syntax errors", syntax_errors, errors) batch:assert() - tl.type_check(ast, { filename = "foo.tl", lax = false }) + tl.type_check(ast, "foo.tl", { feat_lax = "off" }) end end @@ -564,7 +564,7 @@ function util.check_types(code, types) local batch = batch_assertions() local env = tl.init_env() env.report_types = true - local result = tl.type_check(ast, { filename = "foo.tl", env = env, lax = false }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = "off" }, env) batch:add(assert.same, {}, result.type_errors, "Code was not expected to have type errors") local tr = env.reporter:get_report() @@ -596,7 +596,7 @@ local function gen(lax, code, expected, gen_target) return function() local ast, syntax_errors = tl.parse(code, "foo.tl") assert.same({}, syntax_errors, "Code was not expected to have syntax errors") - local result = tl.type_check(ast, { filename = "foo.tl", lax = lax, gen_target = gen_target }) + local result = tl.type_check(ast, "foo.tl", { feat_lax = lax and "on" or "off", gen_target = gen_target }) assert.same({}, result.type_errors) local output_code = tl.pretty_print_ast(ast) diff --git a/tl b/tl index 8892d1831..d8516b8f5 100755 --- a/tl +++ b/tl @@ -163,10 +163,12 @@ local function setup_env(tlconfig, filename) end local opts = { - lax_mode = lax_mode, - feat_arity = tlconfig["feat_arity"], - gen_compat = tlconfig["gen_compat"], - gen_target = tlconfig["gen_target"], + defaults = { + feat_lax = lax_mode and "on" or "off", + feat_arity = tlconfig["feat_arity"], + gen_compat = tlconfig["gen_compat"], + gen_target = tlconfig["gen_target"], + }, predefined_modules = tlconfig._init_env_modules, } diff --git a/tl.lua b/tl.lua index 47281d4da..f87e65196 100644 --- a/tl.lua +++ b/tl.lua @@ -1,4 +1,4 @@ -local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack +local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table local VERSION = "0.15.3+dev" local stdlib = [=====[ @@ -481,10 +481,16 @@ end -local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } +local Errors = {} + + + +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, EnvOptions = {}, } + + @@ -629,6 +635,7 @@ local TypeReporter = {} + tl.version = function() return VERSION end @@ -699,6 +706,12 @@ tl.typecodes = { +local DEFAULT_GEN_COMPAT = "optional" +local DEFAULT_GEN_TARGET = "5.3" + + + + @@ -1517,7 +1530,6 @@ end - local table_types = { @@ -1552,7 +1564,6 @@ local table_types = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1577,6 +1588,9 @@ local table_types = { +local function is_numeric_type(t) + return t.typename == "number" or t.typename == "integer" +end @@ -1852,14 +1866,12 @@ local table_types = { -local TruthyFact = {} -local NotFact = {} @@ -1868,7 +1880,6 @@ local NotFact = {} -local AndFact = {} @@ -1878,33 +1889,34 @@ local AndFact = {} -local OrFact = {} +local TruthyFact = {} +local NotFact = {} -local EqFact = {} +local AndFact = {} -local IsFact = {} +local OrFact = {} @@ -1914,22 +1926,17 @@ local IsFact = {} +local EqFact = {} -local attributes = { - ["const"] = true, - ["close"] = true, - ["total"] = true, -} -local is_attribute = attributes -local Node = {ExpectedContext = {}, } +local IsFact = {} @@ -1951,6 +1958,15 @@ local Node = {ExpectedContext = {}, } +local attributes = { + ["const"] = true, + ["close"] = true, + ["total"] = true, +} +local is_attribute = attributes + +local Node = {ExpectedContext = {}, } + @@ -2032,9 +2048,6 @@ local Node = {ExpectedContext = {}, } -local function is_number_type(t) - return t.typename == "number" or t.typename == "integer" -end @@ -2051,95 +2064,34 @@ end -local parse_type_list -local parse_expression -local parse_expression_and_tk -local parse_statements -local parse_argument_list -local parse_argument_type_list -local parse_type -local parse_newtype -local parse_interface_name -local parse_enum_body -local parse_record_body -local parse_type_body_fns -local function fail(ps, i, msg) - if not ps.tokens[i] then - local eof = ps.tokens[#ps.tokens] - table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) - return #ps.tokens - end - table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) - return math.min(#ps.tokens, i + 1) -end -local function end_at(node, tk) - node.yend = tk.y - node.xend = tk.x + #tk.tk - 1 -end -local function verify_tk(ps, i, tk) - if ps.tokens[i].tk == tk then - return i + 1 - end - return fail(ps, i, "syntax error, expected '" .. tk .. "'") -end -local function verify_end(ps, i, istart, node) - if ps.tokens[i].tk == "end" then - local endy, endx = ps.tokens[i].y, ps.tokens[i].x - node.yend = endy - node.xend = endx + 2 - if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then - if not ps.end_alignment_hint then - ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } - end - end - return i + 1 - end - end_at(node, ps.tokens[i]) - if ps.end_alignment_hint then - table.insert(ps.errs, ps.end_alignment_hint) - ps.end_alignment_hint = nil - end - return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") -end -local function new_node(tokens, i, kind) - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } -end -local function a_type(typename, t) + +local function a_type(w, typename, t) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function edit_type(t, typename) +local function edit_type(w, t, typename) t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y t.typename = typename return t end -local function new_type(ps, i, typename) - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - - }) -end -local function new_typedecl(ps, i, def) - local t = new_type(ps, i, "typedecl") - t.def = def - return t -end @@ -2151,20 +2103,28 @@ end +local function a_function(w, t) + assert(t.min_arity) + return a_type(w, "function", t) +end +local function a_vararg(w, t) + local typ = a_type(w, "tuple", { tuple = t }) + typ.is_va = true + return typ +end -local function a_function(t) - assert(t.min_arity) - return a_type("function", t) -end +local function a_nominal(n, names) + return a_type(n, "nominal", { names = names }) +end @@ -2174,16 +2134,63 @@ end +local an_operator +local function shallow_copy_new_type(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy +end +local function shallow_copy_table(t) + local copy = {} + for k, v in pairs(t) do + copy[k] = v + end + return copy +end -local function va_args(args) - args.is_va = true - return args +local function clear_redundant_errors(errors) + local redundant = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a, b) + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf or + (af == bf and (a.y < b.y or + (a.y == b.y and (a.x < b.x or + (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end end +local simple_types = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} +do @@ -2191,194 +2198,232 @@ end -local function a_fn(f) - local args_t = a_type("tuple", { tuple = {} }) - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a.opttype then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - local rets_t = a_type("tuple", { tuple = {} }) - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a.typename) - table.insert(tup, a) - end - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - }) -end -local function a_vararg(t) - local typ = a_type("tuple", { tuple = t }) - typ.is_va = true - return typ -end + local parse_type_list + local parse_expression + local parse_expression_and_tk + local parse_statements + local parse_argument_list + local parse_argument_type_list + local parse_type + local parse_newtype + local parse_interface_name + local parse_enum_body + local parse_record_body + local parse_type_body_fns -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_type("map", { keys = ANY, values = ANY }) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) + local function fail(ps, i, msg) + if not ps.tokens[i] then + local eof = ps.tokens[#ps.tokens] + table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) + return #ps.tokens + end + table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) + return math.min(#ps.tokens, i + 1) + end -local function shallow_copy_new_type(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function end_at(node, tk) + node.yend = tk.y + node.xend = tk.x + #tk.tk - 1 end - copy.typeid = new_typeid() - return copy -end -local function shallow_copy_table(t) - local copy = {} - for k, v in pairs(t) do - copy[k] = v + local function verify_tk(ps, i, tk) + if ps.tokens[i].tk == tk then + return i + 1 + end + return fail(ps, i, "syntax error, expected '" .. tk .. "'") + end + + local function verify_end(ps, i, istart, node) + if ps.tokens[i].tk == "end" then + local endy, endx = ps.tokens[i].y, ps.tokens[i].x + node.yend = endy + node.xend = endx + 2 + if node.kind ~= "function" and endy ~= node.y and endx ~= node.x then + if not ps.end_alignment_hint then + ps.end_alignment_hint = { filename = ps.filename, y = node.y, x = node.x, msg = "syntax error hint: construct starting here is not aligned with its 'end' at " .. ps.filename .. ":" .. endy .. ":" .. endx .. ":" } + end + end + return i + 1 + end + end_at(node, ps.tokens[i]) + if ps.end_alignment_hint then + table.insert(ps.errs, ps.end_alignment_hint) + ps.end_alignment_hint = nil + end + return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end - return copy -end -local function verify_kind(ps, i, kind, node_kind) - if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + local function new_node(ps, i, kind) + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind) } end - return fail(ps, i, "syntax error, expected " .. kind) -end + local function new_type(ps, i, typename) + local token = ps.tokens[i] + local t = {} + t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y + t.typename = typename + return t + end + local function new_typedecl(ps, i, def) + local t = new_type(ps, i, "typedecl") + t.def = def + return t + end -local function skip(ps, i, skip_fn) - local err_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = {}, - } - return skip_fn(err_ps, i) -end + local function new_tuple(ps, i, types, is_va) + local t = new_type(ps, i, "tuple") + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple + end -local function failskip(ps, i, msg, skip_fn, starti) - local skip_i = skip(ps, starti or i, skip_fn) - fail(ps, i, msg) - return skip_i -end + local function new_typealias(ps, i, alias_to) + local t = new_type(ps, i, "typealias") + t.alias_to = alias_to + return t + end -local function skip_type_body(ps, i) - local tn = ps.tokens[i].tk - i = i + 1 - assert(parse_type_body_fns[tn], tn .. " has no parse body function") - return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) -end + local function new_nominal(ps, i, name) + local t = new_type(ps, i, "nominal") + if name then + t.names = { name } + end + return t + end -local function parse_table_value(ps, i) - local next_word = ps.tokens[i].tk - if next_word == "record" or next_word == "interface" then - local skip_i, e = skip(ps, i, skip_type_body) - if e then - fail(ps, i, next_word == "record" and - "syntax error: this syntax is no longer valid; declare nested record inside a record" or - "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + local function verify_kind(ps, i, kind, node_kind) + if ps.tokens[i].kind == kind then + return i + 1, new_node(ps, i, node_kind) end - elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then - i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return fail(ps, i, "syntax error, expected " .. kind) end - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") + + + local function skip(ps, i, skip_fn) + local err_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = {}, + } + return skip_fn(err_ps, i) end - return i, e -end -local function parse_table_item(ps, i, n) - local node = new_node(ps.tokens, i, "literal_table_item") - if ps.tokens[i].kind == "$EOF$" then - return fail(ps, i, "unexpected eof") + local function failskip(ps, i, msg, skip_fn, starti) + local skip_i = skip(ps, starti or i, skip_fn) + fail(ps, i, msg) + return skip_i end - if ps.tokens[i].tk == "[" then - node.key_parsed = "long" + local function skip_type_body(ps, i) + local tn = ps.tokens[i].tk i = i + 1 - i, node.key = parse_expression_and_tk(ps, i, "]") - i = verify_tk(ps, i, "=") - i, node.value = parse_table_value(ps, i) - return i, node, n - elseif ps.tokens[i].kind == "identifier" then - if ps.tokens[i + 1].tk == "=" then - node.key_parsed = "short" - i, node.key = verify_kind(ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' + assert(parse_type_body_fns[tn], tn .. " has no parse body function") + return parse_type_body_fns[tn](ps, i, {}, { kind = "function" }) + end + + local function parse_table_value(ps, i) + local next_word = ps.tokens[i].tk + if next_word == "record" or next_word == "interface" then + local skip_i, e = skip(ps, i, skip_type_body) + if e then + fail(ps, i, next_word == "record" and + "syntax error: this syntax is no longer valid; declare nested record inside a record" or + "syntax error: cannot declare interface inside a table; use a statement") + return skip_i, new_node(ps, i, "error_node") + end + elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then + i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) + return i, new_node(ps, i - 1, "error_node") + end + + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") + end + return i, e + end + + local function parse_table_item(ps, i, n) + local node = new_node(ps, i, "literal_table_item") + if ps.tokens[i].kind == "$EOF$" then + return fail(ps, i, "unexpected eof") + end + + if ps.tokens[i].tk == "[" then + node.key_parsed = "long" + i = i + 1 + i, node.key = parse_expression_and_tk(ps, i, "]") i = verify_tk(ps, i, "=") i, node.value = parse_table_value(ps, i) return i, node, n - elseif ps.tokens[i + 1].tk == ":" then - node.key_parsed = "short" - local orig_i = i - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - i, node.key = verify_kind(try_ps, i, "identifier", "string") - node.key.conststr = node.key.tk - node.key.tk = '"' .. node.key.tk .. '"' - i = verify_tk(try_ps, i, ":") - i, node.itemtype = parse_type(try_ps, i) - if node.itemtype and ps.tokens[i].tk == "=" then - i = verify_tk(try_ps, i, "=") - i, node.value = parse_table_value(try_ps, i) - if node.value then - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) + elseif ps.tokens[i].kind == "identifier" then + if ps.tokens[i + 1].tk == "=" then + node.key_parsed = "short" + i, node.key = verify_kind(ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(ps, i, "=") + i, node.value = parse_table_value(ps, i) + return i, node, n + elseif ps.tokens[i + 1].tk == ":" then + node.key_parsed = "short" + local orig_i = i + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + i, node.key = verify_kind(try_ps, i, "identifier", "string") + node.key.conststr = node.key.tk + node.key.tk = '"' .. node.key.tk .. '"' + i = verify_tk(try_ps, i, ":") + i, node.itemtype = parse_type(try_ps, i) + if node.itemtype and ps.tokens[i].tk == "=" then + i = verify_tk(try_ps, i, "=") + i, node.value = parse_table_value(try_ps, i) + if node.value then + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + return i, node, n end - return i, node, n end - end - node.itemtype = nil - i = orig_i + node.itemtype = nil + i = orig_i + end end - end - node.key = new_node(ps.tokens, i, "integer") - node.key_parsed = "implicit" - node.key.constnum = n - node.key.tk = tostring(n) - i, node.value = parse_expression(ps, i) - if not node.value then - return fail(ps, i, "expected an expression") + node.key = new_node(ps, i, "integer") + node.key_parsed = "implicit" + node.key.constnum = n + node.key.tk = tostring(n) + i, node.value = parse_expression(ps, i) + if not node.value then + return fail(ps, i, "expected an expression") + end + return i, node, n + 1 end - return i, node, n + 1 -end @@ -2387,786 +2432,772 @@ end -local function parse_list(ps, i, list, close, sep, parse_item) - local n = 1 - while ps.tokens[i].kind ~= "$EOF$" do - if close[ps.tokens[i].tk] then - end_at(list, ps.tokens[i]) - break - end - local item - local oldn = n - i, item, n = parse_item(ps, i, n) - n = n or oldn - table.insert(list, item) - if ps.tokens[i].tk == "," then - i = i + 1 - if sep == "sep" and close[ps.tokens[i].tk] then - fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") - return i, list - end - elseif sep == "term" and ps.tokens[i].tk == ";" then - i = i + 1 - elseif not close[ps.tokens[i].tk] then - local options = {} - for k, _ in pairs(close) do - table.insert(options, "'" .. k .. "'") - end - table.sort(options) - local first = options[1]:sub(2, -2) - local msg - - if first == ")" and ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" - i = failskip(ps, i, msg, parse_expression, i + 1) - else - table.insert(options, "','") - msg = "syntax error, expected one of: " .. table.concat(options, ", ") - fail(ps, i, msg) + local function parse_list(ps, i, list, close, sep, parse_item) + local n = 1 + while ps.tokens[i].kind ~= "$EOF$" do + if close[ps.tokens[i].tk] then + end_at(list, ps.tokens[i]) + break end + local item + local oldn = n + i, item, n = parse_item(ps, i, n) + n = n or oldn + table.insert(list, item) + if ps.tokens[i].tk == "," then + i = i + 1 + if sep == "sep" and close[ps.tokens[i].tk] then + fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") + return i, list + end + elseif sep == "term" and ps.tokens[i].tk == ";" then + i = i + 1 + elseif not close[ps.tokens[i].tk] then + local options = {} + for k, _ in pairs(close) do + table.insert(options, "'" .. k .. "'") + end + table.sort(options) + local first = options[1]:sub(2, -2) + local msg + + if first == ")" and ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + i = failskip(ps, i, msg, parse_expression, i + 1) + else + table.insert(options, "','") + msg = "syntax error, expected one of: " .. table.concat(options, ", ") + fail(ps, i, msg) + end - if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then + if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then - table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) - return i, list + table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) + return i, list + end end end + return i, list end - return i, list -end -local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) - i = verify_tk(ps, i, open) - i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) - i = verify_tk(ps, i, close) - return i, list -end + local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) + i = verify_tk(ps, i, open) + i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) + i = verify_tk(ps, i, close) + return i, list + end -local function parse_table_literal(ps, i) - local node = new_node(ps.tokens, i, "literal_table") - return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) -end + local function parse_table_literal(ps, i) + local node = new_node(ps, i, "literal_table") + return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) + end -local function parse_trying_list(ps, i, list, parse_item) - local try_ps = { - filename = ps.filename, - tokens = ps.tokens, - errs = {}, - required_modules = ps.required_modules, - } - local tryi, item = parse_item(try_ps, i) - if not item then + local function parse_trying_list(ps, i, list, parse_item) + local try_ps = { + filename = ps.filename, + tokens = ps.tokens, + errs = {}, + required_modules = ps.required_modules, + } + local tryi, item = parse_item(try_ps, i) + if not item then + return i, list + end + for _, e in ipairs(try_ps.errs) do + table.insert(ps.errs, e) + end + i = tryi + table.insert(list, item) + if ps.tokens[i].tk == "," then + while ps.tokens[i].tk == "," do + i = i + 1 + i, item = parse_item(ps, i) + table.insert(list, item) + end + end return i, list end - for _, e in ipairs(try_ps.errs) do - table.insert(ps.errs, e) - end - i = tryi - table.insert(list, item) - if ps.tokens[i].tk == "," then - while ps.tokens[i].tk == "," do + + local function parse_anglebracket_list(ps, i, parse_item) + if ps.tokens[i + 1].tk == ">" then + return fail(ps, i + 1, "type argument list cannot be empty") + end + local types = {} + i = verify_tk(ps, i, "<") + i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) + if ps.tokens[i].tk == ">" then i = i + 1 - i, item = parse_item(ps, i) - table.insert(list, item) + elseif ps.tokens[i].tk == ">>" then + + ps.tokens[i].tk = ">" + else + return fail(ps, i, "syntax error, expected '>'") end + return i, types end - return i, list -end -local function parse_anglebracket_list(ps, i, parse_item) - if ps.tokens[i + 1].tk == ">" then - return fail(ps, i + 1, "type argument list cannot be empty") + local function parse_typearg(ps, i) + local name = ps.tokens[i].tk + local constraint + i = verify_kind(ps, i, "identifier") + if ps.tokens[i].tk == "is" then + i = i + 1 + i, constraint = parse_interface_name(ps, i) + end + local t = new_type(ps, i, "typearg") + t.typearg = name + t.constraint = constraint + return i, t end - local types = {} - i = verify_tk(ps, i, "<") - i = parse_list(ps, i, types, { [">"] = true, [">>"] = true }, "sep", parse_item) - if ps.tokens[i].tk == ">" then - i = i + 1 - elseif ps.tokens[i].tk == ">>" then - ps.tokens[i].tk = ">" - else - return fail(ps, i, "syntax error, expected '>'") + local function parse_return_types(ps, i) + local iprev = i - 1 + local t + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end - return i, types -end -local function parse_typearg(ps, i) - local name = ps.tokens[i].tk - local constraint - i = verify_kind(ps, i, "identifier") - if ps.tokens[i].tk == "is" then + local function parse_function_type(ps, i) + local typ = new_type(ps, i, "function") i = i + 1 - i, constraint = parse_interface_name(ps, i) + if ps.tokens[i].tk == "<" then + i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end + if ps.tokens[i].tk == "(" then + i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) + i, typ.rets = parse_return_types(ps, i) + else + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + end + return i, typ end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - }) -end - -local function parse_return_types(ps, i) - return parse_type_list(ps, i, "rets") -end -local function parse_function_type(ps, i) - local typ = new_type(ps, i, "function") - i = i + 1 - if ps.tokens[i].tk == "<" then - i, typ.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - if ps.tokens[i].tk == "(" then - i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) - i, typ.rets = parse_return_types(ps, i) - else - typ.args = a_vararg({ ANY }) - typ.rets = a_vararg({ ANY }) - end - return i, typ -end + local function parse_simple_type_or_nominal(ps, i) + local tk = ps.tokens[i].tk + local st = simple_types[tk] + if st then + return i + 1, new_type(ps, i, tk) + elseif tk == "table" then + local typ = new_type(ps, i, "map") + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ + end -local simple_types = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} + local typ = new_nominal(ps, i, tk) + i = i + 1 + while ps.tokens[i].tk == "." do + i = i + 1 + if ps.tokens[i].kind == "identifier" then + table.insert(typ.names, ps.tokens[i].tk) + i = i + 1 + else + return fail(ps, i, "syntax error, expected identifier") + end + end -local function parse_simple_type_or_nominal(ps, i) - local tk = ps.tokens[i].tk - local st = simple_types[tk] - if st then - return i + 1, st + if ps.tokens[i].tk == "<" then + i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) + end + return i, typ end - local typ = new_type(ps, i, "nominal") - typ.names = { tk } - i = i + 1 - while ps.tokens[i].tk == "." do - i = i + 1 + + local function parse_base_type(ps, i) + local tk = ps.tokens[i].tk if ps.tokens[i].kind == "identifier" then - table.insert(typ.names, ps.tokens[i].tk) + return parse_simple_type_or_nominal(ps, i) + elseif tk == "{" then + local istart = i i = i + 1 - else - return fail(ps, i, "syntax error, expected identifier") + local t + i, t = parse_type(ps, i) + if not t then + return i + end + if ps.tokens[i].tk == "}" then + local decl = new_type(ps, istart, "array") + decl.elements = t + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == "," then + local decl = new_type(ps, istart, "tupletable") + decl.types = { t } + local n = 2 + repeat + i = i + 1 + i, decl.types[n] = parse_type(ps, i) + if not decl.types[n] then + break + end + n = n + 1 + until ps.tokens[i].tk ~= "," + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + elseif ps.tokens[i].tk == ":" then + local decl = new_type(ps, istart, "map") + i = i + 1 + decl.keys = t + i, decl.values = parse_type(ps, i) + if not decl.values then + return i + end + end_at(decl, ps.tokens[i]) + i = verify_tk(ps, i, "}") + return i, decl + end + return fail(ps, i, "syntax error; did you forget a '}'?") + elseif tk == "function" then + return parse_function_type(ps, i) + elseif tk == "nil" then + return i + 1, new_type(ps, i, "nil") end + return fail(ps, i, "expected a type") end - if ps.tokens[i].tk == "<" then - i, typ.typevals = parse_anglebracket_list(ps, i, parse_type) - end - return i, typ -end + parse_type = function(ps, i) + if ps.tokens[i].tk == "(" then + i = i + 1 + local t + i, t = parse_type(ps, i) + i = verify_tk(ps, i, ")") + return i, t + end -local function parse_base_type(ps, i) - local tk = ps.tokens[i].tk - if ps.tokens[i].kind == "identifier" then - return parse_simple_type_or_nominal(ps, i) - elseif tk == "{" then + local bt local istart = i - i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then + i, bt = parse_base_type(ps, i) + if not bt then return i end - if ps.tokens[i].tk == "}" then - local decl = new_type(ps, istart, "array") - decl.elements = t - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == "," then - local decl = new_type(ps, istart, "tupletable") - decl.types = { t } - local n = 2 - repeat + if ps.tokens[i].tk == "|" then + local u = new_type(ps, istart, "union") + u.types = { bt } + while ps.tokens[i].tk == "|" do i = i + 1 - i, decl.types[n] = parse_type(ps, i) - if not decl.types[n] then - break + i, bt = parse_base_type(ps, i) + if not bt then + return i end - n = n + 1 - until ps.tokens[i].tk ~= "," - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - elseif ps.tokens[i].tk == ":" then - local decl = new_type(ps, istart, "map") - i = i + 1 - decl.keys = t - i, decl.values = parse_type(ps, i) - if not decl.values then - return i + table.insert(u.types, bt) end - end_at(decl, ps.tokens[i]) - i = verify_tk(ps, i, "}") - return i, decl - end - return fail(ps, i, "syntax error; did you forget a '}'?") - elseif tk == "function" then - return parse_function_type(ps, i) - elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") - typ.keys = ANY - typ.values = ANY - return i + 1, typ - end - return fail(ps, i, "expected a type") -end - -parse_type = function(ps, i) - if ps.tokens[i].tk == "(" then - i = i + 1 - local t - i, t = parse_type(ps, i) - i = verify_tk(ps, i, ")") - return i, t + bt = u + end + return i, bt end - local bt - local istart = i - i, bt = parse_base_type(ps, i) - if not bt then - return i - end - if ps.tokens[i].tk == "|" then - local u = new_type(ps, istart, "union") - u.types = { bt } - while ps.tokens[i].tk == "|" do - i = i + 1 - i, bt = parse_base_type(ps, i) - if not bt then - return i + parse_type_list = function(ps, i, mode) + local t, list = new_tuple(ps, i) + + local first_token = ps.tokens[i].tk + if mode == "rets" or mode == "decltuple" then + if first_token == ":" then + i = i + 1 + else + return i, t end - table.insert(u.types, bt) end - bt = u - end - return i, bt -end -local function new_tuple(ps, i) - local t = new_type(ps, i, "tuple") - t.tuple = {} - return t, t.tuple -end + local optional_paren = false + if ps.tokens[i].tk == "(" then + optional_paren = true + i = i + 1 + end -parse_type_list = function(ps, i, mode) - local t, list = new_tuple(ps, i) + local prev_i = i + i = parse_trying_list(ps, i, list, parse_type) + if i == prev_i and ps.tokens[i].tk ~= ")" then + fail(ps, i - 1, "expected a type list") + end - local first_token = ps.tokens[i].tk - if mode == "rets" or mode == "decltuple" then - if first_token == ":" then + if mode == "rets" and ps.tokens[i].tk == "..." then i = i + 1 - else - return i, t + local nrets = #list + if nrets > 0 then + t.is_va = true + else + fail(ps, i, "unexpected '...'") + end end - end - local optional_paren = false - if ps.tokens[i].tk == "(" then - optional_paren = true - i = i + 1 - end + if optional_paren then + i = verify_tk(ps, i, ")") + end - local prev_i = i - i = parse_trying_list(ps, i, list, parse_type) - if i == prev_i and ps.tokens[i].tk ~= ")" then - fail(ps, i - 1, "expected a type list") + return i, t end - if mode == "rets" and ps.tokens[i].tk == "..." then - i = i + 1 - local nrets = #list - if nrets > 0 then - t.is_va = true - else - fail(ps, i, "unexpected '...'") + local function parse_function_args_rets_body(ps, i, node) + local istart = i - 1 + if ps.tokens[i].tk == "<" then + i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) end + i, node.args, node.min_arity = parse_argument_list(ps, i) + i, node.rets = parse_return_types(ps, i) + i, node.body = parse_statements(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + return i, node end - if optional_paren then - i = verify_tk(ps, i, ")") + local function parse_function_value(ps, i) + local node = new_node(ps, i, "function") + i = verify_tk(ps, i, "function") + return parse_function_args_rets_body(ps, i, node) end - return i, t -end - -local function parse_function_args_rets_body(ps, i, node) - local istart = i - 1 - if ps.tokens[i].tk == "<" then - i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end - i, node.args, node.min_arity = parse_argument_list(ps, i) - i, node.rets = parse_return_types(ps, i) - i, node.body = parse_statements(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_function_value(ps, i) - local node = new_node(ps.tokens, i, "function") - i = verify_tk(ps, i, "function") - return parse_function_args_rets_body(ps, i, node) -end - -local function unquote(str) - local f = str:sub(1, 1) - if f == '"' or f == "'" then - return str:sub(2, -2), false + local function unquote(str) + local f = str:sub(1, 1) + if f == '"' or f == "'" then + return str:sub(2, -2), false + end + f = str:match("^%[=*%[") + local l = #f + 1 + return str:sub(l, -l), true end - f = str:match("^%[=*%[") - local l = #f + 1 - return str:sub(l, -l), true -end - -local function parse_literal(ps, i) - local tk = ps.tokens[i].tk - local kind = ps.tokens[i].kind - if kind == "identifier" then - return verify_kind(ps, i, "identifier", "variable") - elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") - node.conststr, node.is_longstring = unquote(tk) - return i + 1, node - elseif kind == "number" or kind == "integer" then - local n = tonumber(tk) - local node - i, node = verify_kind(ps, i, kind) - node.constnum = n - return i, node - elseif tk == "true" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "false" then - return verify_kind(ps, i, "keyword", "boolean") - elseif tk == "nil" then - return verify_kind(ps, i, "keyword", "nil") - elseif tk == "function" then - return parse_function_value(ps, i) - elseif tk == "{" then - return parse_table_literal(ps, i) - elseif kind == "..." then - return verify_kind(ps, i, "...") - elseif kind == "$ERR invalid_string$" then - return fail(ps, i, "malformed string") - elseif kind == "$ERR invalid_number$" then - return fail(ps, i, "malformed number") - end - return fail(ps, i, "syntax error") -end -local function node_is_require_call(n) - if n.e1 and n.e2 and - n.e1.kind == "variable" and n.e1.tk == "require" and - n.e2.kind == "expression_list" and #n.e2 == 1 and - n.e2[1].kind == "string" then - - return n.e2[1].conststr - elseif n.op and n.op.op == "@funcall" and - n.e1 and n.e1.tk == "pcall" and - n.e2 and #n.e2 == 2 and - n.e2[1].kind == "variable" and n.e2[1].tk == "require" and - n.e2[2].kind == "string" and n.e2[2].conststr then - - return n.e2[2].conststr - else - return nil + local function parse_literal(ps, i) + local tk = ps.tokens[i].tk + local kind = ps.tokens[i].kind + if kind == "identifier" then + return verify_kind(ps, i, "identifier", "variable") + elseif kind == "string" then + local node = new_node(ps, i, "string") + node.conststr, node.is_longstring = unquote(tk) + return i + 1, node + elseif kind == "number" or kind == "integer" then + local n = tonumber(tk) + local node + i, node = verify_kind(ps, i, kind) + node.constnum = n + return i, node + elseif tk == "true" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "false" then + return verify_kind(ps, i, "keyword", "boolean") + elseif tk == "nil" then + return verify_kind(ps, i, "keyword", "nil") + elseif tk == "function" then + return parse_function_value(ps, i) + elseif tk == "{" then + return parse_table_literal(ps, i) + elseif kind == "..." then + return verify_kind(ps, i, "...") + elseif kind == "$ERR invalid_string$" then + return fail(ps, i, "malformed string") + elseif kind == "$ERR invalid_number$" then + return fail(ps, i, "malformed number") + end + return fail(ps, i, "syntax error") + end + + local function node_is_require_call(n) + if n.e1 and n.e2 and + n.e1.kind == "variable" and n.e1.tk == "require" and + n.e2.kind == "expression_list" and #n.e2 == 1 and + n.e2[1].kind == "string" then + + return n.e2[1].conststr + elseif n.op and n.op.op == "@funcall" and + n.e1 and n.e1.tk == "pcall" and + n.e2 and #n.e2 == 2 and + n.e2[1].kind == "variable" and n.e2[1].tk == "require" and + n.e2[2].kind == "string" and n.e2[2].conststr then + + return n.e2[2].conststr + else + return nil + end end -end - -local an_operator -do - local precedences = { - [1] = { - ["not"] = 11, - ["#"] = 11, - ["-"] = 11, - ["~"] = 11, - }, - [2] = { - ["or"] = 1, - ["and"] = 2, - ["is"] = 3, - ["<"] = 3, - [">"] = 3, - ["<="] = 3, - [">="] = 3, - ["~="] = 3, - ["=="] = 3, - ["|"] = 4, - ["~"] = 5, - ["&"] = 6, - ["<<"] = 7, - [">>"] = 7, - [".."] = 8, - ["+"] = 9, - ["-"] = 9, - ["*"] = 10, - ["/"] = 10, - ["//"] = 10, - ["%"] = 10, - ["^"] = 12, - ["as"] = 50, - ["@funcall"] = 100, - ["@index"] = 100, - ["."] = 100, - [":"] = 100, - }, - } - - local is_right_assoc = { - ["^"] = true, - [".."] = true, - } + do + local precedences = { + [1] = { + ["not"] = 11, + ["#"] = 11, + ["-"] = 11, + ["~"] = 11, + }, + [2] = { + ["or"] = 1, + ["and"] = 2, + ["is"] = 3, + ["<"] = 3, + [">"] = 3, + ["<="] = 3, + [">="] = 3, + ["~="] = 3, + ["=="] = 3, + ["|"] = 4, + ["~"] = 5, + ["&"] = 6, + ["<<"] = 7, + [">>"] = 7, + [".."] = 8, + ["+"] = 9, + ["-"] = 9, + ["*"] = 10, + ["/"] = 10, + ["//"] = 10, + ["%"] = 10, + ["^"] = 12, + ["as"] = 50, + ["@funcall"] = 100, + ["@index"] = 100, + ["."] = 100, + [":"] = 100, + }, + } - local function new_operator(tk, arity, op) - return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local is_right_assoc = { + ["^"] = true, + [".."] = true, + } - an_operator = function(node, arity, op) - return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } - end + local function new_operator(tk, arity, op) + return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local args_starters = { - ["("] = true, - ["{"] = true, - ["string"] = true, - } + an_operator = function(node, arity, op) + return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } + end - local E + local args_starters = { + ["("] = true, + ["{"] = true, + ["string"] = true, + } - local function after_valid_prefixexp(ps, prevnode, i) - return ps.tokens[i - 1].kind == ")" or - (prevnode.kind == "op" and - (prevnode.op.op == "@funcall" or - prevnode.op.op == "@index" or - prevnode.op.op == "." or - prevnode.op.op == ":")) or + local E - prevnode.kind == "identifier" or - prevnode.kind == "variable" - end + local function after_valid_prefixexp(ps, prevnode, i) + return ps.tokens[i - 1].kind == ")" or + (prevnode.kind == "op" and + (prevnode.op.op == "@funcall" or + prevnode.op.op == "@index" or + prevnode.op.op == "." or + prevnode.op.op == ":")) or + prevnode.kind == "identifier" or + prevnode.kind == "variable" + end - local function failstore(tkop, e1) - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } - end - local function P(ps, i) - if ps.tokens[i].kind == "$EOF$" then - return i + local function failstore(ps, tkop, e1) + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end - local e1 - local t1 = ps.tokens[i] - if precedences[1][t1.tk] ~= nil then - local op = new_operator(t1, 1, t1.tk) - i = i + 1 - local prev_i = i - i, e1 = P(ps, i) - if not e1 then - fail(ps, prev_i, "expected an expression") + + local function P(ps, i) + if ps.tokens[i].kind == "$EOF$" then return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } - elseif ps.tokens[i].tk == "(" then - i = i + 1 - local prev_i = i - i, e1 = parse_expression_and_tk(ps, i, ")") + local e1 + local t1 = ps.tokens[i] + if precedences[1][t1.tk] ~= nil then + local op = new_operator(t1, 1, t1.tk) + i = i + 1 + local prev_i = i + i, e1 = P(ps, i) + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + elseif ps.tokens[i].tk == "(" then + i = i + 1 + local prev_i = i + i, e1 = parse_expression_and_tk(ps, i, ")") + if not e1 then + fail(ps, prev_i, "expected an expression") + return i + end + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + else + i, e1 = parse_literal(ps, i) + end + if not e1 then - fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } - else - i, e1 = parse_literal(ps, i) - end - - if not e1 then - return i - end - while true do - local tkop = ps.tokens[i] - if tkop.kind == "," or tkop.kind == ")" then - break - end - if tkop.tk == "." or tkop.tk == ":" then - local op = new_operator(tkop, 2, tkop.tk) + while true do + local tkop = ps.tokens[i] + if tkop.kind == "," or tkop.kind == ")" then + break + end + if tkop.tk == "." or tkop.tk == ":" then + local op = new_operator(tkop, 2, tkop.tk) - local prev_i = i + local prev_i = i - local key - i = i + 1 - if ps.tokens[i].kind ~= "identifier" then - local skipped = skip(ps, i, parse_type) - if skipped > i + 1 then - fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + local key + i = i + 1 + if ps.tokens[i].kind ~= "identifier" then + local skipped = skip(ps, i, parse_type) + if skipped > i + 1 then + fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") + return skipped, failstore(ps, tkop, e1) + end + end + i, key = verify_kind(ps, i, "identifier") + if not key then + return i, failstore(ps, tkop, e1) end - end - i, key = verify_kind(ps, i, "identifier") - if not key then - return i, failstore(tkop, e1) - end - if op.op == ":" then - if not args_starters[ps.tokens[i].kind] then - if ps.tokens[i].tk == "=" then - fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") - else - fail(ps, i, "expected a function call for a method") + if op.op == ":" then + if not args_starters[ps.tokens[i].kind] then + if ps.tokens[i].tk == "=" then + fail(ps, i, "syntax error, cannot perform an assignment here (missing 'local' or 'global'?)") + else + fail(ps, i, "expected a function call for a method") + end + return i, failstore(ps, tkop, e1) end - return i, failstore(tkop, e1) - end - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call a method on this expression") + return i, failstore(ps, tkop, e1) + end end - end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } - elseif tkop.tk == "(" then - local op = new_operator(tkop, 2, "@funcall") + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + elseif tkop.tk == "(" then + local op = new_operator(tkop, 2, "@funcall") - local prev_i = i + local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") - i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) + local args = new_node(ps, i, "expression_list") + i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) - end - - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot call this expression") + return i, failstore(ps, tkop, e1) + end - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "[" then - local op = new_operator(tkop, 2, "@index") + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - local prev_i = i + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "[" then + local op = new_operator(tkop, 2, "@index") - local idx - i = i + 1 - i, idx = parse_expression_and_tk(ps, i, "]") + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then - fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) - end + local idx + i = i + 1 + i, idx = parse_expression_and_tk(ps, i, "]") - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } - elseif tkop.kind == "string" or tkop.kind == "{" then - local op = new_operator(tkop, 2, "@funcall") + if not after_valid_prefixexp(ps, e1, prev_i) then + fail(ps, prev_i, "cannot index this expression") + return i, failstore(ps, tkop, e1) + end - local prev_i = i + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + elseif tkop.kind == "string" or tkop.kind == "{" then + local op = new_operator(tkop, 2, "@funcall") - local args = new_node(ps.tokens, i, "expression_list") - local argument - if tkop.kind == "string" then - argument = new_node(ps.tokens, i) - argument.conststr = unquote(tkop.tk) - i = i + 1 - else - i, argument = parse_table_literal(ps, i) - end + local prev_i = i - if not after_valid_prefixexp(ps, e1, prev_i) then + local args = new_node(ps, i, "expression_list") + local argument if tkop.kind == "string" then - fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + argument = new_node(ps, i) + argument.conststr = unquote(tkop.tk) + i = i + 1 else - fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + i, argument = parse_table_literal(ps, i) + end + + if not after_valid_prefixexp(ps, e1, prev_i) then + if tkop.kind == "string" then + fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") + else + fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") + end + return i, failstore(ps, tkop, e1) end - return i, failstore(tkop, e1) - end - table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + table.insert(args, argument) + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } - table.insert(ps.required_modules, node_is_require_call(e1)) - elseif tkop.tk == "as" or tkop.tk == "is" then - local op = new_operator(tkop, 2, tkop.tk) + table.insert(ps.required_modules, node_is_require_call(e1)) + elseif tkop.tk == "as" or tkop.tk == "is" then + local op = new_operator(tkop, 2, tkop.tk) - i = i + 1 - local cast = new_node(ps.tokens, i, "cast") - if ps.tokens[i].tk == "(" then - i, cast.casttype = parse_type_list(ps, i, "casttype") + i = i + 1 + local cast = new_node(ps, i, "cast") + if ps.tokens[i].tk == "(" then + i, cast.casttype = parse_type_list(ps, i, "casttype") + else + i, cast.casttype = parse_type(ps, i) + end + if not cast.casttype then + return i, failstore(ps, tkop, e1) + end + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else - i, cast.casttype = parse_type(ps, i) - end - if not cast.casttype then - return i, failstore(tkop, e1) + break end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } - else - break end - end - return i, e1 - end + return i, e1 + end - E = function(ps, i, lhs, min_precedence) - local lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do - local t1 = ps.tokens[i] - local op = new_operator(t1, 2, t1.tk) - i = i + 1 - local rhs - i, rhs = P(ps, i) - if not rhs then - fail(ps, i, "expected an expression") - return i - end - lookahead = ps.tokens[i].tk - while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or - (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do - i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + E = function(ps, i, lhs, min_precedence) + local lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do + local t1 = ps.tokens[i] + local op = new_operator(t1, 2, t1.tk) + i = i + 1 + local rhs + i, rhs = P(ps, i) if not rhs then fail(ps, i, "expected an expression") return i end lookahead = ps.tokens[i].tk + while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or + (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do + i, rhs = E(ps, i, rhs, precedences[2][lookahead]) + if not rhs then + fail(ps, i, "expected an expression") + return i + end + lookahead = ps.tokens[i].tk + end + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } + return i, lhs end - return i, lhs - end - parse_expression = function(ps, i) - local lhs - local istart = i - i, lhs = P(ps, i) - if lhs then - i, lhs = E(ps, i, lhs, 0) - end - if lhs then - return i, lhs, 0 - end + parse_expression = function(ps, i) + local lhs + local istart = i + i, lhs = P(ps, i) + if lhs then + i, lhs = E(ps, i, lhs, 0) + end + if lhs then + return i, lhs, 0 + end - if i == istart then - i = fail(ps, i, "expected an expression") + if i == istart then + i = fail(ps, i, "expected an expression") + end + return i end - return i end -end -parse_expression_and_tk = function(ps, i, tk) - local e - i, e = parse_expression(ps, i) - if not e then - e = new_node(ps.tokens, i - 1, "error_node") - end - if ps.tokens[i].tk == tk then - i = i + 1 - else - local msg = "syntax error, expected '" .. tk .. "'" - if ps.tokens[i].tk == "=" then - msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + parse_expression_and_tk = function(ps, i, tk) + local e + i, e = parse_expression(ps, i) + if not e then + e = new_node(ps, i - 1, "error_node") end + if ps.tokens[i].tk == tk then + i = i + 1 + else + local msg = "syntax error, expected '" .. tk .. "'" + if ps.tokens[i].tk == "=" then + msg = "syntax error, cannot perform an assignment here (did you mean '=='?)" + end - for n = 0, 19 do - local t = ps.tokens[i + n] - if t.kind == "$EOF$" then - break - end - if t.tk == tk then - fail(ps, i, msg) - return i + n + 1, e + for n = 0, 19 do + local t = ps.tokens[i + n] + if t.kind == "$EOF$" then + break + end + if t.tk == tk then + fail(ps, i, msg) + return i + n + 1, e + end end + i = fail(ps, i, msg) end - i = fail(ps, i, msg) + return i, e end - return i, e -end -local function parse_variable_name(ps, i) - local node - i, node = verify_kind(ps, i, "identifier") - if not node then - return i - end - if ps.tokens[i].tk == "<" then - i = i + 1 - local annotation - i, annotation = verify_kind(ps, i, "identifier") - if annotation then - if not is_attribute[annotation.tk] then - fail(ps, i, "unknown variable annotation: " .. annotation.tk) + local function parse_variable_name(ps, i) + local node + i, node = verify_kind(ps, i, "identifier") + if not node then + return i + end + if ps.tokens[i].tk == "<" then + i = i + 1 + local annotation + i, annotation = verify_kind(ps, i, "identifier") + if annotation then + if not is_attribute[annotation.tk] then + fail(ps, i, "unknown variable annotation: " .. annotation.tk) + end + node.attribute = annotation.tk + else + fail(ps, i, "expected a variable annotation") end - node.attribute = annotation.tk - else - fail(ps, i, "expected a variable annotation") + i = verify_tk(ps, i, ">") end - i = verify_tk(ps, i, ">") + return i, node end - return i, node -end -local function parse_argument(ps, i) - local node - if ps.tokens[i].tk == "..." then - i, node = verify_kind(ps, i, "...", "argument") - node.opt = true - else - i, node = verify_kind(ps, i, "identifier", "argument") - end - if ps.tokens[i].tk == "..." then - fail(ps, i, "'...' needs to be declared as a typed argument") - end - if ps.tokens[i].tk == "?" then - i = i + 1 - node.opt = true - end - if ps.tokens[i].tk == ":" then - i = i + 1 - local argtype + local function parse_argument(ps, i) + local node + if ps.tokens[i].tk == "..." then + i, node = verify_kind(ps, i, "...", "argument") + node.opt = true + else + i, node = verify_kind(ps, i, "identifier", "argument") + end + if ps.tokens[i].tk == "..." then + fail(ps, i, "'...' needs to be declared as a typed argument") + end + if ps.tokens[i].tk == "?" then + i = i + 1 + node.opt = true + end + if ps.tokens[i].tk == ":" then + i = i + 1 + local argtype - i, argtype = parse_type(ps, i) + i, argtype = parse_type(ps, i) - if node then - node.argtype = argtype + if node then + node.argtype = argtype + end end + return i, node, 0 end - return i, node, 0 -end -parse_argument_list = function(ps, i) - local node = new_node(ps.tokens, i, "argument_list") - i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) - local opts = false - local min_arity = 0 - for a, fnarg in ipairs(node) do - if fnarg.tk == "..." then - if a ~= #node then - fail(ps, i, "'...' can only be last argument") - break + parse_argument_list = function(ps, i) + local node = new_node(ps, i, "argument_list") + i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) + local opts = false + local min_arity = 0 + for a, fnarg in ipairs(node) do + if fnarg.tk == "..." then + if a ~= #node then + fail(ps, i, "'...' can only be last argument") + break + end + elseif fnarg.opt then + opts = true + elseif opts then + return fail(ps, i, "non-optional arguments cannot follow optional arguments") + else + min_arity = min_arity + 1 end - elseif fnarg.opt then - opts = true - elseif opts then - return fail(ps, i, "non-optional arguments cannot follow optional arguments") - else - min_arity = min_arity + 1 end + return i, node, min_arity end - return i, node, min_arity -end @@ -3176,1014 +3207,982 @@ end -local function parse_argument_type(ps, i) - local opt = false - local is_va = false - local is_self = false - local argument_name = nil + local function parse_argument_type(ps, i) + local opt = false + local is_va = false + local is_self = false + local argument_name = nil - if ps.tokens[i].kind == "identifier" then - argument_name = ps.tokens[i].tk - if ps.tokens[i + 1].tk == "?" then + if ps.tokens[i].kind == "identifier" then + argument_name = ps.tokens[i].tk + if ps.tokens[i + 1].tk == "?" then + opt = true + if ps.tokens[i + 2].tk == ":" then + i = i + 3 + end + elseif ps.tokens[i + 1].tk == ":" then + i = i + 2 + end + elseif ps.tokens[i].kind == "?" then opt = true - if ps.tokens[i + 2].tk == ":" then - i = i + 3 + i = i + 1 + elseif ps.tokens[i].tk == "..." then + if ps.tokens[i + 1].tk == ":" then + i = i + 2 + is_va = true + else + return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") end - elseif ps.tokens[i + 1].tk == ":" then - i = i + 2 end - elseif ps.tokens[i].kind == "?" then - opt = true - i = i + 1 - elseif ps.tokens[i].tk == "..." then - if ps.tokens[i + 1].tk == ":" then - i = i + 2 - is_va = true - else - return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") - end - end - local typ; i, typ = parse_type(ps, i) - if typ then - if not is_va and ps.tokens[i].tk == "..." then - i = i + 1 - is_va = true - end + local typ; i, typ = parse_type(ps, i) + if typ then + if not is_va and ps.tokens[i].tk == "..." then + i = i + 1 + is_va = true + end - if argument_name == "self" then - is_self = true + if argument_name == "self" then + is_self = true + end end - end - return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 -end + return i, { i = i, type = typ, is_va = is_va, is_self = is_self, opt = opt or is_va }, 0 + end -parse_argument_type_list = function(ps, i) - local ars = {} - i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) - local t, list = new_tuple(ps, i) - local n = #ars - local min_arity = 0 - for l, ar in ipairs(ars) do - list[l] = ar.type - if ar.is_va and l < n then - fail(ps, ar.i, "'...' can only be last argument") + parse_argument_type_list = function(ps, i) + local ars = {} + i = parse_bracket_list(ps, i, ars, "(", ")", "sep", parse_argument_type) + local t, list = new_tuple(ps, i) + local n = #ars + local min_arity = 0 + for l, ar in ipairs(ars) do + list[l] = ar.type + if ar.is_va and l < n then + fail(ps, ar.i, "'...' can only be last argument") + end + if not ar.opt then + min_arity = min_arity + 1 + end end - if not ar.opt then - min_arity = min_arity + 1 + if n > 0 and ars[n].is_va then + t.is_va = true end + return i, t, (n > 0 and ars[1].is_self), min_arity end - if n > 0 and ars[n].is_va then - t.is_va = true + + local function parse_identifier(ps, i) + if ps.tokens[i].kind == "identifier" then + return i + 1, new_node(ps, i, "identifier") + end + i = fail(ps, i, "syntax error, expected identifier") + return i, new_node(ps, i, "error_node") end - return i, t, (n > 0 and ars[1].is_self), min_arity -end -local function parse_identifier(ps, i) - if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + local function parse_local_function(ps, i) + i = verify_tk(ps, i, "local") + i = verify_tk(ps, i, "function") + local node = new_node(ps, i - 2, "local_function") + i, node.name = parse_identifier(ps, i) + return parse_function_args_rets_body(ps, i, node) end - i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") -end -local function parse_local_function(ps, i) - i = verify_tk(ps, i, "local") - i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") - i, node.name = parse_identifier(ps, i) - return parse_function_args_rets_body(ps, i, node) -end + local function parse_function(ps, i, fk) + local orig_i = i + i = verify_tk(ps, i, "function") + local fn = new_node(ps, i - 1, "global_function") + local names = {} + i, names[1] = parse_identifier(ps, i) + while ps.tokens[i].tk == "." do + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + end + if ps.tokens[i].tk == ":" then + i = i + 1 + i, names[#names + 1] = parse_identifier(ps, i) + fn.is_method = true + end -local function parse_function(ps, i, fk) - local orig_i = i - i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") - local names = {} - i, names[1] = parse_identifier(ps, i) - while ps.tokens[i].tk == "." do - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - end - if ps.tokens[i].tk == ":" then - i = i + 1 - i, names[#names + 1] = parse_identifier(ps, i) - fn.is_method = true - end + if #names > 1 then + fn.kind = "record_function" + local owner = names[1] + owner.kind = "type_identifier" + for i2 = 2, #names - 1 do + local dot = an_operator(names[i2], 2, ".") + names[i2].kind = "identifier" + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + end + fn.fn_owner = owner + end + fn.name = names[#names] - if #names > 1 then - fn.kind = "record_function" - local owner = names[1] - owner.kind = "type_identifier" - for i2 = 2, #names - 1 do - local dot = an_operator(names[i2], 2, ".") - names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y + i = parse_function_args_rets_body(ps, i, fn) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + fn.min_arity = fn.min_arity + 1 end - fn.fn_owner = owner - end - fn.name = names[#names] - local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y - i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) - fn.min_arity = fn.min_arity + 1 - end + if not fn.name then + return orig_i + 1 + end - if not fn.name then - return orig_i + 1 - end + if fn.kind == "record_function" and fk == "global" then + fail(ps, orig_i, "record functions cannot be annotated as 'global'") + elseif fn.kind == "global_function" and fk == "record" then + fn.implicit_global_function = true + end - if fn.kind == "record_function" and fk == "global" then - fail(ps, orig_i, "record functions cannot be annotated as 'global'") - elseif fn.kind == "global_function" and fk == "record" then - fn.implicit_global_function = true + return i, fn end - return i, fn -end - -local function parse_if_block(ps, i, n, node, is_else) - local block = new_node(ps.tokens, i, "if_block") - i = i + 1 - block.if_parent = node - block.if_block_n = n - if not is_else then - i, block.exp = parse_expression_and_tk(ps, i, "then") - if not block.exp then + local function parse_if_block(ps, i, n, node, is_else) + local block = new_node(ps, i, "if_block") + i = i + 1 + block.if_parent = node + block.if_block_n = n + if not is_else then + i, block.exp = parse_expression_and_tk(ps, i, "then") + if not block.exp then + return i + end + end + i, block.body = parse_statements(ps, i) + if not block.body then return i end + end_at(block.body, ps.tokens[i - 1]) + block.yend, block.xend = block.body.yend, block.body.xend + table.insert(node.if_blocks, block) + return i, node end - i, block.body = parse_statements(ps, i) - if not block.body then - return i - end - end_at(block.body, ps.tokens[i - 1]) - block.yend, block.xend = block.body.yend, block.body.xend - table.insert(node.if_blocks, block) - return i, node -end -local function parse_if(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "if") - node.if_blocks = {} - i, node = parse_if_block(ps, i, 1, node) - if not node then - return i - end - local n = 2 - while ps.tokens[i].tk == "elseif" do - i, node = parse_if_block(ps, i, n, node) + local function parse_if(ps, i) + local istart = i + local node = new_node(ps, i, "if") + node.if_blocks = {} + i, node = parse_if_block(ps, i, 1, node) if not node then return i end - n = n + 1 + local n = 2 + while ps.tokens[i].tk == "elseif" do + i, node = parse_if_block(ps, i, n, node) + if not node then + return i + end + n = n + 1 + end + if ps.tokens[i].tk == "else" then + i, node = parse_if_block(ps, i, n, node, true) + if not node then + return i + end + end + i = verify_end(ps, i, istart, node) + return i, node + end + + local function parse_while(ps, i) + local istart = i + local node = new_node(ps, i, "while") + i = verify_tk(ps, i, "while") + i, node.exp = parse_expression_and_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - if ps.tokens[i].tk == "else" then - i, node = parse_if_block(ps, i, n, node, true) - if not node then - return i + + local function parse_fornum(ps, i) + local istart = i + local node = new_node(ps, i, "fornum") + i = i + 1 + i, node.var = parse_identifier(ps, i) + i = verify_tk(ps, i, "=") + i, node.from = parse_expression_and_tk(ps, i, ",") + i, node.to = parse_expression(ps, i) + if ps.tokens[i].tk == "," then + i = i + 1 + i, node.step = parse_expression_and_tk(ps, i, "do") + else + i = verify_tk(ps, i, "do") end + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_while(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "while") - i = verify_tk(ps, i, "while") - i, node.exp = parse_expression_and_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_fornum(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "fornum") - i = i + 1 - i, node.var = parse_identifier(ps, i) - i = verify_tk(ps, i, "=") - i, node.from = parse_expression_and_tk(ps, i, ",") - i, node.to = parse_expression(ps, i) - if ps.tokens[i].tk == "," then + local function parse_forin(ps, i) + local istart = i + local node = new_node(ps, i, "forin") i = i + 1 - i, node.step = parse_expression_and_tk(ps, i, "do") - else + node.vars = new_node(ps, i, "variable_list") + i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) + i = verify_tk(ps, i, "in") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) + if #node.exps < 1 then + return fail(ps, i, "missing iterator expression in generic for") + elseif #node.exps > 3 then + return fail(ps, i, "too many expressions in generic for") + end i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node end - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end - -local function parse_forin(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "forin") - i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") - i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) - i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) - if #node.exps < 1 then - return fail(ps, i, "missing iterator expression in generic for") - elseif #node.exps > 3 then - return fail(ps, i, "too many expressions in generic for") - end - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_for(ps, i) - if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then - return parse_fornum(ps, i) - else - return parse_forin(ps, i) + local function parse_for(ps, i) + if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then + return parse_fornum(ps, i) + else + return parse_forin(ps, i) + end end -end -local function parse_repeat(ps, i) - local node = new_node(ps.tokens, i, "repeat") - i = verify_tk(ps, i, "repeat") - i, node.body = parse_statements(ps, i) - node.body.is_repeat = true - i = verify_tk(ps, i, "until") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + local function parse_repeat(ps, i) + local node = new_node(ps, i, "repeat") + i = verify_tk(ps, i, "repeat") + i, node.body = parse_statements(ps, i) + node.body.is_repeat = true + i = verify_tk(ps, i, "until") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local function parse_do(ps, i) - local istart = i - local node = new_node(ps.tokens, i, "do") - i = verify_tk(ps, i, "do") - i, node.body = parse_statements(ps, i) - i = verify_end(ps, i, istart, node) - return i, node -end + local function parse_do(ps, i) + local istart = i + local node = new_node(ps, i, "do") + i = verify_tk(ps, i, "do") + i, node.body = parse_statements(ps, i) + i = verify_end(ps, i, istart, node) + return i, node + end -local function parse_break(ps, i) - local node = new_node(ps.tokens, i, "break") - i = verify_tk(ps, i, "break") - return i, node -end + local function parse_break(ps, i) + local node = new_node(ps, i, "break") + i = verify_tk(ps, i, "break") + return i, node + end -local function parse_goto(ps, i) - local node = new_node(ps.tokens, i, "goto") - i = verify_tk(ps, i, "goto") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - return i, node -end + local function parse_goto(ps, i) + local node = new_node(ps, i, "goto") + i = verify_tk(ps, i, "goto") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + return i, node + end -local function parse_label(ps, i) - local node = new_node(ps.tokens, i, "label") - i = verify_tk(ps, i, "::") - node.label = ps.tokens[i].tk - i = verify_kind(ps, i, "identifier") - i = verify_tk(ps, i, "::") - return i, node -end + local function parse_label(ps, i) + local node = new_node(ps, i, "label") + i = verify_tk(ps, i, "::") + node.label = ps.tokens[i].tk + i = verify_kind(ps, i, "identifier") + i = verify_tk(ps, i, "::") + return i, node + end -local stop_statement_list = { - ["end"] = true, - ["else"] = true, - ["elseif"] = true, - ["until"] = true, -} + local stop_statement_list = { + ["end"] = true, + ["else"] = true, + ["elseif"] = true, + ["until"] = true, + } -local stop_return_list = { - [";"] = true, - ["$EOF$"] = true, -} + local stop_return_list = { + [";"] = true, + ["$EOF$"] = true, + } -for k, v in pairs(stop_statement_list) do - stop_return_list[k] = v -end + for k, v in pairs(stop_statement_list) do + stop_return_list[k] = v + end -local function parse_return(ps, i) - local node = new_node(ps.tokens, i, "return") - i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") - i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) - if ps.tokens[i].kind == ";" then - i = i + 1 + local function parse_return(ps, i) + local node = new_node(ps, i, "return") + i = verify_tk(ps, i, "return") + node.exps = new_node(ps, i, "expression_list") + i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) + if ps.tokens[i].kind == ";" then + i = i + 1 + end + return i, node end - return i, node -end -local function store_field_in_record(ps, i, field_name, t, fields, field_order) - if not fields[field_name] then - fields[field_name] = t - table.insert(field_order, field_name) - else - local prev_t = fields[field_name] - if t.typename == "function" and prev_t.typename == "function" then - local p = new_type(ps, i, "poly") - p.types = { prev_t, t } - fields[field_name] = p - elseif t.typename == "function" and prev_t.typename == "poly" then - table.insert(prev_t.types, t) + local function store_field_in_record(ps, i, field_name, t, fields, field_order) + if not fields[field_name] then + fields[field_name] = t + table.insert(field_order, field_name) else - fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") - return false + local prev_t = fields[field_name] + if t.typename == "function" and prev_t.typename == "function" then + local p = new_type(ps, i, "poly") + p.types = { prev_t, t } + fields[field_name] = p + elseif t.typename == "function" and prev_t.typename == "poly" then + table.insert(prev_t.types, t) + else + fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") + return false + end end + return true end - return true -end -local function parse_nested_type(ps, i, def, typename, parse_body) - i = i + 1 - local iv = i + local function parse_nested_type(ps, i, def, typename, parse_body) + i = i + 1 + local iv = i + + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end + local nt = new_node(ps, i - 2, "newtype") + local ndef = new_type(ps, i, typename) + local itype = i + local iok = parse_body(ps, i, ndef, nt) + if iok then + i = iok + nt.newtype = new_typedecl(ps, itype, ndef) + end - local nt = new_node(ps.tokens, i - 2, "newtype") - local ndef = new_type(ps, i, typename) - local iok = parse_body(ps, i, ndef, nt) - if iok then - i = iok - nt.newtype = new_typedecl(ps, i, ndef) + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + return i end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - return i -end - -parse_enum_body = function(ps, i, def, node) - local istart = i - 1 - def.enumset = {} - while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do - local item - i, item = verify_kind(ps, i, "string", "enum_item") - if item then - table.insert(node, item) - def.enumset[unquote(item.tk)] = true + parse_enum_body = function(ps, i, def, node) + local istart = i - 1 + def.enumset = {} + while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do + local item + i, item = verify_kind(ps, i, "string", "enum_item") + if item then + table.insert(node, item) + def.enumset[unquote(item.tk)] = true + end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -local metamethod_names = { - ["__add"] = true, - ["__sub"] = true, - ["__mul"] = true, - ["__div"] = true, - ["__mod"] = true, - ["__pow"] = true, - ["__unm"] = true, - ["__idiv"] = true, - ["__band"] = true, - ["__bor"] = true, - ["__bxor"] = true, - ["__bnot"] = true, - ["__shl"] = true, - ["__shr"] = true, - ["__concat"] = true, - ["__len"] = true, - ["__eq"] = true, - ["__lt"] = true, - ["__le"] = true, - ["__index"] = true, - ["__newindex"] = true, - ["__call"] = true, - ["__tostring"] = true, - ["__pairs"] = true, - ["__gc"] = true, - ["__close"] = true, - ["__is"] = true, -} - -local function parse_macroexp(ps, istart, iargs) + local metamethod_names = { + ["__add"] = true, + ["__sub"] = true, + ["__mul"] = true, + ["__div"] = true, + ["__mod"] = true, + ["__pow"] = true, + ["__unm"] = true, + ["__idiv"] = true, + ["__band"] = true, + ["__bor"] = true, + ["__bxor"] = true, + ["__bnot"] = true, + ["__shl"] = true, + ["__shr"] = true, + ["__concat"] = true, + ["__len"] = true, + ["__eq"] = true, + ["__lt"] = true, + ["__le"] = true, + ["__index"] = true, + ["__newindex"] = true, + ["__call"] = true, + ["__tostring"] = true, + ["__pairs"] = true, + ["__gc"] = true, + ["__close"] = true, + ["__is"] = true, + } + local function parse_macroexp(ps, istart, iargs) - local node = new_node(ps.tokens, istart, "macroexp") - local i - i, node.args, node.min_arity = parse_argument_list(ps, iargs) - i, node.rets = parse_return_types(ps, i) - i = verify_tk(ps, i, "return") - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i]) - i = verify_end(ps, i, istart, node) - return i, node -end -local function parse_where_clause(ps, i) - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") - node.args[1].tk = "self" - node.args[1].argtype = selftype - node.min_arity = 1 - node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN - i, node.exp = parse_expression(ps, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end -parse_interface_name = function(ps, i) - local istart = i - local typ - i, typ = parse_simple_type_or_nominal(ps, i) - if not (typ.typename == "nominal") then - return fail(ps, istart, "expected an interface") + local node = new_node(ps, istart, "macroexp") + local i + i, node.args, node.min_arity = parse_argument_list(ps, iargs) + i, node.rets = parse_return_types(ps, i) + i = verify_tk(ps, i, "return") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i]) + i = verify_end(ps, i, istart, node) + return i, node end - return i, typ -end -local function parse_array_interface_type(ps, i, def) - if def.interface_list then - local first = def.interface_list[1] - if first.typename == "array" then - return failskip(ps, i, "duplicated declaration of array element type", parse_type) - end - end - local t - i, t = parse_base_type(ps, i) - if not t then - return i - end - if not (t.typename == "array") then - fail(ps, i, "expected an array declaration") - return i + local function parse_where_clause(ps, i) + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") + node.args[1].tk = "self" + node.args[1].argtype = new_nominal(ps, i, "@self") + node.min_arity = 1 + node.rets = new_tuple(ps, i) + node.rets.tuple[1] = new_type(ps, i, "boolean") + i, node.exp = parse_expression(ps, i) + end_at(node, ps.tokens[i - 1]) + return i, node end - def.elements = t.elements - return i, t -end - -parse_record_body = function(ps, i, def, node) - local istart = i - 1 - def.fields = {} - def.field_order = {} - if ps.tokens[i].tk == "<" then - i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + parse_interface_name = function(ps, i) + local istart = i + local typ + i, typ = parse_simple_type_or_nominal(ps, i) + if not (typ.typename == "nominal") then + return fail(ps, istart, "expected an interface") + end + return i, typ end - if ps.tokens[i].tk == "{" then - local atype - i, atype = parse_array_interface_type(ps, i, def) - if atype then - def.interface_list = { atype } + local function parse_array_interface_type(ps, i, def) + if def.interface_list then + local first = def.interface_list[1] + if first.typename == "array" then + return failskip(ps, i, "duplicated declaration of array element type", parse_type) + end + end + local t + i, t = parse_base_type(ps, i) + if not t then + return i + end + if not (t.typename == "array") then + fail(ps, i, "expected an array declaration") + return i end + def.elements = t.elements + return i, t end - if ps.tokens[i].tk == "is" then - i = i + 1 + parse_record_body = function(ps, i, def, node) + local istart = i - 1 + def.fields = {} + def.field_order = {} + + if ps.tokens[i].tk == "<" then + i, def.typeargs = parse_anglebracket_list(ps, i, parse_typearg) + end if ps.tokens[i].tk == "{" then local atype i, atype = parse_array_interface_type(ps, i, def) - if ps.tokens[i].tk == "," then - i = i + 1 - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) - else - def.interface_list = {} - end if atype then - table.insert(def.interface_list, 1, atype) + def.interface_list = { atype } end - else - i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end - end - if ps.tokens[i].tk == "where" then - local wstart = i - i = i + 1 - local where_macroexp - i, where_macroexp = parse_where_clause(ps, i) - - local typ = new_type(ps, wstart, "function") - typ.is_method = true - typ.min_arity = 1 - typ.args = a_type("tuple", { tuple = { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" }, - }), - } }) - typ.rets = a_type("tuple", { tuple = { BOOLEAN } }) - typ.macroexp = where_macroexp - - def.meta_fields = {} - def.meta_field_order = {} - store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) - end - - while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do - local tn = ps.tokens[i].tk - if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then - if def.is_userdata then - fail(ps, i, "duplicated 'userdata' declaration") + if ps.tokens[i].tk == "is" then + i = i + 1 + + if ps.tokens[i].tk == "{" then + local atype + i, atype = parse_array_interface_type(ps, i, def) + if ps.tokens[i].tk == "," then + i = i + 1 + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) + else + def.interface_list = {} + end + if atype then + table.insert(def.interface_list, 1, atype) + end else - def.is_userdata = true + i, def.interface_list = parse_trying_list(ps, i, {}, parse_interface_name) end + end + + if ps.tokens[i].tk == "where" then + local wstart = i i = i + 1 - elseif ps.tokens[i].tk == "{" then - return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") - elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then - i = i + 1 - local iv = i - local v - i, v = verify_kind(ps, i, "identifier", "type_identifier") - if not v then - return fail(ps, i, "expected a variable name") - end - i = verify_tk(ps, i, "=") - local nt - i, nt = parse_newtype(ps, i) - if not nt or not nt.newtype then - return fail(ps, i, "expected a type definition") - end + local where_macroexp + i, where_macroexp = parse_where_clause(ps, i) + + local typ = new_type(ps, wstart, "function") + typ.is_method = true + typ.min_arity = 1 + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }), + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) + typ.macroexp = where_macroexp - local ntt = nt.newtype - if ntt.typename == "typealias" then - ntt.is_nested_alias = true - end + def.meta_fields = {} + def.meta_field_order = {} + store_field_in_record(ps, i, "__is", typ, def.meta_fields, def.meta_field_order) + end - store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) - elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then - i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) - else - local is_metamethod = false - if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then - is_metamethod = true + while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do + local tn = ps.tokens[i].tk + if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then + if def.is_userdata then + fail(ps, i, "duplicated 'userdata' declaration") + else + def.is_userdata = true + end i = i + 1 - end + elseif ps.tokens[i].tk == "{" then + return fail(ps, i, "syntax error: this syntax is no longer valid; declare array interface at the top with 'is {...}'") + elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then + i = i + 1 + local iv = i + local v + i, v = verify_kind(ps, i, "identifier", "type_identifier") + if not v then + return fail(ps, i, "expected a variable name") + end + i = verify_tk(ps, i, "=") + local nt + i, nt = parse_newtype(ps, i) + if not nt or not nt.newtype then + return fail(ps, i, "expected a type definition") + end - local v - if ps.tokens[i].tk == "[" then - i, v = parse_literal(ps, i + 1) - if v and not v.conststr then - return fail(ps, i, "expected a string literal") + local ntt = nt.newtype + if ntt.typename == "typealias" then + ntt.is_nested_alias = true end - i = verify_tk(ps, i, "]") + + store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) + elseif parse_type_body_fns[tn] and ps.tokens[i + 1].tk ~= ":" then + i = parse_nested_type(ps, i, def, tn, parse_type_body_fns[tn]) else - i, v = verify_kind(ps, i, "identifier", "variable") - end - local iv = i - if not v then - return fail(ps, i, "expected a variable name") - end + local is_metamethod = false + if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then + is_metamethod = true + i = i + 1 + end - if ps.tokens[i].tk == ":" then - i = i + 1 - local t - i, t = parse_type(ps, i) - if not t then - return fail(ps, i, "expected a type") + local v + if ps.tokens[i].tk == "[" then + i, v = parse_literal(ps, i + 1) + if v and not v.conststr then + return fail(ps, i, "expected a string literal") + end + i = verify_tk(ps, i, "]") + else + i, v = verify_kind(ps, i, "identifier", "variable") + end + local iv = i + if not v then + return fail(ps, i, "expected a variable name") end - local field_name = v.conststr or v.tk - local fields = def.fields - local field_order = def.field_order - if is_metamethod then - if not def.meta_fields then - def.meta_fields = {} - def.meta_field_order = {} + if ps.tokens[i].tk == ":" then + i = i + 1 + local t + i, t = parse_type(ps, i) + if not t then + return fail(ps, i, "expected a type") end - fields = def.meta_fields - field_order = def.meta_field_order - if not metamethod_names[field_name] then - fail(ps, i - 1, "not a valid metamethod: " .. field_name) + + local field_name = v.conststr or v.tk + local fields = def.fields + local field_order = def.field_order + if is_metamethod then + if not def.meta_fields then + def.meta_fields = {} + def.meta_field_order = {} + end + fields = def.meta_fields + field_order = def.meta_field_order + if not metamethod_names[field_name] then + fail(ps, i - 1, "not a valid metamethod: " .. field_name) + end end - end - if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then - if not (t.typename == "function") then - fail(ps, i + 1, "macroexp must have a function type") - else - i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + if ps.tokens[i].tk == "=" and ps.tokens[i + 1].tk == "macroexp" then + if not (t.typename == "function") then + fail(ps, i + 1, "macroexp must have a function type") + else + i, t.macroexp = parse_macroexp(ps, i + 1, i + 2) + end end - end - store_field_in_record(ps, iv, field_name, t, fields, field_order) - elseif ps.tokens[i].tk == "=" then - local next_word = ps.tokens[i + 1].tk - if next_word == "record" or next_word == "enum" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") - elseif next_word == "functiontype" then - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") + store_field_in_record(ps, iv, field_name, t, fields, field_order) + elseif ps.tokens[i].tk == "=" then + local next_word = ps.tokens[i + 1].tk + if next_word == "record" or next_word == "enum" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") + elseif next_word == "functiontype" then + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") + else + return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + end else - return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") + fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end - else - fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") end end + i = verify_end(ps, i, istart, node) + return i, node end - i = verify_end(ps, i, istart, node) - return i, node -end - -parse_type_body_fns = { - ["interface"] = parse_record_body, - ["record"] = parse_record_body, - ["enum"] = parse_enum_body, -} -parse_newtype = function(ps, i) - local node = new_node(ps.tokens, i, "newtype") - local def - local tn = ps.tokens[i].tk - local itype = i - if parse_type_body_fns[tn] then - def = new_type(ps, i, tn) - i = i + 1 - i = parse_type_body_fns[tn](ps, i, def, node) - if not def then - return fail(ps, i, "expected a type") - end + parse_type_body_fns = { + ["interface"] = parse_record_body, + ["record"] = parse_record_body, + ["enum"] = parse_enum_body, + } - node.newtype = new_typedecl(ps, itype, def) - return i, node - else - i, def = parse_type(ps, i) - if not def then - return fail(ps, i, "expected a type") - end + parse_newtype = function(ps, i) + local node = new_node(ps, i, "newtype") + local def + local tn = ps.tokens[i].tk + local itype = i + if parse_type_body_fns[tn] then + def = new_type(ps, i, tn) + i = i + 1 + i = parse_type_body_fns[tn](ps, i, def, node) + if not def then + return fail(ps, i, "expected a type") + end - if def.typename == "nominal" then - local typealias = new_type(ps, itype, "typealias") - typealias.alias_to = def - node.newtype = typealias - else node.newtype = new_typedecl(ps, itype, def) - end - - return i, node - end -end + return i, node + else + i, def = parse_type(ps, i) + if not def then + return fail(ps, i, "expected a type") + end -local function parse_assignment_expression_list(ps, i, asgn) - asgn.exps = new_node(ps.tokens, i, "expression_list") - repeat - i = i + 1 - local val - i, val = parse_expression(ps, i) - if not val then - if #asgn.exps == 0 then - asgn.exps = nil + if def.typename == "nominal" then + node.newtype = new_typealias(ps, itype, def) + else + node.newtype = new_typedecl(ps, itype, def) end - return i - end - table.insert(asgn.exps, val) - until ps.tokens[i].tk ~= "," - return i, asgn -end -local parse_call_or_assignment -do - local function is_lvalue(node) - node.is_lvalue = node.kind == "variable" or - (node.kind == "op" and - (node.op.op == "@index" or node.op.op == ".")) - return node.is_lvalue + return i, node + end end - local function parse_variable(ps, i) - local node - i, node = parse_expression(ps, i) - if not (node and is_lvalue(node)) then - return fail(ps, i, "expected a variable") - end - return i, node + local function parse_assignment_expression_list(ps, i, asgn) + asgn.exps = new_node(ps, i, "expression_list") + repeat + i = i + 1 + local val + i, val = parse_expression(ps, i) + if not val then + if #asgn.exps == 0 then + asgn.exps = nil + end + return i + end + table.insert(asgn.exps, val) + until ps.tokens[i].tk ~= "," + return i, asgn end - parse_call_or_assignment = function(ps, i) - local exp - local istart = i - i, exp = parse_expression(ps, i) - if not exp then - return i + local parse_call_or_assignment + do + local function is_lvalue(node) + node.is_lvalue = node.kind == "variable" or + (node.kind == "op" and + (node.op.op == "@index" or node.op.op == ".")) + return node.is_lvalue end - if (exp.op and exp.op.op == "@funcall") or exp.failstore then - return i, exp + local function parse_variable(ps, i) + local node + i, node = parse_expression(ps, i) + if not (node and is_lvalue(node)) then + return fail(ps, i, "expected a variable") + end + return i, node end - if not is_lvalue(exp) then - return fail(ps, i, "syntax error") - end + parse_call_or_assignment = function(ps, i) + local exp + local istart = i + i, exp = parse_expression(ps, i) + if not exp then + return i + end - local asgn = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") - asgn.vars[1] = exp - if ps.tokens[i].tk == "," then - i = i + 1 - i = parse_trying_list(ps, i, asgn.vars, parse_variable) - if #asgn.vars < 2 then - return fail(ps, i, "syntax error") + if (exp.op and exp.op.op == "@funcall") or exp.failstore then + return i, exp end - end - if ps.tokens[i].tk ~= "=" then - verify_tk(ps, i, "=") - return i - end + if not is_lvalue(exp) then + return fail(ps, i, "syntax error") + end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - return i, asgn - end -end + local asgn = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") + asgn.vars[1] = exp + if ps.tokens[i].tk == "," then + i = i + 1 + i = parse_trying_list(ps, i, asgn.vars, parse_variable) + if #asgn.vars < 2 then + return fail(ps, i, "syntax error") + end + end -local function parse_variable_declarations(ps, i, node_name) - local asgn = new_node(ps.tokens, i, node_name) + if ps.tokens[i].tk ~= "=" then + verify_tk(ps, i, "=") + return i + end - asgn.vars = new_node(ps.tokens, i, "variable_list") - i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) - if #asgn.vars == 0 then - return fail(ps, i, "expected a local variable definition") + i, asgn = parse_assignment_expression_list(ps, i, asgn) + return i, asgn + end end - i, asgn.decltuple = parse_type_list(ps, i, "decltuple") + local function parse_variable_declarations(ps, i, node_name) + local asgn = new_node(ps, i, node_name) - if ps.tokens[i].tk == "=" then - - local next_word = ps.tokens[i + 1].tk - local tn = next_word - if parse_type_body_fns[tn] then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) - elseif next_word == "functiontype" then - local scope = node_name == "local_declaration" and "local" or "global" - return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + asgn.vars = new_node(ps, i, "variable_list") + i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) + if #asgn.vars == 0 then + return fail(ps, i, "expected a local variable definition") end - i, asgn = parse_assignment_expression_list(ps, i, asgn) - end - return i, asgn -end + i, asgn.decltuple = parse_type_list(ps, i, "decltuple") -local function parse_type_declaration(ps, i, node_name) - i = i + 2 + if ps.tokens[i].tk == "=" then - local asgn = new_node(ps.tokens, i, node_name) - i, asgn.var = parse_variable_name(ps, i) - if not asgn.var then - return fail(ps, i, "expected a type name") - end + local next_word = ps.tokens[i + 1].tk + local tn = next_word + if parse_type_body_fns[tn] then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. next_word .. " " .. asgn.vars[1].tk .. "'", skip_type_body) + elseif next_word == "functiontype" then + local scope = node_name == "local_declaration" and "local" or "global" + return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) + end - if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + i, asgn = parse_assignment_expression_list(ps, i, asgn) + end return i, asgn end - i = verify_tk(ps, i, "=") + local function parse_type_declaration(ps, i, node_name) + i = i + 2 - if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then - local istart = i - i, asgn.value = parse_call_or_assignment(ps, i) - if asgn.value and not node_is_require_call(asgn.value) then - fail(ps, istart, "require() for type declarations must have a literal argument") + local asgn = new_node(ps, i, node_name) + i, asgn.var = parse_variable_name(ps, i) + if not asgn.var then + return fail(ps, i, "expected a type name") end - return i, asgn - end - i, asgn.value = parse_newtype(ps, i) - if not asgn.value then - return i - end + if node_name == "global_type" and ps.tokens[i].tk ~= "=" then + return i, asgn + end - local nt = asgn.value.newtype - if nt.typename == "typedecl" then - local def = nt.def - if def.fields or def.typename == "enum" then - if not def.declname then - def.declname = asgn.var.tk + i = verify_tk(ps, i, "=") + + if ps.tokens[i].kind == "identifier" and ps.tokens[i].tk == "require" then + local istart = i + i, asgn.value = parse_call_or_assignment(ps, i) + if asgn.value and not node_is_require_call(asgn.value) then + fail(ps, istart, "require() for type declarations must have a literal argument") end + return i, asgn end - end - - return i, asgn -end -local function parse_type_constructor(ps, i, node_name, type_name, parse_body) - local asgn = new_node(ps.tokens, i, node_name) - local nt = new_node(ps.tokens, i, "newtype") - asgn.value = nt - local itype = i - local def = new_type(ps, i, type_name) + i, asgn.value = parse_newtype(ps, i) + if not asgn.value then + return i + end - i = i + 2 + local nt = asgn.value.newtype + if nt.typename == "typedecl" then + local def = nt.def + if def.fields or def.typename == "enum" then + if not def.declname then + def.declname = asgn.var.tk + end + end + end - i, asgn.var = verify_kind(ps, i, "identifier") - if not asgn.var then - return fail(ps, i, "expected a type name") + return i, asgn end - assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") - def.declname = asgn.var.tk + local function parse_type_constructor(ps, i, node_name, type_name, parse_body) + local asgn = new_node(ps, i, node_name) + local nt = new_node(ps, i, "newtype") + asgn.value = nt + local itype = i + local def = new_type(ps, i, type_name) - i = parse_body(ps, i, def, nt) + i = i + 2 - nt.newtype = new_typedecl(ps, itype, def) + i, asgn.var = verify_kind(ps, i, "identifier") + if not asgn.var then + return fail(ps, i, "expected a type name") + end - return i, asgn -end + assert(def.typename == "record" or def.typename == "interface" or def.typename == "enum") + def.declname = asgn.var.tk -local function skip_type_declaration(ps, i) - return parse_type_declaration(ps, i - 1, "local_type") -end + i = parse_body(ps, i, def, nt) -local function parse_local_macroexp(ps, i) - local istart = i - i = i + 2 - local node = new_node(ps.tokens, i, "local_macroexp") - i, node.name = parse_identifier(ps, i) - i, node.macrodef = parse_macroexp(ps, istart, i) - end_at(node, ps.tokens[i - 1]) - return i, node -end + nt.newtype = new_typedecl(ps, itype, def) -local function parse_local(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_local_function(ps, i) - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "local_type") - elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then - return parse_local_macroexp(ps, i) - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) - end - return parse_variable_declarations(ps, i + 1, "local_declaration") -end + return i, asgn + end -local function parse_global(ps, i) - local ntk = ps.tokens[i + 1].tk - local tn = ntk - if ntk == "function" then - return parse_function(ps, i + 1, "global") - elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then - return parse_type_declaration(ps, i, "global_type") - elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then - return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) - elseif ps.tokens[i + 1].kind == "identifier" then - return parse_variable_declarations(ps, i + 1, "global_declaration") - end - return parse_call_or_assignment(ps, i) -end + local function skip_type_declaration(ps, i) + return parse_type_declaration(ps, i - 1, "local_type") + end -local function parse_record_function(ps, i) - return parse_function(ps, i, "record") -end + local function parse_local_macroexp(ps, i) + local istart = i + i = i + 2 + local node = new_node(ps, i, "local_macroexp") + i, node.name = parse_identifier(ps, i) + i, node.macrodef = parse_macroexp(ps, istart, i) + end_at(node, ps.tokens[i - 1]) + return i, node + end -local parse_statement_fns = { - ["::"] = parse_label, - ["do"] = parse_do, - ["if"] = parse_if, - ["for"] = parse_for, - ["goto"] = parse_goto, - ["local"] = parse_local, - ["while"] = parse_while, - ["break"] = parse_break, - ["global"] = parse_global, - ["repeat"] = parse_repeat, - ["return"] = parse_return, - ["function"] = parse_record_function, -} + local function parse_local(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_local_function(ps, i) + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "local_type") + elseif ntk == "macroexp" and ps.tokens[i + 2].kind == "identifier" then + return parse_local_macroexp(ps, i) + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "local_type", tn, parse_type_body_fns[tn]) + end + return parse_variable_declarations(ps, i + 1, "local_declaration") + end + + local function parse_global(ps, i) + local ntk = ps.tokens[i + 1].tk + local tn = ntk + if ntk == "function" then + return parse_function(ps, i + 1, "global") + elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then + return parse_type_declaration(ps, i, "global_type") + elseif parse_type_body_fns[tn] and ps.tokens[i + 2].kind == "identifier" then + return parse_type_constructor(ps, i, "global_type", tn, parse_type_body_fns[tn]) + elseif ps.tokens[i + 1].kind == "identifier" then + return parse_variable_declarations(ps, i + 1, "global_declaration") + end + return parse_call_or_assignment(ps, i) + end + + local function parse_record_function(ps, i) + return parse_function(ps, i, "record") + end + + local parse_statement_fns = { + ["::"] = parse_label, + ["do"] = parse_do, + ["if"] = parse_if, + ["for"] = parse_for, + ["goto"] = parse_goto, + ["local"] = parse_local, + ["while"] = parse_while, + ["break"] = parse_break, + ["global"] = parse_global, + ["repeat"] = parse_repeat, + ["return"] = parse_return, + ["function"] = parse_record_function, + } -local function type_needs_local_or_global(ps, i) - local tk = ps.tokens[i].tk - return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) -end + local function type_needs_local_or_global(ps, i) + local tk = ps.tokens[i].tk + return failskip(ps, i, ("%s needs to be declared with 'local %s' or 'global %s'"):format(tk, tk, tk), skip_type_body) + end -local needs_local_or_global = { - ["type"] = function(ps, i) - return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) - end, - ["record"] = type_needs_local_or_global, - ["enum"] = type_needs_local_or_global, -} + local needs_local_or_global = { + ["type"] = function(ps, i) + return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) + end, + ["record"] = type_needs_local_or_global, + ["enum"] = type_needs_local_or_global, + } -parse_statements = function(ps, i, toplevel) - local node = new_node(ps.tokens, i, "statements") - local item - while true do - while ps.tokens[i].kind == ";" do - i = i + 1 - if item then - item.semicolon = true + parse_statements = function(ps, i, toplevel) + local node = new_node(ps, i, "statements") + local item + while true do + while ps.tokens[i].kind == ";" do + i = i + 1 + if item then + item.semicolon = true + end end - end - if ps.tokens[i].kind == "$EOF$" then - break - end - local tk = ps.tokens[i].tk - if (not toplevel) and stop_statement_list[tk] then - break - end + if ps.tokens[i].kind == "$EOF$" then + break + end + local tk = ps.tokens[i].tk + if (not toplevel) and stop_statement_list[tk] then + break + end - local fn = parse_statement_fns[tk] - if not fn then - local skip_fn = needs_local_or_global[tk] - if skip_fn and ps.tokens[i + 1].kind == "identifier" then - fn = skip_fn - else - fn = parse_call_or_assignment + local fn = parse_statement_fns[tk] + if not fn then + local skip_fn = needs_local_or_global[tk] + if skip_fn and ps.tokens[i + 1].kind == "identifier" then + fn = skip_fn + else + fn = parse_call_or_assignment + end end - end - i, item = fn(ps, i) + i, item = fn(ps, i) - if item then - table.insert(node, item) - elseif i > 1 then + if item then + table.insert(node, item) + elseif i > 1 then - local lasty = ps.tokens[i - 1].y - while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do - i = i + 1 + local lasty = ps.tokens[i - 1].y + while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do + i = i + 1 + end end end - end - - end_at(node, ps.tokens[i]) - return i, node -end -local function clear_redundant_errors(errors) - local redundant = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i + end_at(node, ps.tokens[i]) + return i, node end - table.sort(errors, function(a, b) - local af = a.filename or "" - local bf = b.filename or "" - return af < bf or - (af == bf and (a.y < b.y or - (a.y == b.y and (a.x < b.x or - (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) + + function tl.parse_program(tokens, errs, filename) + errs = errs or {} + local ps = { + tokens = tokens, + errs = errs, + filename = filename or "", + required_modules = {}, + } + local i = 1 + local hashbang + if ps.tokens[i].kind == "hashbang" then + hashbang = ps.tokens[i].tk + i = i + 1 + end + local _, node = parse_statements(ps, i, true) + if hashbang then + node.hashbang = hashbang end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end -function tl.parse_program(tokens, errs, filename) - errs = errs or {} - local ps = { - tokens = tokens, - errs = errs, - filename = filename or "", - required_modules = {}, - } - local i = 1 - local hashbang - if ps.tokens[i].kind == "hashbang" then - hashbang = ps.tokens[i].tk - i = i + 1 + clear_redundant_errors(errs) + return node, ps.required_modules end - local _, node = parse_statements(ps, i, true) - if hashbang then - node.hashbang = hashbang + + function tl.parse(input, filename) + local tokens, errs = tl.lex(input, filename) + local node, required_modules = tl.parse_program(tokens, errs, filename) + return node, errs, required_modules end - clear_redundant_errors(errs) - return node, ps.required_modules end -function tl.parse(input, filename) - local tokens, errs = tl.lex(input, filename) - local node, required_modules = tl.parse_program(tokens, errs, filename) - return node, errs, required_modules -end + @@ -4296,7 +4295,7 @@ local function tl_debug_indent_pop(mark, single, y, x, fmt, ...) end end -local function recurse_type(ast, visit) +local function recurse_type(s, ast, visit) local kind = ast.typename if TL_DEBUG then @@ -4308,7 +4307,7 @@ local function recurse_type(ast, visit) if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4316,90 +4315,90 @@ local function recurse_type(ast, visit) if ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast.types then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast.typename == "map" then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast.fields then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "function" then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "nominal" then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast.typename == "typearg" then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast.typename == "array" then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast.typename == "literal_table_item" then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast.typename == "typealias" then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast.typename == "typedecl" then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4409,15 +4408,16 @@ local function recurse_type(ast, visit) return ret end -local function recurse_typeargs(ast, visit_type) +local function recurse_typeargs(s, ast, visit_type) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end local function extra_callback(name, + s, ast, xs, visit_node) @@ -4427,7 +4427,7 @@ local function extra_callback(name, if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node = { @@ -4447,7 +4447,7 @@ local no_recurse_node = { ["type_identifier"] = true, } -local function recurse_node(root, +local function recurse_node(s, root, visit_node, visit_type) if not root then @@ -4466,9 +4466,9 @@ local function recurse_node(root, local function walk_vars_exps(ast, xs) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4480,11 +4480,11 @@ local function recurse_node(root, end local function walk_named_function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4497,9 +4497,9 @@ local function recurse_node(root, end xs[2] = p1 if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4517,7 +4517,7 @@ local function recurse_node(root, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4543,13 +4543,13 @@ local function recurse_node(root, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast, xs) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4559,45 +4559,45 @@ local function recurse_node(root, end, ["macroexp"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast, xs) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast, xs) xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast, xs) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4606,7 +4606,7 @@ local function recurse_node(root, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4623,12 +4623,12 @@ local function recurse_node(root, end, ["newtype"] = function(ast, xs) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast, xs) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4647,7 +4647,7 @@ local function recurse_node(root, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4671,10 +4671,10 @@ local function recurse_node(root, local ret local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4778,7 +4778,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local save_indent = {} - local function increment_indent(node) + local function increment_indent(_, node) local child = node.body or node[1] if not child then return @@ -4881,7 +4881,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_node.cbs = { ["statements"] = { - after = function(node, children) + after = function(_, node, children) local out if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4903,7 +4903,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4929,7 +4929,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["local_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4941,7 +4941,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_type"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4952,7 +4952,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["global_declaration"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4963,7 +4963,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["assignment"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4972,7 +4972,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["if"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -4983,7 +4983,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["if_block"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5003,7 +5003,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["while"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5016,7 +5016,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["repeat"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5028,7 +5028,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["do"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5039,7 +5039,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["forin"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5054,7 +5054,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["fornum"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5074,7 +5074,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["return"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5084,14 +5084,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["break"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } local space for i, child in ipairs(children) do @@ -5106,7 +5106,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["literal_table"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5126,7 +5126,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5149,13 +5149,13 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["local_macroexp"] = { before = increment_indent, - after = function(node, _children) + after = function(_, node, _children) return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5170,7 +5170,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["global_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5185,7 +5185,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["record_function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5210,7 +5210,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) }, ["function"] = { before = increment_indent, - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5224,7 +5224,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["cast"] = {}, ["paren"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5233,7 +5233,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["op"] = { - after = function(node, children) + after = function(_, node, children) local out = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5294,14 +5294,14 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["variable"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } add_string(out, node.tk) return out end, }, ["newtype"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } local nt = node.newtype if nt.typename == "typealias" then @@ -5318,7 +5318,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["goto"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5326,7 +5326,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) end, }, ["label"] = { - after = function(node, _children) + after = function(_, node, _children) local out = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5339,7 +5339,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) local visit_type = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ, _children) + after = function(_, typ, _children) local out = { y = typ.y or -1, h = 0 } local r = typ.typename == "nominal" and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5377,7 +5377,6 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] @@ -5392,7 +5391,7 @@ function tl.pretty_print_ast(ast, gen_target, mode) visit_node.cbs["argument"] = visit_node.cbs["variable"] visit_node.cbs["type_identifier"] = visit_node.cbs["variable"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5442,7 +5441,6 @@ local typename_to_typecode = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5450,8 +5448,8 @@ local typename_to_typecode = { local skip_types = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m) @@ -5474,6 +5472,7 @@ function tl.new_type_reporter() local self = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5481,6 +5480,24 @@ function tl.new_type_reporter() globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5500,9 +5517,15 @@ function TypeReporter:store_function(ti, rt) end function TypeReporter:get_typenum(t) + + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5526,7 +5549,7 @@ function TypeReporter:get_typenum(t) local ti = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5596,7 +5619,7 @@ end function TypeReporter:get_collector(filename) - local tc = { + local collector = { filename = filename, symbol_list = {}, } @@ -5604,10 +5627,10 @@ function TypeReporter:get_collector(filename) local ft = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y, x, typ) + collector.store_type = function(y, x, typ) if not typ or skip_types[typ.typename] then return end @@ -5621,12 +5644,12 @@ function TypeReporter:get_collector(filename) yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node) + collector.reserve_symbol_list_slot = function(node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node, name, t) + collector.add_to_symbol_list = function(node, name, t) if not node then return end @@ -5640,12 +5663,12 @@ function TypeReporter:get_collector(filename) symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node) + collector.begin_symbol_list_scope = function(node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node) + collector.end_symbol_list_scope = function(node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5655,14 +5678,14 @@ function TypeReporter:get_collector(filename) end end - return tc + return collector end -function TypeReporter:store_result(tc, globals) +function TypeReporter:store_result(collector, globals) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5731,143 +5754,445 @@ function TypeReporter:get_report() end -function tl.get_types(result) - return result.env.reporter:get_report(), result.env.reporter + + + + +function tl.symbols_in_scope(tr, y, x) + local function find(symbols, at_y, at_x) + local function le(a, b) + return a[1] < b[1] or + (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, { at_y, at_x }, le) or 0 + end + + local ret = {} + + local n = find(tr.symbols, y, x) + + local symbols = tr.symbols + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + + + + + +function Errors.new(filename) + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg, t1, t2, t3) + if t1 then + local s1, s2, s3 + if t1.typename == "invalid" then + return nil + end + s1 = show_type(t1) + if t2 then + if t2.typename == "invalid" then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3.typename == "invalid" then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self, y, x, err) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w, msg, ...) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx, name) + if not ctx then + return "" + end + local ec = (ctx.kind ~= nil) and ctx.expected_context + local cn = (type(ctx) == "string") and ctx or + (ctx.kind ~= nil) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w, ctx, msg, ...) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag, w, fmt, ...) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w, msg, ...) + self:add(w, msg, ...) + return a_type(w, "invalid", {}) +end + +function Errors:add_unknown(node, name) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node, old_var) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name, var) + local prefix = name:sub(1, 1) + if var.declared_at and + var.is_narrowed ~= "narrow" and + prefix ~= "_" and + prefix ~= "@" then + + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" or + t.typename == "function" and "function" or + t.typename == "typedecl" and "type" or + t.typename == "typealias" and "type" or + "variable", + name, + show_type(var.t)) + + end +end + +function Errors:add_prefixing(w, src, prefix, dst) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) or + (not err.y) or + (w.y > err.y or (w.y == err.y and w.x > err.x))) then + + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + + + + + + + + +local function check_for_unused_vars(scope, is_global) + local vars = scope.vars + if not next(vars) then + return + end + local list + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a, b) + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope, is_global) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end end +function Errors:add_unknown_dot(node, name) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end +function Errors:fail_unresolved_labels(scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end +function Errors:fail_unresolved_nominals(scope, global_scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {}) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) -local FUNCTION = a_fn({ args = va_args({ ANY }), rets = va_args({ ANY }) }) +function Errors:check_redeclared_key(w, ctx, seen_keys, key) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end -local XPCALL_MSGH_FUNCTION = a_fn({ args = { ANY }, rets = {} }) local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", }, } local unop_types = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5898,67 +6223,66 @@ local binop_types = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", }, }, } @@ -6166,8 +6490,8 @@ local function show_type_base(t, short, seen) end end -local function inferred_msg(t) - return " (inferred at " .. t.inferred_at.filename .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" +local function inferred_msg(t, prefix) + return " (" .. (prefix or "") .. "inferred at " .. t.inferred_at.f .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" end show_type = function(t, short, seen) @@ -6219,28 +6543,29 @@ function tl.search_module(module_name, search_dtl) return nil, nil, tried end -local function require_module(module_name, lax, env) +local function require_module(w, module_name, feat_lax, env) local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$")) then + if found and (feat_lax or found:match("tl$")) then - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + env.module_filenames[module_name] = found + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) local found_result, err = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return a_type(w, "invalid", {}), found end local compat_code_cache = {} @@ -6262,7 +6587,7 @@ local function add_compat_entries(program, used_set, gen_compat) local code = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6301,32 +6626,26 @@ local function add_compat_entries(program, used_set, gen_compat) TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax) - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat() + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators = { @@ -6337,14 +6656,21 @@ local bit_operators = { ["<<"] = "lshift", } +local function node_at(w, n) + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node, mod_name, fn_name, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6353,10 +6679,10 @@ local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end @@ -6365,25 +6691,6 @@ local stdlib_globals = nil local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 -local function set_feat(feat, default) - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts) - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end - local function assert_no_stdlib_errors(errors, name) if #errors ~= 0 then local out = {} @@ -6394,46 +6701,31 @@ local function assert_no_stdlib_errors(errors, name) end end -tl.init_env = function(lax, gen_compat, gen_target, predefined) - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts) + opts = opts or {} local env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env, - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6442,21 +6734,20 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) local math_t = (stdlib_globals["math"].t).def local table_t = (stdlib_globals["table"].t).def - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true - stdlib_globals["..."] = { t = a_vararg({ STRING }) } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6467,53 +6758,40 @@ tl.init_env = function(lax, gen_compat, gen_target, predefined) end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type.typename == "invalid" then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast, opts) - opts = opts or {} - local env = opts.env - if not env then - local err - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + + + + local TypeChecker = {} + + + + + + + + - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename - local st = { env.globals } - local all_needs_compat = {} - local dependencies = {} - local warnings = {} - local errors = {} - local module_type - local tc - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") - end @@ -6522,10 +6800,24 @@ tl.type_check = function(ast, opts) - local function find_var(name, use) - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + + + + + + + + + + + + + + + function TypeChecker:find_var(name, use) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6534,7 +6826,7 @@ tl.type_check = function(ast, opts) end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6547,10 +6839,10 @@ tl.type_check = function(ast, opts) end end - local function simulate_g() + function TypeChecker:simulate_g() local globals = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1, 1) ~= "@" then globals[k] = v.t end @@ -6564,100 +6856,60 @@ tl.type_check = function(ast, opts) end - local resolve_typevars + local typevar_resolver - local function fresh_typevar(t) - return a_type("typevar", { + local function fresh_typevar(_, t) + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function fresh_typearg(t) - return a_type("typearg", { + local function fresh_typearg(_, t) + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }) end - local function ensure_fresh_typeargs(t) + function TypeChecker:ensure_fresh_typeargs(t) if not t.typeargs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name, use) - local var = find_var(name, use) + function TypeChecker:find_var_type(name, use) + local var = self:find_var(name, use) if var then local t = var.t if t.typename == "unresolved_typearg" then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where, msg, ...) - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(_tl_table_unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w, msg, ...) - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where, t) + local function ensure_not_abstract(t) if t.typename == "function" and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t.typename == "typedecl" then local def = t.def if def.typename == "interface" then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names, accept_typearg) - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names, accept_typearg) + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6679,7 +6931,7 @@ tl.type_check = function(ast, opts) return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ.typename == "nominal" and typ.found then typ = typ.found end @@ -6691,19 +6943,19 @@ tl.type_check = function(ast, opts) end end - local function union_type(t) + local function type_for_union(t) if t.typename == "typedecl" then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t.typename == "typealias" then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t.typename == "tuple" then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t.fields then if t.is_userdata then return "userdata", t @@ -6727,7 +6979,7 @@ tl.type_check = function(ast, opts) local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then assert(rt.fields) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6808,24 +7060,11 @@ tl.type_check = function(ast, opts) ["unknown"] = true, } - local function default_resolve_typevars_callback(t) - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt.typename == "string" then - - return STRING - end - return rt - end - - resolve_typevars = function(typ, fn_var, fn_arg) + typevar_resolver = function(self, typ, fn_var, fn_arg) local errs local seen = {} local resolved = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t, all_same) local same = true @@ -6840,7 +7079,7 @@ tl.type_check = function(ast, opts) local orig_t = t if t.typename == "typevar" then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then @@ -6856,7 +7095,7 @@ tl.type_check = function(ast, opts) seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6867,7 +7106,7 @@ tl.type_check = function(ast, opts) elseif t.typename == "typearg" then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy.typename == "typearg") copy.typearg = t.typearg @@ -6960,7 +7199,7 @@ tl.type_check = function(ast, opts) local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t.typename == "poly" then assert(copy.typename == "poly") @@ -6970,6 +7209,7 @@ tl.type_check = function(ast, opts) end elseif t.typename == "tupletable" then assert(copy.typename == "tupletable") + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6989,7 +7229,7 @@ tl.type_check = function(ast, opts) local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, a_type(typ, "invalid", {}), errs end if (not same) and @@ -7008,144 +7248,72 @@ tl.type_check = function(ast, opts) return true, copy end - local function infer_emptytable(emptytable, fresh_t) - local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st - for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } - end - end - end + local function resolve_typevar(tc, t) + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt.typename == "string" then - local function resolve_tuple(t) - if t.typename == "tuple" then - t = t.tuple[1] - end - if t == nil then - return NIL + return a_type(rt, "string", {}) end - return t - end - - local function add_warning(tag, where, fmt, ...) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where, msg, ...) - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node, name) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node, old_var) - if node.tk:sub(1, 1) == "_" then return end - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) + function TypeChecker:infer_emptytable(emptytable, fresh_t) + local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") + local nst = is_global and 1 or #self.st + for i = nst, 1, -1 do + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } + end end end - local function check_if_redeclaration(new_name, at) - local old = find_var(new_name, "check_only") - if old then - redeclaration_warning(at, old) + local function resolve_tuple(t) + local rt = t + if rt.typename == "tuple" then + rt = rt.tuple[1] end - end - - local function unused_warning(name, var) - local prefix = name:sub(1, 1) - if var.declared_at and - var.is_narrowed ~= "narrow" and - prefix ~= "_" and - prefix ~= "@" then - - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" or - t.typename == "function" and "function" or - t.typename == "typedecl" and "type" or - t.typename == "typealias" and "type" or - "variable", - name, - show_type(var.t)) - - end + if rt == nil then + return a_type(t, "nil", {}) end + return rt end - local function add_errs_prefixing(where, src, dst, prefix) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) or - (not err.y) or - (where.y > err.y or (where.y == err.y and where.x > err.x))) then - - err.y = where.y - err.x = where.x - err.filename = filename - end - table.insert(dst, err) + function TypeChecker:check_if_redeclaration(new_name, at) + local old = self:find_var(new_name, "check_only") + if old then + self.errs:redeclaration_warning(at, old) end end + local function type_at(w, t) t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where, t) - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w, t) + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where, t) - local ret = resolve_typevars_at(where, t) + function TypeChecker:infer_at(w, t) + local ret = self:resolve_typevars_at(w, t) if ret.typename == "invalid" then ret = t end @@ -7153,8 +7321,8 @@ tl.type_check = function(ast, opts) if ret == t or t.typename == "typevar" then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7167,12 +7335,9 @@ tl.type_check = function(ast, opts) return t end - local get_unresolved - local find_unresolved - - local function add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7185,11 +7350,11 @@ tl.type_check = function(ast, opts) var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7200,37 +7365,33 @@ tl.type_check = function(ast, opts) name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node, name, t, attribute, narrow, dont_check_redeclaration) - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node, name, t, attribute, narrow, dont_check_redeclaration) + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t.typename == "unresolved" or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var @@ -7238,8 +7399,6 @@ tl.type_check = function(ast, opts) - local same_type - local is_a @@ -7253,39 +7412,38 @@ tl.type_check = function(ast, opts) - - local function arg_check(where, all_errs, a, b, v, mode, n) + function TypeChecker:arg_check(w, all_errs, a, b, v, mode, n) local ok, errs if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s, t2s) + function TypeChecker:has_all_types_of(t1s, t2s) for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7317,8 +7475,8 @@ tl.type_check = function(ast, opts) end end - local function close_types(vars) - for _, var in pairs(vars) do + local function close_types(scope) + for _, var in pairs(scope.vars) do local t = var.t if t.typename == "typedecl" then t.closed = true @@ -7330,161 +7488,96 @@ tl.type_check = function(ast, opts) end end + function TypeChecker:begin_scope(node) + table.insert(self.st, { vars = {} }) - - - - - - - local function check_for_unused_vars(vars, is_global) - if not next(vars) then - return - end - local list = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t.typename == "typedecl" or t.typename == "typealias") and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t.typename == "typedecl" or t.typename == "typealias") and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a, b) - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope) - local unresolved - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t - else - unresolved = find_var_type("@unresolved") - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - }) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level) - local u = st[level or #st]["@unresolved"] - if u then - return u.t - end - end - - local function begin_scope(node) - table.insert(st, {}) - - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node) + function TypeChecker:end_scope(node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node, _children) - end_scope(node) + + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self, node, _children) + self:end_scope(node) return NONE end - local resolve_nominal - local resolve_typealias do - local function match_typevals(t, def) + local function match_typevals(self, t, def) if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t) + local function find_nominal_type_decl(self, t) if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found.typename == "typealias" then @@ -7492,8 +7585,7 @@ tl.type_check = function(ast, opts) end if not (found.typename == "typedecl") then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7508,44 +7600,35 @@ tl.type_check = function(ast, opts) return nil, found end - local function resolve_decl_into_nominal(t, found) + local function resolve_decl_into_nominal(self, t, found) local def = found.def local resolved if def.typename == "record" or def.typename == "function" then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t) - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias) + function TypeChecker:resolve_typealias(typealias) local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7554,90 +7637,92 @@ tl.type_check = function(ast, opts) return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved }) + local typedecl = a_type(typealias, "typedecl", { def = resolved }) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1, t2) - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self, t1, t2) + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1, t2) - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self, t1, t2) + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1, t2) - local same_names - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1, t2) + local same_names + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} end - return false, {} end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type - local function to_structural(t) + function TypeChecker:to_structural(t) assert(not (t.typename == "tuple")) if t.typename == "nominal" then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types, flatten_constants) + local function unite(w, types, flatten_constants) if #types == 1 then return types[1] end @@ -7648,7 +7733,6 @@ tl.type_check = function(ast, opts) local types_seen = {} - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7684,14 +7768,14 @@ tl.type_check = function(ast, opts) end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_type("union", { types = ts }) + return a_type(w, "union", { types = ts }) end end @@ -7711,21 +7795,20 @@ tl.type_check = function(ast, opts) end end - local expand_type - local function arraytype_from_tuple(where, tupletype) + function TypeChecker:arraytype_from_tuple(w, tupletype) - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not (element_type.typename == "union")) and true or is_valid_union(element_type) if valid then - return a_type("array", { elements = element_type }) + return a_type(w, "array", { elements = element_type }) end - local arr_type = a_type("array", { elements = tupletype.types[1] }) + local arr_type = a_type(w, "array", { elements = tupletype.types[1] }) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, a_type("array", { elements = tupletype.types[i] })) + local expanded = self:expand_type(w, arr_type, a_type(w, "array", { elements = tupletype.types[i] })) if not (expanded.typename == "array") then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7736,33 +7819,33 @@ tl.type_check = function(ast, opts) return t.typename == "nominal" and t.names[1] == "@self" end - local function compare_true(_, _) + local function compare_true(_, _, _) return true end - local function subtype_nominal(a, b) + function TypeChecker:subtype_nominal(a, b) if is_self(a) and is_self(b) then return true end - local ra = a.typename == "nominal" and resolve_nominal(a) or a - local rb = b.typename == "nominal" and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a.typename == "nominal" and self:resolve_nominal(a) or a + local rb = b.typename == "nominal" and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false end return ok, errs end - local function subtype_array(a, b) - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a, b) + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7784,16 +7867,16 @@ tl.type_check = function(ast, opts) return nil end - local function subtype_record(a, b) + function TypeChecker:subtype_record(a, b) if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" or + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata"), } end @@ -7802,9 +7885,9 @@ tl.type_check = function(ast, opts) local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7818,32 +7901,32 @@ tl.type_check = function(ast, opts) return true end - local eqtype_record = function(a, b) + function TypeChecker:eqtype_record(a, b) if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak, bk, av, bv, no_hack) - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self, ak, bk, av, bv, no_hack) + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) if bk.typename == "any" and not no_hack then @@ -7873,25 +7956,25 @@ tl.type_check = function(ast, opts) return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar, a, b, cmp) + function TypeChecker:compare_or_infer_typevar(typevar, a, b, cmp) - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else local other = a or b if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then @@ -7899,22 +7982,22 @@ tl.type_check = function(ast, opts) end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r.typename == "typevar" and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end - local function exists_supertype_in(t, xs) + function TypeChecker:exists_supertype_in(t, xs) for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -7925,143 +8008,139 @@ tl.type_check = function(ast, opts) ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a, b) + ["interface"] = function(_self, _a, b) return not b.is_userdata end, - ["record"] = function(_a, b) + ["record"] = function(_self, _a, b) return not b.is_userdata end, } - - - local eqtype_relations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a, b) - return same_type(a.elements, b.elements) + ["array"] = function(self, a, b) + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a, b) - return (has_all_types_of(a.types, b.types) and - has_all_types_of(b.types, a.types)) + ["union"] = function(self, a, b) + return (self:has_all_types_of(a.types, b.types) and + self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a, b) + ["interface"] = function(_self, a, b) return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a, b) + ["tuple"] = function(self, a, b) local at, bt = a.tuple, b.tuple if #at ~= #bt then return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a, b) - return is_a(resolve_tuple(a), b) + ["*"] = function(self, a, b) + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a, b) + ["typevar"] = function(self, a, b) if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a, b) - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self, a, b) + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a, b) + ["union"] = function(self, a, b) local used = {} for _, t in ipairs(a.types) do - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() if not u then return false end @@ -8070,13 +8149,13 @@ tl.type_check = function(ast, opts) end end for u, t in pairs(used) do - is_a(t, u) + self:is_a(t, u) end return true end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) for _, t in ipairs(a.types) do - if not is_a(t, b) then + if not self:is_a(t, b) then return false end end @@ -8084,212 +8163,212 @@ tl.type_check = function(ast, opts) end, }, ["poly"] = { - ["*"] = function(a, b) - if exists_supertype_in(b, a) then + ["*"] = function(self, a, b) + if self:exists_supertype_in(b, a) then return true end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a, b) - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self, a, b) + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb.typename == "interface" then - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra.typename == "union" or rb.typename == "union" then - return is_a(ra, rb) + return self:is_a(ra, rb) end return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a, b) + ["enum"] = function(_self, a, b) if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a, b) - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self, a, b) + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]), } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a, b) + ["record"] = function(self, a, b) if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a, b) + ["array"] = function(self, a, b) if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a, b) - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self, a, b) + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a, b) - if find_in_interface_list(a, function(t) return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self, a, b) + if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then return true end if not a.declname then - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a, b) - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self, a, b) + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk.typename == "enum" and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self, a, b) if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a, b) - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a, b) + ["tupletable"] = function(self, a, b) local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a, b) - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self, a, b) + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a, b) - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self, a, b) + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a, b) + ["record"] = function(self, a, b) local def = a.def if def.fields then - return subtype_record(def, b) + return self:subtype_record(def, b) end end, }, ["function"] = { - ["function"] = function(a, b) + ["function"] = function(self, a, b) local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8297,36 +8376,36 @@ a.types[i], b.types[i]), } end, }, ["typearg"] = { - ["typearg"] = function(a, b) + ["typearg"] = function(_self, a, b) return a.typearg == b.typearg end, - ["*"] = function(a, b) + ["*"] = function(self, a, b) if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a, b) - return is_a(a_type("tuple", { tuple = { a } }), b) + ["tuple"] = function(self, a, b) + return self:is_a(a_type(a, "tuple", { tuple = { a } }), b) end, - ["typevar"] = function(a, b) - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self, a, b) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a, b) + ["typearg"] = function(self, a, b) if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in, + ["union"] = TypeChecker.exists_supertype_in, - ["nominal"] = subtype_nominal, - ["poly"] = function(a, b) + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self, a, b) for _, t in ipairs(b.types) do - if not is_a(a, t) then - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + if not self:is_a(a, t) then + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8335,7 +8414,7 @@ a.types[i], b.types[i]), } } - local type_priorities = { + TypeChecker.type_priorities = { ["tuple"] = 2, ["typevar"] = 3, @@ -8364,19 +8443,7 @@ a.types[i], b.types[i]), } ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations, t1, t2) + local function compare_types(self, relations, t1, t2) if t1.typeid == t2.typeid then return true end @@ -8384,8 +8451,8 @@ a.types[i], b.types[i]), } local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8394,32 +8461,32 @@ a.types[i], b.types[i]), } if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end - is_a = function(t1, t2) - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1, t2) + return compare_types(self, self.subtype_relations, t1, t2) end - same_type = function(t1, t2) + function TypeChecker:same_type(t1, t2) - return compare_types(eqtype_relations, t1, t2) + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1, t2) + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self, t1, t2) assert(type(t1) == "table") assert(type(t2) == "table") @@ -8429,14 +8496,14 @@ a.types[i], b.types[i]), } return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where, t1, t2, context, name) + function TypeChecker:assert_is_a(w, t1, t2, ctx, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8444,24 +8511,27 @@ a.types[i], b.types[i]), } if t1.typename == "nil" then return true elseif t2.typename == "unresolved_emptytable_value" then - if is_number_type(t2.emptytable_type.keys) then - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("array", { elements = t1 }))) + local t2keys = t2.emptytable_type.keys + if is_numeric_type(t2keys) then + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "array", { elements = t1 }))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_type("map", { keys = t2.emptytable_type.keys, values = t1 }))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_type(w, "map", { keys = t2keys, values = t1 }))) end return true elseif t2.typename == "emptytable" then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not (t1.typename == "emptytable") then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": " .. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8469,11 +8539,11 @@ a.types[i], b.types[i]), } if t.typename == "invalid" then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t.typename == "nominal" then - t = resolve_nominal(t) + t = assert(t.resolved) end if t.fields then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8491,36 +8561,27 @@ a.types[i], b.types[i]), } return definitely_not_closable_exprs[e.kind] end - local unknown_dots = {} - - local function add_unknown_dot(node, name) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end - end - - local function same_in_all_union_entries(u, check) + function TypeChecker:same_in_all_union_entries(u, check) local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u) - return same_in_all_union_entries(u, function(t) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u) + return self:same_in_all_union_entries(u, function(t) + t = self:to_structural(t) if t.fields then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt.typename == "function" then - local args_tuple = a_type("tuple", { tuple = {} }) + local args_tuple = a_type(u, "tuple", { tuple = {} }) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8530,20 +8591,21 @@ a.types[i], b.types[i]), } end) end - local function resolve_for_call(func, args, is_method) + function TypeChecker:resolve_for_call(func, args, is_method) - if lax and is_unknown(func) then - func = a_fn({ args = va_args({ UNKNOWN }), rets = va_args({ UNKNOWN }) }) + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then if func.typename == "union" then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) - return to_structural(r), true + return self:to_structural(r), true end end @@ -8557,7 +8619,7 @@ a.types[i], b.types[i]), } if func.fields and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8577,7 +8639,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["variable"] = { - after = function(node, _children) + after = function(_, node, _children) local i = argnames[node.tk] if not i then return nil @@ -8590,7 +8652,7 @@ a.types[i], b.types[i]), } after = on_node, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode, args, macroexp) @@ -8598,7 +8660,7 @@ a.types[i], b.types[i]), } return { Node, args[i] } end - local on_node = function(node, children, ret) + local on_node = function(_, node, children, ret) local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8627,12 +8689,12 @@ a.types[i], b.types[i]), } orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp) + function TypeChecker:check_macroexp_arg_use(macroexp) local used = {} local on_arg_id = function(node, _i) if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8655,18 +8717,15 @@ a.types[i], b.types[i]), } orignode.known = saveknown end - - - local type_check_function_call do - local function mark_invalid_typeargs(f) + local function mark_invalid_typeargs(self, f) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and a_type(a, "unknown", {}) or a_type(a, "unresolvable_typearg", { typearg = a.typearg, })) end @@ -8675,7 +8734,7 @@ a.types[i], b.types[i]), } end end - local function infer_emptytables(where, wheres, xs, ys, delta) + local function infer_emptytables(self, w, wheres, xs, ys, delta) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8685,9 +8744,9 @@ a.types[i], b.types[i]), } if x.typename == "emptytable" then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[i + delta] or where - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end @@ -8697,7 +8756,7 @@ a.types[i], b.types[i]), } local check_args_rets do - local function check_func_type_list(where, wheres, xs, ys, from, delta, v, mode) + local function check_func_type_list(self, w, wheres, xs, ys, from, delta, v, mode) assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8708,11 +8767,11 @@ a.types[i], b.types[i]), } for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8721,7 +8780,7 @@ a.types[i], b.types[i]), } return true end - check_args_rets = function(where, where_args, f, args, expected_rets, argdelta) + check_args_rets = function(self, w, where_args, f, args, expected_rets, argdelta) local rets_ok = true local rets_errs local args_ok @@ -8732,19 +8791,19 @@ a.types[i], b.types[i]), } if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8752,29 +8811,29 @@ a.types[i], b.types[i]), } - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func) + local function push_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, })) end end end - local function pop_typeargs(func) + local function pop_typeargs(self, func) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8788,12 +8847,9 @@ a.types[i], b.types[i]), } end end - local function fail_call(where, func, nargs, errs) + local function fail_call(self, w, func, nargs, errs) if errs then - - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else local expects = {} @@ -8810,34 +8866,34 @@ a.types[i], b.types[i]), } else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) + local function check_call(self, w, where_args, func, args, expected_rets, is_typedecl_funcall, argdelta) assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func.typename == "function" or func.typename == "poly") then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func.typename == "function" or func.typename == "poly") then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_type("typedecl", { def = args.tuple[1] }))) + self:add_var(nil, "@self", a_type(w, "typedecl", { def = args.tuple[1] })) end local passes, n = 1, 1 @@ -8854,30 +8910,30 @@ a.types[i], b.types[i]), } local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) or + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) or (passes == 3 and ((pass == 1 and given == wanted) or - (pass == 2 and given < wanted and (lax or given >= min_arity)) or + (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then return matched, f @@ -8886,23 +8942,23 @@ a.types[i], b.types[i]), } if expected_rets then - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node, func, args, argdelta, e1, e2) + function TypeChecker:type_check_function_call(node, func, args, argdelta, e1, e2) e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8911,14 +8967,14 @@ a.types[i], b.types[i]), } if expected and expected.typename == "tuple" then expected_rets = expected else - expected_rets = a_type("tuple", { tuple = { node.expected } }) + expected_rets = a_type(node, "tuple", { tuple = { node.expected } }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver.typename == "nominal" then local resolved = receiver.resolved if resolved and resolved.typename == "typedecl" then @@ -8927,12 +8983,12 @@ a.types[i], b.types[i]), } end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -8943,9 +8999,9 @@ a.types[i], b.types[i]), } end end - local function check_metamethod(node, method_name, a, b, orig_a, orig_b) - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node, method_name, a, b, orig_a, orig_b) + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return a_type(node, "unknown", {}), nil end local ameta = a.fields and a.meta_fields local bmeta = b and b.fields and b.meta_fields @@ -8966,26 +9022,26 @@ a.types[i], b.types[i]), } if metamethod then local e2 = { node.e1 } - local args = a_type("tuple", { tuple = { orig_a } }) + local args = a_type(node, "tuple", { tuple = { orig_a } }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl, rec, key) + function TypeChecker:match_record_key(tbl, rec, key) assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl.typename == "string" or tbl.typename == "enum" then - tbl = find_var_type("string") + tbl = self:find_var_type("string") end if tbl.typename == "typedecl" then @@ -8994,13 +9050,13 @@ a.types[i], b.types[i]), } if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl.typename == "union" then - local t = same_in_all_union_entries(tbl, function(t) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9009,7 +9065,7 @@ a.types[i], b.types[i]), } end if (tbl.typename == "typevar" or tbl.typename == "typearg") and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9023,7 +9079,8 @@ a.types[i], b.types[i]), } return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9034,8 +9091,8 @@ a.types[i], b.types[i]), } return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl.typename == "emptytable" or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return a_type(rec, "unknown", {}) end return nil, "cannot index a value of unknown type" end @@ -9047,30 +9104,35 @@ a.types[i], b.types[i]), } end end - local function widen_in_scope(scope, var) - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope, var) + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name) + function TypeChecker:widen_back_var(name) local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9084,7 +9146,7 @@ a.types[i], b.types[i]), } local visit_node = { cbs = { ["assignment"] = { - after = function(node, _children) + after = function(_, node, _children) for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9094,7 +9156,7 @@ a.types[i], b.types[i]), } end, }, }, - after = function(_node, children, ret) + after = function(_, _node, children, ret) ret = ret or false for _, c in ipairs(children) do local ca = c @@ -9112,118 +9174,82 @@ a.types[i], b.types[i]), } end, } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node, var, valtype, is_assigning) - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node, varname, valtype, is_assigning) + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - local get_rets - if lax then - get_rets = function(rets) - if #rets.tuple == 0 then - return a_vararg({ UNKNOWN }) - end - return rets - end - else - get_rets = function(rets) - return rets - end + return var end - local function add_internal_function_variables(node, args) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_type("tuple", { tuple = {} })) + function TypeChecker:add_internal_function_variables(node, args) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_type(node, "tuple", { tuple = {} })) if node.typeargs then for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") + local v = self:find_var(t.typearg, "check_only") if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) - end - end - end - end - - local function add_function_definition_for_recursion(node, fnargs) - add_var(nil, node.name.tk, type_at(node, a_function({ - min_arity = node.min_arity, - typeargs = node.typeargs, - args = fnargs, - rets = get_rets(node.rets), - }))) - end - - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function end_function_scope(node) - fail_unresolved() - end_scope(node) + function TypeChecker:add_function_definition_for_recursion(node, fnargs) + self:add_var(nil, node.name.tk, a_function(node, { + min_arity = node.min_arity, + typeargs = node.typeargs, + args = fnargs, + rets = self.get_rets(node.rets), + })) + end + + function TypeChecker:end_function_scope(node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals) local vt = vals.tuple local n_vals = #vt - local ret = a_type("tuple", { tuple = {} }) + local ret = a_type(vals, "tuple", { tuple = {} }) local rt = ret.tuple if n_vals == 0 then @@ -9251,9 +9277,9 @@ a.types[i], b.types[i]), } return ret end - local function get_assignment_values(vals, wanted) + local function get_assignment_values(w, vals, wanted) if vals == nil then - return a_type("tuple", { tuple = {} }) + return a_type(w, "tuple", { tuple = {} }) end local ret = flatten_tuple(vals) @@ -9272,14 +9298,14 @@ a.types[i], b.types[i]), } return ret end - local function match_all_record_field_names(node, a, field_names, errmsg) + function TypeChecker:match_all_record_field_names(node, a, field_names, errmsg) local t for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9289,26 +9315,26 @@ a.types[i], b.types[i]), } if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode, bnode, a, b) + function TypeChecker:type_check_index(anode, bnode, a, b) assert(not (a.typename == "tuple")) assert(not (b.typename == "tuple")) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm local erra local errb - if ra.typename == "tupletable" and is_a(rb, INTEGER) then + if ra.typename == "tupletable" and rb.typename == "integer" then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum] @@ -9316,38 +9342,35 @@ a.types[i], b.types[i]), } errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra.elements and is_a(rb, INTEGER) then + elseif ra.elements and rb.typename == "integer" then return ra.elements elseif ra.typename == "emptytable" then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra, - })) + }) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " .. - ra.keys.inferred_at.filename .. ":" .. - ra.keys.inferred_at.y .. ":" .. - ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra.typename == "map" then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb.typename == "string" and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9363,10 +9386,10 @@ a.types[i], b.types[i]), } end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb.typename == "string" then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9375,28 +9398,28 @@ a.types[i], b.types[i]), } errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where, old, new) + function TypeChecker:expand_type(w, old, new) if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old.typename == "map" and new.fields then local old_keys = old.keys if old_keys.typename == "string" then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") + edit_type(w, old, "map") else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old.fields and new.fields then local values @@ -9404,14 +9427,14 @@ a.types[i], b.types[i]), } if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9419,25 +9442,25 @@ a.types[i], b.types[i]), } old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old.typename == "union" then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp) + function TypeChecker:find_record_to_extend(exp) if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9454,7 +9477,7 @@ a.types[i], b.types[i]), } return t, v, exp.tk elseif exp.kind == "op" then - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9475,30 +9498,29 @@ a.types[i], b.types[i]), } end end - local function typedecl_to_nominal(where, name, t, resolved) + local function typedecl_to_nominal(node, name, t, resolved) local typevals local def = t.def if def.typeargs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, })) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - })) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp) + function TypeChecker:get_self_type(exp) if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9510,7 +9532,7 @@ a.types[i], b.types[i]), } end elseif exp.kind == "op" then - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9542,7 +9564,6 @@ a.types[i], b.types[i]), } local facts_and local facts_or local facts_not - local apply_facts local FACT_TRUTHY do local IsFact_mt = { @@ -9554,6 +9575,7 @@ a.types[i], b.types[i]), } setmetatable(IsFact, { __call = function(_, fact) fact.fact = "is" + assert(fact.w) return setmetatable(fact, IsFact_mt) end, }) @@ -9567,6 +9589,7 @@ a.types[i], b.types[i]), } setmetatable(EqFact, { __call = function(_, fact) fact.fact = "==" + assert(fact.w) return setmetatable(fact, EqFact_mt) end, }) @@ -9625,57 +9648,57 @@ a.types[i], b.types[i]), } FACT_TRUTHY = TruthyFact({}) - facts_and = function(where, f1, f2) - return AndFact({ f1 = f1, f2 = f2, where = where }) + facts_and = function(w, f1, f2) + return AndFact({ f1 = f1, f2 = f2, w = w }) end - facts_or = function(where, f1, f2) + facts_or = function(w, f1, f2) if f1 and f2 then - return OrFact({ f1 = f1, f2 = f2, where = where }) + return OrFact({ f1 = f1, f2 = f2, w = w }) else return nil end end - facts_not = function(where, f1) + facts_not = function(w, f1) if f1 then - return NotFact({ f1 = f1, where = where }) + return NotFact({ f1 = f1, w = w }) else return nil end end - local function unite_types(t1, t2) - return unite({ t2, t1 }) + local function unite_types(w, t1, t2) + return unite(w, { t2, t1 }) end - local function intersect_types(t1, t2) + local function intersect_types(self, w, t1, t2) if t2.typename == "union" then t1, t2 = t2, t1 end if t1.typename == "union" then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL + return a_type(w, "nil", {}) end end end - local function resolve_if_union(t) - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t) + local rt = self:to_structural(t) if rt.typename == "union" then return rt end @@ -9683,23 +9706,23 @@ a.types[i], b.types[i]), } end - local function subtract_types(t1, t2) + local function subtract_types(self, w, t1, t2) local types = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) if not (t1.typename == "union") then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2.typename == "union" and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9710,10 +9733,10 @@ a.types[i], b.types[i]), } end if #types == 0 then - return NIL + return a_type(w, "nil", {}) end - return unite(types) + return unite(w, types) end local eval_not @@ -9723,65 +9746,65 @@ a.types[i], b.types[i]), } local eval_fact local function invalid_from(f) - return IsFact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) + return IsFact({ fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w }) end - not_facts = function(fs) + not_facts = function(self, fs) local ret = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) elseif f.fact == "==" then - ret[var] = EqFact({ var = var, typ = typ }) + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) elseif typ.typename == "typevar" then assert(f.fact == "is") - ret[var] = EqFact({ var = var, typ = typ }) - elseif not is_a(f.typ, typ) then + ret[var] = EqFact({ var = var, typ = typ, w = f.w, no_infer = true }) + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact({ var = var, typ = INVALID, where = f.where }) + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact({ var = var, typ = a_type(f.w, "invalid", {}), w = f.w, no_infer = f.no_infer }) else assert(f.fact == "is") - ret[var] = IsFact({ var = var, typ = subtract_types(typ, f.typ), where = f.where }) + ret[var] = IsFact({ var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer }) end end return ret end - eval_not = function(f) + eval_not = function(self, f) if not f then return {} elseif f.fact == "is" then - return not_facts({ [f.var] = f }) + return not_facts(self, { [f.var] = f }) elseif f.fact == "not" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "and" then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f.fact == "or" then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1, fs2) + or_facts = function(_self, fs1, fs2) local ret = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact({ var = var, typ = united, where = f.where }) + ret[var] = IsFact({ var = var, typ = united, w = f.w }) else - ret[var] = EqFact({ var = var, typ = united, where = f.where }) + ret[var] = EqFact({ var = var, typ = united, w = f.w }) end end end @@ -9789,7 +9812,7 @@ a.types[i], b.types[i]), } return ret end - and_facts = function(fs1, fs2) + and_facts = function(self, fs1, fs2) local ret = {} local has = {} @@ -9800,18 +9823,18 @@ a.types[i], b.types[i]), } if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor({ var = var, typ = rt, where = f.where }) + local ff = ctor({ var = var, typ = rt, w = f.w, no_infer = f.no_infer }) ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact({ var = var, typ = f.typ, where = f.where }) + ret[var] = EqFact({ var = var, typ = f.typ, w = f.w, no_infer = f.no_infer }) has["=="] = true end end @@ -9825,21 +9848,21 @@ a.types[i], b.types[i]), } return ret end - eval_fact = function(f) + eval_fact = function(self, f) if not f then return {} elseif f.fact == "is" then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9847,63 +9870,60 @@ a.types[i], b.types[i]), } elseif f.fact == "==" then return { [f.var] = f } elseif f.fact == "not" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "truthy" then return {} elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f.fact == "and" then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f.fact == "or" then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where, known) + function TypeChecker:apply_facts(w, known) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall - - local function special_pcall_xpcall(node, _a, b, argdelta) + local function special_pcall_xpcall(self, node, _a, b, argdelta) local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_type("tuple", { tuple = { BOOLEAN } }) + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_type(node, "tuple", { tuple = { bool } }) end @@ -9915,137 +9935,142 @@ a.types[i], b.types[i]), } ftype.is_method = false end - local fe2 = {} + local fe2 = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_type(arg2, "tuple", { tuple = { a_type(arg2, "any", {}) } }), + rets = a_type(arg2, "tuple", { tuple = {} }), + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode = { - y = node.y, - x = node.x, + local fnode = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets.typename == "invalid" then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end local special_functions = { - ["pairs"] = function(node, a, b, argdelta) + ["pairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t.elements then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t.fields then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node, a, b, argdelta) + ["ipairs"] = function(self, node, a, b, argdelta) if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t.typename == "tupletable" then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t.elements then - if not (lax and (is_unknown(t) or t.typename == "emptytable")) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t.typename == "emptytable")) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node, _a, b, _argdelta) + ["rawget"] = function(self, node, _a, b, _argdelta) if #b.tuple == 2 then - return a_type("tuple", { tuple = { type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) + return a_type(node, "tuple", { tuple = { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) } }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node, _a, b, _argdelta) + ["require"] = function(self, node, _a, b, _argdelta) if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return invalid_at(node, "don't know how to resolve a dynamic require") + return self.errs:invalid_at(node, "don't know how to resolve a dynamic require") end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_type("tuple", { tuple = { UNKNOWN } }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") + end + + if self.feat_lax then + return a_type(node, "tuple", { tuple = { a_type(node, "unknown", {}) } }) end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_type("tuple", { tuple = { t } })) + self.dependencies[module_name] = module_filename + return a_type(node, "tuple", { tuple = { t } }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node, a, b, argdelta) + ["assert"] = function(self, node, a, b, argdelta) node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node, a, b, argdelta) + function TypeChecker:type_check_funcall(node, a, b, argdelta) argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10057,19 +10082,19 @@ a.types[i], b.types[i]), } node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node, i, name) - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node, i, name) + if self.feat_lax then + return a_type(node, "unknown", {}) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node, children) + local function set_expected_types_to_decltuple(_, node, children) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple.typename == "tuple") local decls = decltuple.tuple @@ -10081,7 +10106,7 @@ a.types[i], b.types[i]), } typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_type("tuple", { tuple = {} })) + typ = a_type(node, "tuple", { tuple = {} }) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10097,38 +10122,7 @@ a.types[i], b.types[i]), } return n and n >= 1 and math.floor(n) == n end - local context_name = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx, msg) - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - - - local function check_redeclared_key(where, ctx, seen_keys, key) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node, children) + local function infer_table_literal(self, node, children) local is_record = false local is_array = false local is_map = false @@ -10153,14 +10147,15 @@ a.types[i], b.types[i]), } for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end local key = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10171,7 +10166,7 @@ a.types[i], b.types[i]), } end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif is_numeric_type(cktype) then is_array = true if not is_not_tuple then is_tuple = true @@ -10185,25 +10180,25 @@ a.types[i], b.types[i]), } if i == #children and cv.typename == "tuple" then for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n] = uvtype if n > largest_array_idx then largest_array_idx = n end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10215,37 +10210,37 @@ a.types[i], b.types[i]), } end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_type("map", { keys = -expand_type(node, keys, INTEGER), values = + self.errs:add(node, "cannot determine type of table literal") + t = a_type(node, "map", { keys = +self:expand_type(node, keys, a_type(node, "integer", {})), values = -expand_type(node, values, elements) }) +self:expand_type(node, values, elements) }) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, a_type("array", { elements = elements })), + a_type(node, "array", { elements = elements }), }, }) elseif is_record and is_map then if keys.typename == "string" then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10253,7 +10248,7 @@ expand_type(node, values, elements) }) local last_t for _, current_t in pairs(types) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10262,69 +10257,70 @@ expand_type(node, values, elements) }) end end if pure_array then - t = a_type("array", { elements = elements }) + t = a_type(node, "array", { elements = elements }) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, }) elseif is_map then - t = a_type("map", { keys = keys, values = values }) + t = a_type(node, "map", { keys = keys, values = values }) elseif is_tuple then - t = a_type("tupletable", {}) + t = a_type(node, "tupletable", { inferred_at = node }) t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where, ifnode, n) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w, ifnode, n) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var, node, infertypes, i) + function TypeChecker:determine_declaration_type(var, node, infertypes, i) local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype.typename == "invalid" then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype.typename == "unresolvable_typearg" then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype.typename == "function" and infertype.is_method then @@ -10336,17 +10332,17 @@ expand_type(node, values, elements) }) end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10354,12 +10350,12 @@ expand_type(node, values, elements) }) if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri.typename == "map" then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri.typename == "record" then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10369,34 +10365,36 @@ expand_type(node, values, elements) }) local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = name elseif t.elements then t.inferred_len = nil + elseif t.typename == "nominal" then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value) + function TypeChecker:get_typedecl(value) if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_type("tuple", { tuple = { STRING } }), 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_type(value.e2, "tuple", { tuple = { a_type(value.e2[1], "string", {}) } }), 0) local ty = t.typename == "tuple" and t.tuple[1] or t - ty = (ty.typename == "typealias") and resolve_typealias(ty) or ty - local td = (ty.typename == "typedecl") and ty or a_type("typedecl", { def = ty }) + ty = (ty.typename == "typealias") and self:resolve_typealias(ty) or ty + local td = (ty.typename == "typedecl") and ty or a_type(value, "typedecl", { def = ty }) return td else local newtype = value.newtype if newtype.typename == "typealias" then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype.typename == "typedecl" then return newtype, nil end end @@ -10427,15 +10425,14 @@ expand_type(node, values, elements) }) return is_total, missing end - local function total_map_check(t, seen_keys) - local k = to_structural(t.keys) + local function total_map_check(keys, seen_keys) local is_total = true local missing - if k.typename == "enum" then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys.typename == "enum" then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10449,35 +10446,38 @@ expand_type(node, values, elements) }) - local function check_assignment(where, vartype, valtype, varname, attr) + function TypeChecker:check_assignment(varnode, vartype, valtype) + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var.typename == "typedecl" or var.typename == "typealias" then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10493,181 +10493,182 @@ expand_type(node, values, elements) }) visit_node.cbs = { ["statements"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - after = function(node, _children) + after = function(self, node, _children) - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end return NONE end, }, ["local_type"] = { - before = function(node) + before = function(self, node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node) + before = function(self, node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node, _children) - dismiss_unresolved(node.var.tk) + after = function(self, node, _children) + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node) - if tc then + before = function(self, node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not (rt.typename == "enum")) and ((not (t.typename == "nominal")) or (rt.typename == "union")) and - not same_type(t, infertype) then + not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local valtuple = children[3] - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node, children) + after = function(self, node, children) local vartuple = children[1] assert(vartuple.typename == "tuple") local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple.typename == "tuple") - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then if rval.typename == "function" then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar.typename == "union" or rvar.typename == "interface") then - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10676,7 +10677,7 @@ expand_type(node, values, elements) }) end, }, ["if"] = { - after = function(node, _children) + after = function(self, node, _children) local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10686,26 +10687,26 @@ expand_type(node, values, elements) }) end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node) + before_statements = function(self, node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node, _children) - end_scope(node) + after = function(self, node, _children) + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10715,76 +10716,96 @@ expand_type(node, values, elements) }) end, }, ["while"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self, node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node) - - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + before = function(self, node) + + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + + end + end, after = function() return NONE end, }, ["goto"] = { - after = function(node, _children) - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self, node, _children) + local label_id = node.label + local found_label + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node) + before = function(self, node) - widen_all_unions(node) + self:widen_all_unions(node) end, after = end_scope_and_none_type, }, ["forin"] = { - before = function(node) - begin_scope(node) + before = function(self, node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local exptuple = children[2] assert(exptuple.typename == "tuple") local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_type("tuple", { tuple = { + local args = a_type(node.exps, "tuple", { tuple = { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3], } }) - local exp1type = resolve_for_call(exptypes[1], args, false) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type.typename == "poly" then local _ - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, { node.exps[2], node.exps[3] }) end if exp1type.typename == "function" then @@ -10797,69 +10818,69 @@ expand_type(node, values, elements) }) if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and a_type(v, "unknown", {}) or a_type(v, "invalid", {}) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node, children) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and + before_statements = function(self, node, children) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename = (from_t.typename == "integer" and to_t.typename == "integer" and (not step_t or step_t.typename == "integer")) and - INTEGER or - NUMBER - add_var(node.var, node.var.tk, t) + "integer" or + "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node) - local rets = find_var_type("@return") + before = function(self, node) + local rets = self:find_var_type("@return") if rets and rets.typename == "tuple" then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node, children) + after = function(self, node, children) local got = children[1] assert(got.typename == "tuple") local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") + local expected = self:find_var_type("@return") if not expected then - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10874,8 +10895,8 @@ expand_type(node, values, elements) }) vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what .. ": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 and @@ -10883,18 +10904,18 @@ expand_type(node, values, elements) }) node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) and + local w = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10902,25 +10923,28 @@ expand_type(node, values, elements) }) end, }, ["variable_list"] = { - after = function(node, children) - local tuple = a_type("tuple", { tuple = children }) + after = function(self, node, children) + local tuple = a_type(node, "tuple", { tuple = children }) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node) + before = function(self, node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype.typename == "typevar" and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype.typename == "tupletable" then @@ -10952,19 +10976,19 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) + after = function(self, node, children) node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint if decltype.typename == "typevar" and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype.typename == "union" then @@ -10972,7 +10996,7 @@ expand_type(node, values, elements) }) local single_table_rt for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then @@ -10993,7 +11017,7 @@ expand_type(node, values, elements) }) end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array = nil @@ -11003,73 +11027,75 @@ expand_type(node, values, elements) }) for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b = nil - if child.ktype.typename == "boolean" then + if cktype.typename == "boolean" then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype.fields and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df.typename == "typedecl" or df.typename == "typealias" then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype.typename == "tupletable" and is_number_type(child.ktype) then + elseif decltype.typename == "tupletable" and is_numeric_type(cktype) then local dt = decltype.types[n] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype.elements and is_number_type(child.ktype) then + elseif decltype.elements and is_numeric_type(cktype) then local cv = child.vtype if cv.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype.typename == "map" then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype.typename == "map" then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t if force_array then - t = infer_at(node, a_type("array", { elements = force_array })) + t = self:infer_at(node, a_type(node, "array", { elements = force_array })) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype.typename == "record" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "record" then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype.typename == "map" then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt.typename == "map" then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11081,13 +11107,13 @@ expand_type(node, values, elements) }) end, }, ["literal_table_item"] = { - after = function(node, children) + after = function(self, node, children) local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype.typename == "function" and vtype.is_method then @@ -11096,210 +11122,210 @@ expand_type(node, values, elements) }) vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - })) + }) end, }, ["local_function"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self, node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function({ + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, - }))) + })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ.typename == "function" then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[2] assert(args.typename == "tuple") - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[2] assert(args.typename == "tuple") local rets = children[3] assert(rets.typename == "tuple") - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function({ + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - })))) + rets = self.get_rets(rets), + }))) return NONE end, }, ["record_function"] = { - before = function(node) - widen_all_unions() - begin_scope(node) + before = function(self, node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node, children) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self, _node, children) + local rtype = self:to_structural(resolve_typedecl(children[1])) if rtype.fields and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) end end end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[3] assert(args.typename == "tuple") local rets = children[4] assert(rets.typename == "tuple") - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype.typename == "unknown" then return end if rtype.typename == "emptytable" then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype r.fields = {} r.field_order = {} end if not rtype.fields then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function({ + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), - }))) + rets = self.get_rets(rets), + })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype.typename == "poly" then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11312,82 +11338,82 @@ expand_type(node, values, elements) }) open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, _children) - end_function_scope(node) + after = function(self, node, _children) + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node, children) + before_statements = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["macroexp"] = { - before = function(node) - widen_all_unions(node) - begin_scope(node) + before = function(self, node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node, children) + before_exp = function(self, node, children) local args = children[1] assert(args.typename == "tuple") - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node, children) + after = function(self, node, children) local args = children[1] assert(args.typename == "tuple") local rets = children[2] assert(rets.typename == "tuple") - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function({ + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, rets = rets, - }))) + })) end, }, ["cast"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.casttype end, }, ["paren"] = { - before = function(node) + before = function(_self, node) node.e1.expected = node.expected end, - after = function(node, children) + after = function(_self, node, children) node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node) - begin_scope() + before = function(self, node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11399,18 +11425,19 @@ expand_type(node, values, elements) }) end end end, - before_e2 = function(node, children) + before_e2 = function(self, node, children) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11433,8 +11460,8 @@ expand_type(node, values, elements) }) end end end, - after = function(node, children) - end_scope() + after = function(self, node, children) + self:end_scope() local ga = children[1] @@ -11445,29 +11472,34 @@ expand_type(node, values, elements) }) local ub - local ra = to_structural(ua) + local ra = self:to_structural(ua) local rb if ra.typename == "circular_require" or (ra.typename == "typedecl" and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb.typename == "tuple") + assert(node.f) + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra.typename == "typedecl" and ra.def.typename == "record" then ra = ra.def end @@ -11476,8 +11508,11 @@ expand_type(node, values, elements) }) if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb.typename == "typedecl" and rb.def.typename == "record" then rb = rb.def end @@ -11487,22 +11522,20 @@ expand_type(node, values, elements) }) node.receiver = ua assert(node.e2.kind == "identifier") - local bnode = { - y = node.e2.y, - x = node.e2.x, + local bnode = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk })) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk }) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11510,22 +11543,22 @@ expand_type(node, values, elements) }) end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra.typename == "typedecl" then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact({ var = node.e1.tk, typ = ub, where = node }) + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact({ var = node.e1.tk, typ = ub, w = node }) else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11533,16 +11566,16 @@ expand_type(node, values, elements) }) - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua.typename == "typevar") then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return a_type(node, "unknown", {}) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11550,7 +11583,7 @@ expand_type(node, values, elements) }) if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11568,33 +11601,33 @@ expand_type(node, values, elements) }) node.known = nil t = ua - elseif ((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or - (ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then + elseif ((ra.typename == "enum" and rb.typename == "string" and self:is_a(rb, ra)) or + (ra.typename == "string" and rb.typename == "enum" and self:is_a(ra, rb))) then node.known = nil t = (ra.typename == "enum" and ra or rb) elseif expected and expected.typename == "union" then node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ ra, rb }, true) + local u = unite(node, { ra, rb }, true) if u.typename == "union" then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or a_type(node, "invalid", {}) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11618,39 +11651,41 @@ expand_type(node, values, elements) }) if ra.typename == "enum" and rb.typename == "string" then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra.typename == "tupletable" and rb.typename == "tupletable" and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua.typename == "typevar" then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact({ var = node.e1.tk, typ = ub, where = node }) + node.known = EqFact({ var = node.e1.tk, typ = ub, w = node }) end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub.typename == "typevar" then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact({ var = node.e2.tk, typ = ua, where = node }) + node.known = EqFact({ var = node.e2.tk, typ = ua, w = node }) end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return a_type(node, "unknown", {}) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(node, ra.types, true) end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra.fields then t = find_in_interface_list(ra, function(ty) - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11658,19 +11693,18 @@ expand_type(node, values, elements) }) if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra.typename == "map" then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11678,12 +11712,12 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11697,39 +11731,39 @@ expand_type(node, values, elements) }) end if ra.typename == "union" then - ra = unite(ra.types, true) + ra = unite(ra, ra.types, true) end if rb.typename == "union" then - rb = unite(rb.types, true) + rb = unite(rb, rb.types, true) end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ ua, ub }) + local u = unite(node, { ua, ub }) if u.typename == "union" and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua.typename == "nominal" and ub.typename == "nominal" and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11737,20 +11771,20 @@ expand_type(node, values, elements) }) node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11762,28 +11796,28 @@ expand_type(node, values, elements) }) end, }, ["variable"] = { - after = function(node, _children) + after = function(self, node, _children) if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t.typename == "typedecl" then @@ -11794,70 +11828,70 @@ expand_type(node, values, elements) }) end, }, ["type_identifier"] = { - after = function(node, _children) - local typ, attr = find_var_type(node.tk) + after = function(self, node, _children) + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return a_type(node, "unknown", {}) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node, children) + after = function(self, node, children) local t = children[1] if not t then - t = UNKNOWN + t = a_type(node, "unknown", {}) end if node.tk == "..." then - t = a_vararg({ t }) + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, }, ["newtype"] = { - after = function(node, _children) + after = function(_self, node, _children) return node.newtype end, }, ["error_node"] = { - after = function(_node, _children) - return INVALID + after = function(_self, node, _children) + return a_type(node, "invalid", {}) end, }, } visit_node.cbs["break"] = { - after = function(_node, _children) + after = function(_self, _node, _children) return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node) + local function after_literal(_self, node) node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind, {})) + return a_type(node, node.kind, {}) end visit_node.cbs["string"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected.typename == "enum" and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected.typename == "enum" and self:is_a(t, expected) then return node.expected end @@ -11868,8 +11902,8 @@ expand_type(node, values, elements) }) visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node, _children) - local t = after_literal(node) + after = function(self, node, _children) + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11880,7 +11914,7 @@ expand_type(node, values, elements) }) visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node, _children, t) + visit_node.after = function(_self, node, _children, t) if node.expanded then apply_macroexp(node) end @@ -11888,13 +11922,12 @@ expand_type(node, values, elements) }) return t end - local expand_interfaces do - local function add_interface_fields(what, fields, field_order, resolved, named, list) + local function add_interface_fields(self, what, fields, field_order, resolved, named, list) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .. " '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11903,18 +11936,21 @@ expand_type(node, values, elements) }) end end - local function collect_interfaces(list, t, seen) + local function collect_interfaces(self, list, t, seen) if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri.typename == "interface", "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri.typename == "interface" then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -11927,30 +11963,30 @@ expand_type(node, values, elements) }) return list end - expand_interfaces = function(t) + function TypeChecker:expand_interfaces(t) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri.typename == "interface") - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -11962,29 +11998,29 @@ expand_type(node, values, elements) }) visit_type = { cbs = { ["function"] = { - before = function(_typ) - begin_scope() + before = function(self, _typ) + self:begin_scope() end, - after = function(typ, _children) - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self, typ, _children) + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ) - begin_scope() - add_var(nil, "@self", type_at(typ, a_type("typedecl", { def = typ }))) + before = function(self, typ) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_type(typ, "typedecl", { def = typ }))) for fname, ftype in fields_of(typ) do if ftype.typename == "typealias" then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype.typename == "typedecl" then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ, children) + after = function(self, typ, children) local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11998,11 +12034,11 @@ expand_type(node, values, elements) }) if iface.typename == "array" then typ.interface_list[j] = iface elseif iface.typename == "nominal" then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri.typename == "interface" then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12042,7 +12078,7 @@ expand_type(node, values, elements) }) end end elseif ftype.typename == "typealias" then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12061,55 +12097,55 @@ expand_type(node, values, elements) }) end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ, _children) - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self, typ, _children) + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - }))) + })) return typ end, }, ["typevar"] = { - after = function(typ, _children) - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self, typ, _children) + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ, _children) + after = function(self, typ, _children) if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t.typename == "typearg" then typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ tv.typevar = t.typearg tv.constraint = t.constraint @@ -12120,18 +12156,19 @@ expand_type(node, values, elements) }) end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ, _children) + after = function(self, typ, _children) local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or a_type(typ, "invalid", {}) end return typ end, @@ -12139,15 +12176,47 @@ expand_type(node, values, elements) }) }, } + local default_type_visitor = { + after = function(_self, typ, _children) + return typ + end, + } + + visit_type.cbs["interface"] = visit_type.cbs["record"] + + visit_type.cbs["string"] = default_type_visitor + visit_type.cbs["tupletable"] = default_type_visitor + visit_type.cbs["typedecl"] = default_type_visitor + visit_type.cbs["typealias"] = default_type_visitor + visit_type.cbs["array"] = default_type_visitor + visit_type.cbs["map"] = default_type_visitor + visit_type.cbs["enum"] = default_type_visitor + visit_type.cbs["boolean"] = default_type_visitor + visit_type.cbs["nil"] = default_type_visitor + visit_type.cbs["number"] = default_type_visitor + visit_type.cbs["integer"] = default_type_visitor + visit_type.cbs["thread"] = default_type_visitor + visit_type.cbs["emptytable"] = default_type_visitor + visit_type.cbs["literal_table_item"] = default_type_visitor + visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor + visit_type.cbs["tuple"] = default_type_visitor + visit_type.cbs["poly"] = default_type_visitor + visit_type.cbs["any"] = default_type_visitor + visit_type.cbs["unknown"] = default_type_visitor + visit_type.cbs["invalid"] = default_type_visitor + visit_type.cbs["none"] = default_type_visitor + + + local function internal_compiler_check(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(s, n, children, t) + t = fn and fn(s, n, children, t) or t if type(t) ~= "table" then - error(((w).kind or (w).typename) .. " did not produce a type") + error(((n).kind or (n).typename) .. " did not produce a type") end if type(t.typename) ~= "string" then - error(((w).kind or (w).typename) .. " type does not have a typename") + error(((n).kind or (n).typename) .. " type does not have a typename") end return t @@ -12155,13 +12224,13 @@ expand_type(node, values, elements) }) end local function store_type_after(fn) - return function(w, children, t) - t = fn and fn(w, children, t) or t + return function(self, n, children, t) + t = fn and fn(self, n, children, t) or t - local where = w + local w = n - if where.y then - tc.store_type(where.y, where.x, t) + if w.y then + self.collector.store_type(w.y, w.x, t) end return t @@ -12169,119 +12238,167 @@ expand_type(node, values, elements) }) end local function debug_type_after(fn) - return function(node, children, t) - t = fn and fn(node, children, t) or t + return function(s, node, children, t) + t = fn and fn(s, node, children, t) or t + node.debug_type = t return t end end - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end + local function patch_visitors(my_visit_node, + after_node, + my_visit_type, + after_type) - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) + + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) + end + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) + local function set_feat(feat, default) + if feat then + return (feat == "on") + else + return default + end end - local default_type_visitor = { - after = function(typ, _children) - return typ - end, - } + tl.type_check = function(ast, filename, opts, env) + assert(type(filename) == "string", "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts).env), "tl.type_check signature has changed, pass env separately") - visit_type.cbs["interface"] = visit_type.cbs["record"] + filename = filename or "?" - visit_type.cbs["string"] = default_type_visitor - visit_type.cbs["tupletable"] = default_type_visitor - visit_type.cbs["typedecl"] = default_type_visitor - visit_type.cbs["typealias"] = default_type_visitor - visit_type.cbs["array"] = default_type_visitor - visit_type.cbs["map"] = default_type_visitor - visit_type.cbs["enum"] = default_type_visitor - visit_type.cbs["boolean"] = default_type_visitor - visit_type.cbs["nil"] = default_type_visitor - visit_type.cbs["number"] = default_type_visitor - visit_type.cbs["integer"] = default_type_visitor - visit_type.cbs["thread"] = default_type_visitor - visit_type.cbs["emptytable"] = default_type_visitor - visit_type.cbs["literal_table_item"] = default_type_visitor - visit_type.cbs["unresolved_emptytable_value"] = default_type_visitor - visit_type.cbs["tuple"] = default_type_visitor - visit_type.cbs["poly"] = default_type_visitor - visit_type.cbs["any"] = default_type_visitor - visit_type.cbs["unknown"] = default_type_visitor - visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor - visit_type.cbs["none"] = default_type_visitor + opts = opts or {} + + if not env then + local err + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + local self = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } - close_types(st[1]) - check_for_unused_vars(st[1], true) + setmetatable(self, { __index = TypeChecker }) - clear_redundant_errors(errors) + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET - add_compat_entries(ast, all_needs_compat, env.gen_compat) + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + self.subtype_relations = shallow_copy_table(self.subtype_relations) - if tc then - env.reporter:store_result(tc, env.globals) - end + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true - return result -end + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true + + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets) + if #rets.tuple == 0 then + return a_vararg(rets, { a_type(rets, "unknown", {}) }) + end + return rets + end + else + self.get_rets = function(rets) + return rets + end + end + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after) -function tl.symbols_in_scope(tr, y, x) - local function find(symbols, at_y, at_x) - local function le(a, b) - return a[1] < b[1] or - (a[1] == b[1] and a[2] <= b[2]) end - return binary_search(symbols, { at_y, at_x }, le) or 0 - end - local ret = {} + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) - local n = find(tr.symbols, y, x) + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] - else - ret[s[3]] = s[4] - n = n - 1 + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) end - end - return ret + return result + end end @@ -12297,9 +12414,24 @@ local function read_full_file(fd) return content, err end -tl.process = function(filename, env, fd) - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename, input) + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename, env, fd) if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12319,23 +12451,38 @@ tl.process = function(filename, env, fd) return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") +function tl.target_from_lua_version(str) + if str == "Lua 5.1" or + str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime, filename, input) + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + }, + } end -function tl.process_string(input, is_lua, env, filename) - env = env or tl.init_env(is_lua) +function tl.process_string(input, env, filename) + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12347,7 +12494,7 @@ function tl.process_string(input, is_lua, env, filename) local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12357,14 +12504,7 @@ function tl.process_string(input, is_lua, env, filename) return result end - local opts = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12372,15 +12512,15 @@ function tl.process_string(input, is_lua, env, filename) end tl.gen = function(input, env, pp) - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12396,28 +12536,25 @@ local function tl_package_loader(module_name) if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_type("typedecl", { def = CIRCULAR_REQUIRE }) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_type(w, "typedecl", { def = a_type(w, "circular_require", {}) }) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname, loader_data) @@ -12443,21 +12580,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str) - if str == "Lua 5.1" or - str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax, env_tbl) +local function env_for(opts, env_tbl) if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12466,7 +12592,7 @@ local function env_for(lax, env_tbl) tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12476,17 +12602,14 @@ tl.load = function(input, chunkname, mode, ...) return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12500,7 +12623,7 @@ tl.load = function(input, chunkname, mode, ...) mode = mode:gsub("c", "") end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12508,4 +12631,29 @@ tl.load = function(input, chunkname, mode, ...) return load(code, chunkname, mode, ...) end + + + + +function tl.get_types(result) + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax, gen_compat, gen_target, predefined) + local opts = { + defaults = { + feat_lax = (lax and "on" or "off"), + gen_compat = ((type(gen_compat) == "string") and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl diff --git a/tl.tl b/tl.tl index a8b612ec6..00400e2ab 100644 --- a/tl.tl +++ b/tl.tl @@ -476,9 +476,16 @@ end ]=====] local interface Where + f: string y: integer x: integer +end + +local record Errors filename: string + errors: {Error} + warnings: {Error} + unknown_dots: {string:boolean} end local record tl @@ -492,13 +499,13 @@ local record tl end type LoadFunction = function(...:any): any... - enum CompatMode + enum GenCompat "off" "optional" "required" end - enum TargetMode + enum GenTarget "5.1" "5.3" "5.4" @@ -516,25 +523,23 @@ local record tl end record TypeCheckOptions - lax: boolean - filename: string - gen_compat: CompatMode - gen_target: TargetMode - env: Env + feat_lax: Feat + feat_arity: Feat + gen_compat: GenCompat + gen_target: GenTarget run_internal_compiler_checks: boolean end record Env globals: {string:Variable} modules: {string:Type} + module_filenames: {string:string} loaded: {string:Result} loaded_order: {string} reporter: TypeReporter - gen_compat: CompatMode - gen_target: TargetMode keep_going: boolean report_types: boolean - feat_arity: boolean + defaults: TypeCheckOptions end record Result @@ -571,6 +576,8 @@ local record tl i: integer end + type errors = Errors + typecodes: {string:integer} record TypeInfo @@ -601,28 +608,28 @@ local record tl end record EnvOptions - lax_mode: boolean - gen_compat: CompatMode - gen_target: TargetMode - feat_arity: Feat + defaults: TypeCheckOptions predefined_modules: {string} end load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, ? FILE): (Result, string) - process_string: function(string, boolean, Env, ? string): Result + process_string: function(string, Env, ? string): Result gen: function(string, Env, PrettyPrintOptions): string, Result - type_check: function(Node, TypeCheckOptions): Result, string - new_env: function(EnvOptions): Env, string - init_env: function(? boolean, ? boolean | CompatMode, ? TargetMode, ? {string}): Env, string + type_check: function(Node, string, TypeCheckOptions, ? Env): Result, string + new_env: function(? EnvOptions): Env, string version: function(): string + -- Backwards compatibility + init_env: function(? boolean, ? boolean | GenCompat, ? GenTarget, ? {string}): Env, string + package_loader_env: Env load_envs: { {any:any} : Env } end local record TypeReporter typeid_to_num: {integer: integer} + typename_to_num: {TypeName: integer} next_num: integer tr: TypeReport @@ -684,17 +691,23 @@ tl.typecodes = { INVALID = 0x80000000, } -local type Result = tl.Result local type Env = tl.Env +local type EnvOptions = tl.EnvOptions local type Error = tl.Error -local type CompatMode = tl.CompatMode +local type Feat = tl.Feat +local type GenCompat = tl.GenCompat +local type GenTarget = tl.GenTarget +local type LoadFunction = tl.LoadFunction +local type LoadMode = tl.LoadMode local type PrettyPrintOptions = tl.PrettyPrintOptions +local type Result = tl.Result local type TypeCheckOptions = tl.TypeCheckOptions -local type LoadMode = tl.LoadMode -local type LoadFunction = tl.LoadFunction -local type TargetMode = tl.TargetMode local type TypeInfo = tl.TypeInfo local type TypeReport = tl.TypeReport +local type WarningKind = tl.WarningKind + +local DEFAULT_GEN_COMPAT : GenCompat = "optional" +local DEFAULT_GEN_TARGET : GenTarget = "5.3" local enum Narrow "narrow" @@ -1515,7 +1528,6 @@ local enum TypeName "any" "unknown" -- to be used in lax mode only "invalid" -- producing a new value of this type (not propagating) must always produce a type error - "unresolved" "none" "*" end @@ -1552,7 +1564,6 @@ local table_types : {TypeName:boolean} = { ["any"] = false, ["unknown"] = false, ["invalid"] = false, - ["unresolved"] = false, ["none"] = false, ["*"] = false, } @@ -1561,6 +1572,9 @@ local interface Type is Where where self.typename + y: integer + x: integer + typename: TypeName -- discriminator typeid: integer -- unique identifier inferred_at: Where -- for error messages @@ -1574,7 +1588,24 @@ local record StringType literal: string end -local type TypeType = TypeAliasType | TypeDeclType +local function is_numeric_type(t:Type): boolean + return t.typename == "number" or t.typename == "integer" +end + +local interface NumericType + is Type + where is_numeric_type(self) +end + +local record IntegerType + is NumericType + where self.typename == "integer" +end + +local record BooleanType + is Type + where self.typename == "boolean" +end local record TypeDeclType is Type @@ -1592,6 +1623,8 @@ local record TypeAliasType is_nested_alias: boolean end +local type TypeType = TypeDeclType | TypeAliasType + local record LiteralTableItemType is Type where self.typename == "literal_table_item" @@ -1602,13 +1635,12 @@ local record LiteralTableItemType vtype: Type end -local record UnresolvedType - is Type - where self.typename == "unresolved" - - labels: {string:{Node}} - nominals: {string:{NominalType}} - global_types: {string:boolean} +local record Scope + vars: {string:Variable} + labels: {string:Node} + pending_labels: {string:{Node}} + pending_nominals: {string:{NominalType}} + pending_global_types: {string:boolean} narrows: {string:boolean} end @@ -1675,6 +1707,11 @@ local record InvalidType where self.typename == "invalid" end +local record UnknownType + is Type + where self.typename == "unknown" +end + local record TupleType is Type where self.typename == "tuple" @@ -1849,7 +1886,8 @@ local interface Fact where self.fact fact: FactType - where: Where + w: Where + no_infer: boolean end local record TruthyFact @@ -2014,6 +2052,9 @@ local record Node -- goto label: string + -- label + used_label: boolean + casttype: Type -- variable @@ -2032,10 +2073,125 @@ local record Node debug_type: Type end -local function is_number_type(t:Type): boolean - return t.typename == "number" or t.typename == "integer" +local function a_type(w: Where, typename: TypeName, t: T): T + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local function edit_type(w: Where, t: Type, typename: TypeName): Type + t.typeid = new_typeid() + t.f = w.f + t.x = w.x + t.y = w.y + t.typename = typename + return t +end + +local macroexp a_typedecl(w: Where, def: Type): TypeDeclType + return a_type(w, "typedecl", { def = def } as TypeDeclType) +end + +local macroexp a_tuple(w: Where, t: {Type}): TupleType + return a_type(w, "tuple", { tuple = t } as TupleType) +end + +local macroexp a_union(w: Where, t: {Type}): UnionType + return a_type(w, "union", { types = t } as UnionType) +end + +local function a_function(w: Where, t: FunctionType): FunctionType + assert(t.min_arity) + return a_type(w, "function", t) +end + +local function a_vararg(w: Where, t: {Type}): TupleType + local typ = a_tuple(w, t) + typ.is_va = true + return typ +end + +local macroexp an_array(w: Where, t: Type): ArrayType + return a_type(w, "array", { elements = t } as ArrayType) +end + +local macroexp a_map(w: Where, k: Type, v: Type): MapType + return a_type(w, "map", { keys = k, values = v } as MapType) +end + +local function a_nominal(n: Node, names: {string}): NominalType + return a_type(n, "nominal", { names = names } as NominalType) end +local macroexp an_invalid(w: Where): InvalidType + return a_type(w, "invalid", {} as InvalidType) +end + +local macroexp an_unknown(w: Where): UnknownType + return a_type(w, "unknown", {} as UnknownType) +end + +local an_operator: function(Node, integer, string): Operator + +local function shallow_copy_new_type(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + copy.typeid = new_typeid() + return copy as T +end + +local function shallow_copy_table(t: T): T + local copy: {any:any} = {} + for k, v in pairs(t as {any:any}) do + copy[k] = v + end + return copy as T +end + +-- TODO move to Errors module +local function clear_redundant_errors(errors: {Error}) + local redundant: {integer} = {} + local lastx, lasty = 0, 0 + for i, err in ipairs(errors) do + err.i = i + end + table.sort(errors, function(a: Error, b: Error): boolean + local af = assert(a.filename) + local bf = assert(b.filename) + return af < bf + or (af == bf and (a.y < b.y + or (a.y == b.y and (a.x < b.x + or (a.x == b.x and (a.i < b.i)))))) + end) + for i, err in ipairs(errors) do + err.i = nil + if err.x == lastx and err.y == lasty then + table.insert(redundant, i) + end + lastx, lasty = err.x, err.y + end + for i = #redundant, 1, -1 do + table.remove(errors, redundant[i]) + end +end + +local simple_types: {TypeName:boolean} = { + ["nil"] = true, + ["any"] = true, + ["number"] = true, + ["string"] = true, + ["thread"] = true, + ["boolean"] = true, + ["integer"] = true, +} + +do ----------------------------------------------------------------------------- + local record ParseState tokens: {Token} errs: {Error} @@ -2108,163 +2264,52 @@ local function verify_end(ps: ParseState, i: integer, istart: integer, node: Nod return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") end -local function new_node(tokens: {Token}, i: integer, kind?: NodeKind): Node - local t = tokens[i] - return { y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } -end - -local function a_type(typename: TypeName, t: T): T - t.typeid = new_typeid() - t.typename = typename - return t +local function new_node(ps: ParseState, i: integer, kind?: NodeKind): Node + local t = ps.tokens[i] + return { f = ps.filename, y = t.y, x = t.x, tk = t.tk, kind = kind or (t.kind as NodeKind) } end -local function edit_type(t: Type, typename: TypeName): Type +local function new_type(ps: ParseState, i: integer, typename: TypeName): Type + local token = ps.tokens[i] + local t: Type = {} t.typeid = new_typeid() + t.f = ps.filename + t.x = token.x + t.y = token.y t.typename = typename return t end -local function new_type(ps: ParseState, i: integer, typename: TypeName): Type - local token = ps.tokens[i] - return a_type(typename, { - filename = ps.filename, - y = token.y, - x = token.x, - --tk = token.tk - }) -end - local function new_typedecl(ps: ParseState, i: integer, def: Type): TypeDeclType local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def return t end -local macroexp a_typedecl(def: Type): TypeDeclType - return a_type("typedecl", { def = def } as TypeDeclType) -end - -local macroexp a_tuple(t: {Type}): TupleType - return a_type("tuple", { tuple = t } as TupleType) -end - -local macroexp a_union(t: {Type}): UnionType - return a_type("union", { types = t } as UnionType) -end - ---local macroexp a_poly(t: {FunctionType}): PolyType --- return a_type("poly", { types = t } as PolyType) ---end --- -local function a_function(t: FunctionType): FunctionType - assert(t.min_arity) - return a_type("function", t) -end - -local record Opt - where self.opttype - - opttype: Type -end - ---local function OPT(t: Type): Opt --- return { opttype = t } ---end --- -local record Args - is {Type|Opt} - - is_va: boolean -end - -local function va_args(args: Args): Args - args.is_va = true - return args -end - -local record FuncArgs - is HasTypeArgs - - args: Args - rets: Args - needs_compat: boolean -end - -local function a_fn(f: FuncArgs): FunctionType - local args_t = a_tuple {} - local tup = args_t.tuple - args_t.is_va = f.args.is_va - local min_arity = f.args.is_va and -1 or 0 - for _, a in ipairs(f.args) do - if a is Opt then - table.insert(tup, a.opttype) - else - table.insert(tup, a) - min_arity = min_arity + 1 - end - end - - local rets_t = a_tuple {} - tup = rets_t.tuple - rets_t.is_va = f.rets.is_va - for _, a in ipairs(f.rets) do - assert(a is Type) - table.insert(tup, a) - end - - return a_type("function", { - args = args_t, - rets = rets_t, - min_arity = min_arity, - needs_compat = f.needs_compat, - typeargs = f.typeargs, - } as FunctionType) -end - -local function a_vararg(t: {Type}): TupleType - local typ = a_tuple(t) - typ.is_va = true - return typ -end - -local macroexp an_array(t: Type): ArrayType - return a_type("array", { elements = t } as ArrayType) -end - -local macroexp a_map(k: Type, v: Type): MapType - return a_type("map", { keys = k, values = v } as MapType) +local function new_tuple(ps: ParseState, i: integer, types?: {Type}, is_va?: boolean): TupleType, {Type} + local t = new_type(ps, i, "tuple") as TupleType + t.is_va = is_va + t.tuple = types or {} + return t, t.tuple end -local NIL = a_type("nil", {}) -local ANY = a_type("any", {}) -local TABLE = a_map(ANY, ANY) -local NUMBER = a_type("number", {}) -local STRING = a_type("string", {}) -local THREAD = a_type("thread", {}) -local BOOLEAN = a_type("boolean", {}) -local INTEGER = a_type("integer", {}) - -local function shallow_copy_new_type(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v - end - copy.typeid = new_typeid() - return copy as T +local function new_typealias(ps: ParseState, i: integer, alias_to: NominalType): TypeAliasType + local t = new_type(ps, i, "typealias") as TypeAliasType + t.alias_to = alias_to + return t end -local function shallow_copy_table(t: T): T - local copy: {any:any} = {} - for k, v in pairs(t as {any:any}) do - copy[k] = v +local function new_nominal(ps: ParseState, i: integer, name?: string): NominalType + local t = new_type(ps, i, "nominal") as NominalType + if name then + t.names = { name } end - return copy as T + return t end local function verify_kind(ps: ParseState, i: integer, kind: TokenKind, node_kind?: NodeKind): integer, Node if ps.tokens[i].kind == kind then - return i + 1, new_node(ps.tokens, i, node_kind) + return i + 1, new_node(ps, i, node_kind) end return fail(ps, i, "syntax error, expected " .. kind) end @@ -2302,23 +2347,23 @@ local function parse_table_value(ps: ParseState, i: integer): integer, Node, int fail(ps, i, next_word == "record" and "syntax error: this syntax is no longer valid; declare nested record inside a record" or "syntax error: cannot declare interface inside a table; use a statement") - return skip_i, new_node(ps.tokens, i, "error_node") + return skip_i, new_node(ps, i, "error_node") end elseif next_word == "enum" and ps.tokens[i + 1].kind == "string" then i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_type_body) - return i, new_node(ps.tokens, i - 1, "error_node") + return i, new_node(ps, i - 1, "error_node") end local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end return i, e end local function parse_table_item(ps: ParseState, i: integer, n?: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "literal_table_item") + local node = new_node(ps, i, "literal_table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i, "unexpected eof") end @@ -2369,7 +2414,7 @@ local function parse_table_item(ps: ParseState, i: integer, n?: integer): intege end end - node.key = new_node(ps.tokens, i, "integer") + node.key = new_node(ps, i, "integer") node.key_parsed = "implicit" node.key.constnum = n node.key.tk = tostring(n) @@ -2445,7 +2490,7 @@ local function parse_bracket_list(ps: ParseState, i: integer, list: {T}, open end local function parse_table_literal(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "literal_table") + local node = new_node(ps, i, "literal_table") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end @@ -2501,16 +2546,21 @@ local function parse_typearg(ps: ParseState, i: integer): integer, TypeArgType, i = i + 1 i, constraint = parse_interface_name(ps, i) -- FIXME what about generic interfaces end - return i, a_type("typearg", { - y = ps.tokens[i - 2].y, - x = ps.tokens[i - 2].x, - typearg = name, - constraint = constraint, - } as TypeArgType) + local t = new_type(ps, i, "typearg") as TypeArgType + t.typearg = name + t.constraint = constraint + return i, t end local function parse_return_types(ps: ParseState, i: integer): integer, TupleType - return parse_type_list(ps, i, "rets") + local iprev = i - 1 + local t: TupleType + i, t = parse_type_list(ps, i, "rets") + if #t.tuple == 0 then + t.x = ps.tokens[iprev].x + t.y = ps.tokens[iprev].y + end + return i, t end local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType @@ -2523,31 +2573,25 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio i, typ.args, typ.is_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) else - typ.args = a_vararg { ANY } - typ.rets = a_vararg { ANY } + typ.args = new_tuple(ps, i, { new_type(ps, i, "any") }, true) + typ.rets = new_tuple(ps, i, { new_type(ps, i, "any") }, true) end return i, typ end -local simple_types: {string:Type} = { - ["nil"] = NIL, - ["any"] = ANY, - ["table"] = TABLE, - ["number"] = NUMBER, - ["string"] = STRING, - ["thread"] = THREAD, - ["boolean"] = BOOLEAN, - ["integer"] = INTEGER, -} - local function parse_simple_type_or_nominal(ps: ParseState, i: integer): integer, Type local tk = ps.tokens[i].tk - local st = simple_types[tk] + local st = simple_types[tk as TypeName] if st then - return i + 1, st + return i + 1, new_type(ps, i, tk as TypeName) + elseif tk == "table" then + local typ = new_type(ps, i, "map") as MapType + typ.keys = new_type(ps, i, "any") + typ.values = new_type(ps, i, "any") + return i + 1, typ end - local typ = new_type(ps, i, "nominal") as NominalType - typ.names = { tk } + + local typ = new_nominal(ps, i, tk) i = i + 1 while ps.tokens[i].tk == "." do i = i + 1 @@ -2614,12 +2658,7 @@ local function parse_base_type(ps: ParseState, i: integer): integer, Type, integ elseif tk == "function" then return parse_function_type(ps, i) elseif tk == "nil" then - return i + 1, simple_types["nil"] - elseif tk == "table" then - local typ = new_type(ps, i, "map") as MapType - typ.keys = ANY - typ.values = ANY - return i + 1, typ + return i + 1, new_type(ps, i, "nil") end return fail(ps, i, "expected a type") end @@ -2655,12 +2694,6 @@ parse_type = function(ps: ParseState, i: integer): integer, Type, integer return i, bt end -local function new_tuple(ps: ParseState, i: integer): TupleType, {Type} - local t = new_type(ps, i, "tuple") as TupleType - t.tuple = {} - return t, t.tuple -end - parse_type_list = function(ps: ParseState, i: integer, mode: ParseTypeListMode): integer, TupleType local t, list = new_tuple(ps, i) @@ -2716,7 +2749,7 @@ local function parse_function_args_rets_body(ps: ParseState, i: integer, node: N end local function parse_function_value(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "function") + local node = new_node(ps, i, "function") i = verify_tk(ps, i, "function") return parse_function_args_rets_body(ps, i, node) end @@ -2737,7 +2770,7 @@ local function parse_literal(ps: ParseState, i: integer): integer, Node if kind == "identifier" then return verify_kind(ps, i, "identifier", "variable") elseif kind == "string" then - local node = new_node(ps.tokens, i, "string") + local node = new_node(ps, i, "string") node.conststr, node.is_longstring = unquote(tk) return i + 1, node elseif kind == "number" or kind == "integer" then @@ -2785,8 +2818,6 @@ local function node_is_require_call(n: Node): string end end -local an_operator: function(Node, integer, string): Operator - do local precedences: {integer:{string:integer}} = { [1] = { @@ -2861,8 +2892,8 @@ do -- small hack: for the sake of `tl types`, parse an invalid binary exp -- as a paren to produce a unary indirection on e1 and save its location. - local function failstore(tkop: Token, e1: Node): Node - return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } + local function failstore(ps: ParseState, tkop: Token, e1: Node): Node + return { f = ps.filename, y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } end local function P(ps: ParseState, i: integer): integer, Node @@ -2880,7 +2911,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } elseif ps.tokens[i].tk == "(" then i = i + 1 local prev_i = i @@ -2889,7 +2920,7 @@ do fail(ps, prev_i, "expected an expression") return i end - e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } + e1 = { f = ps.filename, y = t1.y, x = t1.x, kind = "paren", e1 = e1 } else i, e1 = parse_literal(ps, i) end @@ -2914,12 +2945,12 @@ do local skipped = skip(ps, i, parse_type as SkipFunction) if skipped > i + 1 then fail(ps, i, "syntax error, cannot declare a type here (missing 'local' or 'global'?)") - return skipped, failstore(tkop, e1) + return skipped, failstore(ps, tkop, e1) end end i, key = verify_kind(ps, i, "identifier") if not key then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if op.op == ":" then @@ -2929,30 +2960,30 @@ do else fail(ps, i, "expected a function call for a method") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call a method on this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } elseif tkop.tk == "(" then local op: Operator = new_operator(tkop, 2, "@funcall") local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot call this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "[" then @@ -2966,19 +2997,19 @@ do if not after_valid_prefixexp(ps, e1, prev_i) then fail(ps, prev_i, "cannot index this expression") - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } elseif tkop.kind == "string" or tkop.kind == "{" then local op: Operator = new_operator(tkop, 2, "@funcall") local prev_i = i - local args = new_node(ps.tokens, i, "expression_list") + local args = new_node(ps, i, "expression_list") local argument: Node if tkop.kind == "string" then - argument = new_node(ps.tokens, i) + argument = new_node(ps, i) argument.conststr = unquote(tkop.tk) i = i + 1 else @@ -2991,27 +3022,27 @@ do else fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") end - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end table.insert(args, argument) - e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } + e1 = { f = ps.filename, y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } table.insert(ps.required_modules, node_is_require_call(e1)) elseif tkop.tk == "as" or tkop.tk == "is" then local op: Operator = new_operator(tkop, 2, tkop.tk) i = i + 1 - local cast = new_node(ps.tokens, i, "cast") + local cast = new_node(ps, i, "cast") if ps.tokens[i].tk == "(" then i, cast.casttype = parse_type_list(ps, i, "casttype") else i, cast.casttype = parse_type(ps, i) end if not cast.casttype then - return i, failstore(tkop, e1) + return i, failstore(ps, tkop, e1) end - e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } + e1 = { f = ps.filename, y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else break end @@ -3042,7 +3073,7 @@ do end lookahead = ps.tokens[i].tk end - lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } + lhs = { f = ps.filename, y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs, } end return i, lhs end @@ -3069,7 +3100,7 @@ parse_expression_and_tk = function(ps: ParseState, i: integer, tk: string): inte local e: Node i, e = parse_expression(ps, i) if not e then - e = new_node(ps.tokens, i - 1, "error_node") + e = new_node(ps, i - 1, "error_node") end if ps.tokens[i].tk == tk then i = i + 1 @@ -3147,7 +3178,7 @@ local function parse_argument(ps: ParseState, i: integer): integer, Node, intege end parse_argument_list = function(ps: ParseState, i: integer): integer, Node, integer - local node = new_node(ps.tokens, i, "argument_list") + local node = new_node(ps, i, "argument_list") i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) local opts = false local min_arity = 0 @@ -3242,16 +3273,16 @@ end local function parse_identifier(ps: ParseState, i: integer): integer, Node, integer if ps.tokens[i].kind == "identifier" then - return i + 1, new_node(ps.tokens, i, "identifier") + return i + 1, new_node(ps, i, "identifier") end i = fail(ps, i, "syntax error, expected identifier") - return i, new_node(ps.tokens, i, "error_node") + return i, new_node(ps, i, "error_node") end local function parse_local_function(ps: ParseState, i: integer): integer, Node i = verify_tk(ps, i, "local") i = verify_tk(ps, i, "function") - local node = new_node(ps.tokens, i - 2, "local_function") + local node = new_node(ps, i - 2, "local_function") i, node.name = parse_identifier(ps, i) return parse_function_args_rets_body(ps, i, node) end @@ -3264,7 +3295,7 @@ end local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): integer, Node local orig_i = i i = verify_tk(ps, i, "function") - local fn = new_node(ps.tokens, i - 1, "global_function") + local fn = new_node(ps, i - 1, "global_function") local names: {Node} = {} i, names[1] = parse_identifier(ps, i) while ps.tokens[i].tk == "." do @@ -3284,7 +3315,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int for i2 = 2, #names - 1 do local dot = an_operator(names[i2], 2, ".") names[i2].kind = "identifier" - owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } + owner = { f = ps.filename, y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } end fn.fn_owner = owner end @@ -3292,8 +3323,8 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y i = parse_function_args_rets_body(ps, i, fn) - if fn.is_method then - table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) + if fn.is_method and fn.args then + table.insert(fn.args, 1, { f = ps.filename, x = selfx, y = selfy, tk = "self", kind = "identifier", is_self = true }) fn.min_arity = fn.min_arity + 1 end @@ -3311,7 +3342,7 @@ local function parse_function(ps: ParseState, i: integer, fk: FunctionKind): int end local function parse_if_block(ps: ParseState, i: integer, n: integer, node: Node, is_else?: boolean): integer, Node - local block = new_node(ps.tokens, i, "if_block") + local block = new_node(ps, i, "if_block") i = i + 1 block.if_parent = node block.if_block_n = n @@ -3333,7 +3364,7 @@ end local function parse_if(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "if") + local node = new_node(ps, i, "if") node.if_blocks = {} i, node = parse_if_block(ps, i, 1, node) if not node then @@ -3359,7 +3390,7 @@ end local function parse_while(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "while") + local node = new_node(ps, i, "while") i = verify_tk(ps, i, "while") i, node.exp = parse_expression_and_tk(ps, i, "do") i, node.body = parse_statements(ps, i) @@ -3369,7 +3400,7 @@ end local function parse_fornum(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "fornum") + local node = new_node(ps, i, "fornum") i = i + 1 i, node.var = parse_identifier(ps, i) i = verify_tk(ps, i, "=") @@ -3388,12 +3419,12 @@ end local function parse_forin(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "forin") + local node = new_node(ps, i, "forin") i = i + 1 - node.vars = new_node(ps.tokens, i, "variable_list") + node.vars = new_node(ps, i, "variable_list") i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_identifier) i = verify_tk(ps, i, "in") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) if #node.exps < 1 then return fail(ps, i, "missing iterator expression in generic for") @@ -3415,7 +3446,7 @@ local function parse_for(ps: ParseState, i: integer): integer, Node end local function parse_repeat(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "repeat") + local node = new_node(ps, i, "repeat") i = verify_tk(ps, i, "repeat") i, node.body = parse_statements(ps, i) node.body.is_repeat = true @@ -3427,7 +3458,7 @@ end local function parse_do(ps: ParseState, i: integer): integer, Node local istart = i - local node = new_node(ps.tokens, i, "do") + local node = new_node(ps, i, "do") i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) i = verify_end(ps, i, istart, node) @@ -3435,13 +3466,13 @@ local function parse_do(ps: ParseState, i: integer): integer, Node end local function parse_break(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "break") + local node = new_node(ps, i, "break") i = verify_tk(ps, i, "break") return i, node end local function parse_goto(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "goto") + local node = new_node(ps, i, "goto") i = verify_tk(ps, i, "goto") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3449,7 +3480,7 @@ local function parse_goto(ps: ParseState, i: integer): integer, Node end local function parse_label(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "label") + local node = new_node(ps, i, "label") i = verify_tk(ps, i, "::") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") @@ -3474,9 +3505,9 @@ for k, v in pairs(stop_statement_list) do end local function parse_return(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "return") + local node = new_node(ps, i, "return") i = verify_tk(ps, i, "return") - node.exps = new_node(ps.tokens, i, "expression_list") + node.exps = new_node(ps, i, "expression_list") i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) if ps.tokens[i].kind == ";" then i = i + 1 @@ -3514,12 +3545,13 @@ local function parse_nested_type(ps: ParseState, i: integer, def: RecordLikeType return fail(ps, i, "expected a variable name") end - local nt: Node = new_node(ps.tokens, i - 2, "newtype") + local nt: Node = new_node(ps, i - 2, "newtype") local ndef = new_type(ps, i, typename) + local itype = i local iok = parse_body(ps, i, ndef, nt) if iok then i = iok - nt.newtype = new_typedecl(ps, i, ndef) + nt.newtype = new_typedecl(ps, itype, ndef) end store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) @@ -3576,7 +3608,7 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): -- if ps.tokens[i].tk == "<" then -- i, node.typeargs = parse_anglebracket_list(ps, i, parse_typearg) -- end - local node = new_node(ps.tokens, istart, "macroexp") + local node = new_node(ps, istart, "macroexp") local i: integer i, node.args, node.min_arity = parse_argument_list(ps, iargs) i, node.rets = parse_return_types(ps, i) @@ -3588,18 +3620,14 @@ local function parse_macroexp(ps: ParseState, istart: integer, iargs: integer): end local function parse_where_clause(ps: ParseState, i: integer): integer, Node - local node = new_node(ps.tokens, i, "macroexp") - - local selftype = new_type(ps, i, "nominal") as NominalType - selftype.names = { "@self" } - - node.args = new_node(ps.tokens, i, "argument_list") - node.args[1] = new_node(ps.tokens, i, "argument") + local node = new_node(ps, i, "macroexp") + node.args = new_node(ps, i, "argument_list") + node.args[1] = new_node(ps, i, "argument") node.args[1].tk = "self" - node.args[1].argtype = selftype + node.args[1].argtype = new_nominal(ps, i, "@self") node.min_arity = 1 node.rets = new_tuple(ps, i) - node.rets.tuple[1] = BOOLEAN + node.rets.tuple[1] = new_type(ps, i, "boolean") i, node.exp = parse_expression(ps, i) end_at(node, ps.tokens[i - 1]) return i, node @@ -3681,15 +3709,10 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType, no local typ = new_type(ps, wstart, "function") as FunctionType typ.is_method = true typ.min_arity = 1 - typ.args = a_tuple { - a_type("nominal", { - y = typ.y, - x = typ.x, - filename = ps.filename, - names = { "@self" } - } as NominalType) - } - typ.rets = a_tuple { BOOLEAN } + typ.args = new_tuple(ps, wstart, { + a_nominal(where_macroexp, { "@self" }) + }) + typ.rets = new_tuple(ps, wstart, { new_type(ps, wstart, "boolean") }) typ.macroexp = where_macroexp def.meta_fields = {} @@ -3810,7 +3833,7 @@ parse_type_body_fns = { } parse_newtype = function(ps: ParseState, i: integer): integer, Node - local node: Node = new_node(ps.tokens, i, "newtype") + local node: Node = new_node(ps, i, "newtype") local def: Type local tn = ps.tokens[i].tk as TypeName local itype = i @@ -3831,9 +3854,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end if def is NominalType then - local typealias = new_type(ps, itype, "typealias") as TypeAliasType - typealias.alias_to = def - node.newtype = typealias + node.newtype = new_typealias(ps, itype, def) else node.newtype = new_typedecl(ps, itype, def) end @@ -3843,7 +3864,7 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node end local function parse_assignment_expression_list(ps: ParseState, i: integer, asgn: Node): integer, Node - asgn.exps = new_node(ps.tokens, i, "expression_list") + asgn.exps = new_node(ps, i, "expression_list") repeat i = i + 1 local val: Node @@ -3893,8 +3914,8 @@ do return fail(ps, i, "syntax error") end - local asgn: Node = new_node(ps.tokens, istart, "assignment") - asgn.vars = new_node(ps.tokens, istart, "variable_list") + local asgn: Node = new_node(ps, istart, "assignment") + asgn.vars = new_node(ps, istart, "variable_list") asgn.vars[1] = exp if ps.tokens[i].tk == "," then i = i + 1 @@ -3915,9 +3936,9 @@ do end local function parse_variable_declarations(ps: ParseState, i: integer, node_name: NodeKind): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) - asgn.vars = new_node(ps.tokens, i, "variable_list") + asgn.vars = new_node(ps, i, "variable_list") i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) if #asgn.vars == 0 then return fail(ps, i, "expected a local variable definition") @@ -3945,7 +3966,7 @@ end local function parse_type_declaration(ps: ParseState, i: integer, node_name: NodeKind): integer, Node i = i + 2 -- skip `local` or `global`, and `type` - local asgn: Node = new_node(ps.tokens, i, node_name) + local asgn: Node = new_node(ps, i, node_name) i, asgn.var = parse_variable_name(ps, i) if not asgn.var then return fail(ps, i, "expected a type name") @@ -3985,8 +4006,8 @@ local function parse_type_declaration(ps: ParseState, i: integer, node_name: Nod end local function parse_type_constructor(ps: ParseState, i: integer, node_name: NodeKind, type_name: TypeName, parse_body: ParseBody): integer, Node - local asgn: Node = new_node(ps.tokens, i, node_name) - local nt: Node = new_node(ps.tokens, i, "newtype") + local asgn: Node = new_node(ps, i, node_name) + local nt: Node = new_node(ps, i, "newtype") asgn.value = nt local itype = i local def = new_type(ps, i, type_name) @@ -4015,7 +4036,7 @@ end local function parse_local_macroexp(ps: ParseState, i: integer): integer, Node local istart = i i = i + 2 -- skip `local` - local node = new_node(ps.tokens, i, "local_macroexp") + local node = new_node(ps, i, "local_macroexp") i, node.name = parse_identifier(ps, i) i, node.macrodef = parse_macroexp(ps, istart, i) end_at(node, ps.tokens[i - 1]) @@ -4085,7 +4106,7 @@ local needs_local_or_global: {string : function(ParseState, integer):(integer, N } parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): integer, Node - local node = new_node(ps.tokens, i, "statements") + local node = new_node(ps, i, "statements") local item: Node while true do while ps.tokens[i].kind == ";" do @@ -4130,32 +4151,6 @@ parse_statements = function(ps: ParseState, i: integer, toplevel?: boolean): int return i, node end -local function clear_redundant_errors(errors: {Error}) - local redundant: {integer} = {} - local lastx, lasty = 0, 0 - for i, err in ipairs(errors) do - err.i = i - end - table.sort(errors, function(a: Error, b: Error): boolean - local af = a.filename or "" - local bf = b.filename or "" - return af < bf - or (af == bf and (a.y < b.y - or (a.y == b.y and (a.x < b.x - or (a.x == b.x and (a.i < b.i)))))) - end) - for i, err in ipairs(errors) do - err.i = nil - if err.x == lastx and err.y == lasty then - table.insert(redundant, i) - end - lastx, lasty = err.x, err.y - end - for i = #redundant, 1, -1 do - table.remove(errors, redundant[i]) - end -end - function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Node, {string} errs = errs or {} local ps: ParseState = { @@ -4185,17 +4180,19 @@ function tl.parse(input: string, filename: string): Node, {Error}, {string} return node, errs, required_modules end +end ---------------------------------------------------------------------------- + -------------------------------------------------------------------------------- -- AST traversal -------------------------------------------------------------------------------- -local record VisitorCallbacks - before: function(N) - before_exp: function({N}, {T}) - before_arguments: function({N}, {T}) - before_statements: function({N}, {T}) - before_e2: function({N}, {T}) - after: function(N, {T}): T +local record VisitorCallbacks + before: function(S, N) + before_exp: function(S, {N}, {T}) + before_arguments: function(S, {N}, {T}) + before_statements: function(S, {N}, {T}) + before_e2: function(S, {N}, {T}) + after: function(S, N, {T}): T end local enum VisitorExtraCallback @@ -4205,9 +4202,11 @@ local enum VisitorExtraCallback "before_e2" end -local record Visitor - cbs: {K:VisitorCallbacks} - after: function(N, {T}, T): T +local type VisitorAfter = function(S, N, {T}, T): T + +local record Visitor + cbs: {K:VisitorCallbacks} + after: VisitorAfter allow_missing_cbs: boolean end @@ -4296,7 +4295,7 @@ local function tl_debug_indent_pop(mark: string, single: string, y: integer, x: end end -local function recurse_type(ast: Type, visit: Visitor): T +local function recurse_type(s: S, ast: Type, visit: Visitor): T local kind = ast.typename if TL_DEBUG then @@ -4308,7 +4307,7 @@ local function recurse_type(ast: Type, visit: Visitor): T if cbkind then local cbkind_before = cbkind.before if cbkind_before then - cbkind_before(ast) + cbkind_before(s, ast) end end @@ -4316,90 +4315,90 @@ local function recurse_type(ast: Type, visit: Visitor): T if ast is TupleType then for i, child in ipairs(ast.tuple) do - xs[i] = recurse_type(child, visit) + xs[i] = recurse_type(s, child, visit) end elseif ast is AggregateType then for _, child in ipairs(ast.types) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end elseif ast is MapType then - table.insert(xs, recurse_type(ast.keys, visit)) - table.insert(xs, recurse_type(ast.values, visit)) + table.insert(xs, recurse_type(s, ast.keys, visit)) + table.insert(xs, recurse_type(s, ast.values, visit)) elseif ast is RecordLikeType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.interface_list then for _, child in ipairs(ast.interface_list) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end if ast.fields then for _, child in fields_of(ast) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.meta_fields then for _, child in fields_of(ast, "meta") do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is FunctionType then if ast.typeargs then for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.args then for _, child in ipairs(ast.args.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end if ast.rets then for _, child in ipairs(ast.rets.tuple) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is NominalType then if ast.typevals then for _, child in ipairs(ast.typevals) do - table.insert(xs, recurse_type(child, visit)) + table.insert(xs, recurse_type(s, child, visit)) end end elseif ast is TypeArgType then if ast.constraint then - table.insert(xs, recurse_type(ast.constraint, visit)) + table.insert(xs, recurse_type(s, ast.constraint, visit)) end elseif ast is ArrayType then if ast.elements then - table.insert(xs, recurse_type(ast.elements, visit)) + table.insert(xs, recurse_type(s, ast.elements, visit)) end elseif ast is LiteralTableItemType then if ast.ktype then - table.insert(xs, recurse_type(ast.ktype, visit)) + table.insert(xs, recurse_type(s, ast.ktype, visit)) end if ast.vtype then - table.insert(xs, recurse_type(ast.vtype, visit)) + table.insert(xs, recurse_type(s, ast.vtype, visit)) end elseif ast is TypeAliasType then - table.insert(xs, recurse_type(ast.alias_to, visit)) + table.insert(xs, recurse_type(s, ast.alias_to, visit)) elseif ast is TypeDeclType then - table.insert(xs, recurse_type(ast.def, visit)) + table.insert(xs, recurse_type(s, ast.def, visit)) end local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end local visit_after = visit.after if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4409,25 +4408,26 @@ local function recurse_type(ast: Type, visit: Visitor): T return ret end -local function recurse_typeargs(ast: Node, visit_type: Visitor) +local function recurse_typeargs(s: S, ast: Node, visit_type: Visitor) if ast.typeargs then for _, typearg in ipairs(ast.typeargs) do - recurse_type(typearg, visit_type) + recurse_type(s, typearg, visit_type) end end end -local function extra_callback(name: VisitorExtraCallback, - ast: Node, - xs: {T}, - visit_node: Visitor) +local function extra_callback(name: VisitorExtraCallback, + s: S, + ast: Node, + xs: {T}, + visit_node: Visitor) local cbs = visit_node.cbs if not cbs then return end local nbs = cbs[ast.kind] if not nbs then return end local bs = nbs[name] if not bs then return end - bs(ast, xs) + bs(s, ast, xs) end local no_recurse_node: {NodeKind : boolean} = { @@ -4447,9 +4447,9 @@ local no_recurse_node: {NodeKind : boolean} = { ["type_identifier"] = true, } -local function recurse_node(root: Node, - visit_node: Visitor, - visit_type: Visitor): T +local function recurse_node(s: S, root: Node, + visit_node: Visitor, + visit_type: Visitor): T if not root then -- parse error return @@ -4466,9 +4466,9 @@ local function recurse_node(root: Node, local function walk_vars_exps(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) if ast.decltuple then - xs[2] = recurse_type(ast.decltuple, visit_type) + xs[2] = recurse_type(s, ast.decltuple, visit_type) end - extra_callback("before_exp", ast, xs, visit_node) + extra_callback("before_exp", s, ast, xs, visit_node) if ast.exps then xs[3] = recurse(ast.exps) end @@ -4480,11 +4480,11 @@ local function recurse_node(root: Node, end local function walk_named_function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.name) xs[2] = recurse(ast.args) - xs[3] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[4] = recurse(ast.body) end @@ -4497,9 +4497,9 @@ local function recurse_node(root: Node, end xs[2] = p1 as T if ast.op.arity == 2 then - extra_callback("before_e2", ast, xs, visit_node) + extra_callback("before_e2", s, ast, xs, visit_node) if ast.op.op == "is" or ast.op.op == "as" then - xs[3] = recurse_type(ast.e2.casttype, visit_type) + xs[3] = recurse_type(s, ast.e2.casttype, visit_type) else xs[3] = recurse(ast.e2) end @@ -4517,7 +4517,7 @@ local function recurse_node(root: Node, xs[1] = recurse(ast.key) xs[2] = recurse(ast.value) if ast.itemtype then - xs[3] = recurse_type(ast.itemtype, visit_type) + xs[3] = recurse_type(s, ast.itemtype, visit_type) end end, @@ -4543,13 +4543,13 @@ local function recurse_node(root: Node, if ast.exp then xs[1] = recurse(ast.exp) end - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, ["while"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.exp) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[2] = recurse(ast.body) end, @@ -4559,45 +4559,45 @@ local function recurse_node(root: Node, end, ["macroexp"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[3] = recurse(ast.exp) end, ["function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.args) - xs[2] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[2] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, ["local_function"] = walk_named_function, ["global_function"] = walk_named_function, ["record_function"] = function(ast: Node, xs: {T}) - recurse_typeargs(ast, visit_type) + recurse_typeargs(s, ast, visit_type) xs[1] = recurse(ast.fn_owner) xs[2] = recurse(ast.name) - extra_callback("before_arguments", ast, xs, visit_node) + extra_callback("before_arguments", s, ast, xs, visit_node) xs[3] = recurse(ast.args) - xs[4] = recurse_type(ast.rets, visit_type) - extra_callback("before_statements", ast, xs, visit_node) + xs[4] = recurse_type(s, ast.rets, visit_type) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, ["local_macroexp"] = function(ast: Node, xs: {T}) -- TODO: generic macroexp xs[1] = recurse(ast.name) xs[2] = recurse(ast.macrodef.args) - xs[3] = recurse_type(ast.macrodef.rets, visit_type) - extra_callback("before_exp", ast, xs, visit_node) + xs[3] = recurse_type(s, ast.macrodef.rets, visit_type) + extra_callback("before_exp", s, ast, xs, visit_node) xs[4] = recurse(ast.macrodef.exp) end, ["forin"] = function(ast: Node, xs: {T}) xs[1] = recurse(ast.vars) xs[2] = recurse(ast.exps) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[3] = recurse(ast.body) end, @@ -4606,7 +4606,7 @@ local function recurse_node(root: Node, xs[2] = recurse(ast.from) xs[3] = recurse(ast.to) xs[4] = ast.step and recurse(ast.step) - extra_callback("before_statements", ast, xs, visit_node) + extra_callback("before_statements", s, ast, xs, visit_node) xs[5] = recurse(ast.body) end, @@ -4623,12 +4623,12 @@ local function recurse_node(root: Node, end, ["newtype"] = function(ast: Node, xs:{T}) - xs[1] = recurse_type(ast.newtype, visit_type) + xs[1] = recurse_type(s, ast.newtype, visit_type) end, ["argument"] = function(ast: Node, xs:{T}) if ast.argtype then - xs[1] = recurse_type(ast.argtype, visit_type) + xs[1] = recurse_type(s, ast.argtype, visit_type) end end, } @@ -4647,7 +4647,7 @@ local function recurse_node(root: Node, local cbkind = cbs and cbs[kind] if cbkind then if cbkind.before then - cbkind.before(ast) + cbkind.before(s, ast) end end @@ -4671,10 +4671,10 @@ local function recurse_node(root: Node, local ret: T local cbkind_after = cbkind and cbkind.after if cbkind_after then - ret = cbkind_after(ast, xs) + ret = cbkind_after(s, ast, xs) end if visit_after then - ret = visit_after(ast, xs, ret) + ret = visit_after(s, ast, xs, ret) end if TL_DEBUG then @@ -4757,7 +4757,7 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOptions): string, string +function tl.pretty_print_ast(ast: Node, gen_target: GenTarget, mode: boolean | PrettyPrintOptions): string, string local err: string local indent = 0 @@ -4778,7 +4778,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | local save_indent: {integer} = {} - local function increment_indent(node: Node) + local function increment_indent(_: nil, node: Node) local child = node.body or node[1] if not child then return @@ -4871,7 +4871,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | return table.concat(out) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} local lua_54_attribute : {Attribute:string} = { ["const"] = " ", @@ -4881,7 +4881,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_node.cbs = { ["statements"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output if opts.preserve_hashbang and node.hashbang then out = { y = 1, h = 0 } @@ -4903,7 +4903,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end }, ["local_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local ") for i, var in ipairs(node.vars) do @@ -4929,7 +4929,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["local_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if not node.var.elide_type then table.insert(out, "local") @@ -4941,7 +4941,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_type"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) @@ -4952,7 +4952,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["global_declaration"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if children[3] then add_child(out, children[1]) @@ -4963,7 +4963,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["assignment"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") @@ -4972,7 +4972,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["if"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } for i, child in ipairs(children) do add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) @@ -4983,7 +4983,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["if_block"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.if_block_n == 1 then table.insert(out, "if") @@ -5003,7 +5003,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["while"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") @@ -5016,7 +5016,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["repeat"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") @@ -5028,7 +5028,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["do"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") @@ -5039,7 +5039,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["forin"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5054,7 +5054,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["fornum"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") @@ -5074,7 +5074,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["return"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then @@ -5084,14 +5084,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["break"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["variable_list"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } local space: string for i, child in ipairs(children) do @@ -5106,7 +5106,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["literal_table"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if #children == 0 then table.insert(out, "{}") @@ -5126,7 +5126,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then @@ -5149,13 +5149,13 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["local_macroexp"] = { before = increment_indent, - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output return { y = node.y, h = 0 } end, }, ["local_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") @@ -5170,7 +5170,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["global_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5185,7 +5185,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["record_function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") @@ -5210,7 +5210,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, ["function"] = { before = increment_indent, - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) @@ -5224,7 +5224,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["cast"] = { }, ["paren"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) @@ -5233,7 +5233,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["op"] = { - after = function(node: Node, children: {Output}): Output + after = function(_: nil, node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) @@ -5294,14 +5294,14 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["variable"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } add_string(out, node.tk) return out end, }, ["newtype"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } local nt = node.newtype if nt is TypeAliasType then @@ -5318,7 +5318,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["goto"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) @@ -5326,7 +5326,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | end, }, ["label"] = { - after = function(node: Node, _children: {Output}): Output + after = function(_: nil, node: Node, _children: {Output}): Output local out: Output = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) @@ -5336,10 +5336,10 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | }, } - local visit_type: Visitor = {} + local visit_type: Visitor = {} visit_type.cbs = {} local default_type_visitor = { - after = function(typ: Type, _children: {Output}): Output + after = function(_: nil, typ: Type, _children: {Output}): Output local out: Output = { y = typ.y or -1, h = 0 } local r = typ is NominalType and typ.resolved or typ local lua_type = primitive[r.typename] or "table" @@ -5377,7 +5377,6 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] @@ -5392,7 +5391,7 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | visit_node.cbs["argument"] = visit_node.cbs["variable"] visit_node.cbs["type_identifier"] = visit_node.cbs["variable"] - local out = recurse_node(ast, visit_node, visit_type) + local out = recurse_node(nil, ast, visit_node, visit_type) if err then return nil, err end @@ -5442,7 +5441,6 @@ local typename_to_typecode : {TypeName:integer} = { ["none"] = tl.typecodes.UNKNOWN, ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, - ["unresolved"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, ["typealias"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, @@ -5450,8 +5448,8 @@ local typename_to_typecode : {TypeName:integer} = { local skip_types: {TypeName: boolean} = { ["none"] = true, + ["tuple"] = true, ["literal_table_item"] = true, - ["unresolved"] = true, } local function sorted_keys(m: {A:B}):{A} @@ -5474,6 +5472,7 @@ function tl.new_type_reporter(): TypeReporter local self: TypeReporter = { next_num = 1, typeid_to_num = {}, + typename_to_num = {}, tr = { by_pos = {}, types = {}, @@ -5481,6 +5480,24 @@ function tl.new_type_reporter(): TypeReporter globals = {}, }, } + + local names = {} + for name, _ in pairs(simple_types) do + table.insert(names, name) + end + table.sort(names) + + for _, name in ipairs(names) do + local ti: TypeInfo = { + t = assert(typename_to_typecode[name]), + str = name, + } + local n = self.next_num + self.typename_to_num[name] = n + self.tr.types[n] = ti + self.next_num = self.next_num + 1 + end + return setmetatable(self, { __index = TypeReporter }) end @@ -5500,9 +5517,15 @@ function TypeReporter:store_function(ti: TypeInfo, rt: FunctionType) end function TypeReporter:get_typenum(t: Type): integer + -- try simple types first + local n = self.typename_to_num[t.typename] + if n then + return n + end + assert(t.typeid) -- try by typeid - local n = self.typeid_to_num[t.typeid] + n = self.typeid_to_num[t.typeid] if n then return n end @@ -5526,7 +5549,7 @@ function TypeReporter:get_typenum(t: Type): integer local ti: TypeInfo = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), - file = t.filename, + file = t.f, y = t.y, x = t.x, } @@ -5596,7 +5619,7 @@ local record TypeCollector end function TypeReporter:get_collector(filename: string): TypeCollector - local tc: TypeCollector = { + local collector: TypeCollector = { filename = filename, symbol_list = {}, } @@ -5604,10 +5627,10 @@ function TypeReporter:get_collector(filename: string): TypeCollector local ft: {integer:{integer:integer}} = {} self.tr.by_pos[filename] = ft - local symbol_list = tc.symbol_list + local symbol_list = collector.symbol_list local symbol_list_n = 0 - tc.store_type = function(y: integer, x: integer, typ: Type) + collector.store_type = function(y: integer, x: integer, typ: Type) if not typ or skip_types[typ.typename] then return end @@ -5621,12 +5644,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector yt[x] = self:get_typenum(typ) end - tc.reserve_symbol_list_slot = function(node: Node) + collector.reserve_symbol_list_slot = function(node: Node) symbol_list_n = symbol_list_n + 1 node.symbol_list_slot = symbol_list_n end - tc.add_to_symbol_list = function(node: Node, name: string, t: Type) + collector.add_to_symbol_list = function(node: Node, name: string, t: Type) if not node then return end @@ -5640,12 +5663,12 @@ function TypeReporter:get_collector(filename: string): TypeCollector symbol_list[slot] = { y = node.y, x = node.x, name = name, typ = t } end - tc.begin_symbol_list_scope = function(node: Node) + collector.begin_symbol_list_scope = function(node: Node) symbol_list_n = symbol_list_n + 1 symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } end - tc.end_symbol_list_scope = function(node: Node) + collector.end_symbol_list_scope = function(node: Node) if symbol_list[symbol_list_n].name == "@{" then symbol_list[symbol_list_n] = nil symbol_list_n = symbol_list_n - 1 @@ -5655,14 +5678,14 @@ function TypeReporter:get_collector(filename: string): TypeCollector end end - return tc + return collector end -function TypeReporter:store_result(tc: TypeCollector, globals: {string:Variable}) +function TypeReporter:store_result(collector: TypeCollector, globals: {string:Variable}) local tr = self.tr - local filename = tc.filename - local symbol_list = tc.symbol_list + local filename = collector.filename + local symbol_list = collector.symbol_list tr.by_pos[filename][0] = nil @@ -5730,144 +5753,446 @@ function TypeReporter:get_report(): TypeReport return self.tr end --- backwards compatibility -function tl.get_types(result: Result): TypeReport, TypeReporter - return result.env.reporter:get_report(), result.env.reporter -end -------------------------------------------------------------------------------- --- Type check +-- Report types -------------------------------------------------------------------------------- -local NONE = a_type("none", {}) -local INVALID = a_type("invalid", {} as InvalidType) -local UNKNOWN = a_type("unknown", {}) -local CIRCULAR_REQUIRE = a_type("circular_require", {}) - -local FUNCTION = a_fn { args = va_args { ANY }, rets = va_args { ANY } } - ---local NOMINAL_FILE = a_type("nominal", { names = {"FILE"} } as NominalType) -local XPCALL_MSGH_FUNCTION = a_fn { args = { ANY }, rets = { } } - ---local USERDATA = ANY -- Placeholder for maybe having a userdata "primitive" type - -local numeric_binop = { +function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} + local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer + local function le(a: {integer, integer}, b: {integer, integer}): boolean + return a[1] < b[1] + or (a[1] == b[1] and a[2] <= b[2]) + end + return binary_search(symbols, {at_y, at_x}, le) or 0 + end + + local ret: {string:integer} = {} + + local n = find(tr.symbols, y, x) + + local symbols = tr.symbols + while n >= 1 do + local s = symbols[n] + if s[3] == "@{" then + n = n - 1 + elseif s[3] == "@}" then + n = s[4] + else + ret[s[3]] = s[4] + n = n - 1 + end + end + + return ret +end + +-------------------------------------------------------------------------------- +-- Errors +-------------------------------------------------------------------------------- + +function Errors.new(filename: string): Errors + local self = { + errors = {}, + warnings = {}, + unknown_dots = {}, + filename = filename, + } + return setmetatable(self, { __index = Errors }) +end + +local function Err(msg: string, t1?: Type, t2?: Type, t3?: Type): Error + if t1 then + local s1, s2, s3: string, string, string + if t1 is InvalidType then + return nil + end + s1 = show_type(t1) + if t2 then + if t2 is InvalidType then + return nil + end + s2 = show_type(t2) + end + if t3 then + if t3 is InvalidType then + return nil + end + s3 = show_type(t3) + end + msg = msg:format(s1, s2, s3) + return { + msg = msg, + x = t1.x, + y = t1.y, + filename = t1.f, + } + end + + return { + msg = msg, + } +end + +local function insert_error(self: Errors, y: integer, x: integer, err: Error) + err.y = assert(y) + err.x = assert(x) + err.filename = self.filename + + if TL_DEBUG then + io.stderr:write("ERROR:" .. err.y .. ":" .. err.x .. ": " .. err.msg .. "\n") + end + + table.insert(self.errors, err) +end + +function Errors:add(w: Where, msg: string, ...:Type) + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + +local context_name: {NodeKind: string} = { + ["local_declaration"] = "in local declaration", + ["global_declaration"] = "in global declaration", + ["assignment"] = "in assignment", + ["literal_table_item"] = "in table item", +} + +function Errors:get_context(ctx: Node|string, name?: string): string + if not ctx then + return "" + end + local ec = (ctx is Node) and ctx.expected_context + local cn = (ctx is string) and ctx or + (ctx is Node) and context_name[ec and ec.kind or ctx.kind] + return (cn and cn .. ": " or "") .. (ec and ec.name and ec.name .. ": " or "") .. (name and name .. ": " or "") +end + +function Errors:add_in_context(w: Where, ctx: Node, msg: string, ...:Type) + local prefix = self:get_context(ctx) + msg = prefix .. msg + + local e = Err(msg, ...) + if e then + insert_error(self, w.y, w.x, e) + end +end + + +function Errors:collect(errs: {Error}) + for _, e in ipairs(errs) do + insert_error(self, e.y, e.x, e) + end +end + +function Errors:add_warning(tag: WarningKind, w: Where, fmt: string, ...: any) + assert(w.y) + table.insert(self.warnings, { + y = w.y, + x = w.x, + msg = fmt:format(...), + filename = self.filename, + tag = tag, + }) +end + +function Errors:invalid_at(w: Where, msg: string, ...:Type): InvalidType + self:add(w, msg, ...) + return an_invalid(w) +end + +function Errors:add_unknown(node: Node, name: string) + self:add_warning("unknown", node, "unknown variable: %s", name) +end + +function Errors:redeclaration_warning(node: Node, old_var?: Variable) + if node.tk:sub(1, 1) == "_" then return end + + local var_kind = "variable" + local var_name = node.tk + if node.kind == "local_function" or node.kind == "record_function" then + var_kind = "function" + var_name = node.name.tk + end + + local short_error = "redeclaration of " .. var_kind .. " '%s'" + if old_var and old_var.declared_at then + self:add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) + else + self:add_warning("redeclaration", node, short_error, var_name) + end +end + +function Errors:unused_warning(name: string, var: Variable) + local prefix = name:sub(1,1) + if var.declared_at + and var.is_narrowed ~= "narrow" + and prefix ~= "_" + and prefix ~= "@" + then + local t = var.t + self:add_warning( + "unused", + var.declared_at, + "unused %s %s: %s", + var.is_func_arg and "argument" + or t is FunctionType and "function" + or t is TypeDeclType and "type" + or t is TypeAliasType and "type" + or "variable", + name, + show_type(var.t) + ) + end +end + +function Errors:add_prefixing(w: Where, src: {Error}, prefix: string, dst?: {Error}) + if not src then + return + end + + for _, err in ipairs(src) do + err.msg = prefix .. err.msg + if w and ( + (err.filename ~= w.f) + or (not err.y) + or (w.y > err.y or (w.y == err.y and w.x > err.x)) + ) then + err.y = w.y + err.x = w.x + err.filename = w.f + end + + if dst then + table.insert(dst, err) + else + insert_error(self, err.y, err.x, err) + end + end +end + +local record Unused + y: integer + x: integer + name: string + var: Variable +end + +local function check_for_unused_vars(scope: Scope, is_global?: boolean): {Unused} + local vars = scope.vars + if not next(vars) then + return + end + local list: {Unused} + for name, var in pairs(vars) do + local t = var.t + if var.declared_at and not var.used then + if var.used_as_type then + var.declared_at.elide_type = true + else + if (t is TypeDeclType or t is TypeAliasType) and not is_global then + var.declared_at.elide_type = true + end + list = list or {} + table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) + end + elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then + var.aliasing.used = true + var.aliasing.declared_at.elide_type = false + end + end + if list then + table.sort(list, function(a: Unused, b: Unused): boolean + return a.y < b.y or (a.y == b.y and a.x < b.x) + end) + end + return list +end + +function Errors:warn_unused_vars(scope: Scope, is_global?: boolean) + local unused = check_for_unused_vars(scope, is_global) + if unused then + for _, u in ipairs(unused) do + self:unused_warning(u.name, u.var) + end + end + + if scope.labels then + for name, node in pairs(scope.labels) do + if not node.used_label then + self:add_warning("unused", node, "unused label ::%s::", name) + end + end + end +end + +function Errors:add_unknown_dot(node: Node, name: string) + if not self.unknown_dots[name] then + self.unknown_dots[name] = true + self:add_unknown(node, name) + end +end + +function Errors:fail_unresolved_labels(scope: Scope) + if scope.pending_labels then + for name, nodes in pairs(scope.pending_labels) do + for _, node in ipairs(nodes) do + self:add(node, "no visible label '" .. name .. "' for goto") + end + end + end +end + +function Errors:fail_unresolved_nominals(scope: Scope, global_scope: Scope) + if global_scope and scope.pending_nominals then + for name, types in pairs(scope.pending_nominals) do + if not global_scope.pending_global_types[name] then + for _, typ in ipairs(types) do + assert(typ.x) + assert(typ.y) + self:add(typ, "unknown type %s", typ) + end + end + end + end +end + +local type CheckableKey = string | number | boolean + +function Errors:check_redeclared_key(w: Where, ctx: Node, seen_keys: {CheckableKey:Where}, key: CheckableKey) + if key ~= nil then + local s = seen_keys[key] + if s then + self:add_in_context(w, ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. self.filename .. ":" .. s.y .. ":" .. s.x .. ")") + else + seen_keys[key] = w + end + end +end + +-------------------------------------------------------------------------------- +-- Type check +-------------------------------------------------------------------------------- + +local numeric_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, + ["integer"] = "integer", + ["number"] = "number", }, } local float_binop = { ["number"] = { - ["number"] = NUMBER, - ["integer"] = NUMBER, + ["number"] = "number", + ["integer"] = "number", }, ["integer"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, + ["integer"] = "number", + ["number"] = "number", }, } local integer_binop = { ["number"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = INTEGER, + ["integer"] = "integer", + ["number"] = "integer", }, } local relational_binop = { ["number"] = { - ["integer"] = BOOLEAN, - ["number"] = BOOLEAN, + ["integer"] = "boolean", + ["number"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, + ["string"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, } local equality_binop = { ["number"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["integer"] = { - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["number"] = "boolean", + ["integer"] = "boolean", + ["nil"] = "boolean", }, ["string"] = { - ["string"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["string"] = "boolean", + ["nil"] = "boolean", }, ["boolean"] = { - ["boolean"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["boolean"] = "boolean", + ["nil"] = "boolean", }, ["record"] = { - ["emptytable"] = BOOLEAN, - ["record"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["record"] = "boolean", + ["nil"] = "boolean", }, ["array"] = { - ["emptytable"] = BOOLEAN, - ["array"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["array"] = "boolean", + ["nil"] = "boolean", }, ["map"] = { - ["emptytable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["emptytable"] = "boolean", + ["map"] = "boolean", + ["nil"] = "boolean", }, ["thread"] = { - ["thread"] = BOOLEAN, - ["nil"] = BOOLEAN, + ["thread"] = "boolean", + ["nil"] = "boolean", } } -local unop_types: {string:{string:Type}} = { +local unop_types: {string:{TypeName:TypeName}} = { ["#"] = { - ["string"] = INTEGER, - ["array"] = INTEGER, - ["tupletable"] = INTEGER, - ["map"] = INTEGER, - ["emptytable"] = INTEGER, + ["string"] = "integer", + ["array"] = "integer", + ["tupletable"] = "integer", + ["map"] = "integer", + ["emptytable"] = "integer", }, ["-"] = { - ["number"] = NUMBER, - ["integer"] = INTEGER, + ["number"] = "number", + ["integer"] = "integer", }, ["~"] = { - ["number"] = INTEGER, - ["integer"] = INTEGER, + ["number"] = "integer", + ["integer"] = "integer", }, ["not"] = { - ["string"] = BOOLEAN, - ["number"] = BOOLEAN, - ["integer"] = BOOLEAN, - ["boolean"] = BOOLEAN, - ["record"] = BOOLEAN, - ["array"] = BOOLEAN, - ["tupletable"] = BOOLEAN, - ["map"] = BOOLEAN, - ["emptytable"] = BOOLEAN, - ["thread"] = BOOLEAN, + ["string"] = "boolean", + ["number"] = "boolean", + ["integer"] = "boolean", + ["boolean"] = "boolean", + ["record"] = "boolean", + ["array"] = "boolean", + ["tupletable"] = "boolean", + ["map"] = "boolean", + ["emptytable"] = "boolean", + ["thread"] = "boolean", }, } @@ -5877,7 +6202,7 @@ local unop_to_metamethod: {string:string} = { ["~"] = "__bnot", } -local binop_types: {string:{TypeName:{TypeName:Type}}} = { +local binop_types: {string:{TypeName:{TypeName:TypeName}}} = { ["+"] = numeric_binop, ["-"] = numeric_binop, ["*"] = numeric_binop, @@ -5898,67 +6223,66 @@ local binop_types: {string:{TypeName:{TypeName:Type}}} = { [">"] = relational_binop, ["or"] = { ["boolean"] = { - ["boolean"] = BOOLEAN, - ["function"] = FUNCTION, -- HACK + ["boolean"] = "boolean", }, ["number"] = { - ["integer"] = NUMBER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "number", + ["number"] = "number", + ["boolean"] = "boolean", }, ["integer"] = { - ["integer"] = INTEGER, - ["number"] = NUMBER, - ["boolean"] = BOOLEAN, + ["integer"] = "integer", + ["number"] = "number", + ["boolean"] = "boolean", }, ["string"] = { - ["string"] = STRING, - ["boolean"] = BOOLEAN, - ["enum"] = STRING, + ["string"] = "string", + ["boolean"] = "boolean", + ["enum"] = "string", }, ["function"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["array"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["record"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["map"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", }, ["enum"] = { - ["string"] = STRING, + ["string"] = "string", }, ["thread"] = { - ["boolean"] = BOOLEAN, + ["boolean"] = "boolean", } }, [".."] = { ["string"] = { - ["string"] = STRING, - ["enum"] = STRING, - ["number"] = STRING, - ["integer"] = STRING, + ["string"] = "string", + ["enum"] = "string", + ["number"] = "string", + ["integer"] = "string", }, ["number"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["integer"] = { - ["integer"] = STRING, - ["number"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["integer"] = "string", + ["number"] = "string", + ["string"] = "string", + ["enum"] = "string", }, ["enum"] = { - ["number"] = STRING, - ["integer"] = STRING, - ["string"] = STRING, - ["enum"] = STRING, + ["number"] = "string", + ["integer"] = "string", + ["string"] = "string", + ["enum"] = "string", } }, } @@ -6166,8 +6490,8 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end end -local function inferred_msg(t: Type): string - return " (inferred at "..t.inferred_at.filename..":"..t.inferred_at.y..":"..t.inferred_at.x..")" +local function inferred_msg(t: Type, prefix?: string): string + return " (" .. (prefix or "") .. "inferred at "..t.inferred_at.f..":"..t.inferred_at.y..":"..t.inferred_at.x..")" end show_type = function(t: Type, short?: boolean, seen?: {Type:string}): string @@ -6219,33 +6543,34 @@ function tl.search_module(module_name: string, search_dtl: boolean): string, FIL return nil, nil, tried end -local function require_module(module_name: string, lax: boolean, env: Env): Type, boolean +local function require_module(w: Where, module_name: string, feat_lax: boolean, env: Env): Type, string local mod = env.modules[module_name] if mod then - return mod, true + return mod, env.module_filenames[module_name] end local found, fd = tl.search_module(module_name, true) - if found and (lax or found:match("tl$") as boolean) then + if found and (feat_lax or found:match("tl$") as boolean) then - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + env.module_filenames[module_name] = found + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) local found_result, err: Result, string = tl.process(found, env, fd) assert(found_result, err) env.modules[module_name] = found_result.type - return found_result.type, true + return found_result.type, found elseif fd then fd:close() end - return INVALID, found ~= nil + return an_invalid(w), found end local compat_code_cache: {string:Node} = {} -local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: CompatMode) +local function add_compat_entries(program: Node, used_set: {string: boolean}, gen_compat: GenCompat) if gen_compat == "off" or not next(used_set) then return end @@ -6262,7 +6587,7 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge local code: Node = compat_code_cache[name] if not code then code = tl.parse(text, "@internal") - tl.type_check(code, { filename = "", lax = false, gen_compat = "off" }) + tl.type_check(code, "@internal", { feat_lax = "off", gen_compat = "off" }) compat_code_cache[name] = code end for _, c in ipairs(code) do @@ -6301,32 +6626,26 @@ local function add_compat_entries(program: Node, used_set: {string: boolean}, ge TL_DEBUG = tl_debug end -local function get_stdlib_compat(lax: boolean): {string:boolean} - if lax then - return { - ["utf8"] = true, - } - else - return { - ["io"] = true, - ["math"] = true, - ["string"] = true, - ["table"] = true, - ["utf8"] = true, - ["coroutine"] = true, - ["os"] = true, - ["package"] = true, - ["debug"] = true, - ["load"] = true, - ["loadfile"] = true, - ["assert"] = true, - ["pairs"] = true, - ["ipairs"] = true, - ["pcall"] = true, - ["xpcall"] = true, - ["rawlen"] = true, - } - end +local function get_stdlib_compat(): {string:boolean} + return { + ["io"] = true, + ["math"] = true, + ["string"] = true, + ["table"] = true, + ["utf8"] = true, + ["coroutine"] = true, + ["os"] = true, + ["package"] = true, + ["debug"] = true, + ["load"] = true, + ["loadfile"] = true, + ["assert"] = true, + ["pairs"] = true, + ["ipairs"] = true, + ["pcall"] = true, + ["xpcall"] = true, + ["rawlen"] = true, + } end local bit_operators: {string:string} = { @@ -6337,14 +6656,21 @@ local bit_operators: {string:string} = { ["<<"] = "lshift", } +local function node_at(w: Where, n: Node): Node + n.f = assert(w.f) + n.x = w.x + n.y = w.y + return n +end + local function convert_node_to_compat_call(node: Node, mod_name: string, fn_name: string, e1: Node, e2?: Node) node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } - node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } - node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } + node.e1 = node_at(node, { kind = "op", op = an_operator(node, 2, ".") }) + node.e1.e1 = node_at(node, { kind = "identifier", tk = mod_name }) + node.e1.e2 = node_at(node, { kind = "identifier", tk = fn_name }) + node.e2 = node_at(node, { kind = "expression_list" }) node.e2[1] = e1 node.e2[2] = e2 end @@ -6353,10 +6679,10 @@ local function convert_node_to_compat_mt_call(node: Node, mt_name: string, which node.op.op = "@funcall" node.op.arity = 2 node.op.prec = 100 - node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } - node.e2 = { y = node.y, x = node.x, kind = "expression_list" } - node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } - node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } + node.e1 = node_at(node, { kind = "identifier", tk = "_tl_mt" }) + node.e2 = node_at(node, { kind = "expression_list" }) + node.e2[1] = node_at(node, { kind = "string", tk = "\"" .. mt_name .. "\"" }) + node.e2[2] = node_at(node, { kind = "integer", tk = tostring(which_self) }) node.e2[3] = e1 node.e2[4] = e2 end @@ -6365,25 +6691,6 @@ local stdlib_globals: {string:Variable} = nil local globals_typeid = new_typeid() local fresh_typevar_ctr = 1 -local function set_feat(feat: tl.Feat, default: boolean): boolean - if feat then - return (feat == "on") - else - return default - end -end - -tl.new_env = function(opts: tl.EnvOptions): Env, string - local env, err = tl.init_env(opts.lax_mode, opts.gen_compat, opts.gen_target, opts.predefined_modules) - if not env then - return nil, err - end - - env.feat_arity = set_feat(opts.feat_arity, true) - - return env -end - local function assert_no_stdlib_errors(errors: {Error}, name: string) if #errors ~= 0 then local out = {} @@ -6394,46 +6701,31 @@ local function assert_no_stdlib_errors(errors: {Error}, name: string) end end -tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_target?: TargetMode, predefined?: {string}): Env, string - if gen_compat == true or gen_compat == nil then - gen_compat = "optional" - elseif gen_compat == false then - gen_compat = "off" - end - gen_compat = gen_compat as CompatMode - - if not gen_target then - if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then - gen_target = "5.1" - else - gen_target = "5.3" - end - end - - if gen_target == "5.4" and gen_compat ~= "off" then - return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" - end +tl.new_env = function(opts?: EnvOptions): Env, string + opts = opts or {} local env: Env = { modules = {}, + module_filenames = {}, loaded = {}, loaded_order = {}, globals = {}, - gen_compat = gen_compat, - gen_target = gen_target, + defaults = opts.defaults or {}, } + if env.defaults.gen_target == "5.4" and env.defaults.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + local w: Where = { f = "@stdlib", x = 1, y = 1 } + if not stdlib_globals then local tl_debug = TL_DEBUG TL_DEBUG = nil local program, syntax_errors = tl.parse(stdlib, "stdlib.d.tl") assert_no_stdlib_errors(syntax_errors, "syntax errors") - - local result = tl.type_check(program, { - filename = "@stdlib", - env = env - }) + local result = tl.type_check(program, "@stdlib", {}, env) assert_no_stdlib_errors(result.type_errors, "type errors") stdlib_globals = env.globals @@ -6442,21 +6734,20 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar -- special cases for compatibility local math_t = (stdlib_globals["math"].t as TypeDeclType).def as RecordType local table_t = (stdlib_globals["table"].t as TypeDeclType).def as RecordType - local integer_compat = a_type("integer", { needs_compat = true }) - math_t.fields["maxinteger"] = integer_compat - math_t.fields["mininteger"] = integer_compat + math_t.fields["maxinteger"].needs_compat = true + math_t.fields["mininteger"].needs_compat = true table_t.fields["unpack"].needs_compat = true -- only global scope and vararg functions accept `...`: -- `@is_va` is an internal sentinel value which is -- `any` if `...` is accepted in this scope or `nil` if it isn't. - stdlib_globals["..."] = { t = a_vararg { STRING } } - stdlib_globals["@is_va"] = { t = ANY } + stdlib_globals["..."] = { t = a_vararg(w, { a_type(w, "string", {}) }) } + stdlib_globals["@is_va"] = { t = a_type(w, "any", {}) } env.globals = {} end - local stdlib_compat = get_stdlib_compat(lax) + local stdlib_compat = get_stdlib_compat() for name, var in pairs(stdlib_globals) do env.globals[name] = var var.needs_compat = stdlib_compat[name] @@ -6467,52 +6758,53 @@ tl.init_env = function(lax?: boolean, gen_compat?: boolean | CompatMode, gen_tar end end - if predefined then - for _, name in ipairs(predefined) do - local module_type = require_module(name, lax, env) + if opts.predefined_modules then + for _, name in ipairs(opts.predefined_modules) do + local module_type = require_module(w, name, env.defaults.feat_lax == "on", env) - if module_type == INVALID then + if module_type is InvalidType then return nil, string.format("Error: could not predefine module '%s'", name) end end end - env.feat_arity = true - return env end -tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string - opts = opts or {} - local env = opts.env - if not env then - local err: string - env, err = tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) - if err then - return nil, err - end - end +do + local type TypeRelations = {TypeName:{TypeName:CompareTypes}} + local type InvalidOrTupleType = InvalidType | TupleType - local lax = opts.lax - local feat_arity = env.feat_arity - local filename = opts.filename + local record TypeChecker + env: Env + st: {Scope} + + filename: string + errs: Errors + module_type: Type - local type Scope = {string:Variable} - local st: {Scope} = { env.globals } + subtype_relations: TypeRelations + eqtype_relations: TypeRelations + type_priorities: {TypeName:integer} - local all_needs_compat = {} + all_needs_compat: {string:boolean} + dependencies: {string:string} + collector: TypeCollector + + gen_compat: GenCompat + gen_target: GenTarget + feat_arity: boolean + feat_lax: boolean - local dependencies: {string:string} = {} - local warnings: {Error} = {} - local errors: {Error} = {} + same_type: function(TypeChecker, Type, Type): boolean, {Error} + is_a: function(TypeChecker, Type, Type): boolean, {Error} - local module_type: Type + type_check_funcall: function(TypeChecker, node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType - local tc: TypeCollector - if env.report_types then - env.reporter = env.reporter or tl.new_type_reporter() - tc = env.reporter:get_collector(filename or "?") + expand_type: function(TypeChecker, w: Where, old: Type, new: Type): Type + + get_rets: function(TupleType): TupleType end local enum VarUse @@ -6522,10 +6814,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "check_only" end - local function find_var(name: string, use?: VarUse): Variable, integer, Attribute - for i = #st, 1, -1 do - local scope = st[i] - local var = scope[name] + function TypeChecker:find_var(name: string, use?: VarUse): Variable, integer, Attribute + for i = #self.st, 1, -1 do + local scope = self.st[i] + local var = scope.vars[name] if var then if use == "lvalue" and var.is_narrowed then if var.narrowed_from then @@ -6534,7 +6826,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else if i == 1 and var.needs_compat then - all_needs_compat[name] = true + self.all_needs_compat[name] = true end if use == "use_type" then var.used_as_type = true @@ -6547,10 +6839,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function simulate_g(): RecordType, Attribute + function TypeChecker:simulate_g(): RecordType, Attribute -- this is a static approximation of _G local globals: {string:Type} = {} - for k, v in pairs(st[1]) do + for k, v in pairs(self.st[1].vars) do if k:sub(1,1) ~= "@" then globals[k] = v.t end @@ -6563,101 +6855,61 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, nil end - local type ResolveType = function(Type): Type - local resolve_typevars: function (typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local type ResolveType = function(S, Type): Type + local typevar_resolver: function(s: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} - local function fresh_typevar(t: TypeVarType): Type, Type, boolean - return a_type("typevar", { + local function fresh_typevar(_: nil, t: TypeVarType): Type, Type, boolean + return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeVarType) end - local function fresh_typearg(t: TypeArgType): Type - return a_type("typearg", { + local function fresh_typearg(_: nil, t: TypeArgType): Type + return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeArgType) end - local function ensure_fresh_typeargs(t: T): T + function TypeChecker:ensure_fresh_typeargs(t: T): T if not t is HasTypeArgs then return t end fresh_typevar_ctr = fresh_typevar_ctr + 1 local ok: boolean - ok, t = resolve_typevars(t, fresh_typevar, fresh_typearg) + ok, t = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) assert(ok, "Internal Compiler Error: error creating fresh type variables") return t end - local function find_var_type(name: string, use?: VarUse): Type, Attribute, Type - local var = find_var(name, use) + function TypeChecker:find_var_type(name: string, use?: VarUse): Type, Attribute, Type + local var = self:find_var(name, use) if var then local t = var.t if t is UnresolvedTypeArgType then return nil, nil, t.constraint end - t = ensure_fresh_typeargs(t) + t = self:ensure_fresh_typeargs(t) return t, var.attribute end end - local function Err(where: Where, msg: string, ...: Type): Error - local n = select("#", ...) - if n > 0 then - local showt = {} - for i = 1, n do - local t = select(i, ...) - if t then - if t.typename == "invalid" then - return nil - end - showt[i] = show_type(t) - end - end - msg = msg:format(table.unpack(showt)) - end - local name = where.filename or filename - - if TL_DEBUG then - io.stderr:write("ERROR:" .. (where.y or -1) .. ":" .. (where.x or -1) .. ": " .. msg .. "\n") - end - - return { - y = where.y, - x = where.x, - msg = msg, - filename = name, - } - end - - local function error_at(w: Where, msg: string, ...:Type): boolean - assert(w.y) - - local e = Err(w, msg, ...) - if e then - table.insert(errors, e) - return true - else - return false - end - end - - local function ensure_not_abstract(where: Where, t: Type) + local function ensure_not_abstract(t: Type): boolean, string if t is FunctionType and t.macroexp then - error_at(where, "macroexps are abstract; consider using a concrete function") + return nil, "macroexps are abstract; consider using a concrete function" elseif t is TypeDeclType then local def = t.def if def is InterfaceType then - error_at(where, "interfaces are abstract; consider using a concrete record") + return nil, "interfaces are abstract; consider using a concrete record" end end + return true end - local function find_type(names: {string}, accept_typearg?: boolean): Type - local typ = find_var_type(names[1], "use_type") + function TypeChecker:find_type(names: {string}, accept_typearg?: boolean): Type + local typ = self:find_var_type(names[1], "use_type") if not typ then return nil end @@ -6679,7 +6931,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - typ = ensure_fresh_typeargs(typ) + typ = self:ensure_fresh_typeargs(typ) if typ is NominalType and typ.found then typ = typ.found end @@ -6691,19 +6943,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function union_type(t: Type): string, Type + local function type_for_union(t: Type): string, Type if t is TypeDeclType then - return union_type(t.def), t.def + return type_for_union(t.def), t.def elseif t is TypeAliasType then - return union_type(t.alias_to), t.alias_to + return type_for_union(t.alias_to), t.alias_to elseif t is TupleType then - return union_type(t.tuple[1]), t.tuple[1] + return type_for_union(t.tuple[1]), t.tuple[1] elseif t is NominalType then local typedecl = t.found if not typedecl then return "invalid" end - return union_type(typedecl) + return type_for_union(typedecl) elseif t is RecordLikeType then if t.is_userdata then return "userdata", t @@ -6727,7 +6979,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local n_string_enum = 0 local has_primitive_string_type = false for _, t in ipairs(typ.types) do - local ut, rt = union_type(t) + local ut, rt = type_for_union(t) if ut == "userdata" then -- must be tested before table_types assert(rt is RecordLikeType) if rt.meta_fields and rt.meta_fields["__is"] then @@ -6808,24 +7060,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["unknown"] = true, } - local function default_resolve_typevars_callback(t: TypeVarType): Type - local rt = find_var_type(t.typevar) - if not rt then - return nil - elseif rt is StringType then - -- tk is not propagated - return STRING - end - return rt - end - - resolve_typevars = function(typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + typevar_resolver = function(self: S, typ: Type, fn_var?: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:boolean} = {} - fn_var = fn_var or default_resolve_typevars_callback - local function resolve(t: T, all_same: boolean): T, boolean local same = true @@ -6840,7 +7079,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local orig_t = t if t is TypeVarType then - local rt = fn_var(t) + local rt = fn_var(self, t) if rt then resolved[t.typevar] = true if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then @@ -6856,7 +7095,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string seen[orig_t] = copy copy.typename = t.typename - copy.filename = t.filename + copy.f = t.f copy.x = t.x copy.y = t.y @@ -6867,7 +7106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- inferred_len is not propagated elseif t is TypeArgType then if fn_arg then - copy = fn_arg(t) + copy = fn_arg(self, t) else assert(copy is TypeArgType) copy.typearg = t.typearg @@ -6960,7 +7199,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local _, err = is_valid_union(copy) if err then errs = errs or {} - table.insert(errs, Err(t, err, copy)) + table.insert(errs, Err(err, copy)) end elseif t is PolyType then assert(copy is PolyType) @@ -6970,6 +7209,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end elseif t is TupleTableType then assert(copy is TupleTableType) + copy.inferred_at = t.inferred_at copy.types = {} for i, tf in ipairs(t.types) do copy.types[i], same = resolve(tf, same) @@ -6989,7 +7229,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local copy, same = resolve(typ, true) if errs then - return false, INVALID, errs + return false, an_invalid(typ), errs end if (not same) and @@ -7008,153 +7248,81 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true, copy end - local function infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) + local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + elseif rt is StringType then + -- tk is not propagated + return a_type(rt, "string", {}) + end + return rt + end + + + + function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") - local nst = is_global and 1 or #st + local nst = is_global and 1 or #self.st for i = nst, 1, -1 do - local scope = st[i] - if scope[emptytable.assigned_to] then - scope[emptytable.assigned_to] = { t = fresh_t } + local scope = self.st[i] + if scope.vars[emptytable.assigned_to] then + scope.vars[emptytable.assigned_to] = { t = fresh_t } end end end local function resolve_tuple(t: Type): Type - if t is TupleType then - t = t.tuple[1] + local rt = t + if rt is TupleType then + rt = rt.tuple[1] end - if t == nil then - return NIL + if rt == nil then + return a_type(t, "nil", {}) end - return t - end - - local function add_warning(tag: tl.WarningKind, where: Where, fmt: string, ...: any) - table.insert(warnings, { - y = where.y, - x = where.x, - msg = fmt:format(...), - filename = where.filename or filename, - tag = tag, - }) - end - - local function invalid_at(where: Where, msg: string, ...:Type): InvalidType - error_at(where, msg, ...) - return INVALID - end - - local function add_unknown(node: Node, name: string) - add_warning("unknown", node, "unknown variable: %s", name) + return rt end - local function redeclaration_warning(node: Node, old_var?: Variable) - if node.tk:sub(1, 1) == "_" then return end - - local var_kind = "variable" - local var_name = node.tk - if node.kind == "local_function" or node.kind == "record_function" then - var_kind = "function" - var_name = node.name.tk - end - - local short_error = "redeclaration of " .. var_kind .. " '%s'" - if old_var and old_var.declared_at then - add_warning("redeclaration", node, short_error .. " (originally declared at %d:%d)", var_name, old_var.declared_at.y, old_var.declared_at.x) - else - add_warning("redeclaration", node, short_error, var_name) - end - end - local function check_if_redeclaration(new_name: string, at: Node) - local old = find_var(new_name, "check_only") + function TypeChecker:check_if_redeclaration(new_name: string, at: Node) + local old = self:find_var(new_name, "check_only") if old then - redeclaration_warning(at, old) + self.errs:redeclaration_warning(at, old) end end - local function unused_warning(name: string, var: Variable) - local prefix = name:sub(1,1) - if var.declared_at - and var.is_narrowed ~= "narrow" - and prefix ~= "_" - and prefix ~= "@" - then - if name:sub(1, 2) == "::" then - add_warning("unused", var.declared_at, "unused label %s", name) - else - local t = var.t - add_warning( - "unused", - var.declared_at, - "unused %s %s: %s", - var.is_func_arg and "argument" - or t is FunctionType and "function" - or t is TypeDeclType and "type" - or t is TypeAliasType and "type" - or "variable", - name, - show_type(var.t) - ) - end - end - end - - local function add_errs_prefixing(where: Where, src: {Error}, dst: {Error}, prefix: string) - assert(where == nil or where.y ~= nil) - - if not src then - return - end - for _, err in ipairs(src) do - err.msg = prefix .. err.msg - - if where and ( - (err.filename ~= filename) - or (not err.y) - or (where.y > err.y or (where.y == err.y and where.x > err.x)) - ) then - err.y = where.y - err.x = where.x - err.filename = filename - end - - table.insert(dst, err) - end - end local function type_at(w: Where, t: T): T t.x = w.x t.y = w.y - t.filename = filename return t end - local function resolve_typevars_at(where: Where, t: T): T - assert(where) - local ok, ret, errs = resolve_typevars(t) + function TypeChecker:resolve_typevars_at(w: Where, t: T): T + assert(w) + local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) if not ok then - assert(where.y) - add_errs_prefixing(where, errs, errors, "") + assert(w.y) + self.errs:add_prefixing(w, errs, "") end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - return type_at(where, ret) + return type_at(w, ret) end - local function infer_at(where: Where, t: T): T - local ret = resolve_typevars_at(where, t) - if ret.typename == "invalid" then + function TypeChecker:infer_at(w: Where, t: T): T + local ret = self:resolve_typevars_at(w, t) + if ret is InvalidType then ret = t -- errors are produced by resolve_typevars_at end - if ret == t or t.typename == "typevar" then + if ret == t or t is TypeVarType then ret = shallow_copy_table(ret) end - ret.inferred_at = where - ret.inferred_at.filename = filename + assert(w.f) + ret.inferred_at = w return ret end @@ -7167,12 +7335,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local get_unresolved: function(scope?: Scope): UnresolvedType - local find_unresolved: function(level?: integer): UnresolvedType - - local function add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable - local scope = st[#st] - local var = scope[name] + function TypeChecker:add_to_scope(node: Node, name: string, t: Type, attribute: Attribute, narrow: Narrow, dont_check_redeclaration: boolean): Variable + local scope = self.st[#self.st] + local var = scope.vars[name] if narrow then if var then if var.is_narrowed then @@ -7185,11 +7350,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string var.t = t else var = { t = t, attribute = attribute, is_narrowed = narrow, declared_at = node } - scope[name] = var + scope.vars[name] = var end - local unresolved = get_unresolved(scope) - unresolved.narrows[name] = true + scope.narrows = scope.narrows or {} + scope.narrows[name] = true return var end @@ -7200,46 +7365,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and name ~= "..." and name:sub(1, 1) ~= "@" then - check_if_redeclaration(name, node) + self:check_if_redeclaration(name, node) end if var and not var.used then -- the old var is removed from the scope and won't be checked when it closes, -- so check it here - unused_warning(name, var) + self.errs:unused_warning(name, var) end var = { t = t, attribute = attribute, is_narrowed = nil, declared_at = node } - scope[name] = var + scope.vars[name] = var return var end - local function add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable - if lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then - add_unknown(node, name) + function TypeChecker:add_var(node: Node, name: string, t: Type, attribute?: Attribute, narrow?: Narrow, dont_check_redeclaration?: boolean): Variable + if self.feat_lax and node and is_unknown(t) and (name ~= "self" and name ~= "...") and not narrow then + self.errs:add_unknown(node, name) end if not attribute then t = drop_constant_value(t) end - local var = add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - - if t is UnresolvedType or t.typename == "none" then - return var - end + local var = self:add_to_scope(node, name, t, attribute, narrow, dont_check_redeclaration) - if tc and node then - tc.add_to_symbol_list(node, name, t) + if self.collector and node then + self.collector.add_to_symbol_list(node, name, t) end return var end - local type CompareTypes = function(Type, Type): boolean, {Error} - - local same_type: function(t1: Type, t2: Type): boolean, {Error} - local is_a: function(Type, Type): boolean, {Error} + local type CompareTypes = function(TypeChecker, Type, Type): boolean, {Error} local enum ArgCheckMode "argument" @@ -7254,38 +7412,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "invariant" end - local function arg_check(where: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean + function TypeChecker:arg_check(w: Where, all_errs: {Error}, a: Type, b: Type, v: VarianceMode, mode: ArgCheckMode, n?: integer): boolean local ok, errs: boolean, {Error} if v == "covariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) elseif v == "contravariant" then - ok, errs = is_a(b, a) + ok, errs = self:is_a(b, a) elseif v == "bivariant" then - ok, errs = is_a(a, b) + ok, errs = self:is_a(a, b) if ok then return true end - ok = is_a(b, a) + ok = self:is_a(b, a) if ok then return true end elseif v == "invariant" then - ok, errs = same_type(a, b) + ok, errs = self:same_type(a, b) end if not ok then - add_errs_prefixing(where, errs, all_errs, mode .. (n and " " .. n or "") .. ": ") + self.errs:add_prefixing(w, errs, mode .. (n and " " .. n or "") .. ": ", all_errs) return false end return true end - local function has_all_types_of(t1s: {Type}, t2s: {Type}): boolean + function TypeChecker:has_all_types_of(t1s: {Type}, t2s: {Type}): boolean for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do - if same_type(t2, t1) then + if self:same_type(t2, t1) then found = true break end @@ -7317,8 +7475,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function close_types(vars: {string:Variable}) - for _, var in pairs(vars) do + local function close_types(scope: Scope) + for _, var in pairs(scope.vars) do local t = var.t if t is TypeDeclType then t.closed = true @@ -7330,161 +7488,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local record Unused - y: integer - x: integer - name: string - var: Variable - end - - local function check_for_unused_vars(vars: {string:Variable}, is_global?: boolean) - if not next(vars) then - return - end - local list: {Unused} = {} - for name, var in pairs(vars) do - local t = var.t - if var.declared_at and not var.used then - if var.used_as_type then - var.declared_at.elide_type = true - else - if (t is TypeDeclType or t is TypeAliasType) and not is_global then - var.declared_at.elide_type = true - end - table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) - end - elseif var.used and (t is TypeDeclType or t is TypeAliasType) and var.aliasing then - var.aliasing.used = true - var.aliasing.declared_at.elide_type = false - end - end - if list[1] then - table.sort(list, function(a: Unused, b: Unused): boolean - return a.y < b.y or (a.y == b.y and a.x < b.x) - end) - for _, u in ipairs(list) do - unused_warning(u.name, u.var) - end - end - end - - get_unresolved = function(scope?: Scope): UnresolvedType - local unresolved: UnresolvedType - if scope then - local unr = scope["@unresolved"] - unresolved = unr and unr.t as UnresolvedType - else - unresolved = find_var_type("@unresolved") as UnresolvedType - end - if not unresolved then - unresolved = a_type("unresolved", { - labels = {}, - nominals = {}, - global_types = {}, - narrows = {}, - } as UnresolvedType) - add_var(nil, "@unresolved", unresolved) - end - return unresolved - end - - find_unresolved = function(level?: integer): UnresolvedType - local u = st[level or #st]["@unresolved"] - if u then - return u.t as UnresolvedType - end - end - - local function begin_scope(node?: Node) - table.insert(st, {}) + function TypeChecker:begin_scope(node?: Node) + table.insert(self.st, { vars = {} }) - if tc and node then - tc.begin_symbol_list_scope(node) + if self.collector and node then + self.collector.begin_symbol_list_scope(node) end end - local function end_scope(node?: Node) + function TypeChecker:end_scope(node?: Node) + local st = self.st local scope = st[#st] - local unresolved = scope["@unresolved"] - if unresolved then - local unrt = unresolved.t as UnresolvedType - local next_scope = st[#st - 1] - local upper = next_scope["@unresolved"] - if upper then - local uppert = upper.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do + local next_scope = st[#st - 1] + + if next_scope then + if scope.pending_labels then + next_scope.pending_labels = next_scope.pending_labels or {} + for name, nodes in pairs(scope.pending_labels) do for _, n in ipairs(nodes) do - uppert.labels[name] = uppert.labels[name] or {} - table.insert(uppert.labels[name], n) + next_scope.pending_labels[name] = next_scope.pending_labels[name] or {} + table.insert(next_scope.pending_labels[name], n) end end - for name, types in pairs(unrt.nominals) do + scope.pending_labels = nil + end + if scope.pending_nominals then + next_scope.pending_nominals = next_scope.pending_nominals or {} + for name, types in pairs(scope.pending_nominals) do for _, typ in ipairs(types) do - uppert.nominals[name] = uppert.nominals[name] or {} - table.insert(uppert.nominals[name], typ) + next_scope.pending_nominals[name] = next_scope.pending_nominals[name] or {} + table.insert(next_scope.pending_nominals[name], typ) end end - for name, _ in pairs(unrt.global_types) do - uppert.global_types[name] = true - end - else - next_scope["@unresolved"] = unresolved - unrt.narrows = {} + scope.pending_nominals = nil end end + close_types(scope) - check_for_unused_vars(scope) + self.errs:warn_unused_vars(scope) + table.remove(st) - if tc and node then - tc.end_symbol_list_scope(node) + if self.collector and node then + self.collector.end_symbol_list_scope(node) end end - local end_scope_and_none_type = function(node: Node, _children: {Type}): Type - end_scope(node) + -- This type must never be used for any values + local NONE = a_type({ f = "@none", x = -1, y = -1 }, "none", {}) + + local function end_scope_and_none_type(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) return NONE end local type InvalidOrTypeDeclType = InvalidType | TypeDeclType - local resolve_nominal: function(t: NominalType): Type - local resolve_typealias: function(t: TypeAliasType): InvalidOrTypeDeclType do - local function match_typevals(t: NominalType, def: RecordLikeType | FunctionType): Type + local function match_typevals(self: TypeChecker, t: NominalType, def: RecordLikeType | FunctionType): Type if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then - error_at(t, "mismatch in number of type arguments") + self.errs:add(t, "mismatch in number of type arguments") return nil end - begin_scope() + self:begin_scope() for i, tt in ipairs(t.typevals) do - add_var(nil, def.typeargs[i].typearg, tt) + self:add_var(nil, def.typeargs[i].typearg, tt) end - local ret = resolve_typevars_at(t, def) - end_scope() + local ret = self:resolve_typevars_at(t, def) + self:end_scope() return ret elseif t.typevals then - error_at(t, "spurious type arguments") + self.errs:add(t, "spurious type arguments") return nil elseif def.typeargs then - error_at(t, "missing type arguments in %s", def) + self.errs:add(t, "missing type arguments in %s", def) return nil else return def end end - local function find_nominal_type_decl(t: NominalType): Type, TypeDeclType + local function find_nominal_type_decl(self: TypeChecker, t: NominalType): Type, TypeDeclType if t.resolved then return t.resolved end - local found = t.found or find_type(t.names) + local found = t.found or self:find_type(t.names) if not found then - error_at(t, "unknown type %s", t) - return INVALID + return self.errs:invalid_at(t, "unknown type %s", t) end if found is TypeAliasType then @@ -7492,8 +7585,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not found is TypeDeclType then - error_at(t, table.concat(t.names, ".") .. " is not a type") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " is not a type") end local def = found.def @@ -7508,44 +7600,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, found end - local function resolve_decl_into_nominal(t: NominalType, found: TypeDeclType): Type + local function resolve_decl_into_nominal(self: TypeChecker, t: NominalType, found: TypeDeclType): Type local def = found.def local resolved: Type if def is RecordType or def is FunctionType then - resolved = match_typevals(t, def) + resolved = match_typevals(self, t, def) if not resolved then - error_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") - return INVALID + return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") end else resolved = def end - if not t.filename then - t.filename = resolved.filename - if t.x == nil and t.y == nil then - t.x = resolved.x - t.y = resolved.y - end - end - t.resolved = resolved return resolved end - resolve_nominal = function(t: NominalType): Type - local immediate, found = find_nominal_type_decl(t) + function TypeChecker:resolve_nominal(t: NominalType): Type + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end - return resolve_decl_into_nominal(t, found) + return resolve_decl_into_nominal(self, t, found) end - resolve_typealias = function(typealias: TypeAliasType): InvalidOrTypeDeclType + function TypeChecker:resolve_typealias(typealias: TypeAliasType): InvalidOrTypeDeclType local t = typealias.alias_to - local immediate, found = find_nominal_type_decl(t) + local immediate, found = find_nominal_type_decl(self, t) if immediate then return immediate end @@ -7554,90 +7637,92 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return found end - local resolved = resolve_decl_into_nominal(t, found) + local resolved = resolve_decl_into_nominal(self, t, found) - local typedecl = a_type("typedecl", { def = resolved } as TypeDeclType) + local typedecl = a_type(typealias, "typedecl", { def = resolved } as TypeDeclType) t.resolved = typedecl return typedecl end end - local function are_same_unresolved_global_type(t1: NominalType, t2: NominalType): boolean - if t1.names[1] == t2.names[1] then - local unresolved = get_unresolved() - if unresolved.global_types[t1.names[1]] then - return true + do + local function are_same_unresolved_global_type(self: TypeChecker, t1: NominalType, t2: NominalType): boolean + if t1.names[1] == t2.names[1] then + local global_scope = self.st[1] + if global_scope.pending_global_types[t1.names[1]] then + return true + end end + return false end - return false - end - local function fail_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local t1name = show_type(t1) - local t2name = show_type(t2) - if t1name == t2name then - local t1r = resolve_nominal(t1) - if t1r.filename then - t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" - end - local t2r = resolve_nominal(t2) - if t2r.filename then - t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" + local function fail_nominals(self: TypeChecker, t1: NominalType, t2: NominalType): boolean, {Error} + local t1name = show_type(t1) + local t2name = show_type(t2) + if t1name == t2name then + self:resolve_nominal(t1) + if t1.found then + t1name = t1name .. " (defined in " .. t1.found.f .. ":" .. t1.found.y .. ")" + end + self:resolve_nominal(t2) + if t2.found then + t2name = t2name .. " (defined in " .. t2.found.f .. ":" .. t2.found.y .. ")" + end end + return false, { Err(t1name .. " is not a " .. t2name) } end - return false, { Err(t1, t1name .. " is not a " .. t2name) } - end - local function are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} - local same_names: boolean - if t1.found and t2.found then - same_names = t1.found.typeid == t2.found.typeid - else - local ft1 = t1.found or find_type(t1.names) - local ft2 = t2.found or find_type(t2.names) - if ft1 and ft2 then - same_names = ft1.typeid == ft2.typeid + function TypeChecker:are_same_nominals(t1: NominalType, t2: NominalType): boolean, {Error} + local same_names: boolean + if t1.found and t2.found then + same_names = t1.found.typeid == t2.found.typeid else - if are_same_unresolved_global_type(t1, t2) then - return true - end + local ft1 = t1.found or self:find_type(t1.names) + local ft2 = t2.found or self:find_type(t2.names) + if ft1 and ft2 then + same_names = ft1.typeid == ft2.typeid + else + if are_same_unresolved_global_type(self, t1, t2) then + return true + end - if not ft1 then - error_at(t1, "unknown type %s", t1) - end - if not ft2 then - error_at(t2, "unknown type %s", t2) + if not ft1 then + self.errs:add(t1, "unknown type %s", t1) + end + if not ft2 then + self.errs:add(t2, "unknown type %s", t2) + end + return false, {} -- errors were already produced end - return false, {} -- errors were already produced end - end - if not same_names then - return fail_nominals(t1, t2) - elseif t1.typevals == nil and t2.typevals == nil then - return true - elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then - local errs = {} - for i = 1, #t1.typevals do - local _, typeval_errs = same_type(t1.typevals[i], t2.typevals[i]) - add_errs_prefixing(t1, typeval_errs, errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ") + if not same_names then + return fail_nominals(self, t1, t2) + elseif t1.typevals == nil and t2.typevals == nil then + return true + elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then + local errs = {} + for i = 1, #t1.typevals do + local _, typeval_errs = self:same_type(t1.typevals[i], t2.typevals[i]) + self.errs:add_prefixing(nil, typeval_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", errs) + end + return any_errors(errs) end - return any_errors(errs) + return true end - return true end local is_lua_table_type: function(t: Type): boolean - local function to_structural(t: Type): Type + function TypeChecker:to_structural(t: Type): Type assert(not t is TupleType) if t is NominalType then - return resolve_nominal(t) + return self:resolve_nominal(t) end return t end - local function unite(types: {Type}, flatten_constants?: boolean): Type + local function unite(w: Where, types: {Type}, flatten_constants?: boolean): Type if #types == 1 then return types[1] end @@ -7648,7 +7733,6 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- Make things like number | number resolve to number local types_seen: {(integer|string):boolean} = {} -- but never add nil as a type in the union - types_seen[NIL.typeid] = true types_seen["nil"] = true local i = 1 @@ -7684,14 +7768,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - if types_seen[INVALID.typeid] then - return INVALID + if types_seen["invalid"] then + return a_type(w, "invalid", {}) end if #ts == 1 then return ts[1] else - return a_union(ts) + return a_union(w, ts) end end @@ -7711,21 +7795,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local expand_type: function(where: Where, old: Type, new: Type): Type - local function arraytype_from_tuple(where: Where, tupletype: TupleTableType): ArrayType, {Error} + function TypeChecker:arraytype_from_tuple(w: Where, tupletype: TupleTableType): ArrayType, {Error} -- first just try a basic union - local element_type = unite(tupletype.types, true) + local element_type = unite(w, tupletype.types, true) local valid = (not element_type is UnionType) and true or is_valid_union(element_type) if valid then - return an_array(element_type) + return an_array(w, element_type) end -- failing a basic union, expand the types - local arr_type = an_array(tupletype.types[1]) + local arr_type = an_array(w, tupletype.types[1]) for i = 2, #tupletype.types do - local expanded = expand_type(where, arr_type, an_array(tupletype.types[i])) + local expanded = self:expand_type(w, arr_type, an_array(w, tupletype.types[i])) if not expanded is ArrayType then - return nil, { Err(tupletype, "unable to convert tuple %s to array", tupletype) } + return nil, { Err("unable to convert tuple %s to array", tupletype) } end arr_type = expanded end @@ -7736,33 +7819,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t is NominalType and t.names[1] == "@self" end - local function compare_true(_: Type, _: Type): boolean, {Error} + local function compare_true(_: TypeChecker, _: Type, _: Type): boolean, {Error} return true end - local function subtype_nominal(a: Type, b: Type): boolean, {Error} + function TypeChecker:subtype_nominal(a: Type, b: Type): boolean, {Error} if is_self(a) and is_self(b) then return true end - local ra = a is NominalType and resolve_nominal(a) or a - local rb = b is NominalType and resolve_nominal(b) or b - local ok, errs = is_a(ra, rb) + local ra = a is NominalType and self:resolve_nominal(a) or a + local rb = b is NominalType and self:resolve_nominal(b) or b + local ok, errs = self:is_a(ra, rb) if errs and #errs == 1 and errs[1].msg:match("^got ") then return false -- translate to got-expected error with unresolved types end return ok, errs end - local function subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} - if (not a.elements) or (not is_a(a.elements, b.elements)) then + function TypeChecker:subtype_array(a: ArrayLikeType, b: ArrayLikeType): boolean, {Error} + if (not a.elements) or (not self:is_a(a.elements, b.elements)) then return false end if a.consttypes and #a.consttypes > 1 then -- constant array, check elements (useful for array of enums) for _, e in ipairs(a.consttypes) do - if not is_a(e, b.elements) then - return false, { Err(a, "%s is not a member of %s", e, b.elements) } + if not self:is_a(e, b.elements) then + return false, { Err("%s is not a member of %s", e, b.elements) } end end end @@ -7784,16 +7867,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil end - local function subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} + function TypeChecker:subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error} -- assert(b.typename == "record") if a.elements and b.elements then - if not is_a(a.elements, b.elements) then - return false, { Err(a, "array parts have incompatible element types") } + if not self:is_a(a.elements, b.elements) then + return false, { Err("array parts have incompatible element types") } end end if a.is_userdata ~= b.is_userdata then - return false, { Err(a, a.is_userdata and "userdata is not a record" + return false, { Err(a.is_userdata and "userdata is not a record" or "record is not a userdata") } end @@ -7802,9 +7885,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ak = a.fields[k] local bk = b.fields[k] if bk then - local ok, fielderrs = is_a(ak, bk) + local ok, fielderrs = self:is_a(ak, bk) if not ok then - add_errs_prefixing(nil, fielderrs, errs, "record field doesn't match: " .. k .. ": ") + self.errs:add_prefixing(nil, fielderrs, "record field doesn't match: " .. k .. ": ", errs) end end end @@ -7818,32 +7901,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - local eqtype_record = function(a: RecordType, b: RecordType): boolean, {Error} + function TypeChecker:eqtype_record(a: RecordType, b: RecordType): boolean, {Error} -- checking array interface if (a.elements ~= nil) ~= (b.elements ~= nil) then - return false, { Err(a, "types do not have the same array interface") } + return false, { Err("types do not have the same array interface") } end if a.elements then - local ok, errs = same_type(a.elements, b.elements) + local ok, errs = self:same_type(a.elements, b.elements) if not ok then return ok, errs end end - local ok, errs = subtype_record(a, b) + local ok, errs = self:subtype_record(a, b) if not ok then return ok, errs end - ok, errs = subtype_record(b, a) + ok, errs = self:subtype_record(b, a) if not ok then return ok, errs end return true end - local function compare_map(ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} - local ok1, errs_k = same_type(ak, bk) - local ok2, errs_v = same_type(av, bv) + local function compare_map(self: TypeChecker, ak: Type, bk: Type, av: Type, bv: Type, no_hack?: boolean): boolean, {Error} + local ok1, errs_k = self:same_type(ak, bk) + local ok2, errs_v = self:same_type(av, bv) -- FIXME hack for {any:any} if bk.typename == "any" and not no_hack then @@ -7873,25 +7956,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return false, errs_k or errs_v end - local function compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} + function TypeChecker:compare_or_infer_typevar(typevar: string, a: Type, b: Type, cmp: CompareTypes): boolean, {Error} -- assert((a == nil and b ~= nil) or (a ~= nil and b == nil)) -- does the typevar currently match to a type? - local vt, _, constraint = find_var_type(typevar) + local vt, _, constraint = self:find_var_type(typevar) if vt then -- If so, compare it to the other type - return cmp(a or vt, b or vt) + return cmp(self, a or vt, b or vt) else -- otherwise, infer it to the other type local other = a or b -- but check interface constraint first if present if constraint then - if not is_a(other, constraint) then - return false, { Err(other, "given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } + if not self:is_a(other, constraint) then + return false, { Err("given type %s does not satisfy %s constraint in type variable " .. display_typevar(typevar), other, constraint) } end - if same_type(other, constraint) then + if self:same_type(other, constraint) then -- do not infer to some type as constraint right away, -- to give a chance to more specific inferences -- in other arguments/returns @@ -7899,22 +7982,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ok, r, errs = resolve_typevars(other) + local ok, r, errs = typevar_resolver(self, other, resolve_typevar) if not ok then return false, errs end if r is TypeVarType and r.typevar == typevar then return true end - add_var(nil, typevar, r) + self:add_var(nil, typevar, r) return true end end -- ∃ x ∈ xs. t <: x - local function exists_supertype_in(t: Type, xs: AggregateType): Type + function TypeChecker:exists_supertype_in(t: Type, xs: AggregateType): Type for _, x in ipairs(xs.types) do - if is_a(t, x) then + if self:is_a(t, x) then return x end end @@ -7925,143 +8008,139 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["array"] = compare_true, ["map"] = compare_true, ["tupletable"] = compare_true, - ["interface"] = function(_a: Type, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self: TypeChecker, _a: Type, b: InterfaceType): boolean, {Error} return not b.is_userdata end, - ["record"] = function(_a: Type, b: RecordType): boolean, {Error} + ["record"] = function(_self: TypeChecker, _a: Type, b: RecordType): boolean, {Error} return not b.is_userdata end, } - local type TypeRelations = {TypeName:{TypeName:CompareTypes}} - - local eqtype_relations: TypeRelations - eqtype_relations = { + TypeChecker.eqtype_relations = { ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, same_type) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.same_type) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not same_type(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } + if not self:same_type(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types ~= #b.types then - return false, { Err(a, "tuples have different size", a, b) } + return false, { Err("tuples have different size", a, b) } end return true end, }, ["array"] = { - ["array"] = function(a: ArrayType, b: ArrayType): boolean, {Error} - return same_type(a.elements, b.elements) + ["array"] = function(self: TypeChecker, a: ArrayType, b: ArrayType): boolean, {Error} + return self:same_type(a.elements, b.elements) end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values, true) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values, true) end, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} - return (has_all_types_of(a.types, b.types) - and has_all_types_of(b.types, a.types)) + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} + return (self:has_all_types_of(a.types, b.types) + and self:has_all_types_of(b.types, a.types)) end, }, ["nominal"] = { - ["nominal"] = are_same_nominals, + ["nominal"] = TypeChecker.are_same_nominals, }, ["record"] = { - ["record"] = eqtype_record, + ["record"] = TypeChecker.eqtype_record, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} + ["interface"] = function(_self:TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} return a.typeid == b.typeid end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self:TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local argdelta = a.is_method and 1 or 0 local naargs, nbargs = #a.args.tuple, #b.args.tuple if naargs ~= nbargs then if (not not a.is_method) ~= (not not b.is_method) then - return false, { Err(a, "different number of input arguments: method and non-method are not the same type") } + return false, { Err("different number of input arguments: method and non-method are not the same type") } end - return false, { Err(a, "different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } + return false, { Err("different number of input arguments: got " .. naargs - argdelta .. ", expected " .. nbargs - argdelta) } end local narets, nbrets = #a.rets.tuple, #b.rets.tuple if narets ~= nbrets then - return false, { Err(a, "different number of return values: got " .. narets .. ", expected " .. nbrets) } + return false, { Err("different number of return values: got " .. narets .. ", expected " .. nbrets) } end local errs = {} for i = 1, naargs do - arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) + self:arg_check(a, errs, a.args.tuple[i], b.args.tuple[i], "invariant", "argument", i - argdelta) end for i = 1, narets do - arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) + self:arg_check(a, errs, a.rets.tuple[i], b.rets.tuple[i], "invariant", "return", i) end return any_errors(errs) end, }, ["*"] = { - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, same_type) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.same_type) end, }, } - local subtype_relations: TypeRelations - subtype_relations = { + TypeChecker.subtype_relations = { ["tuple"] = { - ["tuple"] = function(a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] + ["tuple"] = function(self: TypeChecker, a: TupleType, b: TupleType): boolean, {Error} -- ∀ a[i] ∈ a, b[i] ∈ b. a[i] <: b[i] local at, bt = a.tuple, b.tuple -- ────────────────────────────────── if #at ~= #bt then -- a tuple <: b tuple return false end for i = 1, #at do - if not is_a(at[i], bt[i]) then + if not self:is_a(at[i], bt[i]) then return false end end return true end, - ["*"] = function(a: Type, b: Type): boolean, {Error} - return is_a(resolve_tuple(a), b) + ["*"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(resolve_tuple(a), b) end, }, ["typevar"] = { - ["typevar"] = function(a: TypeVarType, b: TypeVarType): boolean, {Error} + ["typevar"] = function(self: TypeChecker, a: TypeVarType, b: TypeVarType): boolean, {Error} if a.typevar == b.typevar then return true end - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["*"] = function(a: TypeVarType, b: Type): boolean, {Error} - return compare_or_infer_typevar(a.typevar, nil, b, is_a) + ["*"] = function(self: TypeChecker, a: TypeVarType, b: Type): boolean, {Error} + return self:compare_or_infer_typevar(a.typevar, nil, b, self.is_a) end, }, ["nil"] = { ["*"] = compare_true, }, ["union"] = { - ["union"] = function(a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u + ["union"] = function(self: TypeChecker, a: UnionType, b: UnionType): boolean, {Error} -- ∀ t ∈ a. ∃ u ∈ b. t <: u local used = {} -- ──────────────────────── for _, t in ipairs(a.types) do -- a union <: b union - begin_scope() - local u = exists_supertype_in(t, b) - end_scope() -- don't preserve failed inferences + self:begin_scope() + local u = self:exists_supertype_in(t, b) + self:end_scope() -- don't preserve failed inferences if not u then return false end @@ -8070,13 +8149,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end for u, t in pairs(used) do - is_a(t, u) -- preserve valid inferences + self:is_a(t, u) -- preserve valid inferences end return true end, - ["*"] = function(a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b - for _, t in ipairs(a.types) do -- ──────────────── - if not is_a(t, b) then -- a union <: b + ["*"] = function(self: TypeChecker, a: UnionType, b: Type): boolean, {Error} -- ∀ t ∈ a, t <: b + for _, t in ipairs(a.types) do -- ──────────────── + if not self:is_a(t, b) then -- a union <: b return false end end @@ -8084,212 +8163,212 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["poly"] = { - ["*"] = function(a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b - if exists_supertype_in(b, a) then -- ─────────────── - return true -- a poly <: b + ["*"] = function(self: TypeChecker, a: PolyType, b: Type): boolean, {Error} -- ∃ t ∈ a, t <: b + if self:exists_supertype_in(b, a) then -- ─────────────── + return true -- a poly <: b end - return false, { Err(a, "cannot match against any alternatives of the polymorphic type") } + return false, { Err("cannot match against any alternatives of the polymorphic type") } end, }, ["nominal"] = { - ["nominal"] = function(a: NominalType, b: NominalType): boolean, {Error} - local ok, errs = are_same_nominals(a, b) + ["nominal"] = function(self: TypeChecker, a: NominalType, b: NominalType): boolean, {Error} + local ok, errs = self:are_same_nominals(a, b) if ok then return true end - local rb = resolve_nominal(b) + local rb = self:resolve_nominal(b) if rb is InterfaceType then -- match interface subtyping - return is_a(a, rb) + return self:is_a(a, rb) end - local ra = resolve_nominal(a) + local ra = self:resolve_nominal(a) if ra is UnionType or rb is UnionType then -- match unions structurally - return is_a(ra, rb) + return self:is_a(ra, rb) end -- all other types nominally return ok, errs end, - ["*"] = subtype_nominal, + ["*"] = TypeChecker.subtype_nominal, }, ["enum"] = { ["string"] = compare_true, }, ["string"] = { - ["enum"] = function(a: StringType, b: EnumType): boolean, {Error} + ["enum"] = function(_self: TypeChecker, a: StringType, b: EnumType): boolean, {Error} if not a.literal then - return false, { Err(a, "string is not a %s", b) } + return false, { Err("%s is not a %s", a, b) } end if b.enumset[a.literal] then return true end - return false, { Err(a, "%s is not a member of %s", a, b) } + return false, { Err("%s is not a member of %s", a, b) } end, }, ["integer"] = { ["number"] = compare_true, }, ["interface"] = { - ["interface"] = function(a: InterfaceType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["interface"] = function(self: TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end - return same_type(a, b) + return self:same_type(a, b) end, - ["array"] = subtype_array, - ["record"] = subtype_record, - ["tupletable"] = function(a: Type, b: Type): boolean, {Error} - return subtype_relations["record"]["tupletable"](a, b) + ["array"] = TypeChecker.subtype_array, + ["record"] = TypeChecker.subtype_record, + ["tupletable"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self.subtype_relations["record"]["tupletable"](self, a, b) end, }, ["emptytable"] = emptytable_relations, ["tupletable"] = { - ["tupletable"] = function(a: TupleTableType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: TupleTableType, b: TupleTableType): boolean, {Error} for i = 1, math.min(#a.types, #b.types) do - if not is_a(a.types[i], b.types[i]) then - return false, { Err(a, "in tuple entry " + if not self:is_a(a.types[i], b.types[i]) then + return false, { Err("in tuple entry " .. tostring(i) .. ": got %s, expected %s", a.types[i], b.types[i]) } end end if #a.types > #b.types then - return false, { Err(a, "tuple %s is too big for tuple %s", a, b) } + return false, { Err("tuple %s is too big for tuple %s", a, b) } end return true end, - ["record"] = function(a: Type, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: Type, b: RecordType): boolean, {Error} if b.elements then - return subtype_relations["tupletable"]["array"](a, b) + return self.subtype_relations["tupletable"]["array"](self, a, b) end end, - ["array"] = function(a: TupleTableType, b: ArrayType): boolean, {Error} + ["array"] = function(self: TypeChecker, a: TupleTableType, b: ArrayType): boolean, {Error} if b.inferred_len and b.inferred_len > #a.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#a.types) .. ", got " .. tostring(b.inferred_len)) } end - local aa, err = arraytype_from_tuple(a.inferred_at, a) + local aa, err = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then return false, err end - if not is_a(aa, b) then - return false, { Err(a, "got %s (from %s), expected %s", aa, a, b) } + if not self:is_a(aa, b) then + return false, { Err("got %s (from %s), expected %s", aa, a, b) } end return true end, - ["map"] = function(a: TupleTableType, b: MapType): boolean, {Error} - local aa = arraytype_from_tuple(a.inferred_at, a) + ["map"] = function(self: TypeChecker, a: TupleTableType, b: MapType): boolean, {Error} + local aa = self:arraytype_from_tuple(a.inferred_at or a, a) if not aa then - return false, { Err(a, "Unable to convert tuple %s to map", a) } + return false, { Err("Unable to convert tuple %s to map", a) } end - return compare_map(INTEGER, b.keys, aa.elements, b.values) + return compare_map(self, a_type(a, "integer", {}), b.keys, aa.elements, b.values) end, }, ["record"] = { - ["record"] = subtype_record, - ["interface"] = function(a: RecordType, b: InterfaceType): boolean, {Error} - if find_in_interface_list(a, function(t: Type): boolean return (is_a(t, b)) end) then + ["record"] = TypeChecker.subtype_record, + ["interface"] = function(self: TypeChecker, a: RecordType, b: InterfaceType): boolean, {Error} + if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then return true end if not a.declname then -- match inferred table (anonymous record) structurally to interface - return subtype_record(a, b) + return self:subtype_record(a, b) end end, - ["array"] = subtype_array, - ["map"] = function(a: RecordType, b: MapType): boolean, {Error} - if not is_a(b.keys, STRING) then - return false, { Err(a, "can't match a record to a map with non-string keys") } + ["array"] = TypeChecker.subtype_array, + ["map"] = function(self: TypeChecker, a: RecordType, b: MapType): boolean, {Error} + if not self:is_a(b.keys, a_type(b, "string", {})) then + return false, { Err("can't match a record to a map with non-string keys") } end for _, k in ipairs(a.field_order) do local bk = b.keys if bk is EnumType and not bk.enumset[k] then - return false, { Err(a, "key is not an enum value: " .. k) } + return false, { Err("key is not an enum value: " .. k) } end - if not is_a(a.fields[k], b.values) then - return false, { Err(a, "record is not a valid map; not all fields have the same type") } + if not self:is_a(a.fields[k], b.values) then + return false, { Err("record is not a valid map; not all fields have the same type") } end end return true end, - ["tupletable"] = function(a: RecordType, b: Type): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: RecordType, b: Type): boolean, {Error} if a.elements then - return subtype_relations["array"]["tupletable"](a, b) + return self.subtype_relations["array"]["tupletable"](self, a, b) end end, }, ["array"] = { - ["array"] = subtype_array, - ["record"] = function(a: ArrayType, b: RecordType): boolean, {Error} + ["array"] = TypeChecker.subtype_array, + ["record"] = function(self: TypeChecker, a: ArrayType, b: RecordType): boolean, {Error} if b.elements then - return subtype_array(a, b) + return self:subtype_array(a, b) end end, - ["map"] = function(a: ArrayType, b: MapType): boolean, {Error} - return compare_map(INTEGER, b.keys, a.elements, b.values) + ["map"] = function(self: TypeChecker, a: ArrayType, b: MapType): boolean, {Error} + return compare_map(self, a_type(a, "integer", {}), b.keys, a.elements, b.values) end, - ["tupletable"] = function(a: ArrayType, b: TupleTableType): boolean, {Error} + ["tupletable"] = function(self: TypeChecker, a: ArrayType, b: TupleTableType): boolean, {Error} local alen = a.inferred_len or 0 if alen > #b.types then - return false, { Err(a, "incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } + return false, { Err("incompatible length, expected maximum length of " .. tostring(#b.types) .. ", got " .. tostring(alen)) } end -- for array literals (which is the only case where inferred_len is defined), -- only check the entries that are present for i = 1, (alen > 0) and alen or #b.types do - if not is_a(a.elements, b.types[i]) then - return false, { Err(a, "tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } + if not self:is_a(a.elements, b.types[i]) then + return false, { Err("tuple entry " .. i .. " of type %s does not match type of array elements, which is %s", b.types[i], a.elements) } end end return true end, }, ["map"] = { - ["map"] = function(a: MapType, b: MapType): boolean, {Error} - return compare_map(a.keys, b.keys, a.values, b.values) + ["map"] = function(self: TypeChecker, a: MapType, b: MapType): boolean, {Error} + return compare_map(self, a.keys, b.keys, a.values, b.values) end, - ["array"] = function(a: MapType, b: ArrayType): boolean, {Error} - return compare_map(a.keys, INTEGER, a.values, b.elements) + ["array"] = function(self: TypeChecker, a: MapType, b: ArrayType): boolean, {Error} + return compare_map(self, a.keys, a_type(b, "integer", {}), a.values, b.elements) end, }, ["typedecl"] = { - ["record"] = function(a: TypeDeclType, b: RecordType): boolean, {Error} + ["record"] = function(self: TypeChecker, a: TypeDeclType, b: RecordType): boolean, {Error} local def = a.def if def is RecordLikeType then - return subtype_record(def, b) -- record as prototype + return self:subtype_record(def, b) -- record as prototype end end, }, ["function"] = { - ["function"] = function(a: FunctionType, b: FunctionType): boolean, {Error} + ["function"] = function(self: TypeChecker, a: FunctionType, b: FunctionType): boolean, {Error} local errs = {} local aa, ba = a.args.tuple, b.args.tuple if (not b.args.is_va) and a.min_arity > b.min_arity then - table.insert(errs, Err(a, "incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) + table.insert(errs, Err("incompatible number of arguments: got " .. show_arity(a) .. " %s, expected " .. show_arity(b) .. " %s", a.args, b.args)) else for i = ((a.is_method or b.is_method) and 2 or 1), #aa do - arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) + self:arg_check(nil, errs, aa[i], ba[i] or ba[#ba], "bivariant", "argument", i) end end local ar, br = a.rets.tuple, b.rets.tuple local diff_by_va = #br - #ar == 1 and b.rets.is_va if #ar < #br and not diff_by_va then - table.insert(errs, Err(a, "incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) + table.insert(errs, Err("incompatible number of returns: got " .. #ar .. " %s, expected " .. #br .. " %s", a.rets, b.rets)) else local nrets = #br if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do - arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) + self:arg_check(nil, errs, ar[i], br[i], "bivariant", "return", i) end end @@ -8297,36 +8376,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["typearg"] = { - ["typearg"] = function(a: TypeArgType, b: TypeArgType): boolean, {Error} + ["typearg"] = function(_self: TypeChecker, a: TypeArgType, b: TypeArgType): boolean, {Error} return a.typearg == b.typearg end, - ["*"] = function(a: TypeArgType, b: Type): boolean, {Error} + ["*"] = function(self: TypeChecker, a: TypeArgType, b: Type): boolean, {Error} if a.constraint then - return is_a(a.constraint, b) + return self:is_a(a.constraint, b) end end, }, ["*"] = { ["any"] = compare_true, - ["tuple"] = function(a: Type, b: Type): boolean, {Error} - return is_a(a_tuple({a}), b) + ["tuple"] = function(self: TypeChecker, a: Type, b: Type): boolean, {Error} + return self:is_a(a_tuple(a, {a}), b) end, - ["typevar"] = function(a: Type, b: TypeVarType): boolean, {Error} - return compare_or_infer_typevar(b.typevar, a, nil, is_a) + ["typevar"] = function(self: TypeChecker, a: Type, b: TypeVarType): boolean, {Error} + return self:compare_or_infer_typevar(b.typevar, a, nil, self.is_a) end, - ["typearg"] = function(a: Type, b: TypeArgType): boolean, {Error} + ["typearg"] = function(self: TypeChecker, a: Type, b: TypeArgType): boolean, {Error} if b.constraint then - return is_a(a, b.constraint) + return self:is_a(a, b.constraint) end end, - ["union"] = exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t - -- ─────────────── - -- a <: b union - ["nominal"] = subtype_nominal, - ["poly"] = function(a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t - for _, t in ipairs(b.types) do -- ─────────────── - if not is_a(a, t) then -- a <: b poly - return false, { Err(a, "cannot match against all alternatives of the polymorphic type") } + ["union"] = TypeChecker.exists_supertype_in as CompareTypes, -- ∃ t ∈ b, a <: t + -- ─────────────── + -- a <: b union + ["nominal"] = TypeChecker.subtype_nominal, + ["poly"] = function(self: TypeChecker, a: Type, b: PolyType): boolean, {Error} -- ∀ t ∈ b, a <: t + for _, t in ipairs(b.types) do -- ─────────────── + if not self:is_a(a, t) then -- a <: b poly + return false, { Err("cannot match against all alternatives of the polymorphic type") } end end return true @@ -8335,7 +8414,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string } -- evaluation strategy - local type_priorities: {TypeName:integer} = { + TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first ["tuple"] = 2, ["typevar"] = 3, @@ -8364,19 +8443,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["function"] = 14, } - if lax then - type_priorities["unknown"] = 0 - - subtype_relations["unknown"] = {} - subtype_relations["unknown"]["*"] = compare_true - subtype_relations["*"]["unknown"] = compare_true - -- in .lua files, all values can be used in a boolean context - subtype_relations["boolean"] = {} - subtype_relations["boolean"]["boolean"] = compare_true - subtype_relations["*"]["boolean"] = compare_true - end - - local function compare_types(relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} + local function compare_types(self: TypeChecker, relations: TypeRelations, t1: Type, t2: Type): boolean, {Error} if t1.typeid == t2.typeid then return true end @@ -8384,8 +8451,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local s1 = relations[t1.typename] local fn = s1 and s1[t2.typename] if not fn then - local p1 = type_priorities[t1.typename] or 999 - local p2 = type_priorities[t2.typename] or 999 + local p1 = self.type_priorities[t1.typename] or 999 + local p2 = self.type_priorities[t2.typename] or 999 fn = (p1 < p2 and (s1 and s1["*"]) or (relations["*"][t2.typename])) end @@ -8394,32 +8461,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fn == compare_true then return true end - ok, err = fn(t1, t2) + ok, err = fn(self, t1, t2) else ok = t1.typename == t2.typename end if (not ok) and not err then - return false, { Err(t1, "got %s, expected %s", t1, t2) } + return false, { Err("got %s, expected %s", t1, t2) } end return ok, err end -- subtyping comparison - is_a = function(t1: Type, t2: Type): boolean, {Error} - return compare_types(subtype_relations, t1, t2) + function TypeChecker:is_a(t1: Type, t2: Type): boolean, {Error} + return compare_types(self, self.subtype_relations, t1, t2) end -- invariant type comparison - same_type = function(t1: Type, t2: Type): boolean, {Error} + function TypeChecker:same_type(t1: Type, t2: Type): boolean, {Error} -- except for error messages, behavior is the same as - -- `return (is_a(t1, t2) and is_a(t2, t1))` - return compare_types(eqtype_relations, t1, t2) + -- `return (is_a(t1, t2) and self:is_a(t2, t1))` + return compare_types(self, self.eqtype_relations, t1, t2) end if TL_DEBUG then - local orig_is_a = is_a - is_a = function(t1: Type, t2: Type): boolean, {Error} + local orig_is_a = TypeChecker.is_a + TypeChecker.is_a = function(self: TypeChecker, t1: Type, t2: Type): boolean, {Error} assert(type(t1) == "table") assert(type(t2) == "table") @@ -8429,14 +8496,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - return orig_is_a(t1, t2) + return orig_is_a(self, t1, t2) end end - local function assert_is_a(where: Where, t1: Type, t2: Type, context: string, name?: string): boolean + function TypeChecker:assert_is_a(w: Where, t1: Type, t2: Type, ctx?: string | Node, name?: string): boolean t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) - if lax and (is_unknown(t1) or is_unknown(t2)) then + if self.feat_lax and (is_unknown(t1) or is_unknown(t2)) then return true end @@ -8444,24 +8511,27 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t1.typename == "nil" then return true elseif t2 is UnresolvedEmptyTableValueType then - if is_number_type(t2.emptytable_type.keys) then -- ideally integer only - infer_emptytable(t2.emptytable_type, infer_at(where, an_array(t1))) + local t2keys = t2.emptytable_type.keys + if t2keys is NumericType then -- ideally integer only + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, an_array(w, t1))) else - infer_emptytable(t2.emptytable_type, infer_at(where, a_map(t2.emptytable_type.keys, t1))) + self:infer_emptytable(t2.emptytable_type, self:infer_at(w, a_map(w, t2keys, t1))) end return true elseif t2 is EmptyTableType then if is_lua_table_type(t1) then - infer_emptytable(t2, infer_at(where, t1)) + self:infer_emptytable(t2, self:infer_at(w, t1)) elseif not t1 is EmptyTableType then - error_at(where, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) + self.errs:add(w, self.errs:get_context(ctx, name) .. "assigning %s to a variable declared with {}", t1) return false end return true end - local ok, match_errs = is_a(t1, t2) - add_errs_prefixing(where, match_errs, errors, context .. ": ".. (name and (name .. ": ") or "")) + local ok, match_errs = self:is_a(t1, t2) + if not ok then + self.errs:add_prefixing(w, match_errs, self.errs:get_context(ctx, name)) + end return ok end @@ -8469,11 +8539,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t is InvalidType then return false end - if same_type(t, NIL) then + if t.typename == "nil" then return true end if t is NominalType then - t = resolve_nominal(t) + t = assert(t.resolved) end if t is RecordLikeType then return t.meta_fields and t.meta_fields["__close"] ~= nil @@ -8487,40 +8557,31 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ["boolean"] = true, ["literal_table"] = true, } - local function expr_is_definitely_not_closable(e: Node): boolean - return definitely_not_closable_exprs[e.kind] - end - - local unknown_dots: {string:boolean} = {} - - local function add_unknown_dot(node: Node, name: string) - if not unknown_dots[name] then - unknown_dots[name] = true - add_unknown(node, name) - end + local function expr_is_definitely_not_closable(e: Node): boolean + return definitely_not_closable_exprs[e.kind] end - local function same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type + function TypeChecker:same_in_all_union_entries(u: UnionType, check: function(Type): (Type, Type)): Type local t1, f = check(u.types[1]) if not t1 then return nil end for i = 2, #u.types do local t2 = check(u.types[i]) - if not t2 or not same_type(t1, t2) then + if not t2 or not self:same_type(t1, t2) then return nil end end return f or t1 end - local function same_call_mt_in_all_union_entries(u: UnionType): Type - return same_in_all_union_entries(u, function(t: Type): (Type, Type) - t = to_structural(t) + function TypeChecker:same_call_mt_in_all_union_entries(u: UnionType): Type + return self:same_in_all_union_entries(u, function(t: Type): (Type, Type) + t = self:to_structural(t) if t is RecordLikeType then local call_mt = t.meta_fields and t.meta_fields["__call"] if call_mt is FunctionType then - local args_tuple = a_tuple({}) + local args_tuple = a_tuple(u, {}) for i = 2, #call_mt.args.tuple do table.insert(args_tuple.tuple, call_mt.args.tuple[i]) end @@ -8530,20 +8591,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end) end - local function resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean + function TypeChecker:resolve_for_call(func: Type, args: TupleType, is_method: boolean): Type, boolean -- resolve unknown in lax mode, produce a general unknown function - if lax and is_unknown(func) then - func = a_fn { args = va_args { UNKNOWN }, rets = va_args { UNKNOWN } } + if self.feat_lax and is_unknown(func) then + local unk = func + func = a_function(func, { min_arity = 0, args = a_vararg(func, { unk }), rets = a_vararg(func, { unk }) }) end -- unwrap if tuple, resolve if nominal - func = to_structural(func) + func = self:to_structural(func) if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union if func is UnionType then - local r = same_call_mt_in_all_union_entries(func) + local r = self:same_call_mt_in_all_union_entries(func) if r then table.insert(args.tuple, 1, func.types[1]) -- FIXME: is this right? - return to_structural(r), true + return self:to_structural(r), true end end -- resolve if prototype @@ -8557,7 +8619,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if func is RecordLikeType and func.meta_fields and func.meta_fields["__call"] then table.insert(args.tuple, 1, func) func = func.meta_fields["__call"] - func = to_structural(func) + func = self:to_structural(func) is_method = true end end @@ -8565,19 +8627,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local type OnArgId = function(node: Node, i: integer): T - local type OnNode = function(node: Node, children: {T}, ret: T): T + local type OnNode = function(s: S, node: Node, children: {T}, ret: T): T - local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T + local function traverse_macroexp(macroexp: Node, on_arg_id: OnArgId, on_node: OnNode): T local root = macroexp.exp local argnames = {} for i, a in ipairs(macroexp.args) do argnames[a.tk] = i end - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["variable"] = { - after = function(node: Node, _children: {T}): T + after = function(_: nil, node: Node, _children: {T}): T local i = argnames[node.tk] if not i then return nil @@ -8587,10 +8649,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = on_node, + after = on_node as VisitorAfter, } - return recurse_node(root, visit_node, {}) + return recurse_node(nil, root, visit_node, {}) end local function expand_macroexp(orignode: Node, args: {Node}, macroexp: Node) @@ -8598,7 +8660,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return { Node, args[i] } end - local on_node = function(node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} + local on_node = function(_: nil, node: Node, children: {{Node, Node}}, ret: {Node, Node}): {Node, Node} local orig = ret and ret[2] or node local out = shallow_copy_table(orig) @@ -8627,12 +8689,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.expanded = p[2] end - local function check_macroexp_arg_use(macroexp: Node) + function TypeChecker:check_macroexp_arg_use(macroexp: Node) local used: {string:boolean} = {} local on_arg_id = function(node: Node, _i: integer): {Node, Node} if used[node.tk] then - error_at(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") + self.errs:add(node, "cannot use argument '" .. node.tk .. "' multiple times in macroexp") else used[node.tk] = true end @@ -8655,18 +8717,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string orignode.known = saveknown end - local type InvalidOrTupleType = InvalidType | TupleType - - local type_check_function_call: function(Node, Type, TupleType, ? integer, ? Node, ? {Node}): InvalidOrTupleType, FunctionType do - local function mark_invalid_typeargs(f: FunctionType) + local function mark_invalid_typeargs(self: TypeChecker, f: FunctionType) if f.typeargs then for _, a in ipairs(f.typeargs) do - if not find_var_type(a.typearg) then + if not self:find_var_type(a.typearg) then if a.constraint then - add_var(nil, a.typearg, a.constraint) + self:add_var(nil, a.typearg, a.constraint) else - add_var(nil, a.typearg, lax and UNKNOWN or a_type("unresolvable_typearg", { + self:add_var(nil, a.typearg, self.feat_lax and an_unknown(a) or a_type(a, "unresolvable_typearg", { typearg = a.typearg } as UnresolvableTypeArgType)) end @@ -8675,7 +8734,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function infer_emptytables(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) + local function infer_emptytables(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, delta: integer) local xt, yt = xs.tuple, ys.tuple local n_xs = #xt local n_ys = #yt @@ -8685,19 +8744,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if x is EmptyTableType then local y = yt[i] or (ys.is_va and yt[n_ys]) if y then -- y may not be present when inferring returns - local w = wheres and wheres[i + delta] or where -- for self, a + argdelta is 0 - local inferred_y = infer_at(w, y) - infer_emptytable(x, inferred_y) + local iw = wheres and wheres[i + delta] or w -- for self, a + argdelta is 0 + local inferred_y = self:infer_at(iw, y) + self:infer_emptytable(x, inferred_y) xt[i] = inferred_y end end end end - local check_args_rets: function(where: Where, where_args: {Node}, f: Type, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + local check_args_rets: function(TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} do -- check if a tuple `xs` matches tuple `ys` - local function check_func_type_list(where: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} + local function check_func_type_list(self: TypeChecker, w: Where, wheres: {Where}, xs: TupleType, ys: TupleType, from: integer, delta: integer, v: VarianceMode, mode: ArgCheckMode): boolean, {Error} assert(xs.typename == "tuple", xs.typename) assert(ys.typename == "tuple", ys.typename) @@ -8708,11 +8767,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i = from, math.max(n_xs, n_ys) do local pos = i + delta - local x = xt[i] or (xs.is_va and xt[n_xs]) or NIL + local x = xt[i] or (xs.is_va and xt[n_xs]) or a_type(w, "nil", {}) local y = yt[i] or (ys.is_va and yt[n_ys]) if y then - local w = wheres and wheres[pos] or where - if not arg_check(w, errs, x, y, v, mode, pos) then + local iw = wheres and wheres[pos] or w + if not self:arg_check(iw, errs, x, y, v, mode, pos) then return nil, errs end end @@ -8721,7 +8780,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return true end - check_args_rets = function(where: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} + check_args_rets = function(self: TypeChecker, w: Where, where_args: {Node}, f: FunctionType, args: TupleType, expected_rets: TupleType, argdelta: integer): TupleType, {Error} local rets_ok = true local rets_errs: {Error} local args_ok: boolean @@ -8732,19 +8791,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if argdelta == -1 then from = 2 local errs = {} - if (not is_self(fargs[1])) and not arg_check(where, errs, fargs[1], args.tuple[1], "contravariant", "self") then + if (not is_self(fargs[1])) and not self:arg_check(w, errs, fargs[1], args.tuple[1], "contravariant", "self") then return nil, errs end end if expected_rets then - expected_rets = infer_at(where, expected_rets) - infer_emptytables(where, nil, expected_rets, f.rets, 0) + expected_rets = self:infer_at(w, expected_rets) + infer_emptytables(self, w, nil, expected_rets, f.rets, 0) - rets_ok, rets_errs = check_func_type_list(where, nil, f.rets, expected_rets, 1, 0, "covariant", "return") + rets_ok, rets_errs = check_func_type_list(self, w, nil, f.rets, expected_rets, 1, 0, "covariant", "return") end - args_ok, args_errs = check_func_type_list(where, where_args, f.args, args, from, argdelta, "contravariant", "argument") + args_ok, args_errs = check_func_type_list(self, w, where_args, f.args, args, from, argdelta, "contravariant", "argument") if (not args_ok) or (not rets_ok) then return nil, args_errs or {} end @@ -8752,29 +8811,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- if we got to this point without returning, -- we got a valid function match - infer_emptytables(where, where_args, args, f.args, argdelta) + infer_emptytables(self, w, where_args, args, f.args, argdelta) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end end - local function push_typeargs(func: FunctionType) + local function push_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - add_var(nil, fnarg.typearg, a_type("unresolved_typearg", { + self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { constraint = fnarg.constraint, } as UnresolvedTypeArgType)) end end end - local function pop_typeargs(func: FunctionType) + local function pop_typeargs(self: TypeChecker, func: FunctionType) if func.typeargs then for _, fnarg in ipairs(func.typeargs) do - if st[#st][fnarg.typearg] then - st[#st][fnarg.typearg] = nil + if self.st[#self.st].vars[fnarg.typearg] then + self.st[#self.st].vars[fnarg.typearg] = nil end end end @@ -8788,12 +8847,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function fail_call(where: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType + local function fail_call(self: TypeChecker, w: Where, func: FunctionType | PolyType, nargs: integer, errs: {Error}): TupleType if errs then - -- report the errors from the first match - for _, err in ipairs(errs) do - table.insert(errors, err) - end + self.errs:collect(errs) else -- found no arity match to try local expects: {string} = {} @@ -8810,34 +8866,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string else table.insert(expects, show_arity(func)) end - error_at(where, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") + self.errs:add(w, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") end local f = resolve_function_type(func, 1) - mark_invalid_typeargs(f) + mark_invalid_typeargs(self, f) - return resolve_typevars_at(where, f.rets) + return self:resolve_typevars_at(w, f.rets) end - local function check_call(where: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType + local function check_call(self: TypeChecker, w: Where, where_args: {Node}, func: Type, args: TupleType, expected_rets: TupleType, is_typedecl_funcall: boolean, argdelta: integer): InvalidOrTupleType, FunctionType assert(type(func) == "table") assert(type(args) == "table") local is_method = (argdelta == -1) if not (func is FunctionType or func is PolyType) then - func, is_method = resolve_for_call(func, args, is_method) + func, is_method = self:resolve_for_call(func, args, is_method) if is_method then argdelta = -1 end if not (func is FunctionType or func is PolyType) then - return invalid_at(where, "not a function: %s", func) + return self.errs:invalid_at(w, "not a function: %s", func) end end if is_method and args.tuple[1] then - add_var(nil, "@self", type_at(where, a_typedecl(args.tuple[1]))) + self:add_var(nil, "@self", a_typedecl(w, args.tuple[1])) end local passes, n = 1, 1 @@ -8854,30 +8910,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local f = resolve_function_type(func, i) local fargs = f.args.tuple if f.is_method and not is_method then - if args.tuple[1] and is_a(args.tuple[1], fargs[1]) then + if args.tuple[1] and self:is_a(args.tuple[1], fargs[1]) then -- a non-"@funcall" means a synthesized call, e.g. from a metamethod if not is_typedecl_funcall then - add_warning("hint", where, "invoked method as a regular function: consider using ':' instead of '.'") + self.errs:add_warning("hint", w, "invoked method as a regular function: consider using ':' instead of '.'") end else - return invalid_at(where, "invoked method as a regular function: use ':' instead of '.'") + return self.errs:invalid_at(w, "invoked method as a regular function: use ':' instead of '.'") end end local wanted = #fargs - local min_arity = feat_arity and f.min_arity or 0 + local min_arity = self.feat_arity and f.min_arity or 0 -- simple functions: - if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (lax and given <= wanted))) + if (passes == 1 and ((given <= wanted and given >= min_arity) or (f.args.is_va and given > wanted) or (self.feat_lax and given <= wanted))) -- poly, pass 1: try exact arity matches first or (passes == 3 and ((pass == 1 and given == wanted) -- poly, pass 2: then try adjusting with nils to missing arguments or using '...' - or (pass == 2 and given < wanted and (lax or given >= min_arity)) + or (pass == 2 and given < wanted and (self.feat_lax or given >= min_arity)) -- poly, pass 3: then finally try vararg functions or (pass == 3 and f.args.is_va and given > wanted))) then - push_typeargs(f) + push_typeargs(self, f) - local matched, errs = check_args_rets(where, where_args, f, args, expected_rets, argdelta) + local matched, errs = check_args_rets(self, w, where_args, f, args, expected_rets, argdelta) if matched then -- success! return matched, f @@ -8886,23 +8942,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected_rets then -- revert inferred returns - infer_emptytables(where, where_args, f.rets, f.rets, argdelta) + infer_emptytables(self, w, where_args, f.rets, f.rets, argdelta) end if passes == 3 then tried = tried or {} tried[i] = true - pop_typeargs(f) + pop_typeargs(self, f) end end end end end - return fail_call(where, func, given, first_errs) + return fail_call(self, w, func, given, first_errs) end - type_check_function_call = function(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType + function TypeChecker:type_check_function_call(node: Node, func: Type, args: TupleType, argdelta?: integer, e1?: Node, e2?: {Node}): InvalidOrTupleType, FunctionType e1 = e1 or node.e1 e2 = e2 or node.e2 @@ -8911,14 +8967,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if expected and expected is TupleType then expected_rets = expected else - expected_rets = a_tuple { node.expected } + expected_rets = a_tuple(node, { node.expected }) end - begin_scope() + self:begin_scope() local is_typedecl_funcall: boolean - if node.kind == "op" and node.op.op == "@funcall" and node.e1 and node.e1.receiver then - local receiver = node.e1.receiver + if node.kind == "op" and node.op.op == "@funcall" and e1 and e1.receiver then + local receiver = e1.receiver if receiver is NominalType then local resolved = receiver.resolved if resolved and resolved is TypeDeclType then @@ -8927,12 +8983,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local ret, f = check_call(node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) - ret = resolve_typevars_at(node, ret) - end_scope() + local ret, f = check_call(self, node, e2, func, args, expected_rets, is_typedecl_funcall, argdelta or 0) + ret = self:resolve_typevars_at(node, ret) + self:end_scope() - if tc and e1 then - tc.store_type(e1.y, e1.x, f) + if self.collector then + self.collector.store_type(e1.y, e1.x, f) end if f and f.macroexp then @@ -8943,9 +8999,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer - if lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then - return UNKNOWN, nil + function TypeChecker:check_metamethod(node: Node, method_name: string, a: Type, b: Type, orig_a: Type, orig_b: Type): Type, integer + if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then + return an_unknown(node), nil end local ameta = a is RecordLikeType and a.meta_fields local bmeta = b and b is RecordLikeType and b.meta_fields @@ -8966,26 +9022,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if metamethod then local e2 = { node.e1 } - local args = a_tuple { orig_a } + local args = a_tuple(node, { orig_a }) if b and method_name ~= "__is" then e2[2] = node.e2 args.tuple[2] = orig_b end - return to_structural(resolve_tuple((type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator + return self:to_structural(resolve_tuple((self:type_check_function_call(node, metamethod, args, -1, node, e2)))), meta_on_operator else return nil, nil end end - local function match_record_key(tbl: Type, rec: Node, key: string): Type, string + function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string assert(type(tbl) == "table") assert(type(rec) == "table") assert(type(key) == "string") - tbl = to_structural(tbl) + tbl = self:to_structural(tbl) if tbl is StringType or tbl is EnumType then - tbl = find_var_type("string") -- simulate string metatable + tbl = self:find_var_type("string") -- simulate string metatable end if tbl is TypeDeclType then @@ -8994,13 +9050,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if tbl.is_nested_alias then return nil, "cannot use a nested type alias as a concrete value" else - tbl = resolve_nominal(tbl.alias_to) + tbl = self:resolve_nominal(tbl.alias_to) end end if tbl is UnionType then - local t = same_in_all_union_entries(tbl, function(t: Type): (Type, Type) - return (match_record_key(t, rec, key)) + local t = self:same_in_all_union_entries(tbl, function(t: Type): (Type, Type) + return (self:match_record_key(t, rec, key)) end) if t then @@ -9009,7 +9065,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if (tbl is TypeVarType or tbl is TypeArgType) and tbl.constraint then - local t = match_record_key(tbl.constraint, rec, key) + local t = self:match_record_key(tbl.constraint, rec, key) if t then return t @@ -9023,7 +9079,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return tbl.fields[key] end - local meta_t = check_metamethod(rec, "__index", tbl, STRING, tbl, STRING) + local str = a_type(rec, "string", {}) + local meta_t = self:check_metamethod(rec, "__index", tbl, str, tbl, str) if meta_t then return meta_t end @@ -9034,8 +9091,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return nil, "invalid key '" .. key .. "' in type %s" end elseif tbl is EmptyTableType or is_unknown(tbl) then - if lax then - return INVALID + if self.feat_lax then + return an_unknown(rec) end return nil, "cannot index a value of unknown type" end @@ -9047,30 +9104,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function widen_in_scope(scope: Scope, var: string): boolean - assert(scope[var], "no " .. var .. " in scope") - local narrow_mode = scope[var].is_narrowed - if narrow_mode and narrow_mode ~= "declaration" then - if scope[var].narrowed_from then - scope[var].t = scope[var].narrowed_from - scope[var].narrowed_from = nil - scope[var].is_narrowed = nil - else - scope[var] = nil - end + function TypeChecker:widen_in_scope(scope: Scope, var: string): boolean + local v = scope.vars[var] + assert(v, "no " .. var .. " in scope") + local narrow_mode = scope.vars[var].is_narrowed + if (not narrow_mode) or narrow_mode == "declaration" then + return false + end - local unresolved = get_unresolved(scope) - unresolved.narrows[var] = nil - return true + if v.narrowed_from then + v.t = v.narrowed_from + v.narrowed_from = nil + v.is_narrowed = nil + else + scope.vars[var] = nil + end + + if scope.narrows then + scope.narrows[var] = nil end - return false + + return true end - local function widen_back_var(name: string): boolean + function TypeChecker:widen_back_var(name: string): boolean local widened = false - for i = #st, 1, -1 do - if st[i][name] then - if widen_in_scope(st[i], name) then + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.vars[name] then + if self:widen_in_scope(scope, name) then widened = true else break @@ -9081,10 +9143,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end local function assigned_anywhere(name: string, root: Node): boolean - local visit_node: Visitor = { + local visit_node: Visitor = { cbs = { ["assignment"] = { - after = function(node: Node, _children: {boolean}): boolean + after = function(_: nil, node: Node, _children: {boolean}): boolean for _, v in ipairs(node.vars) do if v.kind == "variable" and v.tk == name then return true @@ -9094,7 +9156,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } }, - after = function(_node: Node, children: {boolean}, ret: boolean): boolean + after = function(_: nil, _node: Node, children: {boolean}, ret: boolean): boolean ret = ret or false for _, c in ipairs(children) do local ca = c as any @@ -9106,124 +9168,88 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end } - local visit_type: Visitor = { + local visit_type: Visitor = { after = function(): boolean return false end } - return recurse_node(root, visit_node, visit_type) + return recurse_node(nil, root, visit_node, visit_type) end - local function widen_all_unions(node?: Node) - for i = #st, 1, -1 do - local scope = st[i] - local unresolved = find_unresolved(i) - if unresolved and unresolved.narrows then - for name, _ in pairs(unresolved.narrows) do + function TypeChecker:widen_all_unions(node?: Node) + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.narrows then + for name, _ in pairs(scope.narrows) do if not node or assigned_anywhere(name, node) then - widen_in_scope(scope, name) + self:widen_in_scope(scope, name) end end end end end - local function add_global(node: Node, var: string, valtype: Type, is_assigning?: boolean): Variable - if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then - add_unknown(node, var) + function TypeChecker:add_global(node: Node, varname: string, valtype: Type, is_assigning?: boolean): Variable + if self.feat_lax and is_unknown(valtype) and (varname ~= "self" and varname ~= "...") then + self.errs:add_unknown(node, varname) end local is_const = node.attribute ~= nil - local existing, scope, existing_attr = find_var(var) + local existing, scope, existing_attr = self:find_var(varname) if existing then if scope > 1 then - error_at(node, "cannot define a global when a local with the same name is in scope") + self.errs:add(node, "cannot define a global when a local with the same name is in scope") elseif is_assigning and existing_attr then - error_at(node, "cannot reassign to <" .. existing_attr .. "> global: " .. var) + self.errs:add(node, "cannot reassign to <" .. existing_attr .. "> global: " .. varname) elseif existing_attr and not is_const then - error_at(node, "global was previously declared as <" .. existing_attr .. ">: " .. var) + self.errs:add(node, "global was previously declared as <" .. existing_attr .. ">: " .. varname) elseif (not existing_attr) and is_const then - error_at(node, "global was previously declared as not <" .. node.attribute .. ">: " .. var) - elseif valtype and not same_type(existing.t, valtype) then - error_at(node, "cannot redeclare global with a different type: previous type of " .. var .. " is %s", existing.t) + self.errs:add(node, "global was previously declared as not <" .. node.attribute .. ">: " .. varname) + elseif valtype and not self:same_type(existing.t, valtype) then + self.errs:add(node, "cannot redeclare global with a different type: previous type of " .. varname .. " is %s", existing.t) end return nil end - st[1][var] = { t = valtype, attribute = is_const and "const" or nil } - - return st[1][var] - end + local var = { t = valtype, attribute = is_const and "const" or nil } + self.st[1].vars[varname] = var - local get_rets: function(TupleType): TupleType - if lax then - get_rets = function(rets: TupleType): TupleType - if #rets.tuple == 0 then - return a_vararg { UNKNOWN } - end - return rets - end - else - get_rets = function(rets: TupleType): TupleType - return rets - end + return var end - local function add_internal_function_variables(node: Node, args: TupleType) - add_var(nil, "@is_va", args.is_va and ANY or NIL) - add_var(nil, "@return", node.rets or a_tuple({})) + function TypeChecker:add_internal_function_variables(node: Node, args: TupleType) + self:add_var(nil, "@is_va", a_type(node, args.is_va and "any" or "nil", {})) + self:add_var(nil, "@return", node.rets or a_tuple(node, {})) if node.typeargs then for _, t in ipairs(node.typeargs) do - local v = find_var(t.typearg, "check_only") + local v = self:find_var(t.typearg, "check_only") if not v or not v.used_as_type then - error_at(t, "type argument '%s' is not used in function signature", t) + self.errs:add(t, "type argument '%s' is not used in function signature", t) end end end end - local function add_function_definition_for_recursion(node: Node, fnargs: TupleType) - add_var(nil, node.name.tk, type_at(node, a_function { + function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType) + self:add_var(nil, node.name.tk, a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = fnargs, - rets = get_rets(node.rets), + rets = self.get_rets(node.rets), })) end - local function fail_unresolved() - local unresolved = st[#st]["@unresolved"] - if unresolved then - st[#st]["@unresolved"] = nil - local unrt = unresolved.t as UnresolvedType - for name, nodes in pairs(unrt.labels) do - for _, node in ipairs(nodes) do - error_at(node, "no visible label '" .. name .. "' for goto") - end - end - for name, types in pairs(unrt.nominals) do - if not unrt.global_types[name] then - for _, typ in ipairs(types) do - assert(typ.x) - assert(typ.y) - error_at(typ, "unknown type %s", typ) - end - end - end - end - end - - local function end_function_scope(node: Node) - fail_unresolved() - end_scope(node) + function TypeChecker:end_function_scope(node: Node) + self.errs:fail_unresolved_labels(self.st[#self.st]) + self:end_scope(node) end local function flatten_tuple(vals: TupleType): TupleType local vt = vals.tuple local n_vals = #vt - local ret = a_tuple {} + local ret = a_tuple(vals, {}) local rt = ret.tuple if n_vals == 0 then @@ -9251,9 +9277,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function get_assignment_values(vals: TupleType, wanted: integer): TupleType + local function get_assignment_values(w: Where, vals: TupleType, wanted: integer): TupleType if vals == nil then - return a_tuple {} + return a_tuple(w, {}) end local ret = flatten_tuple(vals) @@ -9272,14 +9298,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - local function match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type + function TypeChecker:match_all_record_field_names(node: Node, a: RecordLikeType, field_names: {string}, errmsg: string): Type local t: Type for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else - if not same_type(f, t) then + if not self:same_type(f, t) then errmsg = errmsg .. string.format(" (types of fields '%s' and '%s' do not match)", field_names[1], k) t = nil break @@ -9289,26 +9315,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if t then return t else - return invalid_at(node, errmsg) + return self.errs:invalid_at(node, errmsg) end end - local function type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type + function TypeChecker:type_check_index(anode: Node, bnode: Node, a: Type, b: Type): Type assert(not a is TupleType) assert(not b is TupleType) - local ra = resolve_typedecl(to_structural(a)) - local rb = to_structural(b) + local ra = resolve_typedecl(self:to_structural(a)) + local rb = self:to_structural(b) - if lax and is_unknown(a) then - return UNKNOWN + if self.feat_lax and is_unknown(a) then + return a end local errm: string local erra: Type local errb: Type - if ra is TupleTableType and is_a(rb, INTEGER) then + if ra is TupleTableType and rb is IntegerType then if bnode.constnum then if bnode.constnum >= 1 and bnode.constnum <= #ra.types and bnode.constnum == math.floor(bnode.constnum) then return ra.types[bnode.constnum as integer] @@ -9316,38 +9342,35 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra = "index " .. tostring(bnode.constnum) .. " out of range for tuple %s", ra else - local array_type = arraytype_from_tuple(bnode, ra) + local array_type = self:arraytype_from_tuple(bnode, ra) if array_type then return array_type.elements end errm = "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime" end - elseif ra is ArrayLikeType and is_a(rb, INTEGER) then + elseif ra is ArrayLikeType and rb is IntegerType then return ra.elements elseif ra is EmptyTableType then if ra.keys == nil then - ra.keys = infer_at(anode, b) + ra.keys = self:infer_at(bnode, b) end - if is_a(b, ra.keys) then - return type_at(anode, a_type("unresolved_emptytable_value", { + if self:is_a(b, ra.keys) then + return a_type(anode, "unresolved_emptytable_value", { emptytable_type = ra - } as UnresolvedEmptyTableValueType)) + } as UnresolvedEmptyTableValueType) end - errm, erra, errb = "inconsistent index type: got %s, expected %s (type of keys inferred at " - .. ra.keys.inferred_at.filename .. ":" - .. ra.keys.inferred_at.y .. ":" - .. ra.keys.inferred_at.x .. ": )", b, ra.keys + errm, erra, errb = "inconsistent index type: got %s, expected %s" .. inferred_msg(ra.keys, "type of keys "), b, ra.keys elseif ra is MapType then - if is_a(b, ra.keys) then + if self:is_a(b, ra.keys) then return ra.values end errm, erra, errb = "wrong index type: got %s, expected %s", b, ra.keys elseif rb is StringType and rb.literal then - local t, e = match_record_key(a, anode, rb.literal) + local t, e = self:match_record_key(a, anode, rb.literal) if t then return t end @@ -9363,10 +9386,10 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if not errm then - return match_all_record_field_names(bnode, ra, field_names, + return self:match_all_record_field_names(bnode, ra, field_names, "cannot index, not all enum values map to record fields of the same type") end - elseif is_a(rb, STRING) then + elseif rb is StringType then errm, erra = "cannot index object of type %s with a string, consider using an enum", a else errm, erra, errb = "cannot index object of type %s with %s", a, b @@ -9375,28 +9398,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string errm, erra, errb = "cannot index object of type %s with %s", a, b end - local meta_t = check_metamethod(anode, "__index", ra, b, a, b) + local meta_t = self:check_metamethod(anode, "__index", ra, b, a, b) if meta_t then return meta_t end - return invalid_at(bnode, errm, erra, errb) + return self.errs:invalid_at(bnode, errm, erra, errb) end - expand_type = function(where: Where, old: Type, new: Type): Type + function TypeChecker:expand_type(w: Where, old: Type, new: Type): Type if not old or old.typename == "nil" then return new else - if not is_a(new, old) then + if not self:is_a(new, old) then if old is MapType and new is RecordLikeType then local old_keys = old.keys if old_keys is StringType then for _, ftype in fields_of(new) do - old.values = expand_type(where, old.values, ftype) + old.values = self:expand_type(w, old.values, ftype) end - edit_type(old, "map") -- map changed, refresh typeid + edit_type(w, old, "map") -- map changed, refresh typeid else - error_at(where, "cannot determine table literal type") + self.errs:add(w, "cannot determine table literal type") end elseif old is RecordLikeType and new is RecordLikeType then local values: Type @@ -9404,14 +9427,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end for _, ftype in fields_of(new) do if not values then values = ftype else - values = expand_type(where, values, ftype) + values = self:expand_type(w, values, ftype) end end old.fields = nil @@ -9419,25 +9442,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string old.meta_fields = nil old.meta_fields = nil - edit_type(old, "map") + edit_type(w, old, "map") local map = old as MapType - map.keys = STRING + map.keys = a_type(w, "string", {}) map.values = values elseif old is UnionType then - edit_type(old, "union") + edit_type(w, old, "union") table.insert(old.types, drop_constant_value(new)) else - return unite({ old, new }, true) + return unite(w, { old, new }, true) end end end return old end - local function find_record_to_extend(exp: Node): Type, Variable, string + function TypeChecker:find_record_to_extend(exp: Node): Type, Variable, string -- base if exp.kind == "type_identifier" then - local v = find_var(exp.tk) + local v = self:find_var(exp.tk) if not v then return nil, nil, exp.tk end @@ -9454,7 +9477,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t, v, exp.tk -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t, v, rname = find_record_to_extend(exp.e1) + local t, v, rname = self:find_record_to_extend(exp.e1) local fname = exp.e2.tk local dname = rname .. "." .. fname if not t then @@ -9475,30 +9498,29 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function typedecl_to_nominal(where: Where, name: string, t: TypeDeclType, resolved?: Type): Type + local function typedecl_to_nominal(node: Node, name: string, t: TypeDeclType, resolved?: Type): Type local typevals: {Type} local def = t.def if def is HasTypeArgs then typevals = {} for _, a in ipairs(def.typeargs) do - table.insert(typevals, a_type("typevar", { + table.insert(typevals, a_type(a, "typevar", { typevar = a.typearg, constraint = a.constraint, } as TypeVarType)) end end - return type_at(where, a_type("nominal", { - typevals = typevals, - names = { name }, - found = t, - resolved = resolved, - } as NominalType)) + local nom = a_nominal(node, { name }) + nom.typevals = typevals + nom.found = t + nom.resolved = resolved + return nom end - local function get_self_type(exp: Node): Type + function TypeChecker:get_self_type(exp: Node): Type -- base if exp.kind == "type_identifier" then - local t = find_var_type(exp.tk) + local t = self:find_var_type(exp.tk) if not t then return nil end @@ -9510,7 +9532,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- recurse elseif exp.kind == "op" then -- assert(exp.op.op == ".") - local t = get_self_type(exp.e1) + local t = self:get_self_type(exp.e1) if not t then return nil end @@ -9539,10 +9561,9 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- Inference engine for 'is' operator - local facts_and: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_or: function(where: Where, f1: Fact, f2: Fact): Fact - local facts_not: function(where: Where, f1: Fact): Fact - local apply_facts: function(where: Where, known: Fact) + local facts_and: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_or: function(w: Where, f1: Fact, f2: Fact): Fact + local facts_not: function(w: Where, f1: Fact): Fact local FACT_TRUTHY: Fact do local IsFact_mt: metatable = { @@ -9554,6 +9575,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(IsFact, { __call = function(_: IsFact, fact: Fact): IsFact fact.fact = "is" + assert(fact.w) return setmetatable(fact as IsFact, IsFact_mt) end, }) @@ -9567,6 +9589,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string setmetatable(EqFact, { __call = function(_: EqFact, fact: Fact): EqFact fact.fact = "==" + assert(fact.w) return setmetatable(fact as EqFact, EqFact_mt) end, }) @@ -9625,57 +9648,57 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string FACT_TRUTHY = TruthyFact {} - facts_and = function(where: Where, f1: Fact, f2: Fact): Fact - return AndFact { f1 = f1, f2 = f2, where = where } + facts_and = function(w: Where, f1: Fact, f2: Fact): Fact + return AndFact { f1 = f1, f2 = f2, w = w } end - facts_or = function(where: Where, f1: Fact, f2: Fact): Fact + facts_or = function(w: Where, f1: Fact, f2: Fact): Fact if f1 and f2 then - return OrFact { f1 = f1, f2 = f2, where = where } + return OrFact { f1 = f1, f2 = f2, w = w } else return nil end end - facts_not = function(where: Where, f1: Fact): Fact + facts_not = function(w: Where, f1: Fact): Fact if f1 then - return NotFact { f1 = f1, where = where } + return NotFact { f1 = f1, w = w } else return nil end end -- t1 ∪ t2 - local function unite_types(t1: Type, t2: Type): Type, string - return unite({t2, t1}) + local function unite_types(w: Where, t1: Type, t2: Type): Type, string + return unite(w, {t2, t1}) end -- t1 ∩ t2 - local function intersect_types(t1: Type, t2: Type): Type, string + local function intersect_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type, string if t2 is UnionType then t1, t2 = t2, t1 end if t1 is UnionType then local out = {} for _, t in ipairs(t1.types) do - if is_a(t, t2) then + if self:is_a(t, t2) then table.insert(out, t) end end - return unite(out) + return unite(w, out) else - if is_a(t1, t2) then + if self:is_a(t1, t2) then return t1 - elseif is_a(t2, t1) then + elseif self:is_a(t2, t1) then return t2 else - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end end end - local function resolve_if_union(t: Type): Type - local rt = to_structural(t) + function TypeChecker:resolve_if_union(t: Type): Type + local rt = self:to_structural(t) if rt is UnionType then return rt end @@ -9683,23 +9706,23 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end -- t1 - t2 - local function subtract_types(t1: Type, t2: Type): Type + local function subtract_types(self: TypeChecker, w: Where, t1: Type, t2: Type): Type local types: {Type} = {} - t1 = resolve_if_union(t1) + t1 = self:resolve_if_union(t1) -- poly are not first-class, so we don't handle them here if not t1 is UnionType then return t1 end - t2 = resolve_if_union(t2) + t2 = self:resolve_if_union(t2) local t2types = t2 is UnionType and t2.types or { t2 } for _, at in ipairs(t1.types) do local not_present = true for _, bt in ipairs(t2types) do - if same_type(at, bt) then + if self:same_type(at, bt) then not_present = false break end @@ -9710,78 +9733,78 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if #types == 0 then - return NIL -- because of implicit nil in all unions + return a_type(w, "nil", {}) -- because of implicit nil in all unions end - return unite(types) + return unite(w, types) end - local eval_not: function(f: Fact): {string:IsFact|EqFact} - local not_facts: function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local or_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local and_facts: function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} - local eval_fact: function(f: Fact): {string:IsFact|EqFact} + local eval_not: function(TypeChecker, f: Fact): {string:IsFact|EqFact} + local not_facts: function(TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local or_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local and_facts: function(TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + local eval_fact: function(TypeChecker, f: Fact): {string:IsFact|EqFact} local function invalid_from(f: IsFact): IsFact - return IsFact { fact = "is", var = f.var, typ = INVALID, where = f.where } + return IsFact { fact = "is", var = f.var, typ = a_type(f.w, "invalid", {}), w = f.w } end - not_facts = function(fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} + not_facts = function(self: TypeChecker, fs: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs) do - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } elseif f is EqFact then -- nothing is known from negation of equality; widen back - ret[var] = EqFact { var = var, typ = typ } - elseif typ.typename == "typevar" then + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif typ is TypeVarType then assert(f.fact == "is") - -- nothing is known from negation on typeargs; widen back (no 'where') - ret[var] = EqFact { var = var, typ = typ } - elseif not is_a(f.typ, typ) then + -- nothing is known from negation on typeargs; widen back + ret[var] = EqFact { var = var, typ = typ, w = f.w, no_infer = true } + elseif not self:is_a(f.typ, typ) then assert(f.fact == "is") - add_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) - ret[var] = EqFact { var = var, typ = INVALID, where = f.where } + self.errs:add_warning("branch", f.w, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) + ret[var] = EqFact { var = var, typ = an_invalid(f.w), w = f.w, no_infer = f.no_infer } else assert(f.fact == "is") - ret[var] = IsFact { var = var, typ = subtract_types(typ, f.typ), where = f.where } + ret[var] = IsFact { var = var, typ = subtract_types(self, f.w, typ, f.typ), w = f.w, no_infer = f.no_infer } end end return ret end - eval_not = function(f: Fact): {string:IsFact|EqFact} + eval_not = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - return not_facts({[f.var] = f}) + return not_facts(self, {[f.var] = f}) elseif f is NotFact then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is AndFact then - return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return or_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) elseif f is OrFact then - return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) + return and_facts(self, not_facts(self, eval_fact(self, f.f1)), not_facts(self, eval_fact(self, f.f2))) else - return not_facts(eval_fact(f)) + return not_facts(self, eval_fact(self, f)) end end - or_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + or_facts = function(_self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} for var, f in pairs(fs2) do if fs1[var] then - local united = unite_types(f.typ, fs1[var].typ) + local united = unite_types(f.w, f.typ, fs1[var].typ) if fs1[var].fact == "is" and f.fact == "is" then - ret[var] = IsFact { var = var, typ = united, where = f.where } + ret[var] = IsFact { var = var, typ = united, w = f.w } else - ret[var] = EqFact { var = var, typ = united, where = f.where } + ret[var] = EqFact { var = var, typ = united, w = f.w } end end end @@ -9789,7 +9812,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - and_facts = function(fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} + and_facts = function(self: TypeChecker, fs1: {string:IsFact|EqFact}, fs2: {string:IsFact|EqFact}): {string:IsFact|EqFact} local ret: {string:IsFact|EqFact} = {} local has: {FactType:boolean} = {} @@ -9800,18 +9823,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if fs2[var].fact == "is" and f.fact == "is" then ctor = IsFact end - rt = intersect_types(f.typ, fs2[var].typ) + rt = intersect_types(self, f.w, f.typ, fs2[var].typ) else rt = f.typ end - local ff = ctor { var = var, typ = rt, where = f.where } + local ff = ctor { var = var, typ = rt, w = f.w, no_infer = f.no_infer } ret[var] = ff has[ff.fact] = true end for var, f in pairs(fs2) do if not fs1[var] then - ret[var] = EqFact { var = var, typ = f.typ, where = f.where } + ret[var] = EqFact { var = var, typ = f.typ, w = f.w, no_infer = f.no_infer } has["=="] = true end end @@ -9825,21 +9848,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return ret end - eval_fact = function(f: Fact): {string:IsFact|EqFact} + eval_fact = function(self: TypeChecker, f: Fact): {string:IsFact|EqFact} if not f then return {} elseif f is IsFact then - local typ = find_var_type(f.var, "check_only") + local typ = self:find_var_type(f.var, "check_only") if not typ then return { [f.var] = invalid_from(f) } end if typ.typename ~= "typevar" then - if is_a(typ, f.typ) then + if self:is_a(typ, f.typ) then -- drop this warning because of implicit nil in all unions - -- add_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) + -- self.errs:add_warning("branch", f.w, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) return { [f.var] = f } - elseif not is_a(f.typ, typ) then - error_at(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) + elseif not self:is_a(f.typ, typ) then + self.errs:add(f.w, f.var .. " (of type %s) can never be a %s", typ, f.typ) return { [f.var] = invalid_from(f) } end end @@ -9847,63 +9870,60 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string elseif f is EqFact then return { [f.var] = f } elseif f is NotFact then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is TruthyFact then return {} elseif f is AndFact and f.f2 and f.f2.fact == "truthy" then - return eval_fact(f.f1) + return eval_fact(self, f.f1) elseif f is OrFact and f.f2 and f.f2.fact == "truthy" then - return eval_not(f.f1) + return eval_not(self, f.f1) elseif f is AndFact then - return and_facts(eval_fact(f.f1), eval_fact(f.f2)) + return and_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) elseif f is OrFact then - return or_facts(eval_fact(f.f1), eval_fact(f.f2)) + return or_facts(self, eval_fact(self, f.f1), eval_fact(self, f.f2)) end end - apply_facts = function(where: Where, known: Fact) + function TypeChecker:apply_facts(w: Where, known: Fact) if not known then return end - local facts = eval_fact(known) + local facts = eval_fact(self, known) for v, f in pairs(facts) do if f.typ.typename == "invalid" then - error_at(where, "cannot resolve a type for " .. v .. " here") + self.errs:add(w, "cannot resolve a type for " .. v .. " here") end - local t = infer_at(where, f.typ) - if not f.where then + local t = f.no_infer and f.typ or self:infer_at(w, f.typ) + if f.no_infer then t.inferred_at = nil end - add_var(nil, v, t, "const", "narrow") + self:add_var(nil, v, t, "const", "narrow") end end end - local function dismiss_unresolved(name: string) - for i = #st, 1, -1 do - local unresolved = find_unresolved(i) - if unresolved then - local uses = unresolved.nominals[name] - if uses then - for _, t in ipairs(uses) do - resolve_nominal(t) - end - unresolved.nominals[name] = nil - return + function TypeChecker:dismiss_unresolved(name: string) + for i = #self.st, 1, -1 do + local scope = self.st[i] + local uses = scope.pending_nominals and scope.pending_nominals[name] + if uses then + for _, t in ipairs(uses) do + self:resolve_nominal(t) end + scope.pending_nominals[name] = nil + return end end end - local type_check_funcall: function(node: Node, a: Type, b: Type, argdelta?: integer): InvalidOrTupleType - - local function special_pcall_xpcall(node: Node, _a: Type, b: TupleType, argdelta: integer): Type + local function special_pcall_xpcall(self: TypeChecker, node: Node, _a: Type, b: TupleType, argdelta: integer): Type local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 + local bool = a_type(node, "boolean", {}) if #node.e2 < base_nargs then - error_at(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") - return a_tuple { BOOLEAN } + self.errs:add(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") + return a_tuple(node, { bool }) end -- The function called by pcall/xpcall is invoked as a regular function, @@ -9915,137 +9935,142 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string ftype.is_method = false end - local fe2: Node = {} + local fe2: Node = node_at(node.e2, {}) if node.e1.tk == "xpcall" then base_nargs = 2 + local arg2 = node.e2[2] local msgh = table.remove(b.tuple, 1) - assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") + local msgh_type = a_function(arg2, { + min_arity = 1, + args = a_tuple(arg2, { a_type(arg2, "any", {}) }), + rets = a_tuple(arg2, {}) + }) + self:assert_is_a(arg2, msgh, msgh_type, "in message handler") end for i = base_nargs + 1, #node.e2 do table.insert(fe2, node.e2[i]) end - local fnode: Node = { - y = node.y, - x = node.x, + local fnode: Node = node_at(node, { kind = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, - } - local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) + }) + local rets = self:type_check_funcall(fnode, ftype, b, argdelta + base_nargs) if rets is InvalidType then return rets end - table.insert(rets.tuple, 1, BOOLEAN) + table.insert(rets.tuple, 1, bool) return rets end - local special_functions: {string : function(Node,Type,TupleType,integer):InvalidOrTupleType } = { - ["pairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + local special_functions: {string : function(TypeChecker, Node,Type,TupleType,integer):InvalidOrTupleType } = { + ["pairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "pairs requires an argument") + return self.errs:invalid_at(node, "pairs requires an argument") end - local t = to_structural(b.tuple[1]) + local t = self:to_structural(b.tuple[1]) if t is ArrayLikeType then - add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") + self.errs:add_warning("hint", node, "hint: applying pairs on an array: did you intend to apply ipairs?") end if t.typename ~= "map" then - if not (lax and is_unknown(t)) then + if not (self.feat_lax and is_unknown(t)) then if t is RecordLikeType then - match_all_record_field_names(node.e2, t, t.field_order, + self:match_all_record_field_names(node.e2, t, t.field_order, "attempting pairs on a record with attributes of different types") local ct = t.typename == "record" and "{string:any}" or "{any:any}" - add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) + self.errs:add_warning("hint", node.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) else - error_at(node.e2, "cannot apply pairs on values of type: %s", t) + self.errs:add(node.e2, "cannot apply pairs on values of type: %s", t) end end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["ipairs"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["ipairs"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType if not b.tuple[1] then - return invalid_at(node, "ipairs requires an argument") + return self.errs:invalid_at(node, "ipairs requires an argument") end local orig_t = b.tuple[1] - local t = to_structural(orig_t) + local t = self:to_structural(orig_t) if t is TupleTableType then - local arr_type = arraytype_from_tuple(node.e2, t) + local arr_type = self:arraytype_from_tuple(node.e2, t) if not arr_type then - return invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) + return self.errs:invalid_at(node.e2, "attempting ipairs on tuple that's not a valid array: %s", orig_t) end elseif not t is ArrayLikeType then - if not (lax and (is_unknown(t) or t is EmptyTableType)) then - return invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) + if not (self.feat_lax and (is_unknown(t) or t is EmptyTableType)) then + return self.errs:invalid_at(node.e2, "attempting ipairs on something that's not an array: %s", orig_t) end end - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end, - ["rawget"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["rawget"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType -- TODO should those offsets be fixed by _argdelta? if #b.tuple == 2 then - return a_tuple({ type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) + return a_tuple(node, { self:type_check_index(node.e2[1], node.e2[2], b.tuple[1], b.tuple[2]) }) else - return invalid_at(node, "rawget expects two arguments") + return self.errs:invalid_at(node, "rawget expects two arguments") end end, - ["require"] = function(node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType + ["require"] = function(self: TypeChecker, node: Node, _a: Type, b: TupleType, _argdelta: integer): InvalidOrTupleType if #b.tuple ~= 1 then - return invalid_at(node, "require expects one literal argument") + return self.errs:invalid_at(node, "require expects one literal argument") end if node.e2[1].kind ~= "string" then - return invalid_at(node, "don't know how to resolve a dynamic require") + return self.errs:invalid_at(node, "don't know how to resolve a dynamic require") end local module_name = assert(node.e2[1].conststr) - local t, found = require_module(module_name, lax, env) - if not found then - return invalid_at(node, "module not found: '" .. module_name .. "'") - end + local t, module_filename = require_module(node, module_name, self.feat_lax, self.env) if t.typename == "invalid" then - if lax then - return a_tuple({ UNKNOWN }) + if not module_filename then + return self.errs:invalid_at(node, "module not found: '" .. module_name .. "'") + end + + if self.feat_lax then + return a_tuple(node, { an_unknown(node) }) end - return invalid_at(node, "no type information for required module: '" .. module_name .. "'") + return self.errs:invalid_at(node, "no type information for required module: '" .. module_name .. "'") end - dependencies[module_name] = t.filename - return type_at(node, a_tuple({ t })) + self.dependencies[module_name] = module_filename + return a_tuple(node, { t }) end, ["pcall"] = special_pcall_xpcall, ["xpcall"] = special_pcall_xpcall, - ["assert"] = function(node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType + ["assert"] = function(self: TypeChecker, node: Node, a: Type, b: TupleType, argdelta: integer): InvalidOrTupleType node.known = FACT_TRUTHY - local r = type_check_function_call(node, a, b, argdelta) - apply_facts(node, node.e2[1].known) + local r = self:type_check_function_call(node, a, b, argdelta) + self:apply_facts(node, node.e2[1].known) return r end, } - type_check_funcall = function(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType + function TypeChecker:type_check_funcall(node: Node, a: Type, b: TupleType, argdelta?: integer): InvalidOrTupleType argdelta = argdelta or 0 if node.e1.kind == "variable" then local special = special_functions[node.e1.tk] if special then - return special(node, a, b, argdelta) + return special(self, node, a, b, argdelta) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end elseif node.e1.op and node.e1.op.op == ":" then table.insert(b.tuple, 1, node.e1.receiver) - return (type_check_function_call(node, a, b, -1)) + return (self:type_check_function_call(node, a, b, -1)) else - return (type_check_function_call(node, a, b, argdelta)) + return (self:type_check_function_call(node, a, b, argdelta)) end end @@ -10057,19 +10082,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[i].tk == node.vars[i].tk end - local function missing_initializer(node: Node, i: integer, name: string): Type - if lax then - return UNKNOWN + function TypeChecker:missing_initializer(node: Node, i: integer, name: string): (InvalidType | UnknownType) + if self.feat_lax then + return an_unknown(node) else if node.exps then - return invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") + return self.errs:invalid_at(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") else - return invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") + return self.errs:invalid_at(node.vars[i], "variable '" .. name .. "' has no type or initial value") end end end - local function set_expected_types_to_decltuple(node: Node, children: {Type}) + local function set_expected_types_to_decltuple(_: TypeChecker, node: Node, children: {Type}) local decltuple = node.kind == "assignment" and children[1] or node.decltuple assert(decltuple is TupleType) local decls = decltuple.tuple @@ -10081,7 +10106,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string typ = decls[i] if typ then if i == nexps and ndecl > nexps then - typ = type_at(node, a_tuple {}) + typ = a_tuple(node, {}) for a = i, ndecl do table.insert(typ.tuple, decls[a]) end @@ -10097,38 +10122,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return n and n >= 1 and math.floor(n) == n end - local context_name: {NodeKind: string} = { - ["local_declaration"] = "in local declaration", - ["global_declaration"] = "in global declaration", - ["assignment"] = "in assignment", - } - - local function in_context(ctx: Node.ExpectedContext, msg: string): string - if not ctx then - return msg - end - local where = context_name[ctx.kind] - if where then - return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg - else - return msg - end - end - - local type CheckableKey = string | number | boolean - - local function check_redeclared_key(where: Where, ctx: Node.ExpectedContext, seen_keys: {CheckableKey:Where}, key: CheckableKey) - if key ~= nil then - local s = seen_keys[key] - if s then - error_at(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) - else - seen_keys[key] = where - end - end - end - - local function infer_table_literal(node: Node, children: {LiteralTableItemType}): Type + local function infer_table_literal(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type local is_record = false local is_array = false local is_map = false @@ -10153,14 +10147,15 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end local key: CheckableKey = ck or n or b - check_redeclared_key(node[i], nil, seen_keys, key) + self.errs:check_redeclared_key(node[i], nil, seen_keys, key) local uvtype = resolve_tuple(child.vtype) if ck then @@ -10171,7 +10166,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end fields[ck] = uvtype table.insert(field_order, ck) - elseif is_number_type(child.ktype) then + elseif cktype is NumericType then is_array = true if not is_not_tuple then is_tuple = true @@ -10185,25 +10180,25 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if i == #children and cv is TupleType then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for _, c in ipairs(cv.tuple) do - elements = expand_type(node, elements, c) + elements = self:expand_type(node, elements, c) types[last_array_idx] = resolve_tuple(c) last_array_idx = last_array_idx + 1 end else types[last_array_idx] = uvtype last_array_idx = last_array_idx + 1 - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end else -- explicit if not is_positive_int(n) then - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) is_not_tuple = true elseif n then types[n as integer] = uvtype if n > largest_array_idx then largest_array_idx = n as integer end - elements = expand_type(node, elements, uvtype) + elements = self:expand_type(node, elements, uvtype) end end @@ -10215,37 +10210,37 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else is_map = true - keys = expand_type(node, keys, drop_constant_value(child.ktype)) - values = expand_type(node, values, uvtype) + keys = self:expand_type(node, keys, drop_constant_value(cktype)) + values = self:expand_type(node, values, uvtype) end end local t: Type if is_array and is_map then - error_at(node, "cannot determine type of table literal") - t = a_map( - expand_type(node, keys, INTEGER), - expand_type(node, values, elements) + self.errs:add(node, "cannot determine type of table literal") + t = a_map(node, + self:expand_type(node, keys, a_type(node, "integer", {})), + self:expand_type(node, values, elements) ) elseif is_record and is_array then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, elements = elements, interface_list = { - type_at(node, an_array(elements)) + an_array(node, elements) } } as RecordType) - -- TODO adopt logic from is_array below when we accept tupletable as an interface + -- TODO adopt logic from self:is_array below when we accept tupletable as an interface elseif is_record and is_map then if keys is StringType then for _, fname in ipairs(field_order) do - values = expand_type(node, values, fields[fname]) + values = self:expand_type(node, values, fields[fname]) end - t = a_map(keys, values) + t = a_map(node, keys, values) else - error_at(node, "cannot determine type of table literal") + self.errs:add(node, "cannot determine type of table literal") end elseif is_array then local pure_array = true @@ -10253,7 +10248,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local last_t: Type for _, current_t in pairs(types as {integer:Type}) do if last_t then - if not same_type(last_t, current_t) then + if not self:same_type(last_t, current_t) then pure_array = false break end @@ -10262,69 +10257,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end if pure_array then - t = an_array(elements) + t = an_array(node, elements) t.consttypes = types t.inferred_len = largest_array_idx - 1 else - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types end elseif is_record then - t = a_type("record", { + t = a_type(node, "record", { fields = fields, field_order = field_order, } as RecordType) elseif is_map then - t = a_map(keys, values) + t = a_map(node, keys, values) elseif is_tuple then - t = a_type("tupletable", {}) as TupleTableType + t = a_type(node, "tupletable", { inferred_at = node }) as TupleTableType t.types = types if not types or #types == 0 then - error_at(node, "cannot determine type of tuple elements") + self.errs:add(node, "cannot determine type of tuple elements") end end if not t then - t = a_type("emptytable", {}) + t = a_type(node, "emptytable", {}) end return type_at(node, t) end - local function infer_negation_of_if_blocks(where: Where, ifnode: Node, n: integer) - local f = facts_not(where, ifnode.if_blocks[1].exp.known) + function TypeChecker:infer_negation_of_if_blocks(w: Where, ifnode: Node, n: integer) + local f = facts_not(w, ifnode.if_blocks[1].exp.known) for e = 2, n do local b = ifnode.if_blocks[e] if b.exp then - f = facts_and(where, f, facts_not(where, b.exp.known)) + f = facts_and(w, f, facts_not(w, b.exp.known)) end end - apply_facts(where, f) + self:apply_facts(w, f) end - local function determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean + function TypeChecker:determine_declaration_type(var: Node, node: Node, infertypes: TupleType, i: integer): boolean, Type, boolean local ok = true local name = var.tk local infertype = infertypes and infertypes.tuple[i] - if lax and infertype and infertype.typename == "nil" then + if self.feat_lax and infertype and infertype.typename == "nil" then infertype = nil end local decltype = node.decltuple and node.decltuple.tuple[i] if decltype then - if to_structural(decltype) == INVALID then - decltype = INVALID + local rdecltype = self:to_structural(decltype) + if rdecltype is InvalidType then + decltype = rdecltype end if infertype then - ok = assert_is_a(node.vars[i], infertype, decltype, context_name[node.kind], name) + local w = node.exps and node.exps[i] or node.vars[i] + ok = self:assert_is_a(w, infertype, decltype, context_name[node.kind], name) end else if infertype then if infertype is UnresolvableTypeArgType then - error_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") ok = false - infertype = INVALID + infertype = self.errs:invalid_at(node.vars[i], "cannot infer declaration type; an explicit type annotation is necessary") elseif infertype is FunctionType and infertype.is_method then -- If we assign a method to a variable, e.g: -- `local myfunc = myobj.dothing`, @@ -10336,17 +10332,17 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if var.attribute == "total" then - local rd = decltype and to_structural(decltype) + local rd = decltype and self:to_structural(decltype) if rd and (rd.typename ~= "map" and rd.typename ~= "record") then - error_at(var, "attribute only applies to maps and records") + self.errs:add(var, "attribute only applies to maps and records") ok = false elseif not infertype then - error_at(var, "variable declared does not declare an initialization value") + self.errs:add(var, "variable declared does not declare an initialization value") ok = false else local valnode = node.exps[i] if not valnode or valnode.kind ~= "literal_table" then - error_at(var, "attribute only applies to literal tables") + self.errs:add(var, "attribute only applies to literal tables") ok = false else if not valnode.is_total then @@ -10354,12 +10350,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if valnode.missing then missing = " (missing: " .. table.concat(valnode.missing, ", ") .. ")" end - local ri = to_structural(infertype) + local ri = self:to_structural(infertype) if ri is MapType then - error_at(var, "map variable declared does not declare values for all possible keys" .. missing) + self.errs:add(var, "map variable declared does not declare values for all possible keys" .. missing) ok = false elseif ri is RecordType then - error_at(var, "record variable declared does not declare values for all fields" .. missing) + self.errs:add(var, "record variable declared does not declare values for all fields" .. missing) ok = false end end @@ -10369,34 +10365,36 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local t = decltype or infertype if t == nil then - t = missing_initializer(node, i, name) + t = self:missing_initializer(node, i, name) elseif t is EmptyTableType then t.declared_at = node t.assigned_to = name elseif t is ArrayLikeType then t.inferred_len = nil + elseif t is NominalType then + self:resolve_nominal(t) end return ok, t, infertype ~= nil end - local function get_typedecl(value: Node): TypeDeclType, Variable + function TypeChecker:get_typedecl(value: Node): TypeDeclType, Variable if value.kind == "op" and value.op.op == "@funcall" and value.e1.kind == "variable" and value.e1.tk == "require" then - local t = special_functions["require"](value, find_var_type("require"), a_tuple { STRING }, 0) + local t = special_functions["require"](self, value, self:find_var_type("require"), a_tuple(value.e2, { a_type(value.e2[1], "string", {}) }), 0) local ty = t is TupleType and t.tuple[1] or t - ty = (ty is TypeAliasType) and resolve_typealias(ty) or ty - local td = (ty is TypeDeclType) and ty or a_type("typedecl", { def = ty } as TypeDeclType) + ty = (ty is TypeAliasType) and self:resolve_typealias(ty) or ty + local td = (ty is TypeDeclType) and ty or a_type(value, "typedecl", { def = ty } as TypeDeclType) return td else local newtype = value.newtype if newtype is TypeAliasType then - local aliasing = find_var(newtype.alias_to.names[1], "use_type") - return resolve_typealias(newtype), aliasing - else + local aliasing = self:find_var(newtype.alias_to.names[1], "use_type") + return self:resolve_typealias(newtype), aliasing + elseif newtype is TypeDeclType then return newtype, nil end end @@ -10427,15 +10425,14 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return is_total, missing end - local function total_map_check(t: MapType, seen_keys: {CheckableKey:Where}): boolean, {string} - local k = to_structural(t.keys) + local function total_map_check(keys: Type, seen_keys: {CheckableKey:Where}): boolean, {string} local is_total = true local missing: {string} - if k is EnumType then - for _, key in ipairs(sorted_keys(k.enumset)) do + if keys is EnumType then + for _, key in ipairs(sorted_keys(keys.enumset)) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end - elseif k.typename == "boolean" then + elseif keys.typename == "boolean" then for _, key in ipairs({ true, false }) do is_total, missing = total_check_key(key, seen_keys, is_total, missing) end @@ -10449,35 +10446,38 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string "missing" end - local function check_assignment(where: Where, vartype: Type, valtype: Type, varname: string, attr: Attribute): Type, Type, MissingError + function TypeChecker:check_assignment(varnode: Node, vartype: Type, valtype: Type): Type, Type, MissingError + local varname = varnode.tk + local attr = varnode.attribute + if varname then - if widen_back_var(varname) then - vartype, attr = find_var_type(varname) + if self:widen_back_var(varname) then + vartype, attr = self:find_var_type(varname) if not vartype then - error_at(where, "unknown variable") + self.errs:add(varnode, "unknown variable") return nil end end end if attr == "close" or attr == "const" or attr == "total" then - error_at(where, "cannot assign to <" .. attr .. "> variable") + self.errs:add(varnode, "cannot assign to <" .. attr .. "> variable") return nil end - local var = to_structural(vartype) + local var = self:to_structural(vartype) if var is TypeDeclType or var is TypeAliasType then - error_at(where, "cannot reassign a type") + self.errs:add(varnode, "cannot reassign a type") return nil end if not valtype then - error_at(where, "variable is not being assigned a value") + self.errs:add(varnode, "variable is not being assigned a value") return nil, nil, "missing" end - assert_is_a(where, valtype, vartype, "in assignment") + self:assert_is_a(varnode, valtype, vartype, "in assignment") - local val = to_structural(valtype) + local val = self:to_structural(valtype) return var, val end @@ -10489,185 +10489,186 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return resolve_tuple(t) end - local visit_node: Visitor = {} + local visit_node: Visitor = {} visit_node.cbs = { ["statements"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type -- if at the top level - if #st == 2 then - fail_unresolved() + if #self.st == 2 then + self.errs:fail_unresolved_labels(self.st[2]) + self.errs:fail_unresolved_nominals(self.st[2], self.st[1]) end if not node.is_repeat then - end_scope(node) + self:end_scope(node) end - -- TODO extract node type from `return` + return NONE end }, ["local_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) local name = node.var.tk - local resolved, aliasing = get_typedecl(node.value) - local var = add_var(node.var, name, resolved, node.var.attribute) + local resolved, aliasing = self:get_typedecl(node.value) + local var = self:add_var(node.var, name, resolved, node.var.attribute) if aliasing then var.aliasing = aliasing end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["global_type"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) + local global_scope = self.st[1] local name = node.var.tk - local unresolved = get_unresolved() if node.value then - local resolved, aliasing = get_typedecl(node.value) - local added = add_global(node.var, name, resolved) + local resolved, aliasing = self:get_typedecl(node.value) + local added = self:add_global(node.var, name, resolved) node.value.newtype = resolved if aliasing then added.aliasing = aliasing end - if added and unresolved.global_types[name] then - unresolved.global_types[name] = nil + if global_scope.pending_global_types[name] then + global_scope.pending_global_types[name] = nil end else - if not st[1][name] then - unresolved.global_types[name] = true + if not self.st[1].vars[name] then + global_scope.pending_global_types[name] = true end end end, - after = function(node: Node, _children: {Type}): Type - dismiss_unresolved(node.var.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:dismiss_unresolved(node.var.tk) return NONE end, }, ["local_declaration"] = { - before = function(node: Node) - if tc then + before = function(self: TypeChecker, node: Node) + if self.collector then for _, var in ipairs(node.vars) do - tc.reserve_symbol_list_slot(var) + self.collector.reserve_symbol_list_slot(var) end end end, before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil local encountered_close = false - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do if var.attribute == "close" then - if opts.gen_target == "5.4" then + if self.gen_target == "5.4" then if encountered_close then - error_at(var, "only one per declaration is allowed") + self.errs:add(var, "only one per declaration is allowed") else encountered_close = true end else - error_at(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(opts.gen_target) .. ")") + self.errs:add(var, " attribute is only valid for Lua 5.4 (current target is " .. tostring(self.gen_target) .. ")") end end - local ok, t = determine_declaration_type(var, node, infertypes, i) + local ok, t = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then if not type_is_closable(t) then - error_at(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) + self.errs:add(var, "to-be-closed variable " .. var.tk .. " has a non-closable type %s", t) elseif node.exps and node.exps[i] and expr_is_definitely_not_closable(node.exps[i]) then - error_at(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") + self.errs:add(var, "to-be-closed variable " .. var.tk .. " assigned a non-closable value") end end assert(var) - add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") + self:add_var(var, var.tk, t, var.attribute, is_localizing_a_variable(node, i) and "declaration") local infertype = infertypes.tuple[i] if ok and infertype then - local where = node.exps[i] or node.exps + local w = node.exps[i] or node.exps - local rt = to_structural(t) + local rt = self:to_structural(t) if (not rt is EnumType) and ((not t is NominalType) or (rt is UnionType)) - and not same_type(t, infertype) + and not self:same_type(t, infertype) then - t = infer_at(where, infertype) - add_var(where, var.tk, t, "const", "narrowed_declaration") + t = self:infer_at(w, infertype) + self:add_var(w, var.tk, t, "const", "narrowed_declaration") end end - if tc then - tc.store_type(var.y, var.x, t) + if self.collector then + self.collector.store_type(var.y, var.x, t) end - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["global_declaration"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local valtuple = children[3] as TupleType -- may be nil - local infertypes = get_assignment_values(valtuple, #node.vars) + local infertypes = get_assignment_values(node, valtuple, #node.vars) for i, var in ipairs(node.vars) do - local _, t, is_inferred = determine_declaration_type(var, node, infertypes, i) + local _, t, is_inferred = self:determine_declaration_type(var, node, infertypes, i) if var.attribute == "close" then - error_at(var, "globals may not be ") + self.errs:add(var, "globals may not be ") end - add_global(var, var.tk, t, is_inferred) + self:add_global(var, var.tk, t, is_inferred) - dismiss_unresolved(var.tk) + self:dismiss_unresolved(var.tk) end return NONE end, }, ["assignment"] = { before_exp = set_expected_types_to_decltuple, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local vartuple = children[1] assert(vartuple is TupleType) local vartypes = vartuple.tuple local valtuple = children[3] assert(valtuple is TupleType) - local valtypes = get_assignment_values(valtuple, #vartypes) + local valtypes = get_assignment_values(node, valtuple, #vartypes) for i, vartype in ipairs(vartypes) do local varnode = node.vars[i] local varname = varnode.tk local valtype = valtypes.tuple[i] - local rvar, rval, err = check_assignment(varnode, vartype, valtype, varname, varnode.attribute) + local rvar, rval, err = self:check_assignment(varnode, vartype, valtype) if err == "missing" then if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then local msg = #valtuple.tuple == 1 and "only 1 value is returned by the function" or ("only " .. #valtuple.tuple .. " values are returned by the function") - add_warning("hint", varnode, msg) + self.errs:add_warning("hint", varnode, msg) end end if rval and rvar then -- assigning a function if rval is FunctionType then - widen_all_unions() + self:widen_all_unions() end if varname and (rvar is UnionType or rvar is InterfaceType) then -- narrow unions and interfaces - add_var(varnode, varname, rval, nil, "narrow") + self:add_var(varnode, varname, rval, nil, "narrow") end - if tc then - tc.store_type(varnode.y, varnode.x, valtype) + if self.collector then + self.collector.store_type(varnode.y, varnode.x, valtype) end end end @@ -10676,7 +10677,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["if"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type local all_return = true for _, b in ipairs(node.if_blocks) do if not b.block_returns then @@ -10686,26 +10687,26 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if all_return then node.block_returns = true - infer_negation_of_if_blocks(node, node, #node.if_blocks) + self:infer_negation_of_if_blocks(node, node, #node.if_blocks) end return NONE end, }, ["if_block"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) if node.if_block_n > 1 then - infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) + self:infer_negation_of_if_blocks(node, node.if_parent, node.if_block_n - 1) end end, - before_statements = function(node: Node) + before_statements = function(self: TypeChecker, node: Node) if node.exp then - apply_facts(node.exp, node.exp.known) + self:apply_facts(node.exp, node.exp.known) end end, - after = function(node: Node, _children: {Type}): Type - end_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_scope(node) if #node.body > 0 and node.body[#node.body].block_returns then node.block_returns = true @@ -10715,76 +10716,96 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end }, ["while"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, - before_statements = function(node: Node) - begin_scope(node) - apply_facts(node.exp, node.exp.known) + before_statements = function(self: TypeChecker, node: Node) + self:begin_scope(node) + self:apply_facts(node.exp, node.exp.known) end, after = end_scope_and_none_type, }, ["label"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions() - local label_id = "::" .. node.label .. "::" - if st[#st][label_id] then - error_at(node, "label '" .. node.label .. "' already defined at " .. filename ) - end - local unresolved = find_unresolved() - local var = add_var(node, label_id, type_at(node, a_type("none", {}))) - if unresolved then - if unresolved.labels[node.label] then - var.used = true + self:widen_all_unions() + local label_id = node.label + do + local scope = self.st[#self.st] + scope.labels = scope.labels or {} + if scope.labels[label_id] then + self.errs:add(node, "label '" .. node.label .. "' already defined") + else + scope.labels[label_id] = node end - unresolved.labels[node.label] = nil end + + --for i = #self.st, 1, -1 do + local scope = self.st[#self.st] + if scope.pending_labels and scope.pending_labels[label_id] then + node.used_label = true + scope.pending_labels[label_id] = nil + --break + end + --end end, after = function(): Type return NONE end }, ["goto"] = { - after = function(node: Node, _children: {Type}): Type - if not find_var_type("::" .. node.label .. "::") then - local unresolved = get_unresolved(st[#st]) - unresolved.labels[node.label] = unresolved.labels[node.label] or {} - table.insert(unresolved.labels[node.label], node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local label_id = node.label + local found_label: Node + for i = #self.st, 1, -1 do + local scope = self.st[i] + if scope.labels and scope.labels[label_id] then + found_label = scope.labels[label_id] + break + end + end + + if found_label then + found_label.used_label = true + else + local scope = self.st[#self.st] + scope.pending_labels = scope.pending_labels or {} + scope.pending_labels[label_id] = scope.pending_labels[label_id] or {} + table.insert(scope.pending_labels[label_id], node) end return NONE end, }, ["repeat"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) -- widen all narrowed variables because we don't calculate a fixpoint yet - widen_all_unions(node) + self:widen_all_unions(node) end, -- only end scope after checking `until`, `statements` in repeat body has is_repeat == true after = end_scope_and_none_type, }, ["forin"] = { - before = function(node: Node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local exptuple = children[2] assert(exptuple is TupleType) local exptypes = exptuple.tuple - widen_all_unions(node) + self:widen_all_unions(node) local exp1 = node.exps[1] - local args = a_tuple { + local args = a_tuple(node.exps, { node.exps[2] and exptypes[2], node.exps[3] and exptypes[3] - } - local exp1type = resolve_for_call(exptypes[1], args, false) + }) + local exp1type = self:resolve_for_call(exptypes[1], args, false) if exp1type is PolyType then local _: Type - _, exp1type = type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) + _, exp1type = self:type_check_function_call(exp1, exp1type, args, 0, exp1, {node.exps[2], node.exps[3]}) end if exp1type is FunctionType then @@ -10797,69 +10818,69 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if rets.is_va then r = last else - r = lax and UNKNOWN or INVALID + r = self.feat_lax and an_unknown(v) or an_invalid(v) end end - add_var(v, v.tk, r) + self:add_var(v, v.tk, r) - if tc then - tc.store_type(v.y, v.x, r) + if self.collector then + self.collector.store_type(v.y, v.x, r) end last = r end local nrets = #rets.tuple - if (not lax) and (not rets.is_va and #node.vars > nrets) then + if (not self.feat_lax) and (not rets.is_va and #node.vars > nrets) then local at = node.vars[nrets + 1] local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " values" - error_at(at, "too many variables for this iterator; it produces " .. n_values) + self.errs:add(at, "too many variables for this iterator; it produces " .. n_values) end else - if not (lax and is_unknown(exp1type)) then - error_at(exp1, "expression in for loop does not return an iterator") + if not (self.feat_lax and is_unknown(exp1type)) then + self.errs:add(exp1, "expression in for loop does not return an iterator") end end end, after = end_scope_and_none_type, }, ["fornum"] = { - before_statements = function(node: Node, children: {Type}) - widen_all_unions(node) - begin_scope(node) - local from_t = to_structural(resolve_tuple(children[2])) - local to_t = to_structural(resolve_tuple(children[3])) - local step_t = children[4] and to_structural(children[4]) - local t = (from_t.typename == "integer" and - to_t.typename == "integer" and - (not step_t or step_t.typename == "integer")) - and INTEGER - or NUMBER - add_var(node.var, node.var.tk, t) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) + self:widen_all_unions(node) + self:begin_scope(node) + local from_t = self:to_structural(resolve_tuple(children[2])) + local to_t = self:to_structural(resolve_tuple(children[3])) + local step_t = children[4] and self:to_structural(children[4]) + local typename: TypeName = (from_t.typename == "integer" and + to_t.typename == "integer" and + (not step_t or step_t.typename == "integer")) + and "integer" + or "number" + self:add_var(node.var, node.var.tk, a_type(node.var, typename, {})) end, after = end_scope_and_none_type, }, ["return"] = { - before = function(node: Node) - local rets = find_var_type("@return") + before = function(self: TypeChecker, node: Node) + local rets = self:find_var_type("@return") if rets and rets is TupleType then for i, exp in ipairs(node.exps) do exp.expected = rets.tuple[i] end end end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local got = children[1] assert(got is TupleType) local got_t = got.tuple local n_got = #got_t node.block_returns = true - local expected = find_var_type("@return") as TupleType + local expected = self:find_var_type("@return") as TupleType if not expected then -- if at the toplevel - expected = infer_at(node, got) - module_type = drop_constant_value(to_structural(resolve_tuple(expected))) - st[2]["@return"] = { t = expected } + expected = self:infer_at(node, got) + self.module_type = drop_constant_value(self:to_structural(resolve_tuple(expected))) + self.st[2].vars["@return"] = { t = expected } end local expected_t = expected.tuple @@ -10874,8 +10895,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vatype = expected.is_va and expected.tuple[n_expected] end - if n_got > n_expected and (not lax) and not vatype then - error_at(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) + if n_got > n_expected and (not self.feat_lax) and not vatype then + self.errs:add(node, what ..": excess return values, expected " .. n_expected .. " %s, got " .. n_got .. " %s", expected, got) end if n_expected > 1 @@ -10883,18 +10904,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string and node.exps[1].kind == "op" and (node.exps[1].op.op == "and" or node.exps[1].op.op == "or") and node.exps[1].discarded_tuple then - add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") + self.errs:add_warning("hint", node.exps[1].e2, "additional return values are being discarded due to '" .. node.exps[1].op.op .. "' expression; suggest parentheses if intentional") end for i = 1, n_got do local e = expected_t[i] or vatype if e then e = resolve_tuple(e) - local where = (node.exps[i] and node.exps[i].x) - and node.exps[i] - or node.exps - assert(where and where.x) - assert_is_a(where, got_t[i], e, what) + local w = (node.exps[i] and node.exps[i].x) + and node.exps[i] + or node.exps + assert(w and w.x) + self:assert_is_a(w, got_t[i], e, what) end end @@ -10902,25 +10923,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable_list"] = { - after = function(node: Node, children: {Type}): Type - local tuple = a_tuple(children) + after = function(self: TypeChecker, node: Node, children: {Type}): Type + local tuple = a_tuple(node, children) tuple = flatten_tuple(tuple) for i, t in ipairs(tuple.tuple) do - ensure_not_abstract(node[i], t) + local ok, err = ensure_not_abstract(t) + if not ok then + self.errs:add(node[i], err) + end end return tuple end, }, ["literal_table"] = { - before = function(node: Node) + before = function(self: TypeChecker, node: Node) if node.expected then - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) if decltype is TypeVarType and decltype.constraint then - decltype = resolve_typedecl(to_structural(decltype.constraint)) + decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end if decltype is TupleTableType then @@ -10952,19 +10976,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {LiteralTableItemType}): Type + after = function(self: TypeChecker, node: Node, children: {LiteralTableItemType}): Type node.known = FACT_TRUTHY if not node.expected then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end - local decltype = to_structural(node.expected) + local decltype = self:to_structural(node.expected) local constraint: Type if decltype is TypeVarType and decltype.constraint then constraint = resolve_typedecl(decltype.constraint) - decltype = to_structural(constraint) + decltype = self:to_structural(constraint) end if decltype is UnionType then @@ -10972,7 +10996,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local single_table_rt: Type for _, t in ipairs(decltype.types) do - local rt = to_structural(t) + local rt = self:to_structural(t) if is_lua_table_type(rt) then if single_table_type then -- multiple table types in union, give up @@ -10993,7 +11017,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if not is_lua_table_type(decltype) then - return infer_table_literal(node, children) + return infer_table_literal(self, node, children) end local force_array: Type = nil @@ -11003,73 +11027,75 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string for i, child in ipairs(children) do local cvtype = resolve_tuple(child.vtype) local ck = child.kname + local cktype = child.ktype local n = node[i].key.constnum local b: boolean = nil - if child.ktype.typename == "boolean" then + if cktype is BooleanType then b = (node[i].key.tk == "true") end - check_redeclared_key(node[i], node.expected_context, seen_keys, ck or n or b) + self.errs:check_redeclared_key(node[i], node, seen_keys, ck or n or b) if decltype is RecordLikeType and ck then local df = decltype.fields[ck] if not df then - error_at(node[i], in_context(node.expected_context, "unknown field " .. ck)) + self.errs:add_in_context(node[i], node, "unknown field " .. ck) else if df is TypeDeclType or df is TypeAliasType then - error_at(node[i], in_context(node.expected_context, "cannot reassign a type")) + self.errs:add_in_context(node[i], node, "cannot reassign a type") else - assert_is_a(node[i], cvtype, df, "in record field", ck) + self:assert_is_a(node[i], cvtype, df, "in record field", ck) end end - elseif decltype is TupleTableType and is_number_type(child.ktype) then + elseif decltype is TupleTableType and cktype is NumericType then local dt = decltype.types[n as integer] if not n then - error_at(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unknown index in tuple %s", decltype) elseif not dt then - error_at(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) + self.errs:add_in_context(node[i], node, "unexpected index " .. n .. " in tuple %s", decltype) else - assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, dt, node, "in tuple: at index " .. tostring(n)) end - elseif decltype is ArrayLikeType and is_number_type(child.ktype) then + elseif decltype is ArrayLikeType and cktype is NumericType then local cv = child.vtype if cv is TupleType and i == #children and node[i].key_parsed == "implicit" then -- need to expand last item in an array (e.g { 1, 2, 3, f() }) for ti, tt in ipairs(cv.tuple) do - assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) + self:assert_is_a(node[i], tt, decltype.elements, node, "expected an array: at index " .. tostring(i + ti - 1)) end else - assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) + self:assert_is_a(node[i], cvtype, decltype.elements, node, "expected an array: at index " .. tostring(n)) end elseif node[i].key_parsed == "implicit" then if decltype is MapType then - assert_is_a(node[i], INTEGER, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, a_type(node[i].key, "integer", {}), decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") end - force_array = expand_type(node[i], force_array, child.vtype) + force_array = self:expand_type(node[i], force_array, child.vtype) elseif decltype is MapType then force_array = nil - assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) - assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) + self:assert_is_a(node[i].key, cktype, decltype.keys, node, "in map key") + self:assert_is_a(node[i].value, cvtype, decltype.values, node, "in map value") else - error_at(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) + self.errs:add_in_context(node[i], node, "unexpected key of type %s in table of type %s", cktype, decltype) end end local t: Type if force_array then - t = infer_at(node, an_array(force_array)) + t = self:infer_at(node, an_array(node, force_array)) else - t = resolve_typevars_at(node, node.expected) + t = self:resolve_typevars_at(node, node.expected) end if decltype is RecordType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is RecordType then node.is_total, node.missing = total_record_check(decltype, seen_keys) end elseif decltype is MapType then - local rt = to_structural(t) + local rt = self:to_structural(t) if rt is MapType then - node.is_total, node.missing = total_map_check(decltype, seen_keys) + local rk = self:to_structural(rt.keys) + node.is_total, node.missing = total_map_check(rk, seen_keys) end end @@ -11081,13 +11107,13 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["literal_table_item"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.itemtype then vtype = node.itemtype - assert_is_a(node.value, children[2], node.itemtype, "in table item") + self:assert_is_a(node.value, children[2], node.itemtype, node) end if vtype is FunctionType and vtype.is_method then -- If we assign a method to a table item, e.g. @@ -11096,210 +11122,210 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string vtype = shallow_copy_new_type(vtype) vtype.is_method = false end - return type_at(node, a_type("literal_table_item", { + return a_type(node, "literal_table_item", { kname = kname, ktype = ktype, vtype = vtype, - } as LiteralTableItemType)) + } as LiteralTableItemType) end, }, ["local_function"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["local_macroexp"] = { - before = function(node: Node) - widen_all_unions() - if tc then - tc.reserve_symbol_list_slot(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + if self.collector then + self.collector.reserve_symbol_list_slot(node) end - begin_scope(node) + self:begin_scope(node) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) - check_macroexp_arg_use(node.macrodef) + self:check_macroexp_arg_use(node.macrodef) - local t = type_at(node, ensure_fresh_typeargs(a_function { + local t = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.macrodef.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), macroexp = node.macrodef, })) - add_var(node, node.name.tk, t) + self:add_var(node, node.name.tk, t) return t end, }, ["global_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) if node.implicit_global_function then - local typ = find_var_type(node.name.tk) + local typ = self:find_var_type(node.name.tk) if typ then if typ is FunctionType then node.is_predeclared_local_function = true - elseif not lax then - error_at(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) + elseif not self.feat_lax then + self.errs:add(node, "cannot declare function: type of " .. node.name.tk .. " is %s", typ) end - elseif not lax then - error_at(node, "functions need an explicit 'local' or 'global' annotation") + elseif not self.feat_lax then + self.errs:add(node, "functions need an explicit 'local' or 'global' annotation") end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[2] assert(args is TupleType) - add_internal_function_variables(node, args) - add_function_definition_for_recursion(node, args) + self:add_internal_function_variables(node, args) + self:add_function_definition_for_recursion(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[2] assert(args is TupleType) local rets = children[3] assert(rets is TupleType) - end_function_scope(node) + self:end_function_scope(node) if node.is_predeclared_local_function then return NONE end - add_global(node, node.name.tk, type_at(node, ensure_fresh_typeargs(a_function { + self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), }))) return NONE end, }, ["record_function"] = { - before = function(node: Node) - widen_all_unions() - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions() + self:begin_scope(node) end, - before_arguments = function(_node: Node, children: {Type}) - local rtype = to_structural(resolve_typedecl(children[1])) + before_arguments = function(self: TypeChecker, _node: Node, children: {Type}) + local rtype = self:to_structural(resolve_typedecl(children[1])) -- add type arguments from the record implicitly if rtype is RecordLikeType and rtype.typeargs then for _, typ in ipairs(rtype.typeargs) do - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) end end end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[3] assert(args is TupleType) local rets = children[4] assert(rets is TupleType) - local rtype = to_structural(resolve_typedecl(children[1])) + local rtype = self:to_structural(resolve_typedecl(children[1])) - if lax and rtype.typename == "unknown" then + if self.feat_lax and rtype is UnknownType then return end if rtype is EmptyTableType then - edit_type(rtype, "record") + edit_type(rtype, rtype, "record") local r = rtype as RecordType r.fields = {} r.field_order = {} end if not rtype is RecordLikeType then - error_at(node, "not a record: %s", rtype) + self.errs:add(node, "not a record: %s", rtype) return end - local selftype = get_self_type(node.fn_owner) + local selftype = self:get_self_type(node.fn_owner) if node.is_method then if not selftype then - error_at(node, "could not resolve type of self") + self.errs:add(node, "could not resolve type of self") return end args.tuple[1] = selftype - add_var(nil, "self", selftype) + self:add_var(nil, "self", selftype) end - local fn_type = type_at(node, ensure_fresh_typeargs(a_function { + local fn_type = self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, is_method = node.is_method, typeargs = node.typeargs, args = args, - rets = get_rets(rets), + rets = self.get_rets(rets), })) - local open_t, open_v, owner_name = find_record_to_extend(node.fn_owner) + local open_t, open_v, owner_name = self:find_record_to_extend(node.fn_owner) local open_k = owner_name .. "." .. node.name.tk local rfieldtype = rtype.fields[node.name.tk] if rfieldtype then - rfieldtype = to_structural(rfieldtype) + rfieldtype = self:to_structural(rfieldtype) if open_v and open_v.implemented and open_v.implemented[open_k] then - redeclaration_warning(node) + self.errs:redeclaration_warning(node) end - local ok, err = same_type(fn_type, rfieldtype) + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype is PolyType then - add_errs_prefixing(node, err, errors, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability)") + self.errs:add_prefixing(node, err, "type signature does not match declaration: field has multiple function definitions (such polymorphic declarations are intended for Lua module interoperability): ") return end local shortname = selftype and show_type(selftype) or owner_name local msg = "type signature of '" .. node.name.tk .. "' does not match its declaration in " .. shortname .. ": " - add_errs_prefixing(node, err, errors, msg) + self.errs:add_prefixing(node, err, msg) return end else - if lax or rtype == open_t then + if self.feat_lax or rtype == open_t then rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else - error_at(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") + self.errs:add(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. owner_name .. "' was originally declared") return end @@ -11312,32 +11338,32 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string open_v.implemented[open_k] = true end - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, _children: {Type}): Type - end_function_scope(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + self:end_function_scope(node) return NONE end, }, ["function"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_statements = function(node: Node, children: {Type}) + before_statements = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11346,24 +11372,24 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["macroexp"] = { - before = function(node: Node) - widen_all_unions(node) - begin_scope(node) + before = function(self: TypeChecker, node: Node) + self:widen_all_unions(node) + self:begin_scope(node) end, - before_exp = function(node: Node, children: {Type}) + before_exp = function(self: TypeChecker, node: Node, children: {Type}) local args = children[1] assert(args is TupleType) - add_internal_function_variables(node, args) + self:add_internal_function_variables(node, args) end, - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local args = children[1] assert(args is TupleType) local rets = children[2] assert(rets is TupleType) - end_function_scope(node) - return type_at(node, ensure_fresh_typeargs(a_function { + self:end_function_scope(node) + return self:ensure_fresh_typeargs(a_function(node, { min_arity = node.min_arity, typeargs = node.typeargs, args = args, @@ -11372,22 +11398,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["cast"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.casttype end }, ["paren"] = { - before = function(node: Node) + before = function(_self: TypeChecker, node: Node) node.e1.expected = node.expected end, - after = function(node: Node, children: {Type}): Type + after = function(_self: TypeChecker, node: Node, children: {Type}): Type node.known = node.e1 and node.e1.known return resolve_tuple(children[1]) end, }, ["op"] = { - before = function(node: Node) - begin_scope() + before = function(self: TypeChecker, node: Node) + self:begin_scope() if node.expected then if node.op.op == "and" then node.e2.expected = node.expected @@ -11399,18 +11425,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - before_e2 = function(node: Node, children: {Type}) + before_e2 = function(self: TypeChecker, node: Node, children: {Type}) local e1type = children[1] if node.op.op == "and" then - apply_facts(node, node.e1.known) + self:apply_facts(node, node.e1.known) elseif node.op.op == "or" then - apply_facts(node, facts_not(node, node.e1.known)) + self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then if e1type is FunctionType then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then - is_a(e1type.rets, node.expected) + -- this forces typevars in function return types + self:is_a(e1type.rets, node.expected) end local e1args = e1type.args.tuple local at = argdelta @@ -11433,8 +11460,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end end, - after = function(node: Node, children: {Type}): Type - end_scope() + after = function(self: TypeChecker, node: Node, children: {Type}): Type + self:end_scope() -- given a and b: may be TupleType local ga: Type = children[1] @@ -11445,29 +11472,34 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string local ub: Type -- resolved a and b: not NominalType - local ra: Type = to_structural(ua) + local ra: Type = self:to_structural(ua) local rb: Type if ra.typename == "circular_require" or (ra is TypeDeclType and ra.def and ra.def.typename == "circular_require") then - return invalid_at(node, "cannot dereference a type from a circular require") + return self.errs:invalid_at(node, "cannot dereference a type from a circular require") end if node.op.op == "@funcall" then - if lax and is_unknown(ua) then + if self.feat_lax and is_unknown(ua) then if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then - add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) + self.errs:add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end end - local t = type_check_funcall(node, ua, gb) + assert(gb is TupleType) +assert(node.f) + local t = self:type_check_funcall(node, ua, gb) return t elseif node.op.op == "as" then return gb end - local expected = node.expected and to_structural(resolve_tuple(node.expected)) + local expected = node.expected and self:to_structural(resolve_tuple(node.expected)) - ensure_not_abstract(node.e1, ra) + local ok, err = ensure_not_abstract(ra) + if not ok then + self.errs:add(node.e1, err) + end if ra is TypeDeclType and ra.def.typename == "record" then ra = ra.def end @@ -11476,8 +11508,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- after they are handled above, we can resolve b's tuple and only use that instead. if gb then ub = resolve_tuple(gb) - rb = to_structural(ub) - ensure_not_abstract(node.e2, rb) + rb = self:to_structural(ub) + ok, err = ensure_not_abstract(rb) + if not ok then + self.errs:add(node.e2, err) + end if rb is TypeDeclType and rb.def.typename == "record" then rb = rb.def end @@ -11487,22 +11522,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.receiver = ua assert(node.e2.kind == "identifier") - local bnode: Node = { - y = node.e2.y, - x = node.e2.x, + local bnode: Node = node_at(node.e2, { tk = node.e2.tk, kind = "string", - } - local btype = type_at(node.e2, a_type("string", { literal = node.e2.tk } as StringType)) - local t = type_check_index(node.e1, bnode, ua, btype) + }) + local btype = a_type(node.e2, "string", { literal = node.e2.tk } as StringType) + local t = self:type_check_index(node.e1, bnode, ua, btype) - if t.needs_compat and opts.gen_compat ~= "off" then + if t.needs_compat and self.gen_compat ~= "off" then -- only apply to a literal use, not a propagated type if node.e1.kind == "variable" and node.e2.kind == "identifier" then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk - all_needs_compat[key] = true + self.all_needs_compat[key] = true end end @@ -11510,22 +11543,22 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if node.op.op == "@index" then - return type_check_index(node.e1, node.e2, ua, ub) + return self:type_check_index(node.e1, node.e2, ua, ub) end if node.op.op == "is" then if rb.typename == "integer" then - all_needs_compat["math"] = true + self.all_needs_compat["math"] = true end if ra is TypeDeclType then - error_at(node, "can only use 'is' on variables, not types") + self.errs:add(node, "can only use 'is' on variables, not types") elseif node.e1.kind == "variable" then - check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) - node.known = IsFact { var = node.e1.tk, typ = ub, where = node } + self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub) + node.known = IsFact { var = node.e1.tk, typ = ub, w = node } else - error_at(node, "can only use 'is' on variables") + self.errs:add(node, "can only use 'is' on variables") end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == ":" then @@ -11533,16 +11566,16 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string -- we handle ':' separately from '.' because ':' is specific to records, -- so we produce different error messages - if lax and (is_unknown(ua) or ua.typename == "typevar") then + if self.feat_lax and (is_unknown(ua) or ua is TypeVarType) then if node.e1.kind == "variable" then - add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) + self.errs:add_unknown_dot(node.e1, node.e1.tk .. "." .. node.e2.tk) end - return UNKNOWN + return an_unknown(node) end - local t, e = match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) + local t, e = self:match_record_key(ra, node.e1, node.e2.conststr or node.e2.tk) if not t then - return invalid_at(node.e2, e, ua) + return self.errs:invalid_at(node.e2, e, ua) end return t @@ -11550,7 +11583,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "not" then node.known = facts_not(node, node.e1.known) - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.op == "and" then @@ -11568,33 +11601,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = nil t = ua - elseif ((ra is EnumType and rb is StringType and is_a(rb, ra)) - or (ra is StringType and rb is EnumType and is_a(ra, rb))) then + elseif ((ra is EnumType and rb is StringType and self:is_a(rb, ra)) + or (ra is StringType and rb is EnumType and self:is_a(ra, rb))) then node.known = nil t = (ra is EnumType and ra or rb) elseif expected and expected is UnionType then -- must be checked after string/enum above node.known = facts_or(node, node.e1.known, node.e2.known) - local u = unite({ra, rb}, true) + local u = unite(node, {ra, rb}, true) if u is UnionType then - local ok, err = is_valid_union(u) + ok, err = is_valid_union(u) if not ok then - u = err and invalid_at(node, err, u) or INVALID + u = err and self.errs:invalid_at(node, err, u) or an_invalid(node) end end t = u else - local a_ge_b = is_a(rb, ra) - local b_ge_a = is_a(ra, rb) + local a_ge_b = self:is_a(rb, ra) + local b_ge_a = self:is_a(ra, rb) if a_ge_b or b_ge_a then node.known = facts_or(node, node.e1.known, node.e2.known) if expected then - local a_is = is_a(ua, expected) - local b_is = is_a(ub, expected) + local a_is = self:is_a(ua, expected) + local b_is = self:is_a(ub, expected) if a_is and b_is then - t = resolve_typevars_at(node, expected) + t = self:resolve_typevars_at(node, expected) end end if not t then @@ -11613,44 +11646,46 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if node.op.op == "==" or node.op.op == "~=" then -- if is_lua_table_type(ra) and is_lua_table_type(rb) then --- check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) +-- self:check_metamethod(node, binop_to_metamethod[node.op.op], ra, rb) -- end if ra is EnumType and rb is StringType then if not (rb.literal and ra.enumset[rb.literal]) then - return invalid_at(node, "%s is not a member of %s", ub, ua) + return self.errs:invalid_at(node, "%s is not a member of %s", ub, ua) end elseif ra is TupleTableType and rb is TupleTableType and #ra.types ~= #rb.types then - return invalid_at(node, "tuples are not the same size") - elseif is_a(ub, ua) or ua.typename == "typevar" then + return self.errs:invalid_at(node, "tuples are not the same size") + elseif self:is_a(ub, ua) or ua is TypeVarType then if node.op.op == "==" and node.e1.kind == "variable" then - node.known = EqFact { var = node.e1.tk, typ = ub, where = node } + node.known = EqFact { var = node.e1.tk, typ = ub, w = node } end - elseif is_a(ua, ub) or ub.typename == "typevar" then + elseif self:is_a(ua, ub) or ub is TypeVarType then if node.op.op == "==" and node.e2.kind == "variable" then - node.known = EqFact { var = node.e2.tk, typ = ua, where = node } + node.known = EqFact { var = node.e2.tk, typ = ua, w = node } end - elseif lax and (is_unknown(ua) or is_unknown(ub)) then - return UNKNOWN + elseif self.feat_lax and (is_unknown(ua) or is_unknown(ub)) then + return an_unknown(node) else - return invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) + return self.errs:invalid_at(node, "types are not comparable for equality: %s and %s", ua, ub) end - return BOOLEAN + return a_type(node, "boolean", {}) end if node.op.arity == 1 and unop_types[node.op.op] then if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(node, ra.types, true) -- squash unions of string constants end local types_op = unop_types[node.op.op] - local t = types_op[ra.typename] + local tn = types_op[ra.typename] + local t = tn and a_type(node, tn, {}) if not t and ra is RecordLikeType then t = find_in_interface_list(ra, function(ty: Type): Type - return types_op[ty.typename] + local tname = types_op[ty.typename] + return tname and a_type(node, tname, {}) end) end @@ -11658,19 +11693,18 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if not t then local mt_name = unop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, nil, ua, nil) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, nil, ua, nil) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", ua) end end if ra is MapType then if ra.keys.typename == "number" or ra.keys.typename == "integer" then - add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") + self.errs:add_warning("hint", node, "using the '#' operator on a map with numeric key type may produce unexpected results") else - error_at(node, "using the '#' operator on this map will always return 0") + self.errs:add(node, "using the '#' operator on this map will always return 0") end end @@ -11678,12 +11712,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "~" and env.gen_target == "5.1" then + if node.op.op == "~" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", "bnot", node.e1) end end @@ -11697,39 +11731,39 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if ra is UnionType then - ra = unite(ra.types, true) -- squash unions of string constants + ra = unite(ra, ra.types, true) -- squash unions of string constants end if rb is UnionType then - rb = unite(rb.types, true) -- squash unions of string constants + rb = unite(rb, rb.types, true) -- squash unions of string constants end local types_op = binop_types[node.op.op] - local t = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local tn = types_op[ra.typename] and types_op[ra.typename][rb.typename] + local t = tn and a_type(node, tn, {}) local meta_on_operator: integer if not t then local mt_name = binop_to_metamethod[node.op.op] if mt_name then - t, meta_on_operator = check_metamethod(node, mt_name, ra, rb, ua, ub) + t, meta_on_operator = self:check_metamethod(node, mt_name, ra, rb, ua, ub) end if not t then - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) - t = INVALID + t = self.errs:invalid_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", ua, ub) if node.op.op == "or" then - local u = unite({ua, ub}) + local u = unite(node, {ua, ub}) if u is UnionType and is_valid_union(u) then - add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") + self.errs:add_warning("hint", node, "if a union type was intended, consider declaring it explicitly") end end end end if ua is NominalType and ub is NominalType and not meta_on_operator then - if is_a(ua, ub) then + if self:is_a(ua, ub) then t = ua else - error_at(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) + self.errs:add(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for distinct nominal types %s and %s", ua, ub) end end @@ -11737,20 +11771,20 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string node.known = FACT_TRUTHY end - if node.op.op == "//" and env.gen_target == "5.1" then + if node.op.op == "//" and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, "__idiv", meta_on_operator, node.e1, node.e2) else - local div: Node = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } + local div: Node = node_at(node, { kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 }) convert_node_to_compat_call(node, "math", "floor", div) end - elseif bit_operators[node.op.op] and env.gen_target == "5.1" then + elseif bit_operators[node.op.op] and self.gen_target == "5.1" then if meta_on_operator then - all_needs_compat["mt"] = true + self.all_needs_compat["mt"] = true convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_on_operator, node.e1, node.e2) else - all_needs_compat["bit32"] = true + self.all_needs_compat["bit32"] = true convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) end end @@ -11762,28 +11796,28 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["variable"] = { - after = function(node: Node, _children: {Type}): Type + after = function(self: TypeChecker, node: Node, _children: {Type}): Type if node.tk == "..." then - local va_sentinel = find_var_type("@is_va") + local va_sentinel = self:find_var_type("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then - return invalid_at(node, "cannot use '...' outside a vararg function") + return self.errs:invalid_at(node, "cannot use '...' outside a vararg function") end end local t: Type if node.tk == "_G" then - t, node.attribute = simulate_g() + t, node.attribute = self:simulate_g() else local use: VarUse = node.is_lvalue and "lvalue" or "use" - t, node.attribute = find_var_type(node.tk, use) + t, node.attribute = self:find_var_type(node.tk, use) end if not t then - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end if t is TypeDeclType then @@ -11794,70 +11828,70 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end, }, ["type_identifier"] = { - after = function(node: Node, _children: {Type}): Type - local typ, attr = find_var_type(node.tk) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local typ, attr = self:find_var_type(node.tk) node.attribute = attr if typ then return typ end - if lax then - add_unknown(node, node.tk) - return UNKNOWN + if self.feat_lax then + self.errs:add_unknown(node, node.tk) + return an_unknown(node) end - return invalid_at(node, "unknown variable: " .. node.tk) + return self.errs:invalid_at(node, "unknown variable: " .. node.tk) end, }, ["argument"] = { - after = function(node: Node, children: {Type}): Type + after = function(self: TypeChecker, node: Node, children: {Type}): Type local t = children[1] if not t then - t = UNKNOWN + t = an_unknown(node) end if node.tk == "..." then - t = a_vararg { t } + t = a_vararg(node, { t }) end - add_var(node, node.tk, t).is_func_arg = true + self:add_var(node, node.tk, t).is_func_arg = true return t end, }, ["identifier"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE -- type is resolved elsewhere end, }, ["newtype"] = { - after = function(node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type return node.newtype end, }, ["error_node"] = { - after = function(_node: Node, _children: {Type}): Type - return INVALID + after = function(_self: TypeChecker, node: Node, _children: {Type}): Type + return an_invalid(node) end, } } visit_node.cbs["break"] = { - after = function(_node: Node, _children: {Type}): Type + after = function(_self: TypeChecker, _node: Node, _children: {Type}): Type return NONE end, } visit_node.cbs["do"] = visit_node.cbs["break"] - local function after_literal(node: Node): Type + local function after_literal(_self: TypeChecker, node: Node): Type node.known = FACT_TRUTHY - return type_at(node, a_type(node.kind as TypeName, {})) + return a_type(node, node.kind as TypeName, {}) end visit_node.cbs["string"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) as StringType + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) as StringType t.literal = node.conststr - local expected = node.expected and to_structural(node.expected) - if expected and expected is EnumType and is_a(t, expected) then + local expected = node.expected and self:to_structural(node.expected) + if expected and expected is EnumType and self:is_a(t, expected) then return node.expected end @@ -11868,8 +11902,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["integer"] = { after = after_literal } visit_node.cbs["boolean"] = { - after = function(node: Node, _children: {Type}): Type - local t = after_literal(node) + after = function(self: TypeChecker, node: Node, _children: {Type}): Type + local t = after_literal(self, node) node.known = (node.tk == "true") and FACT_TRUTHY or nil return t end, @@ -11880,7 +11914,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] - visit_node.after = function(node: Node, _children: {Type}, t: Type): Type + visit_node.after = function(_self: TypeChecker, node: Node, _children: {Type}, t: Type): Type if node.expanded then apply_macroexp(node) end @@ -11888,13 +11922,12 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return t end - local expand_interfaces: function(Type) do - local function add_interface_fields(what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) + local function add_interface_fields(self: TypeChecker, what: string, fields: {string:Type}, field_order: {string}, resolved: RecordLikeType, named: NominalType, list?: MetaMode) for fname, ftype in fields_of(resolved, list) do if fields[fname] then - if not is_a(fields[fname], ftype) then - error_at(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) + if not self:is_a(fields[fname], ftype) then + self.errs:add(fields[fname], what .." '" .. fname .. "' does not match definition in interface %s", named) end else table.insert(field_order, fname) @@ -11903,18 +11936,21 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local function collect_interfaces(list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} + local function collect_interfaces(self: TypeChecker, list: {ArrayType | NominalType}, t: RecordLikeType, seen:{Type:boolean}): {ArrayType | NominalType} if t.interface_list then for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if not (ri.typename == "invalid") then - assert(ri is InterfaceType, "nominal resolved to " .. ri.typename) - if not ri.interfaces_expanded and not seen[ri] then - seen[ri] = true - collect_interfaces(list, ri, seen) + if ri is InterfaceType then + if not ri.interfaces_expanded and not seen[ri] then + seen[ri] = true + collect_interfaces(self, list, ri, seen) + end + table.insert(list, iface) + else + self.errs:add(iface, "attempted to use %s as interface, but its type is %s", iface, ri) end - table.insert(list, iface) end else if not seen[iface] then @@ -11927,30 +11963,30 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string return list end - expand_interfaces = function(t: RecordLikeType) + function TypeChecker:expand_interfaces(t: RecordLikeType) if t.interfaces_expanded then return end t.interfaces_expanded = true - t.interface_list = collect_interfaces({}, t, {}) + t.interface_list = collect_interfaces(self, {}, t, {}) for _, iface in ipairs(t.interface_list) do if iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) assert(ri is InterfaceType) - add_interface_fields("field", t.fields, t.field_order, ri, iface) + add_interface_fields(self, "field", t.fields, t.field_order, ri, iface) if ri.meta_fields then t.meta_fields = t.meta_fields or {} t.meta_field_order = t.meta_field_order or {} - add_interface_fields("metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") + add_interface_fields(self, "metamethod", t.meta_fields, t.meta_field_order, ri, iface, "meta") end else if not t.elements then t.elements = iface else - if not same_type(iface.elements, t.elements) then - error_at(t, "incompatible array interfaces") + if not self:same_type(iface.elements, t.elements) then + self.errs:add(t, "incompatible array interfaces") end end end @@ -11958,33 +11994,33 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end - local visit_type: Visitor + local visit_type: Visitor visit_type = { cbs = { ["function"] = { - before = function(_typ: Type) - begin_scope() + before = function(self: TypeChecker, _typ: Type) + self:begin_scope() end, - after = function(typ: Type, _children: {Type}): Type - end_scope() - return ensure_fresh_typeargs(typ) + after = function(self: TypeChecker, typ: Type, _children: {Type}): Type + self:end_scope() + return self:ensure_fresh_typeargs(typ) end, }, ["record"] = { - before = function(typ: RecordType) - begin_scope() - add_var(nil, "@self", type_at(typ, a_typedecl(typ))) + before = function(self: TypeChecker, typ: RecordType) + self:begin_scope() + self:add_var(nil, "@self", type_at(typ, a_typedecl(typ, typ))) for fname, ftype in fields_of(typ) do if ftype is TypeAliasType then - resolve_nominal(ftype.alias_to) - add_var(nil, fname, ftype) + self:resolve_nominal(ftype.alias_to) + self:add_var(nil, fname, ftype) elseif ftype is TypeDeclType then - add_var(nil, fname, ftype) + self:add_var(nil, fname, ftype) end end end, - after = function(typ: RecordType, children: {Type}): Type + after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 if typ.typeargs then for _, _ in ipairs(typ.typeargs) do @@ -11998,11 +12034,11 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string if iface is ArrayType then typ.interface_list[j] = iface elseif iface is NominalType then - local ri = resolve_nominal(iface) + local ri = self:resolve_nominal(iface) if ri is InterfaceType then typ.interface_list[j] = iface else - error_at(children[i], "%s is not an interface", children[i]) + self.errs:add(children[i], "%s is not an interface", children[i]) end end i = i + 1 @@ -12042,7 +12078,7 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end end elseif ftype is TypeAliasType then - resolve_typealias(ftype) + self:resolve_typealias(ftype) end typ.fields[name] = ftype @@ -12061,55 +12097,55 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end if typ.interface_list then - expand_interfaces(typ) + self:expand_interfaces(typ) end if fmacros then for _, t in ipairs(fmacros) do - local macroexp_type = recurse_node(t.macroexp, visit_node, visit_type) + local macroexp_type = recurse_node(self, t.macroexp, visit_node, visit_type) - check_macroexp_arg_use(t.macroexp) + self:check_macroexp_arg_use(t.macroexp) - if not is_a(macroexp_type, t) then - error_at(macroexp_type, "macroexp type does not match declaration") + if not self:is_a(macroexp_type, t) then + self.errs:add(macroexp_type, "macroexp type does not match declaration") end end end - end_scope() + self:end_scope() return typ end, }, ["typearg"] = { - after = function(typ: TypeArgType, _children: {Type}): Type - add_var(nil, typ.typearg, type_at(typ, a_type("typearg", { + after = function(self: TypeChecker, typ: TypeArgType, _children: {Type}): Type + self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, constraint = typ.constraint, - } as TypeArgType))) + } as TypeArgType)) return typ end, }, ["typevar"] = { - after = function(typ: TypeVarType, _children: {Type}): Type - if not find_var_type(typ.typevar) then - error_at(typ, "undefined type variable " .. typ.typevar) + after = function(self: TypeChecker, typ: TypeVarType, _children: {Type}): Type + if not self:find_var_type(typ.typevar) then + self.errs:add(typ, "undefined type variable " .. typ.typevar) end return typ end, }, ["nominal"] = { - after = function(typ: NominalType, _children: {Type}): Type + after = function(self: TypeChecker, typ: NominalType, _children: {Type}): Type if typ.found then return typ end - local t = find_type(typ.names, true) + local t = self:find_type(typ.names, true) if t then if t is TypeArgType then -- convert nominal into a typevar typ.names = nil - edit_type(typ, "typevar") + edit_type(typ, typ, "typevar") local tv = typ as TypeVarType tv.typevar = t.typearg tv.constraint = t.constraint @@ -12120,18 +12156,19 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string end else local name = typ.names[1] - local unresolved = get_unresolved() - unresolved.nominals[name] = unresolved.nominals[name] or {} - table.insert(unresolved.nominals[name], typ) + local scope = self.st[#self.st] + scope.pending_nominals = scope.pending_nominals or {} + scope.pending_nominals[name] = scope.pending_nominals[name] or {} + table.insert(scope.pending_nominals[name], typ) end return typ end, }, ["union"] = { - after = function(typ: UnionType, _children: {Type}): Type + after = function(self: TypeChecker, typ: UnionType, _children: {Type}): Type local ok, err = is_valid_union(typ) if not ok then - return err and invalid_at(typ, err, typ) or INVALID + return err and self.errs:invalid_at(typ, err, typ) or an_invalid(typ) end return typ end @@ -12139,59 +12176,8 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string }, } - local function internal_compiler_check(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - if type(t) ~= "table" then - error(((w as Node).kind or (w as Type).typename) .. " did not produce a type") - end - if type(t.typename) ~= "string" then - error(((w as Node).kind or (w as Type).typename) .. " type does not have a typename") - end - - return t - end - end - - local function store_type_after(fn: function(W, {Type}, Type): (Type)): (function(W, {Type}, Type): (Type)) - return function(w: W, children: {Type}, t: Type): Type - t = fn and fn(w, children, t) or t - - local where = w as Where - - if where.y then - tc.store_type(where.y, where.x, t) - end - - return t - end - end - - local function debug_type_after(fn: function(Node, {Type}, Type): (Type)): (function(Node, {Type}, Type): (Type)) - return function(node: Node, children: {Type}, t: Type): Type - t = fn and fn(node, children, t) or t - node.debug_type = t - return t - end - end - - if opts.run_internal_compiler_checks then - visit_node.after = internal_compiler_check(visit_node.after) - visit_type.after = internal_compiler_check(visit_type.after) - end - - if tc then - visit_node.after = store_type_after(visit_node.after) - visit_type.after = store_type_after(visit_type.after) - end - - if TL_DEBUG then - visit_node.after = debug_type_after(visit_node.after) - end - local default_type_visitor = { - after = function(typ: Type, _children: {Type}): Type + after = function(_self: TypeChecker, typ: Type, _children: {Type}): Type return typ end, } @@ -12218,70 +12204,201 @@ tl.type_check = function(ast: Node, opts: TypeCheckOptions): Result, string visit_type.cbs["any"] = default_type_visitor visit_type.cbs["unknown"] = default_type_visitor visit_type.cbs["invalid"] = default_type_visitor - visit_type.cbs["unresolved"] = default_type_visitor visit_type.cbs["none"] = default_type_visitor - assert(ast.kind == "statements") - recurse_node(ast, visit_node, visit_type) + local type VisitorAfterPatcher = function(VisitorAfter): VisitorAfter - close_types(st[1]) - check_for_unused_vars(st[1], true) + local function internal_compiler_check(fn: VisitorAfter): VisitorAfter + return function(s: S, n: N, children: {Type}, t: Type): Type + t = fn and fn(s, n, children, t) or t + + if type(t) ~= "table" then + error(((n as Node).kind or (n as Type).typename) .. " did not produce a type") + end + if type(t.typename) ~= "string" then + error(((n as Node).kind or (n as Type).typename) .. " type does not have a typename") + end - clear_redundant_errors(errors) + return t + end + end - add_compat_entries(ast, all_needs_compat, env.gen_compat) + local function store_type_after(fn: VisitorAfter): VisitorAfter + return function(self: TypeChecker, n: N, children: {Type}, t: Type): Type + t = fn and fn(self, n, children, t) or t - local result = { - ast = ast, - env = env, - type = module_type or BOOLEAN, - filename = filename, - warnings = warnings, - type_errors = errors, - dependencies = dependencies, - } + local w = n as Where - env.loaded[filename] = result - table.insert(env.loaded_order, filename) + if w.y then + self.collector.store_type(w.y, w.x, t) + end - if tc then - env.reporter:store_result(tc, env.globals) + return t + end end - return result -end + local function debug_type_after(fn: VisitorAfter): VisitorAfter + return function(s: S, node: Node, children: {Type}, t: Type): Type + t = fn and fn(s, node, children, t) or t --------------------------------------------------------------------------------- --- Report types --------------------------------------------------------------------------------- + node.debug_type = t + return t + end + end -function tl.symbols_in_scope(tr: TypeReport, y: integer, x: integer): {string:integer} - local function find(symbols: {{integer, integer, string, integer}}, at_y: integer, at_x: integer): integer - local function le(a: {integer, integer}, b: {integer, integer}): boolean - return a[1] < b[1] - or (a[1] == b[1] and a[2] <= b[2]) + local function patch_visitors(my_visit_node: Visitor, + after_node: VisitorAfterPatcher, + my_visit_type?: Visitor, + after_type?: VisitorAfterPatcher): + Visitor, + Visitor + if my_visit_node == visit_node then + my_visit_node = shallow_copy_table(my_visit_node) end - return binary_search(symbols, {at_y, at_x}, le) or 0 + my_visit_node.after = after_node(my_visit_node.after) + if my_visit_type then + if my_visit_type == visit_type then + my_visit_type = shallow_copy_table(my_visit_type) + end + my_visit_type.after = after_type(my_visit_type.after) + else + my_visit_type = visit_type + end + return my_visit_node, my_visit_type end - local ret: {string:integer} = {} + local function set_feat(feat: Feat, default: boolean): boolean + if feat then + return (feat == "on") + else + return default + end + end - local n = find(tr.symbols, y, x) + tl.type_check = function(ast: Node, filename: string, opts: TypeCheckOptions, env?: Env): Result, string + assert(filename is string, "tl.type_check signature has changed, pass filename separately") + assert((not opts) or (not (opts as {any:any}).env), "tl.type_check signature has changed, pass env separately") - local symbols = tr.symbols - while n >= 1 do - local s = symbols[n] - if s[3] == "@{" then - n = n - 1 - elseif s[3] == "@}" then - n = s[4] + filename = filename or "?" + + opts = opts or {} + + if not env then + local err: string + env, err = tl.new_env({ defaults = opts }) + if err then + return nil, err + end + end + + local self: TypeChecker = { + filename = filename, + env = env, + st = { + { + vars = env.globals, + pending_global_types = {}, + }, + }, + errs = Errors.new(filename), + all_needs_compat = {}, + dependencies = {}, + subtype_relations = TypeChecker.subtype_relations, + eqtype_relations = TypeChecker.eqtype_relations, + type_priorities = TypeChecker.type_priorities, + } + + setmetatable(self, { __index = TypeChecker }) + + self.feat_lax = set_feat(opts.feat_lax or env.defaults.feat_lax, false) + self.feat_arity = set_feat(opts.feat_arity or env.defaults.feat_arity, true) + self.gen_compat = opts.gen_compat or env.defaults.gen_compat or DEFAULT_GEN_COMPAT + self.gen_target = opts.gen_target or env.defaults.gen_target or DEFAULT_GEN_TARGET + + if self.gen_target == "5.4" and self.gen_compat ~= "off" then + return nil, "gen-compat must be explicitly 'off' when gen-target is '5.4'" + end + + if self.feat_lax then + self.type_priorities = shallow_copy_table(self.type_priorities) + self.type_priorities["unknown"] = 0 + + self.subtype_relations = shallow_copy_table(self.subtype_relations) + + self.subtype_relations["unknown"] = {} + self.subtype_relations["unknown"]["*"] = compare_true + + self.subtype_relations["*"] = shallow_copy_table(self.subtype_relations["*"]) + self.subtype_relations["*"]["unknown"] = compare_true + -- in .lua files, all values can be used in a boolean context + self.subtype_relations["*"]["boolean"] = compare_true + + self.get_rets = function(rets: TupleType): TupleType + if #rets.tuple == 0 then + return a_vararg(rets, { an_unknown(rets) }) + end + return rets + end else - ret[s[3]] = s[4] - n = n - 1 + self.get_rets = function(rets: TupleType): TupleType + return rets + end end - end - return ret + if env.report_types then + env.reporter = env.reporter or tl.new_type_reporter() + self.collector = env.reporter:get_collector(filename) + end + + local visit_node, visit_type = visit_node, visit_type + if opts.run_internal_compiler_checks then + visit_node, visit_type = patch_visitors( + visit_node, internal_compiler_check, + visit_type, internal_compiler_check + ) + end + if self.collector then + visit_node, visit_type = patch_visitors( + visit_node, store_type_after, + visit_type, store_type_after + ) + end + if TL_DEBUG then + visit_node, visit_type = patch_visitors( + visit_node, debug_type_after + ) + end + + assert(ast.kind == "statements") + recurse_node(self, ast, visit_node, visit_type) + + local global_scope = self.st[1] + close_types(global_scope) + self.errs:warn_unused_vars(global_scope, true) + + clear_redundant_errors(self.errs.errors) + + add_compat_entries(ast, self.all_needs_compat, self.gen_compat) + + local result = { + ast = ast, + env = env, + type = self.module_type or a_type(ast, "boolean", {}), + filename = filename, + warnings = self.errs.warnings, + type_errors = self.errs.errors, + dependencies = self.dependencies, + } + + env.loaded[filename] = result + table.insert(env.loaded_order, filename or "") + + if self.collector then + env.reporter:store_result(self.collector, env.globals) + end + + return result + end end -------------------------------------------------------------------------------- @@ -12297,9 +12414,24 @@ local function read_full_file(fd: FILE): string, string return content, err end -tl.process = function(filename: string, env: Env, fd?: FILE): Result, string - assert((not fd or type(fd) ~= "string"), "fd must be a file") +local function feat_lax_heuristic(filename?: string, input?: string): Feat + if filename then + local _, extension = filename:match("(.*)%.([a-z]+)$") + extension = extension and extension:lower() + + if extension == "tl" then + return "off" + elseif extension == "lua" then + return "on" + end + end + if input then + return (input:match("^#![^\n]*lua[^\n]*\n")) and "on" or "off" + end + return "off" +end +tl.process = function(filename: string, env: Env, fd?: FILE): Result, string if env and env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12319,23 +12451,38 @@ tl.process = function(filename: string, env: Env, fd?: FILE): Result, string return nil, "could not read " .. filename .. ": " .. err end - local _, extension = filename:match("(.*)%.([a-z]+)$") - extension = extension and extension:lower() + return tl.process_string(input, env, filename) +end - local is_lua: boolean - if extension == "tl" then - is_lua = false - elseif extension == "lua" then - is_lua = true - else - is_lua = input:match("^#![^\n]*lua[^\n]*\n") as boolean +function tl.target_from_lua_version(str: string): GenTarget + if str == "Lua 5.1" + or str == "Lua 5.2" then + return "5.1" + elseif str == "Lua 5.3" then + return "5.3" + elseif str == "Lua 5.4" then + return "5.4" end +end - return tl.process_string(input, is_lua, env, filename) +local function default_env_opts(runtime: boolean, filename?: string, input?: string): EnvOptions + local gen_target = runtime and tl.target_from_lua_version(_VERSION) or DEFAULT_GEN_TARGET + local gen_compat: GenCompat = (gen_target == "5.4") and "off" or DEFAULT_GEN_COMPAT + return { + defaults = { + feat_lax = feat_lax_heuristic(filename, input), + gen_target = gen_target, + gen_compat = gen_compat, + run_internal_compiler_checks = false, + } + } end -function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: string): Result - env = env or tl.init_env(is_lua) +function tl.process_string(input: string, env?: Env, filename?: string): Result + assert(type(env) ~= "boolean", "tl.process_string signature has changed") + + env = env or tl.new_env(default_env_opts(false, filename, input)) + if env.loaded and env.loaded[filename] then return env.loaded[filename] end @@ -12347,7 +12494,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: local result = { ok = false, filename = filename, - type = BOOLEAN, + type = a_type({ f = filename, y = 1, x = 1 }, "boolean", {}), type_errors = {}, syntax_errors = syntax_errors, env = env, @@ -12357,14 +12504,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: return result end - local opts: TypeCheckOptions = { - filename = filename, - lax = is_lua, - gen_compat = env.gen_compat, - gen_target = env.gen_target, - env = env, - } - local result = tl.type_check(program, opts) + local result = tl.type_check(program, filename, env.defaults, env) result.syntax_errors = syntax_errors @@ -12372,15 +12512,15 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename?: end tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result - env = env or assert(tl.init_env(), "Default environment initialization failed") - local result = tl.process_string(input, false, env) + env = env or assert(tl.new_env(default_env_opts(false, nil, input)), "Default environment initialization failed") + local result = tl.process_string(input, env) if (not result.ast) or #result.syntax_errors > 0 then return nil, result end local code: string - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.defaults.gen_target, pp) return code, result end @@ -12396,28 +12536,25 @@ local function tl_package_loader(module_name: string): any, any if #errs > 0 then error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end - local lax = not not found_filename:match("lua$") local env = tl.package_loader_env if not env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = assert(tl.new_env(), "Default environment initialization failed") env = tl.package_loader_env end - env.modules[module_name] = a_typedecl(CIRCULAR_REQUIRE) + local opts = default_env_opts(true, found_filename) - local result = tl.type_check(program, { - lax = lax, - filename = found_filename, - env = env, - run_internal_compiler_checks = false, - }) + local w = { f = found_filename, x = 1, y = 1 } + env.modules[module_name] = a_typedecl(w, a_type(w, "circular_require", {})) + + local result = tl.type_check(program, found_filename, opts.defaults, env) env.modules[module_name] = result.type -- TODO: should this be a hard error? this seems analogous to -- finding a lua file with a syntax error in it - local code = assert(tl.pretty_print_ast(program, env.gen_target, true)) + local code = assert(tl.pretty_print_ast(program, opts.defaults.gen_target, true)) local chunk, err = load(code, "@" .. found_filename, "t") if chunk then return function(modname: string, loader_data: string): any @@ -12443,21 +12580,10 @@ function tl.loader() end end -function tl.target_from_lua_version(str: string): TargetMode - if str == "Lua 5.1" - or str == "Lua 5.2" then - return "5.1" - elseif str == "Lua 5.3" then - return "5.3" - elseif str == "Lua 5.4" then - return "5.4" - end -end - -local function env_for(lax: boolean, env_tbl: {any:any}): Env +local function env_for(opts: EnvOptions, env_tbl: {any:any}): Env if not env_tbl then if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end return tl.package_loader_env end @@ -12466,7 +12592,7 @@ local function env_for(lax: boolean, env_tbl: {any:any}): Env tl.load_envs = setmetatable({}, { __mode = "k" }) end - tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.init_env(lax) + tl.load_envs[env_tbl] = tl.load_envs[env_tbl] or tl.new_env(opts) return tl.load_envs[env_tbl] end @@ -12476,17 +12602,14 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end - local lax = chunkname and not not chunkname:match("lua$") + local opts = default_env_opts(true, chunkname) + if not tl.package_loader_env then - tl.package_loader_env = tl.init_env(lax) + tl.package_loader_env = tl.new_env(opts) end - local result = tl.type_check(program, { - lax = lax, - filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\""), - env = env_for(lax, ...), - run_internal_compiler_checks = false, - }) + local filename = chunkname or ("string \"" .. input:sub(45) .. (#input > 45 and "..." or "") .. "\"") + local result = tl.type_check(program, filename, opts.defaults, env_for(opts, ...)) if mode and mode:match("c") then if #result.type_errors > 0 then @@ -12500,7 +12623,7 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a mode = mode:gsub("c", "") as LoadMode end - local code, err = tl.pretty_print_ast(program, tl.target_from_lua_version(_VERSION), true) + local code, err = tl.pretty_print_ast(program, opts.defaults.gen_target, true) if not code then return nil, err end @@ -12508,4 +12631,29 @@ tl.load = function(input: string, chunkname: string, mode: LoadMode, ...: {any:a return load(code, chunkname, mode, ...) end +-------------------------------------------------------------------------------- +-- Backwards compatibility +-------------------------------------------------------------------------------- + +function tl.get_types(result: Result): TypeReport, TypeReporter + return result.env.reporter:get_report(), result.env.reporter +end + +tl.init_env = function(lax?: boolean, gen_compat?: boolean | GenCompat, gen_target?: GenTarget, predefined?: {string}): Env, string + local opts = { + defaults = { + feat_lax = (lax and "on" or "off") as Feat, + gen_compat = ((gen_compat is GenCompat) and gen_compat) or + (gen_compat == false and "off") or + (gen_compat == true or gen_compat == nil) and "optional", + gen_target = gen_target or + ((_VERSION == "Lua 5.1" or _VERSION == "Lua 5.2") and "5.1") or + "5.3", + }, + predefined_modules = predefined, + } + + return tl.new_env(opts) +end + return tl