diff --git a/lib/timewheel/timewheel.go b/lib/timewheel/timewheel.go index ba3fe89d..e8ba2bfc 100644 --- a/lib/timewheel/timewheel.go +++ b/lib/timewheel/timewheel.go @@ -3,6 +3,7 @@ package timewheel import ( "container/list" "github.com/hdt3213/godis/lib/logger" + "sync" "time" ) @@ -11,7 +12,7 @@ type location struct { etask *list.Element } -// TimeWheel can execute job after waiting given duration +// TimeWheel can execute jobs after a given delay type TimeWheel struct { interval time.Duration ticker *time.Ticker @@ -23,6 +24,8 @@ type TimeWheel struct { addTaskChannel chan task removeTaskChannel chan string stopChannel chan bool + + mu sync.RWMutex } type task struct { @@ -48,7 +51,6 @@ func New(interval time.Duration, slotNum int) *TimeWheel { stopChannel: make(chan bool), } tw.initSlots() - return tw } @@ -58,7 +60,7 @@ func (tw *TimeWheel) initSlots() { } } -// Start starts ticker for time wheel +// Start starts the time wheel func (tw *TimeWheel) Start() { tw.ticker = time.NewTicker(tw.interval) go tw.start() @@ -69,7 +71,7 @@ func (tw *TimeWheel) Stop() { tw.stopChannel <- true } -// AddJob add new job into pending queue +// AddJob adds a new job to the pending queue func (tw *TimeWheel) AddJob(delay time.Duration, key string, job func()) { if delay < 0 { return @@ -103,16 +105,21 @@ func (tw *TimeWheel) start() { } func (tw *TimeWheel) tickHandler() { + tw.mu.Lock() l := tw.slots[tw.currentPos] if tw.currentPos == tw.slotNum-1 { tw.currentPos = 0 } else { tw.currentPos++ } + tw.mu.Unlock() + go tw.scanAndRunTask(l) } func (tw *TimeWheel) scanAndRunTask(l *list.List) { + var tasksToRemove []string + tw.mu.RLock() // Read lock for accessing the list for e := l.Front(); e != nil; { task := e.Value.(*task) if task.circle > 0 { @@ -121,52 +128,68 @@ func (tw *TimeWheel) scanAndRunTask(l *list.List) { continue } - go func() { + go func(job func()) { defer func() { if err := recover(); err != nil { logger.Error(err) } }() - job := task.job job() - }() - next := e.Next() - l.Remove(e) + }(task.job) + if task.key != "" { - delete(tw.timer, task.key) + tasksToRemove = append(tasksToRemove, task.key) } + next := e.Next() + l.Remove(e) // Safe as this is a local operation e = next } + tw.mu.RUnlock() + + // Remove tasks from the timer after the scan + tw.mu.Lock() + for _, key := range tasksToRemove { + delete(tw.timer, key) + } + tw.mu.Unlock() } func (tw *TimeWheel) addTask(task *task) { pos, circle := tw.getPositionAndCircle(task.delay) task.circle = circle + tw.mu.Lock() + defer tw.mu.Unlock() + + if task.key != "" { + if _, ok := tw.timer[task.key]; ok { + tw.removeTaskInternal(task.key) // Internal version avoids double lock + } + } + e := tw.slots[pos].PushBack(task) loc := &location{ slot: pos, etask: e, } - if task.key != "" { - _, ok := tw.timer[task.key] - if ok { - tw.removeTask(task.key) - } - } tw.timer[task.key] = loc } func (tw *TimeWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) { delaySeconds := int(d.Seconds()) intervalSeconds := int(tw.interval.Seconds()) - circle = int(delaySeconds / intervalSeconds / tw.slotNum) - pos = int(tw.currentPos+delaySeconds/intervalSeconds) % tw.slotNum - + circle = delaySeconds / intervalSeconds / tw.slotNum + pos = (tw.currentPos + delaySeconds/intervalSeconds) % tw.slotNum return } func (tw *TimeWheel) removeTask(key string) { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.removeTaskInternal(key) +} + +func (tw *TimeWheel) removeTaskInternal(key string) { pos, ok := tw.timer[key] if !ok { return diff --git a/lib/timewheel/timewheel_test.go b/lib/timewheel/timewheel_test.go new file mode 100644 index 00000000..60c599b0 --- /dev/null +++ b/lib/timewheel/timewheel_test.go @@ -0,0 +1,122 @@ +package timewheel + +import ( + "fmt" + "sync" + "testing" + "time" +) + +func TestTimeWheelConcurrency(t *testing.T) { + // Initialize the time wheel + tw := New(time.Second, 3600) + tw.Start() + defer tw.Stop() + + var wg sync.WaitGroup + const jobCount = 1000 + + // Function to simulate a job + job := func(id int) { + fmt.Printf("Job %d executed at %v\n", id, time.Now()) + } + + // Add jobs concurrently + wg.Add(jobCount) + for i := 0; i < jobCount; i++ { + go func(id int) { + defer wg.Done() + delay := time.Duration(id%10) * time.Second // Randomize delays + Delay(delay, fmt.Sprintf("job-%d", id), func() { job(id) }) + }(i) + } + + // Remove jobs concurrently + wg.Add(jobCount / 10) + for i := 0; i < jobCount; i += 10 { + go func(id int) { + defer wg.Done() + time.Sleep(2 * time.Second) // Ensure some jobs are added before canceling + Cancel(fmt.Sprintf("job-%d", id)) + }(i) + } + + // Add timed jobs with specific `At` time + wg.Add(jobCount / 10) + for i := 0; i < jobCount; i += 10 { + go func(id int) { + defer wg.Done() + at := time.Now().Add(5 * time.Second) + At(at, fmt.Sprintf("timed-job-%d", id), func() { + fmt.Printf("Timed Job %d executed at %v\n", id, time.Now()) + }) + }(i) + } + + // Wait for all goroutines to complete + wg.Wait() + fmt.Println("All tasks submitted and executed/cancelled successfully.") +} + +func TestTimeWheelConcurrentAddRunRemove(t *testing.T) { + // Initialize the time wheel + tw := New(time.Millisecond*100, 360) + tw.Start() + defer tw.Stop() + + var wg sync.WaitGroup + const totalJobs = 1000 + + // Function to simulate a job + job := func(id int) { + fmt.Printf("Job %d executed at %v\n", id, time.Now()) + } + + // Concurrently add jobs + wg.Add(totalJobs) + for i := 0; i < totalJobs; i++ { + go func(id int) { + defer wg.Done() + delay := time.Duration(id%50) * time.Millisecond // Randomize delays + Delay(delay, fmt.Sprintf("job-%d", id), func() { job(id) }) + }(i) + } + + // Concurrently remove some jobs + wg.Add(totalJobs / 5) + for i := 0; i < totalJobs; i += 5 { + go func(id int) { + defer wg.Done() + time.Sleep(time.Millisecond * 10) // Allow some jobs to be added first + Cancel(fmt.Sprintf("job-%d", id)) + }(i) + } + + // Concurrently add and execute timed jobs + wg.Add(totalJobs / 10) + for i := 0; i < totalJobs; i += 10 { + go func(id int) { + defer wg.Done() + at := time.Now().Add(time.Millisecond * time.Duration(20+id%30)) + At(at, fmt.Sprintf("timed-job-%d", id), func() { + fmt.Printf("Timed Job %d executed at %v\n", id, time.Now()) + }) + }(i) + } + + // Concurrently add long-duration jobs and immediately remove them + wg.Add(totalJobs / 20) + for i := 0; i < totalJobs; i += 20 { + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("long-job-%d", id) + Delay(time.Second*5, key, func() { fmt.Printf("Long Job %d executed\n", id) }) + time.Sleep(time.Millisecond * 50) + Cancel(key) + }(i) + } + + // Wait for all operations to complete + wg.Wait() + fmt.Println("Concurrent Add, Run, and Remove Test completed successfully.") +}