From 75030407cb8ff7a88a0f5b51326ef0bab2e0c775 Mon Sep 17 00:00:00 2001 From: lhpqaq <657407891@qq.com> Date: Tue, 23 Jul 2024 21:48:31 +0800 Subject: [PATCH] Add sscan, hscan, zscan --- database/hash.go | 52 ++++++++++++++++++++++++ database/hash_test.go | 45 +++++++++++++++++++++ database/keys.go | 2 +- database/set.go | 50 +++++++++++++++++++++++ database/set_test.go | 37 +++++++++++++++++ database/sortedset.go | 56 ++++++++++++++++++++++++-- database/sortedset_test.go | 52 ++++++++++++++++++++++-- datastruct/dict/simple.go | 21 +++++++++- datastruct/dict/simple_test.go | 27 +++++++++++++ datastruct/set/set.go | 18 +++++++++ datastruct/set/set_test.go | 28 +++++++++++++ datastruct/sortedset/sortedset.go | 21 ++++++++++ datastruct/sortedset/sortedset_test.go | 33 ++++++++++++++- 13 files changed, 432 insertions(+), 10 deletions(-) diff --git a/database/hash.go b/database/hash.go index d2f08fe3..777c4d7f 100644 --- a/database/hash.go +++ b/database/hash.go @@ -496,6 +496,56 @@ func execHRandField(db *DB, args [][]byte) redis.Reply { return &protocol.EmptyMultiBulkReply{} } +func execHScan(db *DB, args [][]byte) redis.Reply { + var count int = 10 + var pattern string = "*" + if len(args) > 2 { + for i := 2; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "count" { + count0, err := strconv.Atoi(string(args[i+1])) + if err != nil { + return &protocol.SyntaxErrReply{} + } + count = count0 + i++ + } else if arg == "match" { + pattern = string(args[i+1]) + i++ + } else { + return &protocol.SyntaxErrReply{} + } + } + } + if len(args) < 2 { + return &protocol.SyntaxErrReply{} + } + key := string(args[0]) + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return &protocol.NullBulkReply{} + } + cursor, err := strconv.Atoi(string(args[1])) + if err != nil { + return protocol.MakeErrReply("ERR invalid cursor") + } + + keysReply, nextCursor := dict.DictScan(cursor, count, pattern) + if nextCursor < 0 { + return protocol.MakeErrReply("Invalid argument") + } + + result := make([]redis.Reply, 2) + result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10))) + result[1] = protocol.MakeMultiBulkReply(keysReply) + + return protocol.MakeMultiRawReply(result) +} + func init() { registerCommand("HSet", execHSet, writeFirstKey, undoHSet, 4, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) @@ -529,4 +579,6 @@ func init() { attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) registerCommand("HRandField", execHRandField, readFirstKey, nil, -2, flagReadOnly). attachCommandExtra([]string{redisFlagRandom, redisFlagReadonly}, 1, 1, 1) + registerCommand("HScan", execHScan, readFirstKey, nil, -2, flagReadOnly). + attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) } diff --git a/database/hash_test.go b/database/hash_test.go index 8c73bbcc..7a5b2ea7 100644 --- a/database/hash_test.go +++ b/database/hash_test.go @@ -346,3 +346,48 @@ func TestUndoHIncr(t *testing.T) { result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field)) asserts.AssertBulkReply(t, result, "1") } + +func TestHScan(t *testing.T) { + testDB.Flush() + hashKey := "test:hash" + for i := 0; i < 3; i++ { + key := string(rune(i)) + value := key + testDB.Exec(nil, utils.ToCmdLine("hset", hashKey, "a"+key, value)) + } + for i := 0; i < 3; i++ { + key := string(rune(i)) + value := key + testDB.Exec(nil, utils.ToCmdLine("hset", hashKey, "b"+key, value)) + } + + result := testDB.Exec(nil, utils.ToCmdLine("hscan", hashKey, "0", "count", "10")) + cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg) + cursor, err := strconv.Atoi(cursorStr) + if err == nil { + if cursor != 0 { + t.Errorf("expect cursor 0, actually %d", cursor) + return + } + } else { + t.Errorf("get scan result error") + return + } + + // test hscan 0 match a* + result = testDB.Exec(nil, utils.ToCmdLine("hscan", hashKey, "0", "match", "a*")) + returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + i := 0 + for i < len(returnKeys) { + if i%2 != 0 { + i++ + continue // pass value + } + key := string(returnKeys[i]) + i++ + if key[0] != 'a' { + t.Errorf("The key %s should match a*", key) + return + } + } +} diff --git a/database/keys.go b/database/keys.go index 12dd2259..c9e3dad5 100644 --- a/database/keys.go +++ b/database/keys.go @@ -505,6 +505,6 @@ func init() { attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1) registerCommand("Keys", execKeys, noPrepare, nil, 2, flagReadOnly). attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) - registerCommand("Scan", execScan, readAllKeys, nil, -2, flagReadOnly). + registerCommand("Scan", execScan, noPrepare, nil, -2, flagReadOnly). attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) } diff --git a/database/set.go b/database/set.go index b283a0d6..5246fc06 100644 --- a/database/set.go +++ b/database/set.go @@ -7,6 +7,7 @@ import ( "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/protocol" "strconv" + "strings" ) func (db *DB) getAsSet(key string) (*HashSet.Set, protocol.ErrorReply) { @@ -354,6 +355,53 @@ func execSRandMember(db *DB, args [][]byte) redis.Reply { return &protocol.EmptyMultiBulkReply{} } +func execSScan(db *DB, args [][]byte) redis.Reply { + var count int = 10 + var pattern string = "*" + if len(args) > 2 { + for i := 2; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "count" { + count0, err := strconv.Atoi(string(args[i+1])) + if err != nil { + return &protocol.SyntaxErrReply{} + } + count = count0 + i++ + } else if arg == "match" { + pattern = string(args[i+1]) + i++ + } else { + return &protocol.SyntaxErrReply{} + } + } + } + key := string(args[0]) + // get entity + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &protocol.EmptyMultiBulkReply{} + } + cursor, err := strconv.Atoi(string(args[1])) + if err != nil { + return protocol.MakeErrReply("ERR invalid cursor") + } + + keysReply, nextCursor := set.SetScan(cursor, count, pattern) + if nextCursor < 0 { + return protocol.MakeErrReply("Invalid argument") + } + + result := make([]redis.Reply, 2) + result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10))) + result[1] = protocol.MakeMultiBulkReply(keysReply) + + return protocol.MakeMultiRawReply(result) +} + func init() { registerCommand("SAdd", execSAdd, writeFirstKey, undoSetChange, -3, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) @@ -381,4 +429,6 @@ func init() { attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, 1, 1) registerCommand("SRandMember", execSRandMember, readFirstKey, nil, -2, flagReadOnly). attachCommandExtra([]string{redisFlagReadonly, redisFlagRandom}, 1, 1, 1) + registerCommand("SScan", execSScan, readFirstKey, nil, -2, flagReadOnly). + attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) } diff --git a/database/set_test.go b/database/set_test.go index 8ccc4f61..b8fc062c 100644 --- a/database/set_test.go +++ b/database/set_test.go @@ -248,3 +248,40 @@ func TestSRandMember(t *testing.T) { result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "-110")) asserts.AssertMultiBulkReplySize(t, result, 110) } + +func TestSScan(t *testing.T) { + testDB.Flush() + setKey := "test:set" + for i := 0; i < 3; i++ { + key := string(rune(i)) + testDB.Exec(nil, utils.ToCmdLine("sadd", setKey, "a"+key)) + } + for i := 0; i < 3; i++ { + key := string(rune(i)) + testDB.Exec(nil, utils.ToCmdLine("sadd", setKey, "b"+key)) + } + + result := testDB.Exec(nil, utils.ToCmdLine("sscan", setKey, "0", "count", "10")) + cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg) + cursor, err := strconv.Atoi(cursorStr) + if err == nil { + if cursor != 0 { + t.Errorf("expect cursor 0, actually %d", cursor) + return + } + } else { + t.Errorf("get scan result error") + return + } + + // test sscan 0 match a* + result = testDB.Exec(nil, utils.ToCmdLine("sscan", setKey, "0", "match", "a*")) + returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + for i := range returnKeys { + key := string(returnKeys[i]) + if key[0] != 'a' { + t.Errorf("The key %s should match a*", key) + return + } + } +} diff --git a/database/sortedset.go b/database/sortedset.go index 9602c1a3..ab16aed1 100644 --- a/database/sortedset.go +++ b/database/sortedset.go @@ -1,15 +1,14 @@ package database import ( - "math" - "strconv" - "strings" - SortedSet "github.com/hdt3213/godis/datastruct/sortedset" "github.com/hdt3213/godis/interface/database" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/protocol" + "math" + "strconv" + "strings" ) func (db *DB) getAsSortedSet(key string) (*SortedSet.SortedSet, protocol.ErrorReply) { @@ -796,6 +795,53 @@ func execZRevRangeByLex(db *DB, args [][]byte) redis.Reply { return protocol.MakeMultiBulkReply(result) } +func execZScan(db *DB, args [][]byte) redis.Reply { + var count int = 10 + var pattern string = "*" + if len(args) > 2 { + for i := 2; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "count" { + count0, err := strconv.Atoi(string(args[i+1])) + if err != nil { + return &protocol.SyntaxErrReply{} + } + count = count0 + i++ + } else if arg == "match" { + pattern = string(args[i+1]) + i++ + } else { + return &protocol.SyntaxErrReply{} + } + } + } + key := string(args[0]) + // get entity + set, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &protocol.EmptyMultiBulkReply{} + } + cursor, err := strconv.Atoi(string(args[1])) + if err != nil { + return protocol.MakeErrReply("ERR invalid cursor") + } + + keysReply, nextCursor := set.ZSetScan(cursor, count, pattern) + if nextCursor < 0 { + return protocol.MakeErrReply("Invalid argument") + } + + result := make([]redis.Reply, 2) + result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10))) + result[1] = protocol.MakeMultiBulkReply(keysReply) + + return protocol.MakeMultiRawReply(result) +} + func init() { registerCommand("ZAdd", execZAdd, writeFirstKey, undoZAdd, -4, flagWrite). attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1) @@ -835,4 +881,6 @@ func init() { attachCommandExtra([]string{redisFlagWrite}, 1, 1, 1) registerCommand("ZRevRangeByLex", execZRevRangeByLex, readFirstKey, nil, -4, flagReadOnly). attachCommandExtra([]string{redisFlagReadonly}, 1, 1, 1) + registerCommand("ZScan", execZScan, readFirstKey, nil, -2, flagReadOnly). + attachCommandExtra([]string{redisFlagReadonly}, 1, 1, 1) } diff --git a/database/sortedset_test.go b/database/sortedset_test.go index 657b93fe..ad3b609f 100644 --- a/database/sortedset_test.go +++ b/database/sortedset_test.go @@ -1,12 +1,12 @@ package database import ( + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/protocol" + "github.com/hdt3213/godis/redis/protocol/asserts" "math/rand" "strconv" "testing" - - "github.com/hdt3213/godis/lib/utils" - "github.com/hdt3213/godis/redis/protocol/asserts" ) func TestZAdd(t *testing.T) { @@ -762,3 +762,49 @@ func TestZRevRangeByLex(t *testing.T) { result30 := testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByLex", key, "+", "-", "limit", "2", "2")) asserts.AssertMultiBulkReply(t, result30, []string{"c", "b"}) } + +func TestZScan(t *testing.T) { + testDB.Flush() + zsetKey := "zsetkey" + expectKeyScore := make(map[string]float64) + for i := 0; i < 3; i++ { + key := string(rune(i)) + expectKeyScore[key] = float64(i) + testDB.Exec(nil, utils.ToCmdLine("zadd", zsetKey, strconv.FormatInt(int64(i), 10), "a"+key)) + } + for i := 0; i < 3; i++ { + key := string(rune(i)) + expectKeyScore[key] = float64(i + 3) + testDB.Exec(nil, utils.ToCmdLine("zadd", zsetKey, strconv.FormatInt(int64(i), 10), "b"+key)) + } + + result := testDB.Exec(nil, utils.ToCmdLine("zscan", zsetKey, "0", "count", "10")) + cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg) + cursor, err := strconv.Atoi(cursorStr) + if err == nil { + if cursor != 0 { + t.Errorf("expect cursor 0, actually %d", cursor) + return + } + } else { + t.Errorf("get scan result error") + return + } + + // test zscan 0 match a* + result = testDB.Exec(nil, utils.ToCmdLine("zscan", zsetKey, "0", "match", "a*")) + returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + i := 0 + for i < len(returnKeys) { + if i%2 != 0 { + i++ + continue // pass score + } + key := string(returnKeys[i]) + i++ + if key[0] != 'a' { + t.Errorf("The key %s should match a*", key) + return + } + } +} diff --git a/datastruct/dict/simple.go b/datastruct/dict/simple.go index 187a708d..74282c34 100644 --- a/datastruct/dict/simple.go +++ b/datastruct/dict/simple.go @@ -1,5 +1,9 @@ package dict +import ( + "github.com/hdt3213/godis/lib/wildcard" +) + // SimpleDict wraps a map, it is not thread safe type SimpleDict struct { m map[string]interface{} @@ -122,5 +126,20 @@ func (dict *SimpleDict) Clear() { } func (dict *SimpleDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) { - return stringsToBytes(dict.Keys()), 0 + result := make([][]byte, 0) + matchKey, err := wildcard.CompilePattern(pattern) + if err != nil { + return result, -1 + } + for k := range dict.m { + if pattern == "*" || matchKey.IsMatch(k) { + raw, exists := dict.Get(k) + if !exists { + continue + } + result = append(result, []byte(k)) + result = append(result, raw.([]byte)) + } + } + return result, 0 } diff --git a/datastruct/dict/simple_test.go b/datastruct/dict/simple_test.go index 74f4a4ab..98ee8d5a 100644 --- a/datastruct/dict/simple_test.go +++ b/datastruct/dict/simple_test.go @@ -53,3 +53,30 @@ func TestSimpleDict_PutIfExists(t *testing.T) { return } } + +func TestSimpleDict_Scan(t *testing.T) { + d := MakeSimple() + size := 10 + for i := 0; i < size; i++ { + str := "a" + utils.RandString(5) + d.Put(str, []byte(str)) + } + keys, nextCursor := d.DictScan(0, size, "*") + if len(keys) != size*2 { + t.Errorf("expect %d keys, actual: %d", size*2, len(keys)) + return + } + if nextCursor != 0 { + t.Errorf("expect 0, actual: %d", nextCursor) + return + } + for i := 0; i < size; i++ { + str := "b" + utils.RandString(5) + d.Put(str, str) + } + keys, _ = d.DictScan(0, size*2, "a*") + if len(keys) != size*2 { + t.Errorf("expect %d keys, actual: %d", size*2, len(keys)) + return + } +} diff --git a/datastruct/set/set.go b/datastruct/set/set.go index 11236e62..af3fa0a8 100644 --- a/datastruct/set/set.go +++ b/datastruct/set/set.go @@ -2,6 +2,7 @@ package set import ( "github.com/hdt3213/godis/datastruct/dict" + "github.com/hdt3213/godis/lib/wildcard" ) // Set is a set of elements based on hash table @@ -149,3 +150,20 @@ func (set *Set) RandomMembers(limit int) []string { func (set *Set) RandomDistinctMembers(limit int) []string { return set.dict.RandomDistinctKeys(limit) } + +// Scan set with cursor and pattern +func (set *Set) SetScan(cursor int, count int, pattern string) ([][]byte, int) { + result := make([][]byte, 0) + matchKey, err := wildcard.CompilePattern(pattern) + if err != nil { + return result, -1 + } + set.ForEach(func(member string) bool { + if pattern == "*" || matchKey.IsMatch(member) { + result = append(result, []byte(member)) + } + return true + }) + + return result, 0 +} diff --git a/datastruct/set/set_test.go b/datastruct/set/set_test.go index fd95aefc..6b50b8e1 100644 --- a/datastruct/set/set_test.go +++ b/datastruct/set/set_test.go @@ -1,6 +1,7 @@ package set import ( + "github.com/hdt3213/godis/lib/utils" "strconv" "testing" ) @@ -30,3 +31,30 @@ func TestSet(t *testing.T) { } } } + +func TestSetScan(t *testing.T) { + set := Make() + size := 10 + for i := 0; i < size; i++ { + str := "a" + utils.RandString(5) + set.Add(str) + } + keys, nextCursor := set.SetScan(0, size, "*") + if len(keys) != size { + t.Errorf("expect %d keys, actual: %d", size, len(keys)) + return + } + if nextCursor != 0 { + t.Errorf("expect 0, actual: %d", nextCursor) + return + } + for i := 0; i < size; i++ { + str := "b" + utils.RandString(5) + set.Add(str) + } + keys, _ = set.SetScan(0, size*2, "a*") + if len(keys) != size { + t.Errorf("expect %d keys, actual: %d", size, len(keys)) + return + } +} diff --git a/datastruct/sortedset/sortedset.go b/datastruct/sortedset/sortedset.go index 52eacf3a..769c519e 100644 --- a/datastruct/sortedset/sortedset.go +++ b/datastruct/sortedset/sortedset.go @@ -2,6 +2,8 @@ package sortedset import ( "strconv" + + "github.com/hdt3213/godis/lib/wildcard" ) // SortedSet is a set which keys sorted by bound score @@ -236,3 +238,22 @@ func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 { } return int64(len(removed)) } + +func (sortedSet *SortedSet) ZSetScan(cursor int, count int, pattern string) ([][]byte, int) { + result := make([][]byte, 0) + matchKey, err := wildcard.CompilePattern(pattern) + if err != nil { + return result, -1 + } + for k := range sortedSet.dict { + if pattern == "*" || matchKey.IsMatch(k) { + elem, exists := sortedSet.dict[k] + if !exists { + continue + } + result = append(result, []byte(k)) + result = append(result, []byte(strconv.FormatFloat(elem.Score, 'f', 10, 64))) + } + } + return result, 0 +} diff --git a/datastruct/sortedset/sortedset_test.go b/datastruct/sortedset/sortedset_test.go index 10bb68ce..d36550a7 100644 --- a/datastruct/sortedset/sortedset_test.go +++ b/datastruct/sortedset/sortedset_test.go @@ -1,6 +1,10 @@ package sortedset -import "testing" +import ( + "testing" + + "github.com/hdt3213/godis/lib/utils" +) func TestSortedSet_PopMin(t *testing.T) { var set = Make() @@ -14,3 +18,30 @@ func TestSortedSet_PopMin(t *testing.T) { t.Fail() } } + +func TestSetScan(t *testing.T) { + set := Make() + size := 10 + for i := 0; i < size; i++ { + str := "a" + utils.RandString(5) + set.Add(str, float64(i)) + } + keys, nextCursor := set.ZSetScan(0, size, "*") + if len(keys) != size*2 { + t.Errorf("expect %d keys, actual: %d", size*2, len(keys)) + return + } + if nextCursor != 0 { + t.Errorf("expect 0, actual: %d", nextCursor) + return + } + for i := 0; i < size; i++ { + str := "b" + utils.RandString(5) + set.Add(str, float64(i+size)) + } + keys, _ = set.ZSetScan(0, size*2, "a*") + if len(keys) != size*2 { + t.Errorf("expect %d keys, actual: %d", size*2, len(keys)) + return + } +}