Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
936 changes: 176 additions & 760 deletions cmd/root.go

Large diffs are not rendered by default.

153 changes: 133 additions & 20 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ import (

dockerContainer "github.com/docker/docker/api/types/container"

"github.com/nicholas-fedor/watchtower/internal/api"
"github.com/nicholas-fedor/watchtower/internal/flags"
"github.com/nicholas-fedor/watchtower/internal/logging"
"github.com/nicholas-fedor/watchtower/internal/scheduling"
"github.com/nicholas-fedor/watchtower/internal/util"
"github.com/nicholas-fedor/watchtower/pkg/api/update"
containerMock "github.com/nicholas-fedor/watchtower/pkg/container/mocks"
"github.com/nicholas-fedor/watchtower/pkg/metrics"
Expand Down Expand Up @@ -364,7 +368,7 @@ func TestFormatDuration(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatDuration(tt.duration)
result := util.FormatDuration(tt.duration)
assert.Equal(t, tt.expected, result)
})
}
Expand Down Expand Up @@ -415,11 +419,7 @@ func TestFormatTimeUnit(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatTimeUnit(struct {
value int64
singular string
plural string
}{tt.value, tt.singular, tt.plural}, tt.forceInclude)
result := util.FormatTimeUnit(tt.value, tt.singular, tt.plural, tt.forceInclude)
assert.Equal(t, tt.expected, result)
})
}
Expand Down Expand Up @@ -455,7 +455,7 @@ func TestFilterEmpty(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterEmpty(tt.input)
result := util.FilterEmpty(tt.input)
assert.Equal(t, tt.expected, result)
})
}
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestGetAPIAddr(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAPIAddr(tt.host, tt.port)
result := api.GetAPIAddr(tt.host, tt.port)
assert.Equal(t, tt.expected, result)

// Verify the formatted address is a valid TCP address
Expand Down Expand Up @@ -629,6 +629,12 @@ func TestUpdateLockSerialization(t *testing.T) {
// TestConcurrentScheduledAndAPIUpdate verifies that API-triggered updates wait for scheduled updates to complete,
// ensuring proper serialization and preventing race conditions between periodic updates and HTTP API calls.
func TestConcurrentScheduledAndAPIUpdate(t *testing.T) {
// Enable debug logging to see lock acquisition logs
originalLevel := logrus.GetLevel()

logrus.SetLevel(logrus.DebugLevel)
defer logrus.SetLevel(originalLevel)

// Initialize the update lock channel with the same pattern as in runMain
updateLock := make(chan bool, 1)
updateLock <- true
Expand All @@ -653,20 +659,29 @@ func TestConcurrentScheduledAndAPIUpdate(t *testing.T) {

// Simulate scheduled update (longer duration)
go func() {
t.Log("Scheduled: trying to acquire lock")

select {
case v := <-updateLock:
t.Log("Scheduled: acquired lock")
close(scheduledStarted)
time.Sleep(200 * time.Millisecond) // Simulate scheduled update work (longer than API)
close(scheduledCompleted)
t.Log("Scheduled: releasing lock")

updateLock <- v
default:
t.Error("Scheduled update should have acquired the lock")
}
}()

// Wait for scheduled update to start
<-scheduledStarted

// Simulate API update request
go func() {
t.Log("API: creating request")

req, err := http.NewRequestWithContext(
context.Background(),
http.MethodPost,
Expand All @@ -680,12 +695,12 @@ func TestConcurrentScheduledAndAPIUpdate(t *testing.T) {
}

w := httptest.NewRecorder()

t.Log("API: calling handler.Handle")
handler.Handle(w, req)
t.Log("API: handler.Handle completed")
}()

// Wait for scheduled update to start
<-scheduledStarted

// Verify API update has not started yet (should be blocked by lock)
select {
case <-apiStarted:
Expand Down Expand Up @@ -747,7 +762,21 @@ func TestUpdateOnStartTriggersImmediateUpdate(t *testing.T) {
filterDesc := testFilterDesc

// The function should trigger immediate update and then start scheduler
err = runUpgradesOnSchedule(ctx, cmd, filter, filterDesc, updateLock, false)
err = scheduling.RunUpgradesOnSchedule(
ctx,
cmd,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)

// Should not return an error (context cancellation is expected)
require.NoError(t, err)
Expand Down Expand Up @@ -816,7 +845,21 @@ func TestUpdateOnStartIntegratesWithCronScheduling(t *testing.T) {
filterDesc := testFilterDesc

startTime := time.Now()
err = runUpgradesOnSchedule(ctx, cmd, filter, filterDesc, updateLock, false)
err = scheduling.RunUpgradesOnSchedule(
ctx,
cmd,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)

// Should not return an error (context cancellation is expected)
require.NoError(t, err)
Expand Down Expand Up @@ -888,7 +931,21 @@ func TestUpdateOnStartLockingBehavior(t *testing.T) {
filter := func(_ types.FilterableContainer) bool { return false }
filterDesc := testFilterDesc

err = runUpgradesOnSchedule(ctx, cmd, filter, filterDesc, updateLock, false)
err = scheduling.RunUpgradesOnSchedule(
ctx,
cmd,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)

// Should not return an error
require.NoError(t, err)
Expand Down Expand Up @@ -939,7 +996,21 @@ func TestUpdateOnStartSelfUpdateScenario(t *testing.T) {
filter := func(_ types.FilterableContainer) bool { return true }
filterDesc := testFilterDesc

err = runUpgradesOnSchedule(ctx, cmd, filter, filterDesc, updateLock, false)
err = scheduling.RunUpgradesOnSchedule(
ctx,
cmd,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)

// Should not return an error
require.NoError(t, err)
Expand Down Expand Up @@ -1001,7 +1072,21 @@ func TestUpdateOnStartMultiInstanceScenario(t *testing.T) {
filter := func(_ types.FilterableContainer) bool { return false }
filterDesc := "instance1"

err := runUpgradesOnSchedule(ctx, cmd1, filter, filterDesc, updateLock, false)
err := scheduling.RunUpgradesOnSchedule(
ctx,
cmd1,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)
assert.NoError(t, err)
atomic.AddInt32(&completed, 1)
close(instance1Called)
Expand All @@ -1014,7 +1099,21 @@ func TestUpdateOnStartMultiInstanceScenario(t *testing.T) {
filter := func(_ types.FilterableContainer) bool { return false }
filterDesc := "instance2"

err := runUpgradesOnSchedule(ctx, cmd2, filter, filterDesc, updateLock, false)
err := scheduling.RunUpgradesOnSchedule(
ctx,
cmd2,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)
assert.NoError(t, err)
atomic.AddInt32(&completed, 1)
close(instance2Called)
Expand Down Expand Up @@ -1059,7 +1158,7 @@ func TestWaitForRunningUpdate_NoUpdateRunning(t *testing.T) {
ctx := context.Background()
start := time.Now()

waitForRunningUpdate(ctx, lock)
scheduling.WaitForRunningUpdate(ctx, lock)

elapsed := time.Since(start)

Expand All @@ -1077,7 +1176,7 @@ func TestWaitForRunningUpdate_UpdateRunning(t *testing.T) {
waitCompleted := make(chan bool, 1)

go func() {
waitForRunningUpdate(ctx, lock)
scheduling.WaitForRunningUpdate(ctx, lock)

waitCompleted <- true
}()
Expand Down Expand Up @@ -1141,7 +1240,21 @@ func TestRunUpgradesOnSchedule_ShutdownWaitsForRunningUpdate(t *testing.T) {
filterDesc := testFilterDesc

// This should start and wait for context cancellation
err := runUpgradesOnSchedule(ctx, cmd, filter, filterDesc, updateLock, false)
err := scheduling.RunUpgradesOnSchedule(
ctx,
cmd,
filter,
filterDesc,
updateLock,
false,
"",
logging.WriteStartupMessage,
runUpdatesWithNotifications,
nil,
"",
nil,
"",
)
assert.NoError(t, err)

shutdownCompleted <- true
Expand Down
Loading
Loading