diff --git a/dot/network/config.go b/dot/network/config.go index ca73afb7c6..4f5042bf10 100644 --- a/dot/network/config.go +++ b/dot/network/config.go @@ -13,6 +13,7 @@ import ( "github.com/adrg/xdg" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/ChainSafe/gossamer/dot/network/ratelimiters" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/internal/metrics" "github.com/ChainSafe/gossamer/lib/common" @@ -113,6 +114,9 @@ type Config struct { Telemetry Telemetry Metrics metrics.IntervalConfig + + // Spam limiters configuration + warpSyncSpamLimiter RateLimiter } // build checks the configuration, sets up the private key for the network service, @@ -154,6 +158,14 @@ func (c *Config) build() error { c.telemetryInterval = time.Second * 5 } + // set warp sync spam limiter to default + if c.warpSyncSpamLimiter == nil { + c.warpSyncSpamLimiter = ratelimiters.NewSlidingWindowRateLimiter( + ratelimiters.DefaultMaxCachedRequestSize, + ratelimiters.DefaultMaxSlidingWindowTime, + ) + } + return nil } diff --git a/dot/network/interfaces.go b/dot/network/interfaces.go index 977fa27f6e..6d64d70de8 100644 --- a/dot/network/interfaces.go +++ b/dot/network/interfaces.go @@ -6,6 +6,8 @@ package network import ( "encoding/json" "io" + + "github.com/ChainSafe/gossamer/lib/common" ) // Telemetry is the telemetry client to send telemetry messages. @@ -27,3 +29,9 @@ type MDNS interface { Start() error io.Closer } + +// RateLimiter is the interface for rate limiting requests. +type RateLimiter interface { + AddRequest(id common.Hash) + IsLimitExceeded(id common.Hash) bool +} diff --git a/dot/network/ratelimiters/sliding_window.go b/dot/network/ratelimiters/sliding_window.go new file mode 100644 index 0000000000..24ee655f4f --- /dev/null +++ b/dot/network/ratelimiters/sliding_window.go @@ -0,0 +1,76 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package ratelimiters + +import ( + "sync" + "time" + + "github.com/ChainSafe/gossamer/lib/common" + lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" +) + +const DefaultMaxSlidingWindowTime = 1 * time.Minute +const DefaultMaxCachedRequestSize = 500 + +// SlidingWindowRateLimiter is a rate limiter implementation designed to prevent +// more than `maxReqs` requests from being processed within a `windowSize` time window +type SlidingWindowRateLimiter struct { + mu sync.Mutex + limits *lrucache.LRUCache[common.Hash, []time.Time] + maxReqs uint32 + windowSize time.Duration +} + +// NewSlidingWindowRateLimiter creates a new SlidingWindowRateLimiter with the given maximum number of requests +func NewSlidingWindowRateLimiter(maxReqs uint32, windowSize time.Duration) *SlidingWindowRateLimiter { + return &SlidingWindowRateLimiter{ + limits: lrucache.NewLRUCache[common.Hash, []time.Time](DefaultMaxCachedRequestSize), + maxReqs: maxReqs, + windowSize: windowSize, + } +} + +// AddRequest adds a request to the SlidingWindowRateLimiter +func (rl *SlidingWindowRateLimiter) AddRequest(id common.Hash) { + rl.mu.Lock() + defer rl.mu.Unlock() + + recentRequests := rl.recentRequests(id) + + // Add the current request and update the cache + recentRequests = append(recentRequests, time.Now()) + rl.limits.Put(id, recentRequests) +} + +// IsLimitExceeded returns true if the limit is exceeded for the given peer and hash +func (rl *SlidingWindowRateLimiter) IsLimitExceeded(id common.Hash) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + recentRequests := rl.recentRequests(id) + rl.limits.Put(id, recentRequests) + + return uint32(len(recentRequests)) > rl.maxReqs +} + +func (rl *SlidingWindowRateLimiter) recentRequests(id common.Hash) []time.Time { + // Get the timestamps for the hash + timestamps := rl.limits.Get(id) + if timestamps == nil { + return []time.Time{} + } + + now := time.Now() + + // Filter requests that are within the time window + var recentRequests []time.Time + for _, t := range timestamps { + if now.Sub(t) <= rl.windowSize { + recentRequests = append(recentRequests, t) + } + } + + return recentRequests +} diff --git a/dot/network/ratelimiters/sliding_window_test.go b/dot/network/ratelimiters/sliding_window_test.go new file mode 100644 index 0000000000..781b2deb5a --- /dev/null +++ b/dot/network/ratelimiters/sliding_window_test.go @@ -0,0 +1,86 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package ratelimiters + +import ( + "testing" + "time" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/assert" +) + +func TestSlidingWindowRateLimiter_AddRequestAndCheckLimitExceeded(t *testing.T) { + t.Parallel() + + // Create a SlidingWindowRateLimiter with a limit of 5 requests and a time window of 10 seconds + limiter := NewSlidingWindowRateLimiter(5, 10*time.Second) + + hash := common.Hash{0x01} + + // Add 5 requests for the same hash + for i := 0; i < 5; i++ { + limiter.AddRequest(hash) + } + + // Limit should not be exceeded after 5 requests + assert.False(t, limiter.IsLimitExceeded(hash)) + + // Add one more request and check that the limit is exceeded + limiter.AddRequest(hash) + assert.True(t, limiter.IsLimitExceeded(hash)) +} + +func TestSlidingWindowRateLimiter_WindowExpiry(t *testing.T) { + t.Parallel() + + // Create a SlidingWindowRateLimiter with a limit of 3 requests and a time window of 1 second + limiter := NewSlidingWindowRateLimiter(3, 1*time.Second) + + hash := common.Hash{0x02} + + // Add 3 requests + for i := 0; i < 3; i++ { + limiter.AddRequest(hash) + } + + // Limit should not be exceeded + assert.False(t, limiter.IsLimitExceeded(hash)) + + // Wait for the time window to expire + time.Sleep(2 * time.Second) + + // Add another request, should be considered as the first in a new window + limiter.AddRequest(hash) + assert.False(t, limiter.IsLimitExceeded(hash)) +} + +func TestSlidingWindowRateLimiter_DifferentHashes(t *testing.T) { + t.Parallel() + + // Create a SlidingWindowRateLimiter with a limit of 2 requests and a time window of 5 seconds + limiter := NewSlidingWindowRateLimiter(2, 5*time.Second) + + hash1 := common.Hash{0x01} + hash2 := common.Hash{0x02} + + // Add requests for hash1 + limiter.AddRequest(hash1) + limiter.AddRequest(hash1) + + // Add requests for hash2 + limiter.AddRequest(hash2) + limiter.AddRequest(hash2) + + // No limit should be exceeded yet + assert.False(t, limiter.IsLimitExceeded(hash1)) + assert.False(t, limiter.IsLimitExceeded(hash2)) + + // Add another request for each and check that the limit is exceeded + limiter.AddRequest(hash1) + assert.True(t, limiter.IsLimitExceeded(hash1)) + + limiter.AddRequest(hash2) + assert.True(t, limiter.IsLimitExceeded(hash2)) +} diff --git a/dot/network/service.go b/dot/network/service.go index 258a964b9e..f916e18359 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -145,6 +145,9 @@ type Service struct { closeCh chan struct{} telemetry Telemetry + + // Spam control + warpSyncSpamLimiter RateLimiter } // NewService creates a new network service from the configuration and message channels @@ -226,6 +229,7 @@ func NewService(cfg *Config) (*Service, error) { streamManager: newStreamManager(ctx), telemetry: cfg.Telemetry, Metrics: cfg.Metrics, + warpSyncSpamLimiter: cfg.warpSyncSpamLimiter, } return network, nil diff --git a/dot/network/warp_sync.go b/dot/network/warp_sync.go index 49a9a32490..8c6978163e 100644 --- a/dot/network/warp_sync.go +++ b/dot/network/warp_sync.go @@ -5,6 +5,7 @@ package network import ( "errors" + "fmt" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/lib/common" @@ -12,6 +13,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) +const MaxAllowedSameRequestPerPeer = 5 + // WarpSyncProvider is an interface for generating warp sync proofs type WarpSyncProvider interface { // Generate proof starting at given block hash. The proof is accumulated until maximum proof @@ -55,7 +58,20 @@ func (s *Service) handleWarpSyncMessage(stream libp2pnetwork.Stream, msg message } }() + reqId := fmt.Sprintf("%s-%s", stream.Conn().RemotePeer(), msg.String()) + hashedreqId := common.MustBlake2bHash([]byte(reqId)) + if req, ok := msg.(*messages.WarpProofRequest); ok { + // Check if this peer has exceeded the limit of requests + if s.warpSyncSpamLimiter.IsLimitExceeded(hashedreqId) { + logger.Debugf("same warp sync request exceeded for peer: %s", stream.Conn().RemotePeer()) + return nil + } + + // Add the request to the spam limiter + s.warpSyncSpamLimiter.AddRequest(hashedreqId) + + // Handle request resp, err := s.handleWarpSyncRequest(*req) if err != nil { logger.Debugf("cannot create response for request: %s", err)