Skip to content

Implement changes of function reference proposal #2562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/wabt/binary-reader-logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class BinaryReaderLogging : public BinaryReaderDelegate {
Result OnCatchExpr(Index tag_index) override;
Result OnCatchAllExpr() override;
Result OnCallIndirectExpr(Index sig_index, Index table_index) override;
Result OnCallRefExpr() override;
Result OnCallRefExpr(Type sig_type) override;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think type sounds better than Index, although it could be an Index.

Result OnCompareExpr(Opcode opcode) override;
Result OnConvertExpr(Opcode opcode) override;
Result OnDelegateExpr(Index depth) override;
Expand Down
2 changes: 1 addition & 1 deletion include/wabt/binary-reader-nop.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class BinaryReaderNop : public BinaryReaderDelegate {
Result OnCallIndirectExpr(Index sig_index, Index table_index) override {
return Result::Ok;
}
Result OnCallRefExpr() override { return Result::Ok; }
Result OnCallRefExpr(Type sig_type) override { return Result::Ok; }
Result OnCatchExpr(Index tag_index) override { return Result::Ok; }
Result OnCatchAllExpr() override { return Result::Ok; }
Result OnCompareExpr(Opcode opcode) override { return Result::Ok; }
Expand Down
2 changes: 1 addition & 1 deletion include/wabt/binary-reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class BinaryReaderDelegate {
Index default_target_depth) = 0;
virtual Result OnCallExpr(Index func_index) = 0;
virtual Result OnCallIndirectExpr(Index sig_index, Index table_index) = 0;
virtual Result OnCallRefExpr() = 0;
virtual Result OnCallRefExpr(Type sig_type) = 0;
virtual Result OnCatchExpr(Index tag_index) = 0;
virtual Result OnCatchAllExpr() = 0;
virtual Result OnCompareExpr(Opcode opcode) = 0;
Expand Down
42 changes: 33 additions & 9 deletions include/wabt/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,41 @@ namespace wabt {

struct Module;

enum class VarType {
enum class VarType : uint16_t {
Index,
Name,
};

struct Var {
// Var can represent variables or types.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a major internal change. It does not increase the size of Var. I think it is better than constructing a Var/Type pair.


// Represent a variable:
// has_opt_type() is false
// Only used by wast-parser

// Represent a type:
// has_opt_type() is true, is_index() is true
// type can be get by to_type()
// Binary reader only constructs this variant

// Represent both a variable and a type:
// has_opt_type() is true, is_name() is true
// A reference, which index is unknown
// Only used by wast-parser

explicit Var();
explicit Var(Index index, const Location& loc);
explicit Var(std::string_view name, const Location& loc);
explicit Var(Type type, const Location& loc);
Var(Var&&);
Var(const Var&);
Var& operator=(const Var&);
Var& operator=(Var&&);
~Var();

VarType type() const { return type_; }
bool is_index() const { return type_ == VarType::Index; }
bool is_name() const { return type_ == VarType::Name; }
bool has_opt_type() const { return opt_type_ < 0; }

Index index() const {
assert(is_index());
Expand All @@ -63,17 +80,25 @@ struct Var {
assert(is_name());
return name_;
}
Type::Enum opt_type() const {
assert(has_opt_type());
return static_cast<Type::Enum>(opt_type_);
}

void set_index(Index);
void set_name(std::string&&);
void set_name(std::string_view);
void set_opt_type(Type::Enum);
Type to_type() const;

Location loc;

private:
void Destroy();

VarType type_;
// Can be set to Type::Enum types, Type::Any represent no optional type.
int16_t opt_type_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This range should be enough for a long time.

union {
Index index_;
std::string name_;
Expand Down Expand Up @@ -155,6 +180,7 @@ struct Const {
}
void set_funcref() { From<uintptr_t>(Type::FuncRef, 0); }
void set_externref(uintptr_t x) { From(Type::ExternRef, x); }
void set_extern(uintptr_t x) { From(Type(Type::ExternRef, Type::ReferenceNonNull), x); }
void set_null(Type type) { From<uintptr_t>(type, kRefNullBits); }

bool is_expected_nan(int lane = 0) const {
Expand Down Expand Up @@ -537,10 +563,10 @@ using MemoryCopyExpr = MemoryBinaryExpr<ExprType::MemoryCopy>;
template <ExprType TypeEnum>
class RefTypeExpr : public ExprMixin<TypeEnum> {
public:
RefTypeExpr(Type type, const Location& loc = Location())
RefTypeExpr(Var type, const Location& loc = Location())
: ExprMixin<TypeEnum>(loc), type(type) {}

Type type;
Var type;
};

using RefNullExpr = RefTypeExpr<ExprType::RefNull>;
Expand Down Expand Up @@ -662,8 +688,8 @@ using MemoryInitExpr = MemoryVarExpr<ExprType::MemoryInit>;

class SelectExpr : public ExprMixin<ExprType::Select> {
public:
SelectExpr(TypeVector type, const Location& loc = Location())
: ExprMixin<ExprType::Select>(loc), result_type(type) {}
SelectExpr(const Location& loc = Location())
: ExprMixin<ExprType::Select>(loc) {}
TypeVector result_type;
};

Expand Down Expand Up @@ -727,9 +753,7 @@ class CallRefExpr : public ExprMixin<ExprType::CallRef> {
explicit CallRefExpr(const Location& loc = Location())
: ExprMixin<ExprType::CallRef>(loc) {}

// This field is setup only during Validate phase,
// so keep that in mind when you use it.
Var function_type_index;
Var sig_type;
};

template <ExprType TypeEnum>
Expand Down
1 change: 1 addition & 0 deletions include/wabt/leb128.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ void WriteS32Leb128(Stream* stream, T value, const char* desc) {
size_t ReadU32Leb128(const uint8_t* p, const uint8_t* end, uint32_t* out_value);
size_t ReadU64Leb128(const uint8_t* p, const uint8_t* end, uint64_t* out_value);
size_t ReadS32Leb128(const uint8_t* p, const uint8_t* end, uint32_t* out_value);
size_t ReadS33Leb128(const uint8_t* p, const uint8_t* end, uint64_t* out_value);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec adds this new type, the output is uint64_t. It could be merged with ReadS32Leb128 by adding a bool* argument.

size_t ReadS64Leb128(const uint8_t* p, const uint8_t* end, uint64_t* out_value);

} // namespace wabt
Expand Down
42 changes: 28 additions & 14 deletions include/wabt/shared-validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct ValidateOptions {
class SharedValidator {
public:
WABT_DISALLOW_COPY_AND_ASSIGN(SharedValidator);
using FuncType = TypeChecker::FuncType;
SharedValidator(Errors*, const ValidateOptions& options);

// TODO: Move into SharedValidator?
Expand Down Expand Up @@ -141,7 +142,7 @@ class SharedValidator {
Result EndBrTable(const Location&);
Result OnCall(const Location&, Var func_var);
Result OnCallIndirect(const Location&, Var sig_var, Var table_var);
Result OnCallRef(const Location&, Index* function_type_index);
Result OnCallRef(const Location&, Var function_type_var);
Result OnCatch(const Location&, Var tag_var, bool is_catch_all);
Result OnCompare(const Location&, Opcode);
Result OnConst(const Location&, Type);
Expand Down Expand Up @@ -178,7 +179,7 @@ class SharedValidator {
Result OnNop(const Location&);
Result OnRefFunc(const Location&, Var func_var);
Result OnRefIsNull(const Location&);
Result OnRefNull(const Location&, Type type);
Result OnRefNull(const Location&, Var func_type_var);
Result OnRethrow(const Location&, Var depth);
Result OnReturnCall(const Location&, Var func_var);
Result OnReturnCallIndirect(const Location&, Var sig_var, Var table_var);
Expand Down Expand Up @@ -221,18 +222,6 @@ class SharedValidator {
Result OnUnreachable(const Location&);

private:
struct FuncType {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to type checker.

FuncType() = default;
FuncType(const TypeVector& params,
const TypeVector& results,
Index type_index)
: params(params), results(results), type_index(type_index) {}

TypeVector params;
TypeVector results;
Index type_index;
};

struct StructType {
StructType() = default;
StructType(const TypeMutVector& fields) : fields(fields) {}
Expand Down Expand Up @@ -289,6 +278,25 @@ class SharedValidator {
Index end;
};

struct LocalReferenceMap {
Type type;
Index bit_index;
};

struct RecursionDetector {
RecursionDetector(SharedValidator *validator)
: validator(validator) {}

SharedValidator *validator;
bool recursion_found = false;
Index iteration = 0;
std::map<Index, Index> processed_func_types;
std::vector<Index> visited_func_types;

Result CheckRecursion(Type type,
const char* desc);
};

bool ValidInitOpcode(Opcode opcode) const;
Result CheckInstr(Opcode opcode, const Location& loc);
Result CheckType(const Location&,
Expand Down Expand Up @@ -336,6 +344,10 @@ class SharedValidator {

TypeVector ToTypeVector(Index count, const Type* types);

void SaveLocalRefs();
void RestoreLocalRefs(Result result);
void IgnoreLocalRefs();

ValidateOptions options_;
Errors* errors_;
TypeChecker typechecker_; // TODO: Move into SharedValidator.
Expand All @@ -361,6 +373,8 @@ class SharedValidator {
// Includes parameters, since this is only used for validating
// local.{get,set,tee} instructions.
std::vector<LocalDecl> locals_;
std::map<Index, LocalReferenceMap> local_refs_map_;
std::vector<bool> local_ref_is_set_;

std::set<std::string> export_names_; // Used to check for duplicates.
std::set<Index> declared_funcs_; // TODO: optimize?
Expand Down
1 change: 1 addition & 0 deletions include/wabt/token.def
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ WABT_TOKEN(Module, "module")
WABT_TOKEN(Mut, "mut")
WABT_TOKEN(NanArithmetic, "nan:arithmetic")
WABT_TOKEN(NanCanonical, "nan:canonical")
WABT_TOKEN(Null, "null")
WABT_TOKEN(Offset, "offset")
WABT_TOKEN(Output, "output")
WABT_TOKEN(PageSize, "pagesize")
Expand Down
22 changes: 19 additions & 3 deletions include/wabt/type-checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <functional>
#include <type_traits>
#include <vector>
#include <map>

#include "wabt/common.h"
#include "wabt/feature.h"
Expand All @@ -31,6 +32,18 @@ class TypeChecker {
public:
using ErrorCallback = std::function<void(const char* msg)>;

struct FuncType {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that could be improved. We could create a separate comparison class, which contains only the CheckType method. Then we could do comparisons without instantiating a type checker.

FuncType() = default;
FuncType(const TypeVector& params,
const TypeVector& results,
Index type_index)
: params(params), results(results), type_index(type_index) {}

TypeVector params;
TypeVector results;
Index type_index;
};

struct Label {
Label(LabelType,
const TypeVector& param_types,
Expand All @@ -46,9 +59,11 @@ class TypeChecker {
TypeVector result_types;
size_t type_stack_limit;
bool unreachable;
std::vector<bool> local_ref_is_set_;
};

explicit TypeChecker(const Features& features) : features_(features) {}
explicit TypeChecker(const Features& features, std::map<Index, FuncType>& func_types)
: features_(features), func_types_(func_types) {}

void set_error_callback(const ErrorCallback& error_callback) {
error_callback_ = error_callback;
Expand Down Expand Up @@ -80,7 +95,7 @@ class TypeChecker {
Result OnCallIndirect(const TypeVector& param_types,
const TypeVector& result_types,
const Limits& table_limits);
Result OnIndexedFuncRef(Index* out_index);
Result OnCallRef(Type);
Result OnReturnCall(const TypeVector& param_types,
const TypeVector& result_types);
Result OnReturnCallIndirect(const TypeVector& param_types,
Expand Down Expand Up @@ -141,7 +156,7 @@ class TypeChecker {
Result BeginInitExpr(Type type);
Result EndInitExpr();

static Result CheckType(Type actual, Type expected);
Result CheckType(Type actual, Type expected);

private:
void WABT_PRINTF_FORMAT(2, 3) PrintError(const char* fmt, ...);
Expand Down Expand Up @@ -210,6 +225,7 @@ class TypeChecker {
// to represent "any".
TypeVector* br_table_sig_ = nullptr;
Features features_;
std::map<Index, FuncType>& func_types_;
};

} // namespace wabt
Expand Down
Loading
Loading