From ad480dbc69aba29937eb5c4fda757df599a50b87 Mon Sep 17 00:00:00 2001 From: lhpqaq <657407891@qq.com> Date: Wed, 17 Jul 2024 10:31:49 +0800 Subject: [PATCH] add scan command --- database/keys.go | 83 +++++++++++++++++++++++++---- database/keys_test.go | 82 ++++++++++++++++++++++++++++ datastruct/dict/concurrent.go | 45 ++++++++++++++++ datastruct/dict/concurrent_test.go | 50 ++++++++++++++++- datastruct/dict/dict.go | 1 + datastruct/dict/simple.go | 4 ++ lib/utils/utils.go | 17 ++++++ test.rdb | Bin 202 -> 0 bytes 8 files changed, 271 insertions(+), 11 deletions(-) delete mode 100644 test.rdb diff --git a/database/keys.go b/database/keys.go index abdbf422..12dd2259 100644 --- a/database/keys.go +++ b/database/keys.go @@ -59,26 +59,36 @@ func execFlushDB(db *DB, args [][]byte) redis.Reply { return &protocol.OkReply{} } -// execType returns the type of entity, including: string, list, hash, set and zset -func execType(db *DB, args [][]byte) redis.Reply { - key := string(args[0]) +// returns the type of entity, including: string, list, hash, set and zset +func getType(db *DB, key string) string { entity, exists := db.GetEntity(key) if !exists { - return protocol.MakeStatusReply("none") + return "none" } switch entity.Data.(type) { case []byte: - return protocol.MakeStatusReply("string") + return "string" case list.List: - return protocol.MakeStatusReply("list") + return "list" case dict.Dict: - return protocol.MakeStatusReply("hash") + return "hash" case *set.Set: - return protocol.MakeStatusReply("set") + return "set" case *sortedset.SortedSet: - return protocol.MakeStatusReply("zset") + return "zset" + } + return "" +} + +// execType returns the type of entity, including: string, list, hash, set and zset +func execType(db *DB, args [][]byte) redis.Reply { + key := string(args[0]) + result := getType(db, key) + if len(result) > 0 { + return protocol.MakeStatusReply(result) + } else { + return &protocol.UnknownErrReply{} } - return &protocol.UnknownErrReply{} } func prepareRename(args [][]byte) ([]string, []string) { @@ -413,6 +423,57 @@ func execCopy(mdb *Server, conn redis.Connection, args [][]byte) redis.Reply { return protocol.MakeIntReply(1) } +// execScan return the result of the scan +func execScan(db *DB, args [][]byte) redis.Reply { + var count int = 10 + var pattern string = "*" + var scanType string = "" + if len(args) > 1 { + for i := 1; i < len(args); i++ { + arg := strings.ToLower(string(args[i])) + if arg == "count" { + count0, err := strconv.Atoi(string(args[i+1])) + if err != nil { + return &protocol.SyntaxErrReply{} + } + count = count0 + i++ + } else if arg == "match" { + pattern = string(args[i+1]) + i++ + } else if arg == "type" { + scanType = strings.ToLower(string(args[i+1])) + i++ + } else { + return &protocol.SyntaxErrReply{} + } + } + } + cursor, err := strconv.Atoi(string(args[0])) + if err != nil { + return protocol.MakeErrReply("ERR invalid cursor") + } + keysReply, nextCursor := db.data.DictScan(cursor, count, pattern) + if nextCursor < 0 { + return protocol.MakeErrReply("Invalid argument") + } + + if len(scanType) != 0 { + for i := 0; i < len(keysReply); { + if getType(db, string(keysReply[i])) != scanType { + keysReply = append(keysReply[:i], keysReply[i+1:]...) + } else { + i++ + } + } + } + result := make([]redis.Reply, 2) + result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10))) + result[1] = protocol.MakeMultiBulkReply(keysReply) + + return protocol.MakeMultiRawReply(result) +} + func init() { registerCommand("Del", execDel, writeAllKeys, undoDel, -2, flagWrite). attachCommandExtra([]string{redisFlagWrite}, 1, -1, 1) @@ -444,4 +505,6 @@ func init() { attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1) registerCommand("Keys", execKeys, noPrepare, nil, 2, flagReadOnly). attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) + registerCommand("Scan", execScan, readAllKeys, nil, -2, flagReadOnly). + attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1) } diff --git a/database/keys_test.go b/database/keys_test.go index 2b220f57..0376c62c 100644 --- a/database/keys_test.go +++ b/database/keys_test.go @@ -313,3 +313,85 @@ func TestCopy(t *testing.T) { result = testMDB.Exec(conn, utils.ToCmdLine("ttl", destKey)) asserts.AssertIntReplyGreaterThan(t, result, 0) } + +func TestScan(t *testing.T) { + testDB.Flush() + for i := 0; i < 3; i++ { + key := string(rune(i)) + value := key + testDB.Exec(nil, utils.ToCmdLine("set", "a:"+key, value)) + } + for i := 0; i < 3; i++ { + key := string(rune(i)) + value := key + testDB.Exec(nil, utils.ToCmdLine("set", "b:"+key, value)) + } + + // test scan 0 when keys < 10 + result := testDB.Exec(nil, utils.ToCmdLine("scan", "0")) + cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg) + cursor, err := strconv.Atoi(cursorStr) + if err == nil { + if cursor != 0 { + t.Errorf("expect cursor 0, actually %d", cursor) + return + } + } else { + t.Errorf("get scan result error") + return + } + + // test scan 0 match a* + result = testDB.Exec(nil, utils.ToCmdLine("scan", "0", "match", "a*")) + returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + for i := range returnKeys { + key := string(returnKeys[i]) + if key[0] != 'a' { + t.Errorf("The key %s should match a*", key) + return + } + } + + // test scan 0 type string + testDB.Exec(nil, utils.ToCmdLine("hset", "hashkey", "hashkey", "1")) + result = testDB.Exec(nil, utils.ToCmdLine("scan", "0", "type", "string")) + returnKeys = result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + for i := range returnKeys { + key := string(returnKeys[i]) + if key == "hashkey" { + t.Errorf("expect type string, found hash") + return + } + } + + // test returned cursor + testDB.Flush() + for i := 0; i < 100; i++ { + key := string(rune(i)) + value := key + testDB.Exec(nil, utils.ToCmdLine("set", "a"+key, value)) + } + cursor = 0 + resultByte := make([][]byte, 0) + for { + scanCursor := strconv.Itoa(cursor) + result = testDB.Exec(nil, utils.ToCmdLine("scan", scanCursor, "count", "20")) + cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg) + returnKeys = result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args + resultByte = append(resultByte, returnKeys...) + cursor, err = strconv.Atoi(cursorStr) + if err == nil { + if cursor == 0 { + break + } + } else { + t.Errorf("get scan result error") + return + } + } + resultByte = utils.RemoveDuplicates(resultByte) + if len(resultByte) != 100 { + t.Errorf("expect result num 100, actually %d", len(resultByte)) + return + } +} diff --git a/datastruct/dict/concurrent.go b/datastruct/dict/concurrent.go index 217143bf..2b791224 100644 --- a/datastruct/dict/concurrent.go +++ b/datastruct/dict/concurrent.go @@ -1,6 +1,7 @@ package dict import ( + "github.com/hdt3213/godis/lib/wildcard" "math" "math/rand" "sort" @@ -435,3 +436,47 @@ func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) { } } } + +func stringsToBytes(strSlice []string) [][]byte { + byteSlice := make([][]byte, len(strSlice)) + for i, str := range strSlice { + byteSlice[i] = []byte(str) + } + return byteSlice +} + +func (dict *ConcurrentDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) { + size := dict.Len() + result := make([][]byte, 0) + + if pattern == "*" && count >= size { + return stringsToBytes(dict.Keys()), 0 + } + + matchKey, err := wildcard.CompilePattern(pattern) + if err != nil { + return result, -1 + } + + shardCount := len(dict.table) + shardIndex := cursor + + for shardIndex < shardCount { + shard := dict.table[shardIndex] + shard.mutex.RLock() + if len(result)+len(shard.m) > count && shardIndex > cursor { + shard.mutex.RUnlock() + return result, shardIndex + } + + for key := range shard.m { + if pattern == "*" || matchKey.IsMatch(key) { + result = append(result, []byte(key)) + } + } + shard.mutex.RUnlock() + shardIndex++ + } + + return result, 0 +} diff --git a/datastruct/dict/concurrent_test.go b/datastruct/dict/concurrent_test.go index 0581ef58..c59efd57 100644 --- a/datastruct/dict/concurrent_test.go +++ b/datastruct/dict/concurrent_test.go @@ -465,7 +465,7 @@ func TestConcurrentRemoveWithLock(t *testing.T) { } } -//change t.Error remove->forEach +// change t.Error remove->forEach func TestConcurrentForEach(t *testing.T) { d := MakeConcurrent(0) size := 100 @@ -524,3 +524,51 @@ func TestConcurrentDict_Keys(t *testing.T) { t.Errorf("expect %d keys, actual: %d", size, len(d.Keys())) } } + +func TestDictScan(t *testing.T) { + d := MakeConcurrent(0) + count := 100 + for i := 0; i < count; i++ { + key := "kkk" + strconv.Itoa(i) + d.Put(key, i) + } + for i := 0; i < count; i++ { + key := "key" + strconv.Itoa(i) + d.Put(key, i) + } + cursor := 0 + matchKey := "*" + c := 20 + result := make([][]byte, 0) + var returnKeys [][]byte + for { + returnKeys, cursor = d.DictScan(cursor, c, matchKey) + result = append(result, returnKeys...) + if cursor == 0 { + break + } + } + result = utils.RemoveDuplicates(result) + if len(result) != count*2 { + t.Errorf("scan command result number error: %d, should be %d ", len(result), count*2) + } + matchKey = "key*" + cursor = 0 + mresult := make([][]byte, 0) + for { + returnKeys, cursor = d.DictScan(cursor, c, matchKey) + mresult = append(mresult, returnKeys...) + if cursor == 0 { + break + } + } + mresult = utils.RemoveDuplicates(mresult) + if len(mresult) != count { + t.Errorf("scan command result number error: %d, should be %d ", len(mresult), count) + } + matchKey = "no*" + returnKeys, _ = d.DictScan(cursor, c, matchKey) + if len(returnKeys) != 0 { + t.Errorf("returnKeys should be empty") + } +} diff --git a/datastruct/dict/dict.go b/datastruct/dict/dict.go index 28ecd4cd..4fdf2793 100644 --- a/datastruct/dict/dict.go +++ b/datastruct/dict/dict.go @@ -16,4 +16,5 @@ type Dict interface { RandomKeys(limit int) []string RandomDistinctKeys(limit int) []string Clear() + DictScan(cursor int, count int, pattern string) ([][]byte, int) } diff --git a/datastruct/dict/simple.go b/datastruct/dict/simple.go index 49be8f33..187a708d 100644 --- a/datastruct/dict/simple.go +++ b/datastruct/dict/simple.go @@ -120,3 +120,7 @@ func (dict *SimpleDict) RandomDistinctKeys(limit int) []string { func (dict *SimpleDict) Clear() { *dict = *MakeSimple() } + +func (dict *SimpleDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) { + return stringsToBytes(dict.Keys()), 0 +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index a912e3e6..c5c1b412 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -84,3 +84,20 @@ func ConvertRange(start int64, end int64, size int64) (int, int) { } return int(start), int(end) } + +// RemoveDuplicates removes duplicate byte slices from a 2D byte slice +func RemoveDuplicates(input [][]byte) [][]byte { + uniqueMap := make(map[string]struct{}) + var result [][]byte + + for _, item := range input { + // Use bytes.Buffer to convert byte slice to string + key := string(item) + if _, exists := uniqueMap[key]; !exists { + uniqueMap[key] = struct{}{} + result = append(result, item) + } + } + + return result +} diff --git a/test.rdb b/test.rdb deleted file mode 100644 index 93f1662b620c3f9cb84d3da2ce392b2f1a442583..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 202 zcmWG?b@2=~Ffg$E#aWb^l3A=DZ)S z9HqsnDZ06-xrZ)93o`uTNz6~vEhtJ&%uUKkJ;3mf;WsNIH*;}n2^SLs0|O%%^RQF_ z1^D@a93CJBN_=Aa^q-F