Skip to content

Commit 728aefb

Browse files
committed
Add models and migration for Create/Delete/Get Tasks
1 parent 9309371 commit 728aefb

File tree

25 files changed

+1014
-45
lines changed

25 files changed

+1014
-45
lines changed

go/pkg/common/errors.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ var (
4848
// Segment metadata errors
4949
ErrUnknownSegmentMetadataType = errors.New("segment metadata value type not supported")
5050

51+
// Task errors
52+
ErrTaskAlreadyExists = errors.New("the task that was being created already exists for this collection")
53+
54+
// Operator errors
55+
ErrOperatorNotFound = errors.New("operator not found")
56+
5157
// Others
5258
ErrCompactionOffsetSomehowAhead = errors.New("system invariant was violated. Compaction offset in sysdb should always be behind or equal to offset in log")
5359
)

go/pkg/sysdb/coordinator/task.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package coordinator
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/chroma-core/chroma/go/pkg/common"
8+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
9+
"github.com/chroma-core/chroma/go/pkg/sysdb/coordinator/model"
10+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
11+
"github.com/chroma-core/chroma/go/pkg/types"
12+
"github.com/google/uuid"
13+
"github.com/pingcap/log"
14+
"go.uber.org/zap"
15+
"google.golang.org/protobuf/proto"
16+
)
17+
18+
// CreateTask creates a new task in the database
19+
func (s *Coordinator) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
20+
// Generate new task UUID
21+
taskID := uuid.New()
22+
var outputCollectionID string
23+
24+
// Look up database_id from databases table using database name and tenant
25+
databases, err := s.catalog.metaDomain.DatabaseDb(ctx).GetDatabases(req.TenantId, req.Database)
26+
if err != nil {
27+
log.Error("CreateTask: failed to get database", zap.Error(err))
28+
return nil, err
29+
}
30+
if len(databases) == 0 {
31+
log.Error("CreateTask: database not found")
32+
return nil, common.ErrDatabaseNotFound
33+
}
34+
35+
// Look up operator by name from the operators table
36+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByName(req.OperatorName)
37+
if err != nil {
38+
log.Error("CreateTask: failed to get operator", zap.Error(err))
39+
return nil, err
40+
}
41+
if operator == nil {
42+
log.Error("CreateTask: operator not found", zap.String("operator_name", req.OperatorName))
43+
return nil, common.ErrOperatorNotFound
44+
}
45+
operatorID := operator.OperatorID
46+
47+
// Generate UUIDv7 for time-ordered nonce
48+
nextNonce, err := uuid.NewV7()
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
// Create output collection and task transactionally
54+
err = s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error {
55+
// Create the output collection
56+
createCollection := &model.CreateCollection{
57+
ID: types.NewUniqueID(),
58+
Name: req.OutputCollectionName,
59+
TenantID: req.TenantId,
60+
DatabaseName: req.Database,
61+
GetOrCreate: false, // Don't get existing, must be new
62+
Ts: types.Timestamp(time.Now().Unix()),
63+
}
64+
65+
collection, _, err := s.catalog.CreateCollection(txCtx, createCollection, createCollection.Ts)
66+
if err != nil {
67+
log.Error("CreateTask: failed to create output collection", zap.Error(err))
68+
return err
69+
}
70+
outputCollectionID = collection.ID.String()
71+
72+
// Create the task model
73+
now := time.Now()
74+
task := &dbmodel.Task{
75+
ID: taskID,
76+
Name: req.Name,
77+
TenantID: req.TenantId,
78+
DatabaseID: databases[0].ID,
79+
InputCollectionID: req.InputCollectionId,
80+
OutputCollectionID: outputCollectionID,
81+
OperatorID: operatorID,
82+
OperatorParams: req.Params,
83+
CompletionOffset: 0,
84+
LastRun: nil,
85+
NextRun: nil, // Will be scheduled by task scheduler
86+
MinRecordsForTask: int64(req.MinRecordsForTask),
87+
CurrentAttempts: 0,
88+
CreatedAt: now,
89+
UpdatedAt: now,
90+
NextNonce: nextNonce,
91+
OldestWrittenNonce: nil,
92+
}
93+
94+
// Insert task into database
95+
err = s.catalog.metaDomain.TaskDb(txCtx).Insert(task)
96+
if err != nil {
97+
log.Error("CreateTask: failed to insert task", zap.Error(err))
98+
return err
99+
}
100+
101+
return nil
102+
})
103+
104+
if err != nil {
105+
return nil, err
106+
}
107+
108+
log.Info("Task created", zap.String("task_id", taskID.String()), zap.String("name", req.Name))
109+
110+
return &coordinatorpb.CreateTaskResponse{
111+
TaskId: taskID.String(),
112+
}, nil
113+
}
114+
115+
// GetTaskByName retrieves a task by name from the database
116+
func (s *Coordinator) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
117+
task, err := s.catalog.metaDomain.TaskDb(ctx).GetByName(req.InputCollectionId, req.TaskName)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
// If task not found, return empty response
123+
if task == nil {
124+
return &coordinatorpb.GetTaskByNameResponse{}, nil
125+
}
126+
127+
// Look up operator name from operators table
128+
operator, err := s.catalog.metaDomain.OperatorDb(ctx).GetByID(task.OperatorID)
129+
if err != nil {
130+
log.Error("GetTaskByName: failed to get operator", zap.Error(err))
131+
return nil, err
132+
}
133+
if operator == nil {
134+
log.Error("GetTaskByName: operator not found", zap.String("operator_id", task.OperatorID.String()))
135+
return nil, common.ErrOperatorNotFound
136+
}
137+
138+
// Debug logging
139+
log.Info("Found task", zap.String("task_id", task.ID.String()), zap.String("name", task.Name), zap.String("input_collection_id", task.InputCollectionID), zap.String("output_collection_id", task.OutputCollectionID))
140+
141+
// Convert task to response
142+
return &coordinatorpb.GetTaskByNameResponse{
143+
TaskId: proto.String(task.ID.String()),
144+
Name: proto.String(task.Name),
145+
OperatorName: proto.String(operator.OperatorName),
146+
InputCollectionId: proto.String(task.InputCollectionID),
147+
OutputCollectionName: proto.String(task.OutputCollectionID),
148+
Params: proto.String(task.OperatorParams),
149+
CompletionOffset: proto.Int64(task.CompletionOffset),
150+
MinRecordsForTask: proto.Uint64(uint64(task.MinRecordsForTask)),
151+
}, nil
152+
}
153+
154+
// DeleteTask soft deletes a task by name
155+
func (s *Coordinator) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
156+
// lookup database_id from databases table using database name and tenant
157+
databases, err := s.catalog.metaDomain.DatabaseDb(ctx).GetDatabases(req.TenantId, req.Database)
158+
if err != nil {
159+
log.Error("DeleteTask failed", zap.Error(err))
160+
return nil, err
161+
}
162+
if len(databases) == 0 {
163+
log.Error("DeleteTask failed: database not found")
164+
return nil, common.ErrDatabaseNotFound
165+
}
166+
err = s.catalog.metaDomain.TaskDb(ctx).SoftDelete(req.TenantId, databases[0].ID, req.TaskName)
167+
if err != nil {
168+
log.Error("DeleteTask failed", zap.Error(err))
169+
return nil, err
170+
}
171+
172+
log.Info("Task deleted", zap.String("tenant", req.TenantId), zap.String("database", req.Database), zap.String("task_name", req.TaskName))
173+
174+
return &coordinatorpb.DeleteTaskResponse{
175+
Success: true,
176+
}, nil
177+
}

go/pkg/sysdb/grpc/task_service.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package grpc
2+
3+
import (
4+
"context"
5+
6+
"github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb"
7+
"github.com/pingcap/log"
8+
"go.uber.org/zap"
9+
)
10+
11+
func (s *Server) CreateTask(ctx context.Context, req *coordinatorpb.CreateTaskRequest) (*coordinatorpb.CreateTaskResponse, error) {
12+
log.Info("CreateTask", zap.String("name", req.Name), zap.String("operator_name", req.OperatorName))
13+
14+
res, err := s.coordinator.CreateTask(ctx, req)
15+
if err != nil {
16+
log.Error("CreateTask failed", zap.Error(err))
17+
return nil, err
18+
}
19+
20+
return res, nil
21+
}
22+
23+
func (s *Server) GetTaskByName(ctx context.Context, req *coordinatorpb.GetTaskByNameRequest) (*coordinatorpb.GetTaskByNameResponse, error) {
24+
log.Info("GetTaskByName", zap.String("input_collection_id", req.InputCollectionId), zap.String("task_name", req.TaskName))
25+
26+
res, err := s.coordinator.GetTaskByName(ctx, req)
27+
if err != nil {
28+
log.Error("GetTaskByName failed", zap.Error(err))
29+
return nil, err
30+
}
31+
32+
return res, nil
33+
}
34+
35+
func (s *Server) DeleteTask(ctx context.Context, req *coordinatorpb.DeleteTaskRequest) (*coordinatorpb.DeleteTaskResponse, error) {
36+
log.Info("DeleteTask", zap.String("tenant", req.TenantId), zap.String("database", req.Database), zap.String("task_name", req.TaskName))
37+
38+
res, err := s.coordinator.DeleteTask(ctx, req)
39+
if err != nil {
40+
log.Error("DeleteTask failed", zap.Error(err))
41+
return nil, err
42+
}
43+
44+
return res, nil
45+
}

go/pkg/sysdb/metastore/db/dao/common.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,11 @@ func (*MetaDomain) SegmentDb(ctx context.Context) dbmodel.ISegmentDb {
3636
func (*MetaDomain) SegmentMetadataDb(ctx context.Context) dbmodel.ISegmentMetadataDb {
3737
return &segmentMetadataDb{dbcore.GetDB(ctx)}
3838
}
39+
40+
func (*MetaDomain) TaskDb(ctx context.Context) dbmodel.ITaskDb {
41+
return &taskDb{dbcore.GetDB(ctx)}
42+
}
43+
44+
func (*MetaDomain) OperatorDb(ctx context.Context) dbmodel.IOperatorDb {
45+
return &operatorDb{dbcore.GetDB(ctx)}
46+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package dao
2+
3+
import (
4+
"errors"
5+
6+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
7+
"github.com/google/uuid"
8+
"github.com/pingcap/log"
9+
"go.uber.org/zap"
10+
"gorm.io/gorm"
11+
)
12+
13+
type operatorDb struct {
14+
db *gorm.DB
15+
}
16+
17+
var _ dbmodel.IOperatorDb = &operatorDb{}
18+
19+
func (s *operatorDb) GetByName(operatorName string) (*dbmodel.Operator, error) {
20+
var operator dbmodel.Operator
21+
err := s.db.
22+
Where("operator_name = ?", operatorName).
23+
First(&operator).Error
24+
25+
if err != nil {
26+
if errors.Is(err, gorm.ErrRecordNotFound) {
27+
log.Error("GetOperatorByName: operator not found", zap.String("operator_name", operatorName))
28+
return nil, nil
29+
}
30+
log.Error("GetOperatorByName failed", zap.Error(err))
31+
return nil, err
32+
}
33+
return &operator, nil
34+
}
35+
36+
func (s *operatorDb) GetByID(operatorID uuid.UUID) (*dbmodel.Operator, error) {
37+
var operator dbmodel.Operator
38+
err := s.db.
39+
Where("operator_id = ?", operatorID).
40+
First(&operator).Error
41+
42+
if err != nil {
43+
if errors.Is(err, gorm.ErrRecordNotFound) {
44+
return nil, nil
45+
}
46+
log.Error("GetOperatorByID failed", zap.Error(err))
47+
return nil, err
48+
}
49+
return &operator, nil
50+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package dao
2+
3+
import (
4+
"errors"
5+
6+
"github.com/chroma-core/chroma/go/pkg/common"
7+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
8+
"github.com/jackc/pgx/v5/pgconn"
9+
"github.com/pingcap/log"
10+
"go.uber.org/zap"
11+
"gorm.io/gorm"
12+
)
13+
14+
type taskDb struct {
15+
db *gorm.DB
16+
}
17+
18+
var _ dbmodel.ITaskDb = &taskDb{}
19+
20+
func (s *taskDb) DeleteAll() error {
21+
return s.db.Where("1 = 1").Delete(&dbmodel.Task{}).Error
22+
}
23+
24+
func (s *taskDb) Insert(task *dbmodel.Task) error {
25+
err := s.db.Create(task).Error
26+
if err != nil {
27+
log.Error("insert task failed", zap.Error(err))
28+
var pgErr *pgconn.PgError
29+
ok := errors.As(err, &pgErr)
30+
if ok {
31+
log.Error("Postgres Error")
32+
switch pgErr.Code {
33+
case "23505":
34+
log.Error("task already exists")
35+
return common.ErrTaskAlreadyExists
36+
default:
37+
return err
38+
}
39+
}
40+
return err
41+
}
42+
return nil
43+
}
44+
45+
func (s *taskDb) GetByName(inputCollectionID string, taskName string) (*dbmodel.Task, error) {
46+
var task dbmodel.Task
47+
err := s.db.
48+
Where("input_collection_id = ?", inputCollectionID).
49+
Where("task_name = ?", taskName).
50+
Where("is_deleted = ?", false).
51+
First(&task).Error
52+
53+
if err != nil {
54+
if errors.Is(err, gorm.ErrRecordNotFound) {
55+
return nil, nil
56+
}
57+
log.Error("GetTaskByName failed", zap.Error(err))
58+
return nil, err
59+
}
60+
return &task, nil
61+
}
62+
63+
func (s *taskDb) SoftDelete(tenantID string, databaseID string, taskName string) error {
64+
return s.db.Table("tasks").
65+
Where("tenant_id = ?", tenantID).
66+
Where("database_id = ?", databaseID).
67+
Where("task_name = ?", taskName).
68+
Updates(map[string]interface{}{
69+
"is_deleted": true,
70+
}).Error
71+
}

0 commit comments

Comments
 (0)