diff --git a/Tiltfile b/Tiltfile index 92530e9f3c5..e67e262b5ca 100644 --- a/Tiltfile +++ b/Tiltfile @@ -20,14 +20,14 @@ if config.tilt_subcommand == "ci": custom_build( 'rust-log-service', 'docker image tag rust-log-service:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'rust-log-service', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='log_service' ) @@ -68,14 +68,14 @@ if config.tilt_subcommand == "ci": custom_build( 'rust-frontend-service', 'docker image tag rust-frontend-service:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'rust-frontend-service', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='cli' ) @@ -84,14 +84,14 @@ if config.tilt_subcommand == "ci": custom_build( 'query-service', 'docker image tag query-service:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'query-service', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='query_service' ) @@ -100,14 +100,14 @@ if config.tilt_subcommand == "ci": custom_build( 'compaction-service', 'docker image tag compactor-service:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'compaction-service', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='compaction_service' ) @@ -116,14 +116,14 @@ if config.tilt_subcommand == "ci": custom_build( 'garbage-collector', 'docker image tag garbage-collector:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'garbage-collector', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='garbage_collector' ) @@ -132,14 +132,14 @@ if config.tilt_subcommand == "ci": custom_build( 'load-service', 'docker image tag load-service:ci $EXPECTED_REF', - ['./rust/', './idl/', './Cargo.toml', './Cargo.lock'], + ['./rust/', './idl/', './Cargo.toml', './Cargo.lock', './go/pkg/sysdb/metastore/db/dbmodel/constants.go'], disable_push=True ) else: docker_build( 'load-service', '.', - only=["rust/", "idl/", "Cargo.toml", "Cargo.lock"], + only=["rust/", "idl/", "Cargo.toml", "Cargo.lock", "go/pkg/sysdb/metastore/db/dbmodel/constants.go"], dockerfile='./rust/Dockerfile', target='load_service' ) diff --git a/go/pkg/common/errors.go b/go/pkg/common/errors.go index aaed9c6637c..c9ccb56ad64 100644 --- a/go/pkg/common/errors.go +++ b/go/pkg/common/errors.go @@ -48,6 +48,14 @@ var ( // Segment metadata errors ErrUnknownSegmentMetadataType = errors.New("segment metadata value type not supported") + // Task errors + ErrTaskAlreadyExists = errors.New("the task that was being created already exists for this collection") + ErrTaskNotFound = errors.New("the requested task was not found") + ErrInvalidTaskName = errors.New("task name cannot start with reserved prefix '_deleted_'") + + // Operator errors + ErrOperatorNotFound = errors.New("operator not found") + // Others ErrCompactionOffsetSomehowAhead = errors.New("system invariant was violated. Compaction offset in sysdb should always be behind or equal to offset in log") ) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go new file mode 100644 index 00000000000..6165b5df944 --- /dev/null +++ b/go/pkg/sysdb/coordinator/task.go @@ -0,0 +1,264 @@ +package coordinator + +import ( + "context" + "strings" + "time" + + "github.com/chroma-core/chroma/go/pkg/common" + "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" + "github.com/chroma-core/chroma/go/pkg/sysdb/coordinator/model" + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/chroma-core/chroma/go/pkg/types" + "github.com/google/uuid" + "github.com/pingcap/log" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +// CreateTask creates a new task in the database +func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) { + // Validate task name doesn't start with soft-deletion reserved prefix + if strings.HasPrefix(req.Name, "_deleted_") { + log.Error("CreateTask: task name cannot start with _deleted_") + return nil, common.ErrInvalidTaskName + } + + var taskID uuid.UUID + + // Execute all database operations in a transaction + err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { + // Check if task already exists + existingTask, err := s.catalog.metaDomain.TaskDb(txCtx).GetByName(req.InputCollectionId, req.Name) + if err != nil { + log.Error("CreateTask: failed to check task", zap.Error(err)) + return err + } + if existingTask != nil { + log.Info("CreateTask: task already exists, returning existing") + taskID = existingTask.ID + return nil + } + + // Generate new task UUID + taskID = uuid.New() + outputCollectionName := req.OutputCollectionName + + // Look up database_id from databases table using database name and tenant + databases, err := s.catalog.metaDomain.DatabaseDb(txCtx).GetDatabases(req.TenantId, req.Database) + if err != nil { + log.Error("CreateTask: failed to get database", zap.Error(err)) + return err + } + if len(databases) == 0 { + log.Error("CreateTask: database not found") + return common.ErrDatabaseNotFound + } + + // Look up operator by name from the operators table + operator, err := s.catalog.metaDomain.OperatorDb(txCtx).GetByName(req.OperatorName) + if err != nil { + log.Error("CreateTask: failed to get operator", zap.Error(err)) + return err + } + if operator == nil { + log.Error("CreateTask: operator not found", zap.String("operator_name", req.OperatorName)) + return common.ErrOperatorNotFound + } + operatorID := operator.OperatorID + + // Generate UUIDv7 for time-ordered nonce + nextNonce, err := uuid.NewV7() + if err != nil { + return err + } + + // TODO(tanujnay112): Can combine the two collection checks into one + // Check if input collection exists + collections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections([]string{req.InputCollectionId}, nil, req.TenantId, req.Database, nil, nil, false) + if err != nil { + log.Error("CreateTask: failed to get input collection", zap.Error(err)) + return err + } + if len(collections) == 0 { + log.Error("CreateTask: input collection not found") + return common.ErrCollectionNotFound + } + + // Check if output collection already exists + existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false) + if err != nil { + log.Error("CreateTask: failed to check output collection", zap.Error(err)) + return err + } + if len(existingOutputCollections) > 0 { + log.Error("CreateTask: output collection already exists") + return common.ErrCollectionUniqueConstraintViolation + } + + now := time.Now() + task := &dbmodel.Task{ + ID: taskID, + Name: req.Name, + TenantID: req.TenantId, + DatabaseID: databases[0].ID, + InputCollectionID: req.InputCollectionId, + OutputCollectionName: req.OutputCollectionName, + OperatorID: operatorID, + OperatorParams: req.Params, + CompletionOffset: 0, + LastRun: nil, + NextRun: nil, // Will be set to zero initially, scheduled by task scheduler + MinRecordsForTask: int64(req.MinRecordsForTask), + CurrentAttempts: 0, + CreatedAt: now, + UpdatedAt: now, + NextNonce: nextNonce, + OldestWrittenNonce: nil, + } + + // Try to insert task into database + err = s.catalog.metaDomain.TaskDb(txCtx).Insert(task) + if err != nil { + // Check if it's a unique constraint violation (concurrent creation) + if err == common.ErrTaskAlreadyExists { + log.Error("CreateTask: task already exists") + return common.ErrTaskAlreadyExists + } + log.Error("CreateTask: failed to insert task", zap.Error(err)) + return err + } + + log.Info("Task created successfully", zap.String("task_id", taskID.String()), zap.String("name", req.Name), zap.String("output_collection_name", outputCollectionName)) + return nil + }) + + if err != nil { + return nil, err + } + + return &coordinatorpb.CreateTaskResponse{ + TaskId: taskID.String(), + }, nil +} + +// GetTaskByName retrieves a task by name from the database +func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) { + // Can do both calls with a JOIN + task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName) + if err != nil { + return nil, err + } + + // If task not found, return empty response + if task == nil { + return nil, common.ErrTaskNotFound + } + + // Look up operator name from operators table + operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID) + if err != nil { + log.Error("GetTaskByName: failed to get operator", zap.Error(err)) + return nil, err + } + if operator == nil { + log.Error("GetTaskByName: operator not found", zap.String("operator_id", task.OperatorID.String())) + return nil, common.ErrOperatorNotFound + } + + // Debug logging + log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName)) + + // Convert task to response + response := &coordinatorpb.GetTaskByNameResponse{ + TaskId: proto.String(task.ID.String()), + Name: proto.String(task.Name), + OperatorName: proto.String(operator.OperatorName), + InputCollectionId: proto.String(task.InputCollectionID), + OutputCollectionName: proto.String(task.OutputCollectionName), + Params: proto.String(task.OperatorParams), + CompletionOffset: proto.Int64(task.CompletionOffset), + MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)), + } + // Add output_collection_id if it's set + if task.OutputCollectionID != nil { + response.OutputCollectionId = task.OutputCollectionID + } + return response, nil +} + +// DeleteTask soft deletes a task by name +func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) { + // First get the task to check if we need to delete the output collection + task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName) + if err != nil { + log.Error("DeleteTask: failed to get task", zap.Error(err)) + return nil, err + } + if task == nil { + log.Error("DeleteTask: task not found") + return nil, status.Errorf(codes.NotFound, "task not found") + } + + // If delete_output is true and output_collection_id is set, soft-delete the output collection + if req.DeleteOutput && task.OutputCollectionID != nil && *task.OutputCollectionID != "" { + collectionUUID, err := types.ToUniqueID(task.OutputCollectionID) + if err != nil { + log.Error("DeleteTask: invalid output_collection_id", zap.Error(err)) + return nil, status.Errorf(codes.InvalidArgument, "invalid output_collection_id: %v", err) + } + + deleteCollection := &model.DeleteCollection{ + ID: collectionUUID, + TenantID: task.TenantID, + DatabaseName: task.DatabaseID, + } + + err = s.SoftDeleteCollection(ctx, deleteCollection) + if err != nil { + // Log but don't fail - we still want to delete the task + log.Warn("DeleteTask: failed to delete output collection", zap.Error(err), zap.String("collection_id", *task.OutputCollectionID)) + } else { + log.Info("DeleteTask: deleted output collection", zap.String("collection_id", *task.OutputCollectionID)) + } + } + + // Now soft-delete the task + err = s.catalog.metaDomain.TaskDb(ctx).SoftDelete(req.InputCollectionId, req.TaskName) + if err != nil { + log.Error("DeleteTask failed", zap.Error(err)) + return nil, err + } + + log.Info("Task deleted", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) + + return &coordinatorpb.DeleteTaskResponse{ + Success: true, + }, nil +} + +// GetOperators retrieves all operators from the database +func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { + operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll() + if err != nil { + log.Error("GetOperators failed", zap.Error(err)) + return nil, err + } + + // Convert to proto response + protoOperators := make([]*coordinatorpb.Operator, len(operators)) + for i, op := range operators { + protoOperators[i] = &coordinatorpb.Operator{ + Id: op.OperatorID.String(), + Name: op.OperatorName, + } + } + + log.Info("GetOperators succeeded", zap.Int("count", len(operators))) + + return &coordinatorpb.GetOperatorsResponse{ + Operators: protoOperators, + }, nil +} diff --git a/go/pkg/sysdb/grpc/task_service.go b/go/pkg/sysdb/grpc/task_service.go new file mode 100644 index 00000000000..9f96f7faf2c --- /dev/null +++ b/go/pkg/sysdb/grpc/task_service.go @@ -0,0 +1,57 @@ +package grpc + +import ( + "context" + + "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" + "github.com/pingcap/log" + "go.uber.org/zap" +) + +func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) { + log.Info("CreateTask", zap.String("name", req.Name), zap.String("operator_name", req.OperatorName)) + + res, err := s.coordinator.CreateTask(ctx, req) + if err != nil { + log.Error("CreateTask failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + +func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) { + log.Info("GetTaskByName", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) + + res, err := s.coordinator.GetTaskByName(ctx, req) + if err != nil { + log.Error("GetTaskByName failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + +func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) { + log.Info("DeleteTask", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName)) + + res, err := s.coordinator.DeleteTask(ctx, req) + if err != nil { + log.Error("DeleteTask failed", zap.Error(err)) + return nil, err + } + + return res, nil +} + +func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) { + log.Info("GetOperators") + + res, err := s.coordinator.GetOperators(ctx, req) + if err != nil { + log.Error("GetOperators failed", zap.Error(err)) + return nil, err + } + + return res, nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/common.go b/go/pkg/sysdb/metastore/db/dao/common.go index 2962f01209d..ea3ec6a7ebc 100644 --- a/go/pkg/sysdb/metastore/db/dao/common.go +++ b/go/pkg/sysdb/metastore/db/dao/common.go @@ -36,3 +36,11 @@ func (*MetaDomain) SegmentDb(ctx context.Context) dbmodel.ISegmentDb { func (*MetaDomain) SegmentMetadataDb(ctx context.Context) dbmodel.ISegmentMetadataDb { return &segmentMetadataDb{dbcore.GetDB(ctx)} } + +func (*MetaDomain) TaskDb(ctx context.Context) dbmodel.ITaskDb { + return &taskDb{dbcore.GetDB(ctx)} +} + +func (*MetaDomain) OperatorDb(ctx context.Context) dbmodel.IOperatorDb { + return &operatorDb{dbcore.GetDB(ctx)} +} diff --git a/go/pkg/sysdb/metastore/db/dao/operator.go b/go/pkg/sysdb/metastore/db/dao/operator.go new file mode 100644 index 00000000000..ad91890587a --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dao/operator.go @@ -0,0 +1,60 @@ +package dao + +import ( + "errors" + + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/google/uuid" + "github.com/pingcap/log" + "go.uber.org/zap" + "gorm.io/gorm" +) + +type operatorDb struct { + db *gorm.DB +} + +var _ dbmodel.IOperatorDb = &operatorDb{} + +func (s *operatorDb) GetByName(operatorName string) (*dbmodel.Operator, error) { + var operator dbmodel.Operator + err := s.db. + Where("operator_name = ?", operatorName). + First(&operator).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + log.Error("GetOperatorByName failed", zap.Error(err)) + return nil, err + } + return &operator, nil +} + +func (s *operatorDb) GetByID(operatorID uuid.UUID) (*dbmodel.Operator, error) { + var operator dbmodel.Operator + err := s.db. + Where("operator_id = ?", operatorID). + First(&operator).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + log.Error("GetOperatorByID failed", zap.Error(err)) + return nil, err + } + return &operator, nil +} + +func (s *operatorDb) GetAll() ([]*dbmodel.Operator, error) { + var operators []*dbmodel.Operator + err := s.db.Find(&operators).Error + + if err != nil { + log.Error("GetAllOperators failed", zap.Error(err)) + return nil, err + } + return operators, nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/task.go b/go/pkg/sysdb/metastore/db/dao/task.go new file mode 100644 index 00000000000..f60c64b7c49 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dao/task.go @@ -0,0 +1,84 @@ +package dao + +import ( + "errors" + + "github.com/chroma-core/chroma/go/pkg/common" + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/jackc/pgx/v5/pgconn" + "github.com/pingcap/log" + "go.uber.org/zap" + "gorm.io/gorm" +) + +type taskDb struct { + db *gorm.DB +} + +var _ dbmodel.ITaskDb = &taskDb{} + +func (s *taskDb) DeleteAll() error { + return s.db.Where("1 = 1").Delete(&dbmodel.Task{}).Error +} + +func (s *taskDb) Insert(task *dbmodel.Task) error { + err := s.db.Create(task).Error + if err != nil { + log.Error("insert task failed", zap.Error(err)) + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + if ok { + switch pgErr.Code { + case "23505": + return common.ErrTaskAlreadyExists + default: + return err + } + } + return err + } + return nil +} + +func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.Task, error) { + var task dbmodel.Task + err := s.db. + Where("input_collection_id = ?", inputCollectionID). + Where("task_name = ?", taskName). + Where("is_deleted = ?", false). + First(&task).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + log.Error("GetTaskByName failed", zap.Error(err)) + return nil, err + } + return &task, nil +} + +func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error { + // Update task name and is_deleted in a single query + // Format: _deleted___ + result := s.db.Exec(` + UPDATE tasks + SET task_name = CONCAT('_deleted_', task_name, '_', input_collection_id, '_', task_id::text), + is_deleted = true, updated_at = NOW() + WHERE input_collection_id = ? + AND task_name = ? + AND is_deleted = false + `, inputCollectionID, taskName) + + if result.Error != nil { + log.Error("SoftDelete failed", zap.Error(result.Error)) + return result.Error + } + + // If no rows were affected, task was not found (or already deleted) + if result.RowsAffected == 0 { + return nil // Idempotent - no error if already deleted or not found + } + + return nil +} diff --git a/go/pkg/sysdb/metastore/db/dao/task_test.go b/go/pkg/sysdb/metastore/db/dao/task_test.go new file mode 100644 index 00000000000..94de1147608 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dao/task_test.go @@ -0,0 +1,356 @@ +package dao + +import ( + "testing" + + "github.com/chroma-core/chroma/go/pkg/common" + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbcore" + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/google/uuid" + "github.com/pingcap/log" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type TaskDbTestSuite struct { + suite.Suite + db *gorm.DB + Db *taskDb + t *testing.T +} + +func (suite *TaskDbTestSuite) SetupSuite() { + log.Info("setup suite") + suite.db, _ = dbcore.ConfigDatabaseForTesting() + suite.Db = &taskDb{ + db: suite.db, + } + + // Seed operators for tests - these must match dbmodel/constants.go + // This also serves as a validation that constants are correct + operators := []dbmodel.Operator{ + { + OperatorID: dbmodel.OperatorRecordCounter, + OperatorName: dbmodel.OperatorNameRecordCounter, + IsIncremental: dbmodel.OperatorRecordCounterIsIncremental, + ReturnType: dbmodel.OperatorRecordCounterReturnType, + }, + } + for _, op := range operators { + suite.db.Where(dbmodel.Operator{OperatorID: op.OperatorID}).FirstOrCreate(&op) + } +} + +func (suite *TaskDbTestSuite) SetupTest() { + log.Info("setup test") +} + +func (suite *TaskDbTestSuite) TearDownTest() { + log.Info("teardown test") +} + +func (suite *TaskDbTestSuite) TestTaskDb_Insert() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + task := &dbmodel.Task{ + ID: taskID, + Name: "test-insert-task", + OperatorID: operatorID, + InputCollectionID: "input_col_id", + OutputCollectionName: "output_col_name", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Verify task was inserted + var retrieved dbmodel.Task + err = suite.db.Where("task_name = ? AND tenant_id = ? AND database_id = ?", "test-insert-task", "tenant1", "db1").First(&retrieved).Error + suite.Require().NoError(err) + suite.Require().Equal(task.Name, retrieved.Name) + suite.Require().Equal(task.OperatorID, retrieved.OperatorID) + suite.Require().False(retrieved.IsDeleted) + + // Cleanup + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_Insert_DuplicateName() { + taskID1 := uuid.New() + operatorID1 := dbmodel.OperatorRecordCounter + nextNonce1, _ := uuid.NewV7() + + task1 := &dbmodel.Task{ + ID: taskID1, + Name: "test-task-1", + OperatorID: operatorID1, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce1, + } + + err := suite.Db.Insert(task1) + suite.Require().NoError(err) + + // Try to insert duplicate (same tenant, database, and name) + taskID2 := uuid.New() + operatorID2 := dbmodel.OperatorRecordCounter + nextNonce2, _ := uuid.NewV7() + + task2 := &dbmodel.Task{ + ID: taskID2, + Name: "test-task-1", // Same name as task1 + OperatorID: operatorID2, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce2, + } + + err = suite.Db.Insert(task2) + suite.Require().Error(err) + suite.Require().Equal(common.ErrTaskAlreadyExists, err) + + // Cleanup + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task1.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_GetByName() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + // Insert a task + task := &dbmodel.Task{ + ID: taskID, + Name: "test-get-task", + OperatorID: operatorID, + InputCollectionID: "input_col_id", + OutputCollectionName: "output_col_name", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Retrieve by name + retrieved, err := suite.Db.GetByName("input_col_id", "test-get-task") + suite.Require().NoError(err) + suite.Require().NotNil(retrieved) + suite.Require().Equal(task.ID, retrieved.ID) + suite.Require().Equal(task.Name, retrieved.Name) + suite.Require().Equal(task.OperatorID, retrieved.OperatorID) + + // Cleanup + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_GetByName_NotFound() { + // Try to get non-existent task + retrieved, err := suite.Db.GetByName("input_col_id", "nonexistent-task") + suite.Require().NoError(err) + suite.Require().Nil(retrieved) +} + +func (suite *TaskDbTestSuite) TestTaskDb_GetByName_IgnoresDeleted() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + // Insert a task + task := &dbmodel.Task{ + ID: taskID, + Name: "test-deleted-task", + OperatorID: operatorID, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Soft delete it + err = suite.Db.SoftDelete("input1", "test-deleted-task") + suite.Require().NoError(err) + + // GetByName should not return it + retrieved, err := suite.Db.GetByName("input1", "test-deleted-task") + suite.Require().NoError(err) + suite.Require().Nil(retrieved) + + // Cleanup + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_SoftDelete() { + taskID := uuid.New() + operatorID := dbmodel.OperatorRecordCounter + nextNonce, _ := uuid.NewV7() + + // Insert a task + task := &dbmodel.Task{ + ID: taskID, + Name: "test-soft-delete", + OperatorID: operatorID, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db1", + MinRecordsForTask: 100, + NextNonce: nextNonce, + } + + err := suite.Db.Insert(task) + suite.Require().NoError(err) + + // Soft delete + err = suite.Db.SoftDelete("input1", "test-soft-delete") + suite.Require().NoError(err) + + // Verify task is marked as deleted in DB + var retrieved dbmodel.Task + err = suite.db.Unscoped().Where("task_id = ?", task.ID).First(&retrieved).Error + suite.Require().NoError(err) + suite.Require().True(retrieved.IsDeleted) + + // Cleanup + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) +} + +func (suite *TaskDbTestSuite) TestTaskDb_SoftDelete_NotFound() { + // Try to delete non-existent task - should succeed but do nothing + err := suite.Db.SoftDelete("input1", "nonexistent-task") + suite.Require().NoError(err) +} + +func (suite *TaskDbTestSuite) TestTaskDb_DeleteAll() { + operatorID := dbmodel.OperatorRecordCounter + + // Insert multiple tasks + tasks := []*dbmodel.Task{ + { + ID: uuid.New(), + Name: "task1", + OperatorID: operatorID, + InputCollectionID: "input1", + OutputCollectionName: "output1", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db-delete-all", + MinRecordsForTask: 100, + NextNonce: uuid.Must(uuid.NewV7()), + }, + { + ID: uuid.New(), + Name: "task2", + OperatorID: operatorID, + InputCollectionID: "input2", + OutputCollectionName: "output2", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db-delete-all", + MinRecordsForTask: 100, + NextNonce: uuid.Must(uuid.NewV7()), + }, + { + ID: uuid.New(), + Name: "task3", + OperatorID: operatorID, + InputCollectionID: "input3", + OutputCollectionName: "output3", + OperatorParams: "{}", + TenantID: "tenant1", + DatabaseID: "db-delete-all", + MinRecordsForTask: 100, + NextNonce: uuid.Must(uuid.NewV7()), + }, + } + + for _, task := range tasks { + err := suite.Db.Insert(task) + suite.Require().NoError(err) + } + + // Delete all tasks + err := suite.Db.DeleteAll() + suite.Require().NoError(err) + + // Verify all tasks are deleted + for _, task := range tasks { + retrieved, err := suite.Db.GetByName(task.InputCollectionID, task.Name) + suite.Require().NoError(err) + suite.Require().Nil(retrieved) + } + + // Cleanup + for _, task := range tasks { + suite.db.Unscoped().Delete(&dbmodel.Task{}, "task_id = ?", task.ID) + } +} + +// TestOperatorConstantsMatchSeededDatabase verifies that operator constants in +// dbmodel/constants.go match what we seed in the test database (which should match migrations). +// This catches drift between constants and migrations at test time. +func (suite *TaskDbTestSuite) TestOperatorConstantsMatchSeededDatabase() { + // Map of operator names to expected UUIDs from constants.go + // When you add a new operator: + // 1. Add to migration + // 2. Add to dbmodel/constants.go + // 3. Add to SetupSuite() seed list + // 4. Add here for validation + expectedOperators := map[string]uuid.UUID{ + dbmodel.OperatorNameRecordCounter: dbmodel.OperatorRecordCounter, + } + + // Verify count matches + var actualCount int64 + err := suite.db.Model(&dbmodel.Operator{}).Count(&actualCount).Error + suite.Require().NoError(err, "Failed to count operators") + + expectedCount := int64(len(expectedOperators)) + suite.Require().Equal(expectedCount, actualCount, + "Operator count mismatch. Expected: %d, Actual: %d. "+ + "Did you forget to update SetupSuite() after adding a new operator?", + expectedCount, actualCount) + + // Verify each operator + for operatorName, expectedUUID := range expectedOperators { + var operator dbmodel.Operator + err := suite.db.Where("operator_name = ?", operatorName).First(&operator).Error + suite.Require().NoError(err, "Operator '%s' not found", operatorName) + + suite.Require().Equal(expectedUUID, operator.OperatorID, + "Operator '%s' UUID mismatch. Constant: %s, DB: %s", + operatorName, expectedUUID, operator.OperatorID) + } +} + +func TestTaskDbTestSuite(t *testing.T) { + testSuite := new(TaskDbTestSuite) + testSuite.t = t + suite.Run(t, testSuite) +} diff --git a/go/pkg/sysdb/metastore/db/dbcore/core.go b/go/pkg/sysdb/metastore/db/dbcore/core.go index 7c3e22a1b55..22bd80776a6 100644 --- a/go/pkg/sysdb/metastore/db/dbcore/core.go +++ b/go/pkg/sysdb/metastore/db/dbcore/core.go @@ -225,6 +225,14 @@ func CreateTestTables(db *gorm.DB) { if !tableExist { db.Migrator().CreateTable(&dbmodel.Segment{}) } + tableExist = db.Migrator().HasTable(&dbmodel.Operator{}) + if !tableExist { + db.Migrator().CreateTable(&dbmodel.Operator{}) + } + tableExist = db.Migrator().HasTable(&dbmodel.Task{}) + if !tableExist { + db.Migrator().CreateTable(&dbmodel.Task{}) + } // create default tenant and database CreateDefaultTenantAndDatabase(db) diff --git a/go/pkg/sysdb/metastore/db/dbmodel/common.go b/go/pkg/sysdb/metastore/db/dbmodel/common.go index 3ad50e2933c..cdee7a976cd 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/common.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/common.go @@ -14,6 +14,8 @@ type IMetaDomain interface { CollectionMetadataDb(ctx context.Context) ICollectionMetadataDb SegmentDb(ctx context.Context) ISegmentDb SegmentMetadataDb(ctx context.Context) ISegmentMetadataDb + TaskDb(ctx context.Context) ITaskDb + OperatorDb(ctx context.Context) IOperatorDb } //go:generate mockery --name=ITransaction diff --git a/go/pkg/sysdb/metastore/db/dbmodel/constants.go b/go/pkg/sysdb/metastore/db/dbmodel/constants.go new file mode 100644 index 00000000000..a1f8610d3c1 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dbmodel/constants.go @@ -0,0 +1,35 @@ +package dbmodel + +import "github.com/google/uuid" + +// operator IDs that are pre-populated in the database. +// +// IMPORTANT: These constants must stay in sync with: +// 1. Database migrations that insert operators (go/pkg/sysdb/metastore/db/migrations/*.sql) +// 2. Rust constants (rust/types/src/operators.rs) +// +// When adding a new operator: +// 1. Create a migration to INSERT the operator with a UUID +// 2. Add the UUID constant here +// 3. Add the name constant below +// 4. Add matching constants to rust/types/src/operators.rs +var ( + // OperatorRecordCounter is the UUID for the built-in record_counter operator + // Must match: migration 20250930122132.sql and rust/types/src/operators.rs::OPERATOR_RECORD_COUNTER_ID + OperatorRecordCounter = uuid.MustParse("ccf2e3ba-633e-43ba-9394-46b0c54c61e3") +) + +// OperatorNames contains the names of pre-populated operators. +// Must stay in sync with database and Rust constants. +const ( + // OperatorNameRecordCounter must match rust/types/src/operators.rs::OPERATOR_RECORD_COUNTER_NAME + OperatorNameRecordCounter = "record_counter" +) + +// Operator metadata +const ( + // OperatorRecordCounterIsIncremental indicates record_counter is an incremental operator + OperatorRecordCounterIsIncremental = true + // OperatorRecordCounterReturnType is the JSON schema for record_counter's return type + OperatorRecordCounterReturnType = `{"type": "object", "properties": {"count": {"type": "integer", "description": "Number of records processed"}}}` +) diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go index 97143e3b556..931245a64e5 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionMetadataDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionMetadataDb.go index 5aa9ca82b83..5de747d50ae 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionMetadataDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionMetadataDb.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks @@ -12,7 +12,7 @@ type ICollectionMetadataDb struct { mock.Mock } -// DeleteAll provides a mock function with given fields: +// DeleteAll provides a mock function with no fields func (_m *ICollectionMetadataDb) DeleteAll() error { ret := _m.Called() diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go index 90d4d8dc782..e4d5cae2e39 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IDatabaseDb.go @@ -1,10 +1,12 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks import ( dbmodel "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" mock "github.com/stretchr/testify/mock" + + time "time" ) // IDatabaseDb is an autogenerated mock type for the IDatabaseDb type @@ -12,25 +14,7 @@ type IDatabaseDb struct { mock.Mock } -// Delete provides a mock function with given fields: databaseID -func (_m *IDatabaseDb) Delete(databaseID string) error { - ret := _m.Called(databaseID) - - if len(ret) == 0 { - panic("no return value specified for Delete") - } - - var r0 error - if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(databaseID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DeleteAll provides a mock function with given fields: +// DeleteAll provides a mock function with no fields func (_m *IDatabaseDb) DeleteAll() error { ret := _m.Called() @@ -48,29 +32,27 @@ func (_m *IDatabaseDb) DeleteAll() error { return r0 } -// GetAllDatabases provides a mock function with given fields: -func (_m *IDatabaseDb) GetAllDatabases() ([]*dbmodel.Database, error) { - ret := _m.Called() +// FinishDatabaseDeletion provides a mock function with given fields: cutoffTime +func (_m *IDatabaseDb) FinishDatabaseDeletion(cutoffTime time.Time) (uint64, error) { + ret := _m.Called(cutoffTime) if len(ret) == 0 { - panic("no return value specified for GetAllDatabases") + panic("no return value specified for FinishDatabaseDeletion") } - var r0 []*dbmodel.Database + var r0 uint64 var r1 error - if rf, ok := ret.Get(0).(func() ([]*dbmodel.Database, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(time.Time) (uint64, error)); ok { + return rf(cutoffTime) } - if rf, ok := ret.Get(0).(func() []*dbmodel.Database); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(time.Time) uint64); ok { + r0 = rf(cutoffTime) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*dbmodel.Database) - } + r0 = ret.Get(0).(uint64) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(time.Time) error); ok { + r1 = rf(cutoffTime) } else { r1 = ret.Error(1) } @@ -156,6 +138,24 @@ func (_m *IDatabaseDb) ListDatabases(limit *int32, offset *int32, tenantID strin return r0, r1 } +// SoftDelete provides a mock function with given fields: databaseID +func (_m *IDatabaseDb) SoftDelete(databaseID string) error { + ret := _m.Called(databaseID) + + if len(ret) == 0 { + panic("no return value specified for SoftDelete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(databaseID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // NewIDatabaseDb creates a new instance of IDatabaseDb. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewIDatabaseDb(t interface { diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IMetaDomain.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IMetaDomain.go index be2268f126e..503a8046f09 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IMetaDomain.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IMetaDomain.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks @@ -74,6 +74,26 @@ func (_m *IMetaDomain) DatabaseDb(ctx context.Context) dbmodel.IDatabaseDb { return r0 } +// OperatorDb provides a mock function with given fields: ctx +func (_m *IMetaDomain) OperatorDb(ctx context.Context) dbmodel.IOperatorDb { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for OperatorDb") + } + + var r0 dbmodel.IOperatorDb + if rf, ok := ret.Get(0).(func(context.Context) dbmodel.IOperatorDb); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(dbmodel.IOperatorDb) + } + } + + return r0 +} + // SegmentDb provides a mock function with given fields: ctx func (_m *IMetaDomain) SegmentDb(ctx context.Context) dbmodel.ISegmentDb { ret := _m.Called(ctx) @@ -114,6 +134,26 @@ func (_m *IMetaDomain) SegmentMetadataDb(ctx context.Context) dbmodel.ISegmentMe return r0 } +// TaskDb provides a mock function with given fields: ctx +func (_m *IMetaDomain) TaskDb(ctx context.Context) dbmodel.ITaskDb { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for TaskDb") + } + + var r0 dbmodel.ITaskDb + if rf, ok := ret.Get(0).(func(context.Context) dbmodel.ITaskDb); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(dbmodel.ITaskDb) + } + } + + return r0 +} + // TenantDb provides a mock function with given fields: ctx func (_m *IMetaDomain) TenantDb(ctx context.Context) dbmodel.ITenantDb { ret := _m.Called(ctx) diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IOperatorDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IOperatorDb.go new file mode 100644 index 00000000000..a28e06555e8 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IOperatorDb.go @@ -0,0 +1,89 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + dbmodel "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + mock "github.com/stretchr/testify/mock" + + uuid "github.com/google/uuid" +) + +// IOperatorDb is an autogenerated mock type for the IOperatorDb type +type IOperatorDb struct { + mock.Mock +} + +// GetByID provides a mock function with given fields: operatorID +func (_m *IOperatorDb) GetByID(operatorID uuid.UUID) (*dbmodel.Operator, error) { + ret := _m.Called(operatorID) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 *dbmodel.Operator + var r1 error + if rf, ok := ret.Get(0).(func(uuid.UUID) (*dbmodel.Operator, error)); ok { + return rf(operatorID) + } + if rf, ok := ret.Get(0).(func(uuid.UUID) *dbmodel.Operator); ok { + r0 = rf(operatorID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*dbmodel.Operator) + } + } + + if rf, ok := ret.Get(1).(func(uuid.UUID) error); ok { + r1 = rf(operatorID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetByName provides a mock function with given fields: operatorName +func (_m *IOperatorDb) GetByName(operatorName string) (*dbmodel.Operator, error) { + ret := _m.Called(operatorName) + + if len(ret) == 0 { + panic("no return value specified for GetByName") + } + + var r0 *dbmodel.Operator + var r1 error + if rf, ok := ret.Get(0).(func(string) (*dbmodel.Operator, error)); ok { + return rf(operatorName) + } + if rf, ok := ret.Get(0).(func(string) *dbmodel.Operator); ok { + r0 = rf(operatorName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*dbmodel.Operator) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(operatorName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewIOperatorDb creates a new instance of IOperatorDb. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewIOperatorDb(t interface { + mock.TestingT + Cleanup(func()) +}) *IOperatorDb { + mock := &IOperatorDb{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go index 040bf1de60c..4a1b8f8879f 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks @@ -16,7 +16,7 @@ type ISegmentDb struct { mock.Mock } -// DeleteAll provides a mock function with given fields: +// DeleteAll provides a mock function with no fields func (_m *ISegmentDb) DeleteAll() error { ret := _m.Called() diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentMetadataDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentMetadataDb.go index bd6ba121373..3a90281ebe5 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentMetadataDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentMetadataDb.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks @@ -12,7 +12,7 @@ type ISegmentMetadataDb struct { mock.Mock } -// DeleteAll provides a mock function with given fields: +// DeleteAll provides a mock function with no fields func (_m *ISegmentMetadataDb) DeleteAll() error { ret := _m.Called() diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITaskDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITaskDb.go new file mode 100644 index 00000000000..8a919b0a5fa --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITaskDb.go @@ -0,0 +1,111 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + dbmodel "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + mock "github.com/stretchr/testify/mock" +) + +// ITaskDb is an autogenerated mock type for the ITaskDb type +type ITaskDb struct { + mock.Mock +} + +// DeleteAll provides a mock function with no fields +func (_m *ITaskDb) DeleteAll() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DeleteAll") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetByName provides a mock function with given fields: inputCollectionID, taskName +func (_m *ITaskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.Task, error) { + ret := _m.Called(inputCollectionID, taskName) + + if len(ret) == 0 { + panic("no return value specified for GetByName") + } + + var r0 *dbmodel.Task + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (*dbmodel.Task, error)); ok { + return rf(inputCollectionID, taskName) + } + if rf, ok := ret.Get(0).(func(string, string) *dbmodel.Task); ok { + r0 = rf(inputCollectionID, taskName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*dbmodel.Task) + } + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(inputCollectionID, taskName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Insert provides a mock function with given fields: task +func (_m *ITaskDb) Insert(task *dbmodel.Task) error { + ret := _m.Called(task) + + if len(ret) == 0 { + panic("no return value specified for Insert") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*dbmodel.Task) error); ok { + r0 = rf(task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SoftDelete provides a mock function with given fields: inputCollectionID, taskName +func (_m *ITaskDb) SoftDelete(inputCollectionID string, taskName string) error { + ret := _m.Called(inputCollectionID, taskName) + + if len(ret) == 0 { + panic("no return value specified for SoftDelete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(inputCollectionID, taskName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewITaskDb creates a new instance of ITaskDb. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewITaskDb(t interface { + mock.TestingT + Cleanup(func()) +}) *ITaskDb { + mock := &ITaskDb{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITenantDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITenantDb.go index f76282d3a44..92a8a7e9f4b 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITenantDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITenantDb.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.53.4. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITransaction.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITransaction.go index d46815fdca9..e0700c379e3 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITransaction.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ITransaction.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.46.2. DO NOT EDIT. +// Code generated by mockery v2.53.5. DO NOT EDIT. package mocks diff --git a/go/pkg/sysdb/metastore/db/dbmodel/operator.go b/go/pkg/sysdb/metastore/db/dbmodel/operator.go new file mode 100644 index 00000000000..3f624cef4ea --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dbmodel/operator.go @@ -0,0 +1,23 @@ +package dbmodel + +import ( + "github.com/google/uuid" +) + +type Operator struct { + OperatorID uuid.UUID `gorm:"operator_id;primaryKey;unique"` + OperatorName string `gorm:"operator_name;type:text;not null;unique"` + IsIncremental bool `gorm:"is_incremental;type:bool;not null"` + ReturnType string `gorm:"return_type;type:jsonb;not null"` +} + +func (v Operator) TableName() string { + return "operators" +} + +//go:generate mockery --name=IOperatorDb +type IOperatorDb interface { + GetByName(operatorName string) (*Operator, error) + GetByID(operatorID uuid.UUID) (*Operator, error) + GetAll() ([]*Operator, error) +} diff --git a/go/pkg/sysdb/metastore/db/dbmodel/task.go b/go/pkg/sysdb/metastore/db/dbmodel/task.go new file mode 100644 index 00000000000..9dd79f9ce79 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/dbmodel/task.go @@ -0,0 +1,43 @@ +package dbmodel + +import ( + "time" + + "github.com/google/uuid" +) + +type Task struct { + ID uuid.UUID `gorm:"column:task_id;primaryKey"` + Name string `gorm:"column:task_name;type:text;not null;uniqueIndex:unique_task_per_collection,priority:2"` + TenantID string `gorm:"column:tenant_id;type:text;not null"` + DatabaseID string `gorm:"column:database_id;type:text;not null"` + InputCollectionID string `gorm:"column:input_collection_id;type:text;not null;uniqueIndex:unique_task_per_collection,priority:1"` + OutputCollectionName string `gorm:"column:output_collection_name;type:text;not null"` + OutputCollectionID *string `gorm:"column:output_collection_id;type:text;default:null"` + OperatorID uuid.UUID `gorm:"column:operator_id;type:uuid;not null"` + OperatorParams string `gorm:"column:operator_params;type:jsonb;not null"` + CompletionOffset int64 `gorm:"column:completion_offset;type:bigint;not null;default:0"` + LastRun *time.Time `gorm:"column:last_run;type:timestamp"` + NextRun *time.Time `gorm:"column:next_run;type:timestamp"` + MinRecordsForTask int64 `gorm:"column:min_records_for_task;type:bigint;not null;default:100"` + CurrentAttempts int32 `gorm:"column:current_attempts;type:integer;not null;default:0"` + IsAlive bool `gorm:"column:is_alive;type:boolean;not null;default:true"` + IsDeleted bool `gorm:"column:is_deleted;type:boolean;not null;default:false"` + CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` + UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP"` + TaskTemplateParent *uuid.UUID `gorm:"column:task_template_parent;type:uuid;default:null"` + NextNonce uuid.UUID `gorm:"column:next_nonce;type:uuid;not null"` + OldestWrittenNonce *uuid.UUID `gorm:"column:oldest_written_nonce;type:uuid;default:null"` +} + +func (v Task) TableName() string { + return "tasks" +} + +//go:generate mockery --name=ITaskDb +type ITaskDb interface { + Insert(task *Task) error + GetByName(inputCollectionID string, taskName string) (*Task, error) + SoftDelete(inputCollectionID string, taskName string) error + DeleteAll() error +} diff --git a/go/pkg/sysdb/metastore/db/migrations/20251001073000.sql b/go/pkg/sysdb/metastore/db/migrations/20251001073000.sql new file mode 100644 index 00000000000..4ad106c6c19 --- /dev/null +++ b/go/pkg/sysdb/metastore/db/migrations/20251001073000.sql @@ -0,0 +1,59 @@ +-- Create "operators" table +CREATE TABLE "public"."operators" ( + "operator_id" uuid NOT NULL, + "operator_name" text NOT NULL UNIQUE, + "is_incremental" boolean NOT NULL, + "return_type" jsonb NOT NULL, + PRIMARY KEY ("operator_id") +); + +-- Insert sample operator: record counter +INSERT INTO "public"."operators" ("operator_id", "operator_name", "is_incremental", "return_type") +VALUES ( + 'ccf2e3ba-633e-43ba-9394-46b0c54c61e3', -- Randomly generated + 'record_counter', + true, + '{"type": "object", "properties": {"count": {"type": "integer", "description": "Number of records processed"}}}' +); + +-- Create "tasks" table +CREATE TABLE "public"."tasks" ( + "task_id" uuid NOT NULL, + "task_name" text NOT NULL, + "tenant_id" text NOT NULL, + "database_id" text NOT NULL, + "input_collection_id" text NOT NULL, -- Keeping these as text instead of UUID until collections.id becomes a UUID + "output_collection_name" text NOT NULL, + "output_collection_id" text DEFAULT NULL, -- Lazily filled in after output collection is created + "operator_id" uuid NOT NULL, + "operator_params" jsonb NOT NULL, + "completion_offset" bigint NOT NULL DEFAULT 0, + "last_run" timestamp NULL DEFAULT NULL, + "next_run" timestamp NULL DEFAULT NULL, + "min_records_for_task" bigint NOT NULL DEFAULT 100, + "current_attempts" integer NOT NULL DEFAULT 0, + "is_alive" boolean NOT NULL DEFAULT true, + "is_deleted" boolean NOT NULL DEFAULT false, + "created_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + "task_template_parent" uuid NULL, + "next_nonce" UUID NOT NULL, -- UUIDv7 + "oldest_written_nonce" UUID DEFAULT NULL, -- UUIDv7 + PRIMARY KEY ("task_id"), + CONSTRAINT "unique_task_per_collection" UNIQUE ("input_collection_id", "task_name") +); + +-- Create "task_templates" table +CREATE TABLE "public"."task_templates" ( + "template_id" uuid NOT NULL, + "tenant_id" text NOT NULL, + "database_id" text NOT NULL, + "template_name" text NOT NULL, + "operator_id" text NOT NULL, + "params" jsonb NOT NULL DEFAULT '{}'::jsonb, + "output_collection_pattern" text NOT NULL, + "created_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY ("template_id"), + CONSTRAINT "unique_template_per_tenant_db" UNIQUE ("tenant_id", "database_id", "template_name") +); diff --git a/go/pkg/sysdb/metastore/db/migrations/atlas.sum b/go/pkg/sysdb/metastore/db/migrations/atlas.sum index ec19326a1c6..10a745ddb7e 100644 --- a/go/pkg/sysdb/metastore/db/migrations/atlas.sum +++ b/go/pkg/sysdb/metastore/db/migrations/atlas.sum @@ -1,4 +1,4 @@ -h1:oTWk6ETg+Z0bi5kIQcODs0oNIvOLYPBpI83E7vWGIx4= +h1:Jk3VaF1qoRNVAB7cCxgSDFiiH9Y6r1zSIIj1SxhCklc= 20240313233558.sql h1:Gv0TiSYsqGoOZ2T2IWvX4BOasauxool8PrBOIjmmIdg= 20240321194713.sql h1:kVkNpqSFhrXGVGFFvL7JdK3Bw31twFcEhI6A0oCFCkg= 20240327075032.sql h1:nlr2J74XRU8erzHnKJgMr/tKqJxw9+R6RiiEBuvuzgo= @@ -19,3 +19,4 @@ h1:oTWk6ETg+Z0bi5kIQcODs0oNIvOLYPBpI83E7vWGIx4= 20250716123832.sql h1:2zRLgINX+VGrCVmHHg5LUi9VZmkLF+vien/7W6t9YEk= 20250806213245.sql h1:OgEOd3bL+rKdQ2x/Hcm3f0/yyrWirJkPm14V5N4sgKE= 20250930122132.sql h1:ch67SU2K5X4gV5E1knOEk/yprnn9FrbZsJCkmUnAbqo= +20251001073000.sql h1:pdl+M9f46vz7rbXZtJjOWTXlbSBpL2a0nVHl5VUOOsg= diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index d1302a7fb9f..4b65dfee292 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -524,6 +524,61 @@ message BatchGetCollectionSoftDeleteStatusResponse { map collection_id_to_is_soft_deleted = 1; } +message CreateTaskRequest { + string name = 1; + string operator_name = 2; + string input_collection_id = 3; + string output_collection_name = 4; + string params = 5; + string tenant_id = 6; + string database = 7; + uint64 min_records_for_task = 8; +} + +message CreateTaskResponse { + string task_id = 1; +} + +message GetTaskByNameRequest { + string input_collection_id = 1; + string task_name = 2; +} + +message GetTaskByNameResponse { + optional string task_id = 1; + optional string name = 2; + optional string operator_name = 3; + optional string input_collection_id = 4; + optional string output_collection_name = 5; + optional string output_collection_id = 6; + optional string params = 7; + optional int64 completion_offset = 8; + optional uint64 min_records_for_task = 9; +} + +message DeleteTaskRequest { + string input_collection_id = 1; + string task_name = 2; + bool delete_output = 3; // If true and output_collection_id is not null, atomically soft-delete the output collection +} + +message DeleteTaskResponse { + bool success = 1; +} + +message Operator { + string id = 1; + string name = 2; +} + +message GetOperatorsRequest { + // Empty request - returns all operators +} + +message GetOperatorsResponse { + repeated Operator operators = 1; +} + service SysDB { rpc CreateDatabase(CreateDatabaseRequest) returns (CreateDatabaseResponse) {} rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {} @@ -561,4 +616,8 @@ service SysDB { rpc DeleteCollectionVersion(DeleteCollectionVersionRequest) returns (DeleteCollectionVersionResponse) {} rpc BatchGetCollectionVersionFilePaths(BatchGetCollectionVersionFilePathsRequest) returns (BatchGetCollectionVersionFilePathsResponse) {} rpc BatchGetCollectionSoftDeleteStatus(BatchGetCollectionSoftDeleteStatusRequest) returns (BatchGetCollectionSoftDeleteStatusResponse) {} + rpc CreateTask(CreateTaskRequest) returns (CreateTaskResponse) {} + rpc GetTaskByName(GetTaskByNameRequest) returns (GetTaskByNameResponse) {} + rpc DeleteTask(DeleteTaskRequest) returns (DeleteTaskResponse) {} + rpc GetOperators(GetOperatorsRequest) returns (GetOperatorsResponse) {} } diff --git a/rust/Dockerfile b/rust/Dockerfile index 3d88b2a6db8..f10ae0998b2 100644 --- a/rust/Dockerfile +++ b/rust/Dockerfile @@ -25,6 +25,8 @@ COPY idl/ idl/ COPY Cargo.toml Cargo.toml COPY Cargo.lock Cargo.lock COPY rust/ rust/ +# Copy Go constants file needed by Rust build script for operator code generation +COPY go/pkg/sysdb/metastore/db/dbmodel/constants.go go/pkg/sysdb/metastore/db/dbmodel/constants.go # Skip building these as they're not needed by images (and if Python bindings are built, the final binaries are unnecessarily linked against Python). ENV EXCLUDED_PACKAGES="chromadb_rust_bindings chromadb-js-bindings chroma-benchmark " diff --git a/rust/cli/src/commands/login.rs b/rust/cli/src/commands/login.rs index 1ee2bf9827a..5a32fcdeb17 100644 --- a/rust/cli/src/commands/login.rs +++ b/rust/cli/src/commands/login.rs @@ -269,7 +269,7 @@ pub async fn headless_login(args: LoginArgs) -> Result<(), CliError> { config.current_profile = profile_name.clone(); write_config(&config)?; } - + if !config.current_profile.eq(&profile_name) { println!("{}", set_profile_message(&profile_name)); } diff --git a/rust/frontend/src/impls/service_based_frontend.rs b/rust/frontend/src/impls/service_based_frontend.rs index 9c388056a54..6712ab490e1 100644 --- a/rust/frontend/src/impls/service_based_frontend.rs +++ b/rust/frontend/src/impls/service_based_frontend.rs @@ -1926,6 +1926,66 @@ mod tests { .any(|s| s.r#type == SegmentType::BlockfileRecord && s.scope == SegmentScope::RECORD)); } + #[tokio::test] + async fn test_k8s_integration_operator_constants() { + // Validate that hardcoded Rust operator constants match the live database. + // This prevents drift between constants and database migrations. + use chroma_types::{OPERATOR_RECORD_COUNTER_ID, OPERATOR_RECORD_COUNTER_NAME}; + use std::collections::HashMap; + + // Map of operator names to their expected UUID constants + // Add new operators here as they are added to rust/types/src/operators.rs + let expected_operators: HashMap<&str, uuid::Uuid> = + [(OPERATOR_RECORD_COUNTER_NAME, OPERATOR_RECORD_COUNTER_ID)] + .iter() + .cloned() + .collect(); + + // Connect to sysdb via gRPC + let registry = Registry::new(); + let sysdb_config = chroma_sysdb::SysDbConfig::Grpc(GrpcSysDbConfig { + host: "localhost".to_string(), + port: 50051, + ..Default::default() + }); + let mut sysdb = SysDb::try_from_config(&sysdb_config, ®istry) + .await + .unwrap(); + + // Get all operators from the database via gRPC + let operators = sysdb.get_all_operators().await.unwrap(); + + // Verify count matches expectations + assert_eq!( + operators.len(), + expected_operators.len(), + "Operator count mismatch. If you added a new operator to migrations, \ + rebuild Rust (cargo build -p chroma-types) to auto-generate constants and update this test. \ + Expected: {}, Actual: {}", + expected_operators.len(), + operators.len() + ); + + // Verify each operator constant matches the database + for (operator_name, expected_uuid) in &expected_operators { + let db_operator = operators + .iter() + .find(|(name, _)| name == operator_name) + .unwrap_or_else(|| panic!("Operator '{}' not found in database", operator_name)); + + assert_eq!( + *expected_uuid, db_operator.1, + "Operator '{}' UUID mismatch. Code: {}, DB: {}", + operator_name, expected_uuid, db_operator.1 + ); + } + + println!( + "Verified {} operator(s) match database", + expected_operators.len() + ); + } + #[test] fn test_crn_parsing() { use chroma_types::GetCollectionByCrnRequest; diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index 4c5ce6552f7..9679b820a3a 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -475,6 +475,17 @@ impl SysDb { } } + // Only meant for testing. + pub async fn get_all_operators( + &mut self, + ) -> Result, Box> { + match self { + SysDb::Grpc(grpc) => grpc.get_all_operators().await, + SysDb::Sqlite(_) => unimplemented!("get_all_operators not implemented for sqlite"), + SysDb::Test(_) => unimplemented!("get_all_operators not implemented for test"), + } + } + pub async fn batch_get_collection_version_file_paths( &mut self, collection_ids: Vec, @@ -1363,6 +1374,23 @@ impl GrpcSysDb { }) } + async fn get_all_operators( + &mut self, + ) -> Result, Box> { + let res = self + .client + .get_operators(chroma_proto::GetOperatorsRequest {}) + .await?; + + let operators = res.into_inner().operators; + let mut result = Vec::new(); + for op in operators { + let id = uuid::Uuid::parse_str(&op.id)?; + result.push((op.name, id)); + } + Ok(result) + } + async fn batch_get_collection_version_file_paths( &mut self, collection_ids: Vec, diff --git a/rust/types/build.rs b/rust/types/build.rs index 486b5b79f0d..31cf06deddd 100644 --- a/rust/types/build.rs +++ b/rust/types/build.rs @@ -1,3 +1,5 @@ +mod operator_codegen; + fn main() -> Result<(), Box> { // Compile the protobuf files in the chromadb proto directory. let mut proto_paths = vec![ @@ -24,5 +26,8 @@ fn main() -> Result<(), Box> { .emit_rerun_if_changed(true) .compile(&proto_paths, &["idl/"])?; + // Generate operator constants from Go source + operator_codegen::generate_operator_constants()?; + Ok(()) } diff --git a/rust/types/operator_codegen.rs b/rust/types/operator_codegen.rs new file mode 100644 index 00000000000..7d266cdc70e --- /dev/null +++ b/rust/types/operator_codegen.rs @@ -0,0 +1,173 @@ +/// Module for generating Rust operator constants from Go source code. +/// +/// This module is used by the build script to automatically generate operator constants +/// by parsing the Go constants file at build time. +use std::fs; +use std::path::Path; + +pub fn generate_operator_constants() -> Result<(), Box> { + // Get the workspace root (two levels up from rust/types) + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")?; + let workspace_root = Path::new(&manifest_dir) + .parent() + .and_then(|p| p.parent()) + .ok_or("Failed to find workspace root")?; + + let go_constants_path = workspace_root.join("go/pkg/sysdb/metastore/db/dbmodel/constants.go"); + let out_dir = std::env::var("OUT_DIR")?; + let dest_path = Path::new(&out_dir).join("operators_generated.rs"); + + // Tell Cargo to rerun if the Go file changes + println!("cargo:rerun-if-changed={}", go_constants_path.display()); + + // Read the Go constants file + let go_content = fs::read_to_string(&go_constants_path) + .map_err(|e| format!("Failed to read {}: {}", go_constants_path.display(), e))?; + + // Parse operator UUIDs and names + let mut operators = Vec::new(); + + // Parse UUID constants like: + // OperatorRecordCounter = uuid.MustParse("ccf2e3ba-633e-43ba-9394-46b0c54c61e3") + for line in go_content.lines() { + let trimmed = line.trim(); + // Only match lines that contain uuid.MustParse to avoid parsing other operator constants + if trimmed.starts_with("Operator") && trimmed.contains("uuid.MustParse") { + if let Some(uuid_line) = trimmed.strip_prefix("Operator") { + if let Some(uuid_str) = extract_uuid_from_line(uuid_line) { + // Extract the operator name from the Go constant name + // OperatorRecordCounter -> RecordCounter -> record_counter + if let Some(name_part) = uuid_line.split('=').next() { + let const_name = name_part.trim(); + let operator_name = camel_to_snake_case(const_name); + operators.push((const_name.to_string(), operator_name.clone(), uuid_str)); + } + } + } + } + } + + // Also parse name constants like: + // OperatorNameRecordCounter = "record_counter" + let mut name_map = std::collections::HashMap::new(); + for line in go_content.lines() { + if let Some(name_line) = line.trim().strip_prefix("OperatorName") { + if let Some((const_name, name_str)) = extract_name_from_line(name_line) { + name_map.insert(const_name, name_str); + } + } + } + + // Generate Rust code + let mut rust_code = String::from( + "/// Well-known operator IDs and names that are pre-populated in the database\n", + ); + rust_code.push_str("/// \n"); + rust_code.push_str("/// GENERATED CODE - DO NOT EDIT MANUALLY\n"); + rust_code.push_str( + "/// This file is auto-generated from go/pkg/sysdb/metastore/db/dbmodel/constants.go\n", + ); + rust_code.push_str("/// by the build script in rust/types/build.rs\n"); + rust_code.push_str("use uuid::Uuid;\n\n"); + + for (go_const_name, rust_name_base, uuid_str) in &operators { + // Parse UUID to get byte array + let uuid_bytes = parse_uuid_to_bytes(uuid_str)?; + + // Get the name constant from the name map if available + let name_value = name_map + .get(&format!("Name{}", go_const_name)) + .map(|s| s.as_str()) + .unwrap_or(rust_name_base.as_str()); + + rust_code.push_str(&format!( + "/// UUID for the built-in {} operator\n", + name_value + )); + rust_code.push_str(&format!( + "pub const OPERATOR_{}_ID: Uuid = Uuid::from_bytes([\n", + rust_name_base.to_uppercase() + )); + rust_code.push_str(&format!(" {}\n", format_uuid_bytes(&uuid_bytes))); + rust_code.push_str("]);\n"); + + rust_code.push_str(&format!( + "/// Name of the built-in {} operator\n", + name_value + )); + rust_code.push_str(&format!( + "pub const OPERATOR_{}_NAME: &str = \"{}\";\n\n", + rust_name_base.to_uppercase(), + name_value + )); + } + + // Write the generated file + fs::write(&dest_path, rust_code) + .map_err(|e| format!("Failed to write generated file: {}", e))?; + + Ok(()) +} + +fn extract_uuid_from_line(line: &str) -> Option { + // Extract UUID from: RecordCounter = uuid.MustParse("ccf2e3ba-633e-43ba-9394-46b0c54c61e3") + let parts: Vec<&str> = line.split('"').collect(); + if parts.len() >= 2 { + Some(parts[1].to_string()) + } else { + None + } +} + +fn extract_name_from_line(line: &str) -> Option<(String, String)> { + // Extract from: RecordCounter = "record_counter" + let parts: Vec<&str> = line.split('=').collect(); + if parts.len() == 2 { + let const_name = parts[0].trim().to_string(); + let name_parts: Vec<&str> = parts[1].split('"').collect(); + if name_parts.len() >= 2 { + return Some((const_name, name_parts[1].to_string())); + } + } + None +} + +fn camel_to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, ch) in s.chars().enumerate() { + if ch.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(ch.to_ascii_lowercase()); + } else { + result.push(ch); + } + } + result +} + +fn parse_uuid_to_bytes(uuid_str: &str) -> Result<[u8; 16], Box> { + // Parse UUID string like "ccf2e3ba-633e-43ba-9394-46b0c54c61e3" into bytes + let hex_str = uuid_str.replace('-', ""); + if hex_str.len() != 32 { + return Err(format!("Invalid UUID length: {}", uuid_str).into()); + } + + let mut bytes = [0u8; 16]; + for i in 0..16 { + let byte_str = &hex_str[i * 2..i * 2 + 2]; + bytes[i] = u8::from_str_radix(byte_str, 16) + .map_err(|e| format!("Failed to parse hex byte {}: {}", byte_str, e))?; + } + + Ok(bytes) +} + +fn format_uuid_bytes(bytes: &[u8; 16]) -> String { + bytes + .iter() + .map(|b| format!("0x{:02x}", b)) + .collect::>() + .join(", ") +} diff --git a/rust/types/src/lib.rs b/rust/types/src/lib.rs index 3f6100897fe..f573b620421 100644 --- a/rust/types/src/lib.rs +++ b/rust/types/src/lib.rs @@ -11,6 +11,7 @@ mod flush; mod hnsw_configuration; mod metadata; mod operation; +pub mod operators; mod record; mod scalar_encoding; mod segment; @@ -20,6 +21,7 @@ mod spann_configuration; mod spann_posting_list; #[cfg(feature = "testing")] pub mod strategies; +mod task; mod tenant; mod validators; mod where_parsing; @@ -42,6 +44,7 @@ pub use flush::*; pub use hnsw_configuration::*; pub use metadata::*; pub use operation::*; +pub use operators::*; pub use record::*; pub use scalar_encoding::*; pub use segment::*; @@ -49,6 +52,7 @@ pub use segment_scope::*; pub use signed_rbm::*; pub use spann_configuration::*; pub use spann_posting_list::*; +pub use task::*; pub use tenant::*; pub use types::*; pub use where_parsing::*; diff --git a/rust/types/src/operators.rs b/rust/types/src/operators.rs new file mode 100644 index 00000000000..e0181a5b952 --- /dev/null +++ b/rust/types/src/operators.rs @@ -0,0 +1,31 @@ +//! Well-known operator IDs and names that are pre-populated in the database. +//! +//! ⚠️ **DO NOT EDIT THIS FILE MANUALLY** ⚠️ +//! +//! These constants are auto-generated at build time from the Go source file: +//! `go/pkg/sysdb/metastore/db/dbmodel/constants.go` +//! +//! The build script (`rust/types/build.rs`) parses the Go file and generates +//! the Rust constants automatically. This ensures Go and Rust always stay in sync. +//! +//! ## To add a new operator: +//! +//! 1. Create a database migration to INSERT the operator +//! (in `go/pkg/sysdb/metastore/db/migrations/*.sql`) +//! +//! 2. Add the UUID constant to `go/pkg/sysdb/metastore/db/dbmodel/constants.go`: +//! ```go +//! OperatorMyOperator = uuid.MustParse("your-uuid-here") +//! OperatorNameMyOperator = "my_operator" +//! ``` +//! +//! 3. Rebuild: `cargo build -p chroma-types` +//! +//! 4. The constants will be automatically available in Rust: +//! - `OPERATOR_MY_OPERATOR_ID` +//! - `OPERATOR_MY_OPERATOR_NAME` +//! +//! See `rust/types/README_OPERATORS.md` for more details. + +// Include the auto-generated constants from the build script +include!(concat!(env!("OUT_DIR"), "/operators_generated.rs")); diff --git a/rust/types/src/task.rs b/rust/types/src/task.rs new file mode 100644 index 00000000000..d9721f18339 --- /dev/null +++ b/rust/types/src/task.rs @@ -0,0 +1,87 @@ +use serde::{Deserialize, Serialize}; +use std::time::SystemTime; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::CollectionUuid; + +/// TaskUuid is a wrapper around Uuid to provide a type for task identifiers. +#[derive( + Copy, + Clone, + Debug, + Default, + Deserialize, + Eq, + PartialEq, + Ord, + PartialOrd, + Hash, + Serialize, + ToSchema, +)] +pub struct TaskUuid(pub Uuid); + +impl TaskUuid { + pub fn new() -> Self { + TaskUuid(Uuid::new_v4()) + } +} + +impl std::str::FromStr for TaskUuid { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + match Uuid::parse_str(s) { + Ok(uuid) => Ok(TaskUuid(uuid)), + Err(err) => Err(err), + } + } +} + +impl std::fmt::Display for TaskUuid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Task represents an asynchronous task that is triggered by collection writes +/// to map records from a source collection to a target collection. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Task { + /// Unique identifier for the task + pub id: TaskUuid, + /// Human-readable name for the task instance + pub name: String, + /// Identifier for the operator/built-in definition this task uses + pub operator_id: String, + /// Source collection that triggers the task + pub input_collection_id: CollectionUuid, + /// Name of target collection where task output is stored + pub output_collection_name: String, + /// ID of the output collection (lazily filled in after creation) + pub output_collection_id: Option, + /// Optional JSON parameters for the operator + pub params: Option, + /// Tenant this task belongs to + pub tenant_id: String, + /// Database this task belongs to + pub database_id: String, + /// Timestamp of the last successful task run + #[serde(skip, default)] + pub last_run: Option, + /// Timestamp when the task should next run (None if not yet scheduled) + #[serde(skip, default)] + pub next_run: Option, + /// Completion offset: the WAL position up to which the task has processed records + pub completion_offset: u64, + /// Minimum number of new records required before the task runs again + pub min_records_for_task: u64, + /// Whether the task has been soft-deleted + #[serde(skip, default)] + pub is_deleted: bool, + /// Timestamp when the task was created + pub created_at: SystemTime, + /// Timestamp when the task was last updated + pub updated_at: SystemTime, +}