diff --git a/pkg/graph/helpers_test.go b/pkg/graph/helpers_test.go index 7ca27d0..2509944 100644 --- a/pkg/graph/helpers_test.go +++ b/pkg/graph/helpers_test.go @@ -18,6 +18,14 @@ type rule = korrel8r.Rule func c(i int) korrel8r.Class { return Domain.Class(strconv.Itoa(i)) } +func nodesToInts(nodes []*Node) (ret []int) { + for _, n := range nodes { + i,_ := strconv.Atoi(n.Class.Name()) + ret = append(ret, i) + } + return ret +} + func testGraph(rules []korrel8r.Rule) *Graph { d := NewData() for _, r := range rules { diff --git a/pkg/graph/traverse.go b/pkg/graph/traverse.go index b0c9c65..48cd2e4 100644 --- a/pkg/graph/traverse.go +++ b/pkg/graph/traverse.go @@ -32,6 +32,7 @@ func (g *Graph) Traverse(start korrel8r.Class, goals []korrel8r.Class, f func(*L // Returns the subset of the graph that was traversed. func (g *Graph) Neighbours(start korrel8r.Class, depth int, f func(*Line) bool) (*Graph, error) { sub := g.Data.EmptyGraph() + sub.AddNode(g.NodeFor(start)) atDepth := 0 current := unique.Set[int64]{} // Nodes at the current depth or above. diff --git a/pkg/graph/traverse_test.go b/pkg/graph/traverse_test.go index e56db1f..03f62fc 100644 --- a/pkg/graph/traverse_test.go +++ b/pkg/graph/traverse_test.go @@ -7,22 +7,13 @@ import ( "testing" "github.com/korrel8r/korrel8r/pkg/korrel8r" - "github.com/korrel8r/korrel8r/pkg/unique" "github.com/stretchr/testify/assert" ) -type collecter struct { - rules []string - classes unique.Set[string] -} +type ruleCollecter struct { rules []string } -func (c *collecter) Traverse(l *Line) bool { +func (c *ruleCollecter) Traverse(l *Line) bool { c.rules = append(c.rules, RuleFor(l).Name()) - if c.classes == nil { - c.classes = unique.NewSet[string]() - } - c.classes.Add(ClassFor(l.From()).String()) - c.classes.Add(ClassFor(l.To()).String()) return true } @@ -34,31 +25,34 @@ func TestTraverse(t *testing.T) { name string graph []rule rules [][]string // inner slices are are unordered components. - classes unique.Set[string] + nodes []int }{ { name: "multipath", graph: []rule{r(1, 11), r(1, 12), r(11, 99), r(12, 99)}, rules: [][]string{{"1_11", "1_12"}, {"11_99", "12_99"}}, - classes: unique.NewSet("1", "11", "12", "99"), + nodes: []int{1, 11, 12, 99}, }, { name: "simple", graph: []rule{r(1, 2), r(2, 3), r(3, 4), r(4, 5)}, rules: [][]string{{"1_2"}, {"2_3"}, {"3_4"}, {"4_5"}}, + nodes: []int{1,2,3,4,5}, }, { name: "cycle", // cycle of 2,3,4 graph: []rule{r(1, 2), r(2, 3), r(3, 4), r(4, 2), r(4, 5)}, rules: [][]string{{"1_2"}, {"2_3", "3_4", "4_2", "4_5"}}, + nodes: []int{1,2,3,4,5}, }, } { t.Run(x.name, func(t *testing.T) { g := testGraph(x.graph) - var got collecter + var got ruleCollecter _, err := g.Traverse(x.graph[0].Start()[0], x.graph[len(x.graph)-1].Goal(), got.Traverse) assert.NoError(t, err) assertComponentOrder(t, x.rules, got.rules) + assert.ElementsMatch(t, x.nodes, nodesToInts(g.AllNodes())) }) } } @@ -70,26 +64,37 @@ func TestNeighbours(t *testing.T) { g := testGraph([]rule{r(1, 11), r(11, 1), r(1, 12), r(1, 13), r(11, 22), r(12, 22), r(12, 13), r(22, 99)}) for _, x := range []struct { depth int - want [][]string + rules [][]string + nodes []int }{ + { + depth: 0, + rules: nil, + nodes: []int{1}, + }, { depth: 1, - want: [][]string{{"1_11", "1_12", "1_13"}}, + rules: [][]string{{"1_11", "1_12", "1_13"}}, + nodes: []int{1,11,12,13}, }, { depth: 2, - want: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}}, + rules: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}}, + nodes: []int{1,11,12,13,22}, }, { depth: 3, - want: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}, {"22_99"}}, + rules: [][]string{{"1_11", "1_12", "1_13"}, {"11_22", "12_22"}, {"22_99"}}, + nodes: []int{1,11,12,13,22,99}, }, } { t.Run(fmt.Sprintf("depth=%v", x.depth), func(t *testing.T) { - var got collecter - _, err := g.Neighbours(c(1), x.depth, got.Traverse) + var got ruleCollecter + g2, err := g.Neighbours(c(1), x.depth, got.Traverse) assert.NoError(t, err) - assertComponentOrder(t, x.want, got.rules) + assertComponentOrder(t, x.rules, got.rules) + assert.ElementsMatch(t, x.nodes, nodesToInts(g2.AllNodes())) }) } } +