diff --git a/doc_build/ast_diff.py b/doc_build/ast_diff.py index da4d7b7..dc18c3d 100644 --- a/doc_build/ast_diff.py +++ b/doc_build/ast_diff.py @@ -17,6 +17,7 @@ import json +from collections.abc import Callable from typing import List, Dict, Any, Optional, Tuple from doc_build.filters.shared_filter_utils import HASH_ATTR_KEY @@ -184,63 +185,74 @@ def find_longest_common_subsequence(list_a: NodeList, list_b: NodeList) -> NodeL return dp[m][n] -def _pair_adjacent_changes(blocks: NodeList) -> NodeList: - """Pair adjacent deletion+insertion runs into substitution Divs. - - The LCS-based diff emits all deletions before all insertions within a - changed section (because nodes not in the LCS are drained from 'before' - first). This pass pairs them 1-to-1 as substitution Divs so the render - filter can produce per-word inline diffs. - - Excess deletions or insertions (when counts differ) remain as-is. +def _pair_adjacent_changes( + raw: List[Tuple[str, Any]], + *, + pair: Callable[[Any, Any], List[Any]], + wrap_deletion: Callable[[Any], Any], + wrap_insertion: Callable[[Any], Any], +) -> List[Any]: + """Walk an ``(op, element)`` stream from `_merge_with_lcs` and fold + adjacent deletion+insertion runs into per-pair recursive diffs (or + substitutions). + + ``"equal"`` elements pass through unchanged. Each contiguous deletion run + is gathered together with its immediately following insertion run; pairs + are converted via ``pair(deletion, insertion)``, which returns a (possibly + empty or multi-element) list of result elements to splice in. Excess + unpaired deletions / insertions are emitted via ``wrap_deletion`` / + ``wrap_insertion`` respectively. + + Generic over element type: callers pass blocks (``PandocNode``) or list + items (``List[PandocNode]``) and supply matching callbacks. The LCS-based + diff in `_merge_with_lcs` always emits deletions before insertions within + each gap, which is what this function expects. """ - - def _div_classes(block: PandocNode) -> List[str]: - if block.get("t") == "Div" and block.get("c"): - return block["c"][0][1] - return [] - - result: NodeList = [] + result: List[Any] = [] i = 0 - while i < len(blocks): - if "deletion" not in _div_classes(blocks[i]): - result.append(blocks[i]) + while i < len(raw): + op, element = raw[i] + if op == "equal": + result.append(element) i += 1 continue - - # Collect a consecutive run of deletions. - deletions: NodeList = [] - while i < len(blocks) and "deletion" in _div_classes(blocks[i]): - deletions.append(blocks[i]["c"][1][0]) # unwrap inner block + if op == "insertion": + # Insertion not preceded by a deletion run: emit bare. + result.append(wrap_insertion(element)) i += 1 - - # Collect the immediately following run of insertions. - insertions: NodeList = [] - while i < len(blocks) and "insertion" in _div_classes(blocks[i]): - insertions.append(blocks[i]["c"][1][0]) # unwrap inner block + continue + # op == "deletion". Gather the deletion run, then any immediately + # following insertion run. + deletions: List[Any] = [] + while i < len(raw) and raw[i][0] == "deletion": + deletions.append(raw[i][1]) + i += 1 + insertions: List[Any] = [] + while i < len(raw) and raw[i][0] == "insertion": + insertions.append(raw[i][1]) i += 1 - - # Pair 1-to-1 as substitutions; excess remain as bare deletions/insertions. n_pairs = min(len(deletions), len(insertions)) for j in range(n_pairs): - d, ins = deletions[j], insertions[j] - if _is_list_node(d) and _is_list_node(ins) and d.get("t") == ins.get("t"): - result.append(diff_list_nodes(d, ins)) - elif d.get("t") == "BlockQuote" and ins.get("t") == "BlockQuote": - result.append(diff_block_quote_nodes(d, ins)) - elif d.get("t") == "LineBlock" and ins.get("t") == "LineBlock": - result.extend(diff_line_block_nodes(d, ins)) - else: - extra_kv = _image_substitution_kv(d, ins) - result.append(make_substitution_div(d, ins, extra_kv=extra_kv)) - for node in deletions[n_pairs:]: - result.append(add_diff_meta(node, "deletion")) - for node in insertions[n_pairs:]: - result.append(add_diff_meta(node, "insertion")) - + result.extend(pair(deletions[j], insertions[j])) + for d in deletions[n_pairs:]: + result.append(wrap_deletion(d)) + for ins in insertions[n_pairs:]: + result.append(wrap_insertion(ins)) return result +def _pair_blocks(d: PandocNode, ins: PandocNode) -> NodeList: + """Pair strategy for `diff_block_lists`: recurse for like containers, + otherwise produce a substitution Div.""" + if _is_list_node(d) and _is_list_node(ins) and d.get("t") == ins.get("t"): + return [diff_list_nodes(d, ins)] + if d.get("t") == "BlockQuote" and ins.get("t") == "BlockQuote": + return [diff_block_quote_nodes(d, ins)] + if d.get("t") == "LineBlock" and ins.get("t") == "LineBlock": + return diff_line_block_nodes(d, ins) + return [make_substitution_div(d, ins, extra_kv=_image_substitution_kv(d, ins))] + + LIST_TYPES = frozenset({"BulletList", "OrderedList"}) @@ -287,69 +299,18 @@ def diff_list_nodes(old_node: PandocNode, new_node: PandocNode) -> PandocNode: old_items = _get_list_items(old_node) new_items = _get_list_items(new_node) - # find_longest_common_subsequence serializes each element via json.dumps, so - # it works on list items (List[PandocNode]) just as well as on PandocNode. + # find_longest_common_subsequence and _merge_with_lcs both serialize each + # element via json.dumps, so they work on list items (List[PandocNode]) + # just as well as on PandocNode. lcs = find_longest_common_subsequence(old_items, new_items) # type: ignore - lcs_strs = {json.dumps(item, sort_keys=True) for item in lcs} - - # Walk like diff_block_lists to produce an ordered stream of (op, item) pairs. - raw: List[Tuple[str, List[PandocNode]]] = [] - ptr_a, ptr_b = 0, 0 - while ptr_a < len(old_items) or ptr_b < len(new_items): - a = old_items[ptr_a] if ptr_a < len(old_items) else None - b = new_items[ptr_b] if ptr_b < len(new_items) else None - a_str = json.dumps(a, sort_keys=True) if a is not None else None - b_str = json.dumps(b, sort_keys=True) if b is not None else None - - if a is not None and a_str not in lcs_strs: - raw.append(("deletion", a)) - ptr_a += 1 - elif b is not None and b_str not in lcs_strs: - raw.append(("insertion", b)) - ptr_b += 1 - elif a is not None and b is not None: - raw.append(("equal", a)) - ptr_a += 1 - ptr_b += 1 - elif ptr_a < len(old_items): - raw.append(("deletion", old_items[ptr_a])) - ptr_a += 1 - else: - raw.append(("insertion", new_items[ptr_b])) - ptr_b += 1 - - # Pair consecutive deletion+insertion runs into substitution items. - result_items: List[List[PandocNode]] = [] - i = 0 - while i < len(raw): - op, item = raw[i] - if op == "equal": - result_items.append(item) - i += 1 - continue - if op == "insertion": - result_items.append([add_diff_meta(_item_to_block(item), "insertion")]) - i += 1 - continue - - # Collect a run of deletions then the immediately following insertions. - deletions: List[List[PandocNode]] = [] - while i < len(raw) and raw[i][0] == "deletion": - deletions.append(raw[i][1]) - i += 1 - insertions: List[List[PandocNode]] = [] - while i < len(raw) and raw[i][0] == "insertion": - insertions.append(raw[i][1]) - i += 1 - - n_pairs = min(len(deletions), len(insertions)) - for j in range(n_pairs): - result_items.append(diff_block_lists(deletions[j], insertions[j])) - for del_item in deletions[n_pairs:]: - result_items.append([add_diff_meta(_item_to_block(del_item), "deletion")]) - for ins_item in insertions[n_pairs:]: - result_items.append([add_diff_meta(_item_to_block(ins_item), "insertion")]) - + raw = _merge_with_lcs(old_items, new_items, lcs) # type: ignore + + result_items = _pair_adjacent_changes( + raw, + pair=lambda d, ins: [diff_block_lists(d, ins)], + wrap_deletion=lambda item: [add_diff_meta(_item_to_block(item), "deletion")], + wrap_insertion=lambda item: [add_diff_meta(_item_to_block(item), "insertion")], + ) return _build_list_with_items(old_node, result_items) @@ -397,41 +358,61 @@ def diff_block_lists(before_blocks: NodeList, after_blocks: NodeList) -> NodeLis groups as substitution Divs for per-word inline diffing. """ lcs_nodes = find_longest_common_subsequence(before_blocks, after_blocks) - lcs_set = {json.dumps(node, sort_keys=True) for node in lcs_nodes} - - merged_blocks: NodeList = [] - - ptr_a, ptr_b = 0, 0 - while ptr_a < len(before_blocks) or ptr_b < len(after_blocks): - node_a = before_blocks[ptr_a] if ptr_a < len(before_blocks) else None - node_b = after_blocks[ptr_b] if ptr_b < len(after_blocks) else None + raw = _merge_with_lcs(before_blocks, after_blocks, lcs_nodes) + return _pair_adjacent_changes( + raw, + pair=_pair_blocks, + wrap_deletion=lambda b: add_diff_meta(b, "deletion"), + wrap_insertion=lambda b: add_diff_meta(b, "insertion"), + ) + + +def _merge_with_lcs( + before: List[Any], + after: List[Any], + lcs_elements: List[Any], +) -> List[Tuple[str, Any]]: + """Walk `before` and `after` in lockstep with the LCS, yielding an ordered + stream of ``(op, element)`` tuples in document order. + + `op` is one of ``"equal"``, ``"deletion"`` or ``"insertion"``. Elements + not in the LCS are emitted as deletions (from `before`) or insertions + (from `after`); LCS elements are emitted as ``"equal"``. Within each + changed gap, deletions are emitted before insertions, which is what + `_pair_adjacent_changes` needs to fold them into substitution-style + annotations. + + Elements may be Pandoc blocks (``PandocNode``) or list items + (``List[PandocNode]``); the only requirement is that ``json.dumps(element, + sort_keys=True)`` gives a stable equality key. + """ + lcs_set = {json.dumps(n, sort_keys=True) for n in lcs_elements} - node_a_str = json.dumps(node_a, sort_keys=True) if node_a else None - node_b_str = json.dumps(node_b, sort_keys=True) if node_b else None + raw: List[Tuple[str, Any]] = [] + ptr_a = ptr_b = 0 + while ptr_a < len(before) or ptr_b < len(after): + a = before[ptr_a] if ptr_a < len(before) else None + b = after[ptr_b] if ptr_b < len(after) else None + a_str = json.dumps(a, sort_keys=True) if a is not None else None + b_str = json.dumps(b, sort_keys=True) if b is not None else None - if node_a and node_a_str not in lcs_set: - # This node from 'before' is not in the LCS, so it was removed. - merged_blocks.append(add_diff_meta(node_a, "deletion")) + if a is not None and a_str not in lcs_set: + raw.append(("deletion", a)) ptr_a += 1 - elif node_b and node_b_str not in lcs_set: - # This node from 'after' is not in the LCS, so it was added. - merged_blocks.append(add_diff_meta(node_b, "insertion")) + elif b is not None and b_str not in lcs_set: + raw.append(("insertion", b)) ptr_b += 1 - elif node_a and node_b: - # Both nodes are in the LCS (and therefore identical). - merged_blocks.append(node_a) + elif a is not None and b is not None: + raw.append(("equal", a)) ptr_a += 1 ptr_b += 1 - elif ptr_a < len(before_blocks): - # Exhausted 'after_blocks', remaining 'before' blocks are removals. - merged_blocks.append(add_diff_meta(before_blocks[ptr_a], "deletion")) + elif ptr_a < len(before): + raw.append(("deletion", before[ptr_a])) ptr_a += 1 - elif ptr_b < len(after_blocks): - # Exhausted 'before_blocks', remaining 'after' blocks are additions. - merged_blocks.append(add_diff_meta(after_blocks[ptr_b], "insertion")) + else: + raw.append(("insertion", after[ptr_b])) ptr_b += 1 - - return _pair_adjacent_changes(merged_blocks) + return raw def diff_ast_files(before_path, after_path, output_path):