From 5ac9c0e3b226e0e4783492f39f47c396c2082251 Mon Sep 17 00:00:00 2001 From: david may <1301201+wass3r@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:06:26 -0500 Subject: [PATCH] feat(db/log): use logrus as logger for db (#1135) * use logrus for gorm * use witherror * add knobs for db logging * false is default * show normal set of logrus levels * extend test struct with new fields * update test function with new fields --- database/context.go | 4 + database/database.go | 61 ++++++++++++- database/database_test.go | 28 ++++++ database/flags.go | 27 ++++++ database/logger.go | 79 +++++++++++++++++ database/logger_test.go | 51 +++++++++++ database/opts.go | 40 +++++++++ database/opts_test.go | 176 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 462 insertions(+), 4 deletions(-) create mode 100644 database/logger.go create mode 100644 database/logger_test.go diff --git a/database/context.go b/database/context.go index fd4cae217..a11863e7d 100644 --- a/database/context.go +++ b/database/context.go @@ -49,6 +49,10 @@ func FromCLIContext(c *cli.Context) (Interface, error) { WithConnectionOpen(c.Int("database.connection.open")), WithDriver(c.String("database.driver")), WithEncryptionKey(c.String("database.encryption.key")), + WithLogLevel(c.String("database.log.level")), + WithLogSkipNotFound(c.Bool("database.log.skip_notfound")), + WithLogSlowThreshold(c.Duration("database.log.slow_threshold")), + WithLogShowSQL(c.Bool("database.log.show_sql")), WithSkipCreation(c.Bool("database.skip_creation")), ) } diff --git a/database/database.go b/database/database.go index 9ff3a2855..f9e2658ed 100644 --- a/database/database.go +++ b/database/database.go @@ -48,6 +48,14 @@ type ( Driver string // specifies the encryption key to use for the database engine EncryptionKey string + // specifies the database engine specific log level + LogLevel string + // specifies to skip logging when a record is not found + LogSkipNotFound bool + // specifies the threshold for slow queries in the database engine + LogSlowThreshold time.Duration + // specifies whether to log SQL queries in the database engine + LogShowSQL bool // specifies to skip creating tables and indexes for the database engine SkipCreation bool } @@ -113,21 +121,62 @@ func New(opts ...EngineOpt) (Interface, error) { return nil, err } - // update the logger with additional metadata + // by default use the global logger with additional metadata e.logger = logrus.NewEntry(logrus.StandardLogger()).WithField("database", e.Driver()) + // translate the log level to logrus level for the database engine + var dbLogLevel logrus.Level + + switch e.config.LogLevel { + case "t", "trace", "Trace", "TRACE": + dbLogLevel = logrus.TraceLevel + case "d", "debug", "Debug", "DEBUG": + dbLogLevel = logrus.DebugLevel + case "i", "info", "Info", "INFO": + dbLogLevel = logrus.InfoLevel + case "w", "warn", "Warn", "WARN": + dbLogLevel = logrus.WarnLevel + case "e", "error", "Error", "ERROR": + dbLogLevel = logrus.ErrorLevel + case "f", "fatal", "Fatal", "FATAL": + dbLogLevel = logrus.FatalLevel + case "p", "panic", "Panic", "PANIC": + dbLogLevel = logrus.PanicLevel + } + + // if the log level for the database engine is different than + // the global log level, create a new logrus instance + if dbLogLevel != logrus.GetLevel() { + log := logrus.New() + + // set the custom log level + log.Level = dbLogLevel + + // copy the formatter from the global logger to + // retain the same format for the database engine + log.Formatter = logrus.StandardLogger().Formatter + + // update the logger with additional metadata + e.logger = logrus.NewEntry(log).WithField("database", e.Driver()) + } + e.logger.Trace("creating database engine from configuration") - // process the database driver being provided + + // configure gorm to use logrus as internal logger + gormConfig := &gorm.Config{ + Logger: NewGormLogger(e.logger, e.config.LogSlowThreshold, e.config.LogSkipNotFound, e.config.LogShowSQL), + } + switch e.config.Driver { case constants.DriverPostgres: // create the new Postgres database client - e.client, err = gorm.Open(postgres.Open(e.config.Address), &gorm.Config{}) + e.client, err = gorm.Open(postgres.Open(e.config.Address), gormConfig) if err != nil { return nil, err } case constants.DriverSqlite: // create the new Sqlite database client - e.client, err = gorm.Open(sqlite.Open(e.config.Address), &gorm.Config{}) + e.client, err = gorm.Open(sqlite.Open(e.config.Address), gormConfig) if err != nil { return nil, err } @@ -177,5 +226,9 @@ func NewTest() (Interface, error) { WithDriver("sqlite3"), WithEncryptionKey("A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW"), WithSkipCreation(false), + WithLogLevel("warn"), + WithLogShowSQL(false), + WithLogSkipNotFound(true), + WithLogSlowThreshold(200*time.Millisecond), ) } diff --git a/database/database_test.go b/database/database_test.go index 26ea7e471..a8b70111a 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -32,6 +32,10 @@ func TestDatabase_New(t *testing.T) { ConnectionOpen: 20, EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, }, { @@ -46,6 +50,10 @@ func TestDatabase_New(t *testing.T) { ConnectionOpen: 20, EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, }, { @@ -60,6 +68,10 @@ func TestDatabase_New(t *testing.T) { ConnectionOpen: 20, EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, }, { @@ -74,6 +86,10 @@ func TestDatabase_New(t *testing.T) { ConnectionOpen: 20, EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, }, } @@ -88,6 +104,10 @@ func TestDatabase_New(t *testing.T) { WithConnectionIdle(test.config.ConnectionIdle), WithConnectionOpen(test.config.ConnectionOpen), WithDriver(test.config.Driver), + WithLogLevel(test.config.LogLevel), + WithLogShowSQL(test.config.LogShowSQL), + WithLogSkipNotFound(test.config.LogSkipNotFound), + WithLogSlowThreshold(test.config.LogSlowThreshold), WithEncryptionKey(test.config.EncryptionKey), WithSkipCreation(test.config.SkipCreation), ) @@ -119,6 +139,10 @@ func testPostgres(t *testing.T) (*engine, sqlmock.Sqlmock) { Driver: "postgres", EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, logger: logrus.NewEntry(logrus.StandardLogger()), } @@ -161,6 +185,10 @@ func testSqlite(t *testing.T) *engine { Driver: "sqlite3", EncryptionKey: "A1B2C3D4E5G6H7I8J9K0LMNOPQRSTUVW", SkipCreation: false, + LogLevel: "info", + LogSkipNotFound: true, + LogSlowThreshold: 100 * time.Millisecond, + LogShowSQL: false, }, logger: logrus.NewEntry(logrus.StandardLogger()), } diff --git a/database/flags.go b/database/flags.go index 8655c1808..129143655 100644 --- a/database/flags.go +++ b/database/flags.go @@ -60,6 +60,33 @@ var Flags = []cli.Flag{ Name: "database.encryption.key", Usage: "AES-256 key for encrypting and decrypting values in the database", }, + &cli.StringFlag{ + EnvVars: []string{"VELA_DATABASE_LOG_LEVEL", "DATABASE_LOG_LEVEL"}, + FilePath: "/vela/database/log_level", + Name: "database.log.level", + Usage: "set log level - options: (trace|debug|info|warn|error|fatal|panic)", + Value: "warn", + }, + &cli.BoolFlag{ + EnvVars: []string{"VELA_DATABASE_LOG_SHOW_SQL", "DATABASE_LOG_SHOW_SQL"}, + FilePath: "/vela/database/log_show_sql", + Name: "database.log.show_sql", + Usage: "show the SQL query in the logs", + }, + &cli.BoolFlag{ + EnvVars: []string{"VELA_DATABASE_LOG_SKIP_NOTFOUND", "DATABASE_LOG_SKIP_NOTFOUND"}, + FilePath: "/vela/database/log_skip_notfound", + Name: "database.log.skip_notfound", + Usage: "skip logging when a resource is not found in the database", + Value: true, + }, + &cli.DurationFlag{ + EnvVars: []string{"VELA_DATABASE_LOG_SLOW_THRESHOLD", "DATABASE_LOG_SLOW_THRESHOLD"}, + FilePath: "/vela/database/log_slow_threshold", + Name: "database.log.slow_threshold", + Usage: "queries that take longer than this threshold are considered slow and will be logged", + Value: 200 * time.Millisecond, + }, &cli.BoolFlag{ EnvVars: []string{"VELA_DATABASE_SKIP_CREATION", "DATABASE_SKIP_CREATION"}, FilePath: "/vela/database/skip_creation", diff --git a/database/logger.go b/database/logger.go new file mode 100644 index 000000000..819574704 --- /dev/null +++ b/database/logger.go @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "context" + "errors" + "time" + + "github.com/sirupsen/logrus" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/utils" +) + +// GormLogger is a custom logger for Gorm. +type GormLogger struct { + slowThreshold time.Duration + skipErrRecordNotFound bool + showSQL bool + entry *logrus.Entry +} + +// NewGormLogger creates a new Gorm logger. +func NewGormLogger(logger *logrus.Entry, slowThreshold time.Duration, skipNotFound, showSQL bool) *GormLogger { + return &GormLogger{ + skipErrRecordNotFound: skipNotFound, + slowThreshold: slowThreshold, + showSQL: showSQL, + entry: logger, + } +} + +// LogMode sets the log mode for the logger. +func (l *GormLogger) LogMode(logger.LogLevel) logger.Interface { + return l +} + +// Info implements the logger.Interface. +func (l *GormLogger) Info(ctx context.Context, msg string, args ...interface{}) { + l.entry.WithContext(ctx).Info(msg, args) +} + +// Warn implements the logger.Interface. +func (l *GormLogger) Warn(ctx context.Context, msg string, args ...interface{}) { + l.entry.WithContext(ctx).Warn(msg, args) +} + +// Error implements the logger.Interface. +func (l *GormLogger) Error(ctx context.Context, msg string, args ...interface{}) { + l.entry.WithContext(ctx).Error(msg, args) +} + +// Trace implements the logger.Interface. +func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + elapsed := time.Since(begin) + sql, rows := fc() + fields := logrus.Fields{ + "rows": rows, + "elapsed": elapsed, + "source": utils.FileWithLineNum(), + } + + if l.showSQL { + fields["sql"] = sql + } + + if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.skipErrRecordNotFound) { + l.entry.WithContext(ctx).WithError(err).WithFields(fields).Error("gorm error") + return + } + + if l.slowThreshold != 0 && elapsed > l.slowThreshold { + l.entry.WithContext(ctx).WithFields(fields).Warnf("gorm warn SLOW QUERY >= %s, took %s", l.slowThreshold, elapsed) + return + } + + l.entry.WithContext(ctx).WithFields(fields).Infof("gorm info") +} diff --git a/database/logger_test.go b/database/logger_test.go new file mode 100644 index 000000000..ce64513bf --- /dev/null +++ b/database/logger_test.go @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sirupsen/logrus" +) + +func TestNewGormLogger(t *testing.T) { + logger := logrus.NewEntry(logrus.New()) + + type args struct { + logger *logrus.Entry + slowThreshold time.Duration + skipNotFound bool + showSQL bool + } + tests := []struct { + name string + args args + want *GormLogger + }{ + { + name: "logger set", + args: args{ + logger: logger, + slowThreshold: time.Second, + skipNotFound: false, + showSQL: true, + }, + want: &GormLogger{ + slowThreshold: time.Second, + skipErrRecordNotFound: false, + showSQL: true, + entry: logger, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if diff := cmp.Diff(NewGormLogger(tt.args.logger, tt.args.slowThreshold, tt.args.skipNotFound, tt.args.showSQL), tt.want, cmpopts.EquateComparable(GormLogger{})); diff != "" { + t.Errorf("NewGormLogger() mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/database/opts.go b/database/opts.go index 2e758d802..3a35d751e 100644 --- a/database/opts.go +++ b/database/opts.go @@ -80,6 +80,46 @@ func WithEncryptionKey(encryptionKey string) EngineOpt { } } +// WithLogLevel sets the log level in the database engine. +func WithLogLevel(logLevel string) EngineOpt { + return func(e *engine) error { + // set the log level for the database engine + e.config.LogLevel = logLevel + + return nil + } +} + +// WithLogSkipNotFound sets the log skip not found option in the database engine. +func WithLogSkipNotFound(logSkipNotFound bool) EngineOpt { + return func(e *engine) error { + // set the log skip not found option for the database engine + e.config.LogSkipNotFound = logSkipNotFound + + return nil + } +} + +// WithLogSlowThreshold sets the log slow query threshold in the database engine. +func WithLogSlowThreshold(logSlowThreshold time.Duration) EngineOpt { + return func(e *engine) error { + // set the slow query threshold for the database engine + e.config.LogSlowThreshold = logSlowThreshold + + return nil + } +} + +// WithLogShowSQL sets the log show SQL option in the database engine. +func WithLogShowSQL(logShowSQL bool) EngineOpt { + return func(e *engine) error { + // set the log show SQL option for the database engine + e.config.LogShowSQL = logShowSQL + + return nil + } +} + // WithSkipCreation sets the skip creation logic in the database engine. func WithSkipCreation(skipCreation bool) EngineOpt { return func(e *engine) error { diff --git a/database/opts_test.go b/database/opts_test.go index 7089d9271..abd65ad94 100644 --- a/database/opts_test.go +++ b/database/opts_test.go @@ -405,3 +405,179 @@ func TestDatabase_EngineOpt_WithSkipCreation(t *testing.T) { }) } } + +func TestDatabase_EngineOpt_WithLogLevel(t *testing.T) { + e := &engine{config: new(config)} + + tests := []struct { + failure bool + name string + logLevel string + want string + }{ + { + failure: false, + name: "log level set to debug", + logLevel: "debug", + want: "debug", + }, + { + failure: false, + name: "log level set to info", + logLevel: "info", + want: "info", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithLogLevel(test.logLevel)(e) + + if test.failure { + if err == nil { + t.Errorf("WithLogLevel for %s should have returned err", test.name) + } + + return + } + + if err != nil { + t.Errorf("WithLogLevel returned err: %v", err) + } + + if !reflect.DeepEqual(e.config.LogLevel, test.want) { + t.Errorf("WithLogLevel is %v, want %v", e.config.SkipCreation, test.want) + } + }) + } +} + +func TestDatabase_EngineOpt_WithLogSkipNotFound(t *testing.T) { + e := &engine{config: new(config)} + + tests := []struct { + failure bool + name string + skip bool + want bool + }{ + { + failure: false, + name: "log skip not found set to true", + skip: true, + want: true, + }, + { + failure: false, + name: "log skip not found set to false", + skip: false, + want: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithLogSkipNotFound(test.skip)(e) + + if test.failure { + if err == nil { + t.Errorf("WithLogSkipNotFound for %s should have returned err", test.name) + } + + if err != nil { + t.Errorf("WithLogSkipNotFound for %s returned err: %v", test.name, err) + } + + if !reflect.DeepEqual(e.config.LogSkipNotFound, test.want) { + t.Errorf("WithLogSkipNotFound for %s is %v, want %v", test.name, e.config.LogSkipNotFound, test.want) + } + } + }) + } +} + +func TestDatabase_EngineOpt_WithLogSlowThreshold(t *testing.T) { + e := &engine{config: new(config)} + + tests := []struct { + failure bool + name string + threshold time.Duration + want time.Duration + }{ + { + failure: false, + name: "log slow threshold set to 1ms", + threshold: 1 * time.Millisecond, + want: 1 * time.Millisecond, + }, + { + failure: false, + name: "log slow threshold set to 2ms", + threshold: 2 * time.Millisecond, + want: 2 * time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithLogSlowThreshold(test.threshold)(e) + + if test.failure { + if err == nil { + t.Errorf("WithLogSlowThreshold for %s should have returned err", test.name) + } + + if err != nil { + t.Errorf("WithLogSlowThreshold for %s returned err: %v", test.name, err) + } + + if !reflect.DeepEqual(e.config.LogSlowThreshold, test.want) { + t.Errorf("WithLogSlowThreshold for %s is %v, want %v", test.name, e.config.LogSlowThreshold, test.want) + } + } + }) + } +} + +func TestDatabase_EngineOpt_WithLogShowSQL(t *testing.T) { + e := &engine{config: new(config)} + + tests := []struct { + failure bool + name string + show bool + want bool + }{ + { + failure: false, + name: "log show SQL set to true", + show: true, + want: true, + }, + { + failure: false, + name: "log show SQL set to false", + show: false, + want: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithLogShowSQL(test.show)(e) + + if test.failure { + if err == nil { + t.Errorf("WithLogShowSQL for %s should have returned err", test.name) + } + + if err != nil { + t.Errorf("WithLogShowSQL for %s returned err: %v", test.name, err) + } + + if !reflect.DeepEqual(e.config.LogShowSQL, test.want) { + t.Errorf("WithLogShowSQL for %s is %v, want %v", test.name, e.config.LogShowSQL, test.want) + } + } + }) + } +}