diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 1186b4d9..3cb6371e 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -172,10 +172,13 @@ split_leading_indent, split_trailing_indent, split_leading_trailing_indent, - match_in, - transform, parse, + transform, cached_parse, + cached_try_parse, + cached_does_parse, + cached_parse_where, + cached_match_in, get_target_info_smart, split_leading_comments, compile_regex, @@ -184,14 +187,11 @@ handle_indentation, tuple_str_of, join_args, - parse_where, get_highest_parse_loc, literal_eval, should_trim_arity, rem_and_count_indents, normalize_indent_markers, - try_parse, - does_parse, prep_grammar, ordered, tuple_str_of_str, @@ -1266,7 +1266,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor causes = dictset() for check_loc in dictset((loc, endpoint, startpoint)): if check_loc is not None: - cause = try_parse(self.parse_err_msg, original[check_loc:], inner=True) + cause = self.cached_try_parse("make_err", self.parse_err_msg, original[check_loc:], inner=True, cache_prefixes=True) if cause: causes.add(cause) if causes: @@ -1633,11 +1633,19 @@ def cached_parse(self, call_site_name, parser, text, **kwargs): def cached_try_parse(self, call_site_name, parser, text, **kwargs): """Call cached_try_parse using self.computation_graph_caches.""" - return try_parse(parser, text, computation_graph_cache=self.computation_graph_caches[(call_site_name, parser)], **kwargs) + return cached_try_parse(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs) def cached_does_parse(self, call_site_name, parser, text, **kwargs): """Call cached_does_parse using self.computation_graph_caches.""" - return does_parse(parser, text, computation_graph_cache=self.computation_graph_caches[(call_site_name, parser)], **kwargs) + return cached_does_parse(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs) + + def cached_parse_where(self, call_site_name, parser, text, **kwargs): + """Call cached_parse_where using self.computation_graph_caches.""" + return cached_parse_where(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs) + + def cached_match_in(self, call_site_name, parser, text, **kwargs): + """Call cached_match_in using self.computation_graph_caches.""" + return cached_match_in(self.computation_graph_caches[(call_site_name, parser)], parser, text, **kwargs) def parse_line_by_line(self, init_parser, line_parser, original): """Apply init_parser then line_parser repeatedly.""" @@ -1657,6 +1665,7 @@ def parse_line_by_line(self, init_parser, line_parser, original): parser, self.remaining_original, inner=False, + cache_prefixes=True, ) if len(results) == 1: got_loc, = results @@ -1800,7 +1809,7 @@ def str_proc(self, inputstring, **kwargs): if hold.get("in_expr", False): internal_assert(is_f, "in_expr should only be for f string holds, not", hold) remaining_text = inputstring[i:] - str_start, str_stop = parse_where(self.string_start, remaining_text) + str_start, str_stop = self.cached_parse_where("str_proc", self.string_start, remaining_text, cache_prefixes=True) if str_start is not None: # str_start >= 0; if > 0 means there is whitespace before the string hold["exprs"][-1] += remaining_text[:str_stop] # add any skips from where we're fast-forwarding (except don't include c since we handle that below) @@ -1813,7 +1822,7 @@ def str_proc(self, inputstring, **kwargs): hold["exprs"][-1] += c elif hold["paren_level"] > 0: raise self.make_err(CoconutSyntaxError, "imbalanced parentheses in format string expression", inputstring, i, reformat=False) - elif self.cached_does_parse("str_proc", self.end_f_str_expr, remaining_text): + elif self.cached_does_parse("str_proc", self.end_f_str_expr, remaining_text, cache_prefixes=True): hold["in_expr"] = False hold["str_parts"].append(c) else: @@ -2393,7 +2402,7 @@ def split_docstring(self, block): pass else: raw_first_line = split_leading_trailing_indent(rem_comment(first_line))[1] - if match_in(self.just_a_string, raw_first_line, inner=True): + if self.cached_match_in("split_docstring", self.just_a_string, raw_first_line, inner=True): return first_line, rest_of_lines return None, block @@ -2527,7 +2536,7 @@ def transform_returns(self, original, loc, raw_lines, tre_return_grammar=None, i # check if there is anything that stores a scope reference, and if so, # disable TRE, since it can't handle that - if attempt_tre and match_in(self.stores_scope, line): + if attempt_tre and self.cached_match_in("transform_returns", self.stores_scope, line): attempt_tre = False # attempt tco/tre/async universalization diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index d146f7e9..5dcfa3e4 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -250,7 +250,7 @@ def evaluate_tokens(tokens, **kwargs): ) # base cases (performance sensitive; should be in likelihood order): - if isinstance(tokens, str): + if isinstance(tokens, (str, bool)) or tokens is None: return tokens elif isinstance(tokens, ComputationNode): @@ -712,19 +712,41 @@ def prep_grammar(grammar, for_scan, streamline=False, add_unpack=False): return grammar.parseWithTabs() -def parse(grammar, text, inner=None, eval_parse_tree=True): +def parse(grammar, text, inner=None, eval_parse_tree=True, **kwargs): """Parse text using grammar.""" with parsing_context(inner): - result = prep_grammar(grammar, for_scan=False).parseString(text) + result = prep_grammar(grammar, for_scan=False).parseString(text, **kwargs) if eval_parse_tree: result = unpack(result) return result -def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_tree=True): +def all_matches(grammar, text, inner=None, eval_parse_tree=True): + """Find all matches for grammar in text.""" + kwargs = {} + if CPYPARSING and isinstance(grammar, StartOfStrGrammar): + grammar = grammar.grammar + kwargs["maxStartLoc"] = 0 + with parsing_context(inner): + for tokens, start, stop in prep_grammar(grammar, for_scan=True).scanString(text, **kwargs): + if eval_parse_tree: + tokens = unpack(tokens) + yield tokens, start, stop + + +def cached_parse( + computation_graph_cache, + grammar, + text, + inner=None, + eval_parse_tree=True, + scan_string=False, + include_tokens=True, + cache_prefixes=False, +): """Version of parse that caches the result when it's a pure ComputationNode.""" if not CPYPARSING: # caching is only supported on cPyparsing - return parse(grammar, text, inner) + return (parse_where if scan_string else parse)(grammar, text, inner) # only iterate over keys, not items, so we don't mark everything as alive for key in computation_graph_cache: @@ -736,7 +758,10 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_ is_at_end and text == prefix or not is_at_end and text.startswith(prefix) ): - tokens = computation_graph_cache[key] + if scan_string: + tokens, start, stop = computation_graph_cache[key] + else: + tokens = computation_graph_cache[key] if DEVELOP: logger.record_stat("cached_parse", True) logger.log_tag("cached_parse hit", (prefix, text[len(prefix):], tokens)) @@ -749,12 +774,23 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_ # parse and tells us that something greedy happened so we can't cache final_evaluate_tokens.enabled = False try: - with parsing_context(inner): - loc, tokens = prep_grammar(grammar, for_scan=False).parseString(text, returnLoc=True) + if scan_string: + for tokens, start, stop in all_matches(grammar, text, inner, eval_parse_tree=False): + break + else: # no break + tokens = start = stop = None + else: + stop, tokens = parse(grammar, text, inner, eval_parse_tree=False, returnLoc=True) + if not include_tokens: + tokens = bool(tokens) if not final_evaluate_tokens.enabled: - prefix = text[:loc + 1] - is_at_end = loc >= len(text) - computation_graph_cache[(prefix, is_at_end)] = tokens + is_at_end = True if stop is None else stop >= len(text) + if cache_prefixes or is_at_end: + prefix = text if stop is None else text[:stop + 1] + if scan_string: + computation_graph_cache[(prefix, is_at_end)] = tokens, start, stop + else: + computation_graph_cache[(prefix, is_at_end)] = tokens finally: if DEVELOP: logger.record_stat("cached_parse", False) @@ -765,38 +801,40 @@ def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_ ) final_evaluate_tokens.enabled = True - if eval_parse_tree: + if include_tokens and eval_parse_tree: tokens = unpack(tokens) - return tokens + if scan_string: + return tokens, start, stop + else: + return tokens -def try_parse(grammar, text, inner=None, eval_parse_tree=True, computation_graph_cache=None): +def try_parse(grammar, text, inner=None, eval_parse_tree=True): """Attempt to parse text using grammar else None.""" try: - if computation_graph_cache is None: - return parse(grammar, text, inner, eval_parse_tree) - else: - return cached_parse(computation_graph_cache, grammar, text, inner, eval_parse_tree) + return parse(grammar, text, inner, eval_parse_tree) except ParseBaseException: return None -def does_parse(grammar, text, inner=None, **kwargs): +def cached_try_parse(cache, grammar, text, inner=None, eval_parse_tree=True, **kwargs): + """Cached version of try_parse.""" + if not CPYPARSING: # scan_string on StartOfStrGrammar is only fast on cPyparsing + return try_parse(grammar, text, inner, eval_parse_tree) + if not isinstance(grammar, StartOfStrGrammar): + grammar = StartOfStrGrammar(grammar) + tokens, start, stop = cached_parse(cache, grammar, text, inner, eval_parse_tree, scan_string=True, **kwargs) + return tokens + + +def does_parse(grammar, text, inner=None): """Determine if text can be parsed using grammar.""" - return try_parse(grammar, text, inner, eval_parse_tree=False, **kwargs) + return try_parse(grammar, text, inner, eval_parse_tree=False) -def all_matches(grammar, text, inner=None, eval_parse_tree=True): - """Find all matches for grammar in text.""" - kwargs = {} - if CPYPARSING and isinstance(grammar, StartOfStrGrammar): - grammar = grammar.grammar - kwargs["maxStartLoc"] = 0 - with parsing_context(inner): - for tokens, start, stop in prep_grammar(grammar, for_scan=True).scanString(text, **kwargs): - if eval_parse_tree: - tokens = unpack(tokens) - yield tokens, start, stop +def cached_does_parse(cache, grammar, text, inner=None, **kwargs): + """Cached version of does_parse.""" + return cached_try_parse(cache, grammar, text, inner, eval_parse_tree=False, include_tokens=False, **kwargs) def parse_where(grammar, text, inner=None): @@ -806,6 +844,12 @@ def parse_where(grammar, text, inner=None): return None, None +def cached_parse_where(cache, grammar, text, inner=None, **kwargs): + """Cached version of parse_where.""" + tokens, start, stop = cached_parse(cache, grammar, text, inner, scan_string=True, eval_parse_tree=False, include_tokens=False, **kwargs) + return start, stop + + def match_in(grammar, text, inner=None): """Determine if there is a match for grammar anywhere in text.""" start, stop = parse_where(grammar, text, inner) @@ -813,6 +857,13 @@ def match_in(grammar, text, inner=None): return start is not None +def cached_match_in(cache, grammar, text, inner=None, **kwargs): + """Cached version of match_in.""" + start, stop = cached_parse_where(cache, grammar, text, inner, **kwargs) + internal_assert((start is None) == (stop is None), "invalid cached_parse_where results", (start, stop)) + return start is not None + + def transform(grammar, text, inner=None): """Transform text by replacing matches to grammar.""" kwargs = {}