Skip to content

Commit

Permalink
NEOS-1643: Integrates JobHooks into Data Sync Workflow (#2991)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzelei authored Nov 25, 2024
1 parent 0f405b0 commit 65cc55e
Show file tree
Hide file tree
Showing 11 changed files with 567 additions and 55 deletions.
1 change: 1 addition & 0 deletions backend/internal/cmds/mgmt/serve/connect/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ func serve(ctx context.Context) error {
mgmtv1alpha1connect.UserAccountServiceSetBillingMeterEventProcedure,
mgmtv1alpha1connect.MetricsServiceGetDailyMetricCountProcedure,
mgmtv1alpha1connect.AnonymizationServiceAnonymizeManyProcedure,
mgmtv1alpha1connect.JobServiceGetActiveJobHooksByTimingProcedure,
})
stdAuthInterceptors = append(
stdAuthInterceptors,
Expand Down
21 changes: 21 additions & 0 deletions internal/ee/license/cascade.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package license

type CascadeLicense struct {
isValid bool
}

// Checks multiple licenses in input order to see if any are valid
func NewCascadeLicense(licenses ...EEInterface) *CascadeLicense {
isValid := false
for _, l := range licenses {
if l.IsValid() {
isValid = true
break
}
}
return &CascadeLicense{isValid: isValid}
}

func (c *CascadeLicense) IsValid() bool {
return c.isValid
}
22 changes: 22 additions & 0 deletions internal/ee/license/cloud.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package license

import "github.com/spf13/viper"

// Conforms to the EE License for Neosync Cloud
type CloudLicense struct {
isCloud bool
}

var _ EEInterface = (*CloudLicense)(nil)

func NewCloudLicense(isCloud bool) *CloudLicense {
return &CloudLicense{isCloud: isCloud}
}

func NewCloudLicenseFromEnv() *CloudLicense {
return &CloudLicense{isCloud: viper.GetBool("NEOSYNC_CLOUD")}
}

func (c *CloudLicense) IsValid() bool {
return c.isCloud
}
10 changes: 9 additions & 1 deletion worker/internal/cmds/worker/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
neosyncotel "github.com/nucleuscloud/neosync/internal/otel"
accountstatus_activity "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/account-status"
genbenthosconfigs_activity "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/gen-benthos-configs"
jobhooks_by_timing_activity "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/jobhooks-by-timing"
posttablesync_activity "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/post-table-sync"
runsqlinittablestmts_activity "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts"
"github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared"
Expand Down Expand Up @@ -257,6 +258,11 @@ func serve(ctx context.Context) error {
w := worker.New(temporalClient, taskQueue, worker.Options{})
_ = w

cascadelicense := license.NewCascadeLicense(
license.NewCloudLicenseFromEnv(),
eelicense,
)

pgpoolmap := &sync.Map{}
mysqlpoolmap := &sync.Map{}
mssqlpoolmap := &sync.Map{}
Expand Down Expand Up @@ -286,9 +292,10 @@ func serve(ctx context.Context) error {
disableReaper := false
syncActivity := sync_activity.New(connclient, jobclient, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, temporalClient, syncActivityMeter, sync_activity.NewBenthosStreamManager(), disableReaper)
retrieveActivityOpts := syncactivityopts_activity.New(jobclient)
runSqlInitTableStatements := runsqlinittablestmts_activity.New(jobclient, connclient, sqlmanager, eelicense, isNeosyncCloud)
runSqlInitTableStatements := runsqlinittablestmts_activity.New(jobclient, connclient, sqlmanager, cascadelicense)
accountStatusActivity := accountstatus_activity.New(userclient)
runPostTableSyncActivity := posttablesync_activity.New(jobclient, sqlmanager, connclient)
jobhookByTimingActivity := jobhooks_by_timing_activity.New(jobclient, connclient, sqlmanager, cascadelicense)

w.RegisterWorkflow(datasync_workflow.Workflow)
w.RegisterActivity(syncActivity.Sync)
Expand All @@ -298,6 +305,7 @@ func serve(ctx context.Context) error {
w.RegisterActivity(genbenthosActivity.GenerateBenthosConfigs)
w.RegisterActivity(accountStatusActivity.CheckAccountStatus)
w.RegisterActivity(runPostTableSyncActivity.RunPostTableSync)
w.RegisterActivity(jobhookByTimingActivity.RunJobHooksByTiming)

if err := w.Start(); err != nil {
return fmt.Errorf("unable to start temporal worker: %w", err)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package jobhooks_by_timing_activity

import (
"context"
"errors"
"fmt"
"log/slog"
"time"

"connectrpc.com/connect"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
neosynclogger "github.com/nucleuscloud/neosync/backend/pkg/logger"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
"go.temporal.io/sdk/activity"
"go.temporal.io/sdk/log"
)

type License interface {
IsValid() bool
}

type Activity struct {
jobclient mgmtv1alpha1connect.JobServiceClient
connclient mgmtv1alpha1connect.ConnectionServiceClient
sqlmanagerclient sqlmanager.SqlManagerClient
license License
}

func New(
jobclient mgmtv1alpha1connect.JobServiceClient,
connclient mgmtv1alpha1connect.ConnectionServiceClient,
sqlmanagerclient sqlmanager.SqlManagerClient,
license License,
) *Activity {
return &Activity{jobclient: jobclient, connclient: connclient, sqlmanagerclient: sqlmanagerclient, license: license}
}

type RunJobHooksByTimingRequest struct {
JobId string
Timing mgmtv1alpha1.GetActiveJobHooksByTimingRequest_Timing
}

type RunJobHooksByTimingResponse struct {
ExecCount uint
}

// Runs active job hooks by the provided timing value
func (a *Activity) RunJobHooksByTiming(
ctx context.Context,
req *RunJobHooksByTimingRequest,
) (*RunJobHooksByTimingResponse, error) {
activityInfo := activity.GetInfo(ctx)
timingName, ok := mgmtv1alpha1.GetActiveJobHooksByTimingRequest_Timing_name[int32(req.Timing)]
if !ok {
return nil, fmt.Errorf("timing was invalid and not resolvable: %d", req.Timing)
}
loggerKeyVals := []any{
"WorkflowID", activityInfo.WorkflowExecution.ID,
"RunID", activityInfo.WorkflowExecution.RunID,
"jobId", req.JobId,
"timing", timingName,
}
logger := log.With(
activity.GetLogger(ctx),
loggerKeyVals...,
)
slogger := neosynclogger.NewJsonSLogger().With(loggerKeyVals...)

go func() {
for {
select {
case <-time.After(1 * time.Second):
activity.RecordHeartbeat(ctx)
case <-activity.GetWorkerStopChannel(ctx):
return
case <-ctx.Done():
return
}
}
}()
if !a.license.IsValid() {
logger.Debug("skipping job hooks due to EE license not being active")
return &RunJobHooksByTimingResponse{ExecCount: 0}, nil
}

logger.Debug(fmt.Sprintf("retrieving job hooks by timing %q", req.Timing))

resp, err := a.jobclient.GetActiveJobHooksByTiming(ctx, connect.NewRequest(&mgmtv1alpha1.GetActiveJobHooksByTimingRequest{
JobId: req.JobId,
Timing: req.Timing,
}))
if err != nil {
return nil, fmt.Errorf("unable to retrieve active hooks by timing: %w", err)
}
hooks := resp.Msg.GetHooks()
logger.Debug(fmt.Sprintf("found %d active hooks", len(hooks)))

connections := make(map[string]*sqlmanager.SqlConnection)
defer func() {
for _, conn := range connections {
conn.Db.Close()
}
}()

execCount := uint(0)

for _, hook := range hooks {
logger.Debug(fmt.Sprintf("running hook %q", hook.GetName()))
logger := log.With(logger, "hookName", hook.GetName())

switch hookConfig := hook.GetConfig().GetConfig().(type) {
case *mgmtv1alpha1.JobHookConfig_Sql:
logger.Debug("running SQL hook")
if hookConfig.Sql == nil {
return nil, errors.New("SQL hook config has undefined SQL configuration")
}
if err := a.executeSqlHook(
ctx,
hookConfig.Sql,
a.getCachedConnectionFn(connections, logger, slogger),
); err != nil {
return nil, fmt.Errorf("unable to execute sql hook: %w", err)
}
execCount++
default:
logger.Warn(fmt.Sprintf("hook config with type %T is not currently supported!", hookConfig))
}
}

return &RunJobHooksByTimingResponse{ExecCount: execCount}, nil
}

// Given a connection id, returns an initialized sql database connection
type getSqlDbFromConnectionId = func(ctx context.Context, connectionId string) (sqlmanager.SqlDatabase, error)

func (a *Activity) getCachedConnectionFn(
connections map[string]*sqlmanager.SqlConnection,
logger log.Logger,
slogger *slog.Logger,
) getSqlDbFromConnectionId {
return func(ctx context.Context, connectionId string) (sqlmanager.SqlDatabase, error) {
conn, ok := connections[connectionId]
if ok {
logger.Debug("found cached connection when running hook")
return conn.Db, nil
}
logger.Debug("initializing connection for hook")
connectionResp, err := a.connclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{
Id: connectionId,
}))
if err != nil {
return nil, err
}
connection := connectionResp.Msg.GetConnection()
sqlconnection, err := a.sqlmanagerclient.NewPooledSqlDb(
ctx,
slogger.With(
"connectionId", connection.GetId(),
"accountId", connection.GetAccountId(),
),
connection,
)
if err != nil {
return nil, fmt.Errorf("unable to initialize pooled sql connection: %W", err)
}
connections[connectionId] = sqlconnection
return sqlconnection.Db, nil
}
}

func (a *Activity) executeSqlHook(
ctx context.Context,
hook *mgmtv1alpha1.JobHookConfig_JobSqlHook,
getSqlConnection getSqlDbFromConnectionId,
) error {
db, err := getSqlConnection(ctx, hook.GetConnectionId())
if err != nil {
return err
}
if err := db.Exec(ctx, hook.GetQuery()); err != nil {
return fmt.Errorf("unable to execute SQL hook statement: %w", err)
}
return nil
}
Loading

0 comments on commit 65cc55e

Please sign in to comment.