From d449988b334543218c3e4418c446432d43361eba Mon Sep 17 00:00:00 2001 From: Hisham Muhammad Date: Tue, 7 Jan 2025 16:35:14 -0300 Subject: [PATCH] refactor: generalize map_typevars into map_type --- tl.lua | 137 +++++++++++++++++++++++++++++-------------------------- tl.tl | 141 +++++++++++++++++++++++++++++++-------------------------- 2 files changed, 150 insertions(+), 128 deletions(-) diff --git a/tl.lua b/tl.lua index 544a8d2a..de02916b 100644 --- a/tl.lua +++ b/tl.lua @@ -7431,26 +7431,28 @@ do end - local map_typevars - local function fresh_typevar(_, t) - if t.typename == "typevar" then + local map_type + + local fresh_typevar_fns = { + ["typevar"] = function(_, t) return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, - }), false - else + }), true + end, + ["typearg"] = function(_, t) return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, }), true - end - end + end, + } local function fresh_typeargs(self, g) fresh_typevar_ctr = fresh_typevar_ctr + 1 - local newg, errs = map_typevars(nil, g, fresh_typevar) + local newg, errs = map_type(nil, g, fresh_typevar_fns) if newg.typename == "invalid" then self.errs:collect(errs) return g @@ -7684,23 +7686,61 @@ do ["unknown"] = true, } - local function clear_resolved_typeargs(copy, resolved) - for i = #copy.typeargs, 1, -1 do - local r = resolved[copy.typeargs[i].typearg] - if r then - table.remove(copy.typeargs, i) + local resolve_typevars + do + + + + + + local resolve_typevar_fns = { + ["typevar"] = function(s, t) + local rt = s.tc:find_var_type(t.typevar) + if not rt then + return t, false + end + + rt = drop_constant_value(rt) + s.resolved[t.typevar] = rt + + return rt, true + end, + } + + local function clear_resolved_typeargs(copy, resolved) + for i = #copy.typeargs, 1, -1 do + local r = resolved[copy.typeargs[i].typearg] + if r then + table.remove(copy.typeargs, i) + end + end + if not copy.typeargs[1] then + return copy.t end + return copy end - if not copy.typeargs[1] then - return copy.t + + resolve_typevars = function(self, t) + local state = { + tc = self, + resolved = {}, + } + local rt, errs = map_type(state, t, resolve_typevar_fns) + if errs then + return rt, errs + end + + if rt.typename == "generic" then + rt = clear_resolved_typeargs(rt, state.resolved) + end + + return rt end - return copy end - map_typevars = function(self, ty, fn_tv) + map_type = function(self, ty, fns) local errs local seen = {} - local resolved = {} local resolve resolve = function(t, all_same) @@ -7716,18 +7756,15 @@ do end local orig_t = t - if t.typename == "typevar" then - local rt, is_resolved = fn_tv(self, t) - if rt then + local fn = fns[t.typename] + if fn then + local rt, is_resolved = fn(self, t) + if rt ~= t then if is_resolved then - resolved[t.typevar] = rt - end - if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then - seen[orig_t] = rt + seen[t] = rt return rt, false end - all_same = false - t = rt + return resolve(rt, false) end end @@ -7755,17 +7792,10 @@ do copy.elements, same = resolve(t.elements, same) elseif t.typename == "typearg" then - local rt, is_resolved = fn_tv(self, t) - if is_resolved then - resolved[t.typearg] = rt - copy = rt - same = false - else - assert(copy.typename == "typearg") - copy.typearg = t.typearg - if t.constraint then - copy.constraint, same = resolve(t.constraint, same) - end + assert(copy.typename == "typearg") + copy.typearg = t.typearg + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end elseif t.typename == "unresolvable_typearg" then assert(copy.typename == "unresolvable_typearg") @@ -7877,26 +7907,9 @@ do return a_type(ty, "invalid", {}), errs end - if copy.typename == "generic" then - copy = clear_resolved_typeargs(copy, resolved) - end - return copy end - local function resolve_typevar(tc, t) - if t.typename == "typevar" then - local rt = tc:find_var_type(t.typevar) - if not rt then - return nil - end - - return drop_constant_value(rt), true - else - return t, false - end - end - function TypeChecker:infer_emptytable(emptytable, fresh_t) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") local nst = is_global and 1 or #self.st @@ -7933,16 +7946,14 @@ do end end - local function type_at(w, t) t.x = w.x t.y = w.y return t end - function TypeChecker:resolve_typevars_at(w, t) - assert(w) - local ret, errs = map_typevars(self, t, resolve_typevar) + function TypeChecker:assert_resolved_typevars_at(w, t) + local ret, errs = resolve_typevars(self, t) if errs then assert(w.y) self.errs:add_prefixing(w, errs, "") @@ -7955,7 +7966,7 @@ do end function TypeChecker:infer_at(w, t) - local ret = self:resolve_typevars_at(w, t) + local ret = self:assert_resolved_typevars_at(w, t) if ret.typename == "invalid" then ret = t end @@ -8160,7 +8171,7 @@ do for i, ta in ipairs(g.typeargs) do self:add_var(nil, ta.typearg, typeargs[i]) end - local applied, errs = map_typevars(self, g, resolve_typevar) + local applied, errs = resolve_typevars(self, g) if errs then self.errs:add_prefixing(w, errs, "") return nil @@ -8729,7 +8740,7 @@ do end end - local r, errs = map_typevars(self, other, resolve_typevar) + local r, errs = resolve_typevars(self, other) if errs then return false, errs end @@ -9900,7 +9911,7 @@ a.types[i], b.types[i]), } mark_invalid_typeargs(self, g) end - ret = self:resolve_typevars_at(node, ret) + ret = self:assert_resolved_typevars_at(node, ret) self:end_scope() if self.collector then diff --git a/tl.tl b/tl.tl index 41be0dca..1a6de454 100644 --- a/tl.tl +++ b/tl.tl @@ -7430,27 +7430,29 @@ do }, nil end - local type ResolveTypeVars = function(S, TypeVarType | TypeArgType): Type, boolean - local map_typevars: function(s: S, ty: Type, fn_tv: ResolveTypeVars): Type, {Error} + local type TypeFunction = function(S, Type): Type, boolean + local type TypeFunctionMap = {TypeName: TypeFunction} + local map_type: function(s: S, ty: Type, fns: TypeFunctionMap): Type, {Error} - local function fresh_typevar(_: nil, t: TypeVarType | TypeArgType): Type, boolean - if t is TypeVarType then + local fresh_typevar_fns: TypeFunctionMap = { + ["typevar"] = function(_: nil, t: TypeVarType): Type, boolean return a_type(t, "typevar", { typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, - } as TypeVarType), false - else + } as TypeVarType), true + end, + ["typearg"] = function(_: nil, t: TypeArgType): Type, boolean return a_type(t, "typearg", { typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr, constraint = t.constraint, } as TypeArgType), true - end - end + end, + } local function fresh_typeargs(self: TypeChecker, g: GenericType): GenericType fresh_typevar_ctr = fresh_typevar_ctr + 1 - local newg, errs = map_typevars(nil, g, fresh_typevar) + local newg, errs = map_type(nil, g, fresh_typevar_fns) if newg is InvalidType then self.errs:collect(errs) return g @@ -7684,23 +7686,61 @@ do ["unknown"] = true, } - local function clear_resolved_typeargs(copy: GenericType, resolved: {string:Type}): Type - for i = #copy.typeargs, 1, -1 do - local r = resolved[copy.typeargs[i].typearg] - if r then - table.remove(copy.typeargs, i) + local resolve_typevars: function(self: TypeChecker, t: Type): Type, {Error} + do + local record ResolveTypeVarState + tc: TypeChecker + resolved: {string:Type} + end + + local resolve_typevar_fns: TypeFunctionMap = { + ["typevar"] = function(s: ResolveTypeVarState, t: TypeVarType): Type, boolean + local rt = s.tc:find_var_type(t.typevar) + if not rt then + return t, false + end + + rt = drop_constant_value(rt) + s.resolved[t.typevar] = rt + + return rt, true + end, + } + + local function clear_resolved_typeargs(copy: GenericType, resolved: {string:Type}): Type + for i = #copy.typeargs, 1, -1 do + local r = resolved[copy.typeargs[i].typearg] + if r then + table.remove(copy.typeargs, i) + end + end + if not copy.typeargs[1] then + return copy.t end + return copy end - if not copy.typeargs[1] then - return copy.t + + resolve_typevars = function(self: TypeChecker, t: Type): Type, {Error} + local state: ResolveTypeVarState = { + tc = self, + resolved = {}, + } + local rt, errs = map_type(state, t, resolve_typevar_fns) + if errs then + return rt, errs + end + + if rt is GenericType then + rt = clear_resolved_typeargs(rt, state.resolved) + end + + return rt end - return copy end - map_typevars = function(self: S, ty: Type, fn_tv: ResolveTypeVars): Type, {Error} + map_type = function(self: S, ty: Type, fns: TypeFunctionMap): Type, {Error} local errs: {Error} local seen: {Type:Type} = {} - local resolved: {string:Type} = {} local resolve: function(t: T, all_same: boolean): T, boolean resolve = function(t: T, all_same: boolean): T, boolean @@ -7716,18 +7756,15 @@ do end local orig_t = t - if t is TypeVarType then - local rt, is_resolved = fn_tv(self, t) - if rt then + local fn = fns[t.typename] + if fn then + local rt, is_resolved = fn(self, t) + if rt ~= t then if is_resolved then - resolved[t.typevar] = rt - end - if no_nested_types[rt.typename] or (rt is NominalType and not rt.typevals) then - seen[orig_t] = rt + seen[t] = rt return rt, false end - all_same = false - t = rt + return resolve(rt, false) end end @@ -7755,17 +7792,10 @@ do copy.elements, same = resolve(t.elements, same) -- inferred_len is not propagated elseif t is TypeArgType then - local rt, is_resolved = fn_tv(self, t) - if is_resolved then - resolved[t.typearg] = rt - copy = rt - same = false - else - assert(copy is TypeArgType) - copy.typearg = t.typearg - if t.constraint then - copy.constraint, same = resolve(t.constraint, same) - end + assert(copy is TypeArgType) + copy.typearg = t.typearg + if t.constraint then + copy.constraint, same = resolve(t.constraint, same) end elseif t is UnresolvableTypeArgType then assert(copy is UnresolvableTypeArgType) @@ -7877,26 +7907,9 @@ do return an_invalid(ty), errs end - if copy is GenericType then - copy = clear_resolved_typeargs(copy, resolved) - end - return copy end - local function resolve_typevar(tc: TypeChecker, t: TypeVarType | TypeArgType): Type, boolean - if t is TypeVarType then - local rt = tc:find_var_type(t.typevar) - if not rt then - return nil - end - - return drop_constant_value(rt), true - else - return t, false - end - end - function TypeChecker:infer_emptytable(emptytable: EmptyTableType, fresh_t: Type) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") local nst = is_global and 1 or #self.st @@ -7933,16 +7946,14 @@ do end end - local function type_at(w: Where, t: T): T t.x = w.x t.y = w.y return t end - function TypeChecker:resolve_typevars_at(w: Where, t: Type): Type - assert(w) - local ret, errs = map_typevars(self, t, resolve_typevar) + function TypeChecker:assert_resolved_typevars_at(w: Where, t: Type): Type + local ret, errs = resolve_typevars(self, t) if errs then assert(w.y) self.errs:add_prefixing(w, errs, "") @@ -7955,9 +7966,9 @@ do end function TypeChecker:infer_at(w: Where, t: T): T - local ret = self:resolve_typevars_at(w, t) + local ret = self:assert_resolved_typevars_at(w, t) if ret is InvalidType then - ret = t -- errors are produced by resolve_typevars_at + ret = t -- errors are produced by assert_resolved_typevars_at end if ret == t or t is TypeVarType then @@ -8160,7 +8171,7 @@ do for i, ta in ipairs(g.typeargs) do self:add_var(nil, ta.typearg, typeargs[i]) end - local applied, errs = map_typevars(self, g, resolve_typevar) + local applied, errs = resolve_typevars(self, g) if errs then self.errs:add_prefixing(w, errs, "") return nil @@ -8729,7 +8740,7 @@ do end end - local r, errs = map_typevars(self, other, resolve_typevar) + local r, errs = resolve_typevars(self, other) if errs then return false, errs end @@ -9900,7 +9911,7 @@ do mark_invalid_typeargs(self, g) end - ret = self:resolve_typevars_at(node, ret) as InvalidOrTupleType + ret = self:assert_resolved_typevars_at(node, ret) as InvalidOrTupleType self:end_scope() if self.collector then