diff --git a/pkg/statestore/redis.go b/pkg/statestore/redis.go index 3b83971..cee5a6a 100644 --- a/pkg/statestore/redis.go +++ b/pkg/statestore/redis.go @@ -15,12 +15,9 @@ import ( ) const ( - DefaultTicketTTL = 10 * time.Minute - DefaultPendingReleaseTimeout = 1 * time.Minute - DefaultAssignedDeleteTimeout = 1 * time.Minute - redisKeyTicketIndex = "allTickets" - redisKeyPendingTicketIndex = "proposed_ticket_ids" - redisKeyFetchTicketsLock = "fetchTicketsLock" + defaultTicketTTL = 10 * time.Minute + defaultPendingReleaseTimeout = 1 * time.Minute + defaultAssignedDeleteTimeout = 1 * time.Minute ) type RedisStore struct { @@ -35,14 +32,16 @@ type redisOpts struct { assignedDeleteTimeout time.Duration // Optional: Assignment is stored in a separate keyspace to distribute the load. assignmentSpaceClient rueidis.Client + keyPrefix string } func defaultRedisOpts() *redisOpts { return &redisOpts{ - ticketTTL: DefaultTicketTTL, - pendingReleaseTimeout: DefaultPendingReleaseTimeout, - assignedDeleteTimeout: DefaultAssignedDeleteTimeout, + ticketTTL: defaultTicketTTL, + pendingReleaseTimeout: defaultPendingReleaseTimeout, + assignedDeleteTimeout: defaultAssignedDeleteTimeout, assignmentSpaceClient: nil, + keyPrefix: "", } } @@ -80,6 +79,12 @@ func WithSeparatedAssignmentRedis(client rueidis.Client) RedisOption { }) } +func WithRedisKeyPrefix(prefix string) RedisOption { + return RedisOptionFunc(func(opts *redisOpts) { + opts.keyPrefix = prefix + }) +} + func NewRedisStore(client rueidis.Client, locker rueidislock.Locker, opts ...RedisOption) *RedisStore { ro := defaultRedisOpts() for _, o := range opts { @@ -98,8 +103,15 @@ func (s *RedisStore) CreateTicket(ctx context.Context, ticket *pb.Ticket) error return err } queries := []rueidis.Completed{ - s.client.B().Set().Key(redisKeyTicketData(ticket.Id)).Value(rueidis.BinaryString(data)).Ex(s.opts.ticketTTL).Build(), - s.client.B().Sadd().Key(redisKeyTicketIndex).Member(ticket.Id).Build(), + s.client.B().Set(). + Key(redisKeyTicketData(s.opts.keyPrefix, ticket.Id)). + Value(rueidis.BinaryString(data)). + Ex(s.opts.ticketTTL). + Build(), + s.client.B().Sadd(). + Key(redisKeyTicketIndex(s.opts.keyPrefix)). + Member(ticket.Id). + Build(), } for _, resp := range s.client.DoMulti(ctx, queries...) { if err := resp.Error(); err != nil { @@ -111,8 +123,8 @@ func (s *RedisStore) CreateTicket(ctx context.Context, ticket *pb.Ticket) error func (s *RedisStore) DeleteTicket(ctx context.Context, ticketID string) error { queries := []rueidis.Completed{ - s.client.B().Del().Key(redisKeyTicketData(ticketID)).Build(), - s.client.B().Srem().Key(redisKeyTicketIndex).Member(ticketID).Build(), + s.client.B().Del().Key(redisKeyTicketData(s.opts.keyPrefix, ticketID)).Build(), + s.client.B().Srem().Key(redisKeyTicketIndex(s.opts.keyPrefix)).Member(ticketID).Build(), } for _, resp := range s.client.DoMulti(ctx, queries...) { if err := resp.Error(); err != nil { @@ -140,7 +152,7 @@ func (s *RedisStore) GetAssignment(ctx context.Context, ticketID string) (*pb.As func (s *RedisStore) GetActiveTicketIDs(ctx context.Context, limit int64) ([]string, error) { // Acquire a lock to prevent multiple backends from fetching the same Ticket - lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock) + lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock(s.opts.keyPrefix)) if err != nil { return nil, fmt.Errorf("failed to acquire fetch tickets lock: %w", err) } @@ -168,7 +180,7 @@ func (s *RedisStore) GetActiveTicketIDs(ctx context.Context, limit int64) ([]str } func (s *RedisStore) getAllTicketIDs(ctx context.Context, limit int64) ([]string, error) { - resp := s.client.Do(ctx, s.client.B().Srandmember().Key(redisKeyTicketIndex).Count(limit).Build()) + resp := s.client.Do(ctx, s.client.B().Srandmember().Key(redisKeyTicketIndex(s.opts.keyPrefix)).Count(limit).Build()) if err := resp.Error(); err != nil { if rueidis.IsRedisNil(err) { return nil, nil @@ -185,7 +197,7 @@ func (s *RedisStore) getAllTicketIDs(ctx context.Context, limit int64) ([]string func (s *RedisStore) getPendingTicketIDs(ctx context.Context) ([]string, error) { rangeMin := strconv.FormatInt(time.Now().Add(-s.opts.pendingReleaseTimeout).Unix(), 10) rangeMax := strconv.FormatInt(time.Now().Add(1*time.Hour).Unix(), 10) - resp := s.client.Do(ctx, s.client.B().Zrangebyscore().Key(redisKeyPendingTicketIndex).Min(rangeMin).Max(rangeMax).Build()) + resp := s.client.Do(ctx, s.client.B().Zrangebyscore().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Min(rangeMin).Max(rangeMax).Build()) if err := resp.Error(); err != nil { if rueidis.IsRedisNil(err) { return nil, nil @@ -200,7 +212,7 @@ func (s *RedisStore) getPendingTicketIDs(ctx context.Context) ([]string, error) } func (s *RedisStore) setTicketsToPending(ctx context.Context, ticketIDs []string) error { - query := s.client.B().Zadd().Key(redisKeyPendingTicketIndex).ScoreMember() + query := s.client.B().Zadd().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).ScoreMember() score := float64(time.Now().Unix()) for _, ticketID := range ticketIDs { query = query.ScoreMember(score, ticketID) @@ -213,7 +225,7 @@ func (s *RedisStore) setTicketsToPending(ctx context.Context, ticketIDs []string } func (s *RedisStore) ReleaseTickets(ctx context.Context, ticketIDs []string) error { - resp := s.client.Do(ctx, s.client.B().Zrem().Key(redisKeyPendingTicketIndex).Member(ticketIDs...).Build()) + resp := s.client.Do(ctx, s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build()) if err := resp.Error(); err != nil { return fmt.Errorf("failed to release tickets: %w", err) } @@ -251,7 +263,7 @@ func (s *RedisStore) AssignTickets(ctx context.Context, asgs []*pb.AssignmentGro } func (s *RedisStore) getTicket(ctx context.Context, ticketID string) (*pb.Ticket, error) { - resp := s.client.Do(ctx, s.client.B().Get().Key(redisKeyTicketData(ticketID)).Build()) + resp := s.client.Do(ctx, s.client.B().Get().Key(redisKeyTicketData(s.opts.keyPrefix, ticketID)).Build()) if err := resp.Error(); err != nil { if rueidis.IsRedisNil(err) { return nil, ErrTicketNotFound @@ -270,7 +282,7 @@ func (s *RedisStore) getTicket(ctx context.Context, ticketID string) (*pb.Ticket } func (s *RedisStore) getAssignment(ctx context.Context, redis rueidis.Client, ticketID string) (*pb.Assignment, error) { - resp := redis.Do(ctx, s.client.B().Get().Key(redisKeyAssignmentData(ticketID)).Build()) + resp := redis.Do(ctx, s.client.B().Get().Key(redisKeyAssignmentData(s.opts.keyPrefix, ticketID)).Build()) if err := resp.Error(); err != nil { if rueidis.IsRedisNil(err) { return nil, ErrAssignmentNotFound @@ -296,7 +308,7 @@ func (s *RedisStore) setAssignmentToTickets(ctx context.Context, redis rueidis.C return fmt.Errorf("failed to encode assignemnt: %w", err) } queries[i] = redis.B().Set(). - Key(redisKeyAssignmentData(ticketID)). + Key(redisKeyAssignmentData(s.opts.keyPrefix, ticketID)). Value(rueidis.BinaryString(data)). Ex(s.opts.assignedDeleteTimeout).Build() } @@ -311,7 +323,7 @@ func (s *RedisStore) setAssignmentToTickets(ctx context.Context, redis rueidis.C func (s *RedisStore) getTickets(ctx context.Context, ticketIDs []string) ([]*pb.Ticket, error) { keys := make([]string, len(ticketIDs)) for i, tid := range ticketIDs { - keys[i] = redisKeyTicketData(tid) + keys[i] = redisKeyTicketData(s.opts.keyPrefix, tid) } mgetMap, err := rueidis.MGet(s.client, ctx, keys) if err != nil { @@ -342,7 +354,7 @@ func (s *RedisStore) getTickets(ctx context.Context, ticketIDs []string) ([]*pb. func (s *RedisStore) setTicketsExpiration(ctx context.Context, ticketIDs []string, expiration time.Duration) error { queries := make([]rueidis.Completed, len(ticketIDs)) for i, ticketID := range ticketIDs { - queries[i] = s.client.B().Expire().Key(redisKeyTicketData(ticketID)).Seconds(int64(expiration.Seconds())).Build() + queries[i] = s.client.B().Expire().Key(redisKeyTicketData(s.opts.keyPrefix, ticketID)).Seconds(int64(expiration.Seconds())).Build() } for _, resp := range s.client.DoMulti(ctx, queries...) { if err := resp.Error(); err != nil { @@ -354,8 +366,8 @@ func (s *RedisStore) setTicketsExpiration(ctx context.Context, ticketIDs []strin func (s *RedisStore) deIndexTickets(ticketIDs []string) []rueidis.Completed { return []rueidis.Completed{ - s.client.B().Zrem().Key(redisKeyPendingTicketIndex).Member(ticketIDs...).Build(), - s.client.B().Srem().Key(redisKeyTicketIndex).Member(ticketIDs...).Build(), + s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build(), + s.client.B().Srem().Key(redisKeyTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build(), } } @@ -367,7 +379,7 @@ func (s *RedisStore) releaseTimeoutTicketsByNow(ctx context.Context) error { func (s *RedisStore) releaseTimeoutTickets(ctx context.Context, before time.Time) error { rangeMin := "0" rangeMax := strconv.FormatInt(before.Unix(), 10) - query := s.client.B().Zremrangebyscore().Key(redisKeyPendingTicketIndex).Min(rangeMin).Max(rangeMax).Build() + query := s.client.B().Zremrangebyscore().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Min(rangeMin).Max(rangeMax).Build() resp := s.client.Do(ctx, query) if err := resp.Error(); err != nil { return err @@ -412,12 +424,24 @@ func decodeAssignment(b []byte) (*pb.Assignment, error) { return &as, nil } -func redisKeyTicketData(ticketID string) string { - return ticketID +func redisKeyTicketIndex(prefix string) string { + return fmt.Sprintf("%sallTickets", prefix) +} + +func redisKeyPendingTicketIndex(prefix string) string { + return fmt.Sprintf("%sproposed_ticket_ids", prefix) +} + +func redisKeyFetchTicketsLock(prefix string) string { + return fmt.Sprintf("%sfetchTicketsLock", prefix) +} + +func redisKeyTicketData(prefix, ticketID string) string { + return fmt.Sprintf("%s%s", prefix, ticketID) } -func redisKeyAssignmentData(ticketID string) string { - return fmt.Sprintf("assign:%s", ticketID) +func redisKeyAssignmentData(prefix, ticketID string) string { + return fmt.Sprintf("%sassign:%s", prefix, ticketID) } // difference returns the elements in `a` that aren't in `b`. diff --git a/pkg/statestore/redis_test.go b/pkg/statestore/redis_test.go index 4660d59..25dac85 100644 --- a/pkg/statestore/redis_test.go +++ b/pkg/statestore/redis_test.go @@ -132,7 +132,7 @@ func TestAssignedDeleteTimeout(t *testing.T) { } // assigned delete timeout - mr.FastForward(DefaultAssignedDeleteTimeout + 1*time.Second) + mr.FastForward(defaultAssignedDeleteTimeout + 1*time.Second) _, err = store.GetTicket(ctx, "test1") require.Error(t, err, ErrTicketNotFound)