From d01425da9884c3f2c417095e4bd285ef47cfcd4e Mon Sep 17 00:00:00 2001 From: Milan Sreckovic <39279049+milanatshopify@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:26:47 -0400 Subject: [PATCH] POC composite key pagination --- config.go | 28 ++++++ cursor.go | 214 ++++++++++++++++++++++++++++++++++++++---- data_iterator.go | 57 +++++++++-- inline_verifier.go | 14 +++ row_batch.go | 5 + state_tracker.go | 61 +++++++++++- table_schema_cache.go | 179 +++++++++++++++++++++++++++++++---- 7 files changed, 511 insertions(+), 47 deletions(-) diff --git a/config.go b/config.go index d9351f01a..23addf359 100644 --- a/config.go +++ b/config.go @@ -377,6 +377,10 @@ func (c ForceIndexConfig) IndexFor(schemaName, tableName string) string { // used. The term `Cascading` to denote that greater specificity takes // precedence. type CascadingPaginationColumnConfig struct { + // PerTableComposite has highest specificity for composite keys (max 3 columns) + // SchemaName => TableName => [ColumnName1, ColumnName2, ColumnName3] + PerTableComposite map[string]map[string][]string + // PerTable has greatest specificity and takes precedence over the other options PerTable map[string]map[string]string // SchemaName => TableName => ColumnName @@ -404,6 +408,30 @@ func (c *CascadingPaginationColumnConfig) PaginationColumnFor(schemaName, tableN return column, true } +// CompositePaginationColumnsFor retrieves composite pagination columns for a table +func (c *CascadingPaginationColumnConfig) CompositePaginationColumnsFor(schemaName, tableName string) ([]string, bool) { + if c == nil || c.PerTableComposite == nil { + return nil, false + } + + tableConfig, found := c.PerTableComposite[schemaName] + if !found { + return nil, false + } + + columns, found := tableConfig[tableName] + if !found { + return nil, false + } + + // Validate max 3 columns + if len(columns) > 3 || len(columns) == 0 { + return nil, false + } + + return columns, true +} + // FallbackPaginationColumnName retreives the column name specified as a fallback when the Primary Key isn't suitable for pagination func (c *CascadingPaginationColumnConfig) FallbackPaginationColumnName() (string, bool) { if c == nil || c.FallbackColumn == "" { diff --git a/cursor.go b/cursor.go index 9a7a72ed1..20bea3a21 100644 --- a/cursor.go +++ b/cursor.go @@ -64,6 +64,29 @@ func (c *CursorConfig) NewCursorWithoutRowLock(table *TableSchema, startPaginati return cursor } +// NewCompositeCursor creates a cursor for composite key pagination +func (c *CursorConfig) NewCompositeCursor(table *TableSchema, startKeys, maxKeys CompositeKey) *Cursor { + cursor := &Cursor{ + CursorConfig: *c, + Table: table, + RowLock: true, + isComposite: true, + lastSuccessfulCompositeKey: startKeys, + maxCompositeKey: maxKeys, + // Set single key values from first column for backward compatibility + MaxPaginationKey: maxKeys.Values[0].(uint64), + lastSuccessfulPaginationKey: startKeys.Values[0].(uint64), + } + return cursor +} + +// NewCompositeCursorWithoutRowLock creates a cursor for composite key pagination without row locks +func (c *CursorConfig) NewCompositeCursorWithoutRowLock(table *TableSchema, startKeys, maxKeys CompositeKey) *Cursor { + cursor := c.NewCompositeCursor(table, startKeys, maxKeys) + cursor.RowLock = false + return cursor +} + func (c CursorConfig) GetBatchSize(schemaName string, tableName string) uint64 { if c.BatchSizePerTableOverride != nil { if batchSize, found := c.BatchSizePerTableOverride.TableOverride[schemaName][tableName]; found { @@ -73,6 +96,39 @@ func (c CursorConfig) GetBatchSize(schemaName string, tableName string) uint64 { return *c.BatchSize } +// CompositeKey represents a composite pagination key +type CompositeKey struct { + Values []interface{} // Can be uint64 or string +} + +// NewCompositeKey creates a CompositeKey from values +func NewCompositeKey(values ...interface{}) CompositeKey { + return CompositeKey{Values: values} +} + +// IsLessThan compares two composite keys +func (c CompositeKey) IsLessThan(other CompositeKey) bool { + for i := 0; i < len(c.Values) && i < len(other.Values); i++ { + switch v1 := c.Values[i].(type) { + case uint64: + v2 := other.Values[i].(uint64) + if v1 < v2 { + return true + } else if v1 > v2 { + return false + } + case string: + v2 := other.Values[i].(string) + if v1 < v2 { + return true + } else if v1 > v2 { + return false + } + } + } + return false +} + type Cursor struct { CursorConfig @@ -80,23 +136,42 @@ type Cursor struct { MaxPaginationKey uint64 RowLock bool + // Single column pagination (backward compatibility) paginationKeyColumn *schema.TableColumn lastSuccessfulPaginationKey uint64 + + // Composite key pagination + isComposite bool + maxCompositeKey CompositeKey + lastSuccessfulCompositeKey CompositeKey + logger *logrus.Entry } +// shouldContinue checks if cursor should continue iterating +func (c *Cursor) shouldContinue() bool { + if c.isComposite { + return c.lastSuccessfulCompositeKey.IsLessThan(c.maxCompositeKey) + } + return c.lastSuccessfulPaginationKey < c.MaxPaginationKey +} + func (c *Cursor) Each(f func(*RowBatch) error) error { c.logger = logrus.WithFields(logrus.Fields{ "table": c.Table.String(), "tag": "cursor", }) - c.paginationKeyColumn = c.Table.GetPaginationColumn() + + if !c.isComposite { + c.paginationKeyColumn = c.Table.GetPaginationColumn() + } if len(c.ColumnsToSelect) == 0 { c.ColumnsToSelect = []string{"*"} } - for c.lastSuccessfulPaginationKey < c.MaxPaginationKey { + // Use appropriate loop condition based on pagination type + for c.shouldContinue() { var tx SqlPreparerAndRollbacker var batch *RowBatch var paginationKeypos uint64 @@ -153,17 +228,50 @@ func (c *Cursor) Each(f func(*RowBatch) error) error { tx.Rollback() - c.lastSuccessfulPaginationKey = paginationKeypos + // Update pagination position + if c.isComposite { + c.updateCompositePosition(batch) + } else { + c.lastSuccessfulPaginationKey = paginationKeypos + } } return nil } +// updateCompositePosition updates the last successful composite key from the batch +func (c *Cursor) updateCompositePosition(batch *RowBatch) { + if batch.Size() == 0 { + return + } + + lastRow := batch.Values()[batch.Size()-1] + newKeys := make([]interface{}, len(c.Table.CompositePaginationIndexes)) + + for i, idx := range c.Table.CompositePaginationIndexes { + switch v := lastRow[idx].(type) { + case int64: + newKeys[i] = uint64(v) + default: + newKeys[i] = v + } + } + + c.lastSuccessfulCompositeKey = NewCompositeKey(newKeys...) + // Update single key for backward compatibility + if v, ok := newKeys[0].(uint64); ok { + c.lastSuccessfulPaginationKey = v + } +} + func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64, err error) { var selectBuilder squirrel.SelectBuilder batchSize := c.CursorConfig.GetBatchSize(c.Table.Schema, c.Table.Name) - if c.BuildSelect != nil { + if c.isComposite { + // Use composite pagination + selectBuilder = DefaultBuildSelectComposite(c.ColumnsToSelect, c.Table, c.lastSuccessfulCompositeKey.Values, batchSize) + } else if c.BuildSelect != nil { selectBuilder, err = c.BuildSelect(c.ColumnsToSelect, c.Table, c.lastSuccessfulPaginationKey, batchSize) if err != nil { c.logger.WithError(err).Error("failed to apply filter for select") @@ -229,17 +337,43 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 } var paginationKeyIndex int = -1 - for idx, col := range columns { - if col == c.paginationKeyColumn.Name { - paginationKeyIndex = idx - break + var compositePaginationIndexes []int + + if c.isComposite { + // Find all composite pagination column indexes + compositePaginationIndexes = make([]int, len(c.Table.CompositePaginationColumns)) + for i, paginationCol := range c.Table.CompositePaginationColumns { + found := false + for idx, col := range columns { + if col == paginationCol.Name { + compositePaginationIndexes[i] = idx + if i == 0 { + paginationKeyIndex = idx // First column for backward compatibility + } + found = true + break + } + } + if !found { + err = fmt.Errorf("composite paginationKey column %s not found in columns: %v", paginationCol.Name, columns) + logger.WithError(err).Error("failed to get composite paginationKey index") + return + } + } + } else { + // Single pagination key + for idx, col := range columns { + if col == c.paginationKeyColumn.Name { + paginationKeyIndex = idx + break + } + } + + if paginationKeyIndex < 0 { + err = fmt.Errorf("paginationKey is not found during iteration with columns: %v", columns) + logger.WithError(err).Error("failed to get paginationKey index") + return } - } - - if paginationKeyIndex < 0 { - err = fmt.Errorf("paginationKey is not found during iteration with columns: %v", columns) - logger.WithError(err).Error("failed to get paginationKey index") - return } var rowData RowData @@ -269,10 +403,12 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos uint64 } batch = &RowBatch{ - values: batchData, - paginationKeyIndex: paginationKeyIndex, - table: c.Table, - columns: columns, + values: batchData, + paginationKeyIndex: paginationKeyIndex, + isCompositePagination: c.isComposite, + compositePaginationIndexes: compositePaginationIndexes, + table: c.Table, + columns: columns, } logger.Debugf("found %d rows", batch.Size()) @@ -304,6 +440,29 @@ func ScanByteRow(rows *sqlorig.Rows, columnCount int) ([][]byte, error) { return values, err } +// BuildCompositeTupleComparison creates a WHERE clause for composite key pagination +// For columns (a,b,c) and values (x,y,z), generates: +// WHERE a > x OR (a = x AND b > y) OR (a = x AND b = y AND c > z) +func BuildCompositeTupleComparison(columns []string, values []interface{}) squirrel.Or { + conditions := make(squirrel.Or, 0, len(columns)) + + for i := 0; i < len(columns); i++ { + condition := squirrel.And{} + + // Add equality conditions for all columns before the current one + for j := 0; j < i; j++ { + condition = append(condition, squirrel.Eq{columns[j]: values[j]}) + } + + // Add greater than condition for the current column + condition = append(condition, squirrel.Gt{columns[i]: values[i]}) + + conditions = append(conditions, condition) + } + + return conditions +} + func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey, batchSize uint64) squirrel.SelectBuilder { quotedPaginationKey := QuoteField(table.GetPaginationColumn().Name) @@ -313,3 +472,22 @@ func DefaultBuildSelect(columns []string, table *TableSchema, lastPaginationKey, Limit(batchSize). OrderBy(quotedPaginationKey) } + +// DefaultBuildSelectComposite builds a SELECT query for composite pagination +func DefaultBuildSelectComposite(columns []string, table *TableSchema, lastPaginationKeys []interface{}, batchSize uint64) squirrel.SelectBuilder { + quotedColumns := make([]string, len(table.CompositePaginationColumns)) + orderByColumns := make([]string, len(table.CompositePaginationColumns)) + + for i, col := range table.CompositePaginationColumns { + quotedColumns[i] = QuoteField(col.Name) + orderByColumns[i] = QuoteField(col.Name) + } + + whereClause := BuildCompositeTupleComparison(quotedColumns, lastPaginationKeys) + + return squirrel.Select(columns...). + From(QuotedTableName(table)). + Where(whereClause). + Limit(batchSize). + OrderBy(orderByColumns...) +} diff --git a/data_iterator.go b/data_iterator.go index 5621a24e0..9142cc741 100644 --- a/data_iterator.go +++ b/data_iterator.go @@ -42,7 +42,7 @@ func (d *DataIterator) Run(tables []*TableSchema) { } d.logger.WithField("tablesCount", len(tables)).Info("starting data iterator run") - tablesWithData, emptyTables, err := MaxPaginationKeys(d.DB, tables, d.logger) + paginationData, emptyTables, err := MaxPaginationKeysWithComposite(d.DB, tables, d.logger) if err != nil { d.ErrorHandler.Fatal("data_iterator", err) } @@ -51,6 +51,7 @@ func (d *DataIterator) Run(tables []*TableSchema) { d.StateTracker.MarkTableAsCompleted(table.String()) } + tablesWithData := paginationData.SingleKeys for table, maxPaginationKey := range tablesWithData { tableName := table.String() if d.StateTracker.IsTableComplete(tableName) { @@ -59,6 +60,11 @@ func (d *DataIterator) Run(tables []*TableSchema) { delete(tablesWithData, table) } else { d.TargetPaginationKeys.Store(tableName, maxPaginationKey) + + // Store composite keys if present + if compositeKeys, found := paginationData.CompositeKeys[table]; found { + d.TargetPaginationKeys.Store(tableName+"_composite", NewCompositeKey(compositeKeys...)) + } } } @@ -86,15 +92,48 @@ func (d *DataIterator) Run(tables []*TableSchema) { return } - startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String()) - if startPaginationKey == math.MaxUint64 { - err := fmt.Errorf("%v has been marked as completed but a table iterator has been spawned, this is likely a programmer error which resulted in the inconsistent starting state", table.String()) - logger.WithError(err).Error("this is definitely a bug") - d.ErrorHandler.Fatal("data_iterator", err) - return - } + var cursor *Cursor + + // Check if this table uses composite pagination + if table.IsCompositePagination { + // Get composite pagination keys + startKeys, found := d.StateTracker.LastSuccessfulCompositePaginationKey(table.String()) + if !found { + // Initialize with zero values for each column + startKeys = make([]interface{}, len(table.CompositePaginationColumns)) + for i, col := range table.CompositePaginationColumns { + if col.Type == schema.TYPE_STRING { + startKeys[i] = "" + } else { + startKeys[i] = uint64(0) + } + } + } + + // Get max composite keys from TargetPaginationKeys + maxKeysInterface, found := d.TargetPaginationKeys.Load(table.String() + "_composite") + if !found { + err := fmt.Errorf("%s composite keys not found in TargetPaginationKeys", table.String()) + logger.WithError(err).Error("composite pagination keys missing") + d.ErrorHandler.Fatal("data_iterator", err) + return + } + + cursor = d.CursorConfig.NewCompositeCursor(table, + NewCompositeKey(startKeys...), + maxKeysInterface.(CompositeKey)) + } else { + // Single column pagination + startPaginationKey := d.StateTracker.LastSuccessfulPaginationKey(table.String()) + if startPaginationKey == math.MaxUint64 { + err := fmt.Errorf("%v has been marked as completed but a table iterator has been spawned, this is likely a programmer error which resulted in the inconsistent starting state", table.String()) + logger.WithError(err).Error("this is definitely a bug") + d.ErrorHandler.Fatal("data_iterator", err) + return + } - cursor := d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(uint64)) + cursor = d.CursorConfig.NewCursor(table, startPaginationKey, targetPaginationKeyInterface.(uint64)) + } if d.SelectFingerprint { if len(cursor.ColumnsToSelect) == 0 { cursor.ColumnsToSelect = []string{"*"} diff --git a/inline_verifier.go b/inline_verifier.go index 552c88e6e..facbd5640 100644 --- a/inline_verifier.go +++ b/inline_verifier.go @@ -28,6 +28,7 @@ type BinlogVerifyStore struct { mutex *sync.Mutex // db => table => paginationKey => number of times it changed. + // For composite keys, we use a string representation of the keys // // We need to store the number of times the row has changed because of the // following series of events: @@ -101,6 +102,19 @@ type BinlogVerifyBatch struct { SchemaName string TableName string PaginationKeys []uint64 + + // Composite key support + IsComposite bool + CompositePaginationKeys []CompositeKey +} + +// Helper function to create composite key string for map storage +func compositeKeyString(keys []interface{}) string { + parts := make([]string, len(keys)) + for i, k := range keys { + parts[i] = fmt.Sprintf("%v", k) + } + return strings.Join(parts, ":") } func NewBinlogVerifyStore() *BinlogVerifyStore { diff --git a/row_batch.go b/row_batch.go index 4426fc127..e962406f1 100644 --- a/row_batch.go +++ b/row_batch.go @@ -8,6 +8,11 @@ import ( type RowBatch struct { values []RowData paginationKeyIndex int + + // Composite pagination support + isCompositePagination bool + compositePaginationIndexes []int + table *TableSchema fingerprints map[uint64][]byte columns []string diff --git a/state_tracker.go b/state_tracker.go index 760481a80..1cb0b0d1a 100644 --- a/state_tracker.go +++ b/state_tracker.go @@ -34,7 +34,12 @@ type SerializableState struct { GhostferryVersion string LastKnownTableSchemaCache TableSchemaCache + // Single column pagination (backward compatibility) LastSuccessfulPaginationKeys map[string]uint64 + + // Composite pagination support + LastSuccessfulCompositePaginationKeys map[string][]interface{} + CompletedTables map[string]bool LastWrittenBinlogPosition mysql.Position BinlogVerifyStore BinlogVerifySerializedStore @@ -92,7 +97,12 @@ type StateTracker struct { lastStoredBinlogPositionForInlineVerifier mysql.Position lastStoredBinlogPositionForTargetVerifier mysql.Position + // Single column pagination (backward compatibility) lastSuccessfulPaginationKeys map[string]uint64 + + // Composite pagination support + lastSuccessfulCompositePaginationKeys map[string][]interface{} + completedTables map[string]bool // TODO: Performance tracking should be refactored out of the state tracker, @@ -106,10 +116,11 @@ func NewStateTracker(speedLogCount int) *StateTracker { BinlogRWMutex: &sync.RWMutex{}, CopyRWMutex: &sync.RWMutex{}, - lastSuccessfulPaginationKeys: make(map[string]uint64), - completedTables: make(map[string]bool), - iterationSpeedLog: newSpeedLogRing(speedLogCount), - rowStatsWrittenPerTable: make(map[string]RowStats), + lastSuccessfulPaginationKeys: make(map[string]uint64), + lastSuccessfulCompositePaginationKeys: make(map[string][]interface{}), + completedTables: make(map[string]bool), + iterationSpeedLog: newSpeedLogRing(speedLogCount), + rowStatsWrittenPerTable: make(map[string]RowStats), } } @@ -118,6 +129,9 @@ func NewStateTracker(speedLogCount int) *StateTracker { func NewStateTrackerFromSerializedState(speedLogCount int, serializedState *SerializableState) *StateTracker { s := NewStateTracker(speedLogCount) s.lastSuccessfulPaginationKeys = serializedState.LastSuccessfulPaginationKeys + if serializedState.LastSuccessfulCompositePaginationKeys != nil { + s.lastSuccessfulCompositePaginationKeys = serializedState.LastSuccessfulCompositePaginationKeys + } s.completedTables = serializedState.CompletedTables s.lastWrittenBinlogPosition = serializedState.LastWrittenBinlogPosition s.lastStoredBinlogPositionForInlineVerifier = serializedState.LastStoredBinlogPositionForInlineVerifier @@ -162,6 +176,25 @@ func (s *StateTracker) UpdateLastSuccessfulPaginationKey(table string, paginatio s.updateSpeedLog(deltaPaginationKey) } +// UpdateLastSuccessfulCompositePaginationKey updates composite pagination key progress +func (s *StateTracker) UpdateLastSuccessfulCompositePaginationKey(table string, keys []interface{}, rowStats RowStats) { + s.CopyRWMutex.Lock() + defer s.CopyRWMutex.Unlock() + + s.lastSuccessfulCompositePaginationKeys[table] = keys + + // Also update single key for backward compatibility if first key is uint64 + if len(keys) > 0 { + if firstKey, ok := keys[0].(uint64); ok { + deltaPaginationKey := firstKey - s.lastSuccessfulPaginationKeys[table] + s.lastSuccessfulPaginationKeys[table] = firstKey + s.updateSpeedLog(deltaPaginationKey) + } + } + + s.updateRowStatsForTable(table, rowStats) +} + func (s *StateTracker) RowStatsWrittenPerTable() map[string]RowStats { s.CopyRWMutex.RLock() defer s.CopyRWMutex.RUnlock() @@ -191,6 +224,20 @@ func (s *StateTracker) LastSuccessfulPaginationKey(table string) uint64 { return paginationKey } +// LastSuccessfulCompositePaginationKey gets the last successful composite pagination key +func (s *StateTracker) LastSuccessfulCompositePaginationKey(table string) ([]interface{}, bool) { + s.CopyRWMutex.RLock() + defer s.CopyRWMutex.RUnlock() + + _, found := s.completedTables[table] + if found { + return nil, false + } + + keys, found := s.lastSuccessfulCompositePaginationKeys[table] + return keys, found +} + func (s *StateTracker) MarkTableAsCompleted(table string) { s.CopyRWMutex.Lock() defer s.CopyRWMutex.Unlock() @@ -264,6 +311,7 @@ func (s *StateTracker) Serialize(lastKnownTableSchemaCache TableSchemaCache, bin GhostferryVersion: VersionString, LastKnownTableSchemaCache: lastKnownTableSchemaCache, LastSuccessfulPaginationKeys: make(map[string]uint64), + LastSuccessfulCompositePaginationKeys: make(map[string][]interface{}), CompletedTables: make(map[string]bool), LastWrittenBinlogPosition: s.lastWrittenBinlogPosition, LastStoredBinlogPositionForInlineVerifier: s.lastStoredBinlogPositionForInlineVerifier, @@ -280,6 +328,11 @@ func (s *StateTracker) Serialize(lastKnownTableSchemaCache TableSchemaCache, bin for k, v := range s.lastSuccessfulPaginationKeys { state.LastSuccessfulPaginationKeys[k] = v } + + // Copy composite pagination keys + for k, v := range s.lastSuccessfulCompositePaginationKeys { + state.LastSuccessfulCompositePaginationKeys[k] = v + } for k, v := range s.completedTables { state.CompletedTables[k] = v diff --git a/table_schema_cache.go b/table_schema_cache.go index ca5b1df81..e8210fc51 100644 --- a/table_schema_cache.go +++ b/table_schema_cache.go @@ -40,8 +40,15 @@ type TableSchema struct { CompressedColumnsForVerification map[string]string // Map of column name => compression type IgnoredColumnsForVerification map[string]struct{} // Set of column name ForcedIndexForVerification string // Forced index name + + // Single column pagination (backward compatibility) PaginationKeyColumn *schema.TableColumn PaginationKeyIndex int + + // Composite pagination support (max 3 columns) + IsCompositePagination bool + CompositePaginationColumns []*schema.TableColumn + CompositePaginationIndexes []int rowMd5Query string } @@ -126,29 +133,74 @@ func QuotedTableNameFromString(database, table string) string { return fmt.Sprintf("`%s`.`%s`", database, table) } +// MaxPaginationKeysData holds both single and composite pagination keys +type MaxPaginationKeysData struct { + SingleKeys map[*TableSchema]uint64 + CompositeKeys map[*TableSchema][]interface{} +} + func MaxPaginationKeys(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (map[*TableSchema]uint64, []*TableSchema, error) { - tablesWithData := make(map[*TableSchema]uint64) + data, emptyTables, err := MaxPaginationKeysWithComposite(db, tables, logger) + if err != nil { + return nil, nil, err + } + return data.SingleKeys, emptyTables, nil +} + +func MaxPaginationKeysWithComposite(db *sql.DB, tables []*TableSchema, logger *logrus.Entry) (*MaxPaginationKeysData, []*TableSchema, error) { + data := &MaxPaginationKeysData{ + SingleKeys: make(map[*TableSchema]uint64), + CompositeKeys: make(map[*TableSchema][]interface{}), + } emptyTables := make([]*TableSchema, 0, len(tables)) for _, table := range tables { logger := logger.WithField("table", table.String()) - maxPaginationKey, maxPaginationKeyExists, err := maxPaginationKey(db, table) - if err != nil { - logger.WithError(err).Errorf("failed to get max primary key %s", table.GetPaginationColumn().Name) - return tablesWithData, emptyTables, err - } + if table.IsCompositePagination { + // Handle composite pagination + maxKeys, maxKeysExist, err := maxCompositePaginationKey(db, table) + if err != nil { + logger.WithError(err).Error("failed to get max composite pagination keys") + return nil, emptyTables, err + } - if !maxPaginationKeyExists { - emptyTables = append(emptyTables, table) - logger.Warn("no data in this table, skipping") - continue - } + if !maxKeysExist { + emptyTables = append(emptyTables, table) + logger.Warn("no data in this table, skipping") + continue + } - tablesWithData[table] = maxPaginationKey + data.CompositeKeys[table] = maxKeys + + // For backward compatibility, store first key as uint64 if possible + if len(maxKeys) > 0 { + if firstKey, ok := maxKeys[0].(uint64); ok { + data.SingleKeys[table] = firstKey + } else { + // Use MaxUint64-1 to indicate composite key table with data + data.SingleKeys[table] = math.MaxUint64 - 1 + } + } + } else { + // Handle single pagination key + maxPaginationKey, maxPaginationKeyExists, err := maxPaginationKey(db, table) + if err != nil { + logger.WithError(err).Errorf("failed to get max primary key %s", table.GetPaginationColumn().Name) + return nil, emptyTables, err + } + + if !maxPaginationKeyExists { + emptyTables = append(emptyTables, table) + logger.Warn("no data in this table, skipping") + continue + } + + data.SingleKeys[table] = maxPaginationKey + } } - return tablesWithData, emptyTables, nil + return data, emptyTables, nil } func LoadTables(db *sql.DB, tableFilter TableFilter, columnCompressionConfig ColumnCompressionConfig, columnIgnoreConfig ColumnIgnoreConfig, forceIndexConfig ForceIndexConfig, cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) (TableSchemaCache, error) { @@ -216,13 +268,11 @@ func LoadTables(db *sql.DB, tableFilter TableFilter, columnCompressionConfig Col tableLog := dbLog.WithField("table", tableName) tableLog.Debug("caching table schema") - paginationKeyColumn, paginationKeyIndex, err := tableSchema.paginationKeyColumn(cascadingPaginationColumnConfig) + err := tableSchema.setupPaginationColumns(cascadingPaginationColumnConfig) if err != nil { logger.WithError(err).Error("invalid table") return tableSchemaCache, err } - tableSchema.PaginationKeyColumn = paginationKeyColumn - tableSchema.PaginationKeyIndex = paginationKeyIndex tableSchemaCache[tableSchema.String()] = tableSchema } @@ -257,6 +307,48 @@ func NonNumericPaginationKeyError(schema, table, paginationKey string) error { return fmt.Errorf("Pagination Key `%s` for %s is non-numeric", paginationKey, QuotedTableNameFromString(schema, table)) } +// Validates if column type is supported for pagination (bigint or varchar(255)) +func isValidPaginationColumnType(col *schema.TableColumn) bool { + return col.Type == schema.TYPE_NUMBER || + col.Type == schema.TYPE_MEDIUM_INT || + (col.Type == schema.TYPE_STRING && col.ColumnType == "varchar(255)") +} + +func (t *TableSchema) setupPaginationColumns(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) error { + // First check for composite pagination config + if compositeColumns, found := cascadingPaginationColumnConfig.CompositePaginationColumnsFor(t.Schema, t.Name); found { + t.IsCompositePagination = true + t.CompositePaginationColumns = make([]*schema.TableColumn, len(compositeColumns)) + t.CompositePaginationIndexes = make([]int, len(compositeColumns)) + + for i, colName := range compositeColumns { + col, idx, err := t.findColumnByName(colName) + if err != nil { + return err + } + if !isValidPaginationColumnType(col) { + return fmt.Errorf("Composite pagination column `%s` for %s must be BIGINT or VARCHAR(255)", colName, QuotedTableNameFromString(t.Schema, t.Name)) + } + t.CompositePaginationColumns[i] = col + t.CompositePaginationIndexes[i] = idx + } + // For backward compatibility, set the first column as the primary pagination column + t.PaginationKeyColumn = t.CompositePaginationColumns[0] + t.PaginationKeyIndex = t.CompositePaginationIndexes[0] + return nil + } + + // Fall back to single column pagination + col, idx, err := t.paginationKeyColumn(cascadingPaginationColumnConfig) + if err != nil { + return err + } + t.PaginationKeyColumn = col + t.PaginationKeyIndex = idx + t.IsCompositePagination = false + return nil +} + func (t *TableSchema) paginationKeyColumn(cascadingPaginationColumnConfig *CascadingPaginationColumnConfig) (*schema.TableColumn, int, error) { var err error var paginationKeyColumn *schema.TableColumn @@ -424,3 +516,58 @@ func maxPaginationKey(db *sql.DB, table *TableSchema) (uint64, bool, error) { return maxPaginationKey, true, nil } } + +func maxCompositePaginationKey(db *sql.DB, table *TableSchema) ([]interface{}, bool, error) { + // Build SELECT columns + columns := make([]string, len(table.CompositePaginationColumns)) + orderBy := make([]string, len(table.CompositePaginationColumns)) + + for i, col := range table.CompositePaginationColumns { + columns[i] = QuoteField(col.Name) + orderBy[i] = fmt.Sprintf("%s DESC", QuoteField(col.Name)) + } + + query, args, err := sq. + Select(columns...). + From(QuotedTableName(table)). + OrderBy(orderBy...). + Limit(1). + ToSql() + + if err != nil { + return nil, false, err + } + + // Prepare scan destinations + scanDest := make([]interface{}, len(columns)) + result := make([]interface{}, len(columns)) + + for i, col := range table.CompositePaginationColumns { + if col.Type == schema.TYPE_STRING { + var s string + scanDest[i] = &s + } else { + var n uint64 + scanDest[i] = &n + } + } + + err = db.QueryRow(query, args...).Scan(scanDest...) + + switch { + case err == sqlorig.ErrNoRows: + return nil, false, nil + case err != nil: + return nil, false, err + default: + // Copy scanned values to result + for i, col := range table.CompositePaginationColumns { + if col.Type == schema.TYPE_STRING { + result[i] = *(scanDest[i].(*string)) + } else { + result[i] = *(scanDest[i].(*uint64)) + } + } + return result, true, nil + } +}