Skip to content

Commit

Permalink
[WIP] Hints for runtime errors
Browse files Browse the repository at this point in the history
  • Loading branch information
SquidDev committed Feb 9, 2025
1 parent 88cb03b commit e419318
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public CobaltLuaMachine(MachineEnvironment environment, InputStream bios) throws
var globals = state.globals();
CoreLibraries.debugGlobals(state);
Bit32Lib.add(state, globals);
ErrorContextLib.add(state);
globals.rawset("_HOST", ValueFactory.valueOf(environment.hostString()));
globals.rawset("_CC_DEFAULT_SETTINGS", ValueFactory.valueOf(CoreConfig.defaultComputerSettings));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
//
// SPDX-License-Identifier: MPL-2.0

package dan200.computercraft.core.lua;

import com.google.common.annotations.VisibleForTesting;
import org.squiddev.cobalt.*;
import org.squiddev.cobalt.debug.DebugFrame;
import org.squiddev.cobalt.function.LuaFunction;
import org.squiddev.cobalt.function.RegisteredFunction;

import javax.annotation.Nullable;
import java.util.List;

import static org.squiddev.cobalt.Lua.*;
import static org.squiddev.cobalt.debug.DebugFrame.FLAG_ANY_HOOK;

public class ErrorContextLib {
private static final int MAX_DEPTH = 8;

private static final RegisteredFunction[] functions = new RegisteredFunction[]{
RegisteredFunction.ofV("context", ErrorContextLib::getContext),
};

public static void add(LuaState state) throws LuaError {
state.registry().getSubTable(Constants.LOADED).rawset("cc.internal.error_context", RegisteredFunction.bind(functions));
}

private static Varargs getContext(LuaState state, Varargs args) throws LuaError {
var thread = args.arg(1).checkThread();
var level = args.arg(2).checkInteger();

var context = getContext(state, thread, level);
return context == null ? Constants.NIL : ValueFactory.varargsOf(ValueFactory.valueOf(context.op()), listOf(context.source()));
}

@VisibleForTesting
static @Nullable OpContext getContext(LuaState state, LuaThread thread, int level) {
var frame = thread.getDebugState().getFrame(level);
if (frame == null || frame.closure == null || (frame.flags & FLAG_ANY_HOOK) != 0) return null;

var prototype = frame.closure.getPrototype();
var pc = frame.pc;
var insn = prototype.code[pc];

// Find the register we're operating on.
return switch (GET_OPCODE(insn)) {
case OP_CALL, OP_TAILCALL ->
OpContext.of("call", resolveValueSource(state, frame, prototype, pc, GETARG_A(insn), 0));
case OP_GETTABLE, OP_SETTABLE ->
OpContext.of("index", resolveValueSource(state, frame, prototype, pc, GETARG_A(insn), 0));
default -> null;
};
}

@VisibleForTesting
record OpContext(String op, List<LuaValue> source) {
public static @Nullable OpContext of(String op, @Nullable List<LuaValue> values) {
return values == null ? null : new OpContext(op, values);
}
}

private static LuaTable listOf(List<? extends LuaValue> values) {
var table = new LuaTable(values.size(), 0);
for (var i = 0; i < values.size(); i++) table.rawset(i + 1, values.get(i));
return table;
}

@SuppressWarnings("NullTernary")
private static @Nullable List<LuaValue> resolveValueSource(LuaState state, DebugFrame frame, Prototype prototype, int pc, int register, int depth) {
if (depth > MAX_DEPTH) return null;
if (prototype.getLocalName(register + 1, pc) != null) {
return List.of(frame.stack[register]);
}

// Find where this register was set. If unknown, then abort.
pc = findSetReg(prototype, pc, register);
if (pc == -1) return null;

var insn = prototype.code[pc];
return switch (GET_OPCODE(insn)) {
case OP_MOVE -> {
var a = GETARG_A(insn);
var b = GETARG_B(insn); // move from `b' to `a'
yield b < a ? resolveValueSource(state, frame, prototype, pc, register, depth + 1) : null; // Resolve 'b' .
}
case OP_GETTABUP, OP_GETTABLE -> {
var table = GETARG_B(insn);
var key = GETARG_C(insn);
if (!ISK(key)) yield null;

var keyValue = prototype.constants[INDEXK(key)];
if (keyValue.type() != Constants.TSTRING) yield null;

var tbl = GET_OPCODE(insn) == OP_GETTABUP
? frame.closure.getUpvalue(table).getValue()
: evaluate(state, frame, prototype, pc, table, depth);
yield tbl == null ? null : List.of(tbl, keyValue);
}
default -> {
var value = evaluate(state, frame, prototype, pc, register, depth);
yield value == null ? null : List.of(value);
}
};
}

@SuppressWarnings("NullTernary")
private static @Nullable LuaValue evaluate(LuaState state, DebugFrame frame, Prototype prototype, int pc, int register, int depth) {
if (depth >= MAX_DEPTH) return null;

// If this is a local, then return its contents.
if (prototype.getLocalName(register + 1, pc) != null) return frame.stack[register];

// Otherwise find where this register was set. If unknown, then abort.
pc = findSetReg(prototype, pc, register);
if (pc == -1) return null;

var insn = prototype.code[pc];
return switch (GET_OPCODE(insn)) {
case OP_MOVE -> {
var a = GETARG_A(insn);
var b = GETARG_B(insn); // move from `b' to `a'
yield b < a ? evaluate(state, frame, prototype, pc, register, depth + 1) : null; // Resolve 'b'.
}
// Load constants
case OP_LOADK -> prototype.constants[GETARG_Bx(insn)];
case OP_LOADKX -> prototype.constants[GETARG_Ax(prototype.code[pc + 1])];
case OP_LOADBOOL -> GETARG_B(insn) == 0 ? Constants.FALSE : Constants.TRUE;
case OP_LOADNIL -> Constants.NIL;
// Upvalues and tables.
case OP_GETUPVAL -> frame.closure.getUpvalue(GETARG_B(insn)).getValue();
case OP_GETTABUP -> {
var table = frame.closure.getUpvalue(GETARG_B(insn)).getValue();
if (table == null) yield null;

var key = evaluateK(state, frame, prototype, pc, GETARG_C(insn), depth + 1);
yield key == null ? null : safeIndex(state, table, key);
}
case OP_GETTABLE -> {
var table = evaluate(state, frame, prototype, pc, GETARG_B(insn), depth + 1);
if (table == null) yield null;
var key = evaluateK(state, frame, prototype, pc, GETARG_C(insn), depth + 1);
yield key == null ? null : safeIndex(state, table, key);
}
default -> null;
};
}

private static @Nullable LuaValue evaluateK(LuaState state, DebugFrame frame, Prototype prototype, int pc, int registerOrConstant, int depth) {
return ISK(registerOrConstant) ? prototype.constants[INDEXK(registerOrConstant)] : evaluate(state, frame, prototype, pc, registerOrConstant, depth + 1);
}

private static @Nullable LuaValue safeIndex(LuaState state, LuaValue table, LuaValue key) {
var loop = 0;
do {
LuaValue metatable;
if (table instanceof LuaTable tbl) {
var res = tbl.rawget(key);
if (!res.isNil() || (metatable = tbl.metatag(state, CachedMetamethod.INDEX)).isNil()) return res;
} else if ((metatable = table.metatag(state, CachedMetamethod.INDEX)).isNil()) {
return null;
}

if (metatable instanceof LuaFunction) return null;

table = metatable;
}
while (++loop < Constants.MAXTAGLOOP);

return null;
}

// TODO: The below code is copied from Cobalt (and so is MIT). We should either make part of the public API
// (probably not), or put into a separate file with proper licensing.

private static int filterPc(int pc, int jumpTarget) {
return pc < jumpTarget ? -1 : pc;
}

private static int findSetReg(Prototype pt, int lastPc, int reg) {
var lastInsn = -1; // Last instruction that changed "reg";
var jumpTarget = 0; // Any code before this address is conditional

for (var pc = 0; pc < lastPc; pc++) {
var i = pt.code[pc];
var op = GET_OPCODE(i);
var a = GETARG_A(i);
switch (op) {
case OP_LOADNIL -> {
var b = GETARG_B(i);
if (a <= reg && reg <= a + b) lastInsn = filterPc(pc, jumpTarget);
}
case OP_TFORCALL -> {
if (a >= a + 2) lastInsn = filterPc(pc, jumpTarget);
}
case OP_CALL, OP_TAILCALL -> {
if (reg >= a) lastInsn = filterPc(pc, jumpTarget);
}
case OP_JMP -> {
var dest = pc + 1 + GETARG_sBx(i);
// If jump is forward and doesn't skip lastPc, update jump target
if (pc < dest && dest <= lastPc && dest > jumpTarget) jumpTarget = dest;
}
default -> {
if (testAMode(op) && reg == a) lastInsn = filterPc(pc, jumpTarget);
}
}
}
return lastInsn;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
-- SPDX-FileCopyrightText: 2025 The CC: Tweaked Developers
--
-- SPDX-License-Identifier: MPL-2.0

--[[- Internal tools for diagnosing errors and suggesting fixes.
> [!DANGER]
> This is an internal module and SHOULD NOT be used in your own code. It may
> be removed or changed at any time.
@local
]]

local debug, type, rawget = debug, type, rawget
local byte, floor, min, max = string.byte, math.floor, math.min, math.max

local function jaro_winkler(str_a, str_b)
local len_a, len_b = #str_a, #str_b
if len_a < 0 or len_a > 20 or len_b < 0 or len_b > 20 then return 0 end
if str_a == str_b then return 1 end

local max_dist = floor(max(len_a, len_b) / 2)

local common_chars = 0
local matches_a = {}
local matches_b = {}

for i = 1, len_a do
local char_a = byte(str_a, i)
for j = max(1, i - max_dist), min(len_b, i + max_dist) do
if char_a == byte(str_b, j) and not matches_b[j] then
matches_a[i] = true
matches_b[j] = true
common_chars = common_chars + 1
break
end
end
end

if common_chars == 0 then return 0 end

local transpositions = 0
local k = 1

for i = 1, len_a do
if matches_a[i] then
while not matches_b[k] do k = k + 1 end

if byte(str_a, i) ~= byte(str_b, k) then
transpositions = transpositions + 1
end

k = k + 1
end
end

-- Compute the Jaro similarity
local sim = (
(common_chars - floor(transpositions / 2)) / common_chars +
(common_chars / len_a) + (common_chars / len_b)
) / 3

local prefix = 0
for i = 1, 4 do
if byte(str_a, i) ~= byte(str_b, i) then break end
prefix = i
end

return min(1, sim + (prefix * 0.1 * (1 - sim)))
end

local function get_suggestions(source)
if #source ~= 2 then return end

local value, key = source[1], source[2]

-- Find all items in the table, and see if they seem similar.
local suggestions = {}
while type(value) == "table" do
for k in next, value do
if type(k) == "string" then
local similarity = jaro_winkler(k, key)
if similarity >= 0.9 then
suggestions[#suggestions + 1] = { value = k, sim = similarity }
end
end
end

local mt = debug.getmetatable(value)
if mt == nil then break end
value = rawget(mt, "__index")
end

table.sort(suggestions, function(a, b) return a.sim > b.sim end)

return suggestions
end

--[[- Get a tip to display at the end of an error.
@tparam string err The error message.
@tparam coroutine thread The current thread.
@tparam number frame_offset The offset into the thread where the current frame exists
@return An optional message to append to the error.
]]
local function get_tip(err, thread, frame_offset)
local nil_op = err:match("^attempt to (%l+) .* %(a nil value%)")
if not nil_op then return end

local has_error_context, error_context = pcall(require, "cc.internal.error_context")
if not has_error_context then return end
local op, source = error_context.context(thread, frame_offset)
if op == nil or op ~= nil_op then return end

local suggestions = get_suggestions(source)
if not suggestions or next(suggestions) == nil then return end

local pretty = require "cc.pretty"
local msg = "Did you mean: "

local n_suggestions = min(3, #suggestions)
for i = 1, n_suggestions do
if i > 1 then
if i == n_suggestions then msg = msg .. " or " else msg = msg .. ", " end
end
msg = msg .. pretty.text(suggestions[i].value, colours.lightGrey)
end
return msg .. "?"
end

return { get_tip = get_tip }
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local function find_frame(thread, file, line)
if not frame then break end

if frame.short_src == file and frame.what ~= "C" and frame.currentline == line then
return frame
return offset, frame
end
end
end
Expand Down Expand Up @@ -111,11 +111,11 @@ local function report(err, thread, source_map)

if type(err) ~= "string" then return end

local file, line = err:match("^([^:]+):(%d+):")
local file, line, err = err:match("^([^:]+):(%d+): (.*)")
if not file then return end
line = tonumber(line)

local frame = find_frame(thread, file, line)
local frame_offset, frame = find_frame(thread, file, line)
if not frame or not frame.currentcolumn then return end

local column = frame.currentcolumn
Expand Down Expand Up @@ -157,6 +157,7 @@ local function report(err, thread, source_map)
get_line = function() return line_contents end,
}, {
{ tag = "annotate", start_pos = column, end_pos = column, msg = "" },
require "cc.internal.error_hints".get_tip(err, thread, frame_offset),
})
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ setmetatable(tEnv, { __index = _ENV })
do
local make_package = require "cc.require".make
local dir = shell.dir()
_ENV.require, _ENV.package = make_package(_ENV, dir)
tEnv.require, tEnv.package = make_package(tEnv, dir)
end

if term.isColour() then
Expand Down
Loading

0 comments on commit e419318

Please sign in to comment.