From b286e25ff151b6ae0ca5828509469c937e670402 Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Wed, 18 Dec 2024 13:35:02 -0300 Subject: [PATCH] generalize internal representation of generic types All types that have type variables are now represented as a GenericType record, which holds a non-generic Type and an array of type arguments. This change is because originally we only cared about generic records and generic functions, but once we have the `local type MyType = ...` syntax, other types can be generic as well (in particular, unions). Instead of replicating generic support logic in the implementation of each type, we factor it out into a type-level term which encapsulates the application of type variables, which is something more like second-order lambda calculus. See https://en.wikipedia.org/wiki/System_F and the ensuing rabbit hole. --- spec/cli/types_spec.lua | 2 +- spec/lang/call/generic_function_spec.lua | 38 +- spec/lang/code_gen/local_type_spec.lua | 34 ++ tl.lua | 711 +++++++++++++-------- tl.tl | 745 ++++++++++++++--------- 5 files changed, 987 insertions(+), 543 deletions(-) diff --git a/spec/cli/types_spec.lua b/spec/cli/types_spec.lua index 0b38170e6..15c41cfc9 100644 --- a/spec/cli/types_spec.lua +++ b/spec/cli/types_spec.lua @@ -377,7 +377,7 @@ describe("tl types works like check", function() assert.same({ ["17"] = 7, ["20"] = 2, - ["25"] = 10, + ["25"] = 17, ["31"] = 9, }, by_pos["2"]) end) diff --git a/spec/lang/call/generic_function_spec.lua b/spec/lang/call/generic_function_spec.lua index d512c307c..2b72f7e90 100644 --- a/spec/lang/call/generic_function_spec.lua +++ b/spec/lang/call/generic_function_spec.lua @@ -424,6 +424,19 @@ describe("generic function", function() { y = 6, msg = "cannot infer declaration type; an explicit type annotation is necessary" }, })) + it("can use record type variable in record function", util.check_type_error([[ + local record Container + try_resolve: function(Container):T + end + + function Container:resolve():U + local t = self:try_resolve() + return t + end + ]], { + { y = 7, msg = "got T, expected U" }, + })) + it("works when an annotation is given", util.check([[ local record Container try_resolve: function(Container):T @@ -514,13 +527,36 @@ describe("generic function", function() return #s end - local pok1, pok2, msg = pcall2(pcall1, greet, "hello") + local pok1, pok2, msg: boolean, boolean, number = pcall2(pcall1, greet, "hello") print(pok1) print(pok2) print(msg) ]])) + it("nested uses of generic functions using the same names for type variables don't cause conflicts", util.check_type_error([[ + local function pcall1(f: function(A):(B), a: A): boolean, B + return true, f(a) + end + + local function pcall2(f: function(A, A2):(B, B2), a: A, a2: A2): boolean, B, B2 + return true, f(a, a2) + end + + local function greet(s: string): number + print(s .. "!") + return #s + end + + local pok1, pok2, msg: boolean, boolean, string = pcall2(pcall1, greet, "hello") + + print(pok1) + print(pok2) + print(msg) + ]], { + { y = 14, msg = "argument 2: return 1: got number, expected string" } + })) + it("nested uses of generic record functions using the same names for type variables don't cause conflicts (#560)", util.check([[ local M = {} diff --git a/spec/lang/code_gen/local_type_spec.lua b/spec/lang/code_gen/local_type_spec.lua index 86036871f..cfbf0b48f 100644 --- a/spec/lang/code_gen/local_type_spec.lua +++ b/spec/lang/code_gen/local_type_spec.lua @@ -135,6 +135,40 @@ describe("local type code generation", function() local lunchbox = L2.new({ "apple", "peach" }) ]])) + it("alias for a type that shouldn't be elided, with function generics", util.gen([[ + local type List2 = record + new: function(initialItems: {T}, u: U): List2 + end + + function List2.new(initialItems: {T}, u: Y): List2 + end + + local type Fruit2 = enum + "apple" + "peach" + "banana" + end + + local type L2 = List2 + local lunchbox = L2.new({"apple", "peach"}, true) + ]], [[ + local List2 = {} + + + + function List2.new(initialItems, u) + end + + + + + + + + local L2 = List2 + local lunchbox = L2.new({ "apple", "peach" }, true) + ]])) + it("if alias shouldn't be elided, type shouldn't be elided either", util.gen([[ local type List = record new: function(initialItems: {T}): List diff --git a/tl.lua b/tl.lua index e58f55b80..f5e95630b 100644 --- a/tl.lua +++ b/tl.lua @@ -1623,6 +1623,7 @@ end + local table_types = { @@ -1634,6 +1635,7 @@ local table_types = { ["emptytable"] = true, ["tupletable"] = true, + ["generic"] = false, ["typedecl"] = false, ["typevar"] = false, ["typearg"] = false, @@ -1997,6 +1999,8 @@ end + + @@ -2379,7 +2383,6 @@ do local parse_argument_type_list local parse_type local parse_type_declaration - local parse_newtype local parse_interface_name @@ -2446,6 +2449,13 @@ do return t end + local function new_generic(ps, i, typeargs, typ) + local gt = new_type(ps, i, "generic") + gt.typeargs = typeargs + gt.t = typ + return gt + end + local function new_typedecl(ps, i, def) local t = new_type(ps, i, "typedecl") t.def = def @@ -2500,12 +2510,6 @@ do def = new_type(ps, istart, tn) - if typeargs then - if def.typename == "record" or def.typename == "interface" then - def.typeargs = typeargs - end - end - local ok i, ok = parse_type_body_fns[tn](ps, i, def) if not ok then @@ -2514,6 +2518,10 @@ do i = verify_end(ps, i, istart, node) + if typeargs then + def = new_generic(ps, istart, typeargs, def) + end + return i, def end @@ -2761,10 +2769,11 @@ do end local function parse_function_type(ps, i) + local typeargs local typ = new_type(ps, i, "function") i = i + 1 - i, typ.typeargs = parse_typeargs_if_any(ps, i) + i, typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then i, typ.args, typ.maybe_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) @@ -2774,6 +2783,11 @@ do typ.is_method = false typ.min_arity = 0 end + + if typeargs then + return i, new_generic(ps, i, typeargs, typ) + end + return i, typ end @@ -3720,9 +3734,11 @@ do end local oldt = fields[field_name] + local oldf = oldt.typename == "generic" and oldt.t or oldt + local newf = newt.typename == "generic" and newt.t or newt - if newt.typename == "function" then - if oldt.typename == "function" then + if newf.typename == "function" then + if oldf.typename == "function" then local p = new_type(ps, i, "poly") p.types = { oldt, newt } fields[field_name] = p @@ -3737,6 +3753,10 @@ do end local function set_declname(def, declname) + if def.typename == "generic" then + def = def.t + end + if def.typename == "record" or def.typename == "interface" or def.typename == "enum" then if not def.declname then def.declname = declname @@ -3879,15 +3899,16 @@ do return i, t end - local function clone_typeargs(ps, i, typeargs) - local copy = {} - for a, ta in ipairs(typeargs) do - local cta = new_type(ps, i, "typearg") - cta.typearg = ta.typearg - copy[a] = cta - end - return copy - end + + + + + + + + + + local function extract_userdata_from_interface_list(ps, i, def) for j = #def.interface_list, 1, -1 do @@ -3945,9 +3966,7 @@ do i, where_macroexp = parse_where_clause(ps, i, def) local typ = new_type(ps, wstart, "function") - if def.typeargs then - typ.typeargs = clone_typeargs(ps, i, def.typeargs) - end + typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { @@ -4080,7 +4099,7 @@ do ["enum"] = parse_enum_body, } - parse_newtype = function(ps, i) + local function parse_newtype(ps, i) local node = new_node(ps, i, "newtype") local def local tn = ps.tokens[i].tk @@ -4099,6 +4118,11 @@ do if def.typename == "nominal" then node.newtype.is_alias = true + elseif def.typename == "generic" then + local deft = def.t + if deft.typename == "nominal" then + node.newtype.is_alias = true + end end return i, node @@ -4241,9 +4265,8 @@ do end local typeargs local itypeargs = i - if ps.tokens[i].tk == "<" then - i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + i, typeargs = parse_typeargs_if_any(ps, i) + asgn.var = var if node_name == "global_type" and ps.tokens[i].tk ~= "=" then @@ -4251,6 +4274,7 @@ do end i = verify_tk(ps, i, "=") + local istart = i if ps.tokens[i].kind == "identifier" then local done @@ -4267,23 +4291,16 @@ do local nt = asgn.value.newtype if nt.typename == "typedecl" then - local def = nt.def - if typeargs then - if def.typeargs then - if def.typeargs then - fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") - else - def.typeargs = typeargs - end + local def = nt.def + if def.typename == "generic" then + fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") else - - - nt.typeargs = typeargs + nt.def = new_generic(ps, istart, typeargs, def) end end - set_declname(def, asgn.var.tk) + set_declname(nt.def, asgn.var.tk) end return i, asgn @@ -4618,15 +4635,12 @@ local function recurse_type(s, ast, visit) local xs = {} - if ast.typeargs then - if ast.typeargs then - for _, child in ipairs(ast.typeargs) do - table.insert(xs, recurse_type(s, child, visit)) - end + if ast.typename == "generic" then + for _, child in ipairs(ast.typeargs) do + table.insert(xs, recurse_type(s, child, visit)) end - end - - if ast.typename == "tuple" then + table.insert(xs, recurse_type(s, ast.t, visit)) + elseif ast.typename == "tuple" then for i, child in ipairs(ast.tuple) do xs[i] = recurse_type(s, child, visit) end @@ -5805,6 +5819,7 @@ local typename_to_typecode = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, + ["generic"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5906,6 +5921,11 @@ function TypeReporter:get_typenum(t) return self:get_typenum(rt.def) end + + if rt.typename == "generic" then + rt = rt.t + end + local ti = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), @@ -6366,19 +6386,26 @@ function Errors:add_prefixing(w, src, prefix, dst) end end +local function ensure_not_abstract_type(def, node) + if def.typename == "record" then + return true + elseif def.typename == "generic" then + return ensure_not_abstract_type(def.t) + elseif node and node_is_require_call(node) then + return nil, "module type is abstract: " .. tostring(def) + elseif def.typename == "interface" then + return nil, "interfaces are abstract; consider using a concrete record" + end + return nil, "cannot use a type definition as a concrete value" +end + local function ensure_not_abstract(t, node) if t.typename == "function" and t.macroexp then return nil, "macroexps are abstract; consider using a concrete function" + elseif t.typename == "generic" then + return ensure_not_abstract(t.t, node) elseif t.typename == "typedecl" then - local def = t.def - if def.typename == "record" then - return true - elseif node and node_is_require_call(node) then - return nil, "module type is abstract: " .. tostring(t.def) - elseif def.typename == "interface" then - return nil, "interfaces are abstract; consider using a concrete record" - end - return nil, "cannot use a type definition as a concrete value" + return ensure_not_abstract_type(t.def, node) end return true end @@ -6774,27 +6801,11 @@ local function display_typevar(typevar, what) end local function show_fields(t, show) - if t.declname and not t.typeargs then + if t.declname then return " " .. t.declname end local out = {} - if t.declname and not t.typeargs then - table.insert(out, " " .. t.declname) - end - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - if t.declname then - return table.concat(out) - end - table.insert(out, " (") if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") @@ -6886,17 +6897,7 @@ local function show_type_base(t, short, seen) elseif t.fields then return short and (t.declname or t.typename) or t.typename .. show_fields(t, show) elseif t.typename == "function" then - local out = { "function" } - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - table.insert(out, "(") + local out = { "function(" } local args = {} for i, v in ipairs(t.args.tuple) do table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " or @@ -6920,6 +6921,26 @@ local function show_type_base(t, short, seen) end end return table.concat(out) + elseif t.typename == "generic" then + local out = {} + local name, rest + local tt = t.t + if tt.typename == "record" or tt.typename == "interface" or tt.typename == "function" then + name, rest = show(tt):match("^(%a+)(.*)") + table.insert(out, name) + else + rest = " " .. show(tt) + table.insert(out, "generic") + end + table.insert(out, "<") + local typeargs = {} + for _, v in ipairs(t.typeargs) do + table.insert(typeargs, show(v)) + end + table.insert(out, table.concat(typeargs, ", ")) + table.insert(out, ">") + table.insert(out, rest) + return table.concat(out) elseif t.typename == "number" or t.typename == "integer" or t.typename == "boolean" or @@ -7350,34 +7371,45 @@ do end - local typevar_resolver + local map_typevars local function fresh_typevar(_, t) - return a_type(t, "typevar", { - typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - constraint = t.constraint, - }) + if t.typename == "typevar" then + return a_type(t, "typevar", { + typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + constraint = t.constraint, + }), false + else + return a_type(t, "typearg", { + typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + constraint = t.constraint, + }), true + end end - local function fresh_typearg(_, t) - return a_type(t, "typearg", { - typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - constraint = t.constraint, - }) + local function fresh_typeargs(self, g) + fresh_typevar_ctr = fresh_typevar_ctr + 1 + + local newg, errs = map_typevars(nil, g, fresh_typevar) + if newg.typename == "invalid" then + self.errs:collect(errs) + return g + end + assert(newg.typename == "generic", "Internal Compiler Error: error creating fresh type variables") + assert(newg ~= g) + newg.fresh = true + + return newg end - function TypeChecker:ensure_fresh_typeargs(t) - if not t.typeargs then + local function wrap_generic_if_typeargs(typeargs, t) + if not typeargs then return t end - fresh_typevar_ctr = fresh_typevar_ctr + 1 - local ok, errs - ok, t, errs = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) - if not ok then - self.errs:collect(errs) - end - return t + local gt = a_type(t, "generic", { t = t }) + gt.typeargs = typeargs + return gt end function TypeChecker:find_var_type(name, use) @@ -7387,13 +7419,26 @@ do if t.typename == "unresolved_typearg" then return nil, nil, t.constraint end - t = self:ensure_fresh_typeargs(t) + + if t.typename == "generic" then + t = fresh_typeargs(self, t) + end + return t, var.attribute end end local function ensure_not_method(t) + if t.typename == "generic" then + local tt = ensure_not_method(t.t) + if tt ~= t.t then + local gg = shallow_copy_new_type(t) + gg.t = tt + return gg + end + end + if t.typename == "function" and t.is_method then t = shallow_copy_new_type(t) t.is_method = false @@ -7401,6 +7446,17 @@ do return t end + local function unwrap_for_find_type(typ) + if typ.typename == "nominal" and typ.found then + return unwrap_for_find_type(typ.found) + elseif typ.typename == "typedecl" then + return unwrap_for_find_type(typ.def) + elseif typ.typename == "generic" then + return unwrap_for_find_type(typ.t) + end + return typ + end + function TypeChecker:find_type(names) local typ = self:find_var_type(names[1], "use_type") if not typ then @@ -7409,13 +7465,8 @@ do end return nil end - if typ.typename == "nominal" and typ.found then - typ = typ.found - end for i = 2, #names do - if typ.typename == "typedecl" then - typ = typ.def - end + typ = unwrap_for_find_type(typ) local fields = typ.fields and typ.fields if not fields then @@ -7423,15 +7474,14 @@ do end typ = fields[names[i]] + if typ and typ.typename == "nominal" then + typ = typ.found + end if typ == nil then return nil end - - typ = self:ensure_fresh_typeargs(typ) - if typ.typename == "nominal" and typ.found then - typ = typ.found - end end + if typ.typename == "typedecl" then return typ elseif typ.typename == "typearg" then @@ -7441,7 +7491,7 @@ do local function type_for_union(t) if t.typename == "typedecl" then - return type_for_union(t.def), t.def + return type_for_union(t.def) elseif t.typename == "tuple" then return type_for_union(t.tuple[1]), t.tuple[1] elseif t.typename == "nominal" then @@ -7455,6 +7505,8 @@ do return "userdata", t end return "table", t + elseif t.typename == "generic" then + return type_for_union(t.t) elseif table_types[t.typename] then return "table", t else @@ -7568,10 +7620,6 @@ do } local function clear_resolved_typeargs(copy, resolved) - if not copy.typeargs then - return - end - for i = #copy.typeargs, 1, -1 do local r = resolved[copy.typeargs[i].typearg] if r then @@ -7579,26 +7627,17 @@ do end end if not copy.typeargs[1] then - copy.typeargs = nil + return copy.t end - - return + return copy end - typevar_resolver = function(self, typ, fn_var, fn_arg) + map_typevars = function(self, ty, fn_tv) local errs local seen = {} local resolved = {} local resolve - local function copy_typeargs(t, same) - local copy = {} - for i, tf in ipairs(t) do - copy[i], same = resolve(tf, same) - end - return copy, same - end - resolve = function(t, all_same) local same = true @@ -7613,9 +7652,11 @@ do local orig_t = t if t.typename == "typevar" then - local rt = fn_var(self, t) + local rt, is_resolved = fn_tv(self, t) if rt then - resolved[t.typevar] = rt + if is_resolved then + resolved[t.typevar] = rt + end if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -7634,18 +7675,26 @@ do copy.x = t.x copy.y = t.y - if t.typeargs then - (copy).typeargs, same = copy_typeargs(t.typeargs, same) - end + if t.typename == "generic" then + assert(copy.typename == "generic") - if t.typename == "array" then + local ct = {} + for i, tf in ipairs(t.typeargs) do + ct[i], same = resolve(tf, same) + end + copy.typeargs = ct + copy.t, same = resolve(t.t, same) + elseif t.typename == "array" then assert(copy.typename == "array") copy.elements, same = resolve(t.elements, same) elseif t.typename == "typearg" then - if fn_arg then - copy = fn_arg(self, t) + local rt, is_resolved = fn_tv(self, t) + if is_resolved then + resolved[t.typearg] = rt + copy = rt + same = false else assert(copy.typename == "typearg") copy.typearg = t.typearg @@ -7653,12 +7702,6 @@ do copy.constraint, same = resolve(t.constraint, same) end end - - - local rt = fn_var(self, a_type(t, "typevar", { typevar = t.typearg })) - if rt then - resolved[t.typearg] = rt - end elseif t.typename == "unresolvable_typearg" then assert(copy.typename == "unresolvable_typearg") copy.typearg = t.typearg @@ -7764,23 +7807,29 @@ do return copy, same and all_same end - local copy = resolve(typ, true) + local copy = resolve(ty, true) if errs then - return false, a_type(typ, "invalid", {}), errs + return a_type(ty, "invalid", {}), errs end - clear_resolved_typeargs(copy, resolved) + if copy.typename == "generic" then + copy = clear_resolved_typeargs(copy, resolved) + end - return true, copy, nil + return copy end local function resolve_typevar(tc, t) - local rt = tc:find_var_type(t.typevar) - if not rt then - return nil - end + if t.typename == "typevar" then + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + end - return drop_constant_value(rt) + return drop_constant_value(rt), true + else + return t, false + end end function TypeChecker:infer_emptytable(emptytable, fresh_t) @@ -7828,8 +7877,8 @@ do function TypeChecker:resolve_typevars_at(w, t) assert(w) - local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) - if not ok then + local ret, errs = map_typevars(self, t, resolve_typevar) + if errs then assert(w.y) self.errs:add_prefixing(w, errs, "") end @@ -8022,6 +8071,43 @@ do return NONE end + local function unresolved_typeargs_for(g) + local ts = {} + for _, ta in ipairs(g.typeargs) do + table.insert(ts, a_type(ta, "unresolved_typearg", { + constraint = ta.constraint, + })) + end + return ts + end + + function TypeChecker:apply_generic(w, g, typeargs) + if not g.fresh then + g = fresh_typeargs(self, g) + end + + if not typeargs then + typeargs = unresolved_typeargs_for(g) + end + + assert(#g.typeargs == #typeargs) + + for i, ta in ipairs(g.typeargs) do + self:add_var(nil, ta.typearg, typeargs[i]) + end + local applied, errs = map_typevars(self, g, resolve_typevar) + if errs then + self.errs:add_prefixing(w, errs, "") + return nil + end + + if applied.typename == "generic" then + return applied.t + else + return applied + end + end + do @@ -8048,12 +8134,7 @@ do end local function match_typevals(self, t, def) - if not t.typevals and not def.typeargs then - return def - elseif not def.typeargs then - self.errs:add(t, "unexpected type argument", t) - return nil - elseif not t.typevals then + if not t.typevals then self.errs:add(t, "missing type arguments in %s", def) return nil elseif #t.typevals ~= #def.typeargs then @@ -8063,10 +8144,7 @@ do self:begin_scope() - for i, tt in ipairs(t.typevals) do - self:add_var(nil, def.typeargs[i].typearg, tt) - end - local ret = self:resolve_typevars_at(t, def) + local ret = self:apply_generic(t, def, t.typevals) if def == self.cache_std_metatable_type then check_metatable_contract(self, t.typevals[1], ret) end @@ -8088,8 +8166,10 @@ do if found.typename == "typedecl" and found.is_alias then local def = found.def - assert(def.typename == "nominal") - found = def.found + if def.typename == "nominal" then + found = def.found + end + end if not found then @@ -8117,14 +8197,16 @@ do return nil, found end - local function resolve_decl_into_nominal(self, t, found) + local function resolve_decl_in_nominal(self, t, found) local def = found.def local resolved - if def.fields or def.typename == "function" then + if def.typename == "generic" then resolved = match_typevals(self, t, def) if not resolved then - return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + resolved = a_type(t, "invalid", {}) end + elseif t.typevals then + resolved = self.errs:invalid_at(t, "unexpected type argument") else resolved = def end @@ -8140,13 +8222,21 @@ do return immediate end - return resolve_decl_into_nominal(self, t, found) + return resolve_decl_in_nominal(self, t, found) end function TypeChecker:resolve_typealias(ta) + local def = ta.def + + local nom = def + if def.typename == "generic" then + nom = def.t + end + + if not (nom.typename == "nominal") then + return ta + end - local nom = ta.def - assert(nom.typename == "nominal") local immediate, found = find_nominal_type_decl(self, nom) @@ -8163,7 +8253,11 @@ do - local struc = resolve_decl_into_nominal(self, nom, found or nom.found) + local struc = resolve_decl_in_nominal(self, nom, found or nom.found) + + if def.typename == "generic" then + struc = wrap_generic_if_typeargs(def.typeargs, struc) + end local td = a_type(ta, "typedecl", { def = struc }) @@ -8259,6 +8353,10 @@ do end local t = typedecl.def + if t.typename == "generic" then + t = t.t + end + return t end @@ -8563,8 +8661,8 @@ do end end - local ok, r, errs = typevar_resolver(self, other, resolve_typevar) - if not ok then + local r, errs = map_typevars(self, other, resolve_typevar) + if errs then return false, errs end if r.typename == "typevar" and r.typevar == typevar then @@ -8739,6 +8837,19 @@ do ["boolean_context"] = { ["boolean"] = compare_true, }, + ["generic"] = { + ["generic"] = function(self, a, b) + if #a.typeargs ~= #b.typeargs then + return false + end + for i = 1, #a.typeargs do + if not self:same_type(a.typeargs[i], b.typeargs[i]) then + return false + end + end + return self:same_type(a.t, b.t) + end, + }, ["*"] = { ["boolean_context"] = compare_true, ["self"] = function(self, a, b) @@ -9068,6 +9179,16 @@ a.types[i], b.types[i]), } ["boolean_context"] = { ["boolean"] = compare_true, }, + ["generic"] = { + ["*"] = function(self, a, b) + + + local aa = self:apply_generic(a, a) + local ok, errs = self:is_a(aa, b) + + return ok, errs + end, + }, ["*"] = { ["any"] = compare_true, ["boolean_context"] = compare_true, @@ -9078,6 +9199,14 @@ a.types[i], b.types[i]), } infer_emptytable_from_unresolved_value(self, b, b, a) return true end, + ["generic"] = function(self, a, b) + + + local bb = self:apply_generic(b, b) + local ok, errs = self:is_a(a, bb) + + return ok, errs + end, ["self"] = function(self, a, b) return self:is_a(a, self:type_of_self(b)) end, @@ -9113,6 +9242,7 @@ a.types[i], b.types[i]), } TypeChecker.type_priorities = { + ["generic"] = -1, ["nil"] = 0, ["unresolved_emptytable_value"] = 1, ["emptytable"] = 2, @@ -9286,6 +9416,11 @@ a.types[i], b.types[i]), } end func = self:to_structural(func) + + if func.typename == "generic" then + func = self:apply_generic(func, func) + end + if func.typename ~= "function" and func.typename ~= "poly" then if func.typename == "union" then @@ -9405,17 +9540,15 @@ a.types[i], b.types[i]), } end do - local function mark_invalid_typeargs(self, f) - if f.typeargs then - for _, a in ipairs(f.typeargs) do - if not self:find_var_type(a.typearg) then - if a.constraint then - self:add_var(nil, a.typearg, a.constraint) - else - self:add_var(nil, a.typearg, self.feat_lax and a_type(a, "unknown", {}) or a_type(a, "unresolvable_typearg", { - typearg = a.typearg, - })) - end + local function mark_invalid_typeargs(self, gt) + for _, a in ipairs(gt.typeargs) do + if not self:find_var_type(a.typearg) then + if a.constraint then + self:add_var(nil, a.typearg, a.constraint) + else + self:add_var(nil, a.typearg, self.feat_lax and a_type(a, "unknown", {}) or a_type(a, "unresolvable_typearg", { + typearg = a.typearg, + })) end end end @@ -9518,16 +9651,6 @@ a.types[i], b.types[i]), } return false end - local function add_call_typeargs(self, func) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { - constraint = fnarg.constraint, - })) - end - end - end - check_call = function(self, w, wargs, f, args, expected_rets, cm, argdelta) local arg1 = args.tuple[1] if cm == "method" and arg1 then @@ -9551,17 +9674,30 @@ a.types[i], b.types[i]), } return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } end - add_call_typeargs(self, f) - return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end end + function TypeChecker:iterate_poly(p) + local i = 0 + return function() + i = i + 1 + local fg = p.types[i] + if not fg then + return + elseif fg.typename == "function" then + return i, fg + elseif fg.typename == "generic" then + return i, self:apply_generic(p, fg) + end + end + end + local check_poly_call do - local function fail_poly_call_arity(w, p, given) + local function fail_poly_call_arity(self, w, p, given) local expects = {} - for _, f in ipairs(p.types) do + for _, f in self:iterate_poly(p) do table.insert(expects, show_arity(f)) end table.sort(expects) @@ -9588,7 +9724,9 @@ a.types[i], b.types[i]), } local first_errs for pass = 1, 3 do - for i, f in ipairs(p.types) do + for i, f in self:iterate_poly(p) do + assert(f.typename == "function", f.typename) + assert(f.args) first_rets = first_rets or f.rets local wanted = #f.args.tuple @@ -9619,7 +9757,7 @@ a.types[i], b.types[i]), } end if not first_errs then - return nil, first_rets, fail_poly_call_arity(w, p, given) + return nil, first_rets, fail_poly_call_arity(self, w, p, given) end return nil, first_rets, first_errs @@ -9654,6 +9792,14 @@ a.types[i], b.types[i]), } expected_rets = a_type(node, "tuple", { tuple = { node.expected } }) end + self:begin_scope() + + local g + if func.typename == "generic" then + g = func + func = self:apply_generic(node, func) + end + local is_method = (argdelta == -1) if not (func.typename == "function" or func.typename == "poly") then @@ -9668,8 +9814,6 @@ a.types[i], b.types[i]), } local errs local f, ret - self:begin_scope() - if func.typename == "poly" then f, ret, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) elseif func.typename == "function" then @@ -9684,8 +9828,8 @@ a.types[i], b.types[i]), } self.errs:collect(errs) end - if f then - mark_invalid_typeargs(self, f) + if g then + mark_invalid_typeargs(self, g) end ret = self:resolve_typevars_at(node, ret) @@ -9721,6 +9865,7 @@ a.types[i], b.types[i]), } if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return a_type(node, "unknown", {}), nil end + local ameta = a.fields and a.meta_fields local bmeta = b and b.fields and b.meta_fields @@ -9821,6 +9966,10 @@ a.types[i], b.types[i]), } end end + if tbl.typename == "generic" then + tbl = self:apply_generic(tbl, tbl) + end + if tbl.typename == "union" then local t = self:same_in_all_union_entries(tbl, function(t) return (self:match_record_key(t, rec, key)) @@ -10004,12 +10153,11 @@ a.types[i], b.types[i]), } end function TypeChecker:add_function_definition_for_recursion(node, fnargs, feat_arity) - self:add_var(nil, node.name.tk, a_function(node, { + self:add_var(nil, node.name.tk, wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), - })) + }))) end function TypeChecker:end_function_scope(node) @@ -10297,7 +10445,7 @@ a.types[i], b.types[i]), } local function typedecl_to_nominal(node, name, t, resolved) local typevals local def = t.def - if def.typeargs then + if def.typename == "generic" then typevals = {} for _, a in ipairs(def.typeargs) do table.insert(typevals, a_type(a, "typevar", { @@ -10923,7 +11071,7 @@ a.types[i], b.types[i]), } local typ typ = decls[i] if typ then - if node.kind == "assignment" and i == nexps and ndecl > nexps then + if i == nexps and ndecl > nexps and node_is_funcall(node.exps[i]) then typ = a_type(node, "tuple", { tuple = {} }) for a = i, ndecl do table.insert(typ.tuple, decls[a]) @@ -11214,6 +11362,13 @@ self:expand_type(node, values, elements) }) if def.typename == "nominal" then return (self:find_var(def.names[1], "use_type")) end + + if def.typename == "generic" then + local nom = def.t + if nom.typename == "nominal" then + return (self:find_var(nom.names[1], "use_type")) + end + end end local function recurse_type_declaration(self, n) @@ -11265,14 +11420,13 @@ self:expand_type(node, values, elements) }) local resolved, aliasing = recurse_type_declaration(self, value) local nt = value.newtype if nt and nt.is_alias and resolved.typename == "typedecl" then - if nt.typeargs then - local def = resolved.def + local ntdef = nt.def + local rdef = resolved.def + if ntdef.typename == "generic" and rdef.typename == "generic" then - if def.typename == "record" or def.typename == "function" or def.typename == "interface" then - def.typeargs = nt.typeargs - end + ntdef.typeargs = rdef.typeargs end end return resolved, aliasing @@ -11848,6 +12002,10 @@ self:expand_type(node, values, elements) }) decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end + if decltype.typename == "generic" then + decltype = self:apply_generic(node, decltype) + end + if decltype.typename == "tupletable" then for _, child in ipairs(node) do local n = child.key.constnum @@ -11892,6 +12050,10 @@ self:expand_type(node, values, elements) }) decltype = self:to_structural(constraint) end + if decltype.typename == "generic" then + decltype = self:apply_generic(node, decltype) + end + if decltype.typename == "union" then local single_table_type local single_table_rt @@ -12046,9 +12208,8 @@ self:expand_type(node, values, elements) }) self:end_function_scope(node) - local t = self:ensure_fresh_typeargs(a_function(node, { + local t = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), })) @@ -12075,9 +12236,8 @@ self:expand_type(node, values, elements) }) self:check_macroexp_arg_use(node.macrodef) - local t = self:ensure_fresh_typeargs(a_function(node, { + local t = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.macrodef.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), macroexp = node.macrodef, @@ -12122,9 +12282,8 @@ self:expand_type(node, values, elements) }) return NONE end - self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { + self:add_global(node, node.name.tk, wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), }))) @@ -12141,7 +12300,7 @@ self:expand_type(node, values, elements) }) local rtype = self:to_structural(resolve_typedecl(children[1])) - if rtype.fields and rtype.typeargs then + if rtype.typename == "generic" then for _, typ in ipairs(rtype.typeargs) do self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, @@ -12159,6 +12318,10 @@ self:expand_type(node, values, elements) }) local t = children[1] local rtype = self:to_structural(resolve_typedecl(t)) + if rtype.typename == "generic" then + rtype = rtype.t + end + do local ok, err = ensure_not_abstract(t) if not ok then @@ -12196,10 +12359,9 @@ self:expand_type(node, values, elements) }) end end - local fn_type = self:ensure_fresh_typeargs(a_function(node, { + local fn_type = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), is_record_function = true, @@ -12215,6 +12377,12 @@ self:expand_type(node, values, elements) }) self.errs:redeclaration_warning(node, node.name.tk, "function") end + if fn_type.typename == "generic" and not (rfieldtype.typename == "generic") then + self:begin_scope() + fn_type = self:apply_generic(node, fn_type) + self:end_scope() + end + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype.typename == "poly" then @@ -12228,7 +12396,15 @@ self:expand_type(node, values, elements) }) return end else + if open_t and open_t.typename == "generic" then + open_t = open_t.t + end if self.feat_lax or rtype == open_t then + + + + + rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) @@ -12283,9 +12459,9 @@ self:expand_type(node, values, elements) }) assert(rets.typename == "tuple") self:end_function_scope(node) - return self:ensure_fresh_typeargs(a_function(node, { + + return wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), })) @@ -12309,9 +12485,8 @@ self:expand_type(node, values, elements) }) assert(rets.typename == "tuple") self:end_function_scope(node) - return self:ensure_fresh_typeargs(a_function(node, { + return wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = rets, })) @@ -12353,6 +12528,9 @@ self:expand_type(node, values, elements) }) elseif node.op.op == "or" then self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then + if e1type.typename == "generic" then + e1type = self:apply_generic(node, e1type) + end if e1type.typename == "function" then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then @@ -12973,16 +13151,6 @@ self:expand_type(node, values, elements) }) end end - local visit_type_with_typeargs = { - before = function(self, _typ) - self:begin_scope() - end, - after = function(self, typ, _children) - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - } - function TypeChecker:begin_temporary_record_types(typ) self:add_var(nil, "@self", a_type(typ, "typedecl", { def = typ })) @@ -13010,26 +13178,35 @@ self:expand_type(node, values, elements) }) end end - local function ensure_is_method_self(typ, fargs) - assert(typ.declname) - local selfarg = fargs[1] + local function ensure_is_method_self(typ, selfarg, g) if selfarg.typename == "self" then return true end if not (selfarg.typename == "nominal") then return false end - if selfarg.names[1] ~= typ.declname or (typ.typeargs and not selfarg.typevals) then + + if #selfarg.names ~= 1 or selfarg.names[1] ~= typ.declname then return false end - if typ.typeargs then - for j = 1, #typ.typeargs do + + if g then + if not selfarg.typevals then + return false + end + + if g.t.typeid ~= typ.typeid then + return false + end + + for j = 1, #g.typeargs do local tv = selfarg.typevals[j] - if not (tv and tv.typename == "typevar" and tv.typevar == typ.typeargs[j].typearg) then + if not (tv and tv.typename == "typevar" and tv.typevar == g.typeargs[j].typearg) then return false end end end + return true end @@ -13050,13 +13227,22 @@ self:expand_type(node, values, elements) }) local visit_type visit_type = { cbs = { + ["generic"] = { + before = function(self, typ) + self:begin_scope() + self:add_var(nil, "@generic", typ) + end, + after = function(self, typ, _children) + self:end_scope() + return fresh_typeargs(self, typ) + end, + }, ["function"] = { - before = visit_type_with_typeargs.before, - after = function(self, typ, children) + after = function(self, typ, _children) if self.feat_arity == false then typ.min_arity = 0 end - return visit_type_with_typeargs.after(self, typ, children) + return typ end, }, ["record"] = { @@ -13066,12 +13252,6 @@ self:expand_type(node, values, elements) }) end, after = function(self, typ, children) local i = 1 - if typ.typeargs then - for _, _ in ipairs(typ.typeargs) do - typ.typeargs[i] = children[i] - i = i + 1 - end - end if typ.interface_list then for j, _ in ipairs(typ.interface_list) do local iface = children[i] @@ -13093,6 +13273,7 @@ self:expand_type(node, values, elements) }) i = i + 1 end local fmacros + local g for name, _ in fields_of(typ) do local ftype = children[i] if ftype.typename == "function" then @@ -13104,7 +13285,10 @@ self:expand_type(node, values, elements) }) if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] then - ftype.is_method = ensure_is_method_self(typ, fargs) + if not g then + g = self:find_var("@generic") + end + ftype.is_method = ensure_is_method_self(typ, fargs[1], g and g.t) if ftype.is_method then fargs[1] = a_type(fargs[1], "self", { display_type = typ }) end @@ -13194,8 +13378,12 @@ self:expand_type(node, values, elements) }) if t then local def = t.def if t.is_alias then - assert(def.typename == "nominal") - typ.found = def.found + if def.typename == "generic" then + def = def.t + end + if def.typename == "nominal" then + typ.found = def.found + end elseif def.typename ~= "circular_require" then typ.found = t end @@ -13236,8 +13424,7 @@ self:expand_type(node, values, elements) }) visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type_with_typeargs - + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor diff --git a/tl.tl b/tl.tl index f72568ea3..793da8605 100644 --- a/tl.tl +++ b/tl.tl @@ -1590,6 +1590,7 @@ local function new_typeid(): integer end local enum TypeName + "generic" "typedecl" "typevar" "typearg" @@ -1634,6 +1635,7 @@ local table_types : {TypeName:boolean} = { ["emptytable"] = true, ["tupletable"] = true, + ["generic"] = false, ["typedecl"] = false, ["typevar"] = false, ["typearg"] = false, @@ -1709,15 +1711,17 @@ local record BooleanContextType where self.typename == "boolean_context" end -local interface HasTypeArgs +local record GenericType is Type - where self.typeargs + where self.typename == "generic" typeargs: {TypeArgType} + t: Type + fresh: boolean end local record TypeDeclType - is Type, HasTypeArgs + is Type where self.typename == "typedecl" def: Type @@ -1777,7 +1781,7 @@ local interface ArrayLikeType end local interface RecordLikeType - is Type, HasTypeArgs, HasDeclName, ArrayLikeType + is Type, HasDeclName, ArrayLikeType where self.fields interface_list: {ArrayType | NominalType} @@ -1882,7 +1886,7 @@ local record UnresolvedEmptyTableValueType end local record FunctionType - is Type, HasTypeArgs + is Type where self.typename == "function" is_method: boolean @@ -1917,7 +1921,7 @@ local record PolyType is AggregateType where self.typename == "poly" - types: {FunctionType} + types: {FunctionType | GenericType} end local record EnumType @@ -2379,7 +2383,6 @@ local parse_argument_list: function(ParseState, integer): integer, Node, integer local parse_argument_type_list: function(ParseState, integer): integer, TupleType, boolean, integer local parse_type: function(ParseState, integer): integer, Type, integer local parse_type_declaration: function(ps: ParseState, i: integer, node_name: NodeKind): integer, Node -local parse_newtype: function(ps: ParseState, i: integer): integer, Node local parse_interface_name: function(ps: ParseState, i: integer): integer, Type, integer local type ParseBody = function(ps: ParseState, i: integer, def: Type): integer, boolean @@ -2446,6 +2449,13 @@ local function new_type(ps: ParseState, i: integer, typename: TypeName): Type return t end +local function new_generic(ps: ParseState, i: integer, typeargs: {TypeArgType}, typ: Type): GenericType + local gt = new_type(ps, i, "generic") as GenericType + gt.typeargs = typeargs + gt.t = typ + return gt +end + local function new_typedecl(ps: ParseState, i: integer, def: Type): TypeDeclType local t = new_type(ps, i, "typedecl") as TypeDeclType t.def = def @@ -2500,12 +2510,6 @@ local function parse_type_body(ps: ParseState, i: integer, istart: integer, node def = new_type(ps, istart, tn) - if typeargs then - if def is RecordType or def is InterfaceType then - def.typeargs = typeargs - end - end - local ok: boolean i, ok = parse_type_body_fns[tn](ps, i, def) if not ok then @@ -2514,6 +2518,10 @@ local function parse_type_body(ps: ParseState, i: integer, istart: integer, node i = verify_end(ps, i, istart, node) + if typeargs then + def = new_generic(ps, istart, typeargs, def) + end + return i, def end @@ -2760,11 +2768,12 @@ parse_typeargs_if_any = function(ps: ParseState, i: integer): integer, {TypeArgT return i end -local function parse_function_type(ps: ParseState, i: integer): integer, FunctionType +local function parse_function_type(ps: ParseState, i: integer): integer, GenericType | FunctionType + local typeargs: {TypeArgType} local typ = new_type(ps, i, "function") as FunctionType i = i + 1 - i, typ.typeargs = parse_typeargs_if_any(ps, i) + i, typeargs = parse_typeargs_if_any(ps, i) if ps.tokens[i].tk == "(" then i, typ.args, typ.maybe_method, typ.min_arity = parse_argument_type_list(ps, i) i, typ.rets = parse_return_types(ps, i) @@ -2774,6 +2783,11 @@ local function parse_function_type(ps: ParseState, i: integer): integer, Functio typ.is_method = false typ.min_arity = 0 end + + if typeargs then + return i, new_generic(ps, i, typeargs, typ) + end + return i, typ end @@ -3720,15 +3734,17 @@ local function store_field_in_record(ps: ParseState, i: integer, field_name: str end local oldt = fields[field_name] + local oldf = oldt is GenericType and oldt.t or oldt + local newf = newt is GenericType and newt.t or newt - if newt is FunctionType then - if oldt is FunctionType then + if newf is FunctionType then + if oldf is FunctionType then local p = new_type(ps, i, "poly") as PolyType - p.types = { oldt, newt } + p.types = { oldt as FunctionType, newt as FunctionType } fields[field_name] = p return true elseif oldt is PolyType then - table.insert(oldt.types, newt) + table.insert(oldt.types, newt as FunctionType) return true end end @@ -3737,6 +3753,10 @@ local function store_field_in_record(ps: ParseState, i: integer, field_name: str end local function set_declname(def: Type, declname: string) + if def is GenericType then + def = def.t + end + if def is RecordType or def is InterfaceType or def is EnumType then if not def.declname then def.declname = declname @@ -3879,15 +3899,16 @@ local function parse_array_interface_type(ps: ParseState, i: integer, def: Recor return i, t end -local function clone_typeargs(ps: ParseState, i: integer, typeargs: {TypeArgType}): {TypeArgType} - local copy = {} - for a, ta in ipairs(typeargs) do - local cta = new_type(ps, i, "typearg") as TypeArgType - cta.typearg = ta.typearg - copy[a] = cta - end - return copy -end +-- FIXME (a) GenericType do we need to patch the where-generated __is function into a generic? (see (b)) +--local function clone_typeargs(ps: ParseState, i: integer, typeargs: {TypeArgType}): {TypeArgType} +-- local copy = {} +-- for a, ta in ipairs(typeargs) do +-- local cta = new_type(ps, i, "typearg") as TypeArgType +-- cta.typearg = ta.typearg +-- copy[a] = cta +-- end +-- return copy +--end local function extract_userdata_from_interface_list(ps: ParseState, i: integer, def: RecordLikeType) for j = #def.interface_list, 1, -1 do @@ -3945,9 +3966,7 @@ parse_record_body = function(ps: ParseState, i: integer, def: RecordLikeType): i i, where_macroexp = parse_where_clause(ps, i, def) local typ = new_type(ps, wstart, "function") as FunctionType - if def.typeargs then - typ.typeargs = clone_typeargs(ps, i, def.typeargs) - end + -- FIXME (b) GenericType do we need to patch the where-generated __is function into a generic? (see (a)) typ.is_method = true typ.min_arity = 1 typ.args = new_tuple(ps, wstart, { @@ -4080,7 +4099,7 @@ parse_type_body_fns = { ["enum"] = parse_enum_body, } -parse_newtype = function(ps: ParseState, i: integer): integer, Node +local function parse_newtype(ps: ParseState, i: integer): integer, Node local node: Node = new_node(ps, i, "newtype") local def: Type local tn = ps.tokens[i].tk as TypeName @@ -4099,6 +4118,11 @@ parse_newtype = function(ps: ParseState, i: integer): integer, Node if def is NominalType then node.newtype.is_alias = true + elseif def is GenericType then + local deft = def.t + if deft is NominalType then + node.newtype.is_alias = true + end end return i, node @@ -4241,9 +4265,8 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin end local typeargs: {TypeArgType} local itypeargs = i - if ps.tokens[i].tk == "<" then - i, typeargs = parse_anglebracket_list(ps, i, parse_typearg) - end + i, typeargs = parse_typeargs_if_any(ps, i) + asgn.var = var if node_name == "global_type" and ps.tokens[i].tk ~= "=" then @@ -4251,6 +4274,7 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin end i = verify_tk(ps, i, "=") + local istart = i if ps.tokens[i].kind == "identifier" then local done: boolean @@ -4267,23 +4291,16 @@ parse_type_declaration = function(ps: ParseState, i: integer, node_name: NodeKin local nt = asgn.value.newtype if nt is TypeDeclType then - local def = nt.def - if typeargs then - if def is HasTypeArgs then - if def.typeargs then - fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") - else - def.typeargs = typeargs - end + local def = nt.def + if def is GenericType then + fail(ps, itypeargs, "cannot declare type arguments twice in type declaration") else - -- FIXME how to resolve type arguments in unions properly - -- fail(ps, itypeargs, def.typename .. " does not accept type arguments") - nt.typeargs = typeargs + nt.def = new_generic(ps, istart, typeargs, def) end end - set_declname(def, asgn.var.tk) + set_declname(nt.def, asgn.var.tk) end return i, asgn @@ -4618,15 +4635,12 @@ local function recurse_type(s: S, ast: Type, visit: Visitor: {TypeName:integer} = { ["tuple"] = tl.typecodes.UNKNOWN, ["literal_table_item"] = tl.typecodes.UNKNOWN, ["typedecl"] = tl.typecodes.UNKNOWN, + ["generic"] = tl.typecodes.UNKNOWN, ["*"] = tl.typecodes.UNKNOWN, } @@ -5906,6 +5921,11 @@ function TypeReporter:get_typenum(t: Type): integer return self:get_typenum(rt.def) end + -- CHECK is this sufficient? + if rt is GenericType then + rt = rt.t + end + local ti: TypeInfo = { t = assert(typename_to_typecode[rt.typename]), str = show_type(t, true), @@ -6366,19 +6386,26 @@ function Errors:add_prefixing(w: Where, src: {Error}, prefix: string, dst?: {Err end end +local function ensure_not_abstract_type(def: Type, node?: Node): boolean, string + if def is RecordType then + return true + elseif def is GenericType then + return ensure_not_abstract_type(def.t) + elseif node and node_is_require_call(node) then + return nil, "module type is abstract: " .. tostring(def) + elseif def is InterfaceType then + return nil, "interfaces are abstract; consider using a concrete record" + end + return nil, "cannot use a type definition as a concrete value" +end + local function ensure_not_abstract(t: Type, node?: Node): boolean, string if t is FunctionType and t.macroexp then return nil, "macroexps are abstract; consider using a concrete function" + elseif t is GenericType then + return ensure_not_abstract(t.t, node) elseif t is TypeDeclType then - local def = t.def - if def is RecordType then - return true - elseif node and node_is_require_call(node) then - return nil, "module type is abstract: " .. tostring(t.def) - elseif def is InterfaceType then - return nil, "interfaces are abstract; consider using a concrete record" - end - return nil, "cannot use a type definition as a concrete value" + return ensure_not_abstract_type(t.def, node) end return true end @@ -6774,27 +6801,11 @@ local function display_typevar(typevar: string, what: TypeName): string end local function show_fields(t: RecordLikeType, show: function(Type):(string)): string - if t.declname and not t.typeargs then + if t.declname then return " " .. t.declname end local out: {string} = {} - if t.declname and not t.typeargs then - table.insert(out, " " .. t.declname) - end - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - if t.declname then - return table.concat(out) - end - table.insert(out, " (") if t.elements then table.insert(out, "{" .. show(t.elements) .. "}") @@ -6886,17 +6897,7 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str elseif t is RecordLikeType then return short and (t.declname or t.typename) or t.typename .. show_fields(t, show) elseif t is FunctionType then - local out: {string} = {"function"} - if t.typeargs then - table.insert(out, "<") - local typeargs = {} - for _, v in ipairs(t.typeargs) do - table.insert(typeargs, show(v)) - end - table.insert(out, table.concat(typeargs, ", ")) - table.insert(out, ">") - end - table.insert(out, "(") + local out: {string} = {"function("} local args = {} for i, v in ipairs(t.args.tuple) do table.insert(args, ((i == #t.args.tuple and t.args.is_va) and "...: " @@ -6920,6 +6921,26 @@ local function show_type_base(t: Type, short: boolean, seen: {Type:string}): str end end return table.concat(out) + elseif t is GenericType then + local out: {string} = {} + local name, rest: string, string + local tt = t.t + if tt is RecordType or tt is InterfaceType or tt is FunctionType then + name, rest = show(tt):match("^(%a+)(.*)") + table.insert(out, name) + else + rest = " " .. show(tt) + table.insert(out, "generic") + end + table.insert(out, "<") + local typeargs = {} + for _, v in ipairs(t.typeargs) do + table.insert(typeargs, show(v)) + end + table.insert(out, table.concat(typeargs, ", ")) + table.insert(out, ">") + table.insert(out, rest) + return table.concat(out) elseif t.typename == "number" or t.typename == "integer" or t.typename == "boolean" @@ -7349,35 +7370,46 @@ do }, nil end - local type ResolveType = function(S, Type): Type - local typevar_resolver: function(s: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + local type ResolveTypeVars = function(S, TypeVarType | TypeArgType): Type, boolean + local map_typevars: function(s: S, ty: Type, fn_tv: ResolveTypeVars): Type, {Error} - local function fresh_typevar(_: nil, t: TypeVarType): Type, Type, boolean - return a_type(t, "typevar", { - typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - constraint = t.constraint, - } as TypeVarType) + local function fresh_typevar(_: nil, t: TypeVarType | TypeArgType): Type, boolean + if t is TypeVarType then + return a_type(t, "typevar", { + typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + constraint = t.constraint, + } as TypeVarType), false + else + return a_type(t, "typearg", { + typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, + constraint = t.constraint, + } as TypeArgType), true + end end - local function fresh_typearg(_: nil, t: TypeArgType): Type - return a_type(t, "typearg", { - typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, - constraint = t.constraint, - } as TypeArgType) + local function fresh_typeargs(self: TypeChecker, g: GenericType): GenericType + fresh_typevar_ctr = fresh_typevar_ctr + 1 + + local newg, errs = map_typevars(nil, g, fresh_typevar) + if newg is InvalidType then + self.errs:collect(errs) + return g + end + assert(newg is GenericType, "Internal Compiler Error: error creating fresh type variables") + assert(newg ~= g) + newg.fresh = true + + return newg end - function TypeChecker:ensure_fresh_typeargs(t: T): T - if not t is HasTypeArgs then + local function wrap_generic_if_typeargs(typeargs: {TypeArgType}, t: T): T | GenericType + if not typeargs then return t end - fresh_typevar_ctr = fresh_typevar_ctr + 1 - local ok, errs: boolean, {Error} - ok, t, errs = typevar_resolver(nil, t, fresh_typevar, fresh_typearg) - if not ok then - self.errs:collect(errs) - end - return t + local gt = a_type(t, "generic", { t = t } as GenericType) + gt.typeargs = typeargs + return gt end function TypeChecker:find_var_type(name: string, use?: VarUse): Type, Attribute, Type @@ -7387,13 +7419,26 @@ do if t is UnresolvedTypeArgType then return nil, nil, t.constraint end - t = self:ensure_fresh_typeargs(t) + + if t is GenericType then + t = fresh_typeargs(self, t) + end + return t, var.attribute end end local function ensure_not_method(t: Type): Type + if t is GenericType then + local tt = ensure_not_method(t.t) + if tt ~= t.t then + local gg = shallow_copy_new_type(t) + gg.t = tt + return gg + end + end + if t is FunctionType and t.is_method then t = shallow_copy_new_type(t) t.is_method = false @@ -7401,6 +7446,17 @@ do return t end + local function unwrap_for_find_type(typ: Type): Type + if typ is NominalType and typ.found then + return unwrap_for_find_type(typ.found) + elseif typ is TypeDeclType then + return unwrap_for_find_type(typ.def) + elseif typ is GenericType then + return unwrap_for_find_type(typ.t) + end + return typ + end + function TypeChecker:find_type(names: {string}): TypeDeclType, TypeArgType local typ = self:find_var_type(names[1], "use_type") if not typ then @@ -7409,13 +7465,8 @@ do end return nil end - if typ is NominalType and typ.found then - typ = typ.found - end for i = 2, #names do - if typ is TypeDeclType then - typ = typ.def - end + typ = unwrap_for_find_type(typ) local fields = typ is RecordLikeType and typ.fields if not fields then @@ -7423,15 +7474,14 @@ do end typ = fields[names[i]] + if typ and typ is NominalType then + typ = typ.found + end if typ == nil then return nil end - - typ = self:ensure_fresh_typeargs(typ) - if typ is NominalType and typ.found then - typ = typ.found - end end + if typ is TypeDeclType then return typ elseif typ is TypeArgType then @@ -7441,7 +7491,7 @@ do local function type_for_union(t: Type): string, Type if t is TypeDeclType then - return type_for_union(t.def), t.def + return type_for_union(t.def) elseif t is TupleType then return type_for_union(t.tuple[1]), t.tuple[1] elseif t is NominalType then @@ -7455,6 +7505,8 @@ do return "userdata", t end return "table", t + elseif t is GenericType then + return type_for_union(t.t) elseif table_types[t.typename] then return "table", t else @@ -7567,11 +7619,7 @@ do ["unknown"] = true, } - local function clear_resolved_typeargs(copy: Type, resolved: {string:Type}) - if not copy is HasTypeArgs then - return - end - + local function clear_resolved_typeargs(copy: GenericType, resolved: {string:Type}): Type for i = #copy.typeargs, 1, -1 do local r = resolved[copy.typeargs[i].typearg] if r then @@ -7579,26 +7627,17 @@ do end end if not copy.typeargs[1] then - copy.typeargs = nil + return copy.t end - - return + return copy end - typevar_resolver = function(self: S, typ: Type, fn_var: ResolveType, fn_arg?: ResolveType): boolean, Type, {Error} + map_typevars = function(self: S, ty: Type, fn_tv: ResolveTypeVars): Type, {Error} local errs: {Error} local seen: {Type:Type} = {} local resolved: {string:Type} = {} local resolve: function(t: T, all_same: boolean): T, boolean - local function copy_typeargs(t: {TypeArgType}, same: boolean): {TypeArgType}, boolean - local copy = {} - for i, tf in ipairs(t) do - copy[i], same = resolve(tf, same) as (TypeArgType, boolean) - end - return copy, same - end - resolve = function(t: T, all_same: boolean): T, boolean local same = true @@ -7613,9 +7652,11 @@ do local orig_t = t if t is TypeVarType then - local rt = fn_var(self, t) + local rt, is_resolved = fn_tv(self, t) if rt then - resolved[t.typevar] = rt + if is_resolved then + resolved[t.typevar] = rt + end if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then seen[orig_t] = rt return rt, false @@ -7634,18 +7675,26 @@ do copy.x = t.x copy.y = t.y - if t is HasTypeArgs then - (copy as HasTypeArgs).typeargs, same = copy_typeargs(t.typeargs, same) - end + if t is GenericType then + assert(copy is GenericType) - if t is ArrayType then + local ct = {} + for i, tf in ipairs(t.typeargs) do + ct[i], same = resolve(tf, same) + end + copy.typeargs = ct + copy.t, same = resolve(t.t, same) + elseif t is ArrayType then assert(copy is ArrayType) copy.elements, same = resolve(t.elements, same) -- inferred_len is not propagated elseif t is TypeArgType then - if fn_arg then - copy = fn_arg(self, t) + local rt, is_resolved = fn_tv(self, t) + if is_resolved then + resolved[t.typearg] = rt + copy = rt + same = false else assert(copy is TypeArgType) copy.typearg = t.typearg @@ -7653,12 +7702,6 @@ do copy.constraint, same = resolve(t.constraint, same) end end - - -- eager resolution of type argument variables - local rt = fn_var(self, a_type(t, "typevar", { typevar = t.typearg } as TypeVarType)) - if rt then - resolved[t.typearg] = rt - end elseif t is UnresolvableTypeArgType then assert(copy is UnresolvableTypeArgType) copy.typearg = t.typearg @@ -7737,7 +7780,7 @@ do assert(copy is PolyType) copy.types = {} for i, tf in ipairs(t.types) do - copy.types[i], same = resolve(tf, same) as (FunctionType, boolean) + copy.types[i], same = resolve(tf, same) end elseif t is TupleTableType then assert(copy is TupleTableType) @@ -7764,23 +7807,29 @@ do return copy, same and all_same end - local copy = resolve(typ, true) + local copy = resolve(ty, true) if errs then - return false, an_invalid(typ), errs + return an_invalid(ty), errs end - clear_resolved_typeargs(copy, resolved) + if copy is GenericType then + copy = clear_resolved_typeargs(copy, resolved) + end - return true, copy, nil + return copy end - local function resolve_typevar(tc: TypeChecker, t: TypeVarType): Type - local rt = tc:find_var_type(t.typevar) - if not rt then - return nil - end + local function resolve_typevar(tc: TypeChecker, t: TypeVarType | TypeArgType): Type, boolean + if t is TypeVarType then + local rt = tc:find_var_type(t.typevar) + if not rt then + return nil + end - return drop_constant_value(rt) + return drop_constant_value(rt), true + else + return t, false + end end function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) @@ -7826,10 +7875,10 @@ do return t end - function TypeChecker:resolve_typevars_at(w: Where, t: T): T + function TypeChecker:resolve_typevars_at(w: Where, t: Type): Type assert(w) - local ok, ret, errs = typevar_resolver(self, t, resolve_typevar) - if not ok then + local ret, errs = map_typevars(self, t, resolve_typevar) + if errs then assert(w.y) self.errs:add_prefixing(w, errs, "") end @@ -8022,6 +8071,43 @@ do return NONE end + local function unresolved_typeargs_for(g: GenericType): {Type} + local ts = {} + for _, ta in ipairs(g.typeargs) do + table.insert(ts, a_type(ta, "unresolved_typearg", { + constraint = ta.constraint + } as UnresolvedTypeArgType)) + end + return ts + end + + function TypeChecker:apply_generic(w: Where, g: GenericType, typeargs?: {Type}): Type + if not g.fresh then + g = fresh_typeargs(self, g) + end + + if not typeargs then + typeargs = unresolved_typeargs_for(g) + end + + assert(#g.typeargs == #typeargs) + + for i, ta in ipairs(g.typeargs) do + self:add_var(nil, ta.typearg, typeargs[i]) + end + local applied, errs = map_typevars(self, g, resolve_typevar) + if errs then + self.errs:add_prefixing(w, errs, "") + return nil + end + + if applied is GenericType then + return applied.t + else + return applied + end + end + local type InvalidOrTypeDeclType = InvalidType | TypeDeclType do @@ -8047,13 +8133,8 @@ do end end - local function match_typevals(self: TypeChecker, t: NominalType, def: HasTypeArgs): Type - if not t.typevals and not def.typeargs then - return def - elseif not def.typeargs then - self.errs:add(t, "unexpected type argument", t) - return nil - elseif not t.typevals then + local function match_typevals(self: TypeChecker, t: NominalType, def: GenericType): Type + if not t.typevals then self.errs:add(t, "missing type arguments in %s", def) return nil elseif #t.typevals ~= #def.typeargs then @@ -8063,10 +8144,7 @@ do self:begin_scope() - for i, tt in ipairs(t.typevals) do - self:add_var(nil, def.typeargs[i].typearg, tt) - end - local ret = self:resolve_typevars_at(t, def) + local ret = self:apply_generic(t, def, t.typevals) if def == self.cache_std_metatable_type then check_metatable_contract(self, t.typevals[1], ret) end @@ -8088,8 +8166,10 @@ do if found is TypeDeclType and found.is_alias then local def = found.def - assert(def is NominalType) - found = def.found + if def is NominalType then + found = def.found + end + -- if found.def is GenericType, return found as-is end if not found then @@ -8117,14 +8197,16 @@ do return nil, found end - local function resolve_decl_into_nominal(self: TypeChecker, t: NominalType, found: TypeDeclType): Type + local function resolve_decl_in_nominal(self: TypeChecker, t: NominalType, found: TypeDeclType): Type local def = found.def local resolved: Type - if def is RecordLikeType or def is FunctionType then + if def is GenericType then resolved = match_typevals(self, t, def) if not resolved then - return self.errs:invalid_at(t, table.concat(t.names, ".") .. " cannot be resolved in scope") + resolved = an_invalid(t) end + elseif t.typevals then + resolved = self.errs:invalid_at(t, "unexpected type argument") else resolved = def end @@ -8140,13 +8222,21 @@ do return immediate end - return resolve_decl_into_nominal(self, t, found) + return resolve_decl_in_nominal(self, t, found) end function TypeChecker:resolve_typealias(ta: TypeDeclType): InvalidOrTypeDeclType + local def = ta.def + + local nom = def + if def is GenericType then + nom = def.t + end + + if not nom is NominalType then + return ta + end -- given a typealias that points to a nominal, - local nom = ta.def - assert(nom is NominalType) local immediate, found = find_nominal_type_decl(self, nom) -- if it was previously resolved (or a circular require, or an error), return that; @@ -8163,7 +8253,11 @@ do -- otherwise, this can't be an alias. -- resolve the nominal into a structural type - local struc = resolve_decl_into_nominal(self, nom, found or nom.found) + local struc = resolve_decl_in_nominal(self, nom, found or nom.found) + + if def is GenericType then + struc = wrap_generic_if_typeargs(def.typeargs, struc) + end -- wrap it into a new non-alias typedecl local td = a_type(ta, "typedecl", { def = struc } as TypeDeclType) @@ -8259,6 +8353,10 @@ do end local t = typedecl.def + if t is GenericType then + t = t.t + end + return t end @@ -8563,8 +8661,8 @@ do end end - local ok, r, errs = typevar_resolver(self, other, resolve_typevar) - if not ok then + local r, errs = map_typevars(self, other, resolve_typevar) + if errs then return false, errs end if r is TypeVarType and r.typevar == typevar then @@ -8739,6 +8837,19 @@ do ["boolean_context"] = { ["boolean"] = compare_true, }, + ["generic"] = { + ["generic"] = function(self: TypeChecker, a: GenericType, b: GenericType): boolean, {Error} + if #a.typeargs ~= #b.typeargs then + return false + end + for i = 1, #a.typeargs do + if not self:same_type(a.typeargs[i], b.typeargs[i]) then + return false + end + end + return self:same_type(a.t, b.t) + end, + }, ["*"] = { ["boolean_context"] = compare_true, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} @@ -9068,6 +9179,16 @@ do ["boolean_context"] = { ["boolean"] = compare_true, }, + ["generic"] = { + ["*"] = function(self: TypeChecker, a: GenericType, b: Type): boolean, {Error} + -- TODO check if commenting this out causes variable leaks anywhere + -- self:begin_scope() + local aa = self:apply_generic(a, a) + local ok, errs = self:is_a(aa, b) + -- self:end_scope() + return ok, errs + end, + }, ["*"] = { ["any"] = compare_true, ["boolean_context"] = compare_true, @@ -9078,6 +9199,14 @@ do infer_emptytable_from_unresolved_value(self, b, b, a) return true end, + ["generic"] = function(self: TypeChecker, a: Type, b: GenericType): boolean, {Error} + -- TODO check if commenting this out causes variable leaks anywhere + -- self:begin_scope() + local bb = self:apply_generic(b, b) + local ok, errs = self:is_a(a, bb) + -- self:end_scope() + return ok, errs + end, ["self"] = function(self: TypeChecker, a: Type, b: SelfType): boolean, {Error} return self:is_a(a, self:type_of_self(b)) end, @@ -9113,6 +9242,7 @@ do -- evaluation strategy TypeChecker.type_priorities = { -- types that have catch-all rules evaluate first + ["generic"] = -1, ["nil"] = 0, ["unresolved_emptytable_value"] = 1, ["emptytable"] = 2, @@ -9286,6 +9416,11 @@ do end -- unwrap if tuple, resolve if nominal func = self:to_structural(func) + + if func is GenericType then + func = self:apply_generic(func, func) + end + if func.typename ~= "function" and func.typename ~= "poly" then -- resolve if union if func is UnionType then @@ -9405,17 +9540,15 @@ do end do - local function mark_invalid_typeargs(self: TypeChecker, f: FunctionType) - if f.typeargs then - for _, a in ipairs(f.typeargs) do - if not self:find_var_type(a.typearg) then - if a.constraint then - self:add_var(nil, a.typearg, a.constraint) - else - self:add_var(nil, a.typearg, self.feat_lax and an_unknown(a) or a_type(a, "unresolvable_typearg", { - typearg = a.typearg - } as UnresolvableTypeArgType)) - end + local function mark_invalid_typeargs(self: TypeChecker, gt: GenericType) + for _, a in ipairs(gt.typeargs) do + if not self:find_var_type(a.typearg) then + if a.constraint then + self:add_var(nil, a.typearg, a.constraint) + else + self:add_var(nil, a.typearg, self.feat_lax and an_unknown(a) or a_type(a, "unresolvable_typearg", { + typearg = a.typearg + } as UnresolvableTypeArgType)) end end end @@ -9518,16 +9651,6 @@ do return false end - local function add_call_typeargs(self: TypeChecker, func: FunctionType) - if func.typeargs then - for _, fnarg in ipairs(func.typeargs) do - self:add_var(nil, fnarg.typearg, a_type(fnarg, "unresolved_typearg", { - constraint = fnarg.constraint, - } as UnresolvedTypeArgType)) - end - end - end - check_call = function(self: TypeChecker, w: Where, wargs: {Where}, f: FunctionType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): boolean, {Error} local arg1 = args.tuple[1] if cm == "method" and arg1 then @@ -9551,17 +9674,30 @@ do return nil, { Err_at(w, "wrong number of arguments (given " .. given .. ", expects " .. show_arity(f) .. ")") } end - add_call_typeargs(self, f) - return check_args_rets(self, w, wargs, f, args, expected_rets, argdelta) end end + function TypeChecker:iterate_poly(p: PolyType): function(): integer, FunctionType + local i = 0 + return function(): integer, FunctionType + i = i + 1 + local fg = p.types[i] + if not fg then + return + elseif fg is FunctionType then + return i, fg + elseif fg is GenericType then + return i, self:apply_generic(p, fg) as FunctionType + end + end + end + local check_poly_call: function(self: TypeChecker, w: Where, wargs: {Where}, p: PolyType, args: TupleType, expected_rets: TupleType, cm: CallMode, argdelta: integer): FunctionType, TupleType, {Error} do - local function fail_poly_call_arity(w: Where, p: PolyType, given: integer): {Error} + local function fail_poly_call_arity(self: TypeChecker, w: Where, p: PolyType, given: integer): {Error} local expects: {string} = {} - for _, f in ipairs(p.types) do + for _, f in self:iterate_poly(p) do table.insert(expects, show_arity(f)) end table.sort(expects) @@ -9588,7 +9724,9 @@ do local first_errs: {Error} for pass = 1, 3 do - for i, f in ipairs(p.types) do + for i, f in self:iterate_poly(p) do + assert(f is FunctionType, f.typename) + assert(f.args) first_rets = first_rets or f.rets local wanted = #f.args.tuple @@ -9619,7 +9757,7 @@ do end if not first_errs then - return nil, first_rets, fail_poly_call_arity(w, p, given) + return nil, first_rets, fail_poly_call_arity(self, w, p, given) end return nil, first_rets, first_errs @@ -9654,6 +9792,14 @@ do expected_rets = a_tuple(node, { node.expected }) end + self:begin_scope() + + local g: GenericType + if func is GenericType then + g = func + func = self:apply_generic(node, func) as FunctionType + end + local is_method = (argdelta == -1) if not (func is FunctionType or func is PolyType) then @@ -9668,8 +9814,6 @@ do local errs: {Error} local f, ret: FunctionType, InvalidOrTupleType - self:begin_scope() - if func is PolyType then f, ret, errs = check_poly_call(self, node, e2, func, args, expected_rets, cm, argdelta) elseif func is FunctionType then @@ -9684,11 +9828,11 @@ do self.errs:collect(errs) end - if f then - mark_invalid_typeargs(self, f) + if g then + mark_invalid_typeargs(self, g) end - ret = self:resolve_typevars_at(node, ret) + ret = self:resolve_typevars_at(node, ret) as InvalidOrTupleType self:end_scope() if self.collector then @@ -9721,6 +9865,7 @@ do if self.feat_lax and ((a and is_unknown(a)) or (b and is_unknown(b))) then return an_unknown(node), nil end + local ameta = a is RecordLikeType and a.meta_fields local bmeta = b and b is RecordLikeType and b.meta_fields @@ -9821,6 +9966,10 @@ do end end + if tbl is GenericType then + tbl = self:apply_generic(tbl, tbl) + end + if tbl is UnionType then local t = self:same_in_all_union_entries(tbl, function(t: Type): (Type, Type) return (self:match_record_key(t, rec, key)) @@ -10004,12 +10153,11 @@ do end function TypeChecker:add_function_definition_for_recursion(node: Node, fnargs: TupleType, feat_arity: boolean) - self:add_var(nil, node.name.tk, a_function(node, { + self:add_var(nil, node.name.tk, wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = fnargs, rets = self.get_rets(node.rets), - })) + }))) end function TypeChecker:end_function_scope(node: Node) @@ -10297,7 +10445,7 @@ do local function typedecl_to_nominal(node: Node, name: string, t: TypeDeclType, resolved?: Type): Type local typevals: {Type} local def = t.def - if def is HasTypeArgs then + if def is GenericType then typevals = {} for _, a in ipairs(def.typeargs) do table.insert(typevals, a_type(a, "typevar", { @@ -10923,7 +11071,7 @@ do local typ: Type typ = decls[i] if typ then - if node.kind == "assignment" and i == nexps and ndecl > nexps then + if i == nexps and ndecl > nexps and node_is_funcall(node.exps[i]) then typ = a_tuple(node, {}) for a = i, ndecl do table.insert(typ.tuple, decls[a]) @@ -11214,6 +11362,13 @@ do if def is NominalType then return (self:find_var(def.names[1], "use_type")) end + + if def is GenericType then + local nom = def.t + if nom is NominalType then + return (self:find_var(nom.names[1], "use_type")) + end + end end local function recurse_type_declaration(self: TypeChecker, n: Node): InvalidOrTypeDeclType, Variable @@ -11265,14 +11420,13 @@ do local resolved, aliasing = recurse_type_declaration(self, value) local nt = value.newtype if nt and nt.is_alias and resolved is TypeDeclType then - if nt.typeargs then - local def = resolved.def + local ntdef = nt.def + local rdef = resolved.def + if ntdef is GenericType and rdef is GenericType then -- FIXME this looks sketchy; not sure if just overwriting the -- type variables in a resolved alias won't have bad side-effects. -- Is it guaranteed to be a fresh type? - if def is RecordType or def is FunctionType or def is InterfaceType then - def.typeargs = nt.typeargs - end + ntdef.typeargs = rdef.typeargs end end return resolved, aliasing @@ -11848,6 +12002,10 @@ do decltype = resolve_typedecl(self:to_structural(decltype.constraint)) end + if decltype is GenericType then + decltype = self:apply_generic(node, decltype) + end + if decltype is TupleTableType then for _, child in ipairs(node) do local n = child.key.constnum @@ -11892,6 +12050,10 @@ do decltype = self:to_structural(constraint) end + if decltype is GenericType then + decltype = self:apply_generic(node, decltype) + end + if decltype is UnionType then local single_table_type: Type local single_table_rt: Type @@ -12046,9 +12208,8 @@ do self:end_function_scope(node) - local t = self:ensure_fresh_typeargs(a_function(node, { + local t = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), })) @@ -12075,9 +12236,8 @@ do self:check_macroexp_arg_use(node.macrodef) - local t = self:ensure_fresh_typeargs(a_function(node, { + local t = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.macrodef.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), macroexp = node.macrodef, @@ -12122,9 +12282,8 @@ do return NONE end - self:add_global(node, node.name.tk, self:ensure_fresh_typeargs(a_function(node, { + self:add_global(node, node.name.tk, wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), }))) @@ -12141,7 +12300,7 @@ do local rtype = self:to_structural(resolve_typedecl(children[1])) -- add type arguments from the record implicitly - if rtype is RecordLikeType and rtype.typeargs then + if rtype is GenericType then for _, typ in ipairs(rtype.typeargs) do self:add_var(nil, typ.typearg, a_type(typ, "typearg", { typearg = typ.typearg, @@ -12159,6 +12318,10 @@ do local t = children[1] local rtype = self:to_structural(resolve_typedecl(t)) + if rtype is GenericType then + rtype = rtype.t + end + do local ok, err = ensure_not_abstract(t) if not ok then @@ -12196,10 +12359,9 @@ do end end - local fn_type = self:ensure_fresh_typeargs(a_function(node, { + local fn_type = wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, is_method = node.is_method, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), is_record_function = true, @@ -12215,6 +12377,12 @@ do self.errs:redeclaration_warning(node, node.name.tk, "function") end + if fn_type is GenericType and not rfieldtype is GenericType then + self:begin_scope() + fn_type = self:apply_generic(node, fn_type) as FunctionType + self:end_scope() + end + local ok, err = self:same_type(fn_type, rfieldtype) if not ok then if rfieldtype is PolyType then @@ -12228,7 +12396,15 @@ do return end else + if open_t and open_t is GenericType then + open_t = open_t.t + end if self.feat_lax or rtype == open_t then + -- TODO is this needed? + -- if fn_type is GenericType then + -- fn_type = fresh_typeargs(self, fn_type) + -- end + rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) @@ -12283,9 +12459,9 @@ do assert(rets is TupleType) self:end_function_scope(node) - return self:ensure_fresh_typeargs(a_function(node, { + + return wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = self.get_rets(rets), })) @@ -12309,9 +12485,8 @@ do assert(rets is TupleType) self:end_function_scope(node) - return self:ensure_fresh_typeargs(a_function(node, { + return wrap_generic_if_typeargs(node.typeargs, a_function(node, { min_arity = self.feat_arity and node.min_arity or 0, - typeargs = node.typeargs, args = args, rets = rets, })) @@ -12353,6 +12528,9 @@ do elseif node.op.op == "or" then self:apply_facts(node, facts_not(node, node.e1.known)) elseif node.op.op == "@funcall" then + if e1type is GenericType then + e1type = self:apply_generic(node, e1type) + end if e1type is FunctionType then local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 if node.expected then @@ -12973,16 +13151,6 @@ do end end - local visit_type_with_typeargs = { - before = function(self: TypeChecker, _typ: Type) - self:begin_scope() - end, - after = function(self: TypeChecker, typ: Type, _children: {Type}): Type - self:end_scope() - return self:ensure_fresh_typeargs(typ) - end, - } - function TypeChecker:begin_temporary_record_types(typ: RecordType) self:add_var(nil, "@self", a_typedecl(typ, typ)) @@ -13010,26 +13178,35 @@ do end end - local function ensure_is_method_self(typ: RecordLikeType, fargs: {Type}): boolean - assert(typ.declname) - local selfarg = fargs[1] + local function ensure_is_method_self(typ: RecordLikeType, selfarg: Type, g?: GenericType): boolean if selfarg is SelfType then return true end if not selfarg is NominalType then return false end - if selfarg.names[1] ~= typ.declname or (typ.typeargs and not selfarg.typevals) then + + if #selfarg.names ~= 1 or selfarg.names[1] ~= typ.declname then return false end - if typ.typeargs then - for j=1,#typ.typeargs do + + if g then + if not selfarg.typevals then + return false + end + + if g.t.typeid ~= typ.typeid then + return false + end + + for j=1,#g.typeargs do local tv = selfarg.typevals[j] - if not (tv and tv is TypeVarType and tv.typevar == typ.typeargs[j].typearg) then + if not (tv and tv is TypeVarType and tv.typevar == g.typeargs[j].typearg) then return false end end end + return true end @@ -13050,13 +13227,22 @@ do local visit_type: Visitor visit_type = { cbs = { + ["generic"] = { + before = function(self: TypeChecker, typ: GenericType) + self:begin_scope() + self:add_var(nil, "@generic", typ) + end, + after = function(self: TypeChecker, typ: GenericType, _children: {Type}): Type + self:end_scope() + return fresh_typeargs(self, typ) + end, + }, ["function"] = { - before = visit_type_with_typeargs.before, - after = function(self: TypeChecker, typ: FunctionType, children: {Type}): Type + after = function(self: TypeChecker, typ: FunctionType, _children: {Type}): Type if self.feat_arity == false then typ.min_arity = 0 end - return visit_type_with_typeargs.after(self, typ, children) + return typ end }, ["record"] = { @@ -13066,12 +13252,6 @@ do end, after = function(self: TypeChecker, typ: RecordType, children: {Type}): Type local i = 1 - if typ.typeargs then - for _, _ in ipairs(typ.typeargs) do - typ.typeargs[i] = children[i] as TypeArgType - i = i + 1 - end - end if typ.interface_list then for j, _ in ipairs(typ.interface_list) do local iface = children[i] @@ -13093,6 +13273,7 @@ do i = i + 1 end local fmacros: {FunctionType} + local g: Variable for name, _ in fields_of(typ) do local ftype = children[i] if ftype is FunctionType then @@ -13104,7 +13285,10 @@ do if ftype.is_method then local fargs = ftype.args.tuple if fargs[1] then - ftype.is_method = ensure_is_method_self(typ, fargs) + if not g then + g = self:find_var("@generic") + end + ftype.is_method = ensure_is_method_self(typ, fargs[1], g and g.t as GenericType) if ftype.is_method then fargs[1] = a_self(fargs[1], typ) end @@ -13194,8 +13378,12 @@ do if t then local def = t.def if t.is_alias then - assert(def is NominalType) - typ.found = def.found + if def is GenericType then + def = def.t + end + if def is NominalType then + typ.found = def.found + end elseif def.typename ~= "circular_require" then typ.found = t end @@ -13236,8 +13424,7 @@ do visit_type.cbs["interface"] = visit_type.cbs["record"] - visit_type.cbs["typedecl"] = visit_type_with_typeargs - + visit_type.cbs["typedecl"] = default_type_visitor visit_type.cbs["self"] = default_type_visitor visit_type.cbs["string"] = default_type_visitor visit_type.cbs["tupletable"] = default_type_visitor