diff --git a/cmd/cli.go b/cmd/cli.go index 3281dc1a..d9bc09b6 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -65,11 +65,14 @@ var supportedCommands = map[string]cmdHandler{ "sunion": sUnion, // zset commands - "zadd": zAdd, - "zscore": zScore, - "zrem": zRem, - "zcard": zCard, - "zrange": zRange, + "zadd": zAdd, + "zscore": zScore, + "zrem": zRem, + "zcard": zCard, + "zrange": zRange, + "zrevrange": zRevRange, + "zrank": zRank, + "zrevrank": zRevRank, // generic commands "type": keyType, diff --git a/cmd/command.go b/cmd/command.go index 67ee6754..04f3c540 100644 --- a/cmd/command.go +++ b/cmd/command.go @@ -738,3 +738,40 @@ func zRange(cli *Client, args [][]byte) (interface{}, error) { } return cli.db.ZRange(args[0], start, stop) } + +func zRevRange(cli *Client, args [][]byte) (interface{}, error) { + if len(args) != 3 { + return nil, newWrongNumOfArgsError("zrevrange") + } + start, err := strconv.Atoi(string(args[1])) + if err != nil { + return nil, err + } + stop, err := strconv.Atoi(string(args[2])) + if err != nil { + return nil, err + } + return cli.db.ZRevRange(args[0], start, stop) +} + +func zRank(cli *Client, args [][]byte) (interface{}, error) { + if len(args) != 2 { + return nil, newWrongNumOfArgsError("zrank") + } + ok, rank := cli.db.ZRank(args[0], args[1]) + if !ok { + return nil, nil + } + return rank, nil +} + +func zRevRank(cli *Client, args [][]byte) (interface{}, error) { + if len(args) != 2 { + return nil, newWrongNumOfArgsError("zrevrank") + } + ok, rank := cli.db.ZRevRank(args[0], args[1]) + if !ok { + return nil, nil + } + return rank, nil +} diff --git a/index.go b/index.go index e8f7cd6e..bc080f78 100644 --- a/index.go +++ b/index.go @@ -133,27 +133,28 @@ func (db *RoseDB) buildSetsIndex(ent *logfile.LogEntry, pos *valuePos) { } func (db *RoseDB) buildZSetIndex(ent *logfile.LogEntry, pos *valuePos) { - idxTree := db.zsetIndex.trees[string(ent.Key)] if ent.Type == logfile.TypeDelete { db.zsetIndex.indexes.ZRem(string(ent.Key), string(ent.Value)) - if idxTree != nil { - idxTree.Delete(ent.Value) + if db.zsetIndex.trees[string(ent.Key)] != nil { + db.zsetIndex.trees[string(ent.Key)].Delete(ent.Value) } return } key, scoreBuf := db.decodeKey(ent.Key) score, _ := util.StrToFloat64(string(scoreBuf)) - - if idxTree == nil { - idxTree = art.NewART() - } if err := db.zsetIndex.murhash.Write(ent.Value); err != nil { logger.Fatalf("fail to write murmur hash: %v", err) } sum := db.zsetIndex.murhash.EncodeSum128() db.zsetIndex.murhash.Reset() + idxTree := db.zsetIndex.trees[string(key)] + if idxTree == nil { + idxTree = art.NewART() + db.zsetIndex.trees[string(key)] = idxTree + } + _, size := logfile.EncodeEntry(ent) idxNode := &indexNode{fid: pos.fid, offset: pos.offset, entrySize: size} if db.opts.IndexMode == KeyValueMemMode { diff --git a/zset.go b/zset.go index 414e2111..32b019c7 100644 --- a/zset.go +++ b/zset.go @@ -103,6 +103,28 @@ func (db *RoseDB) ZCard(key []byte) int { // ZRange returns the specified range of elements in the sorted set stored at key. func (db *RoseDB) ZRange(key []byte, start, stop int) ([][]byte, error) { + return db.zRangeInternal(key, start, stop, false) +} + +// ZRevRange returns the specified range of elements in the sorted set stored at key. +// The elements are considered to be ordered from the highest to the lowest score. +func (db *RoseDB) ZRevRange(key []byte, start, stop int) ([][]byte, error) { + return db.zRangeInternal(key, start, stop, true) +} + +// ZRank returns the rank of member in the sorted set stored at key, with the scores ordered from low to high. +// The rank (or index) is 0-based, which means that the member with the lowest score has rank 0. +func (db *RoseDB) ZRank(key []byte, member []byte) (ok bool, rank int) { + return db.zRankInternal(key, member, false) +} + +// ZRevRank returns the rank of member in the sorted set stored at key, with the scores ordered from high to low. +// The rank (or index) is 0-based, which means that the member with the highest score has rank 0. +func (db *RoseDB) ZRevRank(key []byte, member []byte) (ok bool, rank int) { + return db.zRankInternal(key, member, true) +} + +func (db *RoseDB) zRangeInternal(key []byte, start, stop int, rev bool) ([][]byte, error) { db.zsetIndex.mu.RLock() defer db.zsetIndex.mu.RUnlock() if db.zsetIndex.trees[string(key)] == nil { @@ -111,7 +133,12 @@ func (db *RoseDB) ZRange(key []byte, start, stop int) ([][]byte, error) { idxTree := db.zsetIndex.trees[string(key)] var res [][]byte - values := db.zsetIndex.indexes.ZRange(string(key), start, stop) + var values []interface{} + if rev { + values = db.zsetIndex.indexes.ZRevRange(string(key), start, stop) + } else { + values = db.zsetIndex.indexes.ZRange(string(key), start, stop) + } for _, val := range values { v, _ := val.(string) if val, err := db.getVal(idxTree, []byte(v), ZSet); err != nil { @@ -122,3 +149,29 @@ func (db *RoseDB) ZRange(key []byte, start, stop int) ([][]byte, error) { } return res, nil } + +func (db *RoseDB) zRankInternal(key []byte, member []byte, rev bool) (ok bool, rank int) { + db.zsetIndex.mu.RLock() + defer db.zsetIndex.mu.RUnlock() + if db.zsetIndex.trees[string(key)] == nil { + return + } + + if err := db.zsetIndex.murhash.Write(member); err != nil { + return + } + sum := db.zsetIndex.murhash.EncodeSum128() + db.zsetIndex.murhash.Reset() + + var result int64 + if rev { + result = db.zsetIndex.indexes.ZRevRank(string(key), string(sum)) + } else { + result = db.zsetIndex.indexes.ZRank(string(key), string(sum)) + } + if result != -1 { + ok = true + rank = int(result) + } + return +} diff --git a/zset_test.go b/zset_test.go index 4b83a5b0..e7911787 100644 --- a/zset_test.go +++ b/zset_test.go @@ -188,6 +188,75 @@ func testRoseDBZRange(t *testing.T, ioType IOType, mode DataIndexMode) { assert.Equal(t, 4, len(values)) } +func TestRoseDB_ZRevRange(t *testing.T) { + t.Run("fileio", func(t *testing.T) { + testRoseDBZRevRange(t, FileIO, KeyValueMemMode) + }) + t.Run("mmap", func(t *testing.T) { + testRoseDBZRevRange(t, MMap, KeyOnlyMemMode) + }) +} + +func testRoseDBZRevRange(t *testing.T, ioType IOType, mode DataIndexMode) { + path := filepath.Join("/tmp", "rosedb") + opts := DefaultOptions(path) + opts.IoType = ioType + opts.IndexMode = mode + db, err := Open(opts) + assert.Nil(t, err) + defer destroyDB(db) + + zsetKey := []byte("my_zset") + for i := 0; i < 100; i++ { + err := db.ZAdd(zsetKey, float64(i+100), GetKey(i)) + assert.Nil(t, err) + } + + ok, score := db.ZScore(zsetKey, GetKey(3)) + assert.True(t, ok) + assert.Equal(t, float64(103), score) + + values, err := db.ZRevRange(zsetKey, 1, 10) + assert.Nil(t, err) + assert.Equal(t, 10, len(values)) +} + +func TestRoseDB_ZRank(t *testing.T) { + t.Run("fileio", func(t *testing.T) { + testRoseDBZRank(t, FileIO, KeyValueMemMode) + }) + t.Run("mmap", func(t *testing.T) { + testRoseDBZRank(t, MMap, KeyOnlyMemMode) + }) +} + +func testRoseDBZRank(t *testing.T, ioType IOType, mode DataIndexMode) { + path := filepath.Join("/tmp", "rosedb") + opts := DefaultOptions(path) + opts.IoType = ioType + opts.IndexMode = mode + db, err := Open(opts) + assert.Nil(t, err) + defer destroyDB(db) + + zsetKey := []byte("my_zset") + for i := 0; i < 100; i++ { + err := db.ZAdd(zsetKey, float64(i+100), GetKey(i)) + assert.Nil(t, err) + } + + ok, r1 := db.ZRank(zsetKey, GetKey(-1)) + assert.False(t, ok) + assert.Equal(t, 0, r1) + + ok, r2 := db.ZRank(zsetKey, GetKey(3)) + assert.True(t, ok) + assert.Equal(t, 3, r2) + ok, r3 := db.ZRevRank(zsetKey, GetKey(1)) + assert.True(t, ok) + assert.Equal(t, 98, r3) +} + func TestRoseDB_ZSetGC(t *testing.T) { path := filepath.Join("/tmp", "rosedb") opts := DefaultOptions(path)