Skip to content

Commit

Permalink
refactor: generalize map_typevars into map_type
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm committed Jan 7, 2025
1 parent 63a6efc commit d449988
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 128 deletions.
137 changes: 74 additions & 63 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7431,26 +7431,28 @@ do
end


local map_typevars

local function fresh_typevar(_, t)
if t.typename == "typevar" then
local map_type

local fresh_typevar_fns = {
["typevar"] = function(_, t)
return a_type(t, "typevar", {
typevar = (t.typevar:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr,
constraint = t.constraint,
}), false
else
}), true
end,
["typearg"] = function(_, t)
return a_type(t, "typearg", {
typearg = (t.typearg:gsub("@.*", "")) .. "@" .. fresh_typevar_ctr,
constraint = t.constraint,
}), true
end
end
end,
}

local function fresh_typeargs(self, g)
fresh_typevar_ctr = fresh_typevar_ctr + 1

local newg, errs = map_typevars(nil, g, fresh_typevar)
local newg, errs = map_type(nil, g, fresh_typevar_fns)
if newg.typename == "invalid" then
self.errs:collect(errs)
return g
Expand Down Expand Up @@ -7684,23 +7686,61 @@ do
["unknown"] = true,
}

local function clear_resolved_typeargs(copy, resolved)
for i = #copy.typeargs, 1, -1 do
local r = resolved[copy.typeargs[i].typearg]
if r then
table.remove(copy.typeargs, i)
local resolve_typevars
do





local resolve_typevar_fns = {
["typevar"] = function(s, t)
local rt = s.tc:find_var_type(t.typevar)
if not rt then
return t, false
end

rt = drop_constant_value(rt)
s.resolved[t.typevar] = rt

return rt, true
end,
}

local function clear_resolved_typeargs(copy, resolved)
for i = #copy.typeargs, 1, -1 do
local r = resolved[copy.typeargs[i].typearg]
if r then
table.remove(copy.typeargs, i)
end
end
if not copy.typeargs[1] then
return copy.t
end
return copy
end
if not copy.typeargs[1] then
return copy.t

resolve_typevars = function(self, t)
local state = {
tc = self,
resolved = {},
}
local rt, errs = map_type(state, t, resolve_typevar_fns)
if errs then
return rt, errs
end

if rt.typename == "generic" then
rt = clear_resolved_typeargs(rt, state.resolved)
end

return rt
end
return copy
end

map_typevars = function(self, ty, fn_tv)
map_type = function(self, ty, fns)
local errs
local seen = {}
local resolved = {}
local resolve

resolve = function(t, all_same)
Expand All @@ -7716,18 +7756,15 @@ do
end

local orig_t = t
if t.typename == "typevar" then
local rt, is_resolved = fn_tv(self, t)
if rt then
local fn = fns[t.typename]
if fn then
local rt, is_resolved = fn(self, t)
if rt ~= t then
if is_resolved then
resolved[t.typevar] = rt
end
if no_nested_types[rt.typename] or (rt.typename == "nominal" and not rt.typevals) then
seen[orig_t] = rt
seen[t] = rt
return rt, false
end
all_same = false
t = rt
return resolve(rt, false)
end
end

Expand Down Expand Up @@ -7755,17 +7792,10 @@ do
copy.elements, same = resolve(t.elements, same)

elseif t.typename == "typearg" then
local rt, is_resolved = fn_tv(self, t)
if is_resolved then
resolved[t.typearg] = rt
copy = rt
same = false
else
assert(copy.typename == "typearg")
copy.typearg = t.typearg
if t.constraint then
copy.constraint, same = resolve(t.constraint, same)
end
assert(copy.typename == "typearg")
copy.typearg = t.typearg
if t.constraint then
copy.constraint, same = resolve(t.constraint, same)
end
elseif t.typename == "unresolvable_typearg" then
assert(copy.typename == "unresolvable_typearg")
Expand Down Expand Up @@ -7877,26 +7907,9 @@ do
return a_type(ty, "invalid", {}), errs
end

if copy.typename == "generic" then
copy = clear_resolved_typeargs(copy, resolved)
end

return copy
end

local function resolve_typevar(tc, t)
if t.typename == "typevar" then
local rt = tc:find_var_type(t.typevar)
if not rt then
return nil
end

return drop_constant_value(rt), true
else
return t, false
end
end

function TypeChecker:infer_emptytable(emptytable, fresh_t)
local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration")
local nst = is_global and 1 or #self.st
Expand Down Expand Up @@ -7933,16 +7946,14 @@ do
end
end


local function type_at(w, t)
t.x = w.x
t.y = w.y
return t
end

function TypeChecker:resolve_typevars_at(w, t)
assert(w)
local ret, errs = map_typevars(self, t, resolve_typevar)
function TypeChecker:assert_resolved_typevars_at(w, t)
local ret, errs = resolve_typevars(self, t)
if errs then
assert(w.y)
self.errs:add_prefixing(w, errs, "")
Expand All @@ -7955,7 +7966,7 @@ do
end

function TypeChecker:infer_at(w, t)
local ret = self:resolve_typevars_at(w, t)
local ret = self:assert_resolved_typevars_at(w, t)
if ret.typename == "invalid" then
ret = t
end
Expand Down Expand Up @@ -8160,7 +8171,7 @@ do
for i, ta in ipairs(g.typeargs) do
self:add_var(nil, ta.typearg, typeargs[i])
end
local applied, errs = map_typevars(self, g, resolve_typevar)
local applied, errs = resolve_typevars(self, g)
if errs then
self.errs:add_prefixing(w, errs, "")
return nil
Expand Down Expand Up @@ -8729,7 +8740,7 @@ do
end
end

local r, errs = map_typevars(self, other, resolve_typevar)
local r, errs = resolve_typevars(self, other)
if errs then
return false, errs
end
Expand Down Expand Up @@ -9900,7 +9911,7 @@ a.types[i], b.types[i]), }
mark_invalid_typeargs(self, g)
end

ret = self:resolve_typevars_at(node, ret)
ret = self:assert_resolved_typevars_at(node, ret)
self:end_scope()

if self.collector then
Expand Down
Loading

0 comments on commit d449988

Please sign in to comment.