Skip to content

Commit

Permalink
Merge pull request #189 from ashleym1972/licensed-product
Browse files Browse the repository at this point in the history
Licensed product updates
  • Loading branch information
fiorix authored Oct 26, 2016
2 parents 1f6c482 + f3cb0c5 commit 6a6cfcf
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
13 changes: 13 additions & 0 deletions apiserver/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ func (rr *responseRecord) String() string {

// openDB opens and returns the IP database file or URL.
func openDB(c *Config) (*freegeoip.DB, error) {
// This is a paid product. Get the updates URL.
if len(c.UserID) > 0 && len(c.LicenseKey) > 0 {
var err error
c.DB, err = freegeoip.GeoIPUpdateURL(c.UpdatesHost, c.UserID, c.LicenseKey, c.ProductID)
if err != nil {
return nil, err
} else {
log.Println("Using updates URL:", c.DB)
}
}

u, err := url.Parse(c.DB)
if err != nil || len(u.Scheme) == 0 {
return freegeoip.Open(c.DB)
Expand All @@ -283,6 +294,8 @@ func watchEvents(db *freegeoip.DB) {
case err := <-db.NotifyError():
log.Println("database error:", err)
dbEventCounter.WithLabelValues("failed").Inc()
case msg := <-db.NotifyInfo():
log.Println("database info:", msg)
case <-db.NotifyClose():
return
}
Expand Down
10 changes: 10 additions & 0 deletions apiserver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type Config struct {
RateLimitLimit uint64
RateLimitInterval time.Duration
InternalServerAddr string
UpdatesHost string
LicenseKey string
UserID string
ProductID string

errorLog *log.Logger
accessLog *log.Logger
Expand All @@ -65,6 +69,8 @@ func NewConfig() *Config {
MemcacheTimeout: time.Second,
RateLimitBackend: "redis",
RateLimitInterval: time.Hour,
UpdatesHost: "updates.maxmind.com",
ProductID: "GeoIP2-City",
}
}

Expand Down Expand Up @@ -94,6 +100,10 @@ func (c *Config) AddFlags(fs *flag.FlagSet) {
fs.Uint64Var(&c.RateLimitLimit, "quota-max", c.RateLimitLimit, "Max requests per source IP per interval; set 0 to turn quotas off")
fs.DurationVar(&c.RateLimitInterval, "quota-interval", c.RateLimitInterval, "Quota expiration interval, per source IP querying the API")
fs.StringVar(&c.InternalServerAddr, "internal-server", c.InternalServerAddr, "Address in form of ip:port to listen on for metrics and pprof")
fs.StringVar(&c.UpdatesHost, "updates-host", c.UpdatesHost, "MaxMind Updates Host")
fs.StringVar(&c.LicenseKey, "license-key", c.LicenseKey, "MaxMind License key")
fs.StringVar(&c.UserID, "user-id", c.UserID, "MaxMind User ID")
fs.StringVar(&c.ProductID, "product-id", c.ProductID, "MaxMinf product id (e.g GeoIP2-City)")
}

func (c *Config) logWriter() io.Writer {
Expand Down
62 changes: 62 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package freegeoip

import (
"compress/gzip"
"crypto/md5"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -42,6 +43,7 @@ type DB struct {
notifyQuit chan struct{} // Stop auto-update and watch goroutines.
notifyOpen chan string // Notify when a db file is open.
notifyError chan error // Notify when an error occurs.
notifyInfo chan string // Notify random actions for logging
closed bool // Mark this db as closed.
lastUpdated time.Time // Last time the db was updated.
mu sync.RWMutex // Protects all the above.
Expand All @@ -60,6 +62,7 @@ func Open(dsn string) (db *DB, err error) {
notifyQuit: make(chan struct{}),
notifyOpen: make(chan string, 1),
notifyError: make(chan error, 1),
notifyInfo: make(chan string, 1),
}
err = db.openFile()
if err != nil {
Expand All @@ -74,6 +77,41 @@ func Open(dsn string) (db *DB, err error) {
return db, nil
}

// Calculate geoipupdate URL
// The auto update URL for paid products has a fun scheme.
// Use this function to calculate that URL from various information
func GeoIPUpdateURL(hostName string, userID string, licenseKey string, productID string) (url string, err error) {
// Get the file name from the product
url = fmt.Sprintf("%s://%s/app/update_getfilename?product_id=%s", "https", hostName, productID)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
hexDigest := fmt.Sprintf("%x", md5.Sum(body))

// Get our client IP address
url = fmt.Sprintf("%s://%s/app/update_getipaddr", "https", hostName)
resp, err = http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
challenge := []byte(fmt.Sprintf("%s%s", licenseKey, body))
hexDigest2 := fmt.Sprintf("%x", md5.Sum(challenge))

// Create the URL
return fmt.Sprintf("%s://%s/app/update_secure?db_md5=%s&challenge_md5=%s&user_id=%s&edition_id=%s", "https", hostName, hexDigest, hexDigest2, userID, productID), nil
}

// OpenURL creates and initializes a DB from a URL.
// It automatically downloads and updates the file in background, and
// keeps a local copy on $TMPDIR.
Expand All @@ -83,6 +121,7 @@ func OpenURL(url string, updateInterval, maxRetryInterval time.Duration) (db *DB
notifyQuit: make(chan struct{}),
notifyOpen: make(chan string, 1),
notifyError: make(chan error, 1),
notifyInfo: make(chan string, 1),
updateInterval: updateInterval,
maxRetryInterval: maxRetryInterval,
}
Expand Down Expand Up @@ -196,6 +235,7 @@ func (db *DB) autoUpdate(url string) {
}

func (db *DB) runUpdate(url string) error {
db.sendInfo("starting update")
yes, err := db.needUpdate(url)
if err != nil {
return err
Expand All @@ -212,6 +252,7 @@ func (db *DB) runUpdate(url string) error {
// Cleanup the tempfile if renaming failed.
os.RemoveAll(tmpfile)
}
db.sendInfo("finished update")
return err
}

Expand All @@ -232,6 +273,7 @@ func (db *DB) needUpdate(url string) (bool, error) {
}

func (db *DB) download(url string) (tmpfile string, err error) {
db.sendInfo("starting download")
resp, err := http.Get(url)
if err != nil {
return "", err
Expand All @@ -248,6 +290,7 @@ func (db *DB) download(url string) (tmpfile string, err error) {
if err != nil {
return "", err
}
db.sendInfo("finished download")
return tmpfile, nil
}

Expand Down Expand Up @@ -298,6 +341,12 @@ func (db *DB) NotifyError() (errChan <-chan error) {
return db.notifyError
}

// NotifyInfo returns a channel that notifies informational messages
// while downloading or reloading.
func (db *DB) NotifyInfo() <-chan string {
return db.notifyInfo
}

func (db *DB) sendError(err error) {
db.mu.RLock()
defer db.mu.RUnlock()
Expand All @@ -310,6 +359,18 @@ func (db *DB) sendError(err error) {
}
}

func (db *DB) sendInfo(message string) {
db.mu.RLock()
defer db.mu.RUnlock()
if db.closed {
return
}
select {
case db.notifyInfo <- message:
default:
}
}

// Lookup takes an IP address and a pointer to the result value to decode
// into. The result value pointed to must be a data value that corresponds
// to a record in the database. This may include a struct representation
Expand Down Expand Up @@ -368,6 +429,7 @@ func (db *DB) Close() {
close(db.notifyQuit)
close(db.notifyOpen)
close(db.notifyError)
close(db.notifyInfo)
}
if db.reader != nil {
db.reader.Close()
Expand Down
26 changes: 26 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@ import (

var testFile = "testdata/db.gz"

func TestGeoIPUpdateURL(t *testing.T) {
t.Skip("Updates information required")
licenseKey := ""
UserID := ""
url, err := GeoIPUpdateURL("updates.maxmind.com", licenseKey, UserID, "GeoIP2-City")
if err != nil {
t.Fatal(err)
}

db := &DB{}
dbfile, err := db.download(url)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(testFile); err == nil {
err := os.Remove(testFile)
if err != nil {
t.Fatal(err)
}
}
err = os.Rename(dbfile, testFile)
if err != nil {
t.Fatal(err)
}
}

func TestDownload(t *testing.T) {
if _, err := os.Stat(testFile); err == nil {
t.Skip("Test database already exists:", testFile)
Expand Down

0 comments on commit 6a6cfcf

Please sign in to comment.