From 346f14f31ebedea607322bd4c2eb551df518a099 Mon Sep 17 00:00:00 2001 From: Jussi Maki Date: Thu, 26 Sep 2024 10:42:44 +0200 Subject: [PATCH] testutils: Add script-based test utilities Add generic testscript commands for testing against StateDB tables. This allows implementing tests as scripts, which becomes useful when tests perform multiple steps on tables and need to verify the output each step. Signed-off-by: Jussi Maki --- any_table.go | 63 +++++ db.go | 19 ++ go.mod | 4 +- go.sum | 2 + iterator.go | 27 +- table.go | 23 ++ testutils/script.go | 494 ++++++++++++++++++++++++++++++++++ testutils/script_test.go | 62 +++++ testutils/testdata/test.txtar | 87 ++++++ types.go | 52 ++-- 10 files changed, 807 insertions(+), 26 deletions(-) create mode 100644 any_table.go create mode 100644 testutils/script.go create mode 100644 testutils/script_test.go create mode 100644 testutils/testdata/test.txtar diff --git a/any_table.go b/any_table.go new file mode 100644 index 0000000..1e6efe2 --- /dev/null +++ b/any_table.go @@ -0,0 +1,63 @@ +package statedb + +import ( + "iter" +) + +// AnyTable allows any-typed access to a StateDB table. This is intended +// for building generic tooling for accessing the table and should be +// avoided if possible. +type AnyTable struct { + Meta TableMeta +} + +func (t AnyTable) All(txn ReadTxn) iter.Seq2[any, Revision] { + indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) + return anySeq(indexTxn.Iterator()) +} + +func (t AnyTable) UnmarshalYAML(data []byte) (any, error) { + return t.Meta.unmarshalYAML(data) +} + +func (t AnyTable) Insert(txn WriteTxn, obj any) (old any, hadOld bool, err error) { + var iobj object + iobj, hadOld, err = txn.getTxn().insert(t.Meta, Revision(0), obj) + if hadOld { + old = iobj.data + } + return +} + +func (t AnyTable) Delete(txn WriteTxn, obj any) (old any, hadOld bool, err error) { + var iobj object + iobj, hadOld, err = txn.getTxn().delete(t.Meta, Revision(0), obj) + if hadOld { + old = iobj.data + } + return +} + +func (t AnyTable) Prefix(txn ReadTxn, key string) iter.Seq2[any, Revision] { + indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) + iter, _ := indexTxn.Prefix([]byte(key)) + return anySeq(iter) +} + +func (t AnyTable) LowerBound(txn ReadTxn, key string) iter.Seq2[any, Revision] { + indexTxn := txn.getTxn().mustIndexReadTxn(t.Meta, PrimaryIndexPos) + iter := indexTxn.LowerBound([]byte(key)) + return anySeq(iter) +} + +func (t AnyTable) TableHeader() []string { + zero := t.Meta.proto() + if tw, ok := zero.(TableWritable); ok { + return tw.TableHeader() + } + return nil +} + +func (t AnyTable) Proto() any { + return t.Meta.proto() +} diff --git a/db.go b/db.go index 2ec4d62..9d8b0cb 100644 --- a/db.go +++ b/db.go @@ -247,6 +247,25 @@ func (db *DB) WriteTxn(table TableMeta, tables ...TableMeta) WriteTxn { return txn } +func (db *DB) GetTables(txn ReadTxn) (tbls []TableMeta) { + root := txn.getTxn().root + tbls = make([]TableMeta, 0, len(root)) + for _, table := range root { + tbls = append(tbls, table.meta) + } + return +} + +func (db *DB) GetTable(txn ReadTxn, name string) TableMeta { + root := txn.getTxn().root + for _, table := range root { + if table.meta.Name() == name { + return table.meta + } + } + return nil +} + // Start the background workers for the database. // // This starts the graveyard worker that deals with garbage collecting diff --git a/go.mod b/go.mod index 9726df7..7db8b26 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.23 require ( github.com/cilium/hive v0.0.0-20240209163124-bd6ebb4ec11d github.com/cilium/stream v0.0.0-20240209152734-a0792b51812d + github.com/rogpeppe/go-internal v1.11.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 go.uber.org/goleak v1.3.0 golang.org/x/time v0.5.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -34,7 +36,7 @@ require ( golang.org/x/sys v0.17.0 // indirect golang.org/x/term v0.16.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.17.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9a471c1..a4219d4 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/iterator.go b/iterator.go index 08b4e79..6bacef8 100644 --- a/iterator.go +++ b/iterator.go @@ -53,7 +53,7 @@ func ToSeq[A, B any](seq iter.Seq2[A, B]) iter.Seq[A] { } } -// partSeq returns a sequence of objects from a part Iterator. +// partSeq returns a casted sequence of objects from a part Iterator. func partSeq[Obj any](iter *part.Iterator[object]) iter.Seq2[Obj, Revision] { return func(yield func(Obj, Revision) bool) { // Iterate over a clone of the original iterator to allow the sequence to be iterated @@ -71,6 +71,24 @@ func partSeq[Obj any](iter *part.Iterator[object]) iter.Seq2[Obj, Revision] { } } +// anySeq returns a sequence of objects from a part Iterator. +func anySeq(iter *part.Iterator[object]) iter.Seq2[any, Revision] { + return func(yield func(any, Revision) bool) { + // Iterate over a clone of the original iterator to allow the sequence to be iterated + // from scratch multiple times. + it := iter.Clone() + for { + _, iobj, ok := it.Next() + if !ok { + break + } + if !yield(iobj.data, iobj.revision) { + break + } + } + } +} + // nonUniqueSeq returns a sequence of objects for a non-unique index. // Non-unique indexes work by concatenating the secondary key with the // primary key and then prefix searching for the items: @@ -128,6 +146,13 @@ func (it *iterator[Obj]) Next() (obj Obj, revision uint64, ok bool) { return } +// Iterator for iterating a sequence objects. +type Iterator[Obj any] interface { + // Next returns the next object and its revision if ok is true, otherwise + // zero values to mean that the iteration has finished. + Next() (obj Obj, rev Revision, ok bool) +} + func NewDualIterator[Obj any](left, right Iterator[Obj]) *DualIterator[Obj] { return &DualIterator[Obj]{ left: iterState[Obj]{iter: left}, diff --git a/table.go b/table.go index 1402e93..c1069ed 100644 --- a/table.go +++ b/table.go @@ -14,6 +14,7 @@ import ( "github.com/cilium/statedb/internal" "github.com/cilium/statedb/part" + "gopkg.in/yaml.v3" "github.com/cilium/statedb/index" ) @@ -184,6 +185,15 @@ func (t *genTable[Obj]) Name() string { return t.table } +func (t *genTable[Obj]) Indexes() []string { + idxs := make([]string, 0, 1+len(t.secondaryAnyIndexers)) + idxs = append(idxs, t.primaryAnyIndexer.name) + for k := range t.secondaryAnyIndexers { + idxs = append(idxs, k) + } + return idxs +} + func (t *genTable[Obj]) ToTable() Table[Obj] { return t } @@ -468,5 +478,18 @@ func (t *genTable[Obj]) sortableMutex() internal.SortableMutex { return t.smu } +func (t *genTable[Obj]) proto() any { + var zero Obj + return zero +} + +func (t *genTable[Obj]) unmarshalYAML(data []byte) (any, error) { + var obj Obj + if err := yaml.Unmarshal(data, &obj); err != nil { + return nil, err + } + return obj, nil +} + var _ Table[bool] = &genTable[bool]{} var _ RWTable[bool] = &genTable[bool]{} diff --git a/testutils/script.go b/testutils/script.go new file mode 100644 index 0000000..0120979 --- /dev/null +++ b/testutils/script.go @@ -0,0 +1,494 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package testutils + +import ( + "bytes" + "cmp" + "encoding/json" + "flag" + "fmt" + "iter" + "maps" + "os" + "regexp" + "slices" + "strings" + "text/tabwriter" + "time" + + "github.com/cilium/statedb" + "github.com/rogpeppe/go-internal/testscript" + "gopkg.in/yaml.v3" +) + +type Cmd = func(ts *testscript.TestScript, neg bool, args []string) + +const tsDBKey = "statedb" + +func Setup(e *testscript.Env, db *statedb.DB) { + e.Values[tsDBKey] = db +} + +func getDB(ts *testscript.TestScript) *statedb.DB { + v := ts.Value(tsDBKey) + if v == nil { + ts.Fatalf("%q not set, call testutils.Setup()", tsDBKey) + } + return v.(*statedb.DB) +} + +func getTable(ts *testscript.TestScript, tableName string) (*statedb.DB, statedb.ReadTxn, statedb.AnyTable) { + db := getDB(ts) + txn := db.ReadTxn() + meta := db.GetTable(txn, tableName) + if meta == nil { + ts.Fatalf("table %q not found", tableName) + } + tbl := statedb.AnyTable{Meta: meta} + return db, txn, tbl +} + +var ( + Commands = map[string]Cmd{ + "db": DBCmd, + } + SubCommands = map[string]Cmd{ + "tables": TablesCmd, + "show": ShowTableCmd, + "write": WriteTableCmd, + "cmp": CompareTableCmd, + "insert": InsertCmd, + "delete": DeleteCmd, + "prefix": PrefixCmd, + "lowerbound": LowerBoundCmd, + } +) + +func DBCmd(ts *testscript.TestScript, neg bool, args []string) { + if len(args) < 1 { + ts.Fatalf("usage: db args...\n is one of %v", maps.Keys(SubCommands)) + } + if cmd, ok := SubCommands[args[0]]; ok { + cmd(ts, neg, args[1:]) + } else { + ts.Fatalf("unknown db command %q, should be one of %v", args[0], maps.Keys(SubCommands)) + } +} + +func TablesCmd(ts *testscript.TestScript, neg bool, args []string) { + db := getDB(ts) + txn := db.ReadTxn() + tbls := db.GetTables(txn) + var buf bytes.Buffer + w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) + fmt.Fprintf(w, "Name\tObject count\tIndexes\n") + for _, tbl := range tbls { + idxs := strings.Join(tbl.Indexes(), ", ") + fmt.Fprintf(w, "%s\t%d\t%s\n", tbl.Name(), tbl.NumObjects(txn), idxs) + } + w.Flush() + ts.Logf("%s", buf.String()) +} + +func ShowTableCmd(ts *testscript.TestScript, neg bool, args []string) { + if len(args) != 1 { + ts.Fatalf("usage: show_table ") + } + ts.Logf("%s", showTable(ts, args[0]).String()) +} + +func WriteTableCmd(ts *testscript.TestScript, neg bool, args []string) { + if len(args) < 1 || len(args) > 5 { + ts.Fatalf("usage: write_table
(-to=) (-columns=) (-format={table*,yaml})") + } + var flags flag.FlagSet + file := flags.String("to", "", "File to write to instead of stdout") + columns := flags.String("columns", "", "Comma-separated list of columns to write") + format := flags.String("format", "table", "Format to write in") + + // Sort the args to allow the table name at any position. + slices.SortFunc(args, func(a, b string) int { + switch { + case a[0] == '-': + return 1 + case b[0] == '-': + return -1 + default: + return cmp.Compare(a, b) + } + }) + + if err := flags.Parse(args[1:]); err != nil { + ts.Fatalf("bad args: %s", err) + } + tableName := args[0] + + switch *format { + case "yaml", "json": + if len(*columns) > 0 { + ts.Fatalf("-columns not supported with -format=yaml/json") + } + + _, txn, tbl := getTable(ts, tableName) + var buf bytes.Buffer + count := tbl.Meta.NumObjects(txn) + for obj := range tbl.All(txn) { + if *format == "yaml" { + out, err := yaml.Marshal(obj) + if err != nil { + ts.Fatalf("yaml.Marshal: %s", err) + } + buf.Write(out) + if count > 1 { + buf.WriteString("---\n") + } + } else { + out, err := json.Marshal(obj) + if err != nil { + ts.Fatalf("json.Marshal: %s", err) + } + buf.Write(out) + buf.WriteByte('\n') + } + count-- + } + if *file == "" { + ts.Logf("%s", buf.String()) + } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { + ts.Fatalf("WriteFile(%s): %s", *file, err) + } + default: + var cols []string + if len(*columns) > 0 { + cols = strings.Split(*columns, ",") + } + buf := showTable(ts, tableName, cols...) + if *file == "" { + ts.Logf("%s", buf.String()) + } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { + ts.Fatalf("WriteFile(%s): %s", *file, err) + } + } +} + +func CompareTableCmd(ts *testscript.TestScript, neg bool, args []string) { + var flags flag.FlagSet + timeout := flags.Duration("timeout", time.Second, "Maximum amount of time to wait for the table contents to match") + grep := flags.String("grep", "", "Grep the result rows and only compare matching ones") + + err := flags.Parse(args) + args = args[len(args)-flags.NArg():] + if err != nil || len(args) != 2 { + ts.Fatalf("usage: cmp (-timeout=) (-grep=)
") + } + + var grepRe *regexp.Regexp + if *grep != "" { + grepRe, err = regexp.Compile(*grep) + if err != nil { + ts.Fatalf("bad grep: %s", err) + } + } + + tableName := args[0] + db, _, tbl := getTable(ts, tableName) + header := tbl.TableHeader() + + data := ts.ReadFile(args[1]) + lines := strings.Split(data, "\n") + lines = slices.DeleteFunc(lines, func(line string) bool { + return strings.TrimSpace(line) == "" + }) + if len(lines) < 1 { + ts.Fatalf("%q missing header line, e.g. %q", args[1], strings.Join(header, " ")) + } + + columnNames, columnPositions := splitHeaderLine(lines[0]) + columnIndexes, err := getColumnIndexes(columnNames, header) + if err != nil { + ts.Fatalf("%s", err) + } + lines = lines[1:] + origLines := lines + tryUntil := time.Now().Add(*timeout) + + for { + lines = origLines + + // Create the diff between 'lines' and the rows in the table. + equal := true + var diff bytes.Buffer + w := tabwriter.NewWriter(&diff, 5, 4, 3, ' ', 0) + fmt.Fprintf(w, " %s\n", joinByPositions(columnNames, columnPositions)) + + for obj := range tbl.All(db.ReadTxn()) { + rowRaw := takeColumns(obj.(statedb.TableWritable).TableRow(), columnIndexes) + row := joinByPositions(rowRaw, columnPositions) + if grepRe != nil && !grepRe.Match([]byte(row)) { + continue + } + + if len(lines) == 0 { + equal = false + fmt.Fprintf(w, "- %s\n", row) + continue + } + line := lines[0] + splitLine := splitByPositions(line, columnPositions) + + if slices.Equal(rowRaw, splitLine) { + fmt.Fprintf(w, " %s\n", row) + } else { + fmt.Fprintf(w, "- %s\n", row) + fmt.Fprintf(w, "+ %s\n", line) + equal = false + } + lines = lines[1:] + } + for _, line := range lines { + fmt.Fprintf(w, "+ %s\n", line) + equal = false + } + if equal { + return + } + w.Flush() + + if time.Now().After(tryUntil) { + ts.Fatalf("table mismatch:\n%s", diff.String()) + } + time.Sleep(10 * time.Millisecond) + } +} + +func InsertCmd(ts *testscript.TestScript, neg bool, args []string) { + insertOrDeleteCmd(ts, true, args) +} + +func DeleteCmd(ts *testscript.TestScript, neg bool, args []string) { + insertOrDeleteCmd(ts, false, args) +} + +func insertOrDeleteCmd(ts *testscript.TestScript, insert bool, args []string) { + if len(args) < 2 { + if insert { + ts.Fatalf("usage: insert
path...") + } else { + ts.Fatalf("usage: delete
path...") + } + } + + db, _, tbl := getTable(ts, args[0]) + wtxn := db.WriteTxn(tbl.Meta) + defer wtxn.Commit() + + for _, arg := range args[1:] { + data := ts.ReadFile(arg) + parts := strings.Split(data, "---") + for _, part := range parts { + obj, err := tbl.UnmarshalYAML([]byte(part)) + if err != nil { + ts.Fatalf("Unmarshal(%s): %s", arg, err) + } + if insert { + _, _, err = tbl.Insert(wtxn, obj) + if err != nil { + ts.Fatalf("Insert(%s): %s", arg, err) + } + } else { + _, _, err = tbl.Delete(wtxn, obj) + if err != nil { + ts.Fatalf("Delete(%s): %s", arg, err) + } + + } + } + } +} + +func PrefixCmd(ts *testscript.TestScript, neg bool, args []string) { + prefixOrLowerboundCmd(ts, false, args) +} + +func LowerBoundCmd(ts *testscript.TestScript, neg bool, args []string) { + prefixOrLowerboundCmd(ts, true, args) +} + +func prefixOrLowerboundCmd(ts *testscript.TestScript, lowerbound bool, args []string) { + db := getDB(ts) + if len(args) < 2 { + if lowerbound { + ts.Fatalf("usage: lowerbound
(-to=)") + } else { + ts.Fatalf("usage: prefix
(-to=)") + } + } + + var flags flag.FlagSet + file := flags.String("to", "", "File to write to instead of stdout") + if err := flags.Parse(args[2:]); err != nil { + ts.Fatalf("bad args: %s", err) + } + + txn := db.ReadTxn() + meta := db.GetTable(txn, args[0]) + if meta == nil { + ts.Fatalf("table %q not found", args[0]) + } + tbl := statedb.AnyTable{Meta: meta} + var buf bytes.Buffer + w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) + header := tbl.TableHeader() + fmt.Fprintf(w, "%s\n", strings.Join(header, "\t")) + + var it iter.Seq2[any, uint64] + if lowerbound { + it = tbl.LowerBound(txn, args[1]) + } else { + it = tbl.Prefix(txn, args[1]) + } + + for obj := range it { + row := obj.(statedb.TableWritable).TableRow() + fmt.Fprintf(w, "%s\n", strings.Join(row, "\t")) + } + w.Flush() + if *file == "" { + ts.Logf("%s", buf.String()) + } else if err := os.WriteFile(ts.MkAbs(*file), buf.Bytes(), 0644); err != nil { + ts.Fatalf("WriteFile(%s): %s", *file, err) + } +} + +// splitHeaderLine takes a header of column names separated by any +// number of whitespaces and returns the names and their starting positions. +// e.g. "Foo Bar Baz" would result in ([Foo,Bar,Baz],[0,5,9]). +// With this information we can take a row in the database and format it +// the same way as our test data. +func splitHeaderLine(line string) (names []string, pos []int) { + start := 0 + skip := true + for i, r := range line { + switch r { + case ' ', '\t': + if !skip { + names = append(names, line[start:i]) + pos = append(pos, start) + start = -1 + } + skip = true + default: + skip = false + if start == -1 { + start = i + } + } + } + if start >= 0 && start < len(line) { + names = append(names, line[start:]) + pos = append(pos, start) + } + return +} + +// splitByPositions takes a "row" line and the positions of the header columns +// and extracts the values. +// e.g. if we have the positions [0,5,9] (from header "Foo Bar Baz") and +// line is "1 a b", then we'd extract [1,a,b]. +// The whitespace on the right of the start position (e.g. "1 \t") is trimmed. +// This of course requires that the table is properly formatted in a way that the +// header columns are indented to fit the data exactly. +func splitByPositions(line string, positions []int) []string { + out := make([]string, 0, len(positions)) + start := 0 + for _, pos := range positions[1:] { + if start >= len(line) { + out = append(out, "") + start = len(line) + continue + } + out = append(out, strings.TrimRight(line[start:min(pos, len(line))], " \t")) + start = pos + } + out = append(out, strings.TrimRight(line[min(start, len(line)):], " \t")) + return out +} + +// joinByPositions is the reverse of splitByPositions, it takes the columns of a +// row and the starting positions of each and joins into a single line. +// e.g. [1,a,b] and positions [0,5,9] expands to "1 a b". +// NOTE: This does not deal well with mixing tabs and spaces. The test input +// data should preferably just use spaces. +func joinByPositions(row []string, positions []int) string { + var w strings.Builder + prev := 0 + for i, pos := range positions { + for pad := pos - prev; pad > 0; pad-- { + w.WriteByte(' ') + } + w.WriteString(row[i]) + prev = pos + len(row[i]) + } + return w.String() +} + +func showTable(ts *testscript.TestScript, tableName string, columns ...string) *bytes.Buffer { + db := getDB(ts) + txn := db.ReadTxn() + meta := db.GetTable(txn, tableName) + if meta == nil { + ts.Fatalf("table %q not found", tableName) + } + tbl := statedb.AnyTable{Meta: meta} + + header := tbl.TableHeader() + if header == nil { + ts.Fatalf("objects in table %q not TableWritable", meta.Name()) + } + var idxs []int + var err error + if len(columns) > 0 { + idxs, err = getColumnIndexes(columns, header) + header = columns + } else { + idxs, err = getColumnIndexes(header, header) + } + if err != nil { + ts.Fatalf("%s", err) + } + + var buf bytes.Buffer + w := tabwriter.NewWriter(&buf, 5, 4, 3, ' ', 0) + fmt.Fprintf(w, "%s\n", strings.Join(header, "\t")) + for obj := range tbl.All(db.ReadTxn()) { + row := takeColumns(obj.(statedb.TableWritable).TableRow(), idxs) + fmt.Fprintf(w, "%s\n", strings.Join(row, "\t")) + } + w.Flush() + return &buf +} + +func takeColumns[T any](xs []T, idxs []int) []T { + // Invariant: idxs is sorted so can set in-place. + for i, idx := range idxs { + xs[i] = xs[idx] + } + return xs[:len(idxs)] +} + +func getColumnIndexes(names []string, header []string) ([]int, error) { + columnIndexes := make([]int, 0, len(header)) +loop: + for _, name := range names { + for i, name2 := range header { + if strings.EqualFold(name, name2) { + columnIndexes = append(columnIndexes, i) + continue loop + } + } + return nil, fmt.Errorf("column %q not part of %v", name, header) + } + return columnIndexes, nil +} diff --git a/testutils/script_test.go b/testutils/script_test.go new file mode 100644 index 0000000..a99106f --- /dev/null +++ b/testutils/script_test.go @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package testutils_test + +import ( + "flag" + "strings" + "testing" + + "github.com/cilium/statedb" + "github.com/cilium/statedb/index" + "github.com/cilium/statedb/testutils" + "github.com/rogpeppe/go-internal/testscript" +) + +type object struct { + Name string + Tags []string +} + +func (o object) TableHeader() []string { + return []string{"Name", "Tags"} +} + +func (o object) TableRow() []string { + return []string{ + o.Name, + strings.Join(o.Tags, ", "), + } +} + +var nameIdx = statedb.Index[object, string]{ + Name: "name", + FromObject: func(obj object) index.KeySet { + return index.NewKeySet(index.String(obj.Name)) + }, + FromKey: index.String, + Unique: true, +} + +var update = flag.Bool("update", false, "update the txtar files") + +func TestScriptCommands(t *testing.T) { + testscript.Run(t, testscript.Params{ + Dir: "testdata", + Setup: func(e *testscript.Env) error { + db := statedb.New() + tbl, err := statedb.NewTable("names", nameIdx) + if err != nil { + t.Fatalf("NewTable: %s", err) + } + if err := db.RegisterTable(tbl); err != nil { + t.Fatalf("RegisterTable: %s", err) + } + testutils.Setup(e, db) + return nil + }, + Cmds: testutils.Commands, + UpdateScripts: *update, + }) +} diff --git a/testutils/testdata/test.txtar b/testutils/testdata/test.txtar new file mode 100644 index 0000000..e5e4e9e --- /dev/null +++ b/testutils/testdata/test.txtar @@ -0,0 +1,87 @@ +db tables +db show names + +db insert names data.yaml +db show names + +# Compare the contents of a table +db cmp names names.table + +# Compare against subset of the columns +db cmp names names_name.table + +# Compare the table with retries up to 10s (1s is default) +db cmp -timeout=10s names names.table + +# Compare only rows that match the grep pattern +db cmp -grep=^baz names baz.table + +# Write the table to a file with specific columns +db write names -to=out.table -columns=Name,Tags + +# Use the plain 'cmp'. You'll want to use 'UpdateScript' +# to create and update the expected output. +cmp out.table out_expected.table + +# Write the table to a file as yaml +db write names -to=out.yaml -format=yaml +cmp out.yaml out_expected.yaml + +# Prefix search the table with the primary key. Only useful +# for stringy primary keys. +db prefix names q +db prefix names ba -to=out_prefix_ba.table + +# LowerBound searches +db lowerbound names a -to=out_lb_a.table +cmp out_lb_a.table out_expected.table +db lowerbound names z -to=out_lb_z.table +cmp out_lb_z.table empty.table + +# Delete and check that it's empty. +db delete names quux-name.yaml +db cmp names baz.table +db cmp names out_prefix_ba.table + +db delete names data.yaml +db cmp names empty.table + +-- data.yaml -- +name: quux +tags: +- foo +- bar +--- +name: baz + +-- quux-name.yaml -- +name: quux + +-- names.table -- +Name Tags +baz +quux foo, bar + +-- names_name.table -- +Name +baz +quux + +-- baz.table -- +Name +baz + +-- empty.table -- +Name Tags +-- out_expected.table -- +Name Tags +baz +quux foo, bar +-- out_expected.yaml -- +name: baz +tags: [] +--- +name: quux +tags: + - foo + - bar diff --git a/types.go b/types.go index 5f817f8..ac28faa 100644 --- a/types.go +++ b/types.go @@ -28,22 +28,6 @@ type Table[Obj any] interface { // Useful for generic utilities that need access to the primary key. PrimaryIndexer() Indexer[Obj] - // NumObjects returns the number of objects stored in the table. - NumObjects(ReadTxn) int - - // Initialized returns true if in this ReadTxn (snapshot of the database) - // the registered initializers have all been completed. The returned - // watch channel will be closed when the table becomes initialized. - Initialized(ReadTxn) (bool, <-chan struct{}) - - // PendingInitializers returns the set of pending initializers that - // have not yet completed. - PendingInitializers(ReadTxn) []string - - // Revision of the table. Constant for a read transaction, but - // increments in a write transaction on each Insert and Delete. - Revision(ReadTxn) Revision - // All returns a sequence of all objects in the table. All(ReadTxn) iter.Seq2[Obj, Revision] @@ -215,8 +199,33 @@ type RWTable[Obj any] interface { // TableMeta provides information about the table that is independent of // the object type (the 'Obj' constraint). type TableMeta interface { - Name() TableName // The name of the table + // Name returns the name of the table + Name() TableName + + // Indexes returns the names of the indexes + Indexes() []string + // NumObjects returns the number of objects stored in the table. + NumObjects(ReadTxn) int + + // Initialized returns true if in this ReadTxn (snapshot of the database) + // the registered initializers have all been completed. The returned + // watch channel will be closed when the table becomes initialized. + Initialized(ReadTxn) (bool, <-chan struct{}) + + // PendingInitializers returns the set of pending initializers that + // have not yet completed. + PendingInitializers(ReadTxn) []string + + // Revision of the table. Constant for a read transaction, but + // increments in a write transaction on each Insert and Delete. + Revision(ReadTxn) Revision + + // Internal unexported methods used only internally. + tableInternal +} + +type tableInternal interface { tableEntry() tableEntry tablePos() int setTablePos(int) @@ -226,13 +235,8 @@ type TableMeta interface { secondary() map[string]anyIndexer // Secondary indexers (if any) sortableMutex() internal.SortableMutex // The sortable mutex for locking the table for writing anyChanges(txn WriteTxn) (anyChangeIterator, error) -} - -// Iterator for iterating objects returned from queries. -type Iterator[Obj any] interface { - // Next returns the next object and its revision if ok is true, otherwise - // zero values to mean that the iteration has finished. - Next() (obj Obj, rev Revision, ok bool) + proto() any // Returns the zero value of 'Obj', e.g. the prototype + unmarshalYAML(data []byte) (any, error) // Unmarshal the data into 'Obj' } type ReadTxn interface {