-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[ENH]: Add models and migration for Create/Delete/Get Tasks (#5546) #5573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
package coordinator | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"time" | ||
|
||
"github.com/chroma-core/chroma/go/pkg/common" | ||
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" | ||
"github.com/chroma-core/chroma/go/pkg/sysdb/coordinator/model" | ||
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" | ||
"github.com/chroma-core/chroma/go/pkg/types" | ||
"github.com/google/uuid" | ||
"github.com/pingcap/log" | ||
"go.uber.org/zap" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
"google.golang.org/protobuf/proto" | ||
) | ||
|
||
// CreateTask creates a new task in the database | ||
func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) { | ||
// Validate task name doesn't start with soft-deletion reserved prefix | ||
if strings.HasPrefix(req.Name, "_deleted_") { | ||
log.Error("CreateTask: task name cannot start with _deleted_") | ||
return nil, common.ErrInvalidTaskName | ||
} | ||
|
||
var taskID uuid.UUID | ||
|
||
// Execute all database operations in a transaction | ||
err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { | ||
// Check if task already exists | ||
existingTask, err := s.catalog.metaDomain.TaskDb(txCtx).GetByName(req.InputCollectionId, req.Name) | ||
if err != nil { | ||
log.Error("CreateTask: failed to check task", zap.Error(err)) | ||
return err | ||
} | ||
if existingTask != nil { | ||
log.Info("CreateTask: task already exists, returning existing") | ||
taskID = existingTask.ID | ||
return nil | ||
} | ||
|
||
// Generate new task UUID | ||
taskID = uuid.New() | ||
outputCollectionName := req.OutputCollectionName | ||
|
||
// Look up database_id from databases table using database name and tenant | ||
databases, err := s.catalog.metaDomain.DatabaseDb(txCtx).GetDatabases(req.TenantId, req.Database) | ||
if err != nil { | ||
log.Error("CreateTask: failed to get database", zap.Error(err)) | ||
return err | ||
} | ||
if len(databases) == 0 { | ||
log.Error("CreateTask: database not found") | ||
return common.ErrDatabaseNotFound | ||
} | ||
|
||
// Look up operator by name from the operators table | ||
operator, err := s.catalog.metaDomain.OperatorDb(txCtx).GetByName(req.OperatorName) | ||
if err != nil { | ||
log.Error("CreateTask: failed to get operator", zap.Error(err)) | ||
return err | ||
} | ||
if operator == nil { | ||
log.Error("CreateTask: operator not found", zap.String("operator_name", req.OperatorName)) | ||
return common.ErrOperatorNotFound | ||
} | ||
operatorID := operator.OperatorID | ||
|
||
// Generate UUIDv7 for time-ordered nonce | ||
nextNonce, err := uuid.NewV7() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// TODO(tanujnay112): Can combine the two collection checks into one | ||
// Check if input collection exists | ||
collections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections([]string{req.InputCollectionId}, nil, req.TenantId, req.Database, nil, nil, false) | ||
if err != nil { | ||
log.Error("CreateTask: failed to get input collection", zap.Error(err)) | ||
return err | ||
} | ||
if len(collections) == 0 { | ||
log.Error("CreateTask: input collection not found") | ||
return common.ErrCollectionNotFound | ||
} | ||
|
||
// Check if output collection already exists | ||
existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false) | ||
if err != nil { | ||
log.Error("CreateTask: failed to check output collection", zap.Error(err)) | ||
return err | ||
} | ||
if len(existingOutputCollections) > 0 { | ||
log.Error("CreateTask: output collection already exists") | ||
return common.ErrCollectionUniqueConstraintViolation | ||
} | ||
|
||
now := time.Now() | ||
task := &dbmodel.Task{ | ||
ID: taskID, | ||
Name: req.Name, | ||
TenantID: req.TenantId, | ||
DatabaseID: databases[0].ID, | ||
InputCollectionID: req.InputCollectionId, | ||
OutputCollectionName: req.OutputCollectionName, | ||
OperatorID: operatorID, | ||
OperatorParams: req.Params, | ||
CompletionOffset: 0, | ||
LastRun: nil, | ||
NextRun: nil, // Will be set to zero initially, scheduled by task scheduler | ||
MinRecordsForTask: int64(req.MinRecordsForTask), | ||
CurrentAttempts: 0, | ||
CreatedAt: now, | ||
UpdatedAt: now, | ||
NextNonce: nextNonce, | ||
OldestWrittenNonce: nil, | ||
} | ||
|
||
// Try to insert task into database | ||
err = s.catalog.metaDomain.TaskDb(txCtx).Insert(task) | ||
if err != nil { | ||
// Check if it's a unique constraint violation (concurrent creation) | ||
if err == common.ErrTaskAlreadyExists { | ||
log.Error("CreateTask: task already exists") | ||
return common.ErrTaskAlreadyExists | ||
} | ||
log.Error("CreateTask: failed to insert task", zap.Error(err)) | ||
return err | ||
} | ||
|
||
log.Info("Task created successfully", zap.String("task_id", taskID.String()), zap.String("name", req.Name), zap.String("output_collection_name", outputCollectionName)) | ||
return nil | ||
}) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &coordinatorpb.CreateTaskResponse{ | ||
TaskId: taskID.String(), | ||
}, nil | ||
} | ||
|
||
// GetTaskByName retrieves a task by name from the database | ||
func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) { | ||
// Can do both calls with a JOIN | ||
task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// If task not found, return empty response | ||
if task == nil { | ||
return nil, common.ErrTaskNotFound | ||
} | ||
|
||
// Look up operator name from operators table | ||
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID) | ||
if err != nil { | ||
log.Error("GetTaskByName: failed to get operator", zap.Error(err)) | ||
return nil, err | ||
} | ||
if operator == nil { | ||
log.Error("GetTaskByName: operator not found", zap.String("operator_id", task.OperatorID.String())) | ||
return nil, common.ErrOperatorNotFound | ||
} | ||
|
||
// Debug logging | ||
log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName)) | ||
|
||
// Convert task to response | ||
response := &coordinatorpb.GetTaskByNameResponse{ | ||
TaskId: proto.String(task.ID.String()), | ||
Name: proto.String(task.Name), | ||
OperatorName: proto.String(operator.OperatorName), | ||
InputCollectionId: proto.String(task.InputCollectionID), | ||
OutputCollectionName: proto.String(task.OutputCollectionName), | ||
Params: proto.String(task.OperatorParams), | ||
CompletionOffset: proto.Int64(task.CompletionOffset), | ||
MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)), | ||
} | ||
// Add output_collection_id if it's set | ||
if task.OutputCollectionID != nil { | ||
response.OutputCollectionId = task.OutputCollectionID | ||
} | ||
return response, nil | ||
} | ||
|
||
// DeleteTask soft deletes a task by name | ||
func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) { | ||
// First get the task to check if we need to delete the output collection | ||
task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName) | ||
if err != nil { | ||
log.Error("DeleteTask: failed to get task", zap.Error(err)) | ||
return nil, err | ||
} | ||
if task == nil { | ||
log.Error("DeleteTask: task not found") | ||
return nil, status.Errorf(codes.NotFound, "task not found") | ||
} | ||
|
||
// If delete_output is true and output_collection_id is set, soft-delete the output collection | ||
if req.DeleteOutput && task.OutputCollectionID != nil && *task.OutputCollectionID != "" { | ||
collectionUUID, err := types.ToUniqueID(task.OutputCollectionID) | ||
if err != nil { | ||
log.Error("DeleteTask: invalid output_collection_id", zap.Error(err)) | ||
return nil, status.Errorf(codes.InvalidArgument, "invalid output_collection_id: %v", err) | ||
} | ||
|
||
deleteCollection := &model.DeleteCollection{ | ||
ID: collectionUUID, | ||
TenantID: task.TenantID, | ||
DatabaseName: task.DatabaseID, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] There appears to be a naming inconsistency here. You are assigning Context for Agents
|
||
} | ||
|
||
err = s.SoftDeleteCollection(ctx, deleteCollection) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there not a way to do this transactionally? |
||
if err != nil { | ||
// Log but don't fail - we still want to delete the task | ||
log.Warn("DeleteTask: failed to delete output collection", zap.Error(err), zap.String("collection_id", *task.OutputCollectionID)) | ||
} else { | ||
log.Info("DeleteTask: deleted output collection", zap.String("collection_id", *task.OutputCollectionID)) | ||
} | ||
} | ||
|
||
// Now soft-delete the task | ||
err = s.catalog.metaDomain.TaskDb(ctx).SoftDelete(req.InputCollectionId, req.TaskName) | ||
if err != nil { | ||
log.Error("DeleteTask failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
log.Info("Task deleted", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) | ||
|
||
return &coordinatorpb.DeleteTaskResponse{ | ||
Success: true, | ||
}, nil | ||
} | ||
|
||
// GetOperators retrieves all operators from the database | ||
func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { | ||
operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll() | ||
if err != nil { | ||
log.Error("GetOperators failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
// Convert to proto response | ||
protoOperators := make([]*coordinatorpb.Operator, len(operators)) | ||
for i, op := range operators { | ||
protoOperators[i] = &coordinatorpb.Operator{ | ||
Id: op.OperatorID.String(), | ||
Name: op.OperatorName, | ||
} | ||
} | ||
|
||
log.Info("GetOperators succeeded", zap.Int("count", len(operators))) | ||
|
||
return &coordinatorpb.GetOperatorsResponse{ | ||
Operators: protoOperators, | ||
}, nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package grpc | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" | ||
"github.com/pingcap/log" | ||
"go.uber.org/zap" | ||
) | ||
|
||
func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) { | ||
log.Info("CreateTask", zap.String("name", req.Name), zap.String("operator_name", req.OperatorName)) | ||
|
||
res, err := s.coordinator.CreateTask(ctx, req) | ||
if err != nil { | ||
log.Error("CreateTask failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} | ||
|
||
func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) { | ||
log.Info("GetTaskByName", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) | ||
|
||
res, err := s.coordinator.GetTaskByName(ctx, req) | ||
if err != nil { | ||
log.Error("GetTaskByName failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} | ||
|
||
func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) { | ||
log.Info("DeleteTask", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) | ||
|
||
res, err := s.coordinator.DeleteTask(ctx, req) | ||
if err != nil { | ||
log.Error("DeleteTask failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} | ||
|
||
func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { | ||
log.Info("GetOperators") | ||
|
||
res, err := s.coordinator.GetOperators(ctx, req) | ||
if err != nil { | ||
log.Error("GetOperators failed", zap.Error(err)) | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[BestPractice]
The comment indicates this is for debug logging, but it's using
log.Info
. To avoid potentially noisy logs in a production environment for a read operation, it would be more appropriate to uselog.Debug
to align with the stated intent.⚡ Committable suggestion
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
Context for Agents