Skip to content

Commit

Permalink
Improve comp graph caching
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Oct 20, 2024
1 parent 87f5a4b commit 237593d
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 43 deletions.
33 changes: 21 additions & 12 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,13 @@
split_leading_indent,
split_trailing_indent,
split_leading_trailing_indent,
match_in,
transform,
parse,
transform,
cached_parse,
cached_try_parse,
cached_does_parse,
cached_parse_where,
cached_match_in,
get_target_info_smart,
split_leading_comments,
compile_regex,
Expand All @@ -184,14 +187,11 @@
handle_indentation,
tuple_str_of,
join_args,
parse_where,
get_highest_parse_loc,
literal_eval,
should_trim_arity,
rem_and_count_indents,
normalize_indent_markers,
try_parse,
does_parse,
prep_grammar,
ordered,
tuple_str_of_str,
Expand Down Expand Up @@ -1266,7 +1266,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor
causes = dictset()
for check_loc in dictset((loc, endpoint, startpoint)):
if check_loc is not None:
cause = try_parse(self.parse_err_msg, original[check_loc:], inner=True)
cause = self.cached_try_parse("make_err", self.parse_err_msg, original[check_loc:], inner=True, cache_prefixes=True)
if cause:
causes.add(cause)
if causes:
Expand Down Expand Up @@ -1633,11 +1633,19 @@ def cached_parse(self, call_site_name, parser, text, **kwargs):

def cached_try_parse(self, call_site_name, parser, text, **kwargs):
"""Call cached_try_parse using self.computation_graph_caches."""
return try_parse(parser, text, computation_graph_cache=self.computation_graph_caches[(call_site_name, parser)], **kwargs)
return cached_try_parse(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs)

def cached_does_parse(self, call_site_name, parser, text, **kwargs):
"""Call cached_does_parse using self.computation_graph_caches."""
return does_parse(parser, text, computation_graph_cache=self.computation_graph_caches[(call_site_name, parser)], **kwargs)
return cached_does_parse(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs)

def cached_parse_where(self, call_site_name, parser, text, **kwargs):
"""Call cached_parse_where using self.computation_graph_caches."""
return cached_parse_where(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs)

def cached_match_in(self, call_site_name, parser, text, **kwargs):
"""Call cached_match_in using self.computation_graph_caches."""
return cached_match_in(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs)

def parse_line_by_line(self, init_parser, line_parser, original):
"""Apply init_parser then line_parser repeatedly."""
Expand All @@ -1657,6 +1665,7 @@ def parse_line_by_line(self, init_parser, line_parser, original):
parser,
self.remaining_original,
inner=False,
cache_prefixes=True,
)
if len(results) == 1:
got_loc, = results
Expand Down Expand Up @@ -1800,7 +1809,7 @@ def str_proc(self, inputstring, **kwargs):
if hold.get("in_expr", False):
internal_assert(is_f, "in_expr should only be for f string holds, not", hold)
remaining_text = inputstring[i:]
str_start, str_stop = parse_where(self.string_start, remaining_text)
str_start, str_stop = self.cached_parse_where("str_proc", self.string_start, remaining_text, cache_prefixes=True)
if str_start is not None: # str_start >= 0; if > 0 means there is whitespace before the string
hold["exprs"][-1] += remaining_text[:str_stop]
# add any skips from where we're fast-forwarding (except don't include c since we handle that below)
Expand All @@ -1813,7 +1822,7 @@ def str_proc(self, inputstring, **kwargs):
hold["exprs"][-1] += c
elif hold["paren_level"] > 0:
raise self.make_err(CoconutSyntaxError, "imbalanced parentheses in format string expression", inputstring, i, reformat=False)
elif self.cached_does_parse("str_proc", self.end_f_str_expr, remaining_text):
elif self.cached_does_parse("str_proc", self.end_f_str_expr, remaining_text, cache_prefixes=True):
hold["in_expr"] = False
hold["str_parts"].append(c)
else:
Expand Down Expand Up @@ -2393,7 +2402,7 @@ def split_docstring(self, block):
pass
else:
raw_first_line = split_leading_trailing_indent(rem_comment(first_line))[1]
if match_in(self.just_a_string, raw_first_line, inner=True):
if self.cached_match_in("split_docstring", self.just_a_string, raw_first_line, inner=True):
return first_line, rest_of_lines
return None, block

Expand Down Expand Up @@ -2527,7 +2536,7 @@ def transform_returns(self, original, loc, raw_lines, tre_return_grammar=None, i

# check if there is anything that stores a scope reference, and if so,
# disable TRE, since it can't handle that
if attempt_tre and match_in(self.stores_scope, line):
if attempt_tre and self.cached_match_in("transform_returns", self.stores_scope, line):
attempt_tre = False

# attempt tco/tre/async universalization
Expand Down
113 changes: 82 additions & 31 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def evaluate_tokens(tokens, **kwargs):
)

# base cases (performance sensitive; should be in likelihood order):
if isinstance(tokens, str):
if isinstance(tokens, (str, bool)) or tokens is None:
return tokens

elif isinstance(tokens, ComputationNode):
Expand Down Expand Up @@ -712,19 +712,41 @@ def prep_grammar(grammar, for_scan, streamline=False, add_unpack=False):
return grammar.parseWithTabs()


def parse(grammar, text, inner=None, eval_parse_tree=True):
def parse(grammar, text, inner=None, eval_parse_tree=True, **kwargs):
"""Parse text using grammar."""
with parsing_context(inner):
result = prep_grammar(grammar, for_scan=False).parseString(text)
result = prep_grammar(grammar, for_scan=False).parseString(text, **kwargs)
if eval_parse_tree:
result = unpack(result)
return result


def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_tree=True):
def all_matches(grammar, text, inner=None, eval_parse_tree=True):
"""Find all matches for grammar in text."""
kwargs = {}
if CPYPARSING and isinstance(grammar, StartOfStrGrammar):
grammar = grammar.grammar
kwargs["maxStartLoc"] = 0
with parsing_context(inner):
for tokens, start, stop in prep_grammar(grammar, for_scan=True).scanString(text, **kwargs):
if eval_parse_tree:
tokens = unpack(tokens)
yield tokens, start, stop


def cached_parse(
computation_graph_cache,
grammar,
text,
inner=None,
eval_parse_tree=True,
scan_string=False,
include_tokens=True,
cache_prefixes=False,
):
"""Version of parse that caches the result when it's a pure ComputationNode."""
if not CPYPARSING: # caching is only supported on cPyparsing
return parse(grammar, text, inner)
return (parse_where if scan_string else parse)(grammar, text, inner)

# only iterate over keys, not items, so we don't mark everything as alive
for key in computation_graph_cache:
Expand All @@ -736,7 +758,10 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_
is_at_end and text == prefix
or not is_at_end and text.startswith(prefix)
):
tokens = computation_graph_cache[key]
if scan_string:
tokens, start, stop = computation_graph_cache[key]
else:
tokens = computation_graph_cache[key]
if DEVELOP:
logger.record_stat("cached_parse", True)
logger.log_tag("cached_parse hit", (prefix, text[len(prefix):], tokens))
Expand All @@ -749,12 +774,23 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_
# parse and tells us that something greedy happened so we can't cache
final_evaluate_tokens.enabled = False
try:
with parsing_context(inner):
loc, tokens = prep_grammar(grammar, for_scan=False).parseString(text, returnLoc=True)
if scan_string:
for tokens, start, stop in all_matches(grammar, text, inner, eval_parse_tree=False):
break
else: # no break
tokens = start = stop = None
else:
stop, tokens = parse(grammar, text, inner, eval_parse_tree=False, returnLoc=True)
if not include_tokens:
tokens = bool(tokens)
if not final_evaluate_tokens.enabled:
prefix = text[:loc + 1]
is_at_end = loc >= len(text)
computation_graph_cache[(prefix, is_at_end)] = tokens
is_at_end = True if stop is None else stop >= len(text)
if cache_prefixes or is_at_end:
prefix = text if stop is None else text[:stop + 1]
if scan_string:
computation_graph_cache[(prefix, is_at_end)] = tokens, start, stop
else:
computation_graph_cache[(prefix, is_at_end)] = tokens
finally:
if DEVELOP:
logger.record_stat("cached_parse", False)
Expand All @@ -765,38 +801,40 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_
)
final_evaluate_tokens.enabled = True

if eval_parse_tree:
if include_tokens and eval_parse_tree:
tokens = unpack(tokens)
return tokens
if scan_string:
return tokens, start, stop
else:
return tokens


def try_parse(grammar, text, inner=None, eval_parse_tree=True, computation_graph_cache=None):
def try_parse(grammar, text, inner=None, eval_parse_tree=True):
"""Attempt to parse text using grammar else None."""
try:
if computation_graph_cache is None:
return parse(grammar, text, inner, eval_parse_tree)
else:
return cached_parse(computation_graph_cache, grammar, text, inner, eval_parse_tree)
return parse(grammar, text, inner, eval_parse_tree)
except ParseBaseException:
return None


def does_parse(grammar, text, inner=None, **kwargs):
def cached_try_parse(cache, grammar, text, inner=None, eval_parse_tree=True, **kwargs):
"""Cached version of try_parse."""
if not CPYPARSING: # scan_string on StartOfStrGrammar is only fast on cPyparsing
return try_parse(grammar, text, inner, eval_parse_tree)
if not isinstance(grammar, StartOfStrGrammar):
grammar = StartOfStrGrammar(grammar)
tokens, start, stop = cached_parse(cache, grammar, text, inner, eval_parse_tree, scan_string=True, **kwargs)
return tokens


def does_parse(grammar, text, inner=None):
"""Determine if text can be parsed using grammar."""
return try_parse(grammar, text, inner, eval_parse_tree=False, **kwargs)
return try_parse(grammar, text, inner, eval_parse_tree=False)


def all_matches(grammar, text, inner=None, eval_parse_tree=True):
"""Find all matches for grammar in text."""
kwargs = {}
if CPYPARSING and isinstance(grammar, StartOfStrGrammar):
grammar = grammar.grammar
kwargs["maxStartLoc"] = 0
with parsing_context(inner):
for tokens, start, stop in prep_grammar(grammar, for_scan=True).scanString(text, **kwargs):
if eval_parse_tree:
tokens = unpack(tokens)
yield tokens, start, stop
def cached_does_parse(cache, grammar, text, inner=None, **kwargs):
"""Cached version of does_parse."""
return cached_try_parse(cache, grammar, text, inner, eval_parse_tree=False, include_tokens=False, **kwargs)


def parse_where(grammar, text, inner=None):
Expand All @@ -806,13 +844,26 @@ def parse_where(grammar, text, inner=None):
return None, None


def cached_parse_where(cache, grammar, text, inner=None, **kwargs):
"""Cached version of parse_where."""
tokens, start, stop = cached_parse(cache, grammar, text, inner, scan_string=True, eval_parse_tree=False, include_tokens=False, **kwargs)
return start, stop


def match_in(grammar, text, inner=None):
"""Determine if there is a match for grammar anywhere in text."""
start, stop = parse_where(grammar, text, inner)
internal_assert((start is None) == (stop is None), "invalid parse_where results", (start, stop))
return start is not None


def cached_match_in(cache, grammar, text, inner=None, **kwargs):
"""Cached version of match_in."""
start, stop = cached_parse_where(cache, grammar, text, inner, **kwargs)
internal_assert((start is None) == (stop is None), "invalid cached_parse_where results", (start, stop))
return start is not None


def transform(grammar, text, inner=None):
"""Transform text by replacing matches to grammar."""
kwargs = {}
Expand Down

0 comments on commit 237593d

Please sign in to comment.