From 7926aeaad64bfb5fa736df0aa7ef1f9fbc2f4f53 Mon Sep 17 00:00:00 2001 From: jxsl13 Date: Sat, 29 Jul 2023 20:54:58 +0200 Subject: [PATCH] refactor traversal --- go.mod | 3 +- go.sum | 8 +- names/deduplicate.go | 30 +++++++ names/deduplicate_test.go | 27 ++++++ testdata/004_callbacks.yaml | 164 ++++++++++++++++++++++++++++++++++++ traverse/callback.go | 4 +- traverse/components.go | 14 +-- traverse/header.go | 9 +- traverse/media_type.go | 4 +- traverse/operation.go | 12 +-- traverse/operation_test.go | 40 +++++++++ traverse/parameter.go | 11 ++- traverse/parameter_test.go | 60 +++++++++++++ traverse/path_item.go | 7 +- traverse/request_body.go | 10 +-- traverse/responses.go | 28 +++--- traverse/traverse.go | 63 ++++++++++++-- traverse/traverse_test.go | 62 +++++++++----- 18 files changed, 472 insertions(+), 84 deletions(-) create mode 100644 names/deduplicate.go create mode 100644 names/deduplicate_test.go create mode 100644 testdata/004_callbacks.yaml create mode 100644 traverse/operation_test.go create mode 100644 traverse/parameter_test.go diff --git a/go.mod b/go.mod index 31731b4..ca71153 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/structs v1.1.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.0 // indirect github.com/go-openapi/swag v0.22.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -36,7 +35,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/net v0.8.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 9521cc2..9052310 100644 --- a/go.sum +++ b/go.sum @@ -4,11 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.20.0 h1:ESKJdU9ASRfaPNOPRx12IUyA1vn3R9GiE3KYD14BXdQ= github.com/go-openapi/jsonpointer v0.20.0/go.mod h1:6PGzBjjIIumbLYysB73Klnms1mwnU4G3YHOECG3CedA= @@ -43,7 +40,6 @@ github.com/knadh/koanf/providers/structs v0.1.0 h1:wJRteCNn1qvLtE5h8KQBvLJovidSd github.com/knadh/koanf/providers/structs v0.1.0/go.mod h1:sw2YZ3txUcqA3Z27gPlmmBzWn1h8Nt9O6EP/91MkcWE= github.com/knadh/koanf/v2 v2.0.1 h1:1dYGITt1I23x8cfx8ZnldtezdyaZtfAuRtIFOiRzK7g= github.com/knadh/koanf/v2 v2.0.1/go.mod h1:ZeiIlIDXTE7w1lMT6UVcNiRAS2/rCeLn/GdLNvY1Dus= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -71,6 +67,8 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= @@ -89,8 +87,6 @@ github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= diff --git a/names/deduplicate.go b/names/deduplicate.go new file mode 100644 index 0000000..b329829 --- /dev/null +++ b/names/deduplicate.go @@ -0,0 +1,30 @@ +package names + +import "strings" + +func Deduplicate(names []string) []string { + dups := make(map[int]bool, len(names)-1) + for i := range names { + for j := range names { + if i == j { + continue + } + ni := strings.ToLower(names[i]) + nj := strings.ToLower(names[j]) + if strings.Contains(ni, nj) { + dups[j] = true + } else if strings.Contains(nj, ni) { + dups[i] = true + } + } + } + + result := make([]string, 0, len(names)-len(dups)) + for idx, name := range names { + if !dups[idx] { + result = append(result, name) + } + } + + return result +} diff --git a/names/deduplicate_test.go b/names/deduplicate_test.go new file mode 100644 index 0000000..99d6eeb --- /dev/null +++ b/names/deduplicate_test.go @@ -0,0 +1,27 @@ +package names_test + +import ( + "testing" + + "github.com/jxsl13/openapi-typegen/names" + "github.com/stretchr/testify/assert" +) + +func TestDeduplicate(t *testing.T) { + in := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + expected := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + out := names.Deduplicate(in) + + assert.Equal(t, expected, out) + + in = []string{"a", "b", "c", "d", "ef", "f", "g", "h", "i", "j"} + expected = []string{"a", "b", "c", "d", "ef", "g", "h", "i", "j"} + out = names.Deduplicate(in) + assert.Equal(t, expected, out) + + in = []string{"a", "bd", "c", "d", "ef", "f", "g", "gh", "i", "j"} + expected = []string{"a", "bd", "c", "ef", "gh", "i", "j"} + out = names.Deduplicate(in) + assert.Equal(t, expected, out) + +} diff --git a/testdata/004_callbacks.yaml b/testdata/004_callbacks.yaml new file mode 100644 index 0000000..1521903 --- /dev/null +++ b/testdata/004_callbacks.yaml @@ -0,0 +1,164 @@ +openapi: 3.0.1 + +info: + title: OpenAPI-CodeGen Test + description: 'This is a test OpenAPI Spec' + version: 1.0.0 + +servers: +- url: https://test.oapi-codegen.com/v2 +- url: http://test.oapi-codegen.com/v2 + +paths: + /test: + get: + operationId: doesNothing + summary: does nothing + tags: [nothing] + responses: + default: + description: returns nothing + content: + application/json: + schema: + type: object +components: + schemas: + Object1: + type: object + properties: + object: + $ref: "#/components/schemas/Object2" + Object2: + type: object + properties: + object: + $ref: "#/components/schemas/Object3" + Object3: + type: object + properties: + object: + $ref: "#/components/schemas/Object4" + Object4: + type: object + properties: + object: + $ref: "#/components/schemas/Object5" + Object5: + type: object + properties: + object: + $ref: "#/components/schemas/Object6" + Object6: + type: object + Pet: + type: object + required: + - id + - name + properties: + id: + type: integer + format: int64 + name: + type: string + tag: + type: string + Error: + required: + - code + - message + properties: + code: + type: integer + format: int32 + description: Error code + message: + type: string + description: Error message + parameters: + offsetParam: + name: offset + in: query + description: Number of items to skip before returning the results. + required: false + schema: + type: integer + format: int32 + minimum: 0 + default: 0 + securitySchemes: + BasicAuth: + type: http + scheme: basic + BearerAuth: + type: http + scheme: bearer + requestBodies: + PetBody: + description: A JSON object containing pet information + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Pet' + responses: + NotFound: + description: The specified resource was not found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + Unauthorized: + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + headers: + X-RateLimit-Limit: + schema: + type: integer + description: Request limit per hour. + X-RateLimit-Remaining: + schema: + type: integer + description: The number of requests left for the time window. + X-RateLimit-Reset: + schema: + type: string + format: date-time + description: The UTC date/time at which the current rate limit window resets + examples: + objectExample: + value: + id: 1 + name: new object + summary: A sample object + links: + GetUserByUserId: + description: > + The id value returned in the response can be used as + the userId parameter in GET /users/{userId}. + operationId: getUser + parameters: + userId: '$response.body#/id' + callbacks: + MyCallback: + '{$request.body#/callbackUrl}': + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + message: + type: string + example: Some event happened + required: + - message + responses: + '200': + description: Your server returns this code if it accepts the callback \ No newline at end of file diff --git a/traverse/callback.go b/traverse/callback.go index 0979839..332c887 100644 --- a/traverse/callback.go +++ b/traverse/callback.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // Callback traverses all unique non-reference schemas in the given callback. -func Callback(callback *openapi3.CallbackRef, visitor SchemaVisitor, levelNames ...string) error { +func Callback(callback *openapi3.CallbackRef, visitor SchemaVisitor, levelNames map[string][]string) error { if callback == nil { return nil } @@ -23,7 +23,7 @@ func Callback(callback *openapi3.CallbackRef, visitor SchemaVisitor, levelNames continue } - err = PathItem(pathItem, visitor, append(levelNames, callbackName)...) + err = PathItem(pathItem, visitor, add(levelNames, NameKey, callbackName)) if err != nil { return err } diff --git a/traverse/components.go b/traverse/components.go index fbebe8a..d673bb3 100644 --- a/traverse/components.go +++ b/traverse/components.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // Components traverses the given components and all unique non-reference schemas in it. -func Components(components *openapi3.Components, visitor SchemaVisitor, levelNames ...string) error { +func Components(components *openapi3.Components, visitor SchemaVisitor, levelNames map[string][]string) error { if components == nil { return nil } @@ -18,7 +18,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam if schema.Ref != "" { continue } - err = visitor(schema, append(levelNames, schemaName, SchemaSuffix)...) + err = visitor(schema, add(levelNames, NameKey, schemaName, TypeKey, SchemaType)) if err != nil { return err } @@ -33,7 +33,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam continue } - err = Header(header, visitor, append(levelNames, headerName)...) + err = Header(header, visitor, add(levelNames, NameKey, headerName)) if err != nil { return err } @@ -48,7 +48,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam if parameter.Ref != "" { continue } - err = Parameter(parameter, visitor, append(levelNames, parameterName)...) + err = Parameter(parameter, visitor, add(levelNames, NameKey, parameterName)) if err != nil { return err } @@ -62,7 +62,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam if requestBody.Ref != "" { continue } - err = RequestBody(requestBody, visitor, append(levelNames, requestBodyName)...) + err = RequestBody(requestBody, visitor, add(levelNames, NameKey, requestBodyName)) if err != nil { return err } @@ -76,7 +76,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam if response.Ref != "" { continue } - err = Response(response, visitor, append(levelNames, responseName)...) + err = Response(response, visitor, add(levelNames, NameKey, responseName)) if err != nil { return err } @@ -90,7 +90,7 @@ func Components(components *openapi3.Components, visitor SchemaVisitor, levelNam if callback.Ref != "" { continue } - err = Callback(callback, visitor, append(levelNames, callbackName)...) + err = Callback(callback, visitor, add(levelNames, NameKey, callbackName)) if err != nil { return err } diff --git a/traverse/header.go b/traverse/header.go index f5cfba0..fe25dbb 100644 --- a/traverse/header.go +++ b/traverse/header.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // Header traverses the given header and all unique non-reference schemas in it. -func Header(header *openapi3.HeaderRef, visitor SchemaVisitor, levelNames ...string) error { +func Header(header *openapi3.HeaderRef, visitor SchemaVisitor, levelNames map[string][]string) error { if header == nil { return nil } @@ -13,8 +13,7 @@ func Header(header *openapi3.HeaderRef, visitor SchemaVisitor, levelNames ...str if header.Value == nil { return nil } - if header.Value.Schema == nil { - return nil - } - return visitor(header.Value.Schema, append(levelNames, HeaderSuffix)...) + + // we want to handle component header definitions like any other parameter + return ParameterSchema(&header.Value.Parameter, visitor, add(levelNames, InKey, openapi3.ParameterInHeader)) } diff --git a/traverse/media_type.go b/traverse/media_type.go index 7c00691..b0301b5 100644 --- a/traverse/media_type.go +++ b/traverse/media_type.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // MediaType traverses the given media type and all unique non-reference schemas in it. -func MediaType(mediaType *openapi3.MediaType, visitor SchemaVisitor, levelNames ...string) error { +func MediaType(mediaType *openapi3.MediaType, visitor SchemaVisitor, levelNames map[string][]string) error { if mediaType == nil { return nil } @@ -15,5 +15,5 @@ func MediaType(mediaType *openapi3.MediaType, visitor SchemaVisitor, levelNames return nil } - return visitor(mediaType.Schema, levelNames...) + return visitor(mediaType.Schema, levelNames) } diff --git a/traverse/operation.go b/traverse/operation.go index 0f23799..4c1f910 100644 --- a/traverse/operation.go +++ b/traverse/operation.go @@ -2,7 +2,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" -func Operation(operation *openapi3.Operation, visitor SchemaVisitor, levelNames ...string) error { +func Operation(operation *openapi3.Operation, visitor SchemaVisitor, levelNames map[string][]string) error { if operation == nil { return nil } @@ -16,19 +16,21 @@ func Operation(operation *openapi3.Operation, visitor SchemaVisitor, levelNames if parameter.Ref != "" { continue } - err = Parameter(parameter, visitor, levelNames...) + err = Parameter(parameter, visitor, levelNames) if err != nil { return err } } if operation.RequestBody != nil && operation.RequestBody.Ref == "" { - if err := RequestBody(operation.RequestBody, visitor, append(levelNames, RequestSuffix)...); err != nil { + err = RequestBody(operation.RequestBody, visitor, levelNames) + if err != nil { return err } } if operation.Responses != nil { - if err := Responses(operation.Responses, visitor, levelNames...); err != nil { + err = Responses(operation.Responses, visitor, levelNames) + if err != nil { return err } } @@ -41,7 +43,7 @@ func Operation(operation *openapi3.Operation, visitor SchemaVisitor, levelNames if callback.Ref != "" { continue } - err = Callback(callback, visitor, append(levelNames, callbackName)...) + err = Callback(callback, visitor, add(levelNames, NameKey, callbackName)) if err != nil { return err } diff --git a/traverse/operation_test.go b/traverse/operation_test.go new file mode 100644 index 0000000..4453e18 --- /dev/null +++ b/traverse/operation_test.go @@ -0,0 +1,40 @@ +package traverse_test + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/jxsl13/openapi-typegen/traverse" + "github.com/stretchr/testify/assert" +) + +func TestSingleOperationMustHavePath(t *testing.T) { + doc := Documents["002_parameters.yaml"] + + cnt := 0 + traverse.Document(doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + if values, ok := levelNames[traverse.OperationKey]; ok && len(values) > 0 && values[0] != "" { + assert.Contains(t, levelNames, traverse.PathKey) + assert.NotEmpty(t, levelNames[traverse.PathKey]) + cnt++ + } + return nil + }) + t.Logf("found %d operations that have a path key", cnt) +} + +func TestAllOperationMustHavePath(t *testing.T) { + + cnt := 0 + for _, doc := range OrderedDocuments { + traverse.Document(doc.Doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + if values, ok := levelNames[traverse.OperationKey]; ok && len(values) > 0 && values[0] != "" { + assert.Contains(t, levelNames, traverse.PathKey) + assert.NotEmpty(t, levelNames[traverse.PathKey]) + cnt++ + } + return nil + }) + } + t.Logf("found %d operations that have a path key", cnt) +} diff --git a/traverse/parameter.go b/traverse/parameter.go index 0b9c391..67a6b03 100644 --- a/traverse/parameter.go +++ b/traverse/parameter.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // Parameter traverses the given parameter and all unique non-reference schemas in it. -func Parameter(parameter *openapi3.ParameterRef, visitor SchemaVisitor, levelNames ...string) error { +func Parameter(parameter *openapi3.ParameterRef, visitor SchemaVisitor, levelNames map[string][]string) error { if parameter == nil { return nil } @@ -13,18 +13,21 @@ func Parameter(parameter *openapi3.ParameterRef, visitor SchemaVisitor, levelNam if parameter.Value == nil { return nil } - if err := ParameterSchema(parameter.Value, visitor, append(levelNames, parameter.Value.Name, parameter.Value.In)...); err != nil { + err := ParameterSchema(parameter.Value, visitor, levelNames) + if err != nil { return err } return nil } -func ParameterSchema(parameter *openapi3.Parameter, visitor SchemaVisitor, levelNames ...string) error { +func ParameterSchema(parameter *openapi3.Parameter, visitor SchemaVisitor, levelNames map[string][]string) error { if parameter == nil { return nil } if parameter.Schema == nil { return nil } - return visitor(parameter.Schema, append(levelNames, ParameterSuffix)...) + + // for headers the IN value can become an empty string + return visitor(parameter.Schema, add(levelNames, NameKey, parameter.Name, InKey, parameter.In, TypeKey, ParameterType)) } diff --git a/traverse/parameter_test.go b/traverse/parameter_test.go new file mode 100644 index 0000000..4c5ac24 --- /dev/null +++ b/traverse/parameter_test.go @@ -0,0 +1,60 @@ +package traverse_test + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/jxsl13/openapi-typegen/traverse" + "github.com/stretchr/testify/assert" +) + +func TestSingleParameterTypeMustHaveInKey(t *testing.T) { + doc := Documents["004_callbacks.yaml"] + + traverse.Document(doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + + if values, ok := levelNames[traverse.TypeKey]; ok && len(values) > 0 && values[0] == traverse.ParameterType { + assert.Contains(t, levelNames, traverse.InKey) + } + return nil + }) +} + +func TestAllParameterTypeMustHaveInKey(t *testing.T) { + + for _, doc := range OrderedDocuments { + traverse.Document(doc.Doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + if values, ok := levelNames[traverse.TypeKey]; ok && len(values) > 0 && values[0] == traverse.ParameterType { + assert.Contains(t, levelNames, traverse.InKey) + } + return nil + }) + } +} + +func TestSingleParameterMustHaveExactlyOneInValue(t *testing.T) { + doc := Documents["002_parameters.yaml"] + + cnt := 0 + traverse.Document(doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + + if values, ok := levelNames[traverse.InKey]; ok { + assert.Len(t, values, 1) + cnt++ + } + return nil + }) + t.Logf("found %d parameters that have a in key", cnt) +} + +func TestAllParameterMustHaveExactlyOneInValue(t *testing.T) { + + for _, doc := range OrderedDocuments { + traverse.Document(doc.Doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + if values, ok := levelNames[traverse.TypeKey]; ok { + assert.Len(t, values, 1) + } + return nil + }) + } +} diff --git a/traverse/path_item.go b/traverse/path_item.go index 012d41e..4eac545 100644 --- a/traverse/path_item.go +++ b/traverse/path_item.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // PathItem traverses the given path item and calls the visitor for each schema. -func PathItem(pathItem *openapi3.PathItem, visitor SchemaVisitor, levelNames ...string) error { +func PathItem(pathItem *openapi3.PathItem, visitor SchemaVisitor, levelNames map[string][]string) error { if pathItem == nil { return nil } @@ -20,7 +20,7 @@ func PathItem(pathItem *openapi3.PathItem, visitor SchemaVisitor, levelNames ... if parameter.Ref != "" { continue } - err = Parameter(parameter, visitor, levelNames...) + err = Parameter(parameter, visitor, levelNames) if err != nil { return err } @@ -31,7 +31,8 @@ func PathItem(pathItem *openapi3.PathItem, visitor SchemaVisitor, levelNames ... continue } - if err := Operation(operation, visitor, append(levelNames, method, operation.OperationID)...); err != nil { + err = Operation(operation, visitor, add(levelNames, MethodKey, method, OperationKey, operation.OperationID)) + if err != nil { return err } } diff --git a/traverse/request_body.go b/traverse/request_body.go index 748221c..d20a473 100644 --- a/traverse/request_body.go +++ b/traverse/request_body.go @@ -3,7 +3,7 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // RequestBody traverses the given request body and all unique non-reference schemas in it. -func RequestBody(requestBody *openapi3.RequestBodyRef, visitor SchemaVisitor, levelNames ...string) error { +func RequestBody(requestBody *openapi3.RequestBodyRef, visitor SchemaVisitor, levelNames map[string][]string) error { if requestBody == nil { return nil } @@ -13,14 +13,14 @@ func RequestBody(requestBody *openapi3.RequestBodyRef, visitor SchemaVisitor, le if requestBody.Value == nil { return nil } - if requestBody.Value.Content == nil { - return nil - } + + var err error for contentType, mediaType := range requestBody.Value.Content { if mediaType == nil { continue } - if err := MediaType(mediaType, visitor, append(levelNames, contentType, RequestSuffix)...); err != nil { + err = MediaType(mediaType, visitor, add(levelNames, ContentKey, contentType, TypeKey, RequestType)) + if err != nil { return err } } diff --git a/traverse/responses.go b/traverse/responses.go index 14e41d0..177b3eb 100644 --- a/traverse/responses.go +++ b/traverse/responses.go @@ -3,12 +3,14 @@ package traverse import "github.com/getkin/kin-openapi/openapi3" // Responses traverses the given responses and all unique non-reference schemas in it. -func Responses(responses openapi3.Responses, visitor SchemaVisitor, levelNames ...string) error { +func Responses(responses openapi3.Responses, visitor SchemaVisitor, levelNames map[string][]string) error { + var err error for statusCode, response := range responses { if response == nil { continue } - if err := Response(response, visitor, append(levelNames, statusCode)...); err != nil { + err = Response(response, visitor, add(levelNames, StatusKey, statusCode)) + if err != nil { return err } } @@ -16,40 +18,42 @@ func Responses(responses openapi3.Responses, visitor SchemaVisitor, levelNames . } // Response traverses the given response and all unique non-reference schemas in it. -func Response(response *openapi3.ResponseRef, visitor SchemaVisitor, levelNames ...string) error { +func Response(response *openapi3.ResponseRef, visitor SchemaVisitor, levelNames map[string][]string) error { if response == nil { return nil } if response.Ref != "" { return nil } - if response.Value == nil { + + return ResponseSchema(response.Value, visitor, levelNames) +} + +func ResponseSchema(response *openapi3.Response, visitor SchemaVisitor, levelNames map[string][]string) error { + if response == nil { return nil } - //traverse headers var err error - for headerName, header := range response.Value.Headers { + for headerName, header := range response.Headers { if header == nil { continue } if header.Ref != "" { continue } - err = Header(header, visitor, append(levelNames, headerName)...) + err = Header(header, visitor, add(levelNames, NameKey, headerName)) if err != nil { return err } } - if response.Value.Content == nil { - return nil - } - for contentType, mediaType := range response.Value.Content { + for contentType, mediaType := range response.Content { if mediaType == nil { continue } - if err := MediaType(mediaType, visitor, append(levelNames, contentType, ResponseSuffix)...); err != nil { + err = MediaType(mediaType, visitor, add(levelNames, ContentKey, contentType, TypeKey, ResponseType)) + if err != nil { return err } } diff --git a/traverse/traverse.go b/traverse/traverse.go index dbad8fe..6ef37e0 100644 --- a/traverse/traverse.go +++ b/traverse/traverse.go @@ -5,17 +5,25 @@ import ( ) var ( - RequestSuffix = "Request" - ResponseSuffix = "Response" - HeaderSuffix = "Header" - ParameterSuffix = "Parameter" - SchemaSuffix = "Schema" + RequestType = "request" + ResponseType = "response" + ParameterType = "parameter" + SchemaType = "schema" + + TypeKey = "type" + NameKey = "name" + InKey = "in" + PathKey = "path" + MethodKey = "method" + OperationKey = "operation" + ContentKey = "content" + StatusKey = "status" ) -type SchemaVisitor func(schemaRef *openapi3.SchemaRef, levelNames ...string) error +type SchemaVisitor func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error // Document traverses the given document and calls the visitor for each schema. -func Document(t *openapi3.T, visitor SchemaVisitor, levelNames ...string) error { +func Document(t *openapi3.T, visitor SchemaVisitor, levelNames ...map[string][]string) error { if t == nil { return nil } @@ -29,7 +37,7 @@ func Document(t *openapi3.T, visitor SchemaVisitor, levelNames ...string) error continue } - err = PathItem(pathItem, visitor, append(levelNames, pathName)...) + err = PathItem(pathItem, visitor, add(merge(levelNames...), PathKey, pathName)) if err != nil { return err } @@ -37,7 +45,7 @@ func Document(t *openapi3.T, visitor SchemaVisitor, levelNames ...string) error // traverse components if t.Components != nil { - err = Components(t.Components, visitor, levelNames...) + err = Components(t.Components, visitor, merge(levelNames...)) if err != nil { return err } @@ -45,3 +53,40 @@ func Document(t *openapi3.T, visitor SchemaVisitor, levelNames ...string) error return nil } + +// merge any number of maps into a new map +func merge(maps ...map[string][]string) map[string][]string { + size := 0 + for _, m := range maps { + size += len(m) + } + + m := make(map[string][]string, size) + for _, m1 := range maps { + for k, vs := range m1 { + m[k] = append(m[k], vs...) + } + } + return m +} + +// add clones the map and adds all new +// empty key values are not added +func add(m map[string][]string, keyValue ...string) map[string][]string { + if len(keyValue)%2 != 0 { + panic("keyValue must be even") + } + + m2 := make(map[string][]string, len(m)+len(keyValue)/2) + for k, v := range m { + m2[k] = append(m2[k], v...) + } + + for i := 0; i < len(keyValue); i += 2 { + if keyValue[i+1] != "" { + m2[keyValue[i]] = append(m2[keyValue[i]], keyValue[i+1]) // only add non empty keys + } + } + + return m2 +} diff --git a/traverse/traverse_test.go b/traverse/traverse_test.go index 8ce308f..2e0c02b 100644 --- a/traverse/traverse_test.go +++ b/traverse/traverse_test.go @@ -1,26 +1,20 @@ package traverse_test import ( + "os" "sort" "testing" "github.com/getkin/kin-openapi/openapi3" "github.com/jxsl13/openapi-typegen/testutils" "github.com/jxsl13/openapi-typegen/traverse" - "github.com/stretchr/testify/require" + "github.com/k0kubun/pp/v3" + "github.com/stretchr/testify/assert" ) var ( Documents = testutils.LoadSpecs(`\d+.*.yaml`, "../testdata/") OrderedDocuments = mapToOrderedTupleList(Documents) - - SuffixMap = map[string]bool{ - traverse.RequestSuffix: true, - traverse.ResponseSuffix: true, - traverse.HeaderSuffix: true, - traverse.ParameterSuffix: true, - traverse.SchemaSuffix: true, - } ) type Tuple struct { @@ -47,23 +41,47 @@ func mapToOrderedTupleList(m map[string]*openapi3.T) []Tuple { return tuples } -func TestTraverse(t *testing.T) { - for _, tuple := range OrderedDocuments { - cnt := 0 - //unique := make(map[string]bool, 64) - lenBuckets := make(map[int]int, 64) +func TestMain(m *testing.M) { + pp.Default.SetColoringEnabled(false) + os.Exit(m.Run()) +} - err := traverse.Document(tuple.Doc, func(schemaRef *openapi3.SchemaRef, levelNames ...string) error { - t.Logf("document: %s, levelNames: %v", tuple.Name, levelNames) - require.Greater(t, len(levelNames), 0) - cnt++ +func TestSingleMustContainTypeKey(t *testing.T) { + doc := Documents["004_callbacks.yaml"] - lenBuckets[len(levelNames)]++ + traverse.Document(doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + assert.Contains(t, levelNames, traverse.TypeKey) + return nil + }) +} + +func TestAllMustContainTypeKey(t *testing.T) { + + for _, doc := range OrderedDocuments { + traverse.Document(doc.Doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + assert.Contains(t, levelNames, traverse.TypeKey) return nil }) - t.Logf("document: %s, cnt: %d", tuple.Name, cnt) - t.Logf("document: %s, lenBuckets: %v", tuple.Name, lenBuckets) - require.NoError(t, err) + } +} + +func TestSingleMustContainOneTypeKey(t *testing.T) { + doc := Documents["004_callbacks.yaml"] + + traverse.Document(doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + //must only contain one type name + assert.Equal(t, len(levelNames[traverse.TypeKey]), 1) + return nil + }) +} +func TestAllMustContainOneTypeKey(t *testing.T) { + + for _, doc := range OrderedDocuments { + traverse.Document(doc.Doc, func(schemaRef *openapi3.SchemaRef, levelNames map[string][]string) error { + //must only contain one type name + assert.Equal(t, len(levelNames[traverse.TypeKey]), 1) + return nil + }) } }