From 3fa26af448d32ccbb4e9a6ef862edebb4ef806aa Mon Sep 17 00:00:00 2001 From: roseduan Date: Sat, 16 Sep 2023 12:42:25 +0800 Subject: [PATCH] add Persist to remove the ttl --- batch.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ db.go | 21 ++++++++++++++++++++- db_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/batch.go b/batch.go index d046a869..1c4267ba 100644 --- a/batch.go +++ b/batch.go @@ -360,6 +360,57 @@ func (b *Batch) TTL(key []byte) (time.Duration, error) { return -1, nil } +// Persist removes the ttl of the key. +func (b *Batch) Persist(key []byte) error { + if len(key) == 0 { + return ErrKeyIsEmpty + } + if b.db.closed { + return ErrDBClosed + } + if b.options.ReadOnly { + return ErrReadOnlyBatch + } + + b.mu.Lock() + defer b.mu.Unlock() + + // if the key exists in pendingWrites, update the expiry time directly + pendingRecord := b.pendingWrites[string(key)] + if pendingRecord != nil && pendingRecord.Type != LogRecordDeleted { + pendingRecord.Expire = 0 + } else { + // check if the key exists in index + position := b.db.index.Get(key) + if position == nil { + return ErrKeyNotFound + } + chunk, err := b.db.dataFiles.Read(position) + if err != nil { + return err + } + + record := decodeLogRecord(chunk) + now := time.Now().UnixNano() + // check if the record is deleted or expired + if record.Type == LogRecordDeleted || record.IsExpired(now) { + b.db.index.Delete(record.Key) + return ErrKeyNotFound + } + // if the expiration time is 0, it means that the key has no expiration time, + // so we can return directly + if record.Expire == 0 { + return nil + } + + // set the expiration time to 0, and rewrite the record to wal + record.Expire = 0 + b.pendingWrites[string(key)] = record + } + + return nil +} + // Commit commits the batch, if the batch is readonly or empty, it will return directly. // // It will iterate the pendingWrites and write the data to the database, diff --git a/db.go b/db.go index c07c898f..60cab43a 100644 --- a/db.go +++ b/db.go @@ -313,7 +313,7 @@ func (db *DB) Expire(key []byte, ttl time.Duration) error { batch.reset() db.batchPool.Put(batch) }() - // This is a single put operation, we can set Sync to false. + // This is a single expire operation, we can set Sync to false. // Because the data will be written to the WAL, // and the WAL file will be synced to disk according to the DB options. batch.init(false, false, db).withPendingWrites() @@ -336,6 +336,25 @@ func (db *DB) TTL(key []byte) (time.Duration, error) { return batch.TTL(key) } +// Persist removes the ttl of the key. +// If the key does not exist or expired, it will return ErrKeyNotFound. +func (db *DB) Persist(key []byte) error { + batch := db.batchPool.Get().(*Batch) + defer func() { + batch.reset() + db.batchPool.Put(batch) + }() + // This is a single persist operation, we can set Sync to false. + // Because the data will be written to the WAL, + // and the WAL file will be synced to disk according to the DB options. + batch.init(false, false, db).withPendingWrites() + if err := batch.Persist(key); err != nil { + _ = batch.Rollback() + return err + } + return batch.Commit() +} + func (db *DB) Watch() (chan *Event, error) { if db.options.WatchQueueSize <= 0 { return nil, ErrWatchDisabled diff --git a/db_test.go b/db_test.go index 5d8b1c4e..fb8fadeb 100644 --- a/db_test.go +++ b/db_test.go @@ -661,3 +661,46 @@ func TestDB_Multi_DeleteExpiredKeys(t *testing.T) { assert.Equal(t, 10000, db.Stat().KeysNum) } } + +func TestDB_Persist(t *testing.T) { + options := DefaultOptions + db, err := Open(options) + assert.Nil(t, err) + defer destroyDB(db) + + // not exist + err = db.Persist(utils.GetTestKey(1)) + assert.Equal(t, err, ErrKeyNotFound) + + err = db.PutWithTTL(utils.GetTestKey(1), utils.RandomValue(10), time.Second*1) + assert.Nil(t, err) + + // exist + err = db.Persist(utils.GetTestKey(1)) + assert.Nil(t, err) + time.Sleep(time.Second * 2) + // check ttl + ttl, err := db.TTL(utils.GetTestKey(1)) + assert.Nil(t, err) + assert.Equal(t, ttl, time.Duration(-1)) + val1, err := db.Get(utils.GetTestKey(1)) + assert.Nil(t, err) + assert.NotNil(t, val1) + + // restart + err = db.Close() + assert.Nil(t, err) + + db2, err := Open(options) + assert.Nil(t, err) + defer func() { + _ = db2.Close() + }() + + ttl2, err := db2.TTL(utils.GetTestKey(1)) + assert.Nil(t, err) + assert.Equal(t, ttl2, time.Duration(-1)) + val2, err := db2.Get(utils.GetTestKey(1)) + assert.Nil(t, err) + assert.NotNil(t, val2) +}