diff --git a/main.go b/main.go index b95ff61..3b14678 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "math/big" "net/http" "strings" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -17,8 +18,8 @@ import ( ) var ( - listen = flag.String("l", ":8080", "listen address") - metricsListen = flag.String("m", ":8081", "listen address") + listen = flag.String("l", ":8080", "API listen address") + metricsListen = flag.String("m", ":8081", "metrics listen address") tokenValidityDuration = flag.Int("t", 60, "token validity duration in minutes") tokenValidationWait = flag.Int("w", 60, "how long to wait for a token to be validated before deleting it in seconds") verbose = flag.Bool("v", false, "enable verbose logging") @@ -40,7 +41,10 @@ type cacheEntry struct { validated bool // Has this hash been validated by a client? } -var cache = make(map[string]*cacheEntry) // server hash to expiration timestamp +var ( + cache = make(map[string]*cacheEntry) // server hash to expiration timestamp + cacheMutex sync.RWMutex +) const hexLetters = "0123456789abcdef" @@ -60,7 +64,9 @@ func randomString(length int) (string, error) { // validate checks that a client provided token matches the given server hash func validate(token, hash string) bool { + cacheMutex.RLock() entry, found := cache[hash] + cacheMutex.RUnlock() if !found { return false } @@ -69,7 +75,9 @@ func validate(token, hash string) bool { // Check if server hash is expired if time.Now().After(entry.created.Add(time.Duration(*tokenValidityDuration) * time.Minute)) { log.Debugf("Server hash %s expired, removing from cache", hash) + cacheMutex.Lock() delete(cache, hash) + cacheMutex.Unlock() return false } @@ -91,18 +99,29 @@ func main() { purgeTicker := time.NewTicker(time.Second * time.Duration(*tokenValidationWait/2)) go func() { for range purgeTicker.C { - for hash, entry := range cache { + // Clone cache + cacheMutex.RLock() + cacheCopy := make(map[string]*cacheEntry) + for k, v := range cache { + cacheCopy[k] = v + } + cacheMutex.RUnlock() + + for hash, entry := range cacheCopy { if !entry.validated && time.Now().After(entry.created.Add(time.Duration(*tokenValidationWait)*time.Second)) { log.Debugf("Purging expired server hash %s", hash) + cacheMutex.Lock() delete(cache, hash) + cacheMutex.Unlock() } } } }() - metricUpdateTicker := time.NewTicker(1 * time.Second) + metricUpdateTicker := time.NewTicker(10 * time.Second) go func() { for range metricUpdateTicker.C { + cacheMutex.RLock() metricIssuedTokens.Set(float64(len(cache))) validated := 0 for _, token := range cache { @@ -110,6 +129,7 @@ func main() { validated++ } } + cacheMutex.RUnlock() metricValidatedTokens.Set(float64(validated)) } }() @@ -139,17 +159,21 @@ func main() { _, _ = w.Write([]byte("Error")) return } + cacheMutex.Lock() cache[newHash] = &cacheEntry{ created: time.Now(), validated: false, } + cacheMutex.Unlock() log.Debugf("Generated new hash %s", newHash) _, _ = w.Write([]byte(newHash)) }) http.HandleFunc("/invalidate", func(w http.ResponseWriter, r *http.Request) { log.Debug("Invalidating all hashes") + cacheMutex.Lock() cache = make(map[string]*cacheEntry) + cacheMutex.Unlock() w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) })