diff --git a/db.go b/db.go index 8c28150..463e260 100644 --- a/db.go +++ b/db.go @@ -95,7 +95,7 @@ type DB struct { defaultHandle Handle } -type dbRoot []tableEntry +type dbRoot = []tableEntry func NewDB(tables []TableMeta, metrics Metrics) (*DB, error) { db := &DB{ diff --git a/http.go b/http.go new file mode 100644 index 0000000..e463686 --- /dev/null +++ b/http.go @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "encoding/base64" + "encoding/json" + "io" + "net/http" + + "github.com/cilium/statedb/part" +) + +func (db *DB) HTTPHandler() http.Handler { + h := dbHandler{db} + mux := http.NewServeMux() + mux.HandleFunc("GET /dump", h.dumpAll) + mux.HandleFunc("GET /dump/{table}", h.dumpTable) + mux.HandleFunc("/query", h.query) + return mux +} + +type dbHandler struct { + db *DB +} + +func (h dbHandler) dumpAll(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + h.db.ReadTxn().WriteJSON(w) +} + +func (h dbHandler) dumpTable(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + var err error + if table := r.PathValue("table"); table != "" { + err = h.db.ReadTxn().WriteJSON(w, r.PathValue("table")) + } else { + err = h.db.ReadTxn().WriteJSON(w) + } + if err != nil { + panic(err) + } +} + +func (h dbHandler) query(w http.ResponseWriter, r *http.Request) { + enc := json.NewEncoder(w) + + var req QueryRequest + body, err := io.ReadAll(r.Body) + r.Body.Close() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + enc.Encode(QueryResponse{Err: err.Error()}) + return + } + + if err := json.Unmarshal(body, &req); err != nil { + w.WriteHeader(http.StatusBadRequest) + enc.Encode(QueryResponse{Err: err.Error()}) + return + } + + queryKey, err := base64.StdEncoding.DecodeString(req.Key) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + enc.Encode(QueryResponse{Err: err.Error()}) + return + } + + txn := h.db.ReadTxn().getTxn() + + // Look up the table + var table TableMeta + for _, e := range txn.root { + if e.meta.Name() == req.Table { + table = e.meta + } + } + if table == nil { + w.WriteHeader(http.StatusNotFound) + enc.Encode(QueryResponse{Err: err.Error()}) + return + } + + indexPos := table.indexPos(req.Index) + + indexTxn, err := txn.indexReadTxn(table, indexPos) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + enc.Encode(QueryResponse{Err: err.Error()}) + return + } + + w.WriteHeader(http.StatusOK) + onObject := func(obj object) error { + return enc.Encode(QueryResponse{ + Rev: obj.revision, + Obj: obj.data, + }) + } + runQuery(indexTxn, req.LowerBound, queryKey, onObject) +} + +type QueryRequest struct { + Key string `json:"key"` // Base64 encoded query key + Table string `json:"table"` + Index string `json:"index"` + LowerBound bool `json:"lowerbound"` +} + +type QueryResponse struct { + Rev uint64 `json:"rev"` + Obj any `json:"obj"` + Err string `json:"err,omitempty"` +} + +func runQuery(indexTxn indexReadTxn, lowerbound bool, queryKey []byte, onObject func(object) error) { + var iter *part.Iterator[object] + if lowerbound { + iter = indexTxn.LowerBound(queryKey) + } else { + iter, _ = indexTxn.Prefix(queryKey) + } + var match func([]byte) bool + switch { + case lowerbound: + match = func([]byte) bool { return true } + case indexTxn.unique: + match = func(k []byte) bool { return len(k) == len(queryKey) } + default: + match = func(k []byte) bool { + _, secondary := decodeNonUniqueKey(k) + return len(secondary) == len(queryKey) + } + } + for key, obj, ok := iter.Next(); ok; key, obj, ok = iter.Next() { + if !match(key) { + continue + } + if err := onObject(obj); err != nil { + panic(err) + } + } +} diff --git a/http_client.go b/http_client.go new file mode 100644 index 0000000..ee57522 --- /dev/null +++ b/http_client.go @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" +) + +// NewRemoteTable creates a new handle for querying a remote StateDB table over the HTTP. +// Example usage: +// +// devices := statedb.NewRemoteTable[*tables.Device](url.Parse("http://localhost:8080/db"), "devices") +// +// // Get all devices ordered by name. +// iter, errs := devices.LowerBound(ctx, tables.DeviceByName("")) +// for device, revision, ok := iter.Next(); ok; device, revision, ok = iter.Next() { ... } +// +// // Get device by name. +// iter, errs := devices.Get(ctx, tables.DeviceByName("eth0")) +// if dev, revision, ok := iter.Next(); ok { ... } +// +// // Get devices in revision order, e.g. oldest changed devices first. +// iter, errs = devices.LowerBound(ctx, statedb.ByRevision(0)) +func NewRemoteTable[Obj any](base *url.URL, table TableName) *RemoteTable[Obj] { + return &RemoteTable[Obj]{base: base, tableName: table} +} + +type RemoteTable[Obj any] struct { + base *url.URL + tableName TableName +} + +func (t *RemoteTable[Obj]) query(ctx context.Context, lowerBound bool, q Query[Obj]) (iter Iterator[Obj], errChan <-chan error) { + // Use a channel to return errors so we can use the same Iterator[Obj] interface as StateDB does. + errChanSend := make(chan error, 1) + errChan = errChanSend + + key := base64.StdEncoding.EncodeToString(q.key) + queryReq := QueryRequest{ + Key: key, + Table: t.tableName, + Index: q.index, + LowerBound: lowerBound, + } + bs, err := json.Marshal(&queryReq) + if err != nil { + errChanSend <- err + return + } + + url := t.base.JoinPath("/query") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewBuffer(bs)) + if err != nil { + errChanSend <- err + return + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + errChanSend <- err + return + } + return &remoteGetIterator[Obj]{json.NewDecoder(resp.Body), errChanSend}, errChan +} +func (t *RemoteTable[Obj]) Get(ctx context.Context, q Query[Obj]) (Iterator[Obj], <-chan error) { + return t.query(ctx, false, q) +} + +func (t *RemoteTable[Obj]) LowerBound(ctx context.Context, q Query[Obj]) (Iterator[Obj], <-chan error) { + return t.query(ctx, true, q) +} + +type remoteGetIterator[Obj any] struct { + decoder *json.Decoder + errChan chan error +} + +// responseObject is a typed counterpart of [queryResponseObject] +type responseObject[Obj any] struct { + Rev uint64 `json:"rev"` + Obj Obj `json:"obj"` + Err string `json:"err,omitempty"` +} + +func (it *remoteGetIterator[Obj]) Next() (obj Obj, revision Revision, ok bool) { + if it.decoder == nil { + return + } + + var resp responseObject[Obj] + err := it.decoder.Decode(&resp) + errString := "" + if err != nil { + if errors.Is(err, io.EOF) { + close(it.errChan) + return + } + errString = err.Error() + } else { + errString = resp.Err + } + if errString != "" { + it.decoder = nil + it.errChan <- errors.New(errString) + return + } + + obj = resp.Obj + revision = resp.Rev + ok = true + return +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..dc343dd --- /dev/null +++ b/http_test.go @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cilium/statedb/index" +) + +func httpFixture(t *testing.T) (*DB, Table[testObject], *httptest.Server) { + db, table, _ := newTestDB(t, tagsIndex) + + ts := httptest.NewServer(db.HTTPHandler()) + t.Cleanup(ts.Close) + + wtxn := db.WriteTxn(table) + table.Insert(wtxn, testObject{1, []string{"foo"}}) + table.Insert(wtxn, testObject{2, []string{"foo"}}) + table.Insert(wtxn, testObject{3, []string{"foobar"}}) + table.Insert(wtxn, testObject{4, []string{"baz"}}) + wtxn.Commit() + + return db, table, ts +} + +func Test_http_dump(t *testing.T) { + _, _, ts := httpFixture(t) + + resp, err := http.Get(ts.URL + "/dump") + require.NoError(t, err, "Get(/dump)") + require.Equal(t, http.StatusOK, resp.StatusCode) + + dump, err := io.ReadAll(resp.Body) + resp.Body.Close() + require.NoError(t, err, "ReadAll") + fmt.Printf("%s", dump) + + resp, err = http.Get(ts.URL + "/dump/test") + require.NoError(t, err, "Get(/dump/test)") + require.Equal(t, http.StatusOK, resp.StatusCode) + + dump, err = io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatal(err) + } + fmt.Printf("%s", dump) + +} + +func Test_runQuery(t *testing.T) { + db, table, _ := httpFixture(t) + txn := db.ReadTxn() + + // idIndex, unique + indexTxn, err := txn.getTxn().indexReadTxn(table, table.indexPos(idIndex.Name)) + require.NoError(t, err) + items := []object{} + onObject := func(obj object) error { + items = append(items, obj) + return nil + } + runQuery(indexTxn, false, index.Uint64(1), onObject) + if assert.Len(t, items, 1) { + assert.EqualValues(t, items[0].data.(testObject).ID, 1) + } + + // tagsIndex, non-unique + indexTxn, err = txn.getTxn().indexReadTxn(table, table.indexPos(tagsIndex.Name)) + require.NoError(t, err) + items = nil + runQuery(indexTxn, false, index.String("foo"), onObject) + + if assert.Len(t, items, 2) { + assert.EqualValues(t, items[0].data.(testObject).ID, 1) + assert.EqualValues(t, items[1].data.(testObject).ID, 2) + } + + // lower-bound on revision index + indexTxn, err = txn.getTxn().indexReadTxn(table, RevisionIndexPos) + require.NoError(t, err) + items = nil + runQuery(indexTxn, true, index.Uint64(0), onObject) + if assert.Len(t, items, 4) { + // Items are in revision (creation) order + assert.EqualValues(t, items[0].data.(testObject).ID, 1) + assert.EqualValues(t, items[1].data.(testObject).ID, 2) + assert.EqualValues(t, items[2].data.(testObject).ID, 3) + assert.EqualValues(t, items[3].data.(testObject).ID, 4) + } +} + +func Test_RemoteTable(t *testing.T) { + ctx := context.TODO() + _, table, ts := httpFixture(t) + + base, err := url.Parse(ts.URL) + require.NoError(t, err, "ParseURL") + + remoteTable := NewRemoteTable[testObject](base, table.Name()) + + iter, errs := remoteTable.Get(ctx, idIndex.Query(1)) + items := Collect(iter) + assert.NoError(t, <-errs, "Get(1)") + if assert.Len(t, items, 1) { + assert.EqualValues(t, 1, items[0].ID) + } + + iter, errs = remoteTable.LowerBound(ctx, idIndex.Query(0)) + items = Collect(iter) + assert.NoError(t, <-errs, "LowerBound(0)") + if assert.Len(t, items, 4) { + assert.EqualValues(t, 1, items[0].ID) + assert.EqualValues(t, 2, items[1].ID) + assert.EqualValues(t, 3, items[2].ID) + assert.EqualValues(t, 4, items[3].ID) + } +} diff --git a/txn.go b/txn.go index d325cb9..e0475c4 100644 --- a/txn.go +++ b/txn.go @@ -465,39 +465,53 @@ func (txn *txn) Commit() { *txn = zeroTxn } -// WriteJSON marshals out the whole database as JSON into the given writer. -func (txn *txn) WriteJSON(w io.Writer) error { +func writeTableAsJSON(buf *bufio.Writer, txn *txn, table *tableEntry) error { + indexTxn := txn.mustIndexReadTxn(table.meta, PrimaryIndexPos) + iter := indexTxn.Iterator() + + buf.WriteString(" \"" + table.meta.Name() + "\": [\n") + + _, obj, ok := iter.Next() + for ok { + buf.WriteString(" ") + bs, err := json.Marshal(obj.data) + if err != nil { + return err + } + buf.Write(bs) + _, obj, ok = iter.Next() + if ok { + buf.WriteString(",\n") + } else { + buf.WriteByte('\n') + } + } + buf.WriteString(" ]") + return nil +} + +// WriteJSON marshals out the database as JSON into the given writer. +// If tables are given then only these tables are written. +func (txn *txn) WriteJSON(w io.Writer, tables ...string) error { buf := bufio.NewWriter(w) buf.WriteString("{\n") first := true + for _, table := range txn.root { + if len(tables) > 0 && !slices.Contains(tables, table.meta.Name()) { + continue + } + if !first { buf.WriteString(",\n") } else { first = false } - indexTxn := txn.getTxn().mustIndexReadTxn(table.meta, PrimaryIndexPos) - iter := indexTxn.Iterator() - - buf.WriteString(" \"" + table.meta.Name() + "\": [\n") - - _, obj, ok := iter.Next() - for ok { - buf.WriteString(" ") - bs, err := json.Marshal(obj.data) - if err != nil { - return err - } - buf.Write(bs) - _, obj, ok = iter.Next() - if ok { - buf.WriteString(",\n") - } else { - buf.WriteByte('\n') - } + err := writeTableAsJSON(buf, txn, &table) + if err != nil { + return err } - buf.WriteString(" ]") } buf.WriteString("\n}\n") return buf.Flush() diff --git a/types.go b/types.go index d55ec04..ae2247d 100644 --- a/types.go +++ b/types.go @@ -176,7 +176,7 @@ type ReadTxn interface { getTxn() *txn // WriteJSON writes the contents of the database as JSON. - WriteJSON(io.Writer) error + WriteJSON(w io.Writer, tables ...string) error } type WriteTxn interface {