From 7759d3d43d4a18af0dc6aed720f8a681e4e556c5 Mon Sep 17 00:00:00 2001 From: Sudipto Baral Date: Mon, 16 Sep 2024 13:00:26 -0400 Subject: [PATCH] Implement UpdateQueue() repo method. (#205) * Implement UpdateQueue() repo method. --- .../database/repository/mock_repository.go | 14 ++ internal/database/repository/queue.go | 76 +++++++ .../database/repository/queue_int_test.go | 192 ++++++++++++++++++ internal/database/repository/repository.go | 1 + 4 files changed, 283 insertions(+) diff --git a/internal/database/repository/mock_repository.go b/internal/database/repository/mock_repository.go index 9826600c..1ca5ee0e 100644 --- a/internal/database/repository/mock_repository.go +++ b/internal/database/repository/mock_repository.go @@ -248,6 +248,20 @@ func (mr *MockRepositoryMockRecorder) UpdateHistory(arg0, arg1, arg2 any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHistory", reflect.TypeOf((*MockRepository)(nil).UpdateHistory), arg0, arg1, arg2) } +// UpdateQueue mocks base method. +func (m *MockRepository) UpdateQueue(arg0 context.Context, arg1 *dao.PartitionQueueDAOInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateQueue", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateQueue indicates an expected call of UpdateQueue. +func (mr *MockRepositoryMockRecorder) UpdateQueue(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateQueue", reflect.TypeOf((*MockRepository)(nil).UpdateQueue), arg0, arg1) +} + // UpsertApplications mocks base method. func (m *MockRepository) UpsertApplications(arg0 context.Context, arg1 []*dao.ApplicationDAOInfo) error { m.ctrl.T.Helper() diff --git a/internal/database/repository/queue.go b/internal/database/repository/queue.go index 83b8d8a3..310f5cdb 100644 --- a/internal/database/repository/queue.go +++ b/internal/database/repository/queue.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/G-Research/yunikorn-history-server/internal/model" + "github.com/G-Research/yunikorn-history-server/internal/util" ) func (s *PostgresRepository) UpsertQueues(ctx context.Context, queues []*dao.PartitionQueueDAOInfo) error { @@ -161,6 +162,81 @@ func (s *PostgresRepository) AddQueues(ctx context.Context, parentId *string, qu return nil } +// UpdateQueue updates the queue based on the queue_name and partition. +// If the queue has children, the function will recursively update them. +// If provided child queue does not exist, the function will add it. +// The function returns an error if the update operation fails. +func (s *PostgresRepository) UpdateQueue(ctx context.Context, queue *dao.PartitionQueueDAOInfo) error { + updateSQL := ` + UPDATE queues SET + status = @status, + partition = @partition, + pending_resource = @pending_resource, + max_resource = @max_resource, + guaranteed_resource = @guaranteed_resource, + allocated_resource = @allocated_resource, + preempting_resource = @preempting_resource, + head_room = @head_room, + is_leaf = @is_leaf, + is_managed = @is_managed, + properties = @properties, + parent = @parent, + template_info = @template_info, + abs_used_capacity = @abs_used_capacity, + max_running_apps = @max_running_apps, + running_apps = @running_apps, + current_priority = @current_priority, + allocating_accepted_apps = @allocating_accepted_apps + WHERE queue_name = @queue_name AND partition = @partition AND deleted_at IS NULL +` + + result, err := s.dbpool.Exec(ctx, updateSQL, + pgx.NamedArgs{ + "queue_name": queue.QueueName, + "status": queue.Status, + "partition": queue.Partition, + "pending_resource": queue.PendingResource, + "max_resource": queue.MaxResource, + "guaranteed_resource": queue.GuaranteedResource, + "allocated_resource": queue.AllocatedResource, + "preempting_resource": queue.PreemptingResource, + "head_room": queue.HeadRoom, + "is_leaf": queue.IsLeaf, + "is_managed": queue.IsManaged, + "properties": queue.Properties, + "parent": queue.Parent, + "template_info": queue.TemplateInfo, + "abs_used_capacity": queue.AbsUsedCapacity, + "max_running_apps": queue.MaxRunningApps, + "running_apps": queue.RunningApps, + "current_priority": queue.CurrentPriority, + "allocating_accepted_apps": queue.AllocatingAcceptedApps, + }, + ) + if err != nil { + return fmt.Errorf("could not update queue in DB: %v", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("queue not found: %s", queue.QueueName) + } + + // If there are children, recursively update them + if len(queue.Children) > 0 { + for _, child := range queue.Children { + err := s.UpdateQueue(ctx, util.ToPtr(child)) + // if the child queue does not exist, we should add it + if err != nil { + err := s.AddQueues(ctx, nil, []*dao.PartitionQueueDAOInfo{&child}) + if err != nil { + return fmt.Errorf("could not add child queue %s into DB: %v", child.QueueName, err) + } + } + } + } + return nil +} + // GetAllQueues returns all queues from the database as a flat list // child queues are not nested in the parent queue.Children field func (s *PostgresRepository) GetAllQueues(ctx context.Context) ([]*model.PartitionQueueDAOInfo, error) { diff --git a/internal/database/repository/queue_int_test.go b/internal/database/repository/queue_int_test.go index 0784e68a..0804fc5b 100644 --- a/internal/database/repository/queue_int_test.go +++ b/internal/database/repository/queue_int_test.go @@ -495,6 +495,198 @@ func TestAddQueues_Integration(t *testing.T) { } +func TestUpdateQueue_Integration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + ctx := context.Background() + connPool := database.NewTestConnectionPool(ctx, t) + + repo, err := NewPostgresRepository(connPool) + require.NoError(t, err) + + tests := []struct { + name string + existingQueues []*dao.PartitionQueueDAOInfo + queueToUpdate *dao.PartitionQueueDAOInfo + expectedError bool + }{ + { + name: "Update root queue when root queue exists", + existingQueues: []*dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + }, + }, + queueToUpdate: &dao.PartitionQueueDAOInfo{ + Partition: "default", + QueueName: "root", + CurrentPriority: 1, + }, + expectedError: false, + }, + { + name: "Update root queue when root queue does not exist", + existingQueues: nil, + queueToUpdate: &dao.PartitionQueueDAOInfo{ + Partition: "default", + QueueName: "root", + CurrentPriority: 1, + }, + expectedError: true, + }, + { + name: "Update when child queues has changed", + existingQueues: []*dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + Children: []dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root.org", + Parent: "root", + IsLeaf: true, + CurrentPriority: 100, + }, + { + Partition: "default", + QueueName: "root.system", + Parent: "root", + IsLeaf: true, + CurrentPriority: 150, + }, + }, + }, + }, + queueToUpdate: &dao.PartitionQueueDAOInfo{ + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + Children: []dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root.org", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + { + Partition: "default", + QueueName: "root.system", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + }, + }, + expectedError: false, + }, + { + name: "Update when new child queues has been added", + existingQueues: []*dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + }, + }, + queueToUpdate: &dao.PartitionQueueDAOInfo{ + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + Children: []dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root.org", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + { + Partition: "default", + QueueName: "root.system", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + }, + }, + expectedError: false, + }, + { + name: "Update when both parent queue changed and new child queues has been added", + existingQueues: []*dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root", + CurrentPriority: 0, + }, + }, + queueToUpdate: &dao.PartitionQueueDAOInfo{ + Partition: "default", + QueueName: "root", + CurrentPriority: 100, + Children: []dao.PartitionQueueDAOInfo{ + { + Partition: "default", + QueueName: "root.org", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + { + Partition: "default", + QueueName: "root.system", + Parent: "root", + IsLeaf: true, + CurrentPriority: 200, + }, + }, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // clean up the table after the test + t.Cleanup(func() { + _, err := connPool.Exec(ctx, "DELETE FROM queues") + require.NoError(t, err) + }) + // seed the existing queues + if tt.existingQueues != nil { + if err := repo.AddQueues(ctx, nil, tt.existingQueues); err != nil { + t.Fatalf("could not seed queue: %v", err) + } + } + // update the new queue + err := repo.UpdateQueue(ctx, tt.queueToUpdate) + if tt.expectedError { + require.Error(t, err) + return + } + require.NoError(t, err) + // check if the queue is updated along with its children + queueFromDB, err := repo.GetQueue(ctx, tt.queueToUpdate.Partition, tt.queueToUpdate.QueueName) + require.NoError(t, err) + assert.Equal(t, tt.queueToUpdate.QueueName, queueFromDB.QueueName) + assert.Equal(t, tt.queueToUpdate.Partition, queueFromDB.Partition) + assert.Equal(t, tt.queueToUpdate.CurrentPriority, queueFromDB.CurrentPriority) + // compare the children + for i, child := range tt.queueToUpdate.Children { + assert.Equal(t, child.CurrentPriority, queueFromDB.Children[i].CurrentPriority) + } + + }) + } +} + func seedQueues(t *testing.T, repo *PostgresRepository) { t.Helper() diff --git a/internal/database/repository/repository.go b/internal/database/repository/repository.go index f7204408..b15a4285 100644 --- a/internal/database/repository/repository.go +++ b/internal/database/repository/repository.go @@ -28,6 +28,7 @@ type Repository interface { UpsertPartitions(ctx context.Context, partitions []*dao.PartitionInfo) error GetAllPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) AddQueues(ctx context.Context, parentId *string, queues []*dao.PartitionQueueDAOInfo) error + UpdateQueue(ctx context.Context, queue *dao.PartitionQueueDAOInfo) error UpsertQueues(ctx context.Context, queues []*dao.PartitionQueueDAOInfo) error GetAllQueues(ctx context.Context) ([]*model.PartitionQueueDAOInfo, error) GetQueuesPerPartition(ctx context.Context, partition string) ([]*model.PartitionQueueDAOInfo, error)