Skip to content

Commit

Permalink
server: add support for fully-programmatic data sources (#24)
Browse files Browse the repository at this point in the history
Allow a caller to hook up databases that do not require use of the database/sql
package directly.

The substance of the change is to add Queryable and RowSet interfaces, and to
rework the query plumbing to use them. Existing use of *sql.DB is shimmed with
a wrapper implementing Queryable.

I wanted to use a generic interface so that a shim would not be needed, but
this turned out to be more trouble than it was worth, since it infects the rest of
the package with type parameters. It was (much) simpler to use the wrapper.
  • Loading branch information
creachadair authored Jul 15, 2024
1 parent 6fb9904 commit f453151
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 72 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ opts := tailsql.Options{
}
```

Any number of sources can be configured this way. It is also possible to add new data sources dynamically at runtime using the `SetDB` method of the server. It is _not_ currently possible to remove data sources once added, however.
Any number of sources can be configured this way. It is also possible to add new data sources dynamically at runtime using the `SetDB` and `SetSource` methods of the server. It is _not_ currently possible to remove data sources once added, however.

### Tailscale Integration

Expand Down
7 changes: 7 additions & 0 deletions server/tailsql/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tailsql

import (
"database/sql"
"os"
"testing"

Expand All @@ -13,6 +14,12 @@ import (
"tailscale.com/tailcfg"
)

// Interface satisfaction checks.
var (
_ Queryable = sqlDB{}
_ RowSet = (*sql.Rows)(nil)
)

func TestCheckQuerySyntax(t *testing.T) {
tests := []struct {
query string
Expand Down
16 changes: 16 additions & 0 deletions server/tailsql/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,19 @@ func (s *localState) LogQuery(ctx context.Context, user string, q Query, elapsed

return tx.Commit()
}

// Query satisfies part of the Queryable interface. It supports only read queries.
func (s *localState) Query(ctx context.Context, query string, params ...any) (RowSet, error) {
s.txmu.RLock()
defer s.txmu.RUnlock()
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return nil, err
}
defer tx.Rollback()
return tx.QueryContext(ctx, query, params...)
}

// Close satisfies part of the Queryable interface. For this database the
// implementation is a no-op without error.
func (*localState) Close() error { return nil }
160 changes: 109 additions & 51 deletions server/tailsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ import (
)

// Options describes settings for a Server.
//
// The fields marked as "tsnet" are not used directly by tailsql, but are
// provided for the convenience of a main program that wants to run the server
// under tsnet.
type Options struct {
// The tailnet hostname the server should run on (required).
// The tailnet hostname the server should run on (tsnet).
Hostname string `json:"hostname,omitempty"`

// The directory for tailscale state and configurations (optional).
// If omitted or empty, the default location is used.
// The directory for tailscale state and configurations (tsnet).
StateDir string `json:"stateDir,omitempty"`

// If true, serve HTTPS instead of HTTP.
// If true, serve HTTPS instead of HTTP (tsnet).
ServeHTTPS bool `json:"serveHTTPS,omitempty"`

// If non-empty, a SQLite database URL to use for local state.
Expand Down Expand Up @@ -116,6 +119,19 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
spec.Label = "(unidentified database)"
}

// Case 1: A programmatic source.
if spec.DB != nil {
srcs[i] = &dbHandle{
src: spec.Source,
label: spec.Label,
named: spec.Named,
db: spec.DB,
}
continue
}

// Case 2: A database managed by database/sql.
//
// Resolve the connection string.
var connString string
var w setec.Watcher
Expand All @@ -135,6 +151,7 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
panic("unexpected: no connection source is defined after validation")
}

// Open and ping the database to ensure it is approximately usable.
db, err := openAndPing(spec.Driver, connString)
if err != nil {
return nil, err
Expand All @@ -144,7 +161,7 @@ func (o Options) openSources(store *setec.Store) ([]*dbHandle, error) {
driver: spec.Driver,
label: spec.Label,
named: spec.Named,
db: db,
db: sqlDB{DB: db},
}
if spec.Secret != "" {
go srcs[i].handleUpdates(spec.Secret, w, o.logf())
Expand Down Expand Up @@ -194,14 +211,6 @@ func (o Options) localState() (*localState, error) {
return newLocalState(db)
}

func (o Options) readOnlyLocalState() (*sql.DB, error) {
if o.LocalState == "" {
return nil, errors.New("no local state")
}
url := "file:" + os.ExpandEnv(o.LocalState) + "?mode=ro"
return sql.Open("sqlite", url)
}

func (o Options) logf() logger.Logf {
if o.Logf == nil {
return log.Printf
Expand Down Expand Up @@ -306,7 +315,7 @@ type dbHandle struct {
// Hold exclusive to replace or close db or to update label.
mu sync.RWMutex
label string
db *sql.DB
db Queryable
named map[string]string
}

Expand All @@ -332,7 +341,7 @@ func (h *dbHandle) handleUpdates(name string, w setec.Watcher, logf logger.Logf)
if up := h.checkUpdate(); up != nil {
up.newDB.Close()
}
h.db = db
h.db = sqlDB{DB: db}
h.mu.Unlock()
}
}
Expand Down Expand Up @@ -382,13 +391,12 @@ func (h *dbHandle) Named() map[string]string {
return h.named
}

// Tx calls f with a connection to the wrapped database while holding the lock.
// Any error reported by f is returned to the caller of Tx.
// Multiple callers can safely invoke Tx concurrently.
// Tx reports an error without calling f if h is closed.
// WithLock calls f with the wrapped database while holding the lock.
// If f reports an error is returned to the caller of WithLock.
// WithLock reports an error without calling f if h is closed.
// The context passed to f can be used to look up named queries on h using
// lookupNamedQuery.
func (h *dbHandle) Tx(ctx context.Context, f func(context.Context, *sql.Tx) error) error {
func (h *dbHandle) WithLock(ctx context.Context, f func(context.Context, Queryable) error) error {
h.mu.RLock()
defer h.mu.RUnlock()
if h.db == nil {
Expand All @@ -399,20 +407,11 @@ func (h *dbHandle) Tx(ctx context.Context, f func(context.Context, *sql.Tx) erro
// safe, but to prevent the handle from being swapped (and the database
// closed) while connections are in-flight.
//
// For our uses we could mark transactions ReadOnly, but not all database
// drivers support that option (notably Snowflake does not).

tx, err := h.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()

// Attach the handle to the context during the lifetime of f. This ensures
// that f has access to named queries and other options from h while holding
// the lock on h.
fctx := context.WithValue(ctx, dbHandleKey{}, h)
return f(fctx, tx) // we only read, no commit is needed
return f(fctx, h.db)
}

type dbHandleKey struct{}
Expand All @@ -434,7 +433,7 @@ func lookupNamedQuery(ctx context.Context, name string) (string, bool) {
// and newLabel, and closes the original value. The caller is responsible for
// closing a database handle when it is no longer in use. It will panic if
// newDB == nil, or if h is closed.
func (h *dbHandle) swap(newDB *sql.DB, newOpts *DBOptions) {
func (h *dbHandle) swap(newDB Queryable, newOpts *DBOptions) {
if newDB == nil {
panic("new database is nil")
}
Expand Down Expand Up @@ -468,7 +467,7 @@ func (h *dbHandle) swap(newDB *sql.DB, newOpts *DBOptions) {
// A dbUpdate is an open database handle, label, and set of named queries that
// are ready to be installed in a database handle.
type dbUpdate struct {
newDB *sql.DB
newDB Queryable
label string
named map[string]string
}
Expand Down Expand Up @@ -517,28 +516,38 @@ func (d *Duration) UnmarshalText(data []byte) error {
}

// A DBSpec describes a database that the server should use.
//
// The Source must be non-empty, and exactly one of URL, KeyFile, Secret, or DB
// must be set.
//
// - If DB is set, it is used directly as the database to query, and no
// connection is established.
//
// Otherwise, the Driver must be non-empty and the [database/sql] library is
// used to open a connection to the specified database:
//
// - If URL is set, it is used directly as the connection string.
//
// - If KeyFile is set, it names the location of a file containing the
// connection string. If set, KeyFile is expanded by os.ExpandEnv.
//
// - Otherwise, Secret is the name of a secret to fetch from the secrets
// service, whose value is the connection string. This requires that a
// secrets server be configured in the options.
type DBSpec struct {
Source string `json:"source"` // UI slug
Source string `json:"source"` // UI slug (required)
Label string `json:"label,omitempty"` // descriptive label
Driver string `json:"driver,omitempty"` // e.g., "sqlite", "snowflake"

// Named is an optional map of named SQL queries the database should expose.
Named map[string]string `json:"named,omitempty"`

// Exactly one of the following fields must be set.
//
// If URL is set, it is used directly as the connection string.
//
// If KeyFile is set, it names the location of a file containing the
// connection string. If set, KeyFile is expanded by os.ExpandEnv.
//
// Otherwise, Secret is the name of a secret to fetch from the secrets
// service, whose value is the connection string. This requires that a
// secrets server be configured in the options.
// Exactly one of the fields below must be set.

URL string `json:"url,omitempty"` // path or connection URL
KeyFile string `json:"keyFile,omitempty"` // path to key file
Secret string `json:"secret,omitempty"` // name of secret
URL string `json:"url,omitempty"` // path or connection URL
KeyFile string `json:"keyFile,omitempty"` // path to key file
Secret string `json:"secret,omitempty"` // name of secret
DB Queryable `json:"-"` // programmatic data source
}

func (d *DBSpec) countFields() (n int) {
Expand All @@ -551,12 +560,22 @@ func (d *DBSpec) countFields() (n int) {
}

func (d *DBSpec) checkValid() error {
switch {
case d.Source == "":
return errors.New("missing source")
case d.Driver == "":
if d.Source == "" {
return errors.New("missing source name")
}

// Case 1: A programmatic data source.
if d.DB != nil {
if d.countFields() != 0 {
return errors.New("no connection string is allowed when DB is set")
}
return nil
}

// Case 2: A database/sql database.
if d.Driver == "" {
return errors.New("missing driver name")
case d.countFields() != 1:
} else if d.countFields() != 1 {
return errors.New("exactly one connection source must be set")
}
return nil
Expand Down Expand Up @@ -610,3 +629,42 @@ func DefaultCheckQuery(q Query) (Query, error) {
}
return q, nil
}

// Queryable is the interface used to issue SQL queries to a database.
type Queryable interface {
// Query issues the specified SQL query in a transaction and returns the
// matching result set, if any.
Query(ctx context.Context, sql string, params ...any) (RowSet, error)

// Close closes the database.
Close() error
}

// A RowSet is a sequence of rows reported by a query. It is a subset of the
// interface exposed by [database/sql.Rows], and the implementation must
// provide the same semantics for each of these methods.
type RowSet interface {
// Columns reports the names of the columns requested by the query.
Columns() ([]string, error)

// Close closes the row set, preventing further enumeration.
Close() error

// Err returns the error, if any, that was encountered during iteration.
Err() error

// Next prepares the next result row for reading by the Scan method, and
// reports true if this was successful or false if there was an error or no
// more rows are available.
Next() bool

// Scan copies the columns of the currently-selected row into the values
// pointed to by its arguments.
Scan(...any) error
}

type sqlDB struct{ *sql.DB }

func (s sqlDB) Query(ctx context.Context, query string, params ...any) (RowSet, error) {
return s.DB.QueryContext(ctx, query, params...)
}
30 changes: 15 additions & 15 deletions server/tailsql/tailsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,10 @@ func NewServer(opts Options) (*Server, error) {
return nil, fmt.Errorf("local state: %w", err)
}
if state != nil && opts.LocalSource != "" {
db, err := opts.readOnlyLocalState()
if err != nil {
return nil, fmt.Errorf("read-only local state: %w", err)
}
dbs = append(dbs, &dbHandle{
src: opts.LocalSource,
label: "tailsql local state",
db: db,
db: state,
named: map[string]string{
"schema": `select * from sqlite_schema`,
},
Expand All @@ -174,14 +170,20 @@ func NewServer(opts Options) (*Server, error) {
}

// SetDB adds or replaces the database associated with the specified source in
// s with the given open db and options.
// s with the given open db and options. See [SetSource].
func (s *Server) SetDB(source string, db *sql.DB, opts *DBOptions) bool {
return s.SetSource(source, sqlDB{DB: db}, opts)
}

// SetSource adds or replaces the database associated with the specified source
// in s with the given open db and options.
//
// If a database was already open for the given source, its value is replaced,
// the old database handle is closed, and SetDB reports true.
//
// If no database was already open for the given source, a new source is added
// and SetDB reports false.
func (s *Server) SetDB(source string, db *sql.DB, opts *DBOptions) bool {
func (s *Server) SetSource(source string, db Queryable, opts *DBOptions) bool {
if db == nil {
panic("new database is nil")
}
Expand Down Expand Up @@ -432,8 +434,8 @@ func (s *Server) queryContext(ctx context.Context, caller string, q Query) (*dbR
defer cancel()
}

return runQueryInTx(ctx, h,
func(fctx context.Context, tx *sql.Tx) (_ *dbResult, err error) {
return runQuery(ctx, h,
func(fctx context.Context, db Queryable) (_ *dbResult, err error) {
start := time.Now()
var out dbResult
defer func() {
Expand Down Expand Up @@ -461,19 +463,17 @@ func (s *Server) queryContext(ctx context.Context, caller string, q Query) (*dbR
q.Query = real
}

rows, err := tx.QueryContext(fctx, q.Query)
rows, err := db.Query(fctx, q.Query)
if err != nil {
return nil, err
}
defer rows.Close()

cols, err := rows.ColumnTypes()
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("listing column types: %w", err)
}
for _, col := range cols {
out.Columns = append(out.Columns, col.Name())
return nil, fmt.Errorf("listing column names: %w", err)
}
out.Columns = cols

var tooMany bool
for rows.Next() && !tooMany {
Expand Down
Loading

0 comments on commit f453151

Please sign in to comment.