From f453151cf73b663b740b8601ff7d32f22a15922a Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sun, 14 Jul 2024 19:29:10 -0700 Subject: [PATCH] server: add support for fully-programmatic data sources (#24) Allow a caller to hook up databases that do not require use of the database/sql package directly. The substance of the change is to add Queryable and RowSet interfaces, and to rework the query plumbing to use them. Existing use of *sql.DB is shimmed with a wrapper implementing Queryable. I wanted to use a generic interface so that a shim would not be needed, but this turned out to be more trouble than it was worth, since it infects the rest of the package with type parameters. It was (much) simpler to use the wrapper. --- README.md | 2 +- server/tailsql/internal_test.go | 7 ++ server/tailsql/local.go | 16 ++++ server/tailsql/options.go | 160 ++++++++++++++++++++++---------- server/tailsql/tailsql.go | 30 +++--- server/tailsql/tailsql_test.go | 41 ++++++++ server/tailsql/utils.go | 9 +- 7 files changed, 193 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 73d703f..a56feaa 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ opts := tailsql.Options{ } ``` -Any number of sources can be configured this way. It is also possible to add new data sources dynamically at runtime using the `SetDB` method of the server. It is _not_ currently possible to remove data sources once added, however. +Any number of sources can be configured this way. It is also possible to add new data sources dynamically at runtime using the `SetDB` and `SetSource` methods of the server. It is _not_ currently possible to remove data sources once added, however. ### Tailscale Integration diff --git a/server/tailsql/internal_test.go b/server/tailsql/internal_test.go index a64c5c1..4d9f266 100644 --- a/server/tailsql/internal_test.go +++ b/server/tailsql/internal_test.go @@ -4,6 +4,7 @@ package tailsql import ( + "database/sql" "os" "testing" @@ -13,6 +14,12 @@ import ( "tailscale.com/tailcfg" ) +// Interface satisfaction checks. +var ( + _ Queryable = sqlDB{} + _ RowSet = (*sql.Rows)(nil) +) + func TestCheckQuerySyntax(t *testing.T) { tests := []struct { query string diff --git a/server/tailsql/local.go b/server/tailsql/local.go index fc09577..c594a4e 100644 --- a/server/tailsql/local.go +++ b/server/tailsql/local.go @@ -93,3 +93,19 @@ func (s *localState) LogQuery(ctx context.Context, user string, q Query, elapsed return tx.Commit() } + +// Query satisfies part of the Queryable interface. It supports only read queries. +func (s *localState) Query(ctx context.Context, query string, params ...any) (RowSet, error) { + s.txmu.RLock() + defer s.txmu.RUnlock() + tx, err := s.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return nil, err + } + defer tx.Rollback() + return tx.QueryContext(ctx, query, params...) +} + +// Close satisfies part of the Queryable interface. For this database the +// implementation is a no-op without error. +func (*localState) Close() error { return nil } diff --git a/server/tailsql/options.go b/server/tailsql/options.go index 524e7b4..8353c8c 100644 --- a/server/tailsql/options.go +++ b/server/tailsql/options.go @@ -25,15 +25,18 @@ import ( ) // Options describes settings for a Server. +// +// The fields marked as "tsnet" are not used directly by tailsql, but are +// provided for the convenience of a main program that wants to run the server +// under tsnet. type Options struct { - // The tailnet hostname the server should run on (required). + // The tailnet hostname the server should run on (tsnet). Hostname string `json:"hostname,omitempty"` - // The directory for tailscale state and configurations (optional). - // If omitted or empty, the default location is used. + // The directory for tailscale state and configurations (tsnet). StateDir string `json:"stateDir,omitempty"` - // If true, serve HTTPS instead of HTTP. + // If true, serve HTTPS instead of HTTP (tsnet). ServeHTTPS bool `json:"serveHTTPS,omitempty"` // If non-empty, a SQLite database URL to use for local state. @@ -116,6 +119,19 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { spec.Label = "(unidentified database)" } + // Case 1: A programmatic source. + if spec.DB != nil { + srcs[i] = &dbHandle{ + src: spec.Source, + label: spec.Label, + named: spec.Named, + db: spec.DB, + } + continue + } + + // Case 2: A database managed by database/sql. + // // Resolve the connection string. var connString string var w setec.Watcher @@ -135,6 +151,7 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { panic("unexpected: no connection source is defined after validation") } + // Open and ping the database to ensure it is approximately usable. db, err := openAndPing(spec.Driver, connString) if err != nil { return nil, err @@ -144,7 +161,7 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) { driver: spec.Driver, label: spec.Label, named: spec.Named, - db: db, + db: sqlDB{DB: db}, } if spec.Secret != "" { go srcs[i].handleUpdates(spec.Secret, w, o.logf()) @@ -194,14 +211,6 @@ func (o Options) localState() (*localState, error) { return newLocalState(db) } -func (o Options) readOnlyLocalState() (*sql.DB, error) { - if o.LocalState == "" { - return nil, errors.New("no local state") - } - url := "file:" + os.ExpandEnv(o.LocalState) + "?mode=ro" - return sql.Open("sqlite", url) -} - func (o Options) logf() logger.Logf { if o.Logf == nil { return log.Printf @@ -306,7 +315,7 @@ type dbHandle struct { // Hold exclusive to replace or close db or to update label. mu sync.RWMutex label string - db *sql.DB + db Queryable named map[string]string } @@ -332,7 +341,7 @@ func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf) if up := h.checkUpdate(); up != nil { up.newDB.Close() } - h.db = db + h.db = sqlDB{DB: db} h.mu.Unlock() } } @@ -382,13 +391,12 @@ func (h *dbHandle) Named() map[string]string { return h.named } -// Tx calls f with a connection to the wrapped database while holding the lock. -// Any error reported by f is returned to the caller of Tx. -// Multiple callers can safely invoke Tx concurrently. -// Tx reports an error without calling f if h is closed. +// WithLock calls f with the wrapped database while holding the lock. +// If f reports an error is returned to the caller of WithLock. +// WithLock reports an error without calling f if h is closed. // The context passed to f can be used to look up named queries on h using // lookupNamedQuery. -func (h *dbHandle) Tx(ctx context.Context, f func(context.Context, *sql.Tx) error) error { +func (h *dbHandle) WithLock(ctx context.Context, f func(context.Context, Queryable) error) error { h.mu.RLock() defer h.mu.RUnlock() if h.db == nil { @@ -399,20 +407,11 @@ func (h *dbHandle) Tx(ctx context.Context, f func(context.Context, *sql.Tx) erro // safe, but to prevent the handle from being swapped (and the database // closed) while connections are in-flight. // - // For our uses we could mark transactions ReadOnly, but not all database - // drivers support that option (notably Snowflake does not). - - tx, err := h.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - // Attach the handle to the context during the lifetime of f. This ensures // that f has access to named queries and other options from h while holding // the lock on h. fctx := context.WithValue(ctx, dbHandleKey{}, h) - return f(fctx, tx) // we only read, no commit is needed + return f(fctx, h.db) } type dbHandleKey struct{} @@ -434,7 +433,7 @@ func lookupNamedQuery(ctx context.Context, name string) (string, bool) { // and newLabel, and closes the original value. The caller is responsible for // closing a database handle when it is no longer in use. It will panic if // newDB == nil, or if h is closed. -func (h *dbHandle) swap(newDB *sql.DB, newOpts *DBOptions) { +func (h *dbHandle) swap(newDB Queryable, newOpts *DBOptions) { if newDB == nil { panic("new database is nil") } @@ -468,7 +467,7 @@ func (h *dbHandle) swap(newDB *sql.DB, newOpts *DBOptions) { // A dbUpdate is an open database handle, label, and set of named queries that // are ready to be installed in a database handle. type dbUpdate struct { - newDB *sql.DB + newDB Queryable label string named map[string]string } @@ -517,28 +516,38 @@ func (d *Duration) UnmarshalText(data []byte) error { } // A DBSpec describes a database that the server should use. +// +// The Source must be non-empty, and exactly one of URL, KeyFile, Secret, or DB +// must be set. +// +// - If DB is set, it is used directly as the database to query, and no +// connection is established. +// +// Otherwise, the Driver must be non-empty and the [database/sql] library is +// used to open a connection to the specified database: +// +// - If URL is set, it is used directly as the connection string. +// +// - If KeyFile is set, it names the location of a file containing the +// connection string. If set, KeyFile is expanded by os.ExpandEnv. +// +// - Otherwise, Secret is the name of a secret to fetch from the secrets +// service, whose value is the connection string. This requires that a +// secrets server be configured in the options. type DBSpec struct { - Source string `json:"source"` // UI slug + Source string `json:"source"` // UI slug (required) Label string `json:"label,omitempty"` // descriptive label Driver string `json:"driver,omitempty"` // e.g., "sqlite", "snowflake" // Named is an optional map of named SQL queries the database should expose. Named map[string]string `json:"named,omitempty"` - // Exactly one of the following fields must be set. - // - // If URL is set, it is used directly as the connection string. - // - // If KeyFile is set, it names the location of a file containing the - // connection string. If set, KeyFile is expanded by os.ExpandEnv. - // - // Otherwise, Secret is the name of a secret to fetch from the secrets - // service, whose value is the connection string. This requires that a - // secrets server be configured in the options. + // Exactly one of the fields below must be set. - URL string `json:"url,omitempty"` // path or connection URL - KeyFile string `json:"keyFile,omitempty"` // path to key file - Secret string `json:"secret,omitempty"` // name of secret + URL string `json:"url,omitempty"` // path or connection URL + KeyFile string `json:"keyFile,omitempty"` // path to key file + Secret string `json:"secret,omitempty"` // name of secret + DB Queryable `json:"-"` // programmatic data source } func (d *DBSpec) countFields() (n int) { @@ -551,12 +560,22 @@ func (d *DBSpec) countFields() (n int) { } func (d *DBSpec) checkValid() error { - switch { - case d.Source == "": - return errors.New("missing source") - case d.Driver == "": + if d.Source == "" { + return errors.New("missing source name") + } + + // Case 1: A programmatic data source. + if d.DB != nil { + if d.countFields() != 0 { + return errors.New("no connection string is allowed when DB is set") + } + return nil + } + + // Case 2: A database/sql database. + if d.Driver == "" { return errors.New("missing driver name") - case d.countFields() != 1: + } else if d.countFields() != 1 { return errors.New("exactly one connection source must be set") } return nil @@ -610,3 +629,42 @@ func DefaultCheckQuery(q Query) (Query, error) { } return q, nil } + +// Queryable is the interface used to issue SQL queries to a database. +type Queryable interface { + // Query issues the specified SQL query in a transaction and returns the + // matching result set, if any. + Query(ctx context.Context, sql string, params ...any) (RowSet, error) + + // Close closes the database. + Close() error +} + +// A RowSet is a sequence of rows reported by a query. It is a subset of the +// interface exposed by [database/sql.Rows], and the implementation must +// provide the same semantics for each of these methods. +type RowSet interface { + // Columns reports the names of the columns requested by the query. + Columns() ([]string, error) + + // Close closes the row set, preventing further enumeration. + Close() error + + // Err returns the error, if any, that was encountered during iteration. + Err() error + + // Next prepares the next result row for reading by the Scan method, and + // reports true if this was successful or false if there was an error or no + // more rows are available. + Next() bool + + // Scan copies the columns of the currently-selected row into the values + // pointed to by its arguments. + Scan(...any) error +} + +type sqlDB struct{ *sql.DB } + +func (s sqlDB) Query(ctx context.Context, query string, params ...any) (RowSet, error) { + return s.DB.QueryContext(ctx, query, params...) +} diff --git a/server/tailsql/tailsql.go b/server/tailsql/tailsql.go index 6a8cf92..0266ee8 100644 --- a/server/tailsql/tailsql.go +++ b/server/tailsql/tailsql.go @@ -142,14 +142,10 @@ func NewServer(opts Options) (*Server, error) { return nil, fmt.Errorf("local state: %w", err) } if state != nil && opts.LocalSource != "" { - db, err := opts.readOnlyLocalState() - if err != nil { - return nil, fmt.Errorf("read-only local state: %w", err) - } dbs = append(dbs, &dbHandle{ src: opts.LocalSource, label: "tailsql local state", - db: db, + db: state, named: map[string]string{ "schema": `select * from sqlite_schema`, }, @@ -174,14 +170,20 @@ func NewServer(opts Options) (*Server, error) { } // SetDB adds or replaces the database associated with the specified source in -// s with the given open db and options. +// s with the given open db and options. See [SetSource]. +func (s *Server) SetDB(source string, db *sql.DB, opts *DBOptions) bool { + return s.SetSource(source, sqlDB{DB: db}, opts) +} + +// SetSource adds or replaces the database associated with the specified source +// in s with the given open db and options. // // If a database was already open for the given source, its value is replaced, // the old database handle is closed, and SetDB reports true. // // If no database was already open for the given source, a new source is added // and SetDB reports false. -func (s *Server) SetDB(source string, db *sql.DB, opts *DBOptions) bool { +func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool { if db == nil { panic("new database is nil") } @@ -432,8 +434,8 @@ func (s *Server) queryContext(ctx context.Context, caller string, q Query) (*dbR defer cancel() } - return runQueryInTx(ctx, h, - func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) { + return runQuery(ctx, h, + func(fctx context.Context, db Queryable) (_ *dbResult, err error) { start := time.Now() var out dbResult defer func() { @@ -461,19 +463,17 @@ func (s *Server) queryContext(ctx context.Context, caller string, q Query) (*dbR q.Query = real } - rows, err := tx.QueryContext(fctx, q.Query) + rows, err := db.Query(fctx, q.Query) if err != nil { return nil, err } defer rows.Close() - cols, err := rows.ColumnTypes() + cols, err := rows.Columns() if err != nil { - return nil, fmt.Errorf("listing column types: %w", err) - } - for _, col := range cols { - out.Columns = append(out.Columns, col.Name()) + return nil, fmt.Errorf("listing column names: %w", err) } + out.Columns = cols var tooMany bool for rows.Next() && !tooMany { diff --git a/server/tailsql/tailsql_test.go b/server/tailsql/tailsql_test.go index 867d805..40b98bc 100644 --- a/server/tailsql/tailsql_test.go +++ b/server/tailsql/tailsql_test.go @@ -483,3 +483,44 @@ func TestQueryTimeout(t *testing.T) { t.Fatal("Timeout waiting for query to end") } } + +func TestQueryable(t *testing.T) { + _, db := mustInitSQLite(t) + + s, err := tailsql.NewServer(tailsql.Options{ + Sources: []tailsql.DBSpec{{ + Source: "quux", + Label: "Happy fun database", + DB: sqlDB{DB: db}, + + // Note: Omit Driver to exercise that this is allowed when a + // programmatic data source is given. + }}, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + defer s.Close() + + htest := httptest.NewServer(s.NewMux()) + defer htest.Close() + cli := htest.Client() + + t.Run("Smoke", func(t *testing.T) { + const testProbe = "mindlesprocket" + q := url.Values{ + "src": {"quux"}, + "q": {fmt.Sprintf(`select '%s'`, testProbe)}, + } + rsp := string(mustGet(t, cli, htest.URL+"/csv?"+q.Encode())) + if !strings.Contains(rsp, testProbe) { + t.Errorf("Query failed: got %q, want %q", rsp, testProbe) + } + }) +} + +type sqlDB struct{ *sql.DB } + +func (s sqlDB) Query(ctx context.Context, query string, params ...any) (tailsql.RowSet, error) { + return s.DB.QueryContext(ctx, query, params...) +} diff --git a/server/tailsql/utils.go b/server/tailsql/utils.go index b8e5033..3996de3 100644 --- a/server/tailsql/utils.go +++ b/server/tailsql/utils.go @@ -5,7 +5,6 @@ package tailsql import ( "context" - "database/sql" "encoding/base64" "errors" "fmt" @@ -148,12 +147,12 @@ func isBinaryData(data []byte) bool { return false } -// runQueryInTx executes query using h.Tx, and returns its results. -func runQueryInTx[T any](ctx context.Context, h *dbHandle, query func(context.Context, *sql.Tx) (T, error)) (T, error) { +// runQuery executes query using h.WithLock, and returns its results. +func runQuery[T any](ctx context.Context, h *dbHandle, run func(context.Context, Queryable) (T, error)) (T, error) { var out T - err := h.Tx(ctx, func(fctx context.Context, tx *sql.Tx) error { + err := h.WithLock(ctx, func(fctx context.Context, q Queryable) error { var err error - out, err = query(fctx, tx) + out, err = run(fctx, q) return err })