Skip to content

Commit

Permalink
fix: lock cache
Browse files Browse the repository at this point in the history
  • Loading branch information
natesales committed May 24, 2023
1 parent 120145d commit f4d3bf5
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math/big"
"net/http"
"strings"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -17,8 +18,8 @@ import (
)

var (
listen = flag.String("l", ":8080", "listen address")
metricsListen = flag.String("m", ":8081", "listen address")
listen = flag.String("l", ":8080", "API listen address")
metricsListen = flag.String("m", ":8081", "metrics listen address")
tokenValidityDuration = flag.Int("t", 60, "token validity duration in minutes")
tokenValidationWait = flag.Int("w", 60, "how long to wait for a token to be validated before deleting it in seconds")
verbose = flag.Bool("v", false, "enable verbose logging")
Expand All @@ -40,7 +41,10 @@ type cacheEntry struct {
validated bool // Has this hash been validated by a client?
}

var cache = make(map[string]*cacheEntry) // server hash to expiration timestamp
var (
cache = make(map[string]*cacheEntry) // server hash to expiration timestamp
cacheMutex sync.RWMutex
)

const hexLetters = "0123456789abcdef"

Expand All @@ -60,7 +64,9 @@ func randomString(length int) (string, error) {

// validate checks that a client provided token matches the given server hash
func validate(token, hash string) bool {
cacheMutex.RLock()
entry, found := cache[hash]
cacheMutex.RUnlock()
if !found {
return false
}
Expand All @@ -69,7 +75,9 @@ func validate(token, hash string) bool {
// Check if server hash is expired
if time.Now().After(entry.created.Add(time.Duration(*tokenValidityDuration) * time.Minute)) {
log.Debugf("Server hash %s expired, removing from cache", hash)
cacheMutex.Lock()
delete(cache, hash)
cacheMutex.Unlock()
return false
}

Expand All @@ -91,25 +99,37 @@ func main() {
purgeTicker := time.NewTicker(time.Second * time.Duration(*tokenValidationWait/2))
go func() {
for range purgeTicker.C {
for hash, entry := range cache {
// Clone cache
cacheMutex.RLock()
cacheCopy := make(map[string]*cacheEntry)
for k, v := range cache {
cacheCopy[k] = v
}
cacheMutex.RUnlock()

for hash, entry := range cacheCopy {
if !entry.validated && time.Now().After(entry.created.Add(time.Duration(*tokenValidationWait)*time.Second)) {
log.Debugf("Purging expired server hash %s", hash)
cacheMutex.Lock()
delete(cache, hash)
cacheMutex.Unlock()
}
}
}
}()

metricUpdateTicker := time.NewTicker(1 * time.Second)
metricUpdateTicker := time.NewTicker(10 * time.Second)
go func() {
for range metricUpdateTicker.C {
cacheMutex.RLock()
metricIssuedTokens.Set(float64(len(cache)))
validated := 0
for _, token := range cache {
if token.validated {
validated++
}
}
cacheMutex.RUnlock()
metricValidatedTokens.Set(float64(validated))
}
}()
Expand Down Expand Up @@ -139,17 +159,21 @@ func main() {
_, _ = w.Write([]byte("Error"))
return
}
cacheMutex.Lock()
cache[newHash] = &cacheEntry{
created: time.Now(),
validated: false,
}
cacheMutex.Unlock()
log.Debugf("Generated new hash %s", newHash)
_, _ = w.Write([]byte(newHash))
})

http.HandleFunc("/invalidate", func(w http.ResponseWriter, r *http.Request) {
log.Debug("Invalidating all hashes")
cacheMutex.Lock()
cache = make(map[string]*cacheEntry)
cacheMutex.Unlock()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
Expand Down

0 comments on commit f4d3bf5

Please sign in to comment.