Skip to content

Commit 3b40333

Browse files
committed
Add models and migration for Create/Delete/Get Tasks
1 parent 2c9797a commit 3b40333

28 files changed

+1252
-45
lines changed

go/pkg/common/errors.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ var (
4848
// Segment metadata errors
4949
ErrUnknownSegmentMetadataType = errors.New("segment metadata value type not supported")
5050

51+
// Task errors
52+
ErrTaskAlreadyExists = errors.New("the task that was being created already exists for this collection")
53+
ErrTaskNotFound = errors.New("the requested task was not found")
54+
55+
// Operator errors
56+
ErrOperatorNotFound = errors.New("operator not found")
57+
5158
// Others
5259
ErrCompactionOffsetSomehowAhead = errors.New("system invariant was violated. Compaction offset in sysdb should always be behind or equal to offset in log")
5360
)

go/pkg/sysdb/coordinator/task.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package coordinator
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/chroma-core/chroma/go/pkg/common"
8+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
9+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
10+
"github.com/google/uuid"
11+
"github.com/pingcap/log"
12+
"go.uber.org/zap"
13+
"google.golang.org/protobuf/proto"
14+
)
15+
16+
// CreateTask creates a new task in the database
17+
func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
18+
// Generate new task UUID
19+
taskID := uuid.New()
20+
outputCollectionName := req.OutputCollectionName
21+
22+
// Look up database_id from databases table using database name and tenant
23+
databases, err := s.catalog.metaDomain.DatabaseDb(ctx).GetDatabases(req.TenantId, req.Database)
24+
if err != nil {
25+
log.Error("CreateTask: failed to get database", zap.Error(err))
26+
return nil, err
27+
}
28+
if len(databases) == 0 {
29+
log.Error("CreateTask: database not found")
30+
return nil, common.ErrDatabaseNotFound
31+
}
32+
33+
// Look up operator by name from the operators table
34+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByName(req.OperatorName)
35+
if err != nil {
36+
log.Error("CreateTask: failed to get operator", zap.Error(err))
37+
return nil, err
38+
}
39+
if operator == nil {
40+
log.Error("CreateTask: operator not found", zap.String("operator_name", req.OperatorName))
41+
return nil, common.ErrOperatorNotFound
42+
}
43+
operatorID := operator.OperatorID
44+
45+
// Generate UUIDv7 for time-ordered nonce
46+
nextNonce, err := uuid.NewV7()
47+
if err != nil {
48+
return nil, err
49+
}
50+
51+
// TODO(tanujnay112): Can combine the two collection checks into one
52+
// Check if input collection exists
53+
collections, err := s.catalog.metaDomain.CollectionDb(ctx).GetCollections([]string{req.InputCollectionId}, nil, req.TenantId, req.Database, nil, nil, false)
54+
if err != nil {
55+
log.Error("CreateTask: failed to get input collection", zap.Error(err))
56+
return nil, err
57+
}
58+
if len(collections) == 0 {
59+
log.Error("CreateTask: input collection not found")
60+
return nil, common.ErrCollectionNotFound
61+
}
62+
63+
// Check if output collection already exists
64+
existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(ctx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false)
65+
if err != nil {
66+
log.Error("CreateTask: failed to check output collection", zap.Error(err))
67+
return nil, err
68+
}
69+
if len(existingOutputCollections) > 0 {
70+
log.Error("CreateTask: output collection already exists")
71+
return nil, common.ErrCollectionUniqueConstraintViolation
72+
}
73+
74+
// Check if task already exists
75+
existingTask, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.Name)
76+
if err != nil {
77+
log.Error("CreateTask: failed to check task", zap.Error(err))
78+
return nil, err
79+
}
80+
if existingTask != nil {
81+
log.Error("CreateTask: task already exists")
82+
return nil, common.ErrTaskAlreadyExists
83+
}
84+
85+
now := time.Now()
86+
task := &dbmodel.Task{
87+
ID: taskID,
88+
Name: req.Name,
89+
TenantID: req.TenantId,
90+
DatabaseID: databases[0].ID,
91+
InputCollectionID: req.InputCollectionId,
92+
OutputCollectionName: req.OutputCollectionName,
93+
OperatorID: operatorID,
94+
OperatorParams: req.Params,
95+
CompletionOffset: 0,
96+
LastRun: nil,
97+
NextRun: nil, // Will be set to zero initially, scheduled by task scheduler
98+
MinRecordsForTask: int64(req.MinRecordsForTask),
99+
CurrentAttempts: 0,
100+
CreatedAt: now,
101+
UpdatedAt: now,
102+
NextNonce: nextNonce,
103+
OldestWrittenNonce: nil,
104+
}
105+
106+
// Try to insert task into database
107+
err = s.catalog.metaDomain.TaskDb(ctx).Insert(task)
108+
if err != nil {
109+
// Check if it's a unique constraint violation (concurrent creation)
110+
if err == common.ErrTaskAlreadyExists {
111+
log.Error("CreateTask: task already exists")
112+
return nil, common.ErrTaskAlreadyExists
113+
}
114+
log.Error("CreateTask: failed to insert task", zap.Error(err))
115+
return nil, err
116+
}
117+
118+
log.Info("Task created successfully", zap.String("task_id", taskID.String()), zap.String("name", req.Name), zap.String("output_collection_name", outputCollectionName))
119+
return &coordinatorpb.CreateTaskResponse{
120+
TaskId: taskID.String(),
121+
}, nil
122+
}
123+
124+
// GetTaskByName retrieves a task by name from the database
125+
func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
126+
// Can do both calls with a JOIN
127+
task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName)
128+
if err != nil {
129+
return nil, err
130+
}
131+
132+
// If task not found, return empty response
133+
if task == nil {
134+
return nil, common.ErrTaskNotFound
135+
}
136+
137+
// Look up operator name from operators table
138+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID)
139+
if err != nil {
140+
log.Error("GetTaskByName: failed to get operator", zap.Error(err))
141+
return nil, err
142+
}
143+
if operator == nil {
144+
log.Error("GetTaskByName: operator not found", zap.String("operator_id", task.OperatorID.String()))
145+
return nil, common.ErrOperatorNotFound
146+
}
147+
148+
// Debug logging
149+
log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName))
150+
151+
// Convert task to response
152+
return &coordinatorpb.GetTaskByNameResponse{
153+
TaskId: proto.String(task.ID.String()),
154+
Name: proto.String(task.Name),
155+
OperatorName: proto.String(operator.OperatorName),
156+
InputCollectionId: proto.String(task.InputCollectionID),
157+
OutputCollectionName: proto.String(task.OutputCollectionName),
158+
Params: proto.String(task.OperatorParams),
159+
CompletionOffset: proto.Int64(task.CompletionOffset),
160+
MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)),
161+
}, nil
162+
}
163+
164+
// DeleteTask soft deletes a task by name
165+
func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
166+
err := s.catalog.metaDomain.TaskDb(ctx).SoftDelete(req.InputCollectionId, req.TaskName)
167+
if err != nil {
168+
log.Error("DeleteTask failed", zap.Error(err))
169+
return nil, err
170+
}
171+
172+
log.Info("Task deleted", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
173+
174+
return &coordinatorpb.DeleteTaskResponse{
175+
Success: true,
176+
}, nil
177+
}

go/pkg/sysdb/grpc/task_service.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
6+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
7+
"github.com/pingcap/log"
8+
"go.uber.org/zap"
9+
)
10+
11+
func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
12+
log.Info("CreateTask", zap.String("name", req.Name), zap.String("operator_name", req.OperatorName))
13+
14+
res, err := s.coordinator.CreateTask(ctx, req)
15+
if err != nil {
16+
log.Error("CreateTask failed", zap.Error(err))
17+
return nil, err
18+
}
19+
20+
return res, nil
21+
}
22+
23+
func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
24+
log.Info("GetTaskByName", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
25+
26+
res, err := s.coordinator.GetTaskByName(ctx, req)
27+
if err != nil {
28+
log.Error("GetTaskByName failed", zap.Error(err))
29+
return nil, err
30+
}
31+
32+
return res, nil
33+
}
34+
35+
func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
36+
log.Info("DeleteTask", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
37+
38+
res, err := s.coordinator.DeleteTask(ctx, req)
39+
if err != nil {
40+
log.Error("DeleteTask failed", zap.Error(err))
41+
return nil, err
42+
}
43+
44+
return res, nil
45+
}

go/pkg/sysdb/metastore/db/dao/common.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ func (*MetaDomain) SegmentDb(ctx context.Context) dbmodel.ISegmentDb {
3636
func (*MetaDomain) SegmentMetadataDb(ctx context.Context) dbmodel.ISegmentMetadataDb {
3737
return &segmentMetadataDb{dbcore.GetDB(ctx)}
3838
}
39+
40+
func (*MetaDomain) TaskDb(ctx context.Context) dbmodel.ITaskDb {
41+
return &taskDb{dbcore.GetDB(ctx)}
42+
}
43+
44+
func (*MetaDomain) OperatorDb(ctx context.Context) dbmodel.IOperatorDb {
45+
return &operatorDb{dbcore.GetDB(ctx)}
46+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package dao
2+
3+
import (
4+
"errors"
5+
6+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
7+
"github.com/google/uuid"
8+
"github.com/pingcap/log"
9+
"go.uber.org/zap"
10+
"gorm.io/gorm"
11+
)
12+
13+
type operatorDb struct {
14+
db *gorm.DB
15+
}
16+
17+
var _ dbmodel.IOperatorDb = &operatorDb{}
18+
19+
func (s *operatorDb) GetByName(operatorName string) (*dbmodel.Operator, error) {
20+
var operator dbmodel.Operator
21+
err := s.db.
22+
Where("operator_name = ?", operatorName).
23+
First(&operator).Error
24+
25+
if err != nil {
26+
if errors.Is(err, gorm.ErrRecordNotFound) {
27+
return nil, nil
28+
}
29+
log.Error("GetOperatorByName failed", zap.Error(err))
30+
return nil, err
31+
}
32+
return &operator, nil
33+
}
34+
35+
func (s *operatorDb) GetByID(operatorID uuid.UUID) (*dbmodel.Operator, error) {
36+
var operator dbmodel.Operator
37+
err := s.db.
38+
Where("operator_id = ?", operatorID).
39+
First(&operator).Error
40+
41+
if err != nil {
42+
if errors.Is(err, gorm.ErrRecordNotFound) {
43+
return nil, nil
44+
}
45+
log.Error("GetOperatorByID failed", zap.Error(err))
46+
return nil, err
47+
}
48+
return &operator, nil
49+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package dao
2+
3+
import (
4+
"errors"
5+
6+
"github.com/chroma-core/chroma/go/pkg/common"
7+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
8+
"github.com/jackc/pgx/v5/pgconn"
9+
"github.com/pingcap/log"
10+
"go.uber.org/zap"
11+
"gorm.io/gorm"
12+
)
13+
14+
type taskDb struct {
15+
db *gorm.DB
16+
}
17+
18+
var _ dbmodel.ITaskDb = &taskDb{}
19+
20+
func (s *taskDb) DeleteAll() error {
21+
return s.db.Where("1 = 1").Delete(&dbmodel.Task{}).Error
22+
}
23+
24+
func (s *taskDb) Insert(task *dbmodel.Task) error {
25+
err := s.db.Create(task).Error
26+
if err != nil {
27+
log.Error("insert task failed", zap.Error(err))
28+
var pgErr *pgconn.PgError
29+
ok := errors.As(err, &pgErr)
30+
if ok {
31+
switch pgErr.Code {
32+
case "23505":
33+
log.Error("task already exists")
34+
return common.ErrTaskAlreadyExists
35+
default:
36+
return err
37+
}
38+
}
39+
return err
40+
}
41+
return nil
42+
}
43+
44+
func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.Task, error) {
45+
var task dbmodel.Task
46+
err := s.db.
47+
Where("input_collection_id = ?", inputCollectionID).
48+
Where("task_name = ?", taskName).
49+
Where("is_deleted = ?", false).
50+
First(&task).Error
51+
52+
if err != nil {
53+
if errors.Is(err, gorm.ErrRecordNotFound) {
54+
return nil, nil
55+
}
56+
log.Error("GetTaskByName failed", zap.Error(err))
57+
return nil, err
58+
}
59+
return &task, nil
60+
}
61+
62+
func (s *taskDb) SoftDelete(inputCollectionID string, taskName string) error {
63+
// Update task name and is_deleted in a single query
64+
// Format: _deleted_<original_name>_<input_collection_id>_<task_id>
65+
result := s.db.Exec(`
66+
UPDATE tasks
67+
SET task_name = CONCAT('_deleted_', task_name, '_', input_collection_id, '_', task_id::text),
68+
is_deleted = true
69+
WHERE input_collection_id = ?
70+
AND task_name = ?
71+
AND is_deleted = false
72+
`, inputCollectionID, taskName)
73+
74+
if result.Error != nil {
75+
log.Error("SoftDelete failed", zap.Error(result.Error))
76+
return result.Error
77+
}
78+
79+
// If no rows were affected, task was not found (or already deleted)
80+
if result.RowsAffected == 0 {
81+
return nil // Idempotent - no error if already deleted or not found
82+
}
83+
84+
return nil
85+
}

0 commit comments

Comments
 (0)