diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 5051e22..df3119a 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -18,9 +18,12 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
name: Go
-on: [push, pull_request]
+on:
+ push:
+ branches:
+ - master
+ pull_request: {}
jobs:
-
build:
name: Build
runs-on: ubuntu-latest
diff --git a/.gitignore b/.gitignore
index 9ad3db2..273973f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,5 @@ certs/
*.out
.idea/
vendor/
+
+.vscode
\ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index 1395127..0000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,5 +0,0 @@
-{
- "cSpell.words": [
- "gogm"
- ]
-}
\ No newline at end of file
diff --git a/README.md b/README.md
index 6709201..9832ca9 100644
--- a/README.md
+++ b/README.md
@@ -32,8 +32,8 @@ go get -u github.com/mindstand/gogm/v2
Primary key strategies allow more customization over primary keys. A strategy is provided to gogm on initialization.
Built in primary key strategies are:
-- gogm.DefaultPrimaryKeyStrategy -- just use the graph id from neo4j as the primary key
-- gogm.UUIDPrimaryKeyStrategy -- uuid's as primary keys
+- `gogm.DefaultPrimaryKeyStrategy` -- just use the graph id from neo4j as the primary key
+- `gogm.UUIDPrimaryKeyStrategy` -- uuid's as primary keys
```go
// Example of the internal UUID strategy
PrimaryKeyStrategy{
@@ -52,6 +52,17 @@ PrimaryKeyStrategy{
}
```
+### Load Strategy
+Load strategies allow control over the queries generated by `Load` operations.
+Different strategies change the size of the queries sent to the database as well as the amount of work the database has to do.
+A load strategy is provided to gomg on initialization.
+
+The defined load strategies are:
+- `gogm.PATH_LOAD_STRATEGY` -- Use cypher path queries to generate simple queries for load operations.
+- `gogm.SCHEMA_LOAD_STRATEGY` -- Leverage the GoGM schema to generate more complex queries for load operations which results in less work for the database.
+
+Depending on your use case, `PATH_LOAD_STRATEGY` may result in higher latency.
+
### Struct Configuration
##### text notates deprecation
@@ -167,6 +178,8 @@ func main() {
EnableLogParams: false,
// enable open tracing. Ensure contexts have spans already. GoGM does not make root spans, only child spans
OpentracingEnabled: false,
+ // specify the method gogm will use to generate Load queries
+ LoadStrategy: gogm.PATH_LOAD_STRATEGY // set to SCHEMA_LOAD_STRATEGY for schema-aware queries which may reduce load on the database
}
// register all vertices and edges
@@ -285,7 +298,7 @@ sess, err := gogm.G().NewSessionV2(gogm.SessionConfig{AccessMode: gogm.AccessMod
## CLI Installation
```
-go get -u github.com/mindstand/gogm/v2/cli/gogmcli
+go get -u github.com/mindstand/gogm/v2/cmd/gogmcli
```
## CLI Usage
diff --git a/config.go b/config.go
index d6ce9b8..83c6fa6 100644
--- a/config.go
+++ b/config.go
@@ -69,6 +69,8 @@ type Config struct {
EnableLogParams bool `json:"enable_log_properties" yaml:"enable_log_properties" mapstructure:"enable_log_properties"`
OpentracingEnabled bool `json:"opentracing_enabled" yaml:"opentracing_enabled" mapstructure:"opentracing_enabled"`
+
+ LoadStrategy LoadStrategy `json:"load_strategy" yaml:"load_strategy" mapstructure:"load_strategy"`
}
func (c *Config) validate() error {
@@ -93,6 +95,14 @@ func (c *Config) validate() error {
c.TargetDbs = []string{"neo4j"}
}
+ if err := c.IndexStrategy.validate(); err != nil {
+ return err
+ }
+
+ if err := c.LoadStrategy.validate(); err != nil {
+ return err
+ }
+
return nil
}
@@ -126,3 +136,12 @@ const (
// IGNORE_INDEX skips the index step of setup
IGNORE_INDEX IndexStrategy = 2
)
+
+func (is IndexStrategy) validate() error {
+ switch is {
+ case ASSERT_INDEX, VALIDATE_INDEX, IGNORE_INDEX:
+ return nil
+ default:
+ return fmt.Errorf("invalid index strategy %d", is)
+ }
+}
diff --git a/decoder.go b/decoder.go
index 2dab2f3..fadf7f2 100644
--- a/decoder.go
+++ b/decoder.go
@@ -22,11 +22,41 @@ package gogm
import (
"errors"
"fmt"
- "github.com/neo4j/neo4j-go-driver/v4/neo4j"
"reflect"
"strings"
+
+ "github.com/neo4j/neo4j-go-driver/v4/neo4j"
)
+func traverseResultRecordValues(values []interface{}) ([]neo4j.Path, []neo4j.Relationship, []neo4j.Node) {
+ var paths []neo4j.Path
+ var strictRels []neo4j.Relationship
+ var isolatedNodes []neo4j.Node
+
+ for _, value := range values {
+ switch ct := value.(type) {
+ case neo4j.Path:
+ paths = append(paths, ct)
+ case neo4j.Relationship:
+ strictRels = append(strictRels, ct)
+ case neo4j.Node:
+ isolatedNodes = append(isolatedNodes, ct)
+ case []interface{}:
+ v, ok := value.([]interface{})
+ if ok {
+ p, r, n := traverseResultRecordValues(v)
+ paths = append(paths, p...)
+ strictRels = append(strictRels, r...)
+ isolatedNodes = append(isolatedNodes, n...)
+ }
+ default:
+ continue
+ }
+ }
+
+ return paths, strictRels, isolatedNodes
+}
+
//decodes raw path response from driver
//example query `match p=(n)-[*0..5]-() return p`
func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
@@ -61,21 +91,10 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
var isolatedNodes []neo4j.Node
for result.Next() {
- for _, value := range result.Record().Values {
- switch ct := value.(type) {
- case neo4j.Path:
- paths = append(paths, ct)
- break
- case neo4j.Relationship:
- strictRels = append(strictRels, ct)
- break
- case neo4j.Node:
- isolatedNodes = append(isolatedNodes, ct)
- break
- default:
- continue
- }
- }
+ p, r, n := traverseResultRecordValues(result.Record().Values)
+ paths = append(paths, p...)
+ strictRels = append(strictRels, r...)
+ isolatedNodes = append(isolatedNodes, n...)
}
nodeLookup := make(map[int64]*reflect.Value)
@@ -84,21 +103,21 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
rels := make(map[int64]*neoEdgeConfig)
labelLookup := map[int64]string{}
- if paths != nil && len(paths) != 0 {
+ if len(paths) != 0 {
err = sortPaths(gogm, paths, &nodeLookup, &rels, &pks, primaryLabel, &relMaps)
if err != nil {
return err
}
}
- if isolatedNodes != nil && len(isolatedNodes) != 0 {
+ if len(isolatedNodes) != 0 {
err = sortIsolatedNodes(gogm, isolatedNodes, &labelLookup, &nodeLookup, &pks, primaryLabel, &relMaps)
if err != nil {
return err
}
}
- if strictRels != nil && len(strictRels) != 0 {
+ if len(strictRels) != 0 {
err = sortStrictRels(strictRels, &labelLookup, &rels)
if err != nil {
return err
@@ -232,14 +251,14 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
//can ensure that it implements proper interface if it made it this far
res := val.MethodByName("SetStartNode").Call([]reflect.Value{startCall})
- if res == nil || len(res) == 0 {
+ if len(res) == 0 {
return fmt.Errorf("invalid response from edge callback - %w", err)
} else if !res[0].IsNil() {
return fmt.Errorf("failed call to SetStartNode - %w", res[0].Interface().(error))
}
res = val.MethodByName("SetEndNode").Call([]reflect.Value{endCall})
- if res == nil || len(res) == 0 {
+ if len(res) == 0 {
return fmt.Errorf("invalid response from edge callback - %w", err)
} else if !res[0].IsNil() {
return fmt.Errorf("failed call to SetEndNode - %w", res[0].Interface().(error))
@@ -359,7 +378,7 @@ func sortIsolatedNodes(gogm *Gogm, isolatedNodes []neo4j.Node, labelLookup *map[
}
//set label map
- if _, ok := (*labelLookup)[node.Id]; !ok && len(node.Labels) != 0 && node.Labels[0] == pkLabel {
+ if _, ok := (*labelLookup)[node.Id]; !ok && len(node.Labels) != 0 { //&& node.Labels[0] == pkLabel {
(*labelLookup)[node.Id] = node.Labels[0]
}
}
diff --git a/decoder_test.go b/decoder_test.go
index abe67a5..e4c28dd 100644
--- a/decoder_test.go
+++ b/decoder_test.go
@@ -21,28 +21,86 @@ package gogm
import (
"errors"
- "github.com/cornelk/hashmap"
- "github.com/neo4j/neo4j-go-driver/v4/neo4j"
- "github.com/stretchr/testify/require"
"reflect"
"testing"
"time"
+
+ "github.com/cornelk/hashmap"
+ "github.com/neo4j/neo4j-go-driver/v4/neo4j"
+ "github.com/stretchr/testify/require"
)
-type TestStruct struct {
- Id *int64
- UUID string
- OtherField string
-}
+func TestTraverseResultRecordValues(t *testing.T) {
+ req := require.New(t)
-func toHashmap(m map[string]interface{}) *hashmap.HashMap {
- h := &hashmap.HashMap{}
+ // empty case
+ pArr, rArr, nArr := traverseResultRecordValues([]interface{}{})
+ req.Len(pArr, 0)
+ req.Len(rArr, 0)
+ req.Len(nArr, 0)
+
+ // garbage record case
+ pArr, rArr, nArr = traverseResultRecordValues([]interface{}{"hello", []interface{}{"there"}})
+ req.Len(pArr, 0)
+ req.Len(rArr, 0)
+ req.Len(nArr, 0)
+
+ // define our test paths, rels, and nodes
+ p1 := neo4j.Path{
+ Nodes: []neo4j.Node{
+ {
+ Id: 1,
+ Labels: []string{"start"},
+ },
+ {
+ Id: 2,
+ Labels: []string{"end"},
+ },
+ },
+ Relationships: []neo4j.Relationship{
+ {
+ Id: 3,
+ StartId: 1,
+ EndId: 2,
+ Type: "someType",
+ },
+ },
+ }
- for k, v := range m {
- h.Set(k, v)
+ n1 := neo4j.Node{
+ Id: 4,
+ Labels: []string{"start"},
}
- return h
+ n2 := neo4j.Node{
+ Id: 5,
+ Labels: []string{"end"},
+ }
+
+ r1 := neo4j.Relationship{
+ Id: 6,
+ StartId: 4,
+ EndId: 5,
+ Type: "someType",
+ }
+
+ // normal case (paths, nodes, and rels, but no nested results)
+ pArr, rArr, nArr = traverseResultRecordValues([]interface{}{p1, n1, n2, r1})
+ req.Equal(pArr[0], p1)
+ req.Equal(rArr[0], r1)
+ req.ElementsMatch(nArr, []interface{}{n1, n2})
+
+ // case with nested nodes and rels
+ pArr, rArr, nArr = traverseResultRecordValues([]interface{}{p1, []interface{}{n1, n2, r1}})
+ req.Equal(pArr[0], p1)
+ req.Equal(rArr[0], r1)
+ req.ElementsMatch(nArr, []interface{}{n1, n2})
+}
+
+type TestStruct struct {
+ Id *int64
+ UUID string
+ OtherField string
}
func toHashmapStructdecconf(m map[string]structDecoratorConfig) *hashmap.HashMap {
@@ -868,4 +926,70 @@ func TestDecode2(t *testing.T) {
req.Equal(int64(55), *readin9.Id)
req.Equal("dasdfas", readin9.UUID)
+ // decode should be able to handle queries that return nested lists of paths, relationships, and nodes
+ decodeResultNested := [][]interface{}{
+ {
+ neo4j.Node{
+ Id: 18,
+ Labels: []string{"a"},
+ Props: map[string]interface{}{
+ "uuid": "2588baca-7561-43f8-9ddb-9c7aecf87284",
+ },
+ },
+ },
+ {
+ []interface{}{
+ []interface{}{
+ []interface{}{
+ []interface{}{
+ neo4j.Relationship{
+ Id: 0,
+ StartId: 19,
+ EndId: 18,
+ Type: "testm2o",
+ },
+ neo4j.Node{
+ Id: 19,
+ Labels: []string{"b"},
+ Props: map[string]interface{}{
+ "test_fielda": "1234",
+ "uuid": "b6d8c2ab-06c2-43d0-8452-89d6c4ec5d40",
+ },
+ },
+ },
+ },
+ []interface{}{},
+ []interface{}{
+ []interface{}{
+ neo4j.Relationship{
+ Id: 1,
+ StartId: 18,
+ EndId: 19,
+ Type: "special_single",
+ Props: map[string]interface{}{
+ "test": "testing",
+ },
+ },
+ neo4j.Node{
+ Id: 19,
+ Labels: []string{"b"},
+ Props: map[string]interface{}{
+ "test_fielda": "1234",
+ "uuid": "b6d8c2ab-06c2-43d0-8452-89d6c4ec5d40",
+ },
+ },
+ },
+ },
+ },
+ []interface{}{},
+ []interface{}{},
+ },
+ },
+ }
+ var readinNested a
+ req.Nil(decode(gogm, newMockResult(decodeResultNested), &readinNested))
+ req.Equal("2588baca-7561-43f8-9ddb-9c7aecf87284", readinNested.UUID)
+ req.Len(readinNested.ManyA, 1)
+ req.Equal("b6d8c2ab-06c2-43d0-8452-89d6c4ec5d40", readinNested.ManyA[0].UUID)
+ req.Equal(readinNested.ManyA[0], readinNested.SingleSpecA.End, "Two rels should have the same node instance")
}
diff --git a/decorator.go b/decorator.go
index 81b0f4e..c875f84 100644
--- a/decorator.go
+++ b/decorator.go
@@ -22,10 +22,11 @@ package gogm
import (
"errors"
"fmt"
- dsl "github.com/mindstand/go-cypherdsl"
"reflect"
"strings"
"time"
+
+ dsl "github.com/mindstand/go-cypherdsl"
)
// defined the decorator name for struct tag
@@ -525,7 +526,7 @@ func getStructDecoratorConfig(gogm *Gogm, i interface{}, mappedRelations *relati
fields := getFields(t)
- if fields == nil || len(fields) == 0 {
+ if len(fields) == 0 {
return nil, errors.New("failed to parse fields")
}
@@ -563,40 +564,9 @@ func getStructDecoratorConfig(gogm *Gogm, i interface{}, mappedRelations *relati
endType = field.Type
}
- endTypeName := ""
- if reflect.PtrTo(endType).Implements(edgeType) {
- gogm.logger.Debug(endType.Name())
- endVal := reflect.New(endType)
- var endTypeVal []reflect.Value
-
- //log.Info(endVal.String())
-
- if config.Direction == dsl.DirectionOutgoing {
- endTypeVal = endVal.MethodByName("GetEndNodeType").Call(nil)
- } else {
- endTypeVal = endVal.MethodByName("GetStartNodeType").Call(nil)
- }
-
- if len(endTypeVal) != 1 {
- return nil, errors.New("GetEndNodeType failed")
- }
-
- if endTypeVal[0].IsNil() {
- return nil, errors.New("GetEndNodeType() can not return a nil value")
- }
-
- convertedType, ok := endTypeVal[0].Interface().(reflect.Type)
- if !ok {
- return nil, errors.New("cannot convert to type reflect.Type")
- }
-
- if convertedType.Kind() == reflect.Ptr {
- endTypeName = convertedType.Elem().Name()
- } else {
- endTypeName = convertedType.Name()
- }
- } else {
- endTypeName = endType.Name()
+ endTypeName, err := traverseRelType(gogm, endType, config.Direction)
+ if err != nil {
+ return nil, err
}
mappedRelations.Add(toReturn.Label, config.Relationship, endTypeName, *config)
diff --git a/delete_test.go b/delete_test.go
index cc849ae..246b5ab 100644
--- a/delete_test.go
+++ b/delete_test.go
@@ -19,24 +19,20 @@
package gogm
-import (
- "github.com/stretchr/testify/require"
-)
+// func testDelete(req *require.Assertions) {
+// conn, err := driver.Session(neo4j.AccessModeWrite)
+// if err != nil {
+// req.Nil(err)
+// }
+// defer conn.Close()
-func testDelete(req *require.Assertions) {
- //conn, err := driver.Session(neo4j.AccessModeWrite)
- //if err != nil {
- // req.Nil(err)
- //}
- //defer conn.Close()
- //
- //del := a{
- // BaseNode: BaseNode{
- // Id: 0,
- // UUID: "5334ee8c-6231-40fd-83e5-16c8016ccde6",
- // },
- //}
- //
- //err = deleteNode(runWrap(conn), &del)
- //req.Nil(err)
-}
+// del := a{
+// BaseNode: BaseNode{
+// Id: 0,
+// UUID: "5334ee8c-6231-40fd-83e5-16c8016ccde6",
+// },
+// }
+
+// err = deleteNode(runWrap(conn), &del)
+// req.Nil(err)
+// }
diff --git a/gogm.go b/gogm.go
index 4b88003..29a90a6 100644
--- a/gogm.go
+++ b/gogm.go
@@ -24,12 +24,13 @@ import (
"crypto/x509"
"errors"
"fmt"
- "github.com/cornelk/hashmap"
- "github.com/neo4j/neo4j-go-driver/v4/neo4j"
"io/ioutil"
"reflect"
"strconv"
"strings"
+
+ "github.com/cornelk/hashmap"
+ "github.com/neo4j/neo4j-go-driver/v4/neo4j"
)
var globalGogm = &Gogm{isNoOp: true, logger: GetDefaultLogger()}
@@ -71,7 +72,7 @@ func NewContext(ctx context.Context, config *Config, pkStrategy *PrimaryKeyStrat
return nil, errors.New("pk strategy can not be nil")
}
- if mapTypes == nil || len(mapTypes) == 0 {
+ if len(mapTypes) == 0 {
return nil, errors.New("no types to map")
}
diff --git a/index_v3.go b/index_v3.go
index b3012f0..131bc9b 100644
--- a/index_v3.go
+++ b/index_v3.go
@@ -23,6 +23,7 @@ import (
"context"
"errors"
"fmt"
+
"github.com/adam-hanna/arrayOperations"
"github.com/cornelk/hashmap"
dsl "github.com/mindstand/go-cypherdsl"
@@ -39,7 +40,7 @@ func resultToStringArrV3(result [][]interface{}) ([]string, error) {
for _, res := range result {
val := res
// nothing to parse
- if val == nil || len(val) == 0 {
+ if len(val) == 0 {
continue
}
@@ -70,7 +71,7 @@ func dropAllIndexesAndConstraintsV3(ctx context.Context, gogm *Gogm) error {
return err
}
- if vals == nil || len(vals) == 0 {
+ if len(vals) == 0 {
// nothing to drop if no constraints exist
return nil
}
@@ -282,9 +283,7 @@ func verifyAllIndexesAndConstraintsV3(ctx context.Context, gogm *Gogm, mappedTyp
var founds []string
- for _, constraint := range foundConstraints {
- founds = append(founds, constraint)
- }
+ founds = append(founds, foundConstraints...)
delta, found = arrayOperations.Difference(founds, constraints)
if !found {
diff --git a/index_v4.go b/index_v4.go
index 2c5fc6b..289249c 100644
--- a/index_v4.go
+++ b/index_v4.go
@@ -23,6 +23,7 @@ import (
"context"
"errors"
"fmt"
+
"github.com/adam-hanna/arrayOperations"
"github.com/cornelk/hashmap"
"github.com/neo4j/neo4j-go-driver/v4/neo4j"
@@ -76,7 +77,7 @@ func resultToStringArrV4(isConstraint bool, result [][]interface{}) ([]string, e
for _, res := range result {
val := res
// nothing to parse
- if val == nil || len(val) == 0 {
+ if len(val) == 0 {
continue
}
@@ -108,7 +109,7 @@ func dropAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm) error {
return err
}
- if res == nil || len(res) == 0 {
+ if len(res) == 0 {
// no constraints to kill off, return from here
return nil
}
@@ -334,28 +335,28 @@ func verifyAllIndexesAndConstraintsV4(ctx context.Context, gogm *Gogm, mappedTyp
//verify from there
delta, found := arrayOperations.Difference(foundIndexes, indexes)
if !found {
+ err = fmt.Errorf("found differences in remote vs ogm for found indexes, %v", delta)
_err := sess.Close()
- if err != nil {
+ if _err != nil {
err = fmt.Errorf("%s: %w", err, _err)
}
- return fmt.Errorf("found differences in remote vs ogm for found indexes, %v", delta)
+ return err
}
gogm.logger.Debugf("%+v", delta)
var founds []string
- for _, constraint := range foundConstraints {
- founds = append(founds, constraint)
- }
+ founds = append(founds, foundConstraints...)
delta, found = arrayOperations.Difference(founds, constraints)
if !found {
+ err = fmt.Errorf("found differences in remote vs ogm for found constraints, %v", delta)
_err := sess.Close()
- if err != nil {
+ if _err != nil {
err = fmt.Errorf("%s: %w", err, _err)
}
- return fmt.Errorf("found differences in remote vs ogm for found constraints, %v", delta)
+ return err
}
gogm.logger.Debugf("%+v", delta)
diff --git a/integration_test.go b/integration_test.go
index 9b0e97f..2168c1a 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -22,12 +22,14 @@ package gogm
import (
"context"
"fmt"
- uuid2 "github.com/google/uuid"
- assert2 "github.com/stretchr/testify/assert"
"log"
"os"
"sync"
+ uuid2 "github.com/google/uuid"
+ "github.com/neo4j/neo4j-go-driver/v4/neo4j"
+ assert2 "github.com/stretchr/testify/assert"
+
"testing"
"time"
@@ -600,3 +602,92 @@ func testSaveV2(sess SessionV2, req *require.Assertions) {
req.EqualValues(prop1.TdMapOfTdSlice, prop2.TdMapOfTdSlice)
req.EqualValues(prop1.TdMapTdSliceOfTd, prop2.TdMapTdSliceOfTd)
}
+
+const testUuid = "f64953a5-8b40-4a87-a26b-6427e661570c"
+
+func (i *IntegrationTestSuite) TestSchemaLoadStrategy() {
+ req := i.Require()
+
+ i.gogm.config.LoadStrategy = SCHEMA_LOAD_STRATEGY
+
+ // create required nodes
+ testSchemaLoadStrategy_Setup(i.gogm, req)
+
+ sess, err := i.gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeRead})
+ req.Nil(err)
+ defer req.Nil(sess.Close())
+
+ ctx := context.Background()
+ req.Nil(sess.Begin(ctx))
+ defer req.Nil(sess.Close())
+
+ // test raw query (verify SchemaLoadStrategy + Neo driver decoding)
+ query, err := SchemaLoadStrategyOne(i.gogm, "n", "a", "uuid", "uuid", false, 1, nil)
+ req.Nil(err, "error generating SchemaLoadStrategy query")
+
+ cypher, err := query.ToCypher()
+ req.Nil(err, "error decoding cypher from generated SchemaLoadStrategy query")
+ raw, _, err := sess.QueryRaw(ctx, cypher, map[string]interface{}{"uuid": "f64953a5-8b40-4a87-a26b-6427e661570c"})
+ req.Nil(err)
+
+ req.Len(raw, 1, "Raw result should have one record")
+ req.Len(raw[0], 2, "Raw record should have two items")
+
+ // inspecting first node
+ node, ok := raw[0][0].(neo4j.Node)
+ req.True(ok)
+ req.ElementsMatch(node.Labels, []string{"a"})
+
+ // inspecting nested query result
+ req.Len(raw[0][1], 5)
+
+ var res a
+ err = sess.LoadDepth(ctx, &res, testUuid, 2)
+ req.Nil(err, "Load should not fail")
+
+ req.Len(res.ManyA, 1, "B node should be loaded properly")
+ req.True(res.SingleSpecA.Test == "testing", "C spec rel should be loaded properly")
+ req.True(res.SingleSpecA.End.TestField == "dasdfasd", "B node should be loaded through spec rel")
+}
+
+func testSchemaLoadStrategy_Setup(gogm *Gogm, req *require.Assertions) {
+ sess, err := gogm.NewSessionV2(SessionConfig{AccessMode: AccessModeWrite})
+ req.Nil(err)
+ defer req.Nil(sess.Close())
+
+ a1 := &a{
+ TestField: "test",
+ PropTest0: map[string]interface{}{
+ "test.test": "test",
+ "test2": 1,
+ },
+ PropTest1: map[string]string{
+ "test": "test",
+ },
+ PropsTest2: []string{"test", "test"},
+ PropsTest3: []int{1, 2},
+ }
+
+ b1 := &b{
+ TestField: "dasdfasd",
+ }
+
+ c1 := &c{
+ Start: a1,
+ End: b1,
+ Test: "testing",
+ }
+
+ a1.SingleSpecA = c1
+ a1.ManyA = []*b{b1}
+ b1.SingleSpec = c1
+ b1.ManyB = a1
+
+ a1.UUID = testUuid
+
+ ctx := context.Background()
+ req.Nil(sess.Begin(ctx))
+
+ req.Nil(sess.SaveDepth(ctx, a1, 3))
+ req.Nil(sess.Commit(ctx))
+}
diff --git a/interface.go b/interface.go
index ff44d01..1b8909a 100644
--- a/interface.go
+++ b/interface.go
@@ -20,8 +20,9 @@
package gogm
import (
- dsl "github.com/mindstand/go-cypherdsl"
"reflect"
+
+ dsl "github.com/mindstand/go-cypherdsl"
)
// Edge specifies required functions for special edge nodes
@@ -56,7 +57,7 @@ type ISession interface {
LoadDepth(respObj interface{}, id string, depth int) error
//load with depth and filter
- LoadDepthFilter(respObj interface{}, id string, depth int, filter *dsl.ConditionBuilder, params map[string]interface{}) error
+ LoadDepthFilter(respObj interface{}, id string, depth int, filter dsl.ConditionOperator, params map[string]interface{}) error
//load with depth, filter and pagination
LoadDepthFilterPagination(respObj interface{}, id string, depth int, filter dsl.ConditionOperator, params map[string]interface{}, pagination *Pagination) error
diff --git a/interfacev2.go b/interfacev2.go
index f972ce9..8f46dfe 100644
--- a/interfacev2.go
+++ b/interfacev2.go
@@ -21,6 +21,7 @@ package gogm
import (
"context"
+
dsl "github.com/mindstand/go-cypherdsl"
"github.com/neo4j/neo4j-go-driver/v4/neo4j"
)
@@ -61,7 +62,7 @@ type ogmFunctions interface {
LoadDepth(ctx context.Context, respObj, id interface{}, depth int) error
//load with depth and filter
- LoadDepthFilter(ctx context.Context, respObj, id interface{}, depth int, filter *dsl.ConditionBuilder, params map[string]interface{}) error
+ LoadDepthFilter(ctx context.Context, respObj, id interface{}, depth int, filter dsl.ConditionOperator, params map[string]interface{}) error
//load with depth, filter and pagination
LoadDepthFilterPagination(ctx context.Context, respObj, id interface{}, depth int, filter dsl.ConditionOperator, params map[string]interface{}, pagination *Pagination) error
diff --git a/load_strategy.go b/load_strategy.go
index 74176e4..6938a40 100644
--- a/load_strategy.go
+++ b/load_strategy.go
@@ -22,6 +22,8 @@ package gogm
import (
"errors"
"fmt"
+ "reflect"
+
dsl "github.com/mindstand/go-cypherdsl"
)
@@ -35,6 +37,15 @@ const (
SCHEMA_LOAD_STRATEGY
)
+func (ls LoadStrategy) validate() error {
+ switch ls {
+ case PATH_LOAD_STRATEGY, SCHEMA_LOAD_STRATEGY:
+ return nil
+ default:
+ return fmt.Errorf("invalid load strategy %d", ls)
+ }
+}
+
// PathLoadStrategyMany loads many using path strategy
func PathLoadStrategyMany(variable, label string, depth int, additionalConstraints dsl.ConditionOperator) (dsl.Cypher, error) {
if variable == "" {
@@ -168,3 +179,203 @@ func PathLoadStrategyEdgeConstraint(startVariable, startLabel, endLabel, endTarg
return builder.Return(false, dsl.ReturnPart{Name: "p"}), nil
}
+
+func getRelationshipsForLabel(gogm *Gogm, label string) ([]decoratorConfig, error) {
+ raw, ok := gogm.mappedTypes.Get(label)
+ if !ok {
+ return nil, fmt.Errorf("struct config not found type (%s)", label)
+ }
+
+ config, ok := raw.(structDecoratorConfig)
+ if !ok {
+ return nil, errors.New("unable to cast into struct decorator config")
+ }
+
+ fields := []decoratorConfig{}
+ for _, field := range config.Fields {
+ if !field.Ignore && field.Relationship != "" {
+ fields = append(fields, field)
+ }
+ }
+
+ return fields, nil
+}
+
+func expandBootstrap(gogm *Gogm, variable, label string, depth int) (string, error) {
+ clause := ""
+ rels, err := getRelationshipsForLabel(gogm, label)
+ if err != nil {
+ return "", err
+ }
+
+ if depth > 0 {
+ if len(rels) > 0 {
+ clause += ", ["
+ }
+
+ expanded, err := expand(gogm, variable, label, rels, 1, depth-1)
+ if err != nil {
+ return "", err
+ }
+ clause += expanded
+
+ if len(rels) > 0 {
+ clause += "]"
+ }
+ }
+
+ return clause, nil
+}
+
+func expand(gogm *Gogm, variable, label string, rels []decoratorConfig, level, depth int) (string, error) {
+ clause := ""
+
+ for i, rel := range rels {
+ // check if a seperator is needed
+ if i > 0 {
+ clause += ", "
+ }
+
+ ret, err := listComprehension(gogm, variable, label, rel, level, depth)
+ if err != nil {
+ return "", err
+ }
+ clause += ret
+ }
+
+ return clause, nil
+}
+
+func relString(variable string, rel decoratorConfig) string {
+ start := "-"
+ end := "-"
+
+ if rel.Direction == dsl.DirectionIncoming {
+ start = "<-"
+ } else if rel.Direction == dsl.DirectionOutgoing {
+ end = "->"
+ }
+
+ return fmt.Sprintf("%s[%s:%s]%s", start, variable, rel.Relationship, end)
+}
+
+func listComprehension(gogm *Gogm, fromNodeVar, label string, rel decoratorConfig, level, depth int) (string, error) {
+ relVar := fmt.Sprintf("r_%c_%d", rel.Relationship[0], level)
+
+ toNodeType := rel.Type.Elem()
+ if rel.Type.Kind() == reflect.Slice {
+ toNodeType = toNodeType.Elem()
+ }
+
+ toNodeLabel, err := traverseRelType(gogm, toNodeType, rel.Direction)
+ if err != nil {
+ return "", err
+ }
+
+ toNodeVar := fmt.Sprintf("n_%c_%d", toNodeLabel[0], level)
+
+ clause := fmt.Sprintf("[(%s)%s(%s:%s) | [%s, %s", fromNodeVar, relString(relVar, rel), toNodeVar, toNodeLabel, relVar, toNodeVar)
+
+ if depth > 0 {
+ toNodeRels, err := getRelationshipsForLabel(gogm, label)
+ if err != nil {
+ return "", err
+ }
+
+ if len(toNodeRels) > 0 {
+ toNodeExpansion, err := expand(gogm, toNodeVar, toNodeLabel, toNodeRels, level+1, depth-1)
+ if err != nil {
+ return "", err
+ }
+ clause += fmt.Sprintf(", [%s]", toNodeExpansion)
+ }
+ }
+
+ clause += "]]"
+ return clause, nil
+}
+
+// SchemaLoadStrategyMany loads many using schema strategy
+func SchemaLoadStrategyMany(gogm *Gogm, variable, label string, depth int, additionalConstraints dsl.ConditionOperator) (dsl.Cypher, error) {
+ if variable == "" {
+ return nil, errors.New("variable name cannot be empty")
+ }
+
+ if label == "" {
+ return nil, errors.New("label can not be empty")
+ }
+
+ if depth < 0 {
+ return nil, errors.New("depth can not be less than 0")
+ }
+
+ builder := dsl.QB().Cypher(fmt.Sprintf("MATCH (%s:%s)", variable, label))
+
+ if additionalConstraints != nil {
+ builder = builder.Where(additionalConstraints)
+ }
+
+ builder = builder.Cypher("RETURN " + variable)
+
+ if depth > 0 {
+ clause, err := expandBootstrap(gogm, variable, label, depth)
+ if err != nil {
+ return nil, err
+ }
+ builder = builder.Cypher(clause)
+ }
+
+ return builder, nil
+}
+
+// SchemaLoadStrategyOne loads one object using schema strategy
+func SchemaLoadStrategyOne(gogm *Gogm, variable, label, fieldOn, paramName string, isGraphId bool, depth int, additionalConstraints dsl.ConditionOperator) (dsl.Cypher, error) {
+ if variable == "" {
+ return nil, errors.New("variable name cannot be empty")
+ }
+
+ if label == "" {
+ return nil, errors.New("label can not be empty")
+ }
+
+ if depth < 0 {
+ return nil, errors.New("depth can not be less than 0")
+ }
+
+ builder := dsl.QB().Cypher(fmt.Sprintf("MATCH (%s:%s)", variable, label))
+
+ var condition *dsl.ConditionConfig
+ if isGraphId {
+ condition = &dsl.ConditionConfig{
+ FieldManipulationFunction: "ID",
+ Name: variable,
+ ConditionOperator: dsl.EqualToOperator,
+ Check: dsl.ParamString("$" + paramName),
+ }
+ } else {
+ condition = &dsl.ConditionConfig{
+ Name: variable,
+ Field: fieldOn,
+ ConditionOperator: dsl.EqualToOperator,
+ Check: dsl.ParamString("$" + paramName),
+ }
+ }
+
+ if additionalConstraints != nil {
+ builder = builder.Where(additionalConstraints.And(condition))
+ } else {
+ builder = builder.Where(dsl.C(condition))
+ }
+
+ builder = builder.Cypher("RETURN " + variable)
+
+ if depth > 0 {
+ clause, err := expandBootstrap(gogm, variable, label, depth)
+ if err != nil {
+ return nil, err
+ }
+ builder = builder.Cypher(clause)
+ }
+
+ return builder, nil
+}
diff --git a/load_strategy_test.go b/load_strategy_test.go
index 01beaee..07b0665 100644
--- a/load_strategy_test.go
+++ b/load_strategy_test.go
@@ -18,3 +18,68 @@
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package gogm
+
+import (
+ "testing"
+
+ dsl "github.com/mindstand/go-cypherdsl"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSchemaLoadStrategyMany(t *testing.T) {
+ req := require.New(t)
+
+ // reusing structs from decode_test
+ gogm, err := getTestGogm()
+ req.Nil(err)
+ req.NotNil(gogm)
+
+ // test base case with no schema expansion
+ cypher, err := SchemaLoadStrategyMany(gogm, "n", "a", 0, nil)
+ req.Nil(err)
+ cypherStr, err := cypher.ToCypher()
+ req.Nil(err)
+ req.Equal(cypherStr, "MATCH (n:a) RETURN n")
+
+ // test base case with no schema expansion
+ cypher, err = SchemaLoadStrategyMany(gogm, "n", "a", 0, dsl.C(&dsl.ConditionConfig{
+ Name: "n",
+ ConditionOperator: dsl.EqualToOperator,
+ Field: "test_field",
+ Check: dsl.ParamString("$someParam"),
+ }))
+ req.Nil(err)
+ cypherStr, err = cypher.ToCypher()
+ req.Nil(err)
+ req.Equal(cypherStr, "MATCH (n:a) WHERE n.test_field = $someParam RETURN n")
+
+ // test more complex case with schema expansion
+ cypher, err = SchemaLoadStrategyMany(gogm, "n", "a", 2, nil)
+ req.Nil(err)
+ req.NotNil(cypher)
+ cypherStr, err = cypher.ToCypher()
+ req.Nil(err)
+ req.NotContains(cypherStr, ":c)", "Spec edge should not be treated as a node")
+ req.Regexp("\\[[^\\(\\)\\[\\]]+:special[^\\(\\)\\[\\]]+]..\\([^\\(\\)\\[\\]]+:b\\)", cypherStr, "Spec edge rels should properly link to b")
+
+ // test fail condition of non-existing label
+ cypher, err = SchemaLoadStrategyMany(gogm, "n", "nonexisting", 2, nil)
+ req.NotNil(err, "Should fail due to non-existing label")
+ req.Nil(cypher)
+}
+
+func TestSchemaLoadStrategyOne(t *testing.T) {
+ req := require.New(t)
+
+ // reusing structs from decode_test
+ gogm, err := getTestGogm()
+ req.Nil(err)
+ req.NotNil(gogm)
+
+ // test base case with no schema expansion
+ cypher, err := SchemaLoadStrategyOne(gogm, "n", "a", "uuid", "uuid", false, 0, nil)
+ req.Nil(err)
+ cypherStr, err := cypher.ToCypher()
+ req.Nil(err)
+ req.Equal(cypherStr, "MATCH (n:a) WHERE n.uuid = $uuid RETURN n")
+}
diff --git a/mocks/ISession.go b/mocks/ISession.go
index 20eea85..55c8849 100644
--- a/mocks/ISession.go
+++ b/mocks/ISession.go
@@ -1,4 +1,4 @@
-// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
+// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
@@ -183,11 +183,11 @@ func (_m *ISession) LoadDepth(respObj interface{}, id string, depth int) error {
}
// LoadDepthFilter provides a mock function with given fields: respObj, id, depth, filter, params
-func (_m *ISession) LoadDepthFilter(respObj interface{}, id string, depth int, filter *go_cypherdsl.ConditionBuilder, params map[string]interface{}) error {
+func (_m *ISession) LoadDepthFilter(respObj interface{}, id string, depth int, filter *go_cypherdsl.ConditionOperator, params map[string]interface{}) error {
ret := _m.Called(respObj, id, depth, filter, params)
var r0 error
- if rf, ok := ret.Get(0).(func(interface{}, string, int, *go_cypherdsl.ConditionBuilder, map[string]interface{}) error); ok {
+ if rf, ok := ret.Get(0).(func(interface{}, string, int, *go_cypherdsl.ConditionOperator, map[string]interface{}) error); ok {
r0 = rf(respObj, id, depth, filter, params)
} else {
r0 = ret.Error(0)
@@ -316,3 +316,4 @@ func (_m *ISession) SaveDepth(saveObj interface{}, depth int) error {
return r0
}
+
diff --git a/mocks/SessionV2.go b/mocks/SessionV2.go
index c2304a4..cc878a3 100644
--- a/mocks/SessionV2.go
+++ b/mocks/SessionV2.go
@@ -1,4 +1,4 @@
-// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
+// Code generated by mockery v2.9.4. DO NOT EDIT.
package mocks
@@ -173,11 +173,11 @@ func (_m *SessionV2) LoadDepth(ctx context.Context, respObj interface{}, id inte
}
// LoadDepthFilter provides a mock function with given fields: ctx, respObj, id, depth, filter, params
-func (_m *SessionV2) LoadDepthFilter(ctx context.Context, respObj interface{}, id interface{}, depth int, filter *go_cypherdsl.ConditionBuilder, params map[string]interface{}) error {
+func (_m *SessionV2) LoadDepthFilter(ctx context.Context, respObj interface{}, id interface{}, depth int, filter *go_cypherdsl.ConditionOperator, params map[string]interface{}) error {
ret := _m.Called(ctx, respObj, id, depth, filter, params)
var r0 error
- if rf, ok := ret.Get(0).(func(context.Context, interface{}, interface{}, int, *go_cypherdsl.ConditionBuilder, map[string]interface{}) error); ok {
+ if rf, ok := ret.Get(0).(func(context.Context, interface{}, interface{}, int, *go_cypherdsl.ConditionOperator, map[string]interface{}) error); ok {
r0 = rf(ctx, respObj, id, depth, filter, params)
} else {
r0 = ret.Error(0)
diff --git a/mocks/TransactionV2.go b/mocks/TransactionV2.go
index 68c1381..d7c618c 100644
--- a/mocks/TransactionV2.go
+++ b/mocks/TransactionV2.go
@@ -145,11 +145,11 @@ func (_m *TransactionV2) LoadDepth(ctx context.Context, respObj interface{}, id
}
// LoadDepthFilter provides a mock function with given fields: ctx, respObj, id, depth, filter, params
-func (_m *TransactionV2) LoadDepthFilter(ctx context.Context, respObj interface{}, id interface{}, depth int, filter *go_cypherdsl.ConditionBuilder, params map[string]interface{}) error {
+func (_m *TransactionV2) LoadDepthFilter(ctx context.Context, respObj interface{}, id interface{}, depth int, filter *go_cypherdsl.ConditionOperator, params map[string]interface{}) error {
ret := _m.Called(ctx, respObj, id, depth, filter, params)
var r0 error
- if rf, ok := ret.Get(0).(func(context.Context, interface{}, interface{}, int, *go_cypherdsl.ConditionBuilder, map[string]interface{}) error); ok {
+ if rf, ok := ret.Get(0).(func(context.Context, interface{}, interface{}, int, *go_cypherdsl.ConditionOperator, map[string]interface{}) error); ok {
r0 = rf(ctx, respObj, id, depth, filter, params)
} else {
r0 = ret.Error(0)
diff --git a/save.go b/save.go
index 9c8e43a..9463b06 100644
--- a/save.go
+++ b/save.go
@@ -22,10 +22,11 @@ package gogm
import (
"errors"
"fmt"
- dsl "github.com/mindstand/go-cypherdsl"
- "github.com/neo4j/neo4j-go-driver/v4/neo4j"
"reflect"
"strconv"
+
+ dsl "github.com/mindstand/go-cypherdsl"
+ "github.com/neo4j/neo4j-go-driver/v4/neo4j"
)
// nodeCreate holds configuration for creating new nodes
@@ -155,7 +156,7 @@ func saveDepth(gogm *Gogm, obj interface{}, depth int) neo4j.TransactionWork {
// relateNodes connects nodes together using edge config
func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreate, lookup map[uintptr]int64) error {
- if relations == nil || len(relations) == 0 {
+ if len(relations) == 0 {
return errors.New("relations can not be nil or empty")
}
@@ -240,6 +241,9 @@ func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreat
}).
Cypher("SET rel += row.props").
ToCypher()
+ if err != nil {
+ return fmt.Errorf("failed to build query, %w", err)
+ }
res, err := transaction.Run(cyp, map[string]interface{}{
"rows": params,
@@ -256,7 +260,7 @@ func relateNodes(transaction neo4j.Transaction, relations map[string][]*relCreat
// removes relationships between specified nodes
func removeRelations(transaction neo4j.Transaction, dels map[int64][]int64) error {
- if dels == nil || len(dels) == 0 {
+ if len(dels) == 0 {
return nil
}
@@ -533,6 +537,9 @@ func createNodes(transaction neo4j.Transaction, crNodes map[string]map[uintptr]*
Alias: "id",
}).
ToCypher()
+ if err != nil {
+ return fmt.Errorf("failed to build query, %w", err)
+ }
res, err := transaction.Run(cyp, map[string]interface{}{
"rows": newRows,
@@ -596,6 +603,9 @@ func createNodes(transaction neo4j.Transaction, crNodes map[string]map[uintptr]*
Cypher("WHERE ID(n) = row.id").
Cypher("SET n += row.obj").
ToCypher()
+ if err != nil {
+ return fmt.Errorf("failed to build query, %w", err)
+ }
res, err := transaction.Run(cyp, map[string]interface{}{
"rows": updateRows,
diff --git a/save_test.go b/save_test.go
index 12c2720..a484562 100644
--- a/save_test.go
+++ b/save_test.go
@@ -20,11 +20,12 @@
package gogm
import (
- dsl "github.com/mindstand/go-cypherdsl"
- "github.com/stretchr/testify/require"
"reflect"
"testing"
"time"
+
+ dsl "github.com/mindstand/go-cypherdsl"
+ "github.com/stretchr/testify/require"
)
func TestParseStruct(t *testing.T) {
@@ -553,6 +554,7 @@ func TestCalculateDels(t *testing.T) {
uintptr(1): 1,
uintptr(2): 2,
})
+ req.Nil(err)
req.EqualValues(map[int64][]int64{
1: {2},
diff --git a/session.go b/session.go
index 5b1f710..5209739 100644
--- a/session.go
+++ b/session.go
@@ -41,7 +41,6 @@ type Session struct {
neoSess neo4j.Session
tx neo4j.Transaction
DefaultDepth int
- LoadStrategy LoadStrategy
mode neo4j.AccessMode
}
@@ -65,8 +64,7 @@ func newSession(gogm *Gogm, readonly bool) (*Session, error) {
}
session := &Session{
- gogm: gogm,
- LoadStrategy: PATH_LOAD_STRATEGY,
+ gogm: gogm,
}
var mode neo4j.AccessMode
@@ -116,7 +114,6 @@ func newSessionWithConfig(gogm *Gogm, conf SessionConfig) (*Session, error) {
DefaultDepth: defaultDepth,
mode: conf.AccessMode,
gogm: gogm,
- LoadStrategy: PATH_LOAD_STRATEGY,
}, nil
}
@@ -192,7 +189,7 @@ func (s *Session) LoadDepth(respObj interface{}, id string, depth int) error {
return s.LoadDepthFilterPagination(respObj, id, depth, nil, nil, nil)
}
-func (s *Session) LoadDepthFilter(respObj interface{}, id string, depth int, filter *dsl.ConditionBuilder, params map[string]interface{}) error {
+func (s *Session) LoadDepthFilter(respObj interface{}, id string, depth int, filter dsl.ConditionOperator, params map[string]interface{}) error {
return s.LoadDepthFilterPagination(respObj, id, depth, filter, params, nil)
}
@@ -217,14 +214,17 @@ func (s *Session) LoadDepthFilterPagination(respObj interface{}, id string, dept
var err error
//make the query based off of the load strategy
- switch s.LoadStrategy {
+ switch s.gogm.config.LoadStrategy {
case PATH_LOAD_STRATEGY:
query, err = PathLoadStrategyOne(varName, respObjName, "uuid", "uuid", false, depth, filter)
if err != nil {
return err
}
case SCHEMA_LOAD_STRATEGY:
- return errors.New("schema load strategy not supported yet")
+ query, err = SchemaLoadStrategyOne(s.gogm, varName, respObjName, "uuid", "uuid", false, depth, filter)
+ if err != nil {
+ return err
+ }
default:
return errors.New("unknown load strategy")
}
@@ -308,14 +308,17 @@ func (s *Session) LoadAllDepthFilterPagination(respObj interface{}, depth int, f
var err error
//make the query based off of the load strategy
- switch s.LoadStrategy {
+ switch s.gogm.config.LoadStrategy {
case PATH_LOAD_STRATEGY:
query, err = PathLoadStrategyMany(varName, respObjName, depth, filter)
if err != nil {
return err
}
case SCHEMA_LOAD_STRATEGY:
- return errors.New("schema load strategy not supported yet")
+ query, err = SchemaLoadStrategyMany(s.gogm, varName, respObjName, depth, filter)
+ if err != nil {
+ return err
+ }
default:
return errors.New("unknown load strategy")
}
@@ -378,17 +381,10 @@ func (s *Session) LoadAllEdgeConstraint(respObj interface{}, endNodeType, endNod
var query dsl.Cypher
var err error
- //make the query based off of the load strategy
- switch s.LoadStrategy {
- case PATH_LOAD_STRATEGY:
- query, err = PathLoadStrategyEdgeConstraint(varName, respObjName, endNodeType, endNodeField, minJumps, maxJumps, depth, filter)
- if err != nil {
- return err
- }
- case SCHEMA_LOAD_STRATEGY:
- return errors.New("schema load strategy not supported yet")
- default:
- return errors.New("unknown load strategy")
+ // there is no Schema Load Strategy implementation of EdgeConstraint as it would involve pathfinding within the schema (which would be expensive)
+ query, err = PathLoadStrategyEdgeConstraint(varName, respObjName, endNodeType, endNodeField, minJumps, maxJumps, depth, filter)
+ if err != nil {
+ return err
}
// handle if in transaction
@@ -565,16 +561,12 @@ func (s *Session) parseResult(res neo4j.Result) [][]interface{} {
switch v := val.(type) {
case neo4j.Path:
vals[i] = v
- break
case neo4j.Relationship:
vals[i] = v
- break
case neo4j.Node:
vals[i] = v
- break
default:
vals[i] = v
- continue
}
}
result = append(result, vals)
diff --git a/sessionv2.go b/sessionv2.go
index 7ff6247..f2c5a78 100644
--- a/sessionv2.go
+++ b/sessionv2.go
@@ -23,11 +23,12 @@ import (
"context"
"errors"
"fmt"
- "github.com/opentracing/opentracing-go"
"reflect"
"strings"
"time"
+ "github.com/opentracing/opentracing-go"
+
dsl "github.com/mindstand/go-cypherdsl"
"github.com/neo4j/neo4j-go-driver/v4/neo4j"
)
@@ -37,7 +38,6 @@ type SessionV2Impl struct {
neoSess neo4j.Session
tx neo4j.Transaction
DefaultDepth int
- LoadStrategy LoadStrategy
conf SessionConfig
lastBookmark string
}
@@ -67,7 +67,6 @@ func newSessionWithConfigV2(gogm *Gogm, conf SessionConfig) (*SessionV2Impl, err
DefaultDepth: defaultDepth,
conf: conf,
gogm: gogm,
- LoadStrategy: PATH_LOAD_STRATEGY,
}, nil
}
func (s *SessionV2Impl) Begin(ctx context.Context) error {
@@ -199,7 +198,7 @@ func (s *SessionV2Impl) LoadDepth(ctx context.Context, respObj, id interface{},
return s.LoadDepthFilterPagination(ctx, respObj, id, depth, nil, nil, nil)
}
-func (s *SessionV2Impl) LoadDepthFilter(ctx context.Context, respObj, id interface{}, depth int, filter *dsl.ConditionBuilder, params map[string]interface{}) error {
+func (s *SessionV2Impl) LoadDepthFilter(ctx context.Context, respObj, id interface{}, depth int, filter dsl.ConditionOperator, params map[string]interface{}) error {
var span opentracing.Span
if ctx != nil && s.gogm.config.OpentracingEnabled {
span, ctx = opentracing.StartSpanFromContext(ctx, "gogm.SessionV2Impl.LoadDepthFilter")
@@ -243,14 +242,17 @@ func (s *SessionV2Impl) LoadDepthFilterPagination(ctx context.Context, respObj,
isGraphId := s.gogm.pkStrategy.StrategyName == DefaultPrimaryKeyStrategy.StrategyName
field := s.gogm.pkStrategy.DBName
//make the query based off of the load strategy
- switch s.LoadStrategy {
+ switch s.gogm.config.LoadStrategy {
case PATH_LOAD_STRATEGY:
query, err = PathLoadStrategyOne(varName, respObjName, field, paramName, isGraphId, depth, filter)
if err != nil {
return err
}
case SCHEMA_LOAD_STRATEGY:
- return errors.New("schema load strategy not supported yet")
+ query, err = SchemaLoadStrategyOne(s.gogm, varName, respObjName, field, paramName, isGraphId, depth, filter)
+ if err != nil {
+ return err
+ }
default:
return errors.New("unknown load strategy")
}
@@ -362,14 +364,17 @@ func (s *SessionV2Impl) LoadAllDepthFilterPagination(ctx context.Context, respOb
var err error
//make the query based off of the load strategy
- switch s.LoadStrategy {
+ switch s.gogm.config.LoadStrategy {
case PATH_LOAD_STRATEGY:
query, err = PathLoadStrategyMany(varName, respObjName, depth, filter)
if err != nil {
return err
}
case SCHEMA_LOAD_STRATEGY:
- return errors.New("schema load strategy not supported yet")
+ query, err = SchemaLoadStrategyMany(s.gogm, varName, respObjName, depth, filter)
+ if err != nil {
+ return err
+ }
default:
return errors.New("unknown load strategy")
}
@@ -436,7 +441,7 @@ func (s *SessionV2Impl) runReadOnly(ctx context.Context, cyp string, params map[
}
return nil, decode(s.gogm, res, respObj)
- }, neo4j.WithTxTimeout(s.getDeadline(ctx).Sub(time.Now())))
+ }, neo4j.WithTxTimeout(time.Until(s.getDeadline(ctx))))
if err != nil {
return fmt.Errorf("failed auto read tx, %w", err)
}
@@ -535,7 +540,7 @@ func (s *SessionV2Impl) runWrite(ctx context.Context, work neo4j.TransactionWork
}
s.gogm.logger.Debug("running in managed write transaction")
- _, err := s.neoSess.WriteTransaction(work, neo4j.WithTxTimeout(s.getDeadline(ctx).Sub(time.Now())))
+ _, err := s.neoSess.WriteTransaction(work, neo4j.WithTxTimeout(time.Until(s.getDeadline(ctx))))
if err != nil {
return fmt.Errorf("failed to save in auto transaction, %w", err)
}
@@ -653,16 +658,12 @@ func (s *SessionV2Impl) parseResult(res neo4j.Result) [][]interface{} {
switch v := val.(type) {
case neo4j.Path:
vals[i] = v
- break
case neo4j.Relationship:
vals[i] = v
- break
case neo4j.Node:
vals[i] = v
- break
default:
vals[i] = v
- continue
}
}
result = append(result, vals)
@@ -742,7 +743,7 @@ func (s *SessionV2Impl) ManagedTransaction(ctx context.Context, work Transaction
deadline := s.getDeadline(ctx)
if s.conf.AccessMode == AccessModeWrite {
- _, err := s.neoSess.WriteTransaction(txWork, neo4j.WithTxTimeout(deadline.Sub(time.Now())))
+ _, err := s.neoSess.WriteTransaction(txWork, neo4j.WithTxTimeout(time.Until(deadline)))
if err != nil {
return fmt.Errorf("failed managed write tx, %w", err)
}
diff --git a/util.go b/util.go
index 728a68b..bb85266 100644
--- a/util.go
+++ b/util.go
@@ -22,10 +22,11 @@ package gogm
import (
"errors"
"fmt"
- go_cypherdsl "github.com/mindstand/go-cypherdsl"
"reflect"
"strings"
"sync"
+
+ dsl "github.com/mindstand/go-cypherdsl"
)
// checks if integer is in slice
@@ -229,12 +230,12 @@ func (r *relationConfigs) GetConfigs(startNodeType, startNodeFieldType, endNodeT
return nil, nil, errors.New("no configs provided")
}
- start, err = r.getConfig(startNodeType, relationship, startNodeFieldType, go_cypherdsl.DirectionOutgoing)
+ start, err = r.getConfig(startNodeType, relationship, startNodeFieldType, dsl.DirectionOutgoing)
if err != nil {
return nil, nil, err
}
- end, err = r.getConfig(endNodeType, relationship, endNodeFieldType, go_cypherdsl.DirectionIncoming)
+ end, err = r.getConfig(endNodeType, relationship, endNodeFieldType, dsl.DirectionIncoming)
if err != nil {
return nil, nil, err
}
@@ -242,7 +243,7 @@ func (r *relationConfigs) GetConfigs(startNodeType, startNodeFieldType, endNodeT
return start, end, nil
}
-func (r *relationConfigs) getConfig(nodeType, relationship, fieldType string, direction go_cypherdsl.Direction) (*decoratorConfig, error) {
+func (r *relationConfigs) getConfig(nodeType, relationship, fieldType string, direction dsl.Direction) (*decoratorConfig, error) {
if r.configs == nil {
return nil, errors.New("no configs provided")
}
@@ -310,18 +311,14 @@ func (r *relationConfigs) Validate() error {
validate := checkMap[relType]
switch config.Direction {
- case go_cypherdsl.DirectionIncoming:
+ case dsl.DirectionIncoming:
validate.Incoming = append(validate.Incoming, field)
- break
- case go_cypherdsl.DirectionOutgoing:
+ case dsl.DirectionOutgoing:
validate.Outgoing = append(validate.Outgoing, field)
- break
- case go_cypherdsl.DirectionNone:
+ case dsl.DirectionNone:
validate.None = append(validate.None, field)
- break
- case go_cypherdsl.DirectionBoth:
+ case dsl.DirectionBoth:
validate.Both = append(validate.Both, field)
- break
default:
return fmt.Errorf("unrecognized direction [%s], %w", config.Direction.ToString(), ErrValidation)
}
@@ -345,7 +342,7 @@ func (r *relationConfigs) Validate() error {
//check none direction
if len(validateConfig.None) != 0 {
if len(validateConfig.None)%2 != 0 {
- return fmt.Errorf("invalid length for 'both' validation, %w", ErrValidation)
+ return fmt.Errorf("invalid length for 'none' validation, %w", ErrValidation)
}
}
}
@@ -418,3 +415,40 @@ func getPrimitiveType(k reflect.Kind) (reflect.Type, error) {
func int64Ptr(n int64) *int64 {
return &n
}
+
+// traverseRelType finds the label of a node from a relationship (decoratorConfig).
+// if a special edge is passed in, the linked node's label is returned.
+func traverseRelType(gogm *Gogm, endType reflect.Type, direction dsl.Direction) (string, error) {
+ if !reflect.PtrTo(endType).Implements(edgeType) {
+ return endType.Name(), nil
+ }
+
+ gogm.logger.Debug(endType.Name())
+ endVal := reflect.New(endType)
+ var endTypeVal []reflect.Value
+
+ if direction == dsl.DirectionOutgoing {
+ endTypeVal = endVal.MethodByName("GetEndNodeType").Call(nil)
+ } else {
+ endTypeVal = endVal.MethodByName("GetStartNodeType").Call(nil)
+ }
+
+ if len(endTypeVal) != 1 {
+ return "", errors.New("GetEndNodeType failed")
+ }
+
+ if endTypeVal[0].IsNil() {
+ return "", errors.New("GetEndNodeType() can not return a nil value")
+ }
+
+ convertedType, ok := endTypeVal[0].Interface().(reflect.Type)
+ if !ok {
+ return "", errors.New("cannot convert to type reflect.Type")
+ }
+
+ if convertedType.Kind() == reflect.Ptr {
+ return convertedType.Elem().Name(), nil
+ } else {
+ return convertedType.Name(), nil
+ }
+}