Skip to content

Commit

Permalink
added timewheel concurrency support
Browse files Browse the repository at this point in the history
  • Loading branch information
derkan committed Nov 26, 2024
1 parent 7503040 commit 2e3c05a
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 19 deletions.
61 changes: 42 additions & 19 deletions lib/timewheel/timewheel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package timewheel
import (
"container/list"
"github.com/hdt3213/godis/lib/logger"
"sync"
"time"
)

Expand All @@ -11,7 +12,7 @@ type location struct {
etask *list.Element
}

// TimeWheel can execute job after waiting given duration
// TimeWheel can execute jobs after a given delay
type TimeWheel struct {
interval time.Duration
ticker *time.Ticker
Expand All @@ -23,6 +24,8 @@ type TimeWheel struct {
addTaskChannel chan task
removeTaskChannel chan string
stopChannel chan bool

mu sync.RWMutex
}

type task struct {
Expand All @@ -48,7 +51,6 @@ func New(interval time.Duration, slotNum int) *TimeWheel {
stopChannel: make(chan bool),
}
tw.initSlots()

return tw
}

Expand All @@ -58,7 +60,7 @@ func (tw *TimeWheel) initSlots() {
}
}

// Start starts ticker for time wheel
// Start starts the time wheel
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.interval)
go tw.start()
Expand All @@ -69,7 +71,7 @@ func (tw *TimeWheel) Stop() {
tw.stopChannel <- true
}

// AddJob add new job into pending queue
// AddJob adds a new job to the pending queue
func (tw *TimeWheel) AddJob(delay time.Duration, key string, job func()) {
if delay < 0 {
return
Expand Down Expand Up @@ -103,16 +105,21 @@ func (tw *TimeWheel) start() {
}

func (tw *TimeWheel) tickHandler() {
tw.mu.Lock()
l := tw.slots[tw.currentPos]
if tw.currentPos == tw.slotNum-1 {
tw.currentPos = 0
} else {
tw.currentPos++
}
tw.mu.Unlock()

go tw.scanAndRunTask(l)
}

func (tw *TimeWheel) scanAndRunTask(l *list.List) {
var tasksToRemove []string
tw.mu.RLock() // Read lock for accessing the list
for e := l.Front(); e != nil; {
task := e.Value.(*task)
if task.circle > 0 {
Expand All @@ -121,52 +128,68 @@ func (tw *TimeWheel) scanAndRunTask(l *list.List) {
continue
}

go func() {
go func(job func()) {
defer func() {
if err := recover(); err != nil {
logger.Error(err)
}
}()
job := task.job
job()
}()
next := e.Next()
l.Remove(e)
}(task.job)

if task.key != "" {
delete(tw.timer, task.key)
tasksToRemove = append(tasksToRemove, task.key)
}
next := e.Next()
l.Remove(e) // Safe as this is a local operation
e = next
}
tw.mu.RUnlock()

// Remove tasks from the timer after the scan
tw.mu.Lock()
for _, key := range tasksToRemove {
delete(tw.timer, key)
}
tw.mu.Unlock()
}

func (tw *TimeWheel) addTask(task *task) {
pos, circle := tw.getPositionAndCircle(task.delay)
task.circle = circle

tw.mu.Lock()
defer tw.mu.Unlock()

if task.key != "" {
if _, ok := tw.timer[task.key]; ok {
tw.removeTaskInternal(task.key) // Internal version avoids double lock
}
}

e := tw.slots[pos].PushBack(task)
loc := &location{
slot: pos,
etask: e,
}
if task.key != "" {
_, ok := tw.timer[task.key]
if ok {
tw.removeTask(task.key)
}
}
tw.timer[task.key] = loc
}

func (tw *TimeWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
delaySeconds := int(d.Seconds())
intervalSeconds := int(tw.interval.Seconds())
circle = int(delaySeconds / intervalSeconds / tw.slotNum)
pos = int(tw.currentPos+delaySeconds/intervalSeconds) % tw.slotNum

circle = delaySeconds / intervalSeconds / tw.slotNum
pos = (tw.currentPos + delaySeconds/intervalSeconds) % tw.slotNum
return
}

func (tw *TimeWheel) removeTask(key string) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.removeTaskInternal(key)
}

func (tw *TimeWheel) removeTaskInternal(key string) {
pos, ok := tw.timer[key]
if !ok {
return
Expand Down
122 changes: 122 additions & 0 deletions lib/timewheel/timewheel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package timewheel

import (
"fmt"
"sync"
"testing"
"time"
)

func TestTimeWheelConcurrency(t *testing.T) {
// Initialize the time wheel
tw := New(time.Second, 3600)
tw.Start()
defer tw.Stop()

var wg sync.WaitGroup
const jobCount = 1000

// Function to simulate a job
job := func(id int) {
fmt.Printf("Job %d executed at %v\n", id, time.Now())
}

// Add jobs concurrently
wg.Add(jobCount)
for i := 0; i < jobCount; i++ {
go func(id int) {
defer wg.Done()
delay := time.Duration(id%10) * time.Second // Randomize delays
Delay(delay, fmt.Sprintf("job-%d", id), func() { job(id) })
}(i)
}

// Remove jobs concurrently
wg.Add(jobCount / 10)
for i := 0; i < jobCount; i += 10 {
go func(id int) {
defer wg.Done()
time.Sleep(2 * time.Second) // Ensure some jobs are added before canceling
Cancel(fmt.Sprintf("job-%d", id))
}(i)
}

// Add timed jobs with specific `At` time
wg.Add(jobCount / 10)
for i := 0; i < jobCount; i += 10 {
go func(id int) {
defer wg.Done()
at := time.Now().Add(5 * time.Second)
At(at, fmt.Sprintf("timed-job-%d", id), func() {
fmt.Printf("Timed Job %d executed at %v\n", id, time.Now())
})
}(i)
}

// Wait for all goroutines to complete
wg.Wait()
fmt.Println("All tasks submitted and executed/cancelled successfully.")
}

func TestTimeWheelConcurrentAddRunRemove(t *testing.T) {
// Initialize the time wheel
tw := New(time.Millisecond*100, 360)
tw.Start()
defer tw.Stop()

var wg sync.WaitGroup
const totalJobs = 1000

// Function to simulate a job
job := func(id int) {
fmt.Printf("Job %d executed at %v\n", id, time.Now())
}

// Concurrently add jobs
wg.Add(totalJobs)
for i := 0; i < totalJobs; i++ {
go func(id int) {
defer wg.Done()
delay := time.Duration(id%50) * time.Millisecond // Randomize delays
Delay(delay, fmt.Sprintf("job-%d", id), func() { job(id) })
}(i)
}

// Concurrently remove some jobs
wg.Add(totalJobs / 5)
for i := 0; i < totalJobs; i += 5 {
go func(id int) {
defer wg.Done()
time.Sleep(time.Millisecond * 10) // Allow some jobs to be added first
Cancel(fmt.Sprintf("job-%d", id))
}(i)
}

// Concurrently add and execute timed jobs
wg.Add(totalJobs / 10)
for i := 0; i < totalJobs; i += 10 {
go func(id int) {
defer wg.Done()
at := time.Now().Add(time.Millisecond * time.Duration(20+id%30))
At(at, fmt.Sprintf("timed-job-%d", id), func() {
fmt.Printf("Timed Job %d executed at %v\n", id, time.Now())
})
}(i)
}

// Concurrently add long-duration jobs and immediately remove them
wg.Add(totalJobs / 20)
for i := 0; i < totalJobs; i += 20 {
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("long-job-%d", id)
Delay(time.Second*5, key, func() { fmt.Printf("Long Job %d executed\n", id) })
time.Sleep(time.Millisecond * 50)
Cancel(key)
}(i)
}

// Wait for all operations to complete
wg.Wait()
fmt.Println("Concurrent Add, Run, and Remove Test completed successfully.")
}

0 comments on commit 2e3c05a

Please sign in to comment.