From a6e1e9afdf2433ecde2aa3c65578dcb088356ca6 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 13 Apr 2026 15:55:42 -0700 Subject: [PATCH 1/3] [NFC] Simplify lexer and move to header The lexer previously used its own internal `LexerCtx` abstraction that allowed it to consume the characters that made up a token without changing the lexer state, then update the state at once when committing to consuming the characters. However, manually resetting the lexer to the original position when giving up on parsing a token is simple enough that this abstraction was not holding its weight. Simplify the lexer by removing internal contexts, and move the simplified method bodies to lexer.h. Generally we try to avoid putting lots of code in headers, but in this case making the code available to the inliner, along with removing the extra layer of abstraction, makes the parser about 20% faster. --- src/parser/CMakeLists.txt | 1 - src/parser/contexts.h | 2 +- src/parser/lexer.cpp | 1188 ------------------------------------- src/parser/lexer.h | 1093 +++++++++++++++++++++++++++++++--- test/gtest/wat-lexer.cpp | 11 + 5 files changed, 1036 insertions(+), 1259 deletions(-) delete mode 100644 src/parser/lexer.cpp diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt index 8b7846ca9e9..7d4704dba24 100644 --- a/src/parser/CMakeLists.txt +++ b/src/parser/CMakeLists.txt @@ -2,7 +2,6 @@ FILE(GLOB parser_HEADERS *.h) set(parser_SOURCES context-decls.cpp context-defs.cpp - lexer.cpp parse-1-decls.cpp parse-2-typedefs.cpp parse-3-implicit-types.cpp diff --git a/src/parser/contexts.h b/src/parser/contexts.h index fbdd8d0505a..eb09a0bb3b0 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -1985,7 +1985,7 @@ struct ParseDefsCtx : TypeParserCtx, AnnotationParserCtx { void setSrcLoc(const std::vector& annotations) { const Annotation* annotation = nullptr; for (auto& a : annotations) { - if (a.kind == srcAnnotationKind) { + if (a.kind.str == std::string_view("src")) { annotation = &a; } } diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp deleted file mode 100644 index 5d2aedabe66..00000000000 --- a/src/parser/lexer.cpp +++ /dev/null @@ -1,1188 +0,0 @@ -/* - * Copyright 2023 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "lexer.h" -#include "support/bits.h" -#include "support/string.h" - -using namespace std::string_view_literals; - -namespace wasm::WATParser { - -Name srcAnnotationKind("src"); - -namespace { - -// ================ -// Lexical Analysis -// ================ - -// The result of lexing a token fragment. -struct LexResult { - std::string_view span; -}; - -// Lexing context that accumulates lexed input to produce a token fragment. -struct LexCtx { -private: - // The input we are lexing. - std::string_view input; - - // How much of the input we have already lexed. - size_t lexedSize = 0; - -public: - explicit LexCtx(std::string_view in) : input(in) {} - - // Return the fragment that has been lexed so far. - std::optional lexed() const { - if (lexedSize > 0) { - return {LexResult{input.substr(0, lexedSize)}}; - } - return {}; - } - - // The next input that has not already been lexed. - std::string_view next() const { return input.substr(lexedSize); } - - // Get the next character without consuming it. - uint8_t peek() const { return next()[0]; } - - // The size of the unlexed input. - size_t size() const { return input.size() - lexedSize; } - - // Whether there is no more input. - bool empty() const { return size() == 0; } - - // Tokens must be separated by spaces or parentheses. - bool canFinish() const; - - // Whether the unlexed input starts with prefix `sv`. - size_t startsWith(std::string_view sv) const { - return next().substr(0, sv.size()) == sv; - } - - // Consume the next `n` characters. - void take(size_t n) { lexedSize += n; } - - // Consume an additional lexed fragment. - void take(const LexResult& res) { lexedSize += res.span.size(); } - - // Consume the prefix and return true if possible. - bool takePrefix(std::string_view sv) { - if (startsWith(sv)) { - take(sv.size()); - return true; - } - return false; - } - - // Consume the rest of the input. - void takeAll() { lexedSize = input.size(); } -}; - -enum OverflowBehavior { DisallowOverflow, IgnoreOverflow }; - -std::optional getDigit(char c) { - if ('0' <= c && c <= '9') { - return c - '0'; - } - return {}; -} - -std::optional getHexDigit(char c) { - if ('0' <= c && c <= '9') { - return c - '0'; - } - if ('A' <= c && c <= 'F') { - return 10 + c - 'A'; - } - if ('a' <= c && c <= 'f') { - return 10 + c - 'a'; - } - return {}; -} - -enum Sign { NoSign, Pos, Neg }; - -// The result of lexing an integer token fragment. -struct LexIntResult : LexResult { - uint64_t n; - Sign sign; - - template bool isUnsigned() { - static_assert(std::is_integral_v && std::is_unsigned_v); - return sign == NoSign && n <= std::numeric_limits::max(); - } - - template bool isSigned() { - static_assert(std::is_integral_v && std::is_signed_v); - if (sign == Neg) { - return uint64_t(std::numeric_limits::min()) <= n || n == 0; - } - return n <= uint64_t(std::numeric_limits::max()); - } -}; - -// Lexing context that accumulates lexed input to produce an integer token -// fragment. -struct LexIntCtx : LexCtx { - using LexCtx::take; - -private: - uint64_t n = 0; - Sign sign = NoSign; - bool overflow = false; - -public: - explicit LexIntCtx(std::string_view in) : LexCtx(in) {} - - // Lex only the underlying span, ignoring the overflow and value. - std::optional lexedRaw() { - if (auto basic = LexCtx::lexed()) { - return LexIntResult{*basic, 0, NoSign}; - } - return {}; - } - - std::optional lexed() { - if (overflow) { - return {}; - } - if (auto basic = LexCtx::lexed()) { - return LexIntResult{*basic, sign == Neg ? -n : n, sign}; - } - return {}; - } - - void takeSign() { - if (takePrefix("+"sv)) { - sign = Pos; - } else if (takePrefix("-"sv)) { - sign = Neg; - } else { - sign = NoSign; - } - } - - bool takeDigit() { - if (!empty()) { - if (auto d = getDigit(peek())) { - take(1); - uint64_t newN = n * 10 + *d; - if (newN < n) { - overflow = true; - } - n = newN; - return true; - } - } - return false; - } - - bool takeHexdigit() { - if (!empty()) { - if (auto h = getHexDigit(peek())) { - take(1); - uint64_t newN = n * 16 + *h; - if (newN < n) { - overflow = true; - } - n = newN; - return true; - } - } - return false; - } - - void take(const LexIntResult& res) { - LexCtx::take(res); - n = res.n; - } -}; - -struct LexFloatResult : LexResult { - // The payload if we lexed a nan with payload. We cannot store the payload - // directly in `d` because we do not know at this point whether we are parsing - // an f32 or f64 and therefore we do not know what the allowable payloads are. - // No payload with NaN means to use the default payload for the expected float - // width. - std::optional nanPayload; - double d; -}; - -struct LexFloatCtx : LexCtx { - std::optional nanPayload; - - LexFloatCtx(std::string_view in) : LexCtx(in) {} - - std::optional lexed() { - const double posNan = std::copysign(NAN, 1.0); - const double negNan = std::copysign(NAN, -1.0); - assert(!std::signbit(posNan) && "expected positive NaN to be positive"); - assert(std::signbit(negNan) && "expected negative NaN to be negative"); - auto basic = LexCtx::lexed(); - if (!basic) { - return {}; - } - // strtod does not return NaNs with the expected signs on all platforms. - // TODO: use starts_with once we have C++20. - if (basic->span.substr(0, 3) == "nan"sv || - basic->span.substr(0, 4) == "+nan"sv) { - return LexFloatResult{*basic, nanPayload, posNan}; - } - if (basic->span.substr(0, 4) == "-nan"sv) { - return LexFloatResult{*basic, nanPayload, negNan}; - } - // Do not try to implement fully general and precise float parsing - // ourselves. Instead, call out to std::strtod to do our parsing. This means - // we need to strip any underscores since `std::strtod` does not understand - // them. - std::stringstream ss; - for (const char *curr = basic->span.data(), - *end = curr + basic->span.size(); - curr != end; - ++curr) { - if (*curr != '_') { - ss << *curr; - } - } - std::string str = ss.str(); - char* last; - double d = std::strtod(str.data(), &last); - assert(last == str.data() + str.size() && "could not parse float"); - return LexFloatResult{*basic, {}, d}; - } -}; - -struct LexStrResult : LexResult { - // Allocate a string only if there are escape sequences, otherwise just use - // the original string_view. - std::optional str; - - std::string_view getStr() { - if (str) { - return *str; - } - return span; - } -}; - -struct LexStrCtx : LexCtx { -private: - // Used to build a string with resolved escape sequences. Only used when the - // parsed string contains escape sequences, otherwise we can just use the - // parsed string directly. - std::optional escapeBuilder; - -public: - LexStrCtx(std::string_view in) : LexCtx(in) {} - - std::optional lexed() { - if (auto basic = LexCtx::lexed()) { - if (escapeBuilder) { - return LexStrResult{*basic, {escapeBuilder->str()}}; - } else { - return LexStrResult{*basic, {}}; - } - } - return {}; - } - - void takeChar() { - if (escapeBuilder) { - *escapeBuilder << peek(); - } - LexCtx::take(1); - } - - void ensureBuildingEscaped() { - if (escapeBuilder) { - return; - } - // Drop the opening '"'. - escapeBuilder = std::stringstream{}; - *escapeBuilder << LexCtx::lexed()->span.substr(1); - } - - void appendEscaped(char c) { *escapeBuilder << c; } - - bool appendUnicode(uint64_t u) { - if ((0xd800 <= u && u < 0xe000) || 0x110000 <= u) { - return false; - } - String::writeWTF8CodePoint(*escapeBuilder, u); - return true; - } -}; - -struct LexIdResult : LexResult { - bool isStr = false; - std::optional str; -}; - -struct LexIdCtx : LexCtx { - bool isStr = false; - std::optional str; - - LexIdCtx(std::string_view in) : LexCtx(in) {} - - std::optional lexed() { - if (auto basic = LexCtx::lexed()) { - return LexIdResult{*basic, isStr, str}; - } - return {}; - } -}; - -struct LexAnnotationResult : LexResult { - Annotation annotation; -}; - -struct LexAnnotationCtx : LexCtx { - std::string_view kind; - size_t kindSize = 0; - std::string_view contents; - size_t contentsSize = 0; - - explicit LexAnnotationCtx(std::string_view in) : LexCtx(in) {} - - void startKind() { kind = next(); } - - void takeKind(size_t size) { - kindSize += size; - take(size); - } - - void setKind(std::string_view kind) { - this->kind = kind; - kindSize = kind.size(); - } - - void startContents() { contents = next(); } - - void takeContents(size_t size) { - contentsSize += size; - take(size); - } - - std::optional lexed() { - if (auto basic = LexCtx::lexed()) { - return LexAnnotationResult{ - *basic, - {Name(kind.substr(0, kindSize)), contents.substr(0, contentsSize)}}; - } - return std::nullopt; - } -}; - -std::optional idchar(std::string_view); -std::optional space(std::string_view); -std::optional keyword(std::string_view); -std::optional integer(std::string_view); -std::optional float_(std::string_view); -std::optional str(std::string_view); -std::optional ident(std::string_view); - -// annotation ::= ';;@' [^\n]* | '(@'idchar+ annotelem* ')' -// annotelem ::= keyword | reserved | uN | sN | fN | string | id -// | '(' annotelem* ')' | '(@'idchar+ annotelem* ')' -std::optional annotation(std::string_view in) { - LexAnnotationCtx ctx(in); - if (ctx.takePrefix(";;@"sv)) { - ctx.setKind(srcAnnotationKind.str); - ctx.startContents(); - if (auto size = ctx.next().find('\n'); size != ""sv.npos) { - ctx.takeContents(size); - } else { - ctx.takeContents(ctx.next().size()); - } - } else if (ctx.takePrefix("(@"sv)) { - ctx.startKind(); - bool hasIdchar = false; - while (auto lexed = idchar(ctx.next())) { - ctx.takeKind(1); - hasIdchar = true; - } - if (!hasIdchar) { - return std::nullopt; - } - ctx.startContents(); - size_t depth = 1; - while (true) { - if (ctx.empty()) { - return std::nullopt; - } - if (auto lexed = space(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (auto lexed = keyword(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (auto lexed = integer(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (auto lexed = float_(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (auto lexed = str(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (auto lexed = ident(ctx.next())) { - ctx.takeContents(lexed->span.size()); - continue; - } - if (ctx.startsWith("(@"sv)) { - ctx.takeContents(2); - bool hasIdchar = false; - while (auto lexed = idchar(ctx.next())) { - ctx.takeContents(1); - hasIdchar = true; - } - if (!hasIdchar) { - return std::nullopt; - } - ++depth; - continue; - } - if (ctx.startsWith("("sv)) { - ctx.takeContents(1); - ++depth; - continue; - } - if (ctx.startsWith(")"sv)) { - --depth; - if (depth == 0) { - ctx.take(1); - break; - } - ctx.takeContents(1); - continue; - } - // Unrecognized token. - return std::nullopt; - } - } - return ctx.lexed(); -} - -// comment ::= linecomment | blockcomment -// linecomment ::= ';;' linechar* ('\n' | eof) -// linechar ::= c:char (if c != '\n') -// blockcomment ::= '(;' blockchar* ';)' -// blockchar ::= c:char (if c != ';' and c != '(') -// | ';' (if the next char is not ')') -// | '(' (if the next char is not ';') -// | blockcomment -std::optional comment(std::string_view in) { - LexCtx ctx(in); - if (ctx.size() < 2) { - return {}; - } - - // Line comment - if (!ctx.startsWith(";;@"sv) && ctx.takePrefix(";;"sv)) { - if (auto size = ctx.next().find('\n'); size != ""sv.npos) { - ctx.take(size); - } else { - ctx.takeAll(); - } - return ctx.lexed(); - } - - // Block comment (possibly nested!) - if (ctx.takePrefix("(;"sv)) { - size_t depth = 1; - while (depth > 0 && ctx.size() >= 2) { - if (ctx.takePrefix("(;"sv)) { - ++depth; - } else if (ctx.takePrefix(";)"sv)) { - --depth; - } else { - ctx.take(1); - } - } - if (depth > 0) { - // TODO: Add error production for non-terminated block comment. - return {}; - } - return ctx.lexed(); - } - - return {}; -} - -std::optional spacechar(std::string_view in) { - LexCtx ctx(in); - ctx.takePrefix(" "sv) || ctx.takePrefix("\n"sv) || ctx.takePrefix("\r"sv) || - ctx.takePrefix("\t"sv); - return ctx.lexed(); -} - -// space ::= (' ' | format | comment)* -// format ::= '\t' | '\n' | '\r' -std::optional space(std::string_view in) { - LexCtx ctx(in); - while (ctx.size()) { - if (auto lexed = spacechar(ctx.next())) { - ctx.take(*lexed); - } else if (auto lexed = comment(ctx.next())) { - ctx.take(*lexed); - } else { - break; - } - } - return ctx.lexed(); -} - -bool LexCtx::canFinish() const { - // Logically we want to check for eof, parens, and space. But we don't - // actually want to parse more than a couple characters of space, so check for - // individual space chars or comment starts instead. - return empty() || startsWith("("sv) || startsWith(")"sv) || - spacechar(next()) || startsWith(";;"sv); -} - -// num ::= d:digit => d -// | n:num '_'? d:digit => 10*n + d -// digit ::= '0' => 0 | ... | '9' => 9 -std::optional num(std::string_view in, - OverflowBehavior overflow = DisallowOverflow) { - LexIntCtx ctx(in); - if (ctx.empty()) { - return {}; - } - if (!ctx.takeDigit()) { - return {}; - } - while (true) { - bool under = ctx.takePrefix("_"sv); - if (!ctx.takeDigit()) { - if (!under) { - return overflow == DisallowOverflow ? ctx.lexed() : ctx.lexedRaw(); - } - // TODO: Add error production for trailing underscore. - return {}; - } - } -} - -// hexnum ::= h:hexdigit => h -// | n:hexnum '_'? h:hexdigit => 16*n + h -// hexdigit ::= d:digit => d -// | 'A' => 10 | ... | 'F' => 15 -// | 'a' => 10 | ... | 'f' => 15 -std::optional -hexnum(std::string_view in, OverflowBehavior overflow = DisallowOverflow) { - LexIntCtx ctx(in); - if (!ctx.takeHexdigit()) { - return {}; - } - while (true) { - bool under = ctx.takePrefix("_"sv); - if (!ctx.takeHexdigit()) { - if (!under) { - return overflow == DisallowOverflow ? ctx.lexed() : ctx.lexedRaw(); - } - // TODO: Add error production for trailing underscore. - return {}; - } - } -} - -// uN ::= n:num => n (if n < 2^N) -// | '0x' n:hexnum => n (if n < 2^N) -// sN ::= s:sign n:num => [s]n (if -2^(N-1) <= [s]n < 2^(N-1)) -// | s:sign '0x' n:hexnum => [s]n (if -2^(N-1) <= [s]n < 2^(N-1)) -// sign ::= {} => + | '+' => + | '-' => - -// -// Note: Defer bounds and sign checking until we know what kind of integer we -// expect. -std::optional integer(std::string_view in) { - LexIntCtx ctx(in); - ctx.takeSign(); - if (ctx.takePrefix("0x"sv)) { - if (auto lexed = hexnum(ctx.next())) { - ctx.take(*lexed); - if (ctx.canFinish()) { - return ctx.lexed(); - } - } - // TODO: Add error production for unrecognized hexnum. - return {}; - } - if (auto lexed = num(ctx.next())) { - ctx.take(*lexed); - if (ctx.canFinish()) { - return ctx.lexed(); - } - } - return {}; -} - -// float ::= p:num '.'? => p -// | p:num '.' q:frac => p + q -// | p:num '.'? ('E'|'e') s:sign e:num => p * 10^([s]e) -// | p:num '.' q:frac ('E'|'e') s:sign e:num => (p + q) * 10^([s]e) -// frac ::= d:digit => d/10 -// | d:digit '_'? p:frac => (d + p/10) / 10 -std::optional decfloat(std::string_view in) { - LexCtx ctx(in); - if (auto lexed = num(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } else { - return {}; - } - // Optional '.' followed by optional frac - if (ctx.takePrefix("."sv)) { - if (auto lexed = num(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } - } - if (ctx.takePrefix("E"sv) || ctx.takePrefix("e"sv)) { - // Optional sign - ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); - if (auto lexed = num(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } else { - // TODO: Add error production for missing exponent. - return {}; - } - } - return ctx.lexed(); -} - -// hexfloat ::= '0x' p:hexnum '.'? => p -// | '0x' p:hexnum '.' q:hexfrac => p + q -// | '0x' p:hexnum '.'? ('P'|'p') s:sign e:num => p * 2^([s]e) -// | '0x' p:hexnum '.' q:hexfrac ('P'|'p') s:sign e:num -// => (p + q) * 2^([s]e) -// hexfrac ::= h:hexdigit => h/16 -// | h:hexdigit '_'? p:hexfrac => (h + p/16) / 16 -std::optional hexfloat(std::string_view in) { - LexCtx ctx(in); - if (!ctx.takePrefix("0x"sv)) { - return {}; - } - if (auto lexed = hexnum(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } else { - return {}; - } - // Optional '.' followed by optional hexfrac - if (ctx.takePrefix("."sv)) { - if (auto lexed = hexnum(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } - } - if (ctx.takePrefix("P"sv) || ctx.takePrefix("p"sv)) { - // Optional sign - ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); - if (auto lexed = num(ctx.next(), IgnoreOverflow)) { - ctx.take(*lexed); - } else { - // TODO: Add error production for missing exponent. - return {}; - } - } - return ctx.lexed(); -} - -// fN ::= s:sign z:fNmag => [s]z -// fNmag ::= z:float => float_N(z) (if float_N(z) != +/-infinity) -// | z:hexfloat => float_N(z) (if float_N(z) != +/-infinity) -// | 'inf' => infinity -// | 'nan' => nan(2^(signif(N)-1)) -// | 'nan:0x' n:hexnum => nan(n) (if 1 <= n < 2^signif(N)) -std::optional float_(std::string_view in) { - LexFloatCtx ctx(in); - // Optional sign - ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); - if (auto lexed = hexfloat(ctx.next())) { - ctx.take(*lexed); - } else if (auto lexed = decfloat(ctx.next())) { - ctx.take(*lexed); - } else if (ctx.takePrefix("inf"sv)) { - // nop - } else if (ctx.takePrefix("nan"sv)) { - if (ctx.takePrefix(":0x"sv)) { - if (auto lexed = hexnum(ctx.next())) { - ctx.take(*lexed); - ctx.nanPayload = lexed->n; - } else { - // TODO: Add error production for malformed NaN payload. - return {}; - } - } else { - // No explicit payload necessary; we will inject the default payload - // later. - } - } else { - return {}; - } - if (ctx.canFinish()) { - return ctx.lexed(); - } - return {}; -} - -// idchar ::= '0' | ... | '9' -// | 'A' | ... | 'Z' -// | 'a' | ... | 'z' -// | '!' | '#' | '$' | '%' | '&' | ''' | '*' | '+' -// | '-' | '.' | '/' | ':' | '<' | '=' | '>' | '?' -// | '@' | '\' | '^' | '_' | '`' | '|' | '~' -std::optional idchar(std::string_view in) { - LexCtx ctx(in); - if (ctx.empty()) { - return {}; - } - uint8_t c = ctx.peek(); - // All the allowed characters lie in the range '!' to '~', and within that - // range the vast majority of characters are allowed, so it is significantly - // faster to check for the disallowed characters instead. - if (c < '!' || c > '~') { - return ctx.lexed(); - } - switch (c) { - case '"': - case '(': - case ')': - case ',': - case ';': - case '[': - case ']': - case '{': - case '}': - return ctx.lexed(); - } - ctx.take(1); - return ctx.lexed(); -} - -// string ::= '"' (b*:stringelem)* '"' => concat((b*)*) -// (if |concat((b*)*)| < 2^32) -// stringelem ::= c:stringchar => utf8(c) -// | '\' n:hexdigit m:hexdigit => 16*n + m -// stringchar ::= c:char => c -// (if c >= U+20 && c != U+7f && c != '"' && c != '\') -// | '\t' => \t | '\n' => \n | '\r' => \r -// | '\\' => \ | '\"' => " | '\'' => ' -// | '\u{' n:hexnum '}' => U+(n) -// (if n < 0xD800 and 0xE000 <= n <= 0x110000) -std::optional str(std::string_view in) { - LexStrCtx ctx(in); - if (!ctx.takePrefix("\""sv)) { - return {}; - } - while (!ctx.takePrefix("\""sv)) { - if (ctx.empty()) { - // TODO: Add error production for unterminated string. - return {}; - } - if (ctx.startsWith("\\"sv)) { - // Escape sequences - ctx.ensureBuildingEscaped(); - ctx.take(1); - if (ctx.takePrefix("t"sv)) { - ctx.appendEscaped('\t'); - } else if (ctx.takePrefix("n"sv)) { - ctx.appendEscaped('\n'); - } else if (ctx.takePrefix("r"sv)) { - ctx.appendEscaped('\r'); - } else if (ctx.takePrefix("\\"sv)) { - ctx.appendEscaped('\\'); - } else if (ctx.takePrefix("\""sv)) { - ctx.appendEscaped('"'); - } else if (ctx.takePrefix("'"sv)) { - ctx.appendEscaped('\''); - } else if (ctx.takePrefix("u{"sv)) { - auto lexed = hexnum(ctx.next()); - if (!lexed) { - // TODO: Add error production for malformed unicode escapes. - return {}; - } - ctx.take(*lexed); - if (!ctx.takePrefix("}"sv)) { - // TODO: Add error production for malformed unicode escapes. - return {}; - } - if (!ctx.appendUnicode(lexed->n)) { - // TODO: Add error production for invalid unicode values. - return {}; - } - } else { - LexIntCtx ictx(ctx.next()); - if (!ictx.takeHexdigit() || !ictx.takeHexdigit()) { - // TODO: Add error production for unrecognized escape sequence. - return {}; - } - auto lexed = *ictx.lexed(); - ctx.take(lexed); - ctx.appendEscaped(char(lexed.n)); - } - } else { - // Normal characters - if (uint8_t c = ctx.peek(); c >= 0x20 && c != 0x7F) { - ctx.takeChar(); - } else { - // TODO: Add error production for unescaped control characters. - return {}; - } - } - } - return ctx.lexed(); -} - -// id ::= '$' idchar+ | '$' str -std::optional ident(std::string_view in) { - LexIdCtx ctx(in); - if (!ctx.takePrefix("$"sv)) { - return {}; - } - // Quoted identifier e.g. $"foo" - if (auto s = str(ctx.next())) { - if (!String::isUTF8(s->getStr())) { - return {}; - } - - // empty names, including $"" are not allowed. - if (s->span == "\"\"") { - return {}; - } - - ctx.isStr = true; - ctx.str = s->str; - ctx.take(*s); - } else if (auto lexed = idchar(ctx.next())) { - ctx.take(*lexed); - while (auto lexed = idchar(ctx.next())) { - ctx.take(*lexed); - } - } else { - return {}; - } - if (ctx.canFinish()) { - return ctx.lexed(); - } - return {}; -} - -// keyword ::= ( 'a' | ... | 'z' ) idchar* (if literal terminal in grammar) -// reserved ::= idchar+ -// -// The "keyword" token we lex here covers both keywords as well as any reserved -// tokens that match the keyword format. This saves us from having to enumerate -// all the valid keywords here. These invalid keywords will still produce -// errors, just at a higher level of the parser. -std::optional keyword(std::string_view in) { - LexCtx ctx(in); - if (ctx.empty()) { - return {}; - } - uint8_t start = ctx.peek(); - if ('a' <= start && start <= 'z') { - ctx.take(1); - } else { - return {}; - } - while (auto lexed = idchar(ctx.next())) { - ctx.take(*lexed); - } - return ctx.lexed(); -} - -} // anonymous namespace - -void Lexer::skipSpace() { - while (true) { - if (auto ctx = annotation(next())) { - pos += ctx->span.size(); - annotations.push_back(ctx->annotation); - continue; - } - if (auto ctx = space(next())) { - pos += ctx->span.size(); - continue; - } - break; - } -} - -std::optional Lexer::peekChar() const { - auto n = next(); - if (n.empty()) { - return std::nullopt; - } - - return n[0]; -} - -bool Lexer::takeLParen() { - if (LexCtx(next()).startsWith("("sv)) { - ++pos; - advance(); - return true; - } - return false; -} - -bool Lexer::takeRParen() { - if (LexCtx(next()).startsWith(")"sv)) { - ++pos; - advance(); - return true; - } - return false; -} - -std::optional Lexer::takeString() { - if (auto result = str(next())) { - pos += result->span.size(); - advance(); - if (result->str) { - return result->str; - } - // Remove quotes. - return std::string(result->span.substr(1, result->span.size() - 2)); - } - return std::nullopt; -} - -std::optional Lexer::takeID() { - if (auto result = ident(next())) { - pos += result->span.size(); - advance(); - if (result->str) { - return Name(*result->str); - } - if (result->isStr) { - // Remove '$' and quotes. - return Name(result->span.substr(2, result->span.size() - 3)); - } - // Remove '$'. - return Name(result->span.substr(1)); - } - return std::nullopt; -} - -std::optional Lexer::takeKeyword() { - if (auto result = keyword(next())) { - pos += result->span.size(); - advance(); - return result->span; - } - return std::nullopt; -} - -bool Lexer::takeKeyword(std::string_view expected) { - if (auto result = keyword(next()); result && result->span == expected) { - pos += expected.size(); - advance(); - return true; - } - return false; -} - -std::optional Lexer::takeOffset() { - if (auto result = keyword(next())) { - if (result->span.substr(0, 7) != "offset="sv) { - return std::nullopt; - } - Lexer subLexer(result->span.substr(7)); - if (auto o = subLexer.takeU64()) { - pos += result->span.size(); - advance(); - return o; - } - } - return std::nullopt; -} - -std::optional Lexer::takeAlign() { - if (auto result = keyword(next())) { - if (result->span.substr(0, 6) != "align="sv) { - return std::nullopt; - } - Lexer subLexer(result->span.substr(6)); - if (auto o = subLexer.takeU32()) { - if (Bits::popCount(*o) != 1) { - return std::nullopt; - } - pos += result->span.size(); - advance(); - return o; - } - } - return std::nullopt; -} - -template std::optional Lexer::takeU() { - static_assert(std::is_integral_v && std::is_unsigned_v); - if (auto result = integer(next()); result && result->isUnsigned()) { - pos += result->span.size(); - advance(); - return T(result->n); - } - // TODO: Add error production for unsigned overflow. - return std::nullopt; -} - -template std::optional Lexer::takeS() { - static_assert(std::is_integral_v && std::is_signed_v); - if (auto result = integer(next()); result && result->isSigned()) { - pos += result->span.size(); - advance(); - return T(result->n); - } - return std::nullopt; -} - -template std::optional Lexer::takeI() { - static_assert(std::is_integral_v && std::is_unsigned_v); - if (auto result = integer(next())) { - if (result->isUnsigned() || result->isSigned>()) { - pos += result->span.size(); - advance(); - return T(result->n); - } - } - return std::nullopt; -} - -template std::optional Lexer::takeU(); -template std::optional Lexer::takeS(); -template std::optional Lexer::takeI(); -template std::optional Lexer::takeU(); -template std::optional Lexer::takeS(); -template std::optional Lexer::takeI(); -template std::optional Lexer::takeU(); -template std::optional Lexer::takeS(); -template std::optional Lexer::takeI(); -template std::optional Lexer::takeU(); -template std::optional Lexer::takeS(); -template std::optional Lexer::takeI(); - -std::optional Lexer::takeF64() { - constexpr int signif = 52; - constexpr uint64_t payloadMask = (1ull << signif) - 1; - constexpr uint64_t nanDefault = 1ull << (signif - 1); - if (auto result = float_(next())) { - double d = result->d; - if (std::isnan(d)) { - // Inject payload. - uint64_t payload = result->nanPayload ? *result->nanPayload : nanDefault; - if (payload == 0 || payload > payloadMask) { - // TODO: Add error production for out-of-bounds payload. - return std::nullopt; - } - uint64_t bits; - static_assert(sizeof(bits) == sizeof(d)); - memcpy(&bits, &d, sizeof(bits)); - bits = (bits & ~payloadMask) | payload; - memcpy(&d, &bits, sizeof(bits)); - } - pos += result->span.size(); - advance(); - return d; - } - if (auto result = integer(next())) { - pos += result->span.size(); - advance(); - if (result->sign == Neg) { - if (result->n == 0) { - return -0.0; - } - return double(int64_t(result->n)); - } - return double(result->n); - } - return std::nullopt; -} - -std::optional Lexer::takeF32() { - constexpr int signif = 23; - constexpr uint32_t payloadMask = (1u << signif) - 1; - constexpr uint64_t nanDefault = 1ull << (signif - 1); - if (auto result = float_(next())) { - float f = result->d; - if (std::isnan(f)) { - // Validate and inject payload. - uint64_t payload = result->nanPayload ? *result->nanPayload : nanDefault; - if (payload == 0 || payload > payloadMask) { - // TODO: Add error production for out-of-bounds payload. - return std::nullopt; - } - uint32_t bits; - static_assert(sizeof(bits) == sizeof(f)); - memcpy(&bits, &f, sizeof(bits)); - bits = (bits & ~payloadMask) | payload; - memcpy(&f, &bits, sizeof(bits)); - } - pos += result->span.size(); - advance(); - return f; - } - if (auto result = integer(next())) { - pos += result->span.size(); - advance(); - if (result->sign == Neg) { - if (result->n == 0) { - return -0.0f; - } - return float(int64_t(result->n)); - } - return float(result->n); - } - return std::nullopt; -} - -TextPos Lexer::position(const char* c) const { - assert(size_t(c - buffer.data()) <= buffer.size()); - TextPos pos{1, 0}; - for (const char* p = buffer.data(); p != c; ++p) { - if (*p == '\n') { - pos.line++; - pos.col = 0; - } else { - pos.col++; - } - } - return pos; -} - -bool TextPos::operator==(const TextPos& other) const { - return line == other.line && col == other.col; -} - -std::ostream& operator<<(std::ostream& os, const TextPos& pos) { - return os << pos.line << ":" << pos.col; -} - -} // namespace wasm::WATParser diff --git a/src/parser/lexer.h b/src/parser/lexer.h index ac6549f0de8..d6ae0f429eb 100644 --- a/src/parser/lexer.h +++ b/src/parser/lexer.h @@ -14,34 +14,42 @@ * limitations under the License. */ +#ifndef parser_lexer_h +#define parser_lexer_h + +#include +#include #include #include #include -#include #include #include +#include #include #include +#include +#include "support/bits.h" #include "support/name.h" #include "support/result.h" #include "support/string.h" -#ifndef parser_lexer_h -#define parser_lexer_h - namespace wasm::WATParser { struct TextPos { size_t line; size_t col; - bool operator==(const TextPos& other) const; + bool operator==(const TextPos& other) const { + return line == other.line && col == other.col; + } bool operator!=(const TextPos& other) const { return !(*this == other); } - - friend std::ostream& operator<<(std::ostream& os, const TextPos& pos); }; +inline std::ostream& operator<<(std::ostream& os, const TextPos& pos) { + return os << pos.line << ":" << pos.col; +} + // =========== // Annotations // =========== @@ -51,8 +59,6 @@ struct Annotation { std::string_view contents; }; -extern Name srcAnnotationKind; - // ===== // Lexer // ===== @@ -66,10 +72,8 @@ struct Lexer { public: std::string_view buffer; - Lexer(std::string_view buffer, std::optional file = std::nullopt) - : file(file), buffer(buffer) { - setPos(0); - } + Lexer(std::string_view buffer, + std::optional file = std::nullopt); size_t getPos() const { return pos; } @@ -80,40 +84,23 @@ struct Lexer { std::optional peekChar() const; + bool peekLParen() { return !empty() && peek() == '('; } + bool takeLParen(); - bool peekLParen() { return Lexer(*this).takeLParen(); } + bool peekRParen() { return !empty() && peek() == ')'; } bool takeRParen(); - bool peekRParen() { return Lexer(*this).takeRParen(); } - - bool takeUntilParen() { - while (true) { - if (empty()) { - return false; - } - if (peekLParen() || peekRParen()) { - return true; - } - // Do not count the parentheses in strings. - if (takeString()) { - continue; - } - ++pos; - advance(); - } - } + bool takeUntilParen(); std::optional takeID(); + std::optional peekKeyword(); + std::optional takeKeyword(); bool takeKeyword(std::string_view expected); - std::optional peekKeyword() { - return Lexer(*this).takeKeyword(); - } - std::optional takeOffset(); std::optional takeAlign(); @@ -125,62 +112,38 @@ struct Lexer { std::optional takeU8() { return takeU(); } std::optional takeI8() { return takeI(); } - std::optional takeF64(); std::optional takeF32(); + std::optional takeF64(); std::optional takeString(); - std::optional takeName() { - auto str = takeString(); - if (!str || !String::isUTF8(*str)) { - return std::nullopt; - } - return Name(*str); - } + std::optional takeName(); - bool takeSExprStart(std::string_view expected) { - auto original = *this; - if (takeLParen() && takeKeyword(expected)) { - return true; - } - *this = original; - return false; - } + bool takeSExprStart(std::string_view expected); - bool peekSExprStart(std::string_view expected) { - auto original = *this; - if (!takeLParen()) { - return false; - } - bool ret = takeKeyword(expected); - *this = original; - return ret; - } + bool peekSExprStart(std::string_view expected); std::string_view next() const { return buffer.substr(pos); } + uint8_t peek() const { return buffer[pos]; } + void advance() { annotations.clear(); skipSpace(); } bool empty() const { return pos == buffer.size(); } + size_t remaining() const { return buffer.size() - pos; } TextPos position(const char* c) const; + TextPos position(size_t i) const { return position(buffer.data() + i); } TextPos position(std::string_view span) const { return position(span.data()); } TextPos position() const { return position(getPos()); } - [[nodiscard]] Err err(size_t pos, std::string reason) { - std::stringstream msg; - if (file) { - msg << *file << ":"; - } - msg << position(pos) << ": error: " << reason; - return Err{msg.str()}; - } + [[nodiscard]] Err err(size_t pos, std::string reason); [[nodiscard]] Err err(std::string reason) { return err(getPos(), reason); } @@ -192,13 +155,1005 @@ struct Lexer { } private: + // Whether the unlexed input starts with prefix `sv`. + size_t startsWith(std::string_view sv) const { + return next().starts_with(sv); + } + + // Consume the next `n` characters. + void take(size_t n) { pos += n; } + void takeAll() { pos = buffer.size(); } + + std::optional getDigit(char c); + + std::optional getHexDigit(char c); + + // Consume the prefix and return true if possible. + bool takePrefix(std::string_view sv); + + std::optional takeDigit(); + + std::optional takeHexdigit(); + + enum OverflowBehavior { DisallowOverflow, IgnoreOverflow }; + + std::optional takeNum(OverflowBehavior behavior = DisallowOverflow); + + std::optional + takeHexnum(OverflowBehavior behavior = DisallowOverflow); + + enum Sign { NoSign, Pos, Neg }; + + Sign takeSign(); + + struct LexedInteger { + uint64_t n; + Sign sign; + + template bool isUnsigned(); + template bool isSigned(); + }; + + std::optional takeInteger(); + template std::optional takeU(); + template std::optional takeS(); + template std::optional takeI(); + std::optional takeDecfloat(); + + std::optional takeHexfloat(); + + struct LexedFloat { + std::optional nanPayload; + double d; + }; + + std::optional takeFloat(); + + struct StringOrView : std::variant { + using std::variant::variant; + std::string_view str() const { + return std::visit([](auto& s) -> std::string_view { return s; }, *this); + } + }; + + std::optional takeStr(); + + bool idchar(); + + std::optional takeIdent(); + + bool spacechar(); + + bool takeSpacechar(); + + bool takeComment(); + + bool takeSpace(); + + std::optional takeAnnotation(); + void skipSpace(); + + bool canFinish(); }; +inline Lexer::Lexer(std::string_view buffer, std::optional file) + : file(file), buffer(buffer) { + setPos(0); +} + +inline std::optional Lexer::peekChar() const { + if (!empty()) { + return peek(); + } + return std::nullopt; +} + +inline bool Lexer::takeLParen() { + if (peekLParen()) { + take(1); + advance(); + return true; + } + return false; +} + +inline bool Lexer::takeRParen() { + if (peekRParen()) { + take(1); + advance(); + return true; + } + return false; +} + +inline bool Lexer::takeUntilParen() { + while (true) { + if (empty()) { + return false; + } + if (peekLParen() || peekRParen()) { + return true; + } + // Do not count the parentheses in strings. + if (takeString()) { + continue; + } + ++pos; + advance(); + } +} + +inline std::optional Lexer::takeID() { + if (auto result = takeIdent()) { + auto name = Name(result->str()); + advance(); + return name; + } + return std::nullopt; +} + +inline std::optional Lexer::peekKeyword() { + if (empty()) { + return std::nullopt; + } + auto startPos = pos; + uint8_t start = peek(); + if ('a' <= start && start <= 'z') { + take(1); + } else { + return std::nullopt; + } + while (idchar()) { + take(1); + } + auto ret = buffer.substr(startPos, pos - startPos); + pos = startPos; + return ret; +} + +inline std::optional Lexer::takeKeyword() { + auto keyword = peekKeyword(); + if (keyword) { + take(keyword->size()); + advance(); + } + return keyword; +} + +inline bool Lexer::takeKeyword(std::string_view expected) { + if (!startsWith(expected)) { + return false; + } + auto startPos = pos; + take(expected.size()); + if (canFinish()) { + advance(); + return true; + } + pos = startPos; + return false; +} + +inline std::optional Lexer::takeOffset() { + using namespace std::string_view_literals; + auto startPos = pos; + if (auto offset = takeKeyword()) { + if (!offset->starts_with("offset="sv)) { + pos = startPos; + return std::nullopt; + } + Lexer subLexer(offset->substr(7)); + if (auto o = subLexer.takeU64()) { + advance(); + return o; + } + } + return std::nullopt; +} + +inline std::optional Lexer::takeAlign() { + using namespace std::string_view_literals; + auto startPos = pos; + if (auto result = takeKeyword()) { + if (!result->starts_with("align="sv)) { + pos = startPos; + return std::nullopt; + } + Lexer subLexer(result->substr(6)); + if (auto o = subLexer.takeU32()) { + if (Bits::popCount(*o) != 1) { + pos = startPos; + return std::nullopt; + } + advance(); + return o; + } + } + return std::nullopt; +} + +inline std::optional Lexer::takeF32() { + constexpr int signif = 23; + constexpr uint32_t payloadMask = (1u << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + auto startPos = pos; + if (auto result = takeFloat()) { + float f = result->d; + if (std::isnan(f)) { + // Validate and inject payload. + uint64_t payload = result->nanPayload ? *result->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + pos = startPos; + return std::nullopt; + } + uint32_t bits; + static_assert(sizeof(bits) == sizeof(f)); + memcpy(&bits, &f, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&f, &bits, sizeof(bits)); + } + advance(); + return f; + } + if (auto result = takeInteger()) { + advance(); + if (result->sign == Neg) { + if (result->n == 0) { + return -0.0f; + } + return static_cast(static_cast(result->n)); + } + return static_cast(result->n); + } + return std::nullopt; +} + +inline std::optional Lexer::takeF64() { + constexpr int signif = 52; + constexpr uint64_t payloadMask = (1ull << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + auto startPos = pos; + if (auto result = takeFloat()) { + double d = result->d; + if (std::isnan(d)) { + // Inject payload. + uint64_t payload = result->nanPayload ? *result->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + pos = startPos; + return std::nullopt; + } + uint64_t bits; + static_assert(sizeof(bits) == sizeof(d)); + memcpy(&bits, &d, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&d, &bits, sizeof(bits)); + } + advance(); + return d; + } + if (auto result = takeInteger()) { + advance(); + if (result->sign == Neg) { + if (result->n == 0) { + return -0.0; + } + return static_cast(static_cast(result->n)); + } + return static_cast(result->n); + } + return std::nullopt; +} + +inline std::optional Lexer::takeString() { + if (auto str = takeStr()) { + advance(); + if (auto* s = std::get_if(&*str)) { + return std::move(*s); + } + auto view = std::get(*str); + return std::string(view); + } + return std::nullopt; +} + +inline std::optional Lexer::takeName() { + auto str = takeString(); + if (!str || !String::isUTF8(*str)) { + return std::nullopt; + } + return Name(*str); +} + +inline bool Lexer::takeSExprStart(std::string_view expected) { + auto original = *this; + if (takeLParen() && takeKeyword(expected)) { + return true; + } + *this = original; + return false; +} + +inline bool Lexer::peekSExprStart(std::string_view expected) { + auto original = *this; + if (!takeLParen()) { + return false; + } + bool ret = takeKeyword(expected); + *this = original; + return ret; +} + +inline TextPos Lexer::position(const char* c) const { + assert(size_t(c - buffer.data()) <= buffer.size()); + TextPos pos{1, 0}; + for (const char* p = buffer.data(); p != c; ++p) { + if (*p == '\n') { + pos.line++; + pos.col = 0; + } else { + pos.col++; + } + } + return pos; +} + +inline Err Lexer::err(size_t pos, std::string reason) { + std::stringstream msg; + if (file) { + msg << *file << ":"; + } + msg << position(pos) << ": error: " << reason; + return Err{msg.str()}; +} + +inline std::optional Lexer::getDigit(char c) { + if ('0' <= c && c <= '9') { + return c - '0'; + } + return std::nullopt; +} + +inline std::optional Lexer::getHexDigit(char c) { + if (auto d = getDigit(c)) { + return d; + } + if ('A' <= c && c <= 'F') { + return 10 + c - 'A'; + } + if ('a' <= c && c <= 'f') { + return 10 + c - 'a'; + } + return std::nullopt; +} + +inline bool Lexer::takePrefix(std::string_view sv) { + if (startsWith(sv)) { + take(sv.size()); + return true; + } + return false; +} + +inline std::optional Lexer::takeDigit() { + if (empty()) { + return std::nullopt; + } + if (auto d = getDigit(peek())) { + take(1); + return d; + } + return std::nullopt; +} + +inline std::optional Lexer::takeHexdigit() { + if (empty()) { + return std::nullopt; + } + if (auto h = getHexDigit(peek())) { + take(1); + return h; + } + return std::nullopt; +} + +inline std::optional Lexer::takeNum(OverflowBehavior behavior) { + using namespace std::string_view_literals; + auto startPos = pos; + bool overflow = false; + uint64_t n = 0; + if (auto d = takeDigit()) { + n = *d; + } else { + return std::nullopt; + } + while (true) { + bool under = takePrefix("_"sv); + if (auto d = takeDigit()) { + uint64_t newN = n * 10 + *d; + if (newN < n) { + overflow = true; + } + n = newN; + continue; + } + if (!under && (!overflow || behavior == IgnoreOverflow)) { + return n; + } + // TODO: Add error productions for trailing underscore and overflow. + pos = startPos; + return std::nullopt; + } +} + +inline std::optional Lexer::takeHexnum(OverflowBehavior behavior) { + using namespace std::string_view_literals; + auto startPos = pos; + bool overflow = false; + uint64_t n = 0; + if (auto d = takeHexdigit()) { + n = *d; + } else { + return std::nullopt; + } + while (true) { + bool under = takePrefix("_"sv); + if (auto d = takeHexdigit()) { + uint64_t newN = n * 16 + *d; + if (newN < n) { + overflow = true; + } + n = newN; + continue; + } + if (!under && (!overflow || behavior == IgnoreOverflow)) { + return n; + } + // TODO: Add error productions for trailing underscore and overflow. + pos = startPos; + return std::nullopt; + } +} + +inline Lexer::Sign Lexer::takeSign() { + auto c = peek(); + if (c == '+') { + take(1); + return Pos; + } + if (c == '-') { + take(1); + return Neg; + } + return NoSign; +} + +template bool Lexer::LexedInteger::isUnsigned() { + static_assert(std::is_integral_v && std::is_unsigned_v); + return sign == NoSign && n <= std::numeric_limits::max(); +} + +template bool Lexer::LexedInteger::isSigned() { + static_assert(std::is_integral_v && std::is_signed_v); + if (sign == Neg) { + // Absolute value of min() for two's complement integers is max() + 1. + uint64_t absMin = uint64_t(std::numeric_limits::max()) + 1; + return n <= absMin; + } + return n <= uint64_t(std::numeric_limits::max()); +} + +inline std::optional Lexer::takeInteger() { + using namespace std::string_view_literals; + auto startPos = pos; + auto sign = takeSign(); + if (takePrefix("0x"sv)) { + if (auto n = takeHexnum()) { + if (canFinish()) { + return LexedInteger{*n, sign}; + } + } + // TODO: Add error production for unrecognized hexnum. + pos = startPos; + return std::nullopt; + } + if (auto n = takeNum()) { + if (canFinish()) { + return LexedInteger{*n, sign}; + } + } + pos = startPos; + return std::nullopt; +} + +template std::optional Lexer::takeU() { + static_assert(std::is_integral_v && std::is_unsigned_v); + auto startPos = pos; + if (auto result = takeInteger(); result && result->isUnsigned()) { + advance(); + return static_cast(result->n); + } + // TODO: Add error production for unsigned overflow. + pos = startPos; + return std::nullopt; +} + +template std::optional Lexer::takeS() { + static_assert(std::is_integral_v && std::is_signed_v); + auto startPos = pos; + if (auto result = takeInteger(); result && result->isSigned()) { + advance(); + if (result->sign == Neg) { + return static_cast(-result->n); + } + return static_cast(result->n); + } + pos = startPos; + return std::nullopt; +} + +template std::optional Lexer::takeI() { + static_assert(std::is_integral_v && std::is_unsigned_v); + auto startPos = pos; + if (auto result = takeInteger()) { + if (result->isUnsigned() || result->isSigned>()) { + advance(); + if (result->sign == Neg) { + return static_cast(-result->n); + } + return static_cast(result->n); + } + } + pos = startPos; + return std::nullopt; +} + +inline std::optional Lexer::takeDecfloat() { + using namespace std::string_view_literals; + auto startPos = pos; + if (!takeNum(IgnoreOverflow)) { + return std::nullopt; + } + // Optional '.' followed by optional frac + if (takePrefix("."sv)) { + takeNum(IgnoreOverflow); + } + if (takePrefix("E"sv) || takePrefix("e"sv)) { + // Optional sign + takeSign(); + if (!takeNum(IgnoreOverflow)) { + // TODO: Add error production for missing exponent. + pos = startPos; + return std::nullopt; + } + } + return buffer.substr(startPos, pos - startPos); +} + +inline std::optional Lexer::takeHexfloat() { + using namespace std::string_view_literals; + auto startPos = pos; + if (!takePrefix("0x"sv)) { + return std::nullopt; + } + if (!takeHexnum(IgnoreOverflow)) { + pos = startPos; + return std::nullopt; + } + // Optional '.' followed by optional hexfrac + if (takePrefix("."sv)) { + takeHexnum(IgnoreOverflow); + } + if (takePrefix("P"sv) || takePrefix("p"sv)) { + // Optional sign + takeSign(); + if (!takeNum(IgnoreOverflow)) { + // TODO: Add error production for missing exponent. + pos = startPos; + return std::nullopt; + } + } + return buffer.substr(startPos, pos - startPos); +} + +inline std::optional Lexer::takeFloat() { + using namespace std::string_view_literals; + auto startPos = pos; + std::optional nanPayload; + bool isNan = false; + // Optional sign + auto sign = takeSign(); + if (takeHexfloat() || takeDecfloat() || takePrefix("inf"sv)) { + // nop. + } else if (takePrefix("nan"sv)) { + isNan = true; + if (takePrefix(":0x"sv)) { + if (auto n = takeHexnum()) { + nanPayload = n; + } else { + // TODO: Add error production for malformed NaN payload. + pos = startPos; + return std::nullopt; + } + } else { + // No explicit payload necessary; we will inject the default payload + // later. + } + } else { + pos = startPos; + return std::nullopt; + } + if (!canFinish()) { + pos = startPos; + return std::nullopt; + } + // strtod does not return NaNs with the expected signs on all platforms. + if (isNan) { + if (sign == Neg) { + const double negNan = std::copysign(NAN, -1.0); + assert(std::signbit(negNan) && "expected negative NaN to be negative"); + return LexedFloat{nanPayload, negNan}; + } else { + const double posNan = std::copysign(NAN, 1.0); + assert(!std::signbit(posNan) && "expected positive NaN to be positive"); + return LexedFloat{nanPayload, posNan}; + } + } + // Do not try to implement fully general and precise float parsing + // ourselves. Instead, call out to std::strtod to do our parsing. This means + // we need to strip any underscores since `std::strtod` does not understand + // them. + std::stringstream ss; + for (const char *curr = &buffer[startPos], *end = &buffer[pos]; curr != end; + ++curr) { + if (*curr != '_') { + ss << *curr; + } + } + std::string str = ss.str(); + char* last; + double d = std::strtod(str.data(), &last); + assert(last == str.data() + str.size() && "could not parse float"); + return LexedFloat{std::nullopt, d}; +} + +inline std::optional Lexer::takeStr() { + using namespace std::string_view_literals; + auto startPos = pos; + if (!takePrefix("\""sv)) { + return std::nullopt; + } + // Used to build a string with resolved escape sequences. Only used when the + // parsed string contains escape sequences, otherwise we can just use the + // parsed string directly. + std::optional escapeBuilder; + auto ensureBuildingEscaped = [&]() { + if (escapeBuilder) { + return; + } + // Drop the opening '"'. + escapeBuilder = std::stringstream{}; + *escapeBuilder << buffer.substr(startPos + 1, pos - startPos - 1); + }; + while (!takePrefix("\""sv)) { + if (empty()) { + // TODO: Add error production for unterminated string. + pos = startPos; + return std::nullopt; + } + if (startsWith("\\"sv)) { + // Escape sequences + ensureBuildingEscaped(); + take(1); + auto c = peek(); + take(1); + switch (c) { + case 't': + *escapeBuilder << '\t'; + break; + case 'n': + *escapeBuilder << '\n'; + break; + case 'r': + *escapeBuilder << '\r'; + break; + case '\\': + *escapeBuilder << '\\'; + break; + case '"': + *escapeBuilder << '"'; + break; + case '\'': + *escapeBuilder << '\''; + break; + case 'u': { + if (!takePrefix("{"sv)) { + pos = startPos; + return std::nullopt; + } + auto code = takeHexnum(); + if (!code) { + // TODO: Add error production for malformed unicode escapes. + pos = startPos; + return std::nullopt; + } + if (!takePrefix("}"sv)) { + // TODO: Add error production for malformed unicode escapes. + pos = startPos; + return std::nullopt; + } + if ((0xd800 <= *code && *code < 0xe000) || 0x110000 <= *code) { + // TODO: Add error production for invalid unicode values. + pos = startPos; + return std::nullopt; + } + String::writeWTF8CodePoint(*escapeBuilder, *code); + break; + } + default: { + // Byte escape: \hh + // We already took the first h as c. + auto first = getHexDigit(c); + auto second = takeHexdigit(); + if (!first || !second) { + // TODO: Add error production for unrecognized escape sequence. + pos = startPos; + return std::nullopt; + } + *escapeBuilder << char(*first * 16 + *second); + } + } + } else { + // Normal characters + if (uint8_t c = peek(); c >= 0x20 && c != 0x7F) { + if (escapeBuilder) { + *escapeBuilder << c; + } + take(1); + } else { + // TODO: Add error production for unescaped control characters. + pos = startPos; + return std::nullopt; + } + } + } + if (escapeBuilder) { + return escapeBuilder->str(); + } + // Drop the quotes. + return buffer.substr(startPos + 1, pos - startPos - 2); +} + +inline bool Lexer::idchar() { + if (empty()) { + return false; + } + uint8_t c = peek(); + // All the allowed characters lie in the range '!' to '~', and within that + // range the vast majority of characters are allowed, so it is significantly + // faster to check for the disallowed characters instead. + if (c < '!' || c > '~') { + return false; + } + switch (c) { + case '"': + case '(': + case ')': + case ',': + case ';': + case '[': + case ']': + case '{': + case '}': + return false; + } + return true; +} + +inline std::optional Lexer::takeIdent() { + using namespace std::string_view_literals; + auto startPos = pos; + if (!takePrefix("$"sv)) { + return {}; + } + // Quoted identifier e.g. $"foo" + std::optional str; + if ((str = takeStr())) { + if (str->str().empty() || !String::isUTF8(str->str())) { + pos = startPos; + return std::nullopt; + } + } else if (idchar()) { + take(1); + while (idchar()) { + take(1); + } + } else { + pos = startPos; + return std::nullopt; + } + if (canFinish()) { + if (str) { + return str; + } + // Drop the "$". + return buffer.substr(startPos + 1, pos - startPos - 1); + } + pos = startPos; + return std::nullopt; +} + +inline bool Lexer::spacechar() { + if (empty()) { + return false; + } + switch (peek()) { + case ' ': + case '\n': + case '\r': + case '\t': + return true; + default: + return false; + } +} + +inline bool Lexer::takeSpacechar() { + if (spacechar()) { + take(1); + return true; + } + return false; +} + +inline bool Lexer::takeComment() { + using namespace std::string_view_literals; + + if (remaining() < 2) { + return false; + } + + // Line comment + if (!startsWith(";;@"sv) && takePrefix(";;"sv)) { + if (auto size = next().find('\n'); size != ""sv.npos) { + take(size); + } else { + takeAll(); + } + return true; + } + + // Block comment (possibly nested!) + if (takePrefix("(;"sv)) { + size_t depth = 1; + while (depth > 0 && remaining() >= 2) { + if (takePrefix("(;"sv)) { + ++depth; + } else if (takePrefix(";)"sv)) { + --depth; + } else { + take(1); + } + } + if (depth > 0) { + // TODO: Add error production for non-terminated block comment. + return false; + } + return true; + } + + return false; +} + +inline bool Lexer::takeSpace() { + bool taken = false; + while (remaining() && (takeSpacechar() || takeComment())) { + taken = true; + continue; + } + return taken; +} + +inline std::optional Lexer::takeAnnotation() { + using namespace std::string_view_literals; + auto startPos = pos; + std::string_view kind; + std::string_view contents; + if (takePrefix(";;@"sv)) { + kind = "src"sv; + auto contentPos = pos; + if (auto size = next().find('\n'); size != ""sv.npos) { + take(size); + } else { + takeAll(); + } + contents = buffer.substr(contentPos, pos - contentPos); + } else if (takePrefix("(@"sv)) { + auto kindPos = pos; + bool hasIdchar = false; + while (idchar()) { + take(1); + hasIdchar = true; + } + if (!hasIdchar) { + pos = startPos; + return std::nullopt; + } + kind = buffer.substr(kindPos, pos - kindPos); + auto contentPos = pos; + size_t depth = 1; + while (true) { + if (empty()) { + pos = startPos; + return std::nullopt; + } + if (takeSpace() || takeKeyword() || takeInteger() || takeFloat() || + takeStr() || takeIdent()) { + continue; + } + if (takePrefix("(@"sv)) { + bool hasIdchar = false; + while (idchar()) { + take(1); + hasIdchar = true; + } + if (!hasIdchar) { + pos = startPos; + return std::nullopt; + } + ++depth; + continue; + } + if (takeLParen()) { + ++depth; + continue; + } + if (takePrefix(")"sv)) { + --depth; + if (depth == 0) { + break; + } + continue; + } + // Unrecognized token. + pos = startPos; + return std::nullopt; + } + contents = buffer.substr(contentPos, pos - contentPos - 1); + } else { + return std::nullopt; + } + return Annotation{Name(kind), contents}; +} + +inline void Lexer::skipSpace() { + while (true) { + if (auto annotation = takeAnnotation()) { + annotations.emplace_back(*std::move(annotation)); + continue; + } + if (takeSpace()) { + continue; + } + break; + } +} + +inline bool Lexer::canFinish() { + // Logically we want to check for eof, parens, and space. But we don't + // actually want to parse more than a couple characters of space, so check + // for individual space chars or comment starts instead. + using namespace std::string_view_literals; + return empty() || spacechar() || peek() == '(' || peek() == ')' || + startsWith(";;"sv); +} + } // namespace wasm::WATParser #endif // parser_lexer_h diff --git a/test/gtest/wat-lexer.cpp b/test/gtest/wat-lexer.cpp index a6f4f6d6cfe..3a4cd49e246 100644 --- a/test/gtest/wat-lexer.cpp +++ b/test/gtest/wat-lexer.cpp @@ -931,6 +931,17 @@ TEST(LexerTest, LexString) { EXPECT_FALSE(Lexer("\"too big \\u{110000}\""sv).takeString()); } +TEST(LexerTest, Annotations) { + Lexer lexer( + " (@metadata.code.branch_hint \"\\01\")\n (@metadata.code.branch_hint \"\\00\")\n (br_if $out"sv); + // Trigger advance/skipSpace which parses annotations. + lexer.takeID(); + auto annotations = lexer.takeAnnotations(); + ASSERT_EQ(annotations.size(), 2u); + EXPECT_EQ(annotations[0].contents, " \"\\01\""sv); + EXPECT_EQ(annotations[1].contents, " \"\\00\""sv); +} + TEST(LexerTest, LexKeywords) { Lexer lexer("module type func import rEsErVeD"); ASSERT_EQ(lexer.takeKeyword(), "module"sv); From b2528adfd129a6b545720f3e4d6d32612c62ce1e Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Mon, 13 Apr 2026 22:21:41 -0700 Subject: [PATCH 2/3] fixes --- src/parser/lexer.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/parser/lexer.h b/src/parser/lexer.h index d6ae0f429eb..3a9fd74e27a 100644 --- a/src/parser/lexer.h +++ b/src/parser/lexer.h @@ -353,6 +353,7 @@ inline std::optional Lexer::takeOffset() { return o; } } + pos = startPos; return std::nullopt; } @@ -374,6 +375,7 @@ inline std::optional Lexer::takeAlign() { return o; } } + pos = startPos; return std::nullopt; } @@ -407,7 +409,7 @@ inline std::optional Lexer::takeF32() { if (result->n == 0) { return -0.0f; } - return static_cast(static_cast(result->n)); + return -static_cast(result->n); } return static_cast(result->n); } @@ -444,7 +446,7 @@ inline std::optional Lexer::takeF64() { if (result->n == 0) { return -0.0; } - return static_cast(static_cast(result->n)); + return -static_cast(result->n); } return static_cast(result->n); } From 6c3f3c93423d30a65de7ce986356402b6480753f Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Wed, 8 Apr 2026 12:13:13 -0700 Subject: [PATCH 3/3] [NFC] Skip parsing instructions in first parser pass The first parser pass is responsible for two things: finding the locations of definitions of top-level module items like globals and functions and finding the locations of implicit function type definitions. It previously accomplished the latter by fully parsing every instruction in each function. But the IR is not constructed in this phase of parsing, so fully parsing every instruction was largely wasted work. Optimize the parser by parsing only the instructions that might have implicit type definitions and otherwise just blindly match parentheses to skip the function body. Combined with #8597, this speeds up parsing by 30-40%. --- src/parser/context-decls.cpp | 33 +++++++++++++++++++++++++++++++++ src/parser/contexts.h | 2 ++ src/parser/lexer.h | 18 +++++++++--------- src/parser/parsers.h | 6 ++++-- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/parser/context-decls.cpp b/src/parser/context-decls.cpp index 252185d6634..73327bd1640 100644 --- a/src/parser/context-decls.cpp +++ b/src/parser/context-decls.cpp @@ -15,6 +15,7 @@ */ #include "contexts.h" +#include "parsers.h" namespace wasm::WATParser { @@ -302,4 +303,36 @@ Result<> ParseDeclsCtx::addTag(Name name, return Ok{}; } +bool ParseDeclsCtx::skipFunctionBody() { + using namespace std::string_view_literals; + size_t depth = 1; + while (depth > 0 && !in.empty()) { + if (in.takeLParen()) { + ++depth; + continue; + } + if (in.takeRParen()) { + --depth; + continue; + } + if (auto kw = in.takeKeyword()) { + if (*kw == "block"sv || *kw == "loop"sv || *kw == "if"sv || + *kw == "try"sv || *kw == "try_table"sv) { + in.takeID(); + (void)typeuse(*this); + continue; + } + if (*kw == "call_indirect"sv || *kw == "return_call_indirect"sv) { + (void)maybeTableidx(*this); + (void)typeuse(*this, false); + continue; + } + continue; + } + in.take(1); + in.advance(); + } + return true; +} + } // namespace wasm::WATParser diff --git a/src/parser/contexts.h b/src/parser/contexts.h index eb09a0bb3b0..88b7d8a941c 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -1074,6 +1074,8 @@ struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx { recTypeDefs.push_back({{}, pos, Index(recTypeDefs.size()), {}}); } + bool skipFunctionBody(); + Limits makeLimits(uint64_t n, std::optional m) { return Limits{n, m}; } diff --git a/src/parser/lexer.h b/src/parser/lexer.h index 3a9fd74e27a..2f5cb7a0291 100644 --- a/src/parser/lexer.h +++ b/src/parser/lexer.h @@ -82,6 +82,15 @@ struct Lexer { advance(); } + // Consume the next `n` characters. + void take(size_t n) { pos += n; } + void takeAll() { pos = buffer.size(); } + + // Whether the unlexed input starts with prefix `sv`. + size_t startsWith(std::string_view sv) const { + return next().starts_with(sv); + } + std::optional peekChar() const; bool peekLParen() { return !empty() && peek() == '('; } @@ -155,15 +164,6 @@ struct Lexer { } private: - // Whether the unlexed input starts with prefix `sv`. - size_t startsWith(std::string_view sv) const { - return next().starts_with(sv); - } - - // Consume the next `n` characters. - void take(size_t n) { pos += n; } - void takeAll() { pos = buffer.size(); } - std::optional getDigit(char c); std::optional getHexDigit(char c); diff --git a/src/parser/parsers.h b/src/parser/parsers.h index db41a2534cc..4a5fb71b771 100644 --- a/src/parser/parsers.h +++ b/src/parser/parsers.h @@ -3491,6 +3491,7 @@ template MaybeResult<> func(Ctx& ctx) { typename Ctx::TypeUseT type; Exactness exact = Exact; std::optional localVars; + bool skipped = false; if (import) { auto use = exacttypeuse(ctx); @@ -3505,13 +3506,14 @@ template MaybeResult<> func(Ctx& ctx) { CHECK_ERR(l); localVars = *l; } - if (!ctx.skipFunctionBody()) { + skipped = ctx.skipFunctionBody(); + if (!skipped) { CHECK_ERR(instrs(ctx)); ctx.setSrcLoc(ctx.in.takeAnnotations()); } } - if (!ctx.skipFunctionBody() && !ctx.in.takeRParen()) { + if ((import || !skipped) && !ctx.in.takeRParen()) { return ctx.in.err("expected end of function"); }