Skip to content

Commit

Permalink
Merge pull request #13 from castaneai/fix-data-race
Browse files Browse the repository at this point in the history
fix: fix data race
  • Loading branch information
castaneai authored Dec 25, 2023
2 parents 325bab5 + f43a896 commit 62548a6
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ jobs:
- name: Get dependencies
run: go mod download
- name: Test
run: go test -v ./...
run: go test -v -race ./...
21 changes: 15 additions & 6 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package minimatch
import (
"context"
"fmt"
"sync"
"time"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
"golang.org/x/sync/errgroup"
"open-match.dev/open-match/pkg/pb"

"github.com/castaneai/minimatch/pkg/mmlog"
"github.com/castaneai/minimatch/pkg/statestore"
)

Expand All @@ -21,6 +21,7 @@ const (
type Backend struct {
store statestore.StateStore
mmfs map[*pb.MatchProfile]MatchFunction
mmfMu sync.RWMutex
assigner Assigner
options *backendOptions
metrics *backendMetrics
Expand Down Expand Up @@ -80,25 +81,30 @@ func NewBackend(store statestore.StateStore, assigner Assigner, opts ...BackendO
return &Backend{
store: store,
mmfs: map[*pb.MatchProfile]MatchFunction{},
mmfMu: sync.RWMutex{},
assigner: newAssignerWithMetrics(assigner, metrics),
options: options,
metrics: metrics,
}, nil
}

func (b *Backend) AddMatchFunction(profile *pb.MatchProfile, mmf MatchFunction) {
b.mmfMu.Lock()
defer b.mmfMu.Unlock()
b.mmfs[profile] = newMatchFunctionWithMetrics(mmf, b.metrics)
}

func (b *Backend) Start(ctx context.Context, tickRate time.Duration) error {
ticker := time.NewTicker(tickRate)
defer ticker.Stop()

profiles := make([]string, 0, len(b.mmfs))
for profile := range b.mmfs {
b.mmfMu.RLock()
mmfs := b.mmfs
b.mmfMu.RUnlock()
profiles := make([]string, 0, len(mmfs))
for profile := range mmfs {
profiles = append(profiles, profile.Name)
}
mmlog.Infof("minimatch backend started (matchProfile: %v, tickRate: %s)", profiles, tickRate)
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -156,9 +162,12 @@ func (b *Backend) fetchActiveTickets(ctx context.Context) ([]*pb.Ticket, error)
}

func (b *Backend) makeMatches(ctx context.Context, activeTickets []*pb.Ticket) ([]*pb.Match, error) {
resCh := make(chan []*pb.Match, len(b.mmfs))
b.mmfMu.RLock()
mmfs := b.mmfs
b.mmfMu.RUnlock()
resCh := make(chan []*pb.Match, len(mmfs))
eg, ctx := errgroup.WithContext(ctx)
for profile, mmf := range b.mmfs {
for profile, mmf := range mmfs {
profile := profile
mmf := mmf
eg.Go(func() error {
Expand Down
3 changes: 1 addition & 2 deletions metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ type matchFunctionWithMetrics struct {
func (m *matchFunctionWithMetrics) MakeMatches(ctx context.Context, profile *pb.MatchProfile, poolTickets map[string][]*pb.Ticket) ([]*pb.Match, error) {
start := time.Now()
defer func() {
m.metrics.matchFunctionLatency.Record(ctx, time.Since(start).Seconds(),
metric.WithAttributes(matchProfileKey.String(profile.Name)))
m.metrics.recordMatchFunctionLatency(ctx, time.Since(start).Seconds(), profile)
}()
return m.mmf.MakeMatches(ctx, profile, poolTickets)
}
Expand Down
16 changes: 12 additions & 4 deletions minimatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"sync"
"time"

"github.com/alicebob/miniredis/v2"
Expand All @@ -19,6 +20,7 @@ type MiniMatch struct {
store statestore.StateStore
mmfs map[*pb.MatchProfile]MatchFunction
backend *Backend
mu sync.RWMutex
}

func NewMiniMatchWithRedis(opts ...statestore.RedisOption) (*MiniMatch, error) {
Expand All @@ -45,6 +47,7 @@ func NewMiniMatch(store statestore.StateStore) *MiniMatch {
return &MiniMatch{
store: store,
mmfs: map[*pb.MatchProfile]MatchFunction{},
mu: sync.RWMutex{},
}
}

Expand All @@ -71,19 +74,24 @@ func (m *MiniMatch) StartBackend(ctx context.Context, assigner Assigner, tickRat
if err != nil {
return fmt.Errorf("failed to create minimatch backend: %w", err)
}
m.mu.Lock()
m.backend = backend
m.mu.Unlock()
for profile, mmf := range m.mmfs {
m.backend.AddMatchFunction(profile, mmf)
backend.AddMatchFunction(profile, mmf)
}
return m.backend.Start(ctx, tickRate)
return backend.Start(ctx, tickRate)
}

// for testing
func (m *MiniMatch) TickBackend(ctx context.Context) error {
if m.backend == nil {
m.mu.RLock()
backend := m.backend
m.mu.RUnlock()
if backend == nil {
return fmt.Errorf("backend has not been started")
}
return m.backend.Tick(ctx)
return backend.Tick(ctx)
}

var MatchFunctionSimple1vs1 = MatchFunctionFunc(func(ctx context.Context, profile *pb.MatchProfile, poolTickets map[string][]*pb.Ticket) ([]*pb.Match, error) {
Expand Down
26 changes: 4 additions & 22 deletions pkg/statestore/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ func (s *RedisStore) GetActiveTickets(ctx context.Context, limit int64) ([]*pb.T
// Acquire a lock to prevent multiple backends from fetching the same Ticket
lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock)
if err != nil {
return nil, fmt.Errorf("failed to acquire fetch tickets lock")
return nil, fmt.Errorf("failed to acquire fetch tickets lock: %w", err)
}
defer unlock()

allTicketIDs, err := s.getAllTicketIDs(lockedCtx, limit)
if err != nil {
return nil, fmt.Errorf("failed to get all ticket IDs: %w", err)
}
if len(allTicketIDs) == 0 {
return nil, nil
}
Expand Down Expand Up @@ -333,27 +336,6 @@ func (s *RedisStore) getTickets(ctx context.Context, ticketIDs []string) ([]*pb.
return tickets, nil
}

func (s *RedisStore) setTickets(ctx context.Context, tickets []*pb.Ticket) error {
queries := make([]rueidis.Completed, len(tickets))
for i, ticket := range tickets {
data, err := encodeTicket(ticket)
if err != nil {
return fmt.Errorf("failed to encode ticket to update: %w", err)
}
queries[i] = s.client.B().Set().
Key(redisKeyTicketData(ticket.Id)).
Value(rueidis.BinaryString(data)).
Ex(s.opts.assignedDeleteTimeout).
Build()
}
for _, resp := range s.client.DoMulti(ctx, queries...) {
if err := resp.Error(); err != nil {
return fmt.Errorf("failed to update assigned tickets: %w", err)
}
}
return nil
}

func (s *RedisStore) setTicketsExpiration(ctx context.Context, ticketIDs []string, expiration time.Duration) error {
queries := make([]rueidis.Completed, len(ticketIDs))
for i, ticketID := range ticketIDs {
Expand Down
37 changes: 37 additions & 0 deletions pkg/statestore/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package statestore
import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/redis/rueidis"
"github.com/redis/rueidis/rueidislock"
"github.com/rs/xid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"open-match.dev/open-match/pkg/pb"
)

Expand Down Expand Up @@ -168,3 +171,37 @@ func ticketIDs(tickets []*pb.Ticket) []string {
}
return ids
}

func TestConcurrentFetchActiveTickets(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
mr := miniredis.RunT(t)
store := newTestRedisStore(t, mr.Addr())

for i := 0; i < 1000; i++ {
require.NoError(t, store.CreateTicket(ctx, &pb.Ticket{Id: xid.New().String()}))
}

eg, _ := errgroup.WithContext(ctx)
var mu sync.Mutex
ticketIDs := map[string]struct{}{}
for i := 0; i < 1000; i++ {
eg.Go(func() error {
tickets, err := store.GetActiveTickets(ctx, 1000)
if err != nil {
return err
}
for _, ticket := range tickets {
mu.Lock()
if _, ok := ticketIDs[ticket.Id]; ok {
mu.Unlock()
return fmt.Errorf("duplicated! ticket id: %s", ticket.Id)
}
ticketIDs[ticket.Id] = struct{}{}
mu.Unlock()
}
return nil
})
}
require.NoError(t, eg.Wait())
}
1 change: 1 addition & 0 deletions tests/intergration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func TestFrontend(t *testing.T) {
ctx := context.Background()

resp, err := c.GetTicket(ctx, &pb.GetTicketRequest{TicketId: "invalid"})
require.Error(t, err)
requireErrorCode(t, err, codes.NotFound)

t1 := mustCreateTicket(ctx, t, c, &pb.Ticket{})
Expand Down

0 comments on commit 62548a6

Please sign in to comment.