diff --git a/config.go b/config.go index 83c6fa6..e741344 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ package gogm import ( + "crypto/tls" "errors" "fmt" "time" @@ -53,9 +54,13 @@ type Config struct { Realm string `yaml:"realm" json:"realm" mapstructure:"realm"` - // these security configurations will be ignored if the protocol does not contain +s - UseSystemCertPool bool `yaml:"use_system_cert_pool" mapstructure:"use_system_cert_pool"` - CAFileLocation string `yaml:"ca_file_location" mapstructure:"ca_file_location"` + // deprecated: in favor of tls config + //these security configurations will be ignored if the protocol does not contain +s + UseSystemCertPool bool `yaml:"use_system_cert_pool" mapstructure:"use_system_cert_pool"` + // deprecated: in favor of tls config + CAFileLocation string `yaml:"ca_file_location" mapstructure:"ca_file_location"` + + TLSConfig *tls.Config `yaml:"tls_config" mapstructure:"tls_config"` // Index Strategy defines the index strategy for GoGM IndexStrategy IndexStrategy `yaml:"index_strategy" json:"index_strategy" mapstructure:"index_strategy"` diff --git a/go.mod b/go.mod index 3fab09c..09269be 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/mindstand/gogm/v2 -go 1.13 +go 1.16 require ( github.com/adam-hanna/arrayOperations v0.2.6 @@ -11,7 +11,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/kr/pretty v0.1.0 // indirect github.com/mindstand/go-cypherdsl v0.2.0 - github.com/neo4j/neo4j-go-driver/v4 v4.3.3 + github.com/neo4j/neo4j-go-driver/v4 v4.4.2-0.20220317151800-1a19fb114732 github.com/opentracing/opentracing-go v1.2.0 github.com/sirupsen/logrus v1.6.0 // indirect github.com/stretchr/objx v0.2.0 // indirect diff --git a/go.sum b/go.sum index 8931dbe..f2986fb 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,7 @@ github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4= github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -44,16 +45,26 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mindstand/go-cypherdsl v0.2.0 h1:/B6A8DhWk2RksdJxruy3+ii3Hvrr5JU+2vL3/oJMLrI= github.com/mindstand/go-cypherdsl v0.2.0/go.mod h1:swzbrSTuq3CRgFglg3aVThG9GBQmHXz6AY81q9mRMto= +github.com/mindstand/neo4j-go-driver/v4 v4.3.4-0.20220222141004-33255dc85550 h1:K5JAVs3c5c2n1+jzLejM1vNd8qD8wgLvIyE4WD5tcu8= +github.com/mindstand/neo4j-go-driver/v4 v4.3.4-0.20220222141004-33255dc85550/go.mod h1:HQcN7yHl5H0hJKaYuF3TBH9m/Ym0C0ua4NHzKUr9uSc= +github.com/mindstand/neo4j-go-driver/v4 v4.4.0-beta1.0.20220222145400-d3c6fddab3cc h1:PY2MtEh34XJ/r4rHgxqnsby9lpmijQLWVcC7h0s+pyU= +github.com/mindstand/neo4j-go-driver/v4 v4.4.0-beta1.0.20220222145400-d3c6fddab3cc/go.mod h1:NexOfrm4c317FVjekrhVV8pHBXgtMG5P6GeweJWCyo4= +github.com/mindstand/neo4j-go-driver/v4 v4.4.0-beta1.0.20220314193855-8ea056066b14 h1:+C3FJ2je/N1/ejuLCg6W2/NM5V5D4T6NC6Pf2bAe8t4= +github.com/mindstand/neo4j-go-driver/v4 v4.4.0-beta1.0.20220314193855-8ea056066b14/go.mod h1:NexOfrm4c317FVjekrhVV8pHBXgtMG5P6GeweJWCyo4= github.com/neo4j/neo4j-go-driver/v4 v4.3.3 h1:QwM0IN1L6q1+N9cNqjv9Pmj4J4qCVauczQZdFsDafv8= github.com/neo4j/neo4j-go-driver/v4 v4.3.3/go.mod h1:G+DuMWSR9Auvbm6tk+fHNIegnfswAsmXgP/ibvwOY2Q= +github.com/neo4j/neo4j-go-driver/v4 v4.4.2-0.20220317151800-1a19fb114732 h1:nHwy/xxNFLe9MrQy0ub1JCyrh2p/bAW79uFDX3exjHQ= +github.com/neo4j/neo4j-go-driver/v4 v4.4.2-0.20220317151800-1a19fb114732/go.mod h1:NexOfrm4c317FVjekrhVV8pHBXgtMG5P6GeweJWCyo4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.14.0/go.mod h1:cIuvLEne0aoVhAgh/O6ac0Op8WWw9H6eYCriF+tEHG0= +github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -104,6 +115,8 @@ golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/gogm.go b/gogm.go index f9273d6..7c1f8f4 100644 --- a/gogm.go +++ b/gogm.go @@ -21,6 +21,7 @@ package gogm import ( "context" + "crypto/tls" "crypto/x509" "errors" "fmt" @@ -97,23 +98,28 @@ func NewContext(ctx context.Context, config *Config, pkStrategy *PrimaryKeyStrat func (g *Gogm) init(ctx context.Context) error { err := g.validate() if err != nil { - return err + return fmt.Errorf("failed to validate config, %w", err) } err = g.parseOgmTypes() if err != nil { - return err + return fmt.Errorf("failed to parse ogm types, %w", err) } g.logger.Debug("establishing neo connection") err = g.initDriver(ctx) if err != nil { - return err + return fmt.Errorf("failed to initialize driver, %w", err) } g.logger.Debug("initializing indices") - return g.initIndex(ctx) + err = g.initIndex(ctx) + if err != nil { + return fmt.Errorf("failed to init indices, %w", err) + } + + return nil } func (g *Gogm) validate() error { @@ -147,7 +153,7 @@ func (g *Gogm) parseOgmTypes() error { name := reflect.TypeOf(t).Elem().Name() dc, err := getStructDecoratorConfig(g, t, g.mappedRelations) if err != nil { - return err + return fmt.Errorf("failed to get structDecoratorConfig for %s, %w", name, err) } g.logger.Debugf("mapped type %s", name) @@ -166,27 +172,39 @@ func (g *Gogm) parseOgmTypes() error { } func (g *Gogm) initDriver(ctx context.Context) error { - var certPool *x509.CertPool isEncrypted := strings.Contains(g.config.Protocol, "+s") if isEncrypted { - if g.config.UseSystemCertPool { - var err error - certPool, err = x509.SystemCertPool() - if err != nil { - return fmt.Errorf("failed to get system cert pool") - } - } else { - certPool = x509.NewCertPool() + if g.config.TLSConfig == nil { + g.config.TLSConfig = &tls.Config{} } + // handle deprecated config support if g.config.CAFileLocation != "" { + g.logger.Debugf("loading ca file at location `%s`", g.config.CAFileLocation) bytes, err := ioutil.ReadFile(g.config.CAFileLocation) if err != nil { return fmt.Errorf("failed to open ca file, %w", err) } + g.logger.Debugf("successfully loaded ca file") + + var certPool *x509.CertPool + if g.config.UseSystemCertPool { + g.logger.Debug("loading system cert pool") + var err error + certPool, err = x509.SystemCertPool() + if err != nil { + return fmt.Errorf("failed to get system cert pool") + } + g.logger.Debug("successfully loaded system cert pool") + } else { + certPool = x509.NewCertPool() + } - certPool.AppendCertsFromPEM(bytes) + if !certPool.AppendCertsFromPEM(bytes) { + return errors.New("failed to load CA into cert pool") + } + g.config.TLSConfig.RootCAs = certPool } } @@ -198,7 +216,9 @@ func (g *Gogm) initDriver(ctx context.Context) error { neoConf.MaxConnectionPoolSize = g.config.PoolSize if isEncrypted { - neoConf.RootCAs = certPool + if g.config.TLSConfig.RootCAs != nil { + neoConf.RootCAs = g.config.TLSConfig.RootCAs + } } } @@ -231,7 +251,9 @@ func (g *Gogm) initDriver(ctx context.Context) error { } func (g *Gogm) initDriverRoutine(neoConfig func(neoConf *neo4j.Config), doneChan chan error) { - driver, err := neo4j.NewDriver(g.config.ConnectionString(), neo4j.BasicAuth(g.config.Username, g.config.Password, g.config.Realm), neoConfig) + connStr := g.config.ConnectionString() + g.logger.Debugf("connection string: %s\n", connStr) + driver, err := neo4j.NewDriver(connStr, neo4j.BasicAuth(g.config.Username, g.config.Password, g.config.Realm), neoConfig) if err != nil { doneChan <- fmt.Errorf("failed to create driver, %w", err) return @@ -249,21 +271,20 @@ func (g *Gogm) initDriverRoutine(neoConfig func(neoConf *neo4j.Config), doneChan // get neoversion sess := driver.NewSession(neo4j.SessionConfig{ AccessMode: neo4j.AccessModeRead, - // DatabaseName: "neo4j", }) res, err := sess.Run("return 1", nil) if err != nil { - doneChan <- err + doneChan <- fmt.Errorf("failed to run test query, %w", err) return } else if err = res.Err(); err != nil { - doneChan <- err + doneChan <- fmt.Errorf("failed to run test query, %w", err) return } sum, err := res.Consume() if err != nil { - doneChan <- err + doneChan <- fmt.Errorf("failed to consume test query, %w", err) return } @@ -278,19 +299,19 @@ func (g *Gogm) initIndex(ctx context.Context) error { g.logger.Debug("dropping all known indexes") err := dropAllIndexesAndConstraints(ctx, g) if err != nil { - return err + return fmt.Errorf("failed to drop all known indexes, %w", err) } g.logger.Debug("creating all mapped indexes") err = createAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { - return err + return fmt.Errorf("failed t create all indexes and constraints, %w", err) } g.logger.Debug("verifying all indexes") err = verifyAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { - return err + return fmt.Errorf("failed to verify all indexes and contraints, %w", err) } return nil case VALIDATE_INDEX: @@ -298,7 +319,7 @@ func (g *Gogm) initIndex(ctx context.Context) error { g.logger.Debug("verifying all indexes") err := verifyAllIndexesAndConstraints(ctx, g, g.mappedTypes) if err != nil { - return err + return fmt.Errorf("failed to verify all indexes and contraints, %w", err) } return nil case IGNORE_INDEX: diff --git a/index.go b/index.go index e917bca..77b4e86 100644 --- a/index.go +++ b/index.go @@ -21,32 +21,62 @@ package gogm import ( "context" + "fmt" "github.com/cornelk/hashmap" ) //drops all known indexes func dropAllIndexesAndConstraints(ctx context.Context, gogm *Gogm) error { if gogm.boltMajorVersion >= 4 { - return dropAllIndexesAndConstraintsV4(ctx, gogm) + for _, db := range gogm.config.TargetDbs { + err := dropAllIndexesAndConstraintsV4(ctx, gogm, db) + if err != nil { + return fmt.Errorf("failed to drop indexes and constraints for db %s on db version 4+, %w", db, err) + } + } + } else { + err := dropAllIndexesAndConstraintsV3(ctx, gogm) + if err != nil { + return fmt.Errorf("failed to drop indexes and constraints on db version 3, %w", err) + } } - return dropAllIndexesAndConstraintsV3(ctx, gogm) + return nil } //creates all indexes func createAllIndexesAndConstraints(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { if gogm.boltMajorVersion >= 4 { - return createAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes) + for _, db := range gogm.config.TargetDbs { + err := createAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes, db) + if err != nil { + return fmt.Errorf("failed to create indexes and constraints for db %s on db version 4+, %w", db, err) + } + } + } else { + err := createAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) + if err != nil { + return fmt.Errorf("failed to create indexes and constraints on db version 3, %w", err) + } } - - return createAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) + return nil } //verifies all indexes func verifyAllIndexesAndConstraints(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { if gogm.boltMajorVersion >= 4 { - return verifyAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes) + for _, db := range gogm.config.TargetDbs { + err := verifyAllIndexesAndConstraintsV4(ctx, gogm, mappedTypes, db) + if err != nil { + return fmt.Errorf("failed to verify indexes and constraints for db %s on db version 4+, %w", db, err) + } + } + } else { + err := verifyAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) + if err != nil { + return fmt.Errorf("failed to verify all indexes and contraints on db version 3, %w", err) + } } - return verifyAllIndexesAndConstraintsV3(ctx, gogm, mappedTypes) + return nil } diff --git a/index_v4.go b/index_v4.go index 289249c..d066345 100644 --- a/index_v4.go +++ b/index_v4.go @@ -93,278 +93,232 @@ func resultToStringArrV4(isConstraint bool, result [][]interface{}) ([]string, e } //drops all known indexes -func dropAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm) error { - for _, db := range gogm.config.TargetDbs { - sess, err := gogm.NewSessionV2(SessionConfig{ - AccessMode: neo4j.AccessModeWrite, - DatabaseName: db, - }) - if err != nil { - return err - } - - err = sess.ManagedTransaction(ctx, func(tx TransactionV2) error { - res, _, err := tx.QueryRaw(ctx, "CALL db.constraints()", nil) - if err != nil { - return err - } - - if len(res) == 0 { - // no constraints to kill off, return from here - return nil - } - - constraints, err := resultToStringArrV4(true, res) - if err != nil { - return err - } +func dropAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, db string) error { + sess, err := gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: db, + }) + if err != nil { + return fmt.Errorf("failed to open session to db %s, %w", db, err) + } - //if there is anything, get rid of it - if len(constraints) != 0 { - for _, constraint := range constraints { - gogm.logger.Debugf("dropping constraint '%s'", constraint) - _, _, err := tx.QueryRaw(ctx, fmt.Sprintf("DROP CONSTRAINT %s IF EXISTS", constraint), nil) - if err != nil { - return err - } - } - } + defer sess.Close() - res, _, err = tx.QueryRaw(ctx, "CALL db.indexes()", nil) - if err != nil { - return tx.RollbackWithError(ctx, err) - } + res, _, err := sess.QueryRaw(ctx, "CALL db.constraints()", nil) + if err != nil { + return fmt.Errorf("failed to query current constraints, %w", err) + } - indexes, err := resultToStringArrV4(false, res) - if err != nil { - return err - } + if len(res) == 0 { + // no constraints to kill off, return from here + return nil + } - //if there is anything, get rid of it - if len(indexes) != 0 { - for _, index := range indexes { - if len(index) == 0 { - return errors.New("invalid index config") - } + constraints, err := resultToStringArrV4(true, res) + if err != nil { + return fmt.Errorf("failed to convert constraints to string array, %w", err) + } - _, _, err := tx.QueryRaw(ctx, fmt.Sprintf("DROP INDEX %s IF EXISTS", index), nil) - if err != nil { - return tx.RollbackWithError(ctx, err) - } + err = sess.ManagedTransaction(ctx, func(tx TransactionV2) error { + //if there is anything, get rid of it + if len(constraints) != 0 { + for _, constraint := range constraints { + gogm.logger.Debugf("dropping constraint '%s'", constraint) + _, _, err := tx.QueryRaw(ctx, fmt.Sprintf("DROP CONSTRAINT %s IF EXISTS", constraint), nil) + if err != nil { + return fmt.Errorf("failed to drop contraint `%s`, %w", constraint, err) } } - return nil - }) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return fmt.Errorf("drop index transaction failed, %w", err) } - err = sess.Close() + res, _, err = tx.QueryRaw(ctx, "CALL db.indexes()", nil) if err != nil { - return err + return fmt.Errorf("failed to call db.indexes(), %w", err) } - } - return nil -} -//creates all indexes -func createAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { - for _, db := range gogm.config.TargetDbs { - sess, err := gogm.NewSessionV2(SessionConfig{ - AccessMode: neo4j.AccessModeWrite, - DatabaseName: db, - }) + indexes, err := resultToStringArrV4(false, res) if err != nil { - return err - } - - //validate that we have to do anything - if mappedTypes == nil || mappedTypes.Len() == 0 { - return errors.New("must have types to map") + return fmt.Errorf("failed to convert result to string array, %w", err) } - numIndexCreated := 0 - //index and/or create unique constraints wherever necessary - //for node, structConfig := range mappedTypes{ - err = sess.ManagedTransaction(ctx, func(tx TransactionV2) error { - for nodes := range mappedTypes.Iter() { - node := nodes.Key.(string) - structConfig := nodes.Value.(structDecoratorConfig) - if structConfig.Fields == nil || len(structConfig.Fields) == 0 { - continue + //if there is anything, get rid of it + if len(indexes) != 0 { + for _, index := range indexes { + if len(index) == 0 { + return errors.New("invalid index config") } - var indexFields []string - - for _, config := range structConfig.Fields { - //pk is a special unique key - if config.PrimaryKey != "" || config.Unique { - numIndexCreated++ - _, _, err = tx.QueryRaw(ctx, buildConstraintQuery(true, node, structConfig.Label, config.Name), nil) - if err != nil { - return err - } - } else if config.Index { - indexFields = append(indexFields, config.Name) - } - } - - //create composite index - if len(indexFields) > 0 { - numIndexCreated++ - _, _, err = tx.QueryRaw(ctx, buildIndexQuery(structConfig.Label, indexFields...), nil) - if err != nil { - return err - } + _, _, err := tx.QueryRaw(ctx, fmt.Sprintf("DROP INDEX %s IF EXISTS", index), nil) + if err != nil { + return fmt.Errorf("failed to drop index %s, %w", index, err) } } - - gogm.logger.Debugf("created (%v) indexes", numIndexCreated) - return nil - }) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return err - } - err = sess.Close() - if err != nil { - return err } + return nil + }) + if err != nil { + return fmt.Errorf("drop index and constraint transaction failed, %w", err) } + return nil } -//verifies all indexes -func verifyAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap) error { - for _, db := range gogm.config.TargetDbs { - sess, err := gogm.NewSessionV2(SessionConfig{ - AccessMode: neo4j.AccessModeWrite, - DatabaseName: db, - }) - if err != nil { - return err - } +//creates all indexes +func createAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap, db string) error { + sess, err := gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: db, + }) + if err != nil { + return fmt.Errorf("failed to open session to %s, %w", db, err) + } - //validate that we have to do anything - if mappedTypes == nil || mappedTypes.Len() == 0 { - return errors.New("must have types to map") - } + defer sess.Close() - var constraints []string - var indexes []string + //validate that we have to do anything + if mappedTypes == nil || mappedTypes.Len() == 0 { + return errors.New("must have types to map") + } - //build constraint strings + numIndexCreated := 0 + //index and/or create unique constraints wherever necessary + //for node, structConfig := range mappedTypes{ + err = sess.ManagedTransaction(ctx, func(tx TransactionV2) error { for nodes := range mappedTypes.Iter() { node := nodes.Key.(string) structConfig := nodes.Value.(structDecoratorConfig) - if structConfig.Fields == nil || len(structConfig.Fields) == 0 { continue } - fields := []string{} + var indexFields []string for _, config := range structConfig.Fields { - + //pk is a special unique key if config.PrimaryKey != "" || config.Unique { - t := fmt.Sprintf("CONSTRAINT ON (%s:%s) ASSERT %s.%s IS UNIQUE", node, structConfig.Label, node, config.Name) - constraints = append(constraints, t) - - indexes = append(indexes, fmt.Sprintf("INDEX ON :%s(%s)", structConfig.Label, config.Name)) - + numIndexCreated++ + _, _, err = tx.QueryRaw(ctx, buildConstraintQuery(true, node, structConfig.Label, config.Name), nil) + if err != nil { + return fmt.Errorf("failed to add constraint, %w", err) + } } else if config.Index { - fields = append(fields, config.Name) + indexFields = append(indexFields, config.Name) } } - f := "(" - for _, field := range fields { - f += field + //create composite index + if len(indexFields) > 0 { + numIndexCreated++ + _, _, err = tx.QueryRaw(ctx, buildIndexQuery(structConfig.Label, indexFields...), nil) + if err != nil { + return fmt.Errorf("failed to add index, %w", err) + } } + } - f += ")" + gogm.logger.Debugf("created (%v) indexes", numIndexCreated) + return nil + }) + if err != nil { + return fmt.Errorf("tx to add indexes and contraints failed, %w", err) + } - indexes = append(indexes, fmt.Sprintf("INDEX ON :%s%s", structConfig.Label, f)) + return nil +} - } +//verifies all indexes +func verifyAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTypes *hashmap.HashMap, db string) error { + sess, err := gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: db, + }) + if err != nil { + return fmt.Errorf("failed to open session %s, %w", db, err) + } - //get whats there now - foundResult, _, err := sess.QueryRaw(ctx, "CALL db.constraints", nil) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return fmt.Errorf("no constraints found, %w", err) - } + defer sess.Close() - foundConstraints, err := resultToStringArrV4(true, foundResult) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return fmt.Errorf("failed to convert result to string array, %w", err) - } + //validate that we have to do anything + if mappedTypes == nil || mappedTypes.Len() == 0 { + return errors.New("must have types to map") + } - foundInxdexResult, _, err := sess.QueryRaw(ctx, "CALL db.indexes()", nil) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return fmt.Errorf("no indices found, %w", err) - } + var constraints []string + var indexes []string - foundIndexes, err := resultToStringArrV4(false, foundInxdexResult) - if err != nil { - _err := sess.Close() - if err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return fmt.Errorf("failed to convert result to array, %w", err) - } + //build constraint strings + for nodes := range mappedTypes.Iter() { + node := nodes.Key.(string) + structConfig := nodes.Value.(structDecoratorConfig) - //verify from there - delta, found := arrayOperations.Difference(foundIndexes, indexes) - if !found { - err = fmt.Errorf("found differences in remote vs ogm for found indexes, %v", delta) - _err := sess.Close() - if _err != nil { - err = fmt.Errorf("%s: %w", err, _err) - } - return err + if structConfig.Fields == nil || len(structConfig.Fields) == 0 { + continue } - gogm.logger.Debugf("%+v", delta) + fields := []string{} + + for _, config := range structConfig.Fields { - var founds []string + if config.PrimaryKey != "" || config.Unique { + t := fmt.Sprintf("CONSTRAINT ON (%s:%s) ASSERT %s.%s IS UNIQUE", node, structConfig.Label, node, config.Name) + constraints = append(constraints, t) - founds = append(founds, foundConstraints...) + indexes = append(indexes, fmt.Sprintf("INDEX ON :%s(%s)", structConfig.Label, config.Name)) - delta, found = arrayOperations.Difference(founds, constraints) - if !found { - err = fmt.Errorf("found differences in remote vs ogm for found constraints, %v", delta) - _err := sess.Close() - if _err != nil { - err = fmt.Errorf("%s: %w", err, _err) + } else if config.Index { + fields = append(fields, config.Name) } - return err } - gogm.logger.Debugf("%+v", delta) - err = sess.Close() - if err != nil { - return err + f := "(" + for _, field := range fields { + f += field } + + f += ")" + + indexes = append(indexes, fmt.Sprintf("INDEX ON :%s%s", structConfig.Label, f)) + + } + + //get what's there now + foundResult, _, err := sess.QueryRaw(ctx, "CALL db.constraints", nil) + if err != nil { + return fmt.Errorf("no constraints found, %w", err) + } + + foundConstraints, err := resultToStringArrV4(true, foundResult) + if err != nil { + return fmt.Errorf("failed to convert result to string array, %w", err) + } + + foundInxdexResult, _, err := sess.QueryRaw(ctx, "CALL db.indexes()", nil) + if err != nil { + return fmt.Errorf("no indices found, %w", err) + } + + foundIndexes, err := resultToStringArrV4(false, foundInxdexResult) + if err != nil { + return fmt.Errorf("failed to convert result to array, %w", err) + } + + //verify from there + delta, found := arrayOperations.Difference(foundIndexes, indexes) + if !found { + return fmt.Errorf("found differences in remote vs ogm for found indexes, %v", delta) + } + + gogm.logger.Debugf("%+v", delta) + + var founds []string + + founds = append(founds, foundConstraints...) + + delta, found = arrayOperations.Difference(founds, constraints) + if !found { + return fmt.Errorf("found differences in remote vs ogm for found constraints, %v", delta) } + gogm.logger.Debugf("%+v", delta) return nil } diff --git a/integration_test.go b/integration_test.go index aa0d7d5..bca1ce4 100644 --- a/integration_test.go +++ b/integration_test.go @@ -68,19 +68,19 @@ func (integrationTest *IntegrationTestSuite) TearDownSuite() { func (integrationTest *IntegrationTestSuite) SetupSuite() { conf := Config{ - Username: "neo4j", - Password: "changeme", - Host: "0.0.0.0", - IsCluster: false, - Port: 7687, - PoolSize: 15, + Username: "neo4j", + Password: "changeme", + Host: "0.0.0.0", + Protocol: "bolt", + Port: 7687, + PoolSize: 15, // this is ignore because index management is part of the test IndexStrategy: IGNORE_INDEX, - EnableDriverLogs: true, + EnableDriverLogs: false, DefaultTransactionTimeout: 2 * time.Minute, } - gogm, err := New(&conf, UUIDPrimaryKeyStrategy, &a{}, &b{}, &c{}, &propTest{}, &narcissisticTestNode{}) + gogm, err := New(&conf, UUIDPrimaryKeyStrategy, &a{}, &b{}, &c{}, &propTest{}, &narcissisticTestNode{}, &Sides{}, &Middle{}, &Bottom{}) integrationTest.Require().Nil(err) integrationTest.Require().NotNil(gogm) integrationTest.gogm = gogm @@ -315,6 +315,194 @@ func (integrationTest *IntegrationTestSuite) TestIntegration() { integrationTest.Require().Nil(sess.Close()) } +type Sides struct { + BaseUUIDNode + Name string `gogm:"name=name"` + + MatchIncoming []*Middle `gogm:"direction=incoming;relationship=outgoing_test"` +} + +type Bottom struct { + BaseUUIDNode + Name string `gogm:"name=name"` + Middle []*Middle `gogm:"direction=incoming;relationship=bottom"` +} + +type Middle struct { + BaseUUIDNode + + IncomingSides []*Sides `gogm:"direction=outgoing;relationship=outgoing_test"` + Bottom []*Bottom `gogm:"direction=outgoing;relationship=bottom"` +} + +func (integrationTest *IntegrationTestSuite) TestMultiSaveEdgeCase() { + // skipping multidb integration test for v3 + if integrationTest.gogm.boltMajorVersion < 4 { + integrationTest.T().Log("skipping because of incompatible version", integrationTest.gogm.boltMajorVersion) + integrationTest.T().Skip() + return + } + + /* + (left)--(middle)--(right) + | + (Bottom) + SaveDepth(left, 1) + SaveDepth(right,1) + SaveDepth(bottom, 1) + + Problem is only (middle)--(right) is saved, not (left)--(middle) + */ + numMiddles := 30 + + for _, testCase := range []struct { + TestFunction func(req *require.Assertions, db string) + Name string + }{ + { + Name: "incoming multi non transaction test", + TestFunction: func(req *require.Assertions, db string) { + left, right := &Sides{Name: "left"}, &Sides{Name: "right"} + bottom := &Bottom{} + middles := make([]*Middle, numMiddles) + for i := 0; i < numMiddles; i++ { + middles[i] = &Middle{} + middles[i].IncomingSides = []*Sides{left, right} + middles[i].Bottom = []*Bottom{bottom} + } + + bottom.Middle = middles + left.MatchIncoming = middles + right.MatchIncoming = middles + + sess, err := integrationTest.gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: db, + }) + req.Nil(err) + req.NotNil(sess) + + req.Nil(sess.SaveDepth(context.Background(), left, 1)) + req.Nil(sess.SaveDepth(context.Background(), right, 1)) + req.Nil(sess.SaveDepth(context.Background(), bottom, 1)) + req.Nil(sess.Close()) + + sess, err = integrationTest.gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeRead, + DatabaseName: db, + }) + req.Nil(err) + req.NotNil(sess) + defer sess.Close() + var checkLeft, checkRight Sides + var checkBottom Bottom + + req.Nil(sess.LoadDepth(context.Background(), &checkLeft, left.UUID, 1)) + req.Equal(len(checkLeft.MatchIncoming), numMiddles) + + req.Nil(sess.LoadDepth(context.Background(), &checkBottom, bottom.UUID, 1)) + req.Equal(len(checkBottom.Middle), numMiddles) + + req.Nil(sess.LoadDepth(context.Background(), &checkRight, right.UUID, 1)) + req.Equal(len(checkRight.MatchIncoming), numMiddles) + }, + }, + { + Name: "incoming multi transaction test", + TestFunction: func(req *require.Assertions, db string) { + left, right := &Sides{Name: "left"}, &Sides{Name: "right"} + bottom := &Bottom{} + + middles := make([]*Middle, numMiddles) + for i := 0; i < numMiddles; i++ { + middles[i] = &Middle{} + middles[i].IncomingSides = []*Sides{left, right} + middles[i].Bottom = []*Bottom{bottom} + } + + bottom.Middle = middles + left.MatchIncoming = middles + right.MatchIncoming = middles + + sess, err := integrationTest.gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: db, + }) + req.Nil(err) + req.NotNil(sess) + + ctx := context.Background() + req.Nil(sess.ManagedTransaction(ctx, func(tx TransactionV2) error { + err = tx.SaveDepth(context.Background(), left, 1) + if err != nil { + return err + } + + err = tx.SaveDepth(context.Background(), right, 1) + if err != nil { + return err + } + + return tx.SaveDepth(context.Background(), bottom, 1) + })) + req.Nil(sess.Close()) + + sess, err = integrationTest.gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeRead, + DatabaseName: db, + }) + req.Nil(err) + req.NotNil(sess) + defer sess.Close() + var checkLeft, checkRight Sides + var checkBottom Bottom + + req.Nil(sess.LoadDepth(context.Background(), &checkLeft, left.UUID, 1)) + req.Equal(len(checkLeft.MatchIncoming), numMiddles) + + req.Nil(sess.LoadDepth(context.Background(), &checkBottom, bottom.UUID, 1)) + req.Equal(len(checkBottom.Middle), numMiddles) + + req.Nil(sess.LoadDepth(context.Background(), &checkRight, right.UUID, 1)) + req.Equal(len(checkRight.MatchIncoming), numMiddles) + }, + }, + } { + integrationTest.T().Run(testCase.Name, func(t *testing.T) { + db := fmt.Sprintf("db-%s", uuid2.New().String()) + req := require.New(integrationTest.T()) + sess, err := integrationTest.gogm.NewSessionV2(SessionConfig{ + AccessMode: neo4j.AccessModeWrite, + DatabaseName: "system", + }) + req.NotNil(sess) + req.Nil(err) + ctx := context.Background() + _, info, err := sess.QueryRaw(ctx, "CREATE DATABASE $DB IF NOT EXISTS", map[string]interface{}{ + "DB": db, + }) + req.Nil(err) + req.NotNil(info) + req.NotNil(info.Counters()) + req.Equal(1, info.Counters().SystemUpdates()) + + time.Sleep(10 * time.Second) + + defer func() { + _, info, err := sess.QueryRaw(ctx, "DROP DATABASE $DB", map[string]interface{}{ + "DB": db, + }) + req.Nil(err) + req.NotNil(info) + req.NotNil(info.Counters()) + req.Equal(1, info.Counters().SystemUpdates()) + }() + + testCase.TestFunction(req, db) + }) + } +} + func (integrationTest *IntegrationTestSuite) TestIntegrationV2() { req := integrationTest.Require() log.Println("testIndexManagement") @@ -633,15 +821,15 @@ func testSaveV2(sess SessionV2, req *require.Assertions) { const testUuid1 = "f64953a5-8b40-4a87-a26b-6427e661570c" -func (i *IntegrationTestSuite) TestSchemaLoadStrategy() { - req := i.Require() +func (integrationTest *IntegrationTestSuite) TestSchemaLoadStrategy() { + req := integrationTest.Require() - i.gogm.config.LoadStrategy = SCHEMA_LOAD_STRATEGY + integrationTest.gogm.config.LoadStrategy = SCHEMA_LOAD_STRATEGY // create required nodes - testSchemaLoadStrategy_Setup(i.gogm, req) + testSchemaLoadStrategy_Setup(integrationTest.gogm, req) - sess, err := i.gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeRead}) + sess, err := integrationTest.gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeRead}) req.Nil(err) defer req.Nil(sess.Close()) @@ -650,7 +838,7 @@ func (i *IntegrationTestSuite) TestSchemaLoadStrategy() { defer req.Nil(sess.Close()) // test raw query (verify SchemaLoadStrategy + Neo driver decoding) - query, err := SchemaLoadStrategyOne(i.gogm, "n", "a", "uuid", "uuid", false, 1, nil) + query, err := SchemaLoadStrategyOne(integrationTest.gogm, "n", "a", "uuid", "uuid", false, 1, nil) req.Nil(err, "error generating SchemaLoadStrategy query") cypher, err := query.ToCypher() @@ -730,12 +918,12 @@ const testUuid2 = "f64953a5-8b40-4a87-a26b-6427e661570d" const testUuid3 = "f64953a5-8b40-4a87-a26b-6427e661571d" const testUuid4 = "f64953a5-8b40-4a87-a26b-6427e661572d" -func (i *IntegrationTestSuite) TestRelationshipWithinSingleType() { - req := i.Require() +func (integrationTest *IntegrationTestSuite) TestRelationshipWithinSingleType() { + req := integrationTest.Require() - testRelationshipWithinSingleType_Setup(i.gogm, req) + testRelationshipWithinSingleType_Setup(integrationTest.gogm, req) - sess, err := i.gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeRead}) + sess, err := integrationTest.gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeRead}) req.Nil(err) defer req.Nil(sess.Close()) diff --git a/save.go b/save.go index 8779813..d779f37 100644 --- a/save.go +++ b/save.go @@ -161,7 +161,7 @@ func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreat } for label, rels := range relations { - var params []interface{} + var _params []interface{} if len(rels) == 0 { continue @@ -184,7 +184,7 @@ func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreat rel.Params = map[string]interface{}{} } - params = append(params, map[string]interface{}{ + _params = append(_params, map[string]interface{}{ "startNodeId": startId, "endNodeId": endId, "props": rel.Params, @@ -246,7 +246,7 @@ func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreat } res, err := transaction.Run(cyp, map[string]interface{}{ - "rows": params, + "rows": _params, }) if err != nil { return fmt.Errorf("failed to relate nodes, %w", err) @@ -435,13 +435,10 @@ func generateCurRels(gogm *Gogm, parentPtr uintptr, current *reflect.Value, curr var followId int64 if !followIdVal.IsNil() { followIdVal = followIdVal.Elem() - if followIdVal.IsZero() { - followId = 0 - } else { - followId = followIdVal.Int() - } + followId = followIdVal.Int() } else { - followId = 0 + // should not be nil, just skip this one + continue } //check the config is there for the specified field @@ -471,13 +468,10 @@ func generateCurRels(gogm *Gogm, parentPtr uintptr, current *reflect.Value, curr var followId int64 if !followIdVal.IsNil() { followIdVal = followIdVal.Elem() - if followIdVal.IsZero() { - followId = 0 - } else { - followId = followIdVal.Int() - } + followId = followIdVal.Int() } else { - followId = 0 + // should not be nil, just skip this one + continue } //check the config is there for the specified field