Skip to content

Commit

Permalink
Schema: return error rather than caller having to check Fields == nil
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Dec 27, 2024
1 parent bc56003 commit 9674a51
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 49 deletions.
11 changes: 6 additions & 5 deletions flow/connectors/bigquery/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func (c *BigQueryConnector) SyncQRepRecords(
) (int, error) {
// Ensure the destination table is available.
destTable := config.DestinationTableIdentifier
srcSchema := stream.Schema()
if srcSchema.Fields == nil {
return 0, stream.Err()
srcSchema, err := stream.Schema()
if err != nil {
return 0, err
}

tblMetadata, err := c.replayTableSchemaDeltasQRep(ctx, config, partition, srcSchema)
Expand Down Expand Up @@ -83,8 +83,9 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep(
}
}

err = c.ReplayTableSchemaDeltas(ctx, config.Env, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta})
if err != nil {
if err := c.ReplayTableSchemaDeltas(
ctx, config.Env, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta},
); err != nil {
return nil, fmt.Errorf("failed to add columns to destination table: %w", err)
}
dstTableMetadata, err = bqTable.Metadata(ctx)
Expand Down
13 changes: 9 additions & 4 deletions flow/connectors/clickhouse/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ func (s *ClickHouseAvroSyncMethod) SyncRecords(
) (int, error) {
dstTableName := s.config.DestinationTableIdentifier

schema := stream.Schema()
if schema.Fields == nil {
return 0, stream.Err()
schema, err := stream.Schema()
if err != nil {
return 0, err
}
s.logger.Info("sync function called and schema acquired",
slog.String("dstTable", dstTableName))
Expand Down Expand Up @@ -109,7 +109,12 @@ func (s *ClickHouseAvroSyncMethod) SyncQRepRecords(
stagingPath := s.credsProvider.BucketPath
startTime := time.Now()

avroSchema, err := s.getAvroSchema(ctx, config.Env, dstTableName, stream.Schema())
schema, err := stream.Schema()
if err != nil {
return 0, err
}

avroSchema, err := s.getAvroSchema(ctx, config.Env, dstTableName, schema)
if err != nil {
return 0, err
}
Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/elasticsearch/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func (esc *ElasticsearchConnector) SyncQRepRecords(ctx context.Context, config *
) (int, error) {
startTime := time.Now()

schema := stream.Schema()
if schema.Fields == nil {
return 0, stream.Err()
schema, err := stream.Schema()
if err != nil {
return 0, err
}

var bulkIndexFatalError error
Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/kafka/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ func (c *KafkaConnector) SyncQRepRecords(
) (int, error) {
startTime := time.Now()
numRecords := atomic.Int64{}
schema := stream.Schema()
if schema.Fields == nil {
return 0, stream.Err()
schema, err := stream.Schema()
if err != nil {
return 0, err
}

queueCtx, queueErr := context.WithCancelCause(ctx)
Expand Down
10 changes: 6 additions & 4 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type QRepPullSink interface {
}

type QRepSyncSink interface {
GetColumnNames() []string
GetColumnNames() ([]string, error)
CopyInto(context.Context, *PostgresConnector, pgx.Tx, pgx.Identifier) (int64, error)
}

Expand Down Expand Up @@ -550,7 +550,10 @@ func syncQRepRecords(
upsertMatchCols[col] = struct{}{}
}

columnNames := sink.GetColumnNames()
columnNames, err := sink.GetColumnNames()
if err != nil {
return -1, fmt.Errorf("faild to get column names: %w", err)
}
setClauseArray := make([]string, 0, len(upsertMatchColsList)+1)
selectStrArray := make([]string, 0, len(columnNames))
for _, col := range columnNames {
Expand Down Expand Up @@ -578,8 +581,7 @@ func syncQRepRecords(
setClause,
)
c.logger.Info("Performing upsert operation", slog.String("upsertStmt", upsertStmt), syncLog)
_, err := tx.Exec(ctx, upsertStmt)
if err != nil {
if _, err := tx.Exec(ctx, upsertStmt); err != nil {
return -1, fmt.Errorf("failed to perform upsert operation: %w", err)
}
}
Expand Down
20 changes: 13 additions & 7 deletions flow/connectors/postgres/qrep_query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery(
args ...interface{},
) (*model.QRecordBatch, error) {
stream := model.NewQRecordStream(1024)
errors := make(chan error, 1)
errors := make(chan struct{})
var errorsError error
qe.logger.Info("Executing and processing query", slog.String("query", query))

// must wait on errors to close before returning to maintain qe.conn exclusion
Expand All @@ -233,23 +234,28 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery(
_, err := qe.ExecuteAndProcessQueryStream(ctx, stream, query, args...)
if err != nil {
qe.logger.Error("[pg_query_executor] failed to execute and process query stream", slog.Any("error", err))
errors <- err
errorsError = err
}
}()

select {
case err := <-errors:
return nil, err
case <-errors:
return nil, errorsError
case <-stream.SchemaChan():
schema, err := stream.Schema()
if err != nil {
return nil, err
}
batch := &model.QRecordBatch{
Schema: stream.Schema(),
Schema: schema,
Records: nil,
}
for record := range stream.Records {
batch.Records = append(batch.Records, record)
}
if err := <-errors; err != nil {
return nil, err
<-errors
if errorsError != nil {
return nil, errorsError
}
if err := stream.Err(); err != nil {
return nil, fmt.Errorf("[pg] failed to get record from stream: %w", err)
Expand Down
12 changes: 9 additions & 3 deletions flow/connectors/postgres/sink_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

type PgCopyShared struct {
schemaLatch chan struct{}
err error
schema []string
schemaSet bool
}
Expand Down Expand Up @@ -109,15 +110,20 @@ func (p PgCopyWriter) ExecuteQueryWithTx(

func (p PgCopyWriter) Close(err error) {
p.PipeWriter.CloseWithError(err)
p.schema.err = err
p.SetSchema(nil)
}

func (p PgCopyReader) GetColumnNames() []string {
func (p PgCopyReader) GetColumnNames() ([]string, error) {
<-p.schema.schemaLatch
return p.schema.schema
return p.schema.schema, p.schema.err
}

func (p PgCopyReader) CopyInto(ctx context.Context, c *PostgresConnector, tx pgx.Tx, table pgx.Identifier) (int64, error) {
cols := p.GetColumnNames()
cols, err := p.GetColumnNames()
if err != nil {
return 0, err
}
quotedCols := make([]string, 0, len(cols))
for _, col := range cols {
quotedCols = append(quotedCols, QuoteIdentifier(col))
Expand Down
14 changes: 11 additions & 3 deletions flow/connectors/postgres/sink_q.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,17 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
}

func (stream RecordStreamSink) CopyInto(ctx context.Context, _ *PostgresConnector, tx pgx.Tx, table pgx.Identifier) (int64, error) {
return tx.CopyFrom(ctx, table, stream.GetColumnNames(), model.NewQRecordCopyFromSource(stream.QRecordStream))
columnNames, err := stream.GetColumnNames()
if err != nil {
return 0, err
}
return tx.CopyFrom(ctx, table, columnNames, model.NewQRecordCopyFromSource(stream.QRecordStream))
}

func (stream RecordStreamSink) GetColumnNames() []string {
return stream.Schema().GetColumnNames()
func (stream RecordStreamSink) GetColumnNames() ([]string, error) {
schema, err := stream.Schema()
if err != nil {
return nil, err
}
return schema.GetColumnNames(), nil
}
6 changes: 3 additions & 3 deletions flow/connectors/pubsub/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (c *PubSubConnector) SyncQRepRecords(
stream *model.QRecordStream,
) (int, error) {
startTime := time.Now()
schema := stream.Schema()
if schema.Fields == nil {
return 0, stream.Err()
schema, err := stream.Schema()
if err != nil {
return 0, err
}
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan publishResult, 32)
Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/s3/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ func (c *S3Connector) SyncQRepRecords(
partition *protos.QRepPartition,
stream *model.QRecordStream,
) (int, error) {
schema := stream.Schema()
if schema.Fields == nil {
return 0, stream.Err()
schema, err := stream.Schema()
if err != nil {
return 0, err
}

dstTableName := config.DestinationTableIdentifier
Expand Down
12 changes: 7 additions & 5 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (s *SnowflakeAvroSyncHandler) SyncRecords(
tableLog := slog.String("destinationTable", s.config.DestinationTableIdentifier)
dstTableName := s.config.DestinationTableIdentifier

schema := stream.Schema()
if schema.Fields == nil {
schema, err := stream.Schema()
if err != nil {
return 0, stream.Err()
}

Expand Down Expand Up @@ -98,11 +98,13 @@ func (s *SnowflakeAvroSyncHandler) SyncQRepRecords(
startTime := time.Now()
dstTableName := config.DestinationTableIdentifier

schema := stream.Schema()
schema, err := stream.Schema()
if err != nil {
return 0, err
}
s.logger.Info("sync function called and schema acquired", partitionLog)

err := s.addMissingColumns(ctx, config.Env, schema, dstTableSchema, dstTableName, partition)
if err != nil {
if err := s.addMissingColumns(ctx, config.Env, schema, dstTableSchema, dstTableName, partition); err != nil {
return 0, err
}

Expand Down
6 changes: 3 additions & 3 deletions flow/connectors/utils/avro/avro_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ func (p *peerDBOCFWriter) createOCFWriter(w io.Writer) (*goavro.OCFWriter, error

func (p *peerDBOCFWriter) writeRecordsToOCFWriter(ctx context.Context, env map[string]string, ocfWriter *goavro.OCFWriter) (int64, error) {
logger := shared.LoggerFromCtx(ctx)
schema := p.stream.Schema()
if schema.Fields == nil {
return 0, p.stream.Err()
schema, err := p.stream.Schema()
if err != nil {
return 0, err
}

avroConverter, err := model.NewQRecordAvroConverter(
Expand Down
4 changes: 2 additions & 2 deletions flow/model/qrecord_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ func NewQRecordStream(buffer int) *QRecordStream {
}
}

func (s *QRecordStream) Schema() qvalue.QRecordSchema {
func (s *QRecordStream) Schema() (qvalue.QRecordSchema, error) {
<-s.schemaLatch
return s.schema
return s.schema, s.Err()
}

func (s *QRecordStream) SetSchema(schema qvalue.QRecordSchema) {
Expand Down
6 changes: 5 additions & 1 deletion flow/pua/stream_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ import (
func AttachToStream(ls *lua.LState, lfn *lua.LFunction, stream *model.QRecordStream) *model.QRecordStream {
output := model.NewQRecordStream(0)
go func() {
schema := stream.Schema()
schema, err := stream.Schema()
if err != nil {
output.Close(err)
return
}
output.SetSchema(schema)
for record := range stream.Records {
row := model.NewRecordItems(len(record))
Expand Down

0 comments on commit 9674a51

Please sign in to comment.