Skip to content

Commit

Permalink
Propagate exceptions from parallel where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
SquidDev committed Feb 9, 2025
1 parent 88cb03b commit 00d9569
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ the other.
@since 1.2
]]

local exception = dofile("rom/modules/main/cc/internal/tiny_require.lua")("cc.internal.exception")

local function create(...)
local barrier_ctx = { co = coroutine.running() }

local functions = table.pack(...)
local threads = {}
for i = 1, functions.n, 1 do
Expand All @@ -48,7 +52,7 @@ local function create(...)
error("bad argument #" .. i .. " (function expected, got " .. type(fn) .. ")", 3)
end

threads[i] = { co = coroutine.create(fn), filter = nil }
threads[i] = { co = coroutine.create(function() return exception.try_barrier(barrier_ctx, fn) end), filter = nil }
end

return threads
Expand All @@ -65,11 +69,14 @@ local function runUntilLimit(threads, limit)
local thread = threads[i]
if thread and (thread.filter == nil or thread.filter == event[1] or event[1] == "terminate") then
local ok, param = coroutine.resume(thread.co, table.unpack(event, 1, event.n))
if not ok then
error(param, 0)
else
if ok then
thread.filter = param
elseif type(param) == "string" and exception.can_wrap_errors() then
error(exception.make_exception(param, thread.co))
else
error(param, 0)
end

if coroutine.status(thread.co) == "dead" then
threads[i] = false
living = living - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
-- @module textutils
-- @since 1.2

local pgk_env = setmetatable({}, { __index = _ENV })
pgk_env.require = dofile("rom/modules/main/cc/require.lua").make(pgk_env, "rom/modules/main")
local require = pgk_env.require
local require = dofile("rom/modules/main/cc/internal/tiny_require.lua")

local expect = require("cc.expect")
local expect, field = expect.expect, expect.field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
]]

local expect = require "cc.expect".expect
local error_printer = require "cc.internal.error_printer"
local type, debug, coroutine = type, debug, coroutine

local function find_frame(thread, file, line)
-- Scan the first 16 frames for something interesting.
Expand All @@ -28,7 +28,7 @@ end

--[[- Check whether this error is an exception.
Currently we don't provide a stable API for throwing (and propogating) rich
Currently we don't provide a stable API for throwing (and propagating) rich
errors, like those supported by this module. In lieu of that, we describe the
exception protocol, which may be used by user-written coroutine managers to
throw exceptions which are pretty-printed by the shell:
Expand Down Expand Up @@ -64,6 +64,86 @@ local function is_exception(exn)
return mt and mt.__name == "exception" and type(rawget(exn, "message")) == "string" and type(rawget(exn, "thread")) == "thread"
end

local exn_mt = {
__name = "exception",
__tostring = function(self) return self.message end,
}

--[[- Create a new exception from a message and thread.
@tparam string message The exception message.
@tparam coroutine thread The coroutine the error occurred on.
@return The constructed exception.
]]
local function make_exception(message, thread)
return setmetatable({ message = message, thread = thread }, exn_mt)
end

--[[- A marker function for [`try`] and the wider exception machinery.
This function is typically the first function on the call stack. It acts as both
a signifier that this function is exception aware, and allows us to store
additional information for the exception machinery on the call stack.
@see can_wrap_errors
]]
local try_barrier = debug.getregistry().cc_try_barrier
if not try_barrier then
-- We define an extra "bounce" function to prevent f(...) being treated as a
-- tail call, and so ensure the barrier remains on the stack.
local function bounce(...) return ... end

--- @tparam { co = coroutine, can_wrap ?= boolean } parent The parent coroutine.
-- @tparam function f The function to call.
-- @param ... The arguments to this function.
try_barrier = function(parent, f, ...) return bounce(f(...)) end

debug.getregistry().cc_try_barrier = try_barrier
end

-- Functions that act as a barrier for exceptions.
local pcall_functions = { [pcall] = true, [xpcall] = true, [load] = true }

--[[- Check to see whether we can wrap errors into an exception.
This scans the current thread (up to a limit), and any parent threads, to
determine if there is a pcall anywhere on the callstack. If not, then we know
the error message is not observed by user code, and so may be wrapped into an
exception.
@tparam[opt] coroutine The thread to check. Defaults to the current thread.
@treturn boolean Whether we can wrap errors into exceptions.
]]
local function can_wrap_errors(thread)
if not thread then thread = coroutine.running() end

for offset = 0, 31 do
local frame = debug.getinfo(thread, offset, "f")
if not frame then return false end

local func = frame.func
if func == try_barrier then
-- If we've a try barrier, then extract the parent coroutine and
-- check if it can wrap errors.
local _, parent = debug.getlocal(thread, offset, 1)
if type(parent) ~= "table" or type(parent.co) ~= "thread" then return false end

local result = parent.can_wrap
if result == nil then
result = can_wrap_errors(parent.co)
parent.can_wrap = result
end

return result
elseif pcall_functions[func] then
-- If we're a pcall, then abort.
return false
end
end

return false
end

--[[- Attempt to call the provided function `func` with the provided arguments.
@tparam function func The function to call.
Expand All @@ -79,8 +159,8 @@ end
local function try(func, ...)
expect(1, func, "function")

local co = coroutine.create(func)
local result = table.pack(coroutine.resume(co, ...))
local co = coroutine.create(try_barrier)
local result = table.pack(coroutine.resume(co, { co = co, can_wrap = true }, func, ...))

while coroutine.status(co) ~= "dead" do
local event = table.pack(os.pullEventRaw(result[2]))
Expand Down Expand Up @@ -152,7 +232,7 @@ local function report(err, thread, source_map)
-- Could not determine the line. Bail.
if not line_contents or #line_contents == "" then return end

error_printer({
require("cc.internal.error_printer")({
get_pos = function() return line, column end,
get_line = function() return line_contents end,
}, {
Expand All @@ -162,6 +242,11 @@ end


return {
make_exception = make_exception,

try_barrier = try_barrier,
can_wrap_errors = can_wrap_errors,

try = try,
report = report,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
-- SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0

--[[- A minimal implementation of require.
This is intended for use with APIs, and other internal code which is not run in
the [`shell`] environment. This allows us to avoid some of the overhead of
loading the full [`cc.require`] module.
> [!DANGER]
> This is an internal module and SHOULD NOT be used in your own code. It may
> be removed or changed at any time.
@local
@tparam string name The module to require.
@return The required module.
]]

local loaded = {}
local env = setmetatable({}, { __index = _G })
local function require(name)
local result = loaded[name]
if result then return result end

local path = "rom/modules/main/" .. name:gsub("%.", "/")
if fs.exists(path .. ".lua") then
result = assert(loadfile(path .. ".lua", nil, env))()
else
result = assert(loadfile(path .. "/init.lua", nil, env))()
end
loaded[name] = result
return result
end
env.require = require
return require
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ public final void submit(Map<?, ?> tbl) {
var wholeMessage = new StringBuilder();
if (message != null) wholeMessage.append(message);
if (trace != null) {
if (wholeMessage.length() != 0) wholeMessage.append('\n');
if (!wholeMessage.isEmpty()) wholeMessage.append('\n');
wholeMessage.append(trace);
}

Expand Down
4 changes: 2 additions & 2 deletions projects/core/src/test/resources/test-rom/mcfly.lua
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ local function format(value)
return "\"" .. escaped .. "\""
else
local ok, res = pcall(textutils.serialise, value)
if ok then return res else return tostring(value) end
if ok then return (res:gsub("\\\n", "\\n")) else return tostring(value) end
end
end

Expand Down Expand Up @@ -379,7 +379,7 @@ end
function expect_mt:str_match(pattern)
local actual_type = type(self.value)
if actual_type ~= "string" then
self:_fail(("Expected value of type string\nbut got %s"):format(actual_type))
self:_fail(("Expected value of type string\nbut got %s (of type %s)"):format(format(self.value), actual_type))
end
if not self.value:find(pattern) then
self:_fail(("Expected %q\n to match pattern %q"):format(self.value, pattern))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,63 @@ describe("The parallel library", function()
expect(exitCount):eq(3)
end)
end)

describe("exceptions", function()
local try = require "cc.internal.exception".try
local function check_failure(fn, ...)
local ok, message, thread = try(fn, ...)
expect(ok):eq(false)
expect(message):str_match("/parallel_spec.lua:%d+: Oh no$")
return thread
end

it("throws an exception when within a try", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
expected_thread = coroutine.running()
error("Oh no")
end)

expect(thread):eq(expected_thread)
end)

it("throws an exception when within a try (nested)", function()
local expected_thread
local thread = check_failure(parallel.waitForAny, function()
parallel.waitForAny(function()
expected_thread = coroutine.running()
error("Oh no")
end)
end)
expect(thread):eq(expected_thread)
end)

it("throws the raw error when within a pcall", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()

local ok, err = pcall(parallel.waitForAny, function() error("Oh no") end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)

it("throws the raw error when within a pcall (nested)", function()
local expected_thread
local thread = check_failure(function()
expected_thread = coroutine.running()

local ok, err = pcall(parallel.waitForAny, function()
parallel.waitForAny(function() error("Oh no") end)
end)
expect(ok):eq(false)
expect(err):str_match("/parallel_spec.lua:%d+: Oh no$")
error(err, 0)
end)
expect(thread):eq(expected_thread)
end)
end)
end)

0 comments on commit 00d9569

Please sign in to comment.