Skip to content

Commit e7ce88d

Browse files
committed
Add models and migration for Create/Delete/Get Tasks
1 parent 2c9797a commit e7ce88d

33 files changed

+1671
-46
lines changed

go/pkg/common/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ var (
4848
// Segment metadata errors
4949
ErrUnknownSegmentMetadataType = errors.New("segment metadata value type not supported")
5050

51+
// Task errors
52+
ErrTaskAlreadyExists = errors.New("the task that was being created already exists for this collection")
53+
ErrTaskNotFound = errors.New("the requested task was not found")
54+
ErrInvalidTaskName = errors.New("task name cannot start with reserved prefix '_deleted_'")
55+
56+
// Operator errors
57+
ErrOperatorNotFound = errors.New("operator not found")
58+
5159
// Others
5260
ErrCompactionOffsetSomehowAhead = errors.New("system invariant was violated. Compaction offset in sysdb should always be behind or equal to offset in log")
5361
)

go/pkg/sysdb/coordinator/task.go

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package coordinator
2+
3+
import (
4+
"context"
5+
"strings"
6+
"time"
7+
8+
"github.com/chroma-core/chroma/go/pkg/common"
9+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
10+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
11+
"github.com/google/uuid"
12+
"github.com/pingcap/log"
13+
"go.uber.org/zap"
14+
"google.golang.org/protobuf/proto"
15+
)
16+
17+
// CreateTask creates a new task in the database
18+
func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
19+
// Validate task name doesn't start with soft-deletion reserved prefix
20+
if strings.HasPrefix(req.Name, "_deleted_") {
21+
log.Error("CreateTask: task name cannot start with _deleted_")
22+
return nil, common.ErrInvalidTaskName
23+
}
24+
25+
// Generate new task UUID
26+
taskID := uuid.New()
27+
outputCollectionName := req.OutputCollectionName
28+
29+
// Look up database_id from databases table using database name and tenant
30+
databases, err := s.catalog.metaDomain.DatabaseDb(ctx).GetDatabases(req.TenantId, req.Database)
31+
if err != nil {
32+
log.Error("CreateTask: failed to get database", zap.Error(err))
33+
return nil, err
34+
}
35+
if len(databases) == 0 {
36+
log.Error("CreateTask: database not found")
37+
return nil, common.ErrDatabaseNotFound
38+
}
39+
40+
// Look up operator by name from the operators table
41+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByName(req.OperatorName)
42+
if err != nil {
43+
log.Error("CreateTask: failed to get operator", zap.Error(err))
44+
return nil, err
45+
}
46+
if operator == nil {
47+
log.Error("CreateTask: operator not found", zap.String("operator_name", req.OperatorName))
48+
return nil, common.ErrOperatorNotFound
49+
}
50+
operatorID := operator.OperatorID
51+
52+
// Generate UUIDv7 for time-ordered nonce
53+
nextNonce, err := uuid.NewV7()
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
// TODO(tanujnay112): Can combine the two collection checks into one
59+
// Check if input collection exists
60+
collections, err := s.catalog.metaDomain.CollectionDb(ctx).GetCollections([]string{req.InputCollectionId}, nil, req.TenantId, req.Database, nil, nil, false)
61+
if err != nil {
62+
log.Error("CreateTask: failed to get input collection", zap.Error(err))
63+
return nil, err
64+
}
65+
if len(collections) == 0 {
66+
log.Error("CreateTask: input collection not found")
67+
return nil, common.ErrCollectionNotFound
68+
}
69+
70+
// Check if output collection already exists
71+
existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(ctx).GetCollections(nil, &outputCollectionName, req.TenantId, req.Database, nil, nil, false)
72+
if err != nil {
73+
log.Error("CreateTask: failed to check output collection", zap.Error(err))
74+
return nil, err
75+
}
76+
if len(existingOutputCollections) > 0 {
77+
log.Error("CreateTask: output collection already exists")
78+
return nil, common.ErrCollectionUniqueConstraintViolation
79+
}
80+
81+
// Check if task already exists
82+
existingTask, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.Name)
83+
if err != nil {
84+
log.Error("CreateTask: failed to check task", zap.Error(err))
85+
return nil, err
86+
}
87+
if existingTask != nil {
88+
log.Info("CreateTask: task already exists, returning existing")
89+
return &coordinatorpb.CreateTaskResponse{
90+
TaskId: existingTask.ID.String(),
91+
}, nil
92+
}
93+
94+
now := time.Now()
95+
task := &dbmodel.Task{
96+
ID: taskID,
97+
Name: req.Name,
98+
TenantID: req.TenantId,
99+
DatabaseID: databases[0].ID,
100+
InputCollectionID: req.InputCollectionId,
101+
OutputCollectionName: req.OutputCollectionName,
102+
OperatorID: operatorID,
103+
OperatorParams: req.Params,
104+
CompletionOffset: 0,
105+
LastRun: nil,
106+
NextRun: nil, // Will be set to zero initially, scheduled by task scheduler
107+
MinRecordsForTask: int64(req.MinRecordsForTask),
108+
CurrentAttempts: 0,
109+
CreatedAt: now,
110+
UpdatedAt: now,
111+
NextNonce: nextNonce,
112+
OldestWrittenNonce: nil,
113+
}
114+
115+
// Try to insert task into database
116+
err = s.catalog.metaDomain.TaskDb(ctx).Insert(task)
117+
if err != nil {
118+
// Check if it's a unique constraint violation (concurrent creation)
119+
if err == common.ErrTaskAlreadyExists {
120+
log.Error("CreateTask: task already exists")
121+
return nil, common.ErrTaskAlreadyExists
122+
}
123+
log.Error("CreateTask: failed to insert task", zap.Error(err))
124+
return nil, err
125+
}
126+
127+
log.Info("Task created successfully", zap.String("task_id", taskID.String()), zap.String("name", req.Name), zap.String("output_collection_name", outputCollectionName))
128+
return &coordinatorpb.CreateTaskResponse{
129+
TaskId: taskID.String(),
130+
}, nil
131+
}
132+
133+
// GetTaskByName retrieves a task by name from the database
134+
func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
135+
// Can do both calls with a JOIN
136+
task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName)
137+
if err != nil {
138+
return nil, err
139+
}
140+
141+
// If task not found, return empty response
142+
if task == nil {
143+
return nil, common.ErrTaskNotFound
144+
}
145+
146+
// Look up operator name from operators table
147+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID)
148+
if err != nil {
149+
log.Error("GetTaskByName: failed to get operator", zap.Error(err))
150+
return nil, err
151+
}
152+
if operator == nil {
153+
log.Error("GetTaskByName: operator not found", zap.String("operator_id", task.OperatorID.String()))
154+
return nil, common.ErrOperatorNotFound
155+
}
156+
157+
// Debug logging
158+
log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_name", task.OutputCollectionName))
159+
160+
// Convert task to response
161+
return &coordinatorpb.GetTaskByNameResponse{
162+
TaskId: proto.String(task.ID.String()),
163+
Name: proto.String(task.Name),
164+
OperatorName: proto.String(operator.OperatorName),
165+
InputCollectionId: proto.String(task.InputCollectionID),
166+
OutputCollectionName: proto.String(task.OutputCollectionName),
167+
Params: proto.String(task.OperatorParams),
168+
CompletionOffset: proto.Int64(task.CompletionOffset),
169+
MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)),
170+
}, nil
171+
}
172+
173+
// DeleteTask soft deletes a task by name
174+
func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
175+
err := s.catalog.metaDomain.TaskDb(ctx).SoftDelete(req.InputCollectionId, req.TaskName)
176+
if err != nil {
177+
log.Error("DeleteTask failed", zap.Error(err))
178+
return nil, err
179+
}
180+
181+
log.Info("Task deleted", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
182+
183+
return &coordinatorpb.DeleteTaskResponse{
184+
Success: true,
185+
}, nil
186+
}
187+
188+
// GetOperators retrieves all operators from the database
189+
func (s *Coordinator) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
190+
operators, err := s.catalog.metaDomain.OperatorDb(ctx).GetAll()
191+
if err != nil {
192+
log.Error("GetOperators failed", zap.Error(err))
193+
return nil, err
194+
}
195+
196+
// Convert to proto response
197+
protoOperators := make([]*coordinatorpb.Operator, len(operators))
198+
for i, op := range operators {
199+
protoOperators[i] = &coordinatorpb.Operator{
200+
Id: op.OperatorID.String(),
201+
Name: op.OperatorName,
202+
}
203+
}
204+
205+
log.Info("GetOperators succeeded", zap.Int("count", len(operators)))
206+
207+
return &coordinatorpb.GetOperatorsResponse{
208+
Operators: protoOperators,
209+
}, nil
210+
}

go/pkg/sysdb/grpc/task_service.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
6+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
7+
"github.com/pingcap/log"
8+
"go.uber.org/zap"
9+
)
10+
11+
func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
12+
log.Info("CreateTask", zap.String("name", req.Name), zap.String("operator_name", req.OperatorName))
13+
14+
res, err := s.coordinator.CreateTask(ctx, req)
15+
if err != nil {
16+
log.Error("CreateTask failed", zap.Error(err))
17+
return nil, err
18+
}
19+
20+
return res, nil
21+
}
22+
23+
func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
24+
log.Info("GetTaskByName", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
25+
26+
res, err := s.coordinator.GetTaskByName(ctx, req)
27+
if err != nil {
28+
log.Error("GetTaskByName failed", zap.Error(err))
29+
return nil, err
30+
}
31+
32+
return res, nil
33+
}
34+
35+
func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
36+
log.Info("DeleteTask", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
37+
38+
res, err := s.coordinator.DeleteTask(ctx, req)
39+
if err != nil {
40+
log.Error("DeleteTask failed", zap.Error(err))
41+
return nil, err
42+
}
43+
44+
return res, nil
45+
}
46+
47+
func (s *Server) GetOperators(ctx context.Context, req *coordinatorpb.GetOperatorsRequest) (*coordinatorpb.GetOperatorsResponse, error) {
48+
log.Info("GetOperators")
49+
50+
res, err := s.coordinator.GetOperators(ctx, req)
51+
if err != nil {
52+
log.Error("GetOperators failed", zap.Error(err))
53+
return nil, err
54+
}
55+
56+
return res, nil
57+
}

go/pkg/sysdb/metastore/db/dao/common.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ func (*MetaDomain) SegmentDb(ctx context.Context) dbmodel.ISegmentDb {
3636
func (*MetaDomain) SegmentMetadataDb(ctx context.Context) dbmodel.ISegmentMetadataDb {
3737
return &segmentMetadataDb{dbcore.GetDB(ctx)}
3838
}
39+
40+
func (*MetaDomain) TaskDb(ctx context.Context) dbmodel.ITaskDb {
41+
return &taskDb{dbcore.GetDB(ctx)}
42+
}
43+
44+
func (*MetaDomain) OperatorDb(ctx context.Context) dbmodel.IOperatorDb {
45+
return &operatorDb{dbcore.GetDB(ctx)}
46+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package dao
2+
3+
import (
4+
"errors"
5+
6+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
7+
"github.com/google/uuid"
8+
"github.com/pingcap/log"
9+
"go.uber.org/zap"
10+
"gorm.io/gorm"
11+
)
12+
13+
type operatorDb struct {
14+
db *gorm.DB
15+
}
16+
17+
var _ dbmodel.IOperatorDb = &operatorDb{}
18+
19+
func (s *operatorDb) GetByName(operatorName string) (*dbmodel.Operator, error) {
20+
var operator dbmodel.Operator
21+
err := s.db.
22+
Where("operator_name = ?", operatorName).
23+
First(&operator).Error
24+
25+
if err != nil {
26+
if errors.Is(err, gorm.ErrRecordNotFound) {
27+
return nil, nil
28+
}
29+
log.Error("GetOperatorByName failed", zap.Error(err))
30+
return nil, err
31+
}
32+
return &operator, nil
33+
}
34+
35+
func (s *operatorDb) GetByID(operatorID uuid.UUID) (*dbmodel.Operator, error) {
36+
var operator dbmodel.Operator
37+
err := s.db.
38+
Where("operator_id = ?", operatorID).
39+
First(&operator).Error
40+
41+
if err != nil {
42+
if errors.Is(err, gorm.ErrRecordNotFound) {
43+
return nil, nil
44+
}
45+
log.Error("GetOperatorByID failed", zap.Error(err))
46+
return nil, err
47+
}
48+
return &operator, nil
49+
}
50+
51+
func (s *operatorDb) GetAll() ([]*dbmodel.Operator, error) {
52+
var operators []*dbmodel.Operator
53+
err := s.db.Find(&operators).Error
54+
55+
if err != nil {
56+
log.Error("GetAllOperators failed", zap.Error(err))
57+
return nil, err
58+
}
59+
return operators, nil
60+
}

0 commit comments

Comments
 (0)