Skip to content

Commit

Permalink
Propagate exceptions from parallel where possible (#2095)
Browse files Browse the repository at this point in the history
In the original implementation of our prettier runtime errors (#1320), we
wrapped the errors thrown within parallel functions into an exception
object. This means the call-stack is available to the catching-code, and
so is able to report a pretty exception message.

Unfortunately, this was a breaking change, and so we had to roll that
back. Some people were pcalling the parallel function, and matching on
the result of the error.

This is a second attempt at this, using a technique I've affectionately
dubbed "magic throws". The parallel API is now aware of whether it is
being pcalled or not, and thus able to decide whether to wrap the error
into an exception or not:

 - Add a new `cc.internal.tiny_require` module. This is a tiny
   reimplementation of require, for use in our global APIs.

 - Add a new (global, in the debug registry) `cc_try_barrier` function.
   This acts as a marker function, and is used to store additional
   information about the current coroutine.

   Currently this stores the parent coroutine (used to walk the full call
   stack) and a cache of whether any `pcall`-like function is on the
   stack.

   Both `parallel` and `cc.internal.exception.try` add this function to
   the root of the call stack.

 - When an error occurs within `parallel`, we walk up the call stack,
   using `cc_try_barrier` to traverse up the parent coroutine's stack
   too. If we do not find any `pcall`-like functions, then we know the
   error is never intercepted by user code, and so its safe to throw a
   full exception.
  • Loading branch information
SquidDev authored Feb 13, 2025
1 parent 2e2f308 commit 051c70a
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ the other.
@since 1.2
]]

local exception = dofile("rom/modules/main/cc/internal/tiny_require.lua")("cc.internal.exception")

local function create(...)
local barrier_ctx = { co = coroutine.running() }

local functions = table.pack(...)
local threads = {}
for i = 1, functions.n, 1 do
Expand All @@ -48,7 +52,7 @@ local function create(...)
error("bad argument #" .. i .. " (function expected, got " .. type(fn) .. ")", 3)
end

threads[i] = { co = coroutine.create(fn), filter = nil }
threads[i] = { co = coroutine.create(function() return exception.try_barrier(barrier_ctx, fn) end), filter = nil }
end

return threads
Expand All @@ -65,11 +69,14 @@ local function runUntilLimit(threads, limit)
local thread = threads[i]
if thread and (thread.filter == nil or thread.filter == event[1] or event[1] == "terminate") then
local ok, param = coroutine.resume(thread.co, table.unpack(event, 1, event.n))
if not ok then
error(param, 0)
else
if ok then
thread.filter = param
elseif type(param) == "string" and exception.can_wrap_errors() then
error(exception.make_exception(param, thread.co))
else
error(param, 0)
end

if coroutine.status(thread.co) == "dead" then
threads[i] = false
living = living - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
-- @module textutils
-- @since 1.2

local pgk_env = setmetatable({}, { __index = _ENV })
pgk_env.require = dofile("rom/modules/main/cc/require.lua").make(pgk_env, "rom/modules/main")
local require = pgk_env.require
local require = dofile("rom/modules/main/cc/internal/tiny_require.lua")

local expect = require("cc.expect")
local expect, field = expect.expect, expect.field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
]]

local expect = require "cc.expect".expect
local error_printer = require "cc.internal.error_printer"
local type, debug, coroutine = type, debug, coroutine

local function find_frame(thread, file, line)
-- Scan the first 16 frames for something interesting.
Expand All @@ -28,7 +28,7 @@ end

--[[- Check whether this error is an exception.
Currently we don't provide a stable API for throwing (and propogating) rich
Currently we don't provide a stable API for throwing (and propagating) rich
errors, like those supported by this module. In lieu of that, we describe the
exception protocol, which may be used by user-written coroutine managers to
throw exceptions which are pretty-printed by the shell:
Expand Down Expand Up @@ -64,6 +64,86 @@ local function is_exception(exn)
return mt and mt.__name == "exception" and type(rawget(exn, "message")) == "string" and type(rawget(exn, "thread")) == "thread"
end

local exn_mt = {
__name = "exception",
__tostring = function(self) return self.message end,
}

--[[- Create a new exception from a message and thread.
@tparam string message The exception message.
@tparam coroutine thread The coroutine the error occurred on.
@return The constructed exception.
]]
local function make_exception(message, thread)
return setmetatable({ message = message, thread = thread }, exn_mt)
end

--[[- A marker function for [`try`] and the wider exception machinery.
This function is typically the first function on the call stack. It acts as both
a signifier that this function is exception aware, and allows us to store
additional information for the exception machinery on the call stack.
@see can_wrap_errors
]]
local try_barrier = debug.getregistry().cc_try_barrier
if not try_barrier then
-- We define an extra "bounce" function to prevent f(...) being treated as a
-- tail call, and so ensure the barrier remains on the stack.
local function bounce(...) return ... end

--- @tparam { co = coroutine, can_wrap ?= boolean } parent The parent coroutine.
-- @tparam function f The function to call.
-- @param ... The arguments to this function.
try_barrier = function(parent, f, ...) return bounce(f(...)) end

debug.getregistry().cc_try_barrier = try_barrier
end

-- Functions that act as a barrier for exceptions.
local pcall_functions = { [pcall] = true, [xpcall] = true, [load] = true }

--[[- Check to see whether we can wrap errors into an exception.
This scans the current thread (up to a limit), and any parent threads, to
determine if there is a pcall anywhere on the callstack. If not, then we know
the error message is not observed by user code, and so may be wrapped into an
exception.
@tparam[opt] coroutine The thread to check. Defaults to the current thread.
@treturn boolean Whether we can wrap errors into exceptions.
]]
local function can_wrap_errors(thread)
if not thread then thread = coroutine.running() end

for offset = 0, 31 do
local frame = debug.getinfo(thread, offset, "f")
if not frame then return false end

local func = frame.func
if func == try_barrier then
-- If we've a try barrier, then extract the parent coroutine and
-- check if it can wrap errors.
local _, parent = debug.getlocal(thread, offset, 1)
if type(parent) ~= "table" or type(parent.co) ~= "thread" then return false end

local result = parent.can_wrap
if result == nil then
result = can_wrap_errors(parent.co)
parent.can_wrap = result
end

return result
elseif pcall_functions[func] then
-- If we're a pcall, then abort.
return false
end
end

return false
end

--[[- Attempt to call the provided function `func` with the provided arguments.
@tparam function func The function to call.
Expand All @@ -79,8 +159,8 @@ end
local function try(func, ...)
expect(1, func, "function")

local co = coroutine.create(func)
local result = table.pack(coroutine.resume(co, ...))
local co = coroutine.create(try_barrier)
local result = table.pack(coroutine.resume(co, { co = co, can_wrap = true }, func, ...))

while coroutine.status(co) ~= "dead" do
local event = table.pack(os.pullEventRaw(result[2]))
Expand Down Expand Up @@ -152,7 +232,7 @@ local function report(err, thread, source_map)
-- Could not determine the line. Bail.
if not line_contents or #line_contents == "" then return end

error_printer({
require("cc.internal.error_printer")({
get_pos = function() return line, column end,
get_line = function() return line_contents end,
}, {
Expand All @@ -162,6 +242,11 @@ end


return {
make_exception = make_exception,

try_barrier = try_barrier,
can_wrap_errors = can_wrap_errors,

try = try,
report = report,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
-- SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0

--[[- A minimal implementation of require.
This is intended for use with APIs, and other internal code which is not run in
the [`shell`] environment. This allows us to avoid some of the overhead of
loading the full [`cc.require`] module.
> [!DANGER]
> This is an internal module and SHOULD NOT be used in your own code. It may
> be removed or changed at any time.
@local
@tparam string name The module to require.
@return The required module.
]]

local loaded = {}
local env = setmetatable({}, { __index = _G })
local function require(name)
local result = loaded[name]
if result then return result end

local path = "rom/modules/main/" .. name:gsub("%.", "/")
if fs.exists(path .. ".lua") then
result = assert(loadfile(path .. ".lua", nil, env))()
else
result = assert(loadfile(path .. "/init.lua", nil, env))()
end
loaded[name] = result
return result
end
env.require = require
return require
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ public final void submit(Map<?, ?> tbl) {
var wholeMessage = new StringBuilder();
if (message != null) wholeMessage.append(message);
if (trace != null) {
if (wholeMessage.length() != 0) wholeMessage.append('\n');
if (!wholeMessage.isEmpty()) wholeMessage.append('\n');
wholeMessage.append(trace);
}

Expand Down
4 changes: 2 additions & 2 deletions projects/core/src/test/resources/test-rom/mcfly.lua
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ local function format(value)
return "\"" .. escaped .. "\""
else
local ok, res = pcall(textutils.serialise, value)
if ok then return res else return tostring(value) end
if ok then return (res:gsub("\\\n", "\\n")) else return tostring(value) end
end
end

Expand Down Expand Up @@ -379,7 +379,7 @@ end
function expect_mt:str_match(pattern)
local actual_type = type(self.value)
if actual_type ~= "string" then
self:_fail(("Expected value of type string\nbut got %s"):format(actual_type))
self:_fail(("Expected value of type string\nbut got %s (of type %s)"):format(format(self.value), actual_type))
end
if not self.value:find(pattern) then
self:_fail(("Expected %q\n to match pattern %q"):format(self.value, pattern))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,63 @@ describe("The parallel library", function()
expect(exitCount):eq(3)
end)
end)

describe("exceptions", function()
local try = require "cc.internal.exception".try
local function check_failure(fn, ...)
local ok, message, thread = try(fn, ...)
expect(ok):eq(false)
expect(message):str_match("/parallel_spec.lua:%d+: Oh no$")
return thread
end

it("throws an exception when within a try", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
expected_thread = coroutine.running()
error("Oh no")
end)

expect(thread):eq(expected_thread)
end)

it("throws an exception when within a try (nested)", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
parallel.waitForAny(function()
expected_thread = coroutine.running()
error("Oh no")
end)
end)
expect(thread):eq(expected_thread)
end)

it("throws the raw error when within a pcall", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()

local ok, err = pcall(parallel.waitForAny, function() error("Oh no") end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)

it("throws the raw error when within a pcall (nested)", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()

local ok, err = pcall(parallel.waitForAny, function()
parallel.waitForAny(function() error("Oh no") end)
end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)
end)
end)

0 comments on commit 051c70a

Please sign in to comment.