-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
NEOS-1643: Integrates JobHooks into Data Sync Workflow (#2991)
- Loading branch information
Showing
11 changed files
with
567 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package license | ||
|
||
type CascadeLicense struct { | ||
isValid bool | ||
} | ||
|
||
// Checks multiple licenses in input order to see if any are valid | ||
func NewCascadeLicense(licenses ...EEInterface) *CascadeLicense { | ||
isValid := false | ||
for _, l := range licenses { | ||
if l.IsValid() { | ||
isValid = true | ||
break | ||
} | ||
} | ||
return &CascadeLicense{isValid: isValid} | ||
} | ||
|
||
func (c *CascadeLicense) IsValid() bool { | ||
return c.isValid | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package license | ||
|
||
import "github.com/spf13/viper" | ||
|
||
// Conforms to the EE License for Neosync Cloud | ||
type CloudLicense struct { | ||
isCloud bool | ||
} | ||
|
||
var _ EEInterface = (*CloudLicense)(nil) | ||
|
||
func NewCloudLicense(isCloud bool) *CloudLicense { | ||
return &CloudLicense{isCloud: isCloud} | ||
} | ||
|
||
func NewCloudLicenseFromEnv() *CloudLicense { | ||
return &CloudLicense{isCloud: viper.GetBool("NEOSYNC_CLOUD")} | ||
} | ||
|
||
func (c *CloudLicense) IsValid() bool { | ||
return c.isCloud | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
185 changes: 185 additions & 0 deletions
185
worker/pkg/workflows/datasync/activities/jobhooks-by-timing/activity.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
package jobhooks_by_timing_activity | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log/slog" | ||
"time" | ||
|
||
"connectrpc.com/connect" | ||
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" | ||
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" | ||
neosynclogger "github.com/nucleuscloud/neosync/backend/pkg/logger" | ||
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager" | ||
"go.temporal.io/sdk/activity" | ||
"go.temporal.io/sdk/log" | ||
) | ||
|
||
type License interface { | ||
IsValid() bool | ||
} | ||
|
||
type Activity struct { | ||
jobclient mgmtv1alpha1connect.JobServiceClient | ||
connclient mgmtv1alpha1connect.ConnectionServiceClient | ||
sqlmanagerclient sqlmanager.SqlManagerClient | ||
license License | ||
} | ||
|
||
func New( | ||
jobclient mgmtv1alpha1connect.JobServiceClient, | ||
connclient mgmtv1alpha1connect.ConnectionServiceClient, | ||
sqlmanagerclient sqlmanager.SqlManagerClient, | ||
license License, | ||
) *Activity { | ||
return &Activity{jobclient: jobclient, connclient: connclient, sqlmanagerclient: sqlmanagerclient, license: license} | ||
} | ||
|
||
type RunJobHooksByTimingRequest struct { | ||
JobId string | ||
Timing mgmtv1alpha1.GetActiveJobHooksByTimingRequest_Timing | ||
} | ||
|
||
type RunJobHooksByTimingResponse struct { | ||
ExecCount uint | ||
} | ||
|
||
// Runs active job hooks by the provided timing value | ||
func (a *Activity) RunJobHooksByTiming( | ||
ctx context.Context, | ||
req *RunJobHooksByTimingRequest, | ||
) (*RunJobHooksByTimingResponse, error) { | ||
activityInfo := activity.GetInfo(ctx) | ||
timingName, ok := mgmtv1alpha1.GetActiveJobHooksByTimingRequest_Timing_name[int32(req.Timing)] | ||
if !ok { | ||
return nil, fmt.Errorf("timing was invalid and not resolvable: %d", req.Timing) | ||
} | ||
loggerKeyVals := []any{ | ||
"WorkflowID", activityInfo.WorkflowExecution.ID, | ||
"RunID", activityInfo.WorkflowExecution.RunID, | ||
"jobId", req.JobId, | ||
"timing", timingName, | ||
} | ||
logger := log.With( | ||
activity.GetLogger(ctx), | ||
loggerKeyVals..., | ||
) | ||
slogger := neosynclogger.NewJsonSLogger().With(loggerKeyVals...) | ||
|
||
go func() { | ||
for { | ||
select { | ||
case <-time.After(1 * time.Second): | ||
activity.RecordHeartbeat(ctx) | ||
case <-activity.GetWorkerStopChannel(ctx): | ||
return | ||
case <-ctx.Done(): | ||
return | ||
} | ||
} | ||
}() | ||
if !a.license.IsValid() { | ||
logger.Debug("skipping job hooks due to EE license not being active") | ||
return &RunJobHooksByTimingResponse{ExecCount: 0}, nil | ||
} | ||
|
||
logger.Debug(fmt.Sprintf("retrieving job hooks by timing %q", req.Timing)) | ||
|
||
resp, err := a.jobclient.GetActiveJobHooksByTiming(ctx, connect.NewRequest(&mgmtv1alpha1.GetActiveJobHooksByTimingRequest{ | ||
JobId: req.JobId, | ||
Timing: req.Timing, | ||
})) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to retrieve active hooks by timing: %w", err) | ||
} | ||
hooks := resp.Msg.GetHooks() | ||
logger.Debug(fmt.Sprintf("found %d active hooks", len(hooks))) | ||
|
||
connections := make(map[string]*sqlmanager.SqlConnection) | ||
defer func() { | ||
for _, conn := range connections { | ||
conn.Db.Close() | ||
} | ||
}() | ||
|
||
execCount := uint(0) | ||
|
||
for _, hook := range hooks { | ||
logger.Debug(fmt.Sprintf("running hook %q", hook.GetName())) | ||
logger := log.With(logger, "hookName", hook.GetName()) | ||
|
||
switch hookConfig := hook.GetConfig().GetConfig().(type) { | ||
case *mgmtv1alpha1.JobHookConfig_Sql: | ||
logger.Debug("running SQL hook") | ||
if hookConfig.Sql == nil { | ||
return nil, errors.New("SQL hook config has undefined SQL configuration") | ||
} | ||
if err := a.executeSqlHook( | ||
ctx, | ||
hookConfig.Sql, | ||
a.getCachedConnectionFn(connections, logger, slogger), | ||
); err != nil { | ||
return nil, fmt.Errorf("unable to execute sql hook: %w", err) | ||
} | ||
execCount++ | ||
default: | ||
logger.Warn(fmt.Sprintf("hook config with type %T is not currently supported!", hookConfig)) | ||
} | ||
} | ||
|
||
return &RunJobHooksByTimingResponse{ExecCount: execCount}, nil | ||
} | ||
|
||
// Given a connection id, returns an initialized sql database connection | ||
type getSqlDbFromConnectionId = func(ctx context.Context, connectionId string) (sqlmanager.SqlDatabase, error) | ||
|
||
func (a *Activity) getCachedConnectionFn( | ||
connections map[string]*sqlmanager.SqlConnection, | ||
logger log.Logger, | ||
slogger *slog.Logger, | ||
) getSqlDbFromConnectionId { | ||
return func(ctx context.Context, connectionId string) (sqlmanager.SqlDatabase, error) { | ||
conn, ok := connections[connectionId] | ||
if ok { | ||
logger.Debug("found cached connection when running hook") | ||
return conn.Db, nil | ||
} | ||
logger.Debug("initializing connection for hook") | ||
connectionResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ | ||
Id: connectionId, | ||
})) | ||
if err != nil { | ||
return nil, err | ||
} | ||
connection := connectionResp.Msg.GetConnection() | ||
sqlconnection, err := a.sqlmanagerclient.NewPooledSqlDb( | ||
ctx, | ||
slogger.With( | ||
"connectionId", connection.GetId(), | ||
"accountId", connection.GetAccountId(), | ||
), | ||
connection, | ||
) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to initialize pooled sql connection: %W", err) | ||
} | ||
connections[connectionId] = sqlconnection | ||
return sqlconnection.Db, nil | ||
} | ||
} | ||
|
||
func (a *Activity) executeSqlHook( | ||
ctx context.Context, | ||
hook *mgmtv1alpha1.JobHookConfig_JobSqlHook, | ||
getSqlConnection getSqlDbFromConnectionId, | ||
) error { | ||
db, err := getSqlConnection(ctx, hook.GetConnectionId()) | ||
if err != nil { | ||
return err | ||
} | ||
if err := db.Exec(ctx, hook.GetQuery()); err != nil { | ||
return fmt.Errorf("unable to execute SQL hook statement: %w", err) | ||
} | ||
return nil | ||
} |
Oops, something went wrong.