diff --git a/config.go b/config.go index 9958ec9..d6ce9b8 100644 --- a/config.go +++ b/config.go @@ -54,8 +54,8 @@ type Config struct { Realm string `yaml:"realm" json:"realm" mapstructure:"realm"` // these security configurations will be ignored if the protocol does not contain +s - UseSystemCertPool bool `yaml:"use_system_cert_pool" mapstructure:"use_system_cert_pool"` - CAFileLocation string `yaml:"ca_file_location" mapstructure:"ca_file_location"` + UseSystemCertPool bool `yaml:"use_system_cert_pool" mapstructure:"use_system_cert_pool"` + CAFileLocation string `yaml:"ca_file_location" mapstructure:"ca_file_location"` // Index Strategy defines the index strategy for GoGM IndexStrategy IndexStrategy `yaml:"index_strategy" json:"index_strategy" mapstructure:"index_strategy"` diff --git a/gogm.go b/gogm.go index 1174ed3..4b88003 100644 --- a/gogm.go +++ b/gogm.go @@ -20,6 +20,7 @@ package gogm import ( + "context" "crypto/x509" "errors" "fmt" @@ -58,6 +59,10 @@ type Gogm struct { } func New(config *Config, pkStrategy *PrimaryKeyStrategy, mapTypes ...interface{}) (*Gogm, error) { + return NewContext(context.Background(), config, pkStrategy, mapTypes...) +} + +func NewContext(ctx context.Context, config *Config, pkStrategy *PrimaryKeyStrategy, mapTypes ...interface{}) (*Gogm, error) { if config == nil { return nil, errors.New("config can not be nil") } @@ -81,7 +86,7 @@ func New(config *Config, pkStrategy *PrimaryKeyStrategy, mapTypes ...interface{} pkStrategy: pkStrategy, } - err := g.init() + err := g.init(ctx) if err != nil { return nil, fmt.Errorf("failed to init gogm instance, %w", err) } @@ -89,7 +94,7 @@ func New(config *Config, pkStrategy *PrimaryKeyStrategy, mapTypes ...interface{} return g, nil } -func (g *Gogm) init() error { +func (g *Gogm) init(ctx context.Context) error { err := g.validate() if err != nil { return err @@ -101,13 +106,14 @@ func (g *Gogm) init() error { } g.logger.Debug("establishing neo connection") - err = g.initDriver() + + err = g.initDriver(ctx) if err != nil { return err } g.logger.Debug("initializing indices") - return g.initIndex() + return g.initIndex(ctx) } func (g *Gogm) validate() error { @@ -159,7 +165,7 @@ func (g *Gogm) parseOgmTypes() error { return nil } -func (g *Gogm) initDriver() error { +func (g *Gogm) initDriver(ctx context.Context) error { var certPool *x509.CertPool isEncrypted := strings.Contains(g.config.Protocol, "+s") @@ -184,7 +190,6 @@ func (g *Gogm) initDriver() error { } } - neoConfig := func(neoConf *neo4j.Config) { if g.config.EnableDriverLogs { neoConf.Log = wrapLogger(g.logger) @@ -197,14 +202,45 @@ func (g *Gogm) initDriver() error { } } + doneChan := make(chan error, 1) + + _, hasDeadline := ctx.Deadline() + + go g.initDriverRoutine(neoConfig, doneChan) + + if hasDeadline { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + select { + case err := <-doneChan: + if err != nil { + return fmt.Errorf("failed to init driver, %w", err) + } + return nil + case <-ctx.Done(): + return errors.New("timed out initializing driver") + } + } else { + err := <-doneChan + if err != nil { + return fmt.Errorf("failed to init driver, %w", err) + } + return nil + } +} + +func (g *Gogm) initDriverRoutine(neoConfig func(neoConf *neo4j.Config), doneChan chan error) { driver, err := neo4j.NewDriver(g.config.ConnectionString(), neo4j.BasicAuth(g.config.Username, g.config.Password, g.config.Realm), neoConfig) if err != nil { - return fmt.Errorf("failed to create driver, %w", err) + doneChan <- fmt.Errorf("failed to create driver, %w", err) + return } err = driver.VerifyConnectivity() if err != nil { - return fmt.Errorf("failed to verify connectivity, %w", err) + doneChan <- fmt.Errorf("failed to verify connectivity, %w", err) + return } // set driver @@ -218,43 +254,46 @@ func (g *Gogm) initDriver() error { res, err := sess.Run("return 1", nil) if err != nil { - return err + doneChan <- err + return } else if err = res.Err(); err != nil { - return err + doneChan <- err + return } sum, err := res.Consume() if err != nil { - return err + doneChan <- err + return } version := strings.Split(strings.Replace(strings.ToLower(sum.Server().Version()), "neo4j/", "", -1), ".") g.neoVersion, err = strconv.ParseFloat(version[0], 64) if err != nil { - return err + doneChan <- err + return } - - return nil + doneChan <- nil } -func (g *Gogm) initIndex() error { +func (g *Gogm) initIndex(ctx context.Context) error { switch g.config.IndexStrategy { case ASSERT_INDEX: g.logger.Debug("chose ASSERT_INDEX strategy") g.logger.Debug("dropping all known indexes") - err := dropAllIndexesAndConstraints(g) + err := dropAllIndexesAndConstraints(ctx, g) if err != nil { return err } g.logger.Debug("creating all mapped indexes") - err = createAllIndexesAndConstraints(g, g.mappedTypes) + err = createAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { return err } g.logger.Debug("verifying all indexes") - err = verifyAllIndexesAndConstraints(g, g.mappedTypes) + err = verifyAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { return err } @@ -262,7 +301,7 @@ func (g *Gogm) initIndex() error { case VALIDATE_INDEX: g.logger.Debug("chose VALIDATE_INDEX strategy") g.logger.Debug("verifying all indexes") - err := verifyAllIndexesAndConstraints(g, g.mappedTypes) + err := verifyAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { return err } diff --git a/index.go b/index.go index 88506a9..acc23dd 100644 --- a/index.go +++ b/index.go @@ -20,32 +20,33 @@ package gogm import ( + "context" "github.com/cornelk/hashmap" ) //drops all known indexes -func dropAllIndexesAndConstraints(gogm *Gogm) error { +func dropAllIndexesAndConstraints(ctx context.Context, gogm *Gogm) error { if gogm.neoVersion >= 4 { - return dropAllIndexesAndConstraintsV4(gogm) + return dropAllIndexesAndConstraintsV4(ctx, gogm) } - return dropAllIndexesAndConstraintsV3(gogm) + return dropAllIndexesAndConstraintsV3(ctx, gogm) } //creates all indexes -func createAllIndexesAndConstraints(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func createAllIndexesAndConstraints(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { if gogm.neoVersion >= 4 { - return createAllIndexesAndConstraintsV4(gogm, mappedTypes) + return createAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes) } - return createAllIndexesAndConstraintsV3(gogm, mappedTypes) + return createAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) } //verifies all indexes -func verifyAllIndexesAndConstraints(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func verifyAllIndexesAndConstraints(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { if gogm.neoVersion >= 4 { - return verifyAllIndexesAndConstraintsV4(gogm, mappedTypes) + return verifyAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes) } - return verifyAllIndexesAndConstraintsV3(gogm, mappedTypes) + return verifyAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) } diff --git a/index_v3.go b/index_v3.go index f9300d4..b3012f0 100644 --- a/index_v3.go +++ b/index_v3.go @@ -55,7 +55,7 @@ func resultToStringArrV3(result [][]interface{}) ([]string, error) { } //drops all known indexes -func dropAllIndexesAndConstraintsV3(gogm *Gogm) error { +func dropAllIndexesAndConstraintsV3(ctx context.Context, gogm *Gogm) error { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeWrite, }) @@ -64,8 +64,6 @@ func dropAllIndexesAndConstraintsV3(gogm *Gogm) error { } defer sess.Close() - ctx := context.Background() - return sess.ManagedTransaction(ctx, func(tx TransactionV2) error { vals, _, err := tx.QueryRaw(ctx, "CALL db.constraints", nil) if err != nil { @@ -124,7 +122,7 @@ func dropAllIndexesAndConstraintsV3(gogm *Gogm) error { } //creates all indexes -func createAllIndexesAndConstraintsV3(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func createAllIndexesAndConstraintsV3(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeWrite, }) @@ -133,8 +131,6 @@ func createAllIndexesAndConstraintsV3(gogm *Gogm, mappedTypes *hashmap.HashMap) } defer sess.Close() - ctx := context.Background() - //validate that we have to do anything if mappedTypes == nil || mappedTypes.Len() == 0 { return errors.New("must have types to map") @@ -203,7 +199,7 @@ func createAllIndexesAndConstraintsV3(gogm *Gogm, mappedTypes *hashmap.HashMap) } //verifies all indexes -func verifyAllIndexesAndConstraintsV3(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func verifyAllIndexesAndConstraintsV3(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeRead, }) @@ -212,8 +208,6 @@ func verifyAllIndexesAndConstraintsV3(gogm *Gogm, mappedTypes *hashmap.HashMap) } defer sess.Close() - ctx := context.Background() - //validate that we have to do anything if mappedTypes == nil || mappedTypes.Len() == 0 { return errors.New("must have types to map") diff --git a/index_v4.go b/index_v4.go index 10aa23a..0e8ac02 100644 --- a/index_v4.go +++ b/index_v4.go @@ -62,7 +62,7 @@ func resultToStringArrV4(isConstraint bool, result [][]interface{}) ([]string, e } //drops all known indexes -func dropAllIndexesAndConstraintsV4(gogm *Gogm) error { +func dropAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm) error { for _, db := range gogm.config.TargetDbs { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeWrite, @@ -71,9 +71,6 @@ func dropAllIndexesAndConstraintsV4(gogm *Gogm) error { if err != nil { return err } - defer sess.Close() - - ctx := context.Background() err = sess.ManagedTransaction(ctx, func(tx TransactionV2) error { res, _, err := tx.QueryRaw(ctx, "CALL db.constraints()", nil) @@ -128,14 +125,23 @@ func dropAllIndexesAndConstraintsV4(gogm *Gogm) error { return nil }) if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("drop index transaction failed, %w", err) } + + err = sess.Close() + if err != nil { + return err + } } return nil } //creates all indexes -func createAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func createAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { for _, db := range gogm.config.TargetDbs { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeWrite, @@ -144,9 +150,6 @@ func createAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) if err != nil { return err } - defer sess.Close() - - ctx := context.Background() //validate that we have to do anything if mappedTypes == nil || mappedTypes.Len() == 0 { @@ -211,12 +214,23 @@ func createAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) gogm.logger.Debugf("created (%v) indexes", numIndexCreated) return nil }) + if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } + return err + } + err = sess.Close() + if err != nil { + return err + } } return nil } //verifies all indexes -func verifyAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) error { +func verifyAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { for _, db := range gogm.config.TargetDbs { sess, err := gogm.NewSessionV2(SessionConfig{ AccessMode: neo4j.AccessModeWrite, @@ -225,9 +239,6 @@ func verifyAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) if err != nil { return err } - defer sess.Close() - - ctx := context.Background() //validate that we have to do anything if mappedTypes == nil || mappedTypes.Len() == 0 { @@ -275,27 +286,47 @@ func verifyAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) //get whats there now foundResult, _, err := sess.QueryRaw(ctx, "CALL db.constraints", nil) if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("no constraints found, %w", err) } foundConstraints, err := resultToStringArrV4(true, foundResult) if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("failed to convert result to string array, %w", err) } foundInxdexResult, _, err := sess.QueryRaw(ctx, "CALL db.indexes()", nil) if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("no indices found, %w", err) } foundIndexes, err := resultToStringArrV4(false, foundInxdexResult) if err != nil { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("failed to convert result to array, %w", err) } //verify from there delta, found := arrayOperations.Difference(foundIndexes, indexes) if !found { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("found differences in remote vs ogm for found indexes, %v", delta) } @@ -309,10 +340,18 @@ func verifyAllIndexesAndConstraintsV4(gogm *Gogm, mappedTypes *hashmap.HashMap) delta, found = arrayOperations.Difference(founds, constraints) if !found { + _err := sess.Close() + if err != nil { + err = fmt.Errorf("%s: %w", err, _err) + } return fmt.Errorf("found differences in remote vs ogm for found constraints, %v", delta) } gogm.logger.Debugf("%+v", delta) + err = sess.Close() + if err != nil { + return err + } } return nil