Skip to content

Commit

Permalink
fix: always include start node in neighbours search
Browse files Browse the repository at this point in the history
Previously was returning an empty graph if there were no neighbours.
  • Loading branch information
alanconway committed Sep 6, 2024
1 parent 89dfa95 commit 75da48f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
8 changes: 8 additions & 0 deletions pkg/graph/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ type rule = korrel8r.Rule

func c(i int) korrel8r.Class { return Domain.Class(strconv.Itoa(i)) }

func nodesToInts(nodes []*Node) (ret []int) {
for _, n := range nodes {
i,_ := strconv.Atoi(n.Class.Name())
ret = append(ret, i)
}
return ret
}

func testGraph(rules []korrel8r.Rule) *Graph {
d := NewData()
for _, r := range rules {
Expand Down
1 change: 1 addition & 0 deletions pkg/graph/traverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (g *Graph) Traverse(start korrel8r.Class, goals []korrel8r.Class, f func(*L
// Returns the subset of the graph that was traversed.
func (g *Graph) Neighbours(start korrel8r.Class, depth int, f func(*Line) bool) (*Graph, error) {
sub := g.Data.EmptyGraph()
sub.AddNode(g.NodeFor(start))
atDepth := 0
current := unique.Set[int64]{} // Nodes at the current depth or above.

Expand Down
47 changes: 26 additions & 21 deletions pkg/graph/traverse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,13 @@ import (
"testing"

"github.com/korrel8r/korrel8r/pkg/korrel8r"
"github.com/korrel8r/korrel8r/pkg/unique"
"github.com/stretchr/testify/assert"
)

type collecter struct {
rules []string
classes unique.Set[string]
}
type ruleCollecter struct { rules []string }

func (c *collecter) Traverse(l *Line) bool {
func (c *ruleCollecter) Traverse(l *Line) bool {
c.rules = append(c.rules, RuleFor(l).Name())
if c.classes == nil {
c.classes = unique.NewSet[string]()
}
c.classes.Add(ClassFor(l.From()).String())
c.classes.Add(ClassFor(l.To()).String())
return true
}

Expand All @@ -34,31 +25,34 @@ func TestTraverse(t *testing.T) {
name string
graph []rule
rules [][]string // inner slices are are unordered components.
classes unique.Set[string]
nodes []int
}{
{
name: "multipath",
graph: []rule{r(1, 11), r(1, 12), r(11, 99), r(12, 99)},
rules: [][]string{{"1_11", "1_12"}, {"11_99", "12_99"}},
classes: unique.NewSet("1", "11", "12", "99"),
nodes: []int{1, 11, 12, 99},
},
{
name: "simple",
graph: []rule{r(1, 2), r(2, 3), r(3, 4), r(4, 5)},
rules: [][]string{{"1_2"}, {"2_3"}, {"3_4"}, {"4_5"}},
nodes: []int{1,2,3,4,5},
},
{
name: "cycle", // cycle of 2,3,4
graph: []rule{r(1, 2), r(2, 3), r(3, 4), r(4, 2), r(4, 5)},
rules: [][]string{{"1_2"}, {"2_3", "3_4", "4_2", "4_5"}},
nodes: []int{1,2,3,4,5},
},
} {
t.Run(x.name, func(t *testing.T) {
g := testGraph(x.graph)
var got collecter
var got ruleCollecter
_, err := g.Traverse(x.graph[0].Start()[0], x.graph[len(x.graph)-1].Goal(), got.Traverse)
assert.NoError(t, err)
assertComponentOrder(t, x.rules, got.rules)
assert.ElementsMatch(t, x.nodes, nodesToInts(g.AllNodes()))
})
}
}
Expand All @@ -70,26 +64,37 @@ func TestNeighbours(t *testing.T) {
g := testGraph([]rule{r(1, 11), r(11, 1), r(1, 12), r(1, 13), r(11, 22), r(12, 22), r(12, 13), r(22, 99)})
for _, x := range []struct {
depth int
want [][]string
rules [][]string
nodes []int
}{
{
depth: 0,
rules: nil,
nodes: []int{1},
},
{
depth: 1,
want: [][]string{{"1_11", "1_12", "1_13"}},
rules: [][]string{{"1_11", "1_12", "1_13"}},
nodes: []int{1,11,12,13},
},
{
depth: 2,
want: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}},
rules: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}},
nodes: []int{1,11,12,13,22},
},
{
depth: 3,
want: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}, {"22_99"}},
rules: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}, {"22_99"}},
nodes: []int{1,11,12,13,22,99},
},
} {
t.Run(fmt.Sprintf("depth=%v", x.depth), func(t *testing.T) {
var got collecter
_, err := g.Neighbours(c(1), x.depth, got.Traverse)
var got ruleCollecter
g2, err := g.Neighbours(c(1), x.depth, got.Traverse)
assert.NoError(t, err)
assertComponentOrder(t, x.want, got.rules)
assertComponentOrder(t, x.rules, got.rules)
assert.ElementsMatch(t, x.nodes, nodesToInts(g2.AllNodes()))
})
}
}

0 comments on commit 75da48f

Please sign in to comment.