From 989e6d4183f78c18a28964d4b204d91945e49934 Mon Sep 17 00:00:00 2001 From: Alexandre Fiori Date: Wed, 4 Oct 2017 11:30:52 +0100 Subject: [PATCH] Enforce HSTS on all endpoints --- apiserver/api.go | 81 ++++++++++++++++++++++++++++------------------- apiserver/main.go | 2 +- 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/apiserver/api.go b/apiserver/api.go index 249553a..3ee028a 100644 --- a/apiserver/api.go +++ b/apiserver/api.go @@ -68,16 +68,17 @@ func NewHandler(c *Config) (http.Handler, error) { func (f *apiHandler) config(mc *httpmux.Config) error { mc.Prefix = f.conf.APIPrefix - if f.conf.PublicDir != "" { - mc.NotFound = f.publicDir() - } + mc.NotFound = newPublicDirHandler(f.conf.PublicDir) if f.conf.UseXForwardedFor { mc.UseFunc(httplog.UseXForwardedFor) } if !f.conf.Silent { mc.UseFunc(httplog.ApacheCombinedFormat(f.conf.accessLogger())) } - mc.UseFunc(f.metrics) + if f.conf.HSTS != "" { + mc.UseFunc(hstsMiddleware(f.conf.HSTS)) + } + mc.UseFunc(clientMetricsMiddleware(f.db)) if f.conf.RateLimitLimit > 0 { rl, err := newRateLimiter(f.conf) if err != nil { @@ -96,40 +97,57 @@ func (f *apiHandler) config(mc *httpmux.Config) error { return nil } -func (f *apiHandler) publicDir() http.HandlerFunc { - fs := http.FileServer(http.Dir(f.conf.PublicDir)) - return prometheus.InstrumentHandler("frontend", fs) +func newPublicDirHandler(path string) http.HandlerFunc { + handler := http.NotFoundHandler() + if path != "" { + handler = http.FileServer(http.Dir(path)) + } + return prometheus.InstrumentHandler("frontend", handler) } -func (f *apiHandler) metrics(next http.HandlerFunc) http.HandlerFunc { +func hstsMiddleware(policy string) httpmux.MiddlewareFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil { + return + } + w.Header().Set("Strict-Transport-Security", policy) + next(w, r) + } + } +} + +func clientMetricsMiddleware(db *freegeoip.DB) httpmux.MiddlewareFunc { type query struct { Country struct { ISOCode string `maxminddb:"iso_code"` } `maxminddb:"country"` } - return func(w http.ResponseWriter, r *http.Request) { - next(w, r) - // Collect metrics after serving the request. - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return - } - ip := net.ParseIP(host) - if ip == nil { - return - } - if ip.To4() != nil { - clientIPProtoCounter.WithLabelValues("4").Inc() - } else { - clientIPProtoCounter.WithLabelValues("6").Inc() - } - var q query - err = f.db.Lookup(ip, &q) - if err != nil || q.Country.ISOCode == "" { - clientCountryCounter.WithLabelValues("unknown").Inc() - return + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + next(w, r) + // Collect metrics after serving the request. + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return + } + ip := net.ParseIP(host) + if ip == nil { + return + } + if ip.To4() != nil { + clientIPProtoCounter.WithLabelValues("4").Inc() + } else { + clientIPProtoCounter.WithLabelValues("6").Inc() + } + var q query + err = db.Lookup(ip, &q) + if err != nil || q.Country.ISOCode == "" { + clientCountryCounter.WithLabelValues("unknown").Inc() + return + } + clientCountryCounter.WithLabelValues(q.Country.ISOCode).Inc() } - clientCountryCounter.WithLabelValues(q.Country.ISOCode).Inc() } } @@ -169,9 +187,6 @@ func (f *apiHandler) iplookup(writer writerFunc) http.HandlerFunc { http.Error(w, "Try again later.", http.StatusServiceUnavailable) return } - if r.TLS != nil && f.conf.HSTS != "" { - w.Header().Set("Strict-Transport-Security", f.conf.HSTS) - } w.Header().Set("X-Database-Date", f.db.Date().Format(http.TimeFormat)) resp := q.Record(ip, r.Header.Get("Accept-Language")) writer(w, r, resp) diff --git a/apiserver/main.go b/apiserver/main.go index 6e6abb8..458a85c 100644 --- a/apiserver/main.go +++ b/apiserver/main.go @@ -21,7 +21,7 @@ import ( ) // Version tag. -var Version = "3.4" +var Version = "3.4.1" // Run is the entrypoint for the freegeoip server. func Run() {