Skip to content

Commit

Permalink
Merge pull request #247 from mkscrg/safer-traversor
Browse files Browse the repository at this point in the history
Support NodeVisitor.tail() removing Node
  • Loading branch information
scinfu committed Jun 29, 2023
2 parents 0e96a20 + 8cfe30a commit 213d22a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 20 deletions.
37 changes: 19 additions & 18 deletions Sources/NodeTraversor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,29 @@ class NodeTraversor {
* Start a depth-first traverse of the root and all of its descendants.
* @param root the root node point to traverse.
*/
open func traverse(_ root: Node?)throws {
open func traverse(_ root: Node?) throws {
var node: Node? = root
var depth: Int = 0

while (node != nil) {
try visitor.head(node!, depth)
if (node!.childNodeSize() > 0) {
node = node!.childNode(0)
depth+=1
} else {
while (node!.nextSibling() == nil && depth > 0) {
try visitor.tail(node!, depth)
node = node!.getParentNode()
depth-=1
}
try visitor.tail(node!, depth)
if (node === root) {
break
}
node = node!.nextSibling()
}
try visitor.head(node!, depth)
if (node!.childNodeSize() > 0) {
node = node!.childNode(0)
depth+=1
} else {
while (node!.nextSibling() == nil && depth > 0) {
let parent = node!.getParentNode()
try visitor.tail(node!, depth)
node = parent
depth-=1
}
let nextSib = node!.nextSibling()
try visitor.tail(node!, depth)
if (node === root) {
break
}
node = nextSib
}
}
}

}
4 changes: 2 additions & 2 deletions Sources/NodeVisitor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Foundation
*/
public protocol NodeVisitor {
/**
* Callback for when a node is first visited.
* Callback for when a node is first visited. {@code head} cannot safely call {@code node.remove()}.
*
* @param node the node being visited.
* @param depth the depth of the node, relative to the root node. E.g., the root node has depth 0, and a child node
Expand All @@ -27,7 +27,7 @@ public protocol NodeVisitor {
func head(_ node: Node, _ depth: Int)throws

/**
* Callback for when a node is last visited, after all of its descendants have been visited.
* Callback for when a node is last visited, after all of its descendants have been visited. {@code tail} can safely call {@code node.remove()}.
*
* @param node the node being visited.
* @param depth the depth of the node, relative to the root node. E.g., the root node has depth 0, and a child node
Expand Down
86 changes: 86 additions & 0 deletions Tests/SwiftSoupTests/NodeTraversorTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import XCTest
@testable import SwiftSoup

class NodeTraversorTest: XCTestCase {
func testTraverseOrder() {
class TestVisitor: NodeVisitor {
var heads: [Node] = []
var tails: [Node] = []

func head(_ node: Node, _ depth: Int) throws {
heads.append(node)
}

func tail(_ node: Node, _ depth: Int) throws {
tails.append(node)
}
}

let html = "<p id=1><b id=2>3</b>4</p><p id=5>6</p>"
let doc = try! SwiftSoup.parse(html)

let tv = TestVisitor()
try! doc.body()!.traverse(tv)

assertNodeDescsMatch(
[.e(""), .e("1"), .e("2"), .t("3"), .t("4"), .e("5"), .t("6")],
tv.heads,
"head() order"
)
assertNodeDescsMatch(
[.t("3"), .e("2"), .t("4"), .e("1"), .t("6"), .e("5"), .e("")],
tv.tails,
"tail() order"
)
}

func testTailCanRemoveNode() {
class TestVisitor: NodeVisitor {
func head(_ node: Node, _ depth: Int) throws {
// no-op
}

func tail(_ node: Node, _ depth: Int) throws {
if let elt = node as? Element {
if elt.id() == "3" {
try elt.remove()
}
}
}
}

let html = "<p id=1>2</p><p id=3>4</p><p id=5>6</p>"
let doc = try! SwiftSoup.parse(html)

try! doc.body()!.traverse(TestVisitor())

let expectedHtml = "<p id=1>2</p><p id=5>6</p>"
let expectedDoc = try! SwiftSoup.parse(expectedHtml)
XCTAssertEqual(try! expectedDoc.body()!.html(), try! doc.body()!.html())
}

private func assertNodeDescsMatch(_ descs: [NodeDesc], _ nodes: [Node], _ label: String) {
XCTAssertEqual(nodes.count, descs.count, "\(label): nodes.count == descs.count")
for i in 0..<nodes.count {
let node = nodes[i]
switch descs[i] {
case .element(let id):
XCTAssert(node is Element, "\(label): nodes[i] is Element")
let elt = node as! Element
XCTAssertEqual(id, elt.id(), "\(label): nodes[i].id()")
case .text(let text):
XCTAssert(node is TextNode, "\(label): nodes[i] is TextNode")
let tnode = node as! TextNode
XCTAssertEqual(text, tnode.text(), "\(label): nodes[i].text()")
}
}
}
}

private enum NodeDesc {
case element(_ id: String)
case text(_ text: String)

static let e = NodeDesc.element
static let t = NodeDesc.text
}

0 comments on commit 213d22a

Please sign in to comment.