Skip to content

Commit

Permalink
server: add a query check option (#22)
Browse files Browse the repository at this point in the history
Add an optional function allowing the caller to preprocess a query presented by
the user. This can be used to add syntax checks, normalization, and aliases.
  • Loading branch information
creachadair authored Jul 11, 2024
1 parent 820559f commit c03fe0b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
34 changes: 34 additions & 0 deletions server/tailsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,25 @@ type Options struct {
// by the rule replaces the original string.
UIRewriteRules []UIRewriteRule `json:"-"`

// If non-nil, call this function with each query presented to the API. If
// the function reports an error, the query fails; otherwise the returned
// query state is used to service the query. If nil, DefaultCheckQuery is
// used.
CheckQuery func(Query) (Query, error) `json:"-"`

// If non-nil, send logs to this logger. If nil, use log.Printf.
Logf logger.Logf `json:"-"`
}

// checkQuery returns the query check function specified by options, or a
// default that accepts all queries as given.
func (o Options) checkQuery() func(Query) (Query, error) {
if o.CheckQuery == nil {
return DefaultCheckQuery
}
return o.CheckQuery
}

// openSources opens database handles to each of the sources defined by o.
// Sources that require secrets will get them from store.
// Precondition: All the sources of o have already been validated.
Expand Down Expand Up @@ -576,3 +591,22 @@ func (o *DBOptions) namedQueries() map[string]string {
}
return o.NamedQueries
}

// A Query carries the parameters of a query presented to the API.
type Query struct {
Source string // the data source requested
Query string // the text of the query
}

// DefaultCheckQuery is the default query check function used if another is not
// specified in the Options. It accepts all queries for all sources, as long as
// the query text does not exceed 4000 bytes.
func DefaultCheckQuery(q Query) (Query, error) {
// Reject query strings that are egregiously too long.
const maxQueryBytes = 4000

if len(q.Query) > maxQueryBytes {
return q, errors.New("query too long")
}
return q, nil
}
33 changes: 16 additions & 17 deletions server/tailsql/tailsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ type Server struct {
links []UILink
rules []UIRewriteRule
authorize func(string, *apitype.WhoIsResponse) error
qcheck func(Query) (Query, error)
qtimeout time.Duration
logf logger.Logf

Expand Down Expand Up @@ -165,6 +166,7 @@ func NewServer(opts Options) (*Server, error) {
links: opts.UILinks,
rules: opts.UIRewriteRules,
authorize: opts.authorize(),
qcheck: opts.checkQuery(),
qtimeout: opts.QueryTimeout.Duration(),
logf: opts.logf(),
dbs: dbs,
Expand Down Expand Up @@ -225,40 +227,37 @@ func (s *Server) serveUI(w http.ResponseWriter, r *http.Request) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
src := r.FormValue("src")
if src == "" {
q, err := s.qcheck(Query{
Source: r.FormValue("src"),
Query: strings.TrimSpace(r.FormValue("q")),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if q.Source == "" {
dbs := s.getHandles()
if len(dbs) != 0 {
src = dbs[0].Source() // default to the first source
q.Source = dbs[0].Source() // default to the first source
}
}

// Reject query strings that are egregiously too long.
const maxQueryBytes = 4000

query := strings.TrimSpace(r.FormValue("q"))
if len(query) > maxQueryBytes {
http.Error(w, "query too long", http.StatusBadRequest)
return
}

caller, isAuthorized := s.checkAuth(w, r, src, query)
caller, isAuthorized := s.checkAuth(w, r, q.Source, q.Query)
if !isAuthorized {
authErrorCount.Add(1)
return
}

var err error
switch r.URL.Path {
case "/":
htmlRequestCount.Add(1)
err = s.serveUIInternal(w, r, caller, src, query)
err = s.serveUIInternal(w, r, caller, q.Source, q.Query)
case "/csv":
csvRequestCount.Add(1)
err = s.serveCSVInternal(w, r, caller, src, query)
err = s.serveCSVInternal(w, r, caller, q.Source, q.Query)
case "/json":
jsonRequestCount.Add(1)
err = s.serveJSONInternal(w, r, caller, src, query)
err = s.serveJSONInternal(w, r, caller, q.Source, q.Query)
case "/meta":
metaRequestCount.Add(1)
err = s.serveMetaInternal(w, r)
Expand Down
17 changes: 15 additions & 2 deletions server/tailsql/tailsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ func TestServer(t *testing.T) {
UILinks: []tailsql.UILink{
{Anchor: testAnchor, URL: testURL},
},
CheckQuery: func(q tailsql.Query) (tailsql.Query, error) {
// Rewrite a source named "alias" as a spelling for "main" and add a
// comment at the front.
if q.Source == "alias" {
q.Source = "main"
q.Query = "-- Hello, world\n" + q.Query
}
return tailsql.DefaultCheckQuery(q)
},
UIRewriteRules: testUIRules,
Authorize: authorizer.ACLGrants(nil),
Logf: t.Logf,
Expand Down Expand Up @@ -225,13 +234,17 @@ func TestServer(t *testing.T) {
})

t.Run("UIDecoration", func(t *testing.T) {
q := make(url.Values)
q.Set("q", "select * from misc")
q := url.Values{
"q": {"select * from misc"},
"src": {"alias"},
}
url := htest.URL + "?" + q.Encode()
ui := string(mustGet(t, cli, url))

// As a rough smoke test, look for expected substrings.
for _, want := range []string{
// The query should include its injected comment from the check function.
`-- Hello, world`,
// Stripe IDs should get wrapped in links.
`<a href="https://dashboard.stripe.com/customers/cus_Fak3Cu6t0m3rId"`,
`<a href="https://dashboard.stripe.com/invoices/in_1f4k31nv0Ic3Num83r"`,
Expand Down

0 comments on commit c03fe0b

Please sign in to comment.