Skip to content

Commit

Permalink
feat(dot/network): Add warp sync spam limiter (#4219)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimartiro authored Oct 7, 2024
1 parent 6e7a351 commit e1dd783
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 0 deletions.
12 changes: 12 additions & 0 deletions dot/network/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/adrg/xdg"
"github.com/libp2p/go-libp2p/core/crypto"

"github.com/ChainSafe/gossamer/dot/network/ratelimiters"
"github.com/ChainSafe/gossamer/internal/log"
"github.com/ChainSafe/gossamer/internal/metrics"
"github.com/ChainSafe/gossamer/lib/common"
Expand Down Expand Up @@ -113,6 +114,9 @@ type Config struct {

Telemetry Telemetry
Metrics metrics.IntervalConfig

// Spam limiters configuration
warpSyncSpamLimiter RateLimiter
}

// build checks the configuration, sets up the private key for the network service,
Expand Down Expand Up @@ -154,6 +158,14 @@ func (c *Config) build() error {
c.telemetryInterval = time.Second * 5
}

// set warp sync spam limiter to default
if c.warpSyncSpamLimiter == nil {
c.warpSyncSpamLimiter = ratelimiters.NewSlidingWindowRateLimiter(
ratelimiters.DefaultMaxCachedRequestSize,
ratelimiters.DefaultMaxSlidingWindowTime,
)
}

return nil
}

Expand Down
8 changes: 8 additions & 0 deletions dot/network/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package network
import (
"encoding/json"
"io"

"github.com/ChainSafe/gossamer/lib/common"
)

// Telemetry is the telemetry client to send telemetry messages.
Expand All @@ -27,3 +29,9 @@ type MDNS interface {
Start() error
io.Closer
}

// RateLimiter is the interface for rate limiting requests.
type RateLimiter interface {
AddRequest(id common.Hash)
IsLimitExceeded(id common.Hash) bool
}
76 changes: 76 additions & 0 deletions dot/network/ratelimiters/sliding_window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package ratelimiters

import (
"sync"
"time"

"github.com/ChainSafe/gossamer/lib/common"
lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache"
)

const DefaultMaxSlidingWindowTime = 1 * time.Minute
const DefaultMaxCachedRequestSize = 500

// SlidingWindowRateLimiter is a rate limiter implementation designed to prevent
// more than `maxReqs` requests from being processed within a `windowSize` time window
type SlidingWindowRateLimiter struct {
mu sync.Mutex
limits *lrucache.LRUCache[common.Hash, []time.Time]
maxReqs uint32
windowSize time.Duration
}

// NewSlidingWindowRateLimiter creates a new SlidingWindowRateLimiter with the given maximum number of requests
func NewSlidingWindowRateLimiter(maxReqs uint32, windowSize time.Duration) *SlidingWindowRateLimiter {
return &SlidingWindowRateLimiter{
limits: lrucache.NewLRUCache[common.Hash, []time.Time](DefaultMaxCachedRequestSize),
maxReqs: maxReqs,
windowSize: windowSize,
}
}

// AddRequest adds a request to the SlidingWindowRateLimiter
func (rl *SlidingWindowRateLimiter) AddRequest(id common.Hash) {
rl.mu.Lock()
defer rl.mu.Unlock()

recentRequests := rl.recentRequests(id)

// Add the current request and update the cache
recentRequests = append(recentRequests, time.Now())
rl.limits.Put(id, recentRequests)
}

// IsLimitExceeded returns true if the limit is exceeded for the given peer and hash
func (rl *SlidingWindowRateLimiter) IsLimitExceeded(id common.Hash) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

recentRequests := rl.recentRequests(id)
rl.limits.Put(id, recentRequests)

return uint32(len(recentRequests)) > rl.maxReqs
}

func (rl *SlidingWindowRateLimiter) recentRequests(id common.Hash) []time.Time {
// Get the timestamps for the hash
timestamps := rl.limits.Get(id)
if timestamps == nil {
return []time.Time{}
}

now := time.Now()

// Filter requests that are within the time window
var recentRequests []time.Time
for _, t := range timestamps {
if now.Sub(t) <= rl.windowSize {
recentRequests = append(recentRequests, t)
}
}

return recentRequests
}
86 changes: 86 additions & 0 deletions dot/network/ratelimiters/sliding_window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package ratelimiters

import (
"testing"
"time"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/assert"
)

func TestSlidingWindowRateLimiter_AddRequestAndCheckLimitExceeded(t *testing.T) {
t.Parallel()

// Create a SlidingWindowRateLimiter with a limit of 5 requests and a time window of 10 seconds
limiter := NewSlidingWindowRateLimiter(5, 10*time.Second)

hash := common.Hash{0x01}

// Add 5 requests for the same hash
for i := 0; i < 5; i++ {
limiter.AddRequest(hash)
}

// Limit should not be exceeded after 5 requests
assert.False(t, limiter.IsLimitExceeded(hash))

// Add one more request and check that the limit is exceeded
limiter.AddRequest(hash)
assert.True(t, limiter.IsLimitExceeded(hash))
}

func TestSlidingWindowRateLimiter_WindowExpiry(t *testing.T) {
t.Parallel()

// Create a SlidingWindowRateLimiter with a limit of 3 requests and a time window of 1 second
limiter := NewSlidingWindowRateLimiter(3, 1*time.Second)

hash := common.Hash{0x02}

// Add 3 requests
for i := 0; i < 3; i++ {
limiter.AddRequest(hash)
}

// Limit should not be exceeded
assert.False(t, limiter.IsLimitExceeded(hash))

// Wait for the time window to expire
time.Sleep(2 * time.Second)

// Add another request, should be considered as the first in a new window
limiter.AddRequest(hash)
assert.False(t, limiter.IsLimitExceeded(hash))
}

func TestSlidingWindowRateLimiter_DifferentHashes(t *testing.T) {
t.Parallel()

// Create a SlidingWindowRateLimiter with a limit of 2 requests and a time window of 5 seconds
limiter := NewSlidingWindowRateLimiter(2, 5*time.Second)

hash1 := common.Hash{0x01}
hash2 := common.Hash{0x02}

// Add requests for hash1
limiter.AddRequest(hash1)
limiter.AddRequest(hash1)

// Add requests for hash2
limiter.AddRequest(hash2)
limiter.AddRequest(hash2)

// No limit should be exceeded yet
assert.False(t, limiter.IsLimitExceeded(hash1))
assert.False(t, limiter.IsLimitExceeded(hash2))

// Add another request for each and check that the limit is exceeded
limiter.AddRequest(hash1)
assert.True(t, limiter.IsLimitExceeded(hash1))

limiter.AddRequest(hash2)
assert.True(t, limiter.IsLimitExceeded(hash2))
}
4 changes: 4 additions & 0 deletions dot/network/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ type Service struct {
closeCh chan struct{}

telemetry Telemetry

// Spam control
warpSyncSpamLimiter RateLimiter
}

// NewService creates a new network service from the configuration and message channels
Expand Down Expand Up @@ -226,6 +229,7 @@ func NewService(cfg *Config) (*Service, error) {
streamManager: newStreamManager(ctx),
telemetry: cfg.Telemetry,
Metrics: cfg.Metrics,
warpSyncSpamLimiter: cfg.warpSyncSpamLimiter,
}

return network, nil
Expand Down
16 changes: 16 additions & 0 deletions dot/network/warp_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ package network

import (
"errors"
"fmt"

"github.com/ChainSafe/gossamer/dot/network/messages"
"github.com/ChainSafe/gossamer/lib/common"
libp2pnetwork "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
)

const MaxAllowedSameRequestPerPeer = 5

// WarpSyncProvider is an interface for generating warp sync proofs
type WarpSyncProvider interface {
// Generate proof starting at given block hash. The proof is accumulated until maximum proof
Expand Down Expand Up @@ -55,7 +58,20 @@ func (s *Service) handleWarpSyncMessage(stream libp2pnetwork.Stream, msg message
}
}()

reqId := fmt.Sprintf("%s-%s", stream.Conn().RemotePeer(), msg.String())
hashedreqId := common.MustBlake2bHash([]byte(reqId))

if req, ok := msg.(*messages.WarpProofRequest); ok {
// Check if this peer has exceeded the limit of requests
if s.warpSyncSpamLimiter.IsLimitExceeded(hashedreqId) {
logger.Debugf("same warp sync request exceeded for peer: %s", stream.Conn().RemotePeer())
return nil
}

// Add the request to the spam limiter
s.warpSyncSpamLimiter.AddRequest(hashedreqId)

// Handle request
resp, err := s.handleWarpSyncRequest(*req)
if err != nil {
logger.Debugf("cannot create response for request: %s", err)
Expand Down

0 comments on commit e1dd783

Please sign in to comment.