diff --git a/share/shwap/p2p/bitswap/getter.go b/share/shwap/p2p/bitswap/getter.go index eaba3a0f39..308185a36c 100644 --- a/share/shwap/p2p/bitswap/getter.go +++ b/share/shwap/p2p/bitswap/getter.go @@ -3,6 +3,7 @@ package bitswap import ( "context" "fmt" + "sync" "time" "github.com/ipfs/boxo/blockstore" @@ -10,7 +11,6 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/trace" "github.com/celestiaorg/celestia-app/v3/pkg/wrapper" libshare "github.com/celestiaorg/go-square/v2/share" @@ -30,8 +30,8 @@ type Getter struct { bstore blockstore.Blockstore availWndw time.Duration - availableSession exchange.Fetcher - archivalSession exchange.Fetcher + availablePool *pool + archivalPool *pool cancel context.CancelFunc } @@ -42,7 +42,13 @@ func NewGetter( bstore blockstore.Blockstore, availWndw time.Duration, ) *Getter { - return &Getter{exchange: exchange, bstore: bstore, availWndw: availWndw} + return &Getter{ + exchange: exchange, + bstore: bstore, + availWndw: availWndw, + availablePool: newPool(exchange), + archivalPool: newPool(exchange), + } } // Start kicks off internal fetching sessions. @@ -57,12 +63,13 @@ func NewGetter( // with regular full node peers. func (g *Getter) Start() { ctx, cancel := context.WithCancel(context.Background()) - g.availableSession = g.exchange.NewSession(ctx) - g.archivalSession = g.exchange.NewSession(ctx) g.cancel = cancel + + g.availablePool.ctx = ctx + g.availablePool.ctx = ctx } -// Stop shuts down Getter's internal fetching session. +// Stop shuts down Getter's internal fetching getSession. func (g *Getter) Stop() { g.cancel() } @@ -97,7 +104,12 @@ func (g *Getter) GetShares( blks[i] = sid } - ses := g.session(ctx, hdr) + isArchival := g.isArchival(hdr) + span.SetAttributes(attribute.Bool("is_archival", isArchival)) + + ses, release := g.getSession(isArchival) + defer release() + err := Fetch(ctx, g.exchange, hdr.DAH, blks, WithStore(g.bstore), WithFetcher(ses)) if err != nil { span.RecordError(err) @@ -156,7 +168,12 @@ func (g *Getter) GetEDS( blks[i] = blk } - ses := g.session(ctx, hdr) + isArchival := g.isArchival(hdr) + span.SetAttributes(attribute.Bool("is_archival", isArchival)) + + ses, release := g.getSession(isArchival) + defer release() + err := Fetch(ctx, g.exchange, hdr.DAH, blks, WithFetcher(ses)) if err != nil { span.RecordError(err) @@ -210,7 +227,12 @@ func (g *Getter) GetNamespaceData( blks[i] = rndblk } - ses := g.session(ctx, hdr) + isArchival := g.isArchival(hdr) + span.SetAttributes(attribute.Bool("is_archival", isArchival)) + + ses, release := g.getSession(isArchival) + defer release() + if err = Fetch(ctx, g.exchange, hdr.DAH, blks, WithFetcher(ses)); err != nil { span.RecordError(err) span.SetStatus(codes.Error, "Fetch") @@ -230,17 +252,19 @@ func (g *Getter) GetNamespaceData( return nsShrs, nil } -// session decides which fetching session to use for the given header. -func (g *Getter) session(ctx context.Context, hdr *header.ExtendedHeader) exchange.Fetcher { - session := g.archivalSession +// isArchival reports whether the header is for archival data +func (g *Getter) isArchival(hdr *header.ExtendedHeader) bool { + return !availability.IsWithinWindow(hdr.Time(), g.availWndw) +} - isWithinAvailability := availability.IsWithinWindow(hdr.Time(), g.availWndw) - if isWithinAvailability { - session = g.availableSession +// getSession takes a session out of the respective session pool +func (g *Getter) getSession(isArchival bool) (ses exchange.Fetcher, release func()) { + if isArchival { + ses = g.archivalPool.get() + return ses, func() { g.archivalPool.put(ses) } } - - trace.SpanFromContext(ctx).SetAttributes(attribute.Bool("within_availability", isWithinAvailability)) - return session + ses = g.availablePool.get() + return ses, func() { g.availablePool.put(ses) } } // edsFromRows imports given Rows and computes EDS out of them, assuming enough Rows were provided. @@ -274,3 +298,40 @@ func edsFromRows(roots *share.AxisRoots, rows []shwap.Row) (*rsmt2d.ExtendedData return square, nil } + +// pool is a pool of Bitswap sessions. +type pool struct { + lock sync.Mutex + sessions []exchange.Fetcher + ctx context.Context + exchange exchange.SessionExchange +} + +func newPool(ex exchange.SessionExchange) *pool { + return &pool{ + exchange: ex, + sessions: make([]exchange.Fetcher, 0), + } +} + +// get returns a session from the pool or creates a new one if the pool is empty. +func (p *pool) get() exchange.Fetcher { + p.lock.Lock() + defer p.lock.Unlock() + + if len(p.sessions) == 0 { + return p.exchange.NewSession(p.ctx) + } + + ses := p.sessions[len(p.sessions)-1] + p.sessions = p.sessions[:len(p.sessions)-1] + return ses +} + +// put returns a session to the pool. +func (p *pool) put(ses exchange.Fetcher) { + p.lock.Lock() + defer p.lock.Unlock() + + p.sessions = append(p.sessions, ses) +} diff --git a/share/shwap/p2p/bitswap/getter_test.go b/share/shwap/p2p/bitswap/getter_test.go index cdabcb73be..b0e8d633cc 100644 --- a/share/shwap/p2p/bitswap/getter_test.go +++ b/share/shwap/p2p/bitswap/getter_test.go @@ -1,8 +1,11 @@ package bitswap import ( + "context" + "sync" "testing" + "github.com/ipfs/boxo/exchange" "github.com/stretchr/testify/require" libshare "github.com/celestiaorg/go-square/v2/share" @@ -28,3 +31,86 @@ func TestEDSFromRows(t *testing.T) { require.NoError(t, err) require.True(t, edsIn.Equals(edsOut)) } + +// mockSessionExchange is a mock implementation of exchange.SessionExchange +type mockSessionExchange struct { + exchange.SessionExchange + sessionCount int + mu sync.Mutex +} + +func (m *mockSessionExchange) NewSession(ctx context.Context) exchange.Fetcher { + m.mu.Lock() + defer m.mu.Unlock() + m.sessionCount++ + return &mockFetcher{id: m.sessionCount} +} + +// mockFetcher is a mock implementation of exchange.Fetcher +type mockFetcher struct { + exchange.Fetcher + id int +} + +func TestPoolGetFromEmptyPool(t *testing.T) { + ex := &mockSessionExchange{} + p := newPool(ex) + ctx := context.Background() + p.ctx = ctx + + ses := p.get().(*mockFetcher) + require.NotNil(t, ses) + require.Equal(t, 1, ses.id) +} + +func TestPoolPutAndGet(t *testing.T) { + ex := &mockSessionExchange{} + p := newPool(ex) + ctx := context.Background() + p.ctx = ctx + + // Get a session + ses := p.get().(*mockFetcher) + + // Put it back + p.put(ses) + + // Get again + ses2 := p.get().(*mockFetcher) + + require.Equal(t, ses.id, ses2.id) +} + +func TestPoolConcurrency(t *testing.T) { + ex := &mockSessionExchange{} + p := newPool(ex) + ctx := context.Background() + p.ctx = ctx + + const numGoroutines = 50 + var wg sync.WaitGroup + + sessionIDSet := make(map[int]struct{}) + lock := sync.Mutex{} + + // Start multiple goroutines to get sessions + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ses := p.get() + mockSes := ses.(*mockFetcher) + p.put(ses) + lock.Lock() + sessionIDSet[mockSes.id] = struct{}{} + lock.Unlock() + }() + } + wg.Wait() + + // Since the pool reuses sessions, the number of unique session IDs should be less than or equal to numGoroutines + if len(sessionIDSet) > numGoroutines { + t.Fatalf("expected number of unique sessions to be less than or equal to %d, got %d", + numGoroutines, len(sessionIDSet)) + } +}