Skip to content

Commit 5ed0e1c

Browse files
Copilothzhangxyz
andauthored
Add ThrowingErrorListener to Python ANTLR4 parser to match JavaScript behavior (#110)
* Initial plan * Add ThrowingErrorListener to Python ANTLR4 parser matching JavaScript implementation Co-authored-by: hzhangxyz <[email protected]> * Improve error tests to use pytest.raises for better test structure Co-authored-by: hzhangxyz <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: hzhangxyz <[email protected]>
1 parent 0cc98df commit 5ed0e1c

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

bnf/apyds_bnf/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = ["parse", "unparse"]
22

33
from antlr4 import InputStream, CommonTokenStream
4+
from antlr4.error.ErrorListener import ErrorListener
45
from .DspLexer import DspLexer
56
from .DspParser import DspParser
67
from .DspVisitor import DspVisitor
@@ -9,6 +10,11 @@
910
from .DsVisitor import DsVisitor
1011

1112

13+
class ThrowingErrorListener(ErrorListener):
14+
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
15+
raise Exception(f"line {line}:{column} {msg}")
16+
17+
1218
class ParseVisitor(DspVisitor):
1319
def visitRule_pool(self, ctx):
1420
return "\n\n".join(self.visit(r) for r in ctx.rule_())
@@ -71,8 +77,12 @@ def visitBinary(self, ctx):
7177
def parse(input: str) -> str:
7278
chars = InputStream(input)
7379
lexer = DspLexer(chars)
80+
lexer.removeErrorListeners()
81+
lexer.addErrorListener(ThrowingErrorListener())
7482
tokens = CommonTokenStream(lexer)
7583
parser = DspParser(tokens)
84+
parser.removeErrorListeners()
85+
parser.addErrorListener(ThrowingErrorListener())
7686
tree = parser.rule_pool()
7787
visitor = ParseVisitor()
7888
return visitor.visit(tree)
@@ -81,8 +91,12 @@ def parse(input: str) -> str:
8191
def unparse(input: str) -> str:
8292
chars = InputStream(input)
8393
lexer = DsLexer(chars)
94+
lexer.removeErrorListeners()
95+
lexer.addErrorListener(ThrowingErrorListener())
8496
tokens = CommonTokenStream(lexer)
8597
parser = DsParser(tokens)
98+
parser.removeErrorListeners()
99+
parser.addErrorListener(ThrowingErrorListener())
86100
tree = parser.rule_pool()
87101
visitor = UnparseVisitor()
88102
return visitor.visit(tree)

bnf/tests/test_parse_unparse.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from apyds_bnf import parse, unparse
23

34

@@ -143,3 +144,38 @@ def test_roundtrip_unparse_parse() -> None:
143144
dsp_intermediate = unparse(ds_original)
144145
ds_result = parse(dsp_intermediate)
145146
assert ds_result == ds_original
147+
148+
149+
def test_parse_error_missing_closing_parenthesis() -> None:
150+
"""Test that parse throws error on missing closing parenthesis"""
151+
dsp_input = "(a + b -> c"
152+
with pytest.raises(Exception, match=r"line 1:7.*no viable alternative"):
153+
parse(dsp_input)
154+
155+
156+
def test_parse_error_bad_syntax() -> None:
157+
"""Test that parse throws error on bad syntax"""
158+
dsp_input = "a b c -> -> d"
159+
with pytest.raises(Exception, match=r"line 1:2.*mismatched input"):
160+
parse(dsp_input)
161+
162+
163+
def test_parse_error_malformed_parentheses() -> None:
164+
"""Test that parse throws error on malformed parentheses"""
165+
dsp_input = "()()()"
166+
with pytest.raises(Exception, match=r"line 1:1.*no viable alternative"):
167+
parse(dsp_input)
168+
169+
170+
def test_unparse_error_incomplete_binary() -> None:
171+
"""Test that unparse throws error on incomplete binary expression"""
172+
ds_input = "(binary"
173+
with pytest.raises(Exception, match=r"line 1:7.*mismatched input"):
174+
unparse(ds_input)
175+
176+
177+
def test_unparse_error_malformed_function() -> None:
178+
"""Test that unparse throws error on malformed function"""
179+
ds_input = "(function"
180+
with pytest.raises(Exception, match=r"line 1:9.*mismatched input"):
181+
unparse(ds_input)

0 commit comments

Comments
 (0)