Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 116 additions & 135 deletions doc_build/ast_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json

from collections.abc import Callable
from typing import List, Dict, Any, Optional, Tuple

from doc_build.filters.shared_filter_utils import HASH_ATTR_KEY
Expand Down Expand Up @@ -184,63 +185,74 @@ def find_longest_common_subsequence(list_a: NodeList, list_b: NodeList) -> NodeL
return dp[m][n]


def _pair_adjacent_changes(blocks: NodeList) -> NodeList:
"""Pair adjacent deletion+insertion runs into substitution Divs.

The LCS-based diff emits all deletions before all insertions within a
changed section (because nodes not in the LCS are drained from 'before'
first). This pass pairs them 1-to-1 as substitution Divs so the render
filter can produce per-word inline diffs.

Excess deletions or insertions (when counts differ) remain as-is.
def _pair_adjacent_changes(
raw: List[Tuple[str, Any]],
*,
pair: Callable[[Any, Any], List[Any]],
wrap_deletion: Callable[[Any], Any],
wrap_insertion: Callable[[Any], Any],
) -> List[Any]:
"""Walk an ``(op, element)`` stream from `_merge_with_lcs` and fold
adjacent deletion+insertion runs into per-pair recursive diffs (or
substitutions).

``"equal"`` elements pass through unchanged. Each contiguous deletion run
is gathered together with its immediately following insertion run; pairs
are converted via ``pair(deletion, insertion)``, which returns a (possibly
empty or multi-element) list of result elements to splice in. Excess
unpaired deletions / insertions are emitted via ``wrap_deletion`` /
``wrap_insertion`` respectively.

Generic over element type: callers pass blocks (``PandocNode``) or list
items (``List[PandocNode]``) and supply matching callbacks. The LCS-based
diff in `_merge_with_lcs` always emits deletions before insertions within
each gap, which is what this function expects.
"""

def _div_classes(block: PandocNode) -> List[str]:
if block.get("t") == "Div" and block.get("c"):
return block["c"][0][1]
return []

result: NodeList = []
result: List[Any] = []
i = 0
while i < len(blocks):
if "deletion" not in _div_classes(blocks[i]):
result.append(blocks[i])
while i < len(raw):
op, element = raw[i]
if op == "equal":
result.append(element)
i += 1
continue

# Collect a consecutive run of deletions.
deletions: NodeList = []
while i < len(blocks) and "deletion" in _div_classes(blocks[i]):
deletions.append(blocks[i]["c"][1][0]) # unwrap inner block
if op == "insertion":
# Insertion not preceded by a deletion run: emit bare.
result.append(wrap_insertion(element))
i += 1

# Collect the immediately following run of insertions.
insertions: NodeList = []
while i < len(blocks) and "insertion" in _div_classes(blocks[i]):
insertions.append(blocks[i]["c"][1][0]) # unwrap inner block
continue
# op == "deletion". Gather the deletion run, then any immediately
# following insertion run.
deletions: List[Any] = []
while i < len(raw) and raw[i][0] == "deletion":
deletions.append(raw[i][1])
i += 1
insertions: List[Any] = []
while i < len(raw) and raw[i][0] == "insertion":
insertions.append(raw[i][1])
i += 1

# Pair 1-to-1 as substitutions; excess remain as bare deletions/insertions.
n_pairs = min(len(deletions), len(insertions))
for j in range(n_pairs):
d, ins = deletions[j], insertions[j]
if _is_list_node(d) and _is_list_node(ins) and d.get("t") == ins.get("t"):
result.append(diff_list_nodes(d, ins))
elif d.get("t") == "BlockQuote" and ins.get("t") == "BlockQuote":
result.append(diff_block_quote_nodes(d, ins))
elif d.get("t") == "LineBlock" and ins.get("t") == "LineBlock":
result.extend(diff_line_block_nodes(d, ins))
else:
extra_kv = _image_substitution_kv(d, ins)
result.append(make_substitution_div(d, ins, extra_kv=extra_kv))
for node in deletions[n_pairs:]:
result.append(add_diff_meta(node, "deletion"))
for node in insertions[n_pairs:]:
result.append(add_diff_meta(node, "insertion"))

result.extend(pair(deletions[j], insertions[j]))
for d in deletions[n_pairs:]:
result.append(wrap_deletion(d))
for ins in insertions[n_pairs:]:
result.append(wrap_insertion(ins))
return result


def _pair_blocks(d: PandocNode, ins: PandocNode) -> NodeList:
"""Pair strategy for `diff_block_lists`: recurse for like containers,
otherwise produce a substitution Div."""
if _is_list_node(d) and _is_list_node(ins) and d.get("t") == ins.get("t"):
return [diff_list_nodes(d, ins)]
if d.get("t") == "BlockQuote" and ins.get("t") == "BlockQuote":
return [diff_block_quote_nodes(d, ins)]
if d.get("t") == "LineBlock" and ins.get("t") == "LineBlock":
return diff_line_block_nodes(d, ins)
return [make_substitution_div(d, ins, extra_kv=_image_substitution_kv(d, ins))]


LIST_TYPES = frozenset({"BulletList", "OrderedList"})


Expand Down Expand Up @@ -287,69 +299,18 @@ def diff_list_nodes(old_node: PandocNode, new_node: PandocNode) -> PandocNode:
old_items = _get_list_items(old_node)
new_items = _get_list_items(new_node)

# find_longest_common_subsequence serializes each element via json.dumps, so
# it works on list items (List[PandocNode]) just as well as on PandocNode.
# find_longest_common_subsequence and _merge_with_lcs both serialize each
# element via json.dumps, so they work on list items (List[PandocNode])
# just as well as on PandocNode.
lcs = find_longest_common_subsequence(old_items, new_items) # type: ignore
lcs_strs = {json.dumps(item, sort_keys=True) for item in lcs}

# Walk like diff_block_lists to produce an ordered stream of (op, item) pairs.
raw: List[Tuple[str, List[PandocNode]]] = []
ptr_a, ptr_b = 0, 0
while ptr_a < len(old_items) or ptr_b < len(new_items):
a = old_items[ptr_a] if ptr_a < len(old_items) else None
b = new_items[ptr_b] if ptr_b < len(new_items) else None
a_str = json.dumps(a, sort_keys=True) if a is not None else None
b_str = json.dumps(b, sort_keys=True) if b is not None else None

if a is not None and a_str not in lcs_strs:
raw.append(("deletion", a))
ptr_a += 1
elif b is not None and b_str not in lcs_strs:
raw.append(("insertion", b))
ptr_b += 1
elif a is not None and b is not None:
raw.append(("equal", a))
ptr_a += 1
ptr_b += 1
elif ptr_a < len(old_items):
raw.append(("deletion", old_items[ptr_a]))
ptr_a += 1
else:
raw.append(("insertion", new_items[ptr_b]))
ptr_b += 1

# Pair consecutive deletion+insertion runs into substitution items.
result_items: List[List[PandocNode]] = []
i = 0
while i < len(raw):
op, item = raw[i]
if op == "equal":
result_items.append(item)
i += 1
continue
if op == "insertion":
result_items.append([add_diff_meta(_item_to_block(item), "insertion")])
i += 1
continue

# Collect a run of deletions then the immediately following insertions.
deletions: List[List[PandocNode]] = []
while i < len(raw) and raw[i][0] == "deletion":
deletions.append(raw[i][1])
i += 1
insertions: List[List[PandocNode]] = []
while i < len(raw) and raw[i][0] == "insertion":
insertions.append(raw[i][1])
i += 1

n_pairs = min(len(deletions), len(insertions))
for j in range(n_pairs):
result_items.append(diff_block_lists(deletions[j], insertions[j]))
for del_item in deletions[n_pairs:]:
result_items.append([add_diff_meta(_item_to_block(del_item), "deletion")])
for ins_item in insertions[n_pairs:]:
result_items.append([add_diff_meta(_item_to_block(ins_item), "insertion")])

raw = _merge_with_lcs(old_items, new_items, lcs) # type: ignore

result_items = _pair_adjacent_changes(
raw,
pair=lambda d, ins: [diff_block_lists(d, ins)],
wrap_deletion=lambda item: [add_diff_meta(_item_to_block(item), "deletion")],
wrap_insertion=lambda item: [add_diff_meta(_item_to_block(item), "insertion")],
)
return _build_list_with_items(old_node, result_items)


Expand Down Expand Up @@ -397,41 +358,61 @@ def diff_block_lists(before_blocks: NodeList, after_blocks: NodeList) -> NodeLis
groups as substitution Divs for per-word inline diffing.
"""
lcs_nodes = find_longest_common_subsequence(before_blocks, after_blocks)
lcs_set = {json.dumps(node, sort_keys=True) for node in lcs_nodes}

merged_blocks: NodeList = []

ptr_a, ptr_b = 0, 0
while ptr_a < len(before_blocks) or ptr_b < len(after_blocks):
node_a = before_blocks[ptr_a] if ptr_a < len(before_blocks) else None
node_b = after_blocks[ptr_b] if ptr_b < len(after_blocks) else None
raw = _merge_with_lcs(before_blocks, after_blocks, lcs_nodes)
return _pair_adjacent_changes(
raw,
pair=_pair_blocks,
wrap_deletion=lambda b: add_diff_meta(b, "deletion"),
wrap_insertion=lambda b: add_diff_meta(b, "insertion"),
)


def _merge_with_lcs(
before: List[Any],
after: List[Any],
lcs_elements: List[Any],
) -> List[Tuple[str, Any]]:
"""Walk `before` and `after` in lockstep with the LCS, yielding an ordered
stream of ``(op, element)`` tuples in document order.

`op` is one of ``"equal"``, ``"deletion"`` or ``"insertion"``. Elements
not in the LCS are emitted as deletions (from `before`) or insertions
(from `after`); LCS elements are emitted as ``"equal"``. Within each
changed gap, deletions are emitted before insertions, which is what
`_pair_adjacent_changes` needs to fold them into substitution-style
annotations.

Elements may be Pandoc blocks (``PandocNode``) or list items
(``List[PandocNode]``); the only requirement is that ``json.dumps(element,
sort_keys=True)`` gives a stable equality key.
"""
lcs_set = {json.dumps(n, sort_keys=True) for n in lcs_elements}

node_a_str = json.dumps(node_a, sort_keys=True) if node_a else None
node_b_str = json.dumps(node_b, sort_keys=True) if node_b else None
raw: List[Tuple[str, Any]] = []
ptr_a = ptr_b = 0
while ptr_a < len(before) or ptr_b < len(after):
a = before[ptr_a] if ptr_a < len(before) else None
b = after[ptr_b] if ptr_b < len(after) else None
a_str = json.dumps(a, sort_keys=True) if a is not None else None
b_str = json.dumps(b, sort_keys=True) if b is not None else None

if node_a and node_a_str not in lcs_set:
# This node from 'before' is not in the LCS, so it was removed.
merged_blocks.append(add_diff_meta(node_a, "deletion"))
if a is not None and a_str not in lcs_set:
raw.append(("deletion", a))
ptr_a += 1
elif node_b and node_b_str not in lcs_set:
# This node from 'after' is not in the LCS, so it was added.
merged_blocks.append(add_diff_meta(node_b, "insertion"))
elif b is not None and b_str not in lcs_set:
raw.append(("insertion", b))
ptr_b += 1
elif node_a and node_b:
# Both nodes are in the LCS (and therefore identical).
merged_blocks.append(node_a)
elif a is not None and b is not None:
raw.append(("equal", a))
ptr_a += 1
ptr_b += 1
elif ptr_a < len(before_blocks):
# Exhausted 'after_blocks', remaining 'before' blocks are removals.
merged_blocks.append(add_diff_meta(before_blocks[ptr_a], "deletion"))
elif ptr_a < len(before):
raw.append(("deletion", before[ptr_a]))
ptr_a += 1
elif ptr_b < len(after_blocks):
# Exhausted 'before_blocks', remaining 'after' blocks are additions.
merged_blocks.append(add_diff_meta(after_blocks[ptr_b], "insertion"))
else:
raw.append(("insertion", after[ptr_b]))
ptr_b += 1

return _pair_adjacent_changes(merged_blocks)
return raw


def diff_ast_files(before_path, after_path, output_path):
Expand Down
Loading