From b870fc998deb0c3adb89b9328e011037745b2b06 Mon Sep 17 00:00:00 2001 From: Chris Roche Date: Tue, 4 Sep 2018 12:42:19 -0700 Subject: [PATCH] Remove protoc-gen-go dependency (#30) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Removal of Plugins * protoc-gen-debug + a bunch of testdata * - Gatherer + AST graph parser * Move some isolated Go-specific stuff * SourceCodeInfo * WKTs * Imports * Move go specific stuff into subpackage * Mock Debugger * Bye PGGo * Buncha tests… * Generator to use new AST + workflows * Update module interface * Update persister * proto related helpers * update example plugin * Readme + Travis --- .gitignore | 10 +- .travis.yml | 14 +- Makefile | 100 +- README.md | 367 +++----- artifact.go | 29 +- ast.go | 351 ++++++++ ast_test.go | 260 ++++++ build_context_test.go | 72 +- comment.go | 6 - debug.go | 112 ++- debug_test.go | 134 +-- entity.go | 22 +- enum.go | 64 +- enum_test.go | 30 +- enum_value.go | 15 +- enum_value_test.go | 8 + extension.go | 6 +- field.go | 26 +- field_test.go | 39 +- field_type.go | 38 +- field_type_elem.go | 24 +- field_type_elem_test.go | 33 +- field_type_test.go | 86 +- file.go | 112 ++- file_test.go | 55 +- gatherer.go | 569 ------------ gatherer_test.go | 852 ------------------ generator.go | 97 +- generator_test.go | 71 +- init_option.go | 28 +- init_option_test.go | 38 +- lang/go/Makefile | 42 + lang/go/context.go | 68 ++ lang/go/context_test.go | 19 + lang/go/docs.go | 2 + lang/go/gofmt.go | 36 + lang/go/gofmt_test.go | 54 ++ lang/go/helpers_test.go | 45 + lang/go/name.go | 90 ++ lang/go/name_test.go | 204 +++++ lang/go/package.go | 94 ++ lang/go/package_test.go | 131 +++ lang/go/parameters.go | 132 +++ lang/go/parameters_test.go | 134 +++ .../go/testdata/names/entities/entities.proto | 108 +++ lang/go/testdata/names/entities/params | 0 lang/go/testdata/names/import/import.proto | 4 + lang/go/testdata/names/import/params | 0 .../names/import_path/import_path.proto | 4 + lang/go/testdata/names/import_path/params | 1 + lang/go/testdata/names/keyword/keyword.proto | 6 + lang/go/testdata/names/keyword/params | 0 lang/go/testdata/names/mapped/mapped.proto | 5 + lang/go/testdata/names/mapped/params | 1 + .../names/none/NO.pack--age.name$.proto | 3 + lang/go/testdata/names/none/params | 0 .../go/testdata/names/override/override.proto | 4 + lang/go/testdata/names/override/params | 0 lang/go/testdata/names/package/package.proto | 6 + lang/go/testdata/names/package/params | 0 lang/go/testdata/names/types/params | 0 lang/go/testdata/names/types/proto2.proto | 65 ++ lang/go/testdata/names/types/proto3.proto | 45 + lang/go/testdata/names/unnamed/params | 0 lang/go/testdata/names/unnamed/unnamed.proto | 5 + lang/go/testdata/outputs/import_prefix/params | 1 + .../outputs/import_prefix/prefix.proto | 4 + .../outputs/import_prefix_srcrel/params | 2 + .../outputs/import_prefix_srcrel/prefix.proto | 4 + lang/go/testdata/outputs/mapped/mapped.proto | 5 + lang/go/testdata/outputs/mapped/params | 1 + .../outputs/mapped_srcrel/mapped.proto | 5 + lang/go/testdata/outputs/mapped_srcrel/params | 1 + lang/go/testdata/outputs/none/none.proto | 3 + lang/go/testdata/outputs/none/params | 0 .../testdata/outputs/none_srcrel/none.proto | 3 + lang/go/testdata/outputs/none_srcrel/params | 1 + lang/go/testdata/outputs/qualified/params | 0 .../outputs/qualified/qualified.proto | 5 + .../testdata/outputs/qualified_srcrel/params | 1 + .../outputs/qualified_srcrel/qualified.proto | 5 + lang/go/testdata/outputs/unqualified/params | 0 .../outputs/unqualified/unqualified.proto | 5 + .../outputs/unqualified_srcrel/params | 1 + .../unqualified_srcrel/unqualified.proto | 5 + .../import_prefix/import_prefix.proto | 13 + .../go/testdata/packages/import_prefix/params | 1 + lang/go/testdata/packages/mapped/mapped.proto | 13 + lang/go/testdata/packages/mapped/params | 1 + .../packages/no_options/no_options.proto | 13 + lang/go/testdata/packages/no_options/params | 0 .../fully_qualified/fully_qualified.proto | 5 + .../testdata/packages/targets/none/none.proto | 4 + .../targets/unqualified/unqualified.proto | 5 + lang/go/type_name.go | 132 +++ lang/go/type_name_test.go | 345 +++++++ message.go | 149 +-- message_test.go | 63 +- method.go | 29 +- method_test.go | 17 +- module.go | 44 +- module_test.go | 24 +- name.go | 100 +- name_test.go | 279 +----- node_nilvisitor_test.go | 4 +- oneof.go | 15 +- oneof_test.go | 14 +- package.go | 43 +- package_test.go | 40 +- parameters.go | 137 +-- parameters_test.go | 112 --- path.go | 61 -- path_test.go | 78 -- persister.go | 47 +- persister_test.go | 224 ++--- plugin.go | 207 ----- plugin_test.go | 229 ----- post_process.go | 34 - post_process_test.go | 51 -- proto.go | 66 +- proto_test.go | 27 +- protoc-gen-debug/README.md | 52 ++ protoc-gen-debug/main.go | 48 + protoc_gen_go.go | 62 -- protoc_gen_go_test.go | 36 - service.go | 25 +- service_test.go | 15 +- source_code_info.go | 56 ++ source_code_info_test.go | 26 + testdata/graph/README.md | 13 + testdata/graph/info/info.proto | 56 ++ testdata/graph/messages/embedded.proto | 24 + testdata/graph/messages/enums.proto | 24 + testdata/graph/messages/enums_ext.proto | 4 + testdata/graph/messages/maps.proto | 44 + testdata/graph/messages/oneofs.proto | 13 + testdata/graph/messages/recursive.proto | 6 + testdata/graph/messages/repeated.proto | 41 + testdata/graph/messages/scalars.proto | 20 + testdata/graph/nested/nested.proto | 27 + testdata/graph/services/services.proto | 27 + testdata/protoc-gen-example/jsonify.go | 114 +++ testdata/protoc-gen-example/jsonify_plugin.go | 124 --- testdata/protoc-gen-example/main.go | 18 +- .../{printer_module.go => printer.go} | 36 +- testdata/protos/kitchen/kitchen.proto | 2 +- testdata/protos/kitchen/sink.proto | 2 +- .../protos/multipackage/bar/baz/quux.proto | 1 + testdata/protos/multipackage/bar/buzz.proto | 1 + testdata/protos/multipackage/foo/fizz.proto | 1 + wkt.go | 73 ++ wkt_test.go | 58 ++ workflow.go | 136 +-- workflow_multipackage.go | 306 ------- workflow_multipackage_test.go | 319 ------- workflow_test.go | 96 +- 156 files changed, 4610 insertions(+), 4950 deletions(-) create mode 100644 ast.go create mode 100644 ast_test.go delete mode 100644 gatherer.go delete mode 100644 gatherer_test.go create mode 100644 lang/go/Makefile create mode 100644 lang/go/context.go create mode 100644 lang/go/context_test.go create mode 100644 lang/go/docs.go create mode 100644 lang/go/gofmt.go create mode 100644 lang/go/gofmt_test.go create mode 100644 lang/go/helpers_test.go create mode 100644 lang/go/name.go create mode 100644 lang/go/name_test.go create mode 100644 lang/go/package.go create mode 100644 lang/go/package_test.go create mode 100644 lang/go/parameters.go create mode 100644 lang/go/parameters_test.go create mode 100644 lang/go/testdata/names/entities/entities.proto create mode 100644 lang/go/testdata/names/entities/params create mode 100644 lang/go/testdata/names/import/import.proto create mode 100644 lang/go/testdata/names/import/params create mode 100644 lang/go/testdata/names/import_path/import_path.proto create mode 100644 lang/go/testdata/names/import_path/params create mode 100644 lang/go/testdata/names/keyword/keyword.proto create mode 100644 lang/go/testdata/names/keyword/params create mode 100644 lang/go/testdata/names/mapped/mapped.proto create mode 100644 lang/go/testdata/names/mapped/params create mode 100644 lang/go/testdata/names/none/NO.pack--age.name$.proto create mode 100644 lang/go/testdata/names/none/params create mode 100644 lang/go/testdata/names/override/override.proto create mode 100644 lang/go/testdata/names/override/params create mode 100644 lang/go/testdata/names/package/package.proto create mode 100644 lang/go/testdata/names/package/params create mode 100644 lang/go/testdata/names/types/params create mode 100644 lang/go/testdata/names/types/proto2.proto create mode 100644 lang/go/testdata/names/types/proto3.proto create mode 100644 lang/go/testdata/names/unnamed/params create mode 100644 lang/go/testdata/names/unnamed/unnamed.proto create mode 100644 lang/go/testdata/outputs/import_prefix/params create mode 100644 lang/go/testdata/outputs/import_prefix/prefix.proto create mode 100644 lang/go/testdata/outputs/import_prefix_srcrel/params create mode 100644 lang/go/testdata/outputs/import_prefix_srcrel/prefix.proto create mode 100644 lang/go/testdata/outputs/mapped/mapped.proto create mode 100644 lang/go/testdata/outputs/mapped/params create mode 100644 lang/go/testdata/outputs/mapped_srcrel/mapped.proto create mode 100644 lang/go/testdata/outputs/mapped_srcrel/params create mode 100644 lang/go/testdata/outputs/none/none.proto create mode 100644 lang/go/testdata/outputs/none/params create mode 100644 lang/go/testdata/outputs/none_srcrel/none.proto create mode 100644 lang/go/testdata/outputs/none_srcrel/params create mode 100644 lang/go/testdata/outputs/qualified/params create mode 100644 lang/go/testdata/outputs/qualified/qualified.proto create mode 100644 lang/go/testdata/outputs/qualified_srcrel/params create mode 100644 lang/go/testdata/outputs/qualified_srcrel/qualified.proto create mode 100644 lang/go/testdata/outputs/unqualified/params create mode 100644 lang/go/testdata/outputs/unqualified/unqualified.proto create mode 100644 lang/go/testdata/outputs/unqualified_srcrel/params create mode 100644 lang/go/testdata/outputs/unqualified_srcrel/unqualified.proto create mode 100644 lang/go/testdata/packages/import_prefix/import_prefix.proto create mode 100644 lang/go/testdata/packages/import_prefix/params create mode 100644 lang/go/testdata/packages/mapped/mapped.proto create mode 100644 lang/go/testdata/packages/mapped/params create mode 100644 lang/go/testdata/packages/no_options/no_options.proto create mode 100644 lang/go/testdata/packages/no_options/params create mode 100644 lang/go/testdata/packages/targets/fully_qualified/fully_qualified.proto create mode 100644 lang/go/testdata/packages/targets/none/none.proto create mode 100644 lang/go/testdata/packages/targets/unqualified/unqualified.proto create mode 100644 lang/go/type_name.go create mode 100644 lang/go/type_name_test.go delete mode 100644 path.go delete mode 100644 path_test.go delete mode 100644 plugin.go delete mode 100644 plugin_test.go create mode 100644 protoc-gen-debug/README.md create mode 100644 protoc-gen-debug/main.go delete mode 100644 protoc_gen_go.go delete mode 100644 protoc_gen_go_test.go create mode 100644 source_code_info.go create mode 100644 source_code_info_test.go create mode 100644 testdata/graph/README.md create mode 100644 testdata/graph/info/info.proto create mode 100644 testdata/graph/messages/embedded.proto create mode 100644 testdata/graph/messages/enums.proto create mode 100644 testdata/graph/messages/enums_ext.proto create mode 100644 testdata/graph/messages/maps.proto create mode 100644 testdata/graph/messages/oneofs.proto create mode 100644 testdata/graph/messages/recursive.proto create mode 100644 testdata/graph/messages/repeated.proto create mode 100644 testdata/graph/messages/scalars.proto create mode 100644 testdata/graph/nested/nested.proto create mode 100644 testdata/graph/services/services.proto create mode 100644 testdata/protoc-gen-example/jsonify.go delete mode 100644 testdata/protoc-gen-example/jsonify_plugin.go rename testdata/protoc-gen-example/{printer_module.go => printer.go} (81%) create mode 100644 wkt.go create mode 100644 wkt_test.go delete mode 100644 workflow_multipackage.go delete mode 100644 workflow_multipackage_test.go diff --git a/.gitignore b/.gitignore index e5f9139..9c73d95 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,13 @@ +# vendored code vendor/ -testdata/generated/ + +# binaries +bin/ # coverage reports cover.* + + +testdata/generated/ +**/*.pb.go +**/code_generator_request.pb.bin diff --git a/.travis.yml b/.travis.yml index 614b986..37fa3ae 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,21 @@ language: go -go: "1.10" +go: "1.11" env: global: - GLIDE_VER="v0.13.1" - - GLIDE_ARCH="linux-amd64" + matrix: + - PROTOC_VER="3.5.1" + - PROTOC_VER="3.6.1" before_install: - mkdir -p $GOPATH/bin - - wget "https://github.com/Masterminds/glide/releases/download/${GLIDE_VER}/glide-${GLIDE_VER}-${GLIDE_ARCH}.tar.gz" -O /tmp/glide.tar.gz + - wget "https://github.com/Masterminds/glide/releases/download/${GLIDE_VER}/glide-${GLIDE_VER}-linux-amd64.tar.gz" -O /tmp/glide.tar.gz - tar -xvf /tmp/glide.tar.gz --strip-components 1 -C ${GOPATH}/bin + - wget "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VER}/protoc-${PROTOC_VER}-linux-x86_64.zip" -O /tmp/protoc.zip + - unzip /tmp/protoc.zip -d /tmp + - sudo mv /tmp/bin/protoc /usr/local/bin/protoc + - sudo mv /tmp/include/google /usr/local/include/google -install: make install +install: make testdata script: make lint tests diff --git a/Makefile b/Makefile index df7410e..3b77ee1 100644 --- a/Makefile +++ b/Makefile @@ -1,29 +1,28 @@ -# the name of this package -PKG := $(shell go list .) - -.PHONY: install -install: # downloads dependencies (including test deps) for the package - which glide || (curl https://glide.sh/get | sh) - glide install +# the name of this package & all subpackages +PKG := $(shell go list .) +PKGS := $(shell go list ./...) .PHONY: lint lint: # lints the package for common code smells + set -e; for f in `find . -name "*.go" -not -name "*.pb.go" | grep -v vendor`; do \ + out=`gofmt -s -d $$f`; \ + test -z "$$out" || (echo $$out && exit 1); \ + done which golint || go get -u golang.org/x/lint/golint - test -z "$(gofmt -d -s ./*.go)" || (gofmt -d -s ./*.go && exit 1) - golint -set_exit_status - go tool vet -all -shadow -shadowstrict *.go + golint -set_exit_status $(PKGS) + go vet -all -shadow -shadowstrict $(PKGS) .PHONY: quick -quick: # runs all tests without the race detector or coverage percentage - go test +quick: vendor # runs all tests without the race detector or coverage + go test $(PKGS) .PHONY: tests -tests: # runs all tests against the package with race detection and coverage percentage - go test -race -cover +tests: vendor # runs all tests against the package with race detection and coverage percentage + go test -race -cover $(PKGS) .PHONY: cover -cover: # runs all tests against the package, generating a coverage report and opening it in the browser - go test -race -covermode=atomic -coverprofile=cover.out +cover: vendor # runs all tests against the package, generating a coverage report and opening it in the browser + go test -race -covermode=atomic -coverprofile=cover.out $(PKGS) || true go tool cover -html cover.out -o cover.html open cover.html @@ -32,11 +31,66 @@ docs: # starts a doc server and opens a browser window to this package (sleep 2 && open http://localhost:6060/pkg/$(PKG)/) & godoc -http=localhost:6060 -.PHONY: demo -demo: # compiles, installs, and runs the demo protoc-plugin - go install $(PKG)/testdata/protoc-gen-example - rm -r ./testdata/generated || true - mkdir -p ./testdata/generated - set -e; cd ./testdata/protos; for subdir in `find . -type d -mindepth 1 -maxdepth 1`; do \ - protoc -I . --example_out="plugins:../generated" `find $$subdir -name "*.proto"`; \ +.PHONY: testdata +testdata: testdata-graph testdata-go testdata/generated # generate all testdata + +.PHONY: testdata-graph +testdata-graph: bin/protoc-gen-debug # parses the proto file sets in testdata/graph and renders binary CodeGeneratorRequest + set -e; for subdir in `find ./testdata/graph -type d -mindepth 1 -maxdepth 1`; do \ + protoc -I ./testdata/graph \ + --plugin=protoc-gen-debug=./bin/protoc-gen-debug \ + --debug_out="$$subdir:$$subdir" \ + `find $$subdir -name "*.proto"`; \ + done + +testdata/generated: protoc-gen-go bin/protoc-gen-example + which protoc-gen-go || (go install github.com/golang/protobuf/protoc-gen-go) + rm -rf ./testdata/generated && mkdir -p ./testdata/generated + # generate the official go code, must be one directory at a time + set -e; for subdir in `find ./testdata/protos -type d -mindepth 1`; do \ + files=`find $$subdir -name "*.proto" -maxdepth 1`; \ + [ ! -z "$$files" ] && \ + protoc -I ./testdata/protos \ + --go_out="$$GOPATH/src" \ + $$files; \ + done + # generate using our demo plugin, don't need to go directory at a time + set -e; for subdir in `find ./testdata/protos -type d -mindepth 1 -maxdepth 1`; do \ + protoc -I ./testdata/protos \ + --plugin=protoc-gen-example=./bin/protoc-gen-example \ + --example_out="paths=source_relative:./testdata/generated" \ + `find $$subdir -name "*.proto"`; \ + done + +.PHONY: testdata-go +testdata-go: protoc-gen-go bin/protoc-gen-debug # generate go-specific testdata + cd lang/go && $(MAKE) \ + testdata-names \ + testdata-packages \ + testdata-outputs + +vendor: # install project dependencies + which glide || (curl https://glide.sh/get | sh) + glide install + +.PHONY: protoc-gen-go +protoc-gen-go: + which protoc-gen-go || (go get -u github.com/golang/protobuf/protoc-gen-go) + +bin/protoc-gen-example: vendor # creates the demo protoc plugin for demonstrating uses of PG* + go build -o ./bin/protoc-gen-example ./testdata/protoc-gen-example + +bin/protoc-gen-debug: vendor # creates the protoc-gen-debug protoc plugin for output ProtoGeneratorRequest messages + go build -o ./bin/protoc-gen-debug ./protoc-gen-debug + +.PHONY: clean +clean: + rm -rf vendor + rm -rf bin + rm -rf testdata/generated + set -e; for f in `find . -name *.pb.bin`; do \ + rm $$f; \ + done + set -e; for f in `find . -name *.pb.go`; do \ + rm $$f; \ done diff --git a/README.md b/README.md index ff7df83..3c58d2a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# protoc-gen-star (PGS) [![Build Status](https://travis-ci.org/lyft/protoc-gen-star.svg?branch=master)](https://travis-ci.org/lyft/protoc-gen-star) +# protoc-gen-star (PG*) [![Build Status](https://travis-ci.org/lyft/protoc-gen-star.svg?branch=master)](https://travis-ci.org/lyft/protoc-gen-star) [![GoDoc](https://godoc.org/github.com/lyft/protoc-gen-star?status.svg)](https://godoc.org/github.com/lyft/protoc-gen-star) **!!! THIS PROJECT IS A WORK-IN-PROGRESS | THE API SHOULD BE CONSIDERED UNSTABLE !!!** -_PGS is a protoc plugin library for efficient proto-based code generation_ +_PG* is a protoc plugin library for efficient proto-based code generation_ ```go package main @@ -10,234 +10,128 @@ package main import "github.com/lyft/protoc-gen-star" func main() { - pgs.Init(pgs.IncludeGo()). - RegisterPlugin(&myProtocGenGoPlugin{}). - RegisterModule(&myPGSModule{}). - RegisterPostProcessor(&myPostProcessor{}). - Render() + pgs.Init(pgs.DebugEnv("DEBUG")). + RegisterModule(&myPGSModule{}). + RegisterPostProcessor(&myPostProcessor{}). + Render() } ``` -`protoc-gen-star` (PGS) is built on top of the official [`protoc-gen-go`][pgg] (PGG) protoc plugin. PGG contains a mechanism for extending its behavior with plugins (for instance, gRPC support via a plugin). However, this feature is not accessible from the outside and requires either forking PGG or replicating its behavior using its library code. Further still, the PGG plugins are designed specifically for extending the officially generated Go code, not creating other new files or packages. - -PGS leverages the existing PGG library code to properly build up the [Protocol Buffer][pb] (PB) descriptors' dependency graph before handing it off to custom [`Modules`][module] to generate anything above-and-beyond the officially generated code. In fact, by default PGS does not generate the official Go code. While PGS is written in Go and relies on PGG, this library can be used to generate code in any language. - ## Features ### Documentation -While this README seeks to describe many of the nuances of `protoc` plugin development and using PGS, the true documentation source is the code itself. The Go language is self-documenting and provides tools for easily reading through it and viewing examples. Until this package is open sourced, the documentation can be viewed locally by running `make docs`, which will start a `godoc` server and open the documentation in the default browser. +While this README seeks to describe many of the nuances of `protoc` plugin development and using PG*, the true documentation source is the code itself. The Go language is self-documenting and provides tools for easily reading through it and viewing examples. The docs can be viewed on [GoDoc](https://godoc.org/github.com/lyft/protoc-gen-star) or locally by running `make docs`, which will start a `godoc` server and open them in the default browser. ### Roadmap -- [x] Full support for official Go PB output and `protoc-gen-go` plugins, can replace `protoc-gen-go` - [x] Interface-based and fully-linked dependency graph with access to raw descriptors -- [x] Built-in context-aware debugging capabilities +- [x] Built-in context-aware debugging capabilities - [x] Exhaustive, near 100% unit test coverage -- [x] End-to-end testable via overrideable IO +- [x] End-to-end testable via overrideable IO & Interface based API - [x] [`Visitor`][visitor] pattern and helpers for efficiently walking the dependency graph - [x] [`BuildContext`][context] to facilitate complex generation - [x] Parsed, typed command-line [`Parameters`][params] access -- [x] Extensible `PluginBase` for quickly creating `protoc-gen-go` plugins - [x] Extensible `ModuleBase` for quickly creating `Modules` and facilitating code generation -- [x] Configurable post-processing (eg, gofmt/goimports) of generated files -- [x] Support processing proto files from multiple packages (normally disallowed by `protoc-gen-go`) +- [x] Configurable post-processing (eg, gofmt) of generated files +- [x] Support processing proto files from multiple packages +- [x] Load comments (via SourceCodeInfo) from proto files into gathered AST for easy access +- [x] Language-specific helper subpackages for handling common, nuanced generation tasks - [ ] Load plugins/modules at runtime using Go shared libraries -- [x] Load comments from proto files into gathered AST for easy access -- [ ] More intelligent Go import path resolution ### Examples -[`protoc-gen-example`][pge], can be found in the `testdata` directory. It includes both `Plugin` and `Module` implementations using a variety of the features available. It's `protoc` execution is included in the `demo` [Makefile][make] target. Test examples are also accessible via the documentation by running `make docs`. +[`protoc-gen-example`][pge], can be found in the `testdata` directory. It includes two `Module` implementations using a variety of the features available. It's `protoc` execution is included in the `testdata/generated` [Makefile][make] target. Examples are also accessible via the documentation by running `make docs`. ## How It Works ### The `protoc` Flow -Because the process is somewhat confusing, this section will cover the entire flow of how proto files are converted to generated code, using a hypothetical PGS plugin: `protoc-gen-myplugin`. A typical execution looks like this: +Because the process is somewhat confusing, this section will cover the entire flow of how proto files are converted to generated code, using a hypothetical PG* plugin: `protoc-gen-myplugin`. A typical execution looks like this: ```sh protoc \ - -I . \ - --myplugin_out="plugins=grpc:../generated" \ - ./pkg/*.proto + -I . \ + --myplugin_out="foo=bar:../generated" \ + ./pkg/*.proto ``` -`protoc`, the PB compiler, is configured using a set of flags (documented under `protoc -h`) and handed a set of files as arguments. In this case, the `I` flag can be specified multiple times and is the lookup path it should use for imported dependencies in a proto file. By default, the official descriptor protos are already included. +`protoc`, the PB compiler, is configured using a set of flags (documented under `protoc -h`) and handed a set of files as arguments. In this case, the `I` flag can be specified multiple times and is the lookup path it uses for imported dependencies in a proto file. By default, the official descriptor protos are already included. -`myplugin_out` tells `protoc` to use the `protoc-gen-myplugin` protoc-plugin. These plugins are automatically resolved from the system's `PATH` environment variable, or can be explicitly specified with another flag. The official protoc-plugins (eg, `protoc-gen-python`) are already registered with `protoc`. The flag's value is specific to the particular plugin, with the exception of the `:../generated` suffix. This suffix indicates the root directory in which `protoc` will place the generated files from that package (relative to the current working directory). This generated output directory is _not_ propagated to `protoc-gen-myplugin`, however, so it needs to be duplicated in the left-hand side of the flag. PGS supports this via an `output_path` parameter. +`myplugin_out` tells `protoc` to use the `protoc-gen-myplugin` protoc-plugin. These plugins are automatically resolved from the system's `PATH` environment variable, or can be explicitly specified with another flag. The official protoc-plugins (eg, `protoc-gen-python`) are already registered with `protoc`. The flag's value is specific to the particular plugin, with the exception of the `:../generated` suffix. This suffix indicates the root directory in which `protoc` will place the generated files from that package (relative to the current working directory). This generated output directory is _not_ propagated to `protoc-gen-myplugin`, however, so it needs to be duplicated in the left-hand side of the flag. PG* supports this via an `output_path` parameter. `protoc` parses the passed in proto files, ensures they are syntactically correct, and loads any imported dependencies. It converts these files and the dependencies into descriptors (which are themselves PB messages) and creates a `CodeGeneratorRequest` (yet another PB). `protoc` serializes this request and then executes each configured protoc-plugin, sending the payload via `stdin`. -`protoc-gen-myplugin` starts up, receiving the request payload, which it unmarshals. There are two phases to a PGS-based protoc-plugin. First, the standard PGG process is executed against the input. This allows for generation of the official Go code if desired, as well as applying any PGG plugins we've specified in its protoc flag (in this case, we opted to use `grpc`). PGS, also injects a plugin here called the `gatherer`, which constructs a dependency graph from the incoming descriptors. This process populates a `CodeGeneratorResponse` PB message containing the files to generate. +`protoc-gen-myplugin` starts up, receiving the request payload, which it unmarshals. There are two phases to a PG*-based protoc-plugin. First, PG* unmarshals the `CodeGeneratorRequest` received from `protoc`, and creates a fully connected abstract syntax tree (AST) of each file and all its contained entities. Any parameters specified for this plugin are also parsed for later consumption. -When this step is complete, PGS then executes any registered `Modules`, handing it the constructed graph. `Modules` can be written to generate more files, adding them to the response PB, writing them to disk directly, or just performing some form of validation over the provided graph without any other side effects. `Modules` provide the most flexibility in terms of operating against the PBs. +When this step is complete, PG* then executes any registered `Modules`, handing it the constructed AST. `Modules` can be written to generate artifacts (eg, files) or just performing some form of validation over the provided graph without any other side effects. `Modules` provide the great flexibility in terms of operating against the PBs. -Once all `Modules` are complete, `protoc-gen-myplugin` serializes the `CodeGeneratorResponse` and writes the data to its `stdout`. `protoc` receives this payload, unmarshals it, and writes any requested files to disk after all its plugins have returned. This whole flow looked something like this: +Once all `Modules` are run, PG* writes any custom artifacts to the file system or serializes generator-specific ones into a `CodeGeneratorResponse` and sends the data to its `stdout`. `protoc` receives this payload, unmarshals it, and persists any requested files to disk after all its plugins have returned. This whole flow looks something like this: ``` foo.proto → protoc → CodeGeneratorRequest → protoc-gen-myplugin → CodeGeneratorResponse → protoc → foo.pb.go ``` -The PGS library hides away nearly all of this complexity required to implement a protoc-plugin! - -### Plugins - -Plugins in this context refer to libraries registered with the PGG library to extend the functionality of the offiically generated Go code. The only officially supported extension is `grpc`, which generates the server and client code for services defined in the proto files. This is one of the best ways to extend the behavior of the already generated code within its package. - -PGS provides a `PluginBase` struct to simplify development of these plugins. Out of the box, it satisfies the interface for a `generator.Plugin`, only requiring the creation of the `Name` and `Generate` methods. `PluginBase` is best used as an anonymous embedded field of a wrapping `Plugin` implementation. A minimal plugin would look like the following: - -```go -// GraffitiPlugin tags the generated Go source -type graffitiPlugin struct { - *pgs.PluginBase - tag string -} - -// New configures the plugin with an instance of PluginBase and captures the tag -// which will be used during code generation. -func New(tag string) pgs.Plugin { - return &graffitiPlugin{ - PluginBase: new(pgs.PluginBase), - tag: tag, - } -} - -// Name is the identifier used in the protoc execution to enable this plugin for -// code generation. -func (p *graffitiPlugin) Name() string { return "graffiti" } - -// Generate is handed each file descriptor loaded by protoc, including -// dependencies not targeted for building. Don't worry though, the underlying -// library ensures that writes only occur for those specified by protoc. -func (p *graffitiPlugin) Generate(f *generator.FileDescriptor) { - p.Push(f.GetName()).Debug("tagging") - p.C(80, p.tag) - p.Pop() -} -``` - -`PluginBase` exposes a PGS [`BuildContext`][context] instance, already prefixed with the plugin's name. Calling `Push` and `Pop` allows adding further information to error and debugging messages. Above, the name of the file being generated is pushed onto the context before logging the "tagging" debug message. - -The base also provides helper methods for rendering into the output file. `p.P` prints a line with arbitrary arguments, similar to `fmt.Println`. `p.C` renders a comment in a similar fashion to `p.P` but intelligently wraps the comment to multiple lines at the specified width (above, 80 is used to wrap the supplied tag value). While `p.P` and `p.C` are very procedural, sometimes smarter generation is required: `p.T` renders Go templates (of either the `text` or `html` variety). - -Typically, plugins are registered globally, usually within an `init` method on the plugin's package, but PGS provides some utilities to facilitate development. When registering it with a PGS `Generator`, however, the `init` methodology should be avoided in favor of the following: - -```go -g := pgs.Init(pgs.IncludeGo()) -g.RegisterPlugin(graffiti.New("rodaine was here")) -``` - -`IncludeGo` must be specified or none of the official Go code will be generated. If the plugin also implements the PGS `Plugin` interface (which is achieved for free by composing over `PluginBase`), a shared pre-configured BuildContext will be provided to the plugin for consistent logging and error handling mechanisms. +The PG* library hides away nearly all of this complexity required to implement a protoc-plugin! ### Modules -While plugins allow for injecting into the PGG generated code file-by-file, some code generation tasks require knowing the entire dependency graph of a proto file first or intend to create files on disk outside of the output directory specified by `protoc` (or with custom permissions). `Modules` fill this gap. - -PGS `Modules` are evaluated after the normal PGG flow and are handed a complete graph of the PB entities from the `gatherer` that are targeted for generation as well as all dependencies. A `Module` can then add files to the protoc `CodeGeneratorResponse` or write files directly to disk as `Artifacts`. +PG* `Modules` are handed a complete AST for those files that are targeted for generation as well as all dependencies. A `Module` can then add files to the protoc `CodeGeneratorResponse` or write files directly to disk as `Artifacts`. -PGS provides a `ModuleBase` struct to simplify developing modules. Out of the box, it satisfies the interface for a `Module`, only requiring the creation of `Name` and `Execute` methods. `ModuleBase` is best used as an anonyomous embedded field of a wrapping `Module` implementation. A minimal module would look like the following: +PG* provides a `ModuleBase` struct to simplify developing modules. Out of the box, it satisfies the interface for a `Module`, only requiring the creation of `Name` and `Execute` methods. `ModuleBase` is best used as an anonyomous embedded field of a wrapping `Module` implementation. A minimal module would look like the following: ```go -// ReportModule creates a report of all the target messages generated by the +// ReportModule creates a report of all the target messages generated by the // protoc run, writing the file into the /tmp directory. type reportModule struct { - *pgs.ModuleBase + *pgs.ModuleBase } // New configures the module with an instance of ModuleBase func New() pgs.Module { return &reportModule{&pgs.ModuleBase{}} } -// Name is the identifier used to identify the module. This value is +// Name is the identifier used to identify the module. This value is // automatically attached to the BuildContext associated with the ModuleBase. func (m *reportModule) Name() string { return "reporter" } -// Execute is passed the target pkg as well as its dependencies in the pkgs map. -// The implementation should return a slice of Artifacts that represent the -// files to be generated. In this case, "/tmp/report.txt" will be created +// Execute is passed the target files as well as its dependencies in the pkgs +// map. The implementation should return a slice of Artifacts that represent +// the files to be generated. In this case, "/tmp/report.txt" will be created // outside of the normal protoc flow. -func (m *reportModule) Execute(pkg pgs.Package, pkgs map[string]pgs.Package) []pgs.Artifact { - buf := &bytes.Buffer{} +func (m *reportModule) Execute(targets map[string]pgs.File, pkgs map[string]Package) []pgs.Artifact { + buf := &bytes.Buffer{} - for _, f := range pkg.Files() { - m.Push(f.Name().String()).Debug("reporting") + for _, f := range targets { + m.Push(f.Name().String()).Debug("reporting") - fmt.Fprintf(buf, "--- %v ---", f.Name()) - - for i, msg := range f.AllMessages() { - fmt.Fprintf(buf, "%03d. %v", msg.Name()) - } + fmt.Fprintf(buf, "--- %v ---", f.Name()) - m.Pop() - } + for i, msg := range f.AllMessages() { + fmt.Fprintf(buf, "%03d. %v\n", i, msg.Name()) + } - m.OverwriteCustomFile( - "/tmp/report.txt", - buf.String(), - 0644, - ) + m.Pop() + } - return m.Artifacts() + m.OverwriteCustomFile( + "/tmp/report.txt", + buf.String(), + 0644, + ) + + return m.Artifacts() } ``` -`ModuleBase` exposes a PGS [`BuildContext`][context] instance, already prefixed with the module's name. Calling `Push` and `Pop` allows adding further information to error and debugging messages. Above, each file from the target package is pushed onto the context before logging the "reporting" debug message. +`ModuleBase` exposes a PG* [`BuildContext`][context] instance, already prefixed with the module's name. Calling `Push` and `Pop` allows adding further information to error and debugging messages. Above, each file from the target package is pushed onto the context before logging the "reporting" debug message. The base also provides helper methods for adding or overwriting both protoc-generated and custom files. The above execute method creates a custom file at `/tmp/report.txt` specifying that it should overwrite an existing file with that name. If it instead called `AddCustomFile` and the file existed, no file would have been generated (though a debug message would be logged out). Similar methods exist for adding generator files, appends, and injections. Likewise, methods such as `AddCustomTemplateFile` allows for `Templates` to be rendered instead. -After all modules have been executed, the returned `Artifacts` are either placed into the `CodeGenerationResponse` payload for protoc or written out to the file system. For testing purposes, the file system has been abstracted such that a custom one (such as an in-memory FS) can be provided to the PGS generator with the `FileSystem` `InitOption`. - -Modules are registered with PGS similar to `Plugins`: - - ```go -g := pgs.Init(pgs.IncludeGo()) -g.RegisterModule(reporter.New()) -``` - -#### Multi-Package Aware Modules - -If the `MultiPackage` `InitOption` is enabled and multiple packages are passed into the PGS plugin, a `Module` can be upgraded to a `MultiModule` interface to support handling more than one target package simultaneously. Implementing this on the `reportModule` above might look like the following: - -```go -// MultiExecute satisfies the MultiModule interface. Instead of calling Execute -// and generating a file for each target package, the report can be written -// including all files from all packages in one. -func (m *reportModule) MultiExecute(targets map[string]Package, pkgs map[string]Package) []Artifact { - buf := &bytes.Buffer{} - - for _, pkg := range targets { - m.Push(pkg.Name().String()) - for _, f := range pkg.Files() { - m.Push(f.Name().String()).Debug("reporting") - - fmt.Fprintf(buf, "--- %v ---", f.Name()) - - for i, msg := range f.AllMessages() { - fmt.Fprintf(buf, "%03d. %v", msg.Name()) - } - - m.Pop() - } - m.Pop() - } - - m.OverwriteCustomFile( - "/tmp/report.txt", - buf.String(), - 0644, - ) - - return m.Artifacts() -} -``` - -Without `MultiExecute`, the module's `Execute` method would be called for each individual target `Package` processed. In the above example, the report file would be created for each, possibly overwriting each other. If a `Module` implements `MultiExecute`, however, the method recieves all target packages at once and can choose how to process them, in this case, creating a single report file for all. - -See the **Multi-Package Workflow** section below for more details. +After all modules have been executed, the returned `Artifacts` are either placed into the `CodeGenerationResponse` payload for protoc or written out to the file system. For testing purposes, the file system has been abstracted such that a custom one (such as an in-memory FS) can be provided to the PG* generator with the `FileSystem` `InitOption`. #### Post Processing -`Artifacts` generated by `Modules` sometimes require some mutations prior to writing to disk or sending in the reponse to protoc. This could range from running `gofmt` against Go source or adding copyright headers to all generated source files. To simplify this task in PGS, a `PostProcessor` can be utilized. A minimal looking `PostProcessor` implementation might look like this: +`Artifacts` generated by `Modules` sometimes require some mutations prior to writing to disk or sending in the response to protoc. This could range from running `gofmt` against Go source or adding copyright headers to all generated source files. To simplify this task in PG*, a `PostProcessor` can be utilized. A minimal looking `PostProcessor` implementation might look like this: ```go // New returns a PostProcessor that adds a copyright comment to the top @@ -245,76 +139,76 @@ See the **Multi-Package Workflow** section below for more details. func New(owner string) pgs.PostProcessor { return copyrightPostProcessor{owner} } type copyrightPostProcessor struct { - owner string + owner string } // Match returns true only for Custom and Generated files (including templates). func (cpp copyrightPostProcessor) Match(a pgs.Artifact) bool { - switch a := a.(type) { - case pgs.GeneratorFile, pgs.GeneratorTemplateFile, - pgs.CustomFile, pgs.CustomTemplateFile: - return true - default: - return false - } + switch a := a.(type) { + case pgs.GeneratorFile, pgs.GeneratorTemplateFile, + pgs.CustomFile, pgs.CustomTemplateFile: + return true + default: + return false + } } // Process attaches the copyright header to the top of the input bytes func (cpp copyrightPostProcessor) Process(in []byte) (out []byte, err error) { - cmt := fmt.Sprintf("// Copyright © %d %s. All rights reserved\n", - time.Now().Year(), - cpp.owner) + cmt := fmt.Sprintf("// Copyright © %d %s. All rights reserved\n", + time.Now().Year(), + cpp.owner) - return append([]byte(cmt), in...), nil + return append([]byte(cmt), in...), nil } -``` +``` -The `copyrightPostProcessor` struct satisfies the `PostProcessor` interface by implementing the `Match` and `Process` methods. After PGS recieves all `Artifacts`, each is handed in turn to each registered processor's `Match` method. In the above case, we return `true` if the file is a part of the targeted Artifact types. If `true` is returned, `Process` is immediately called with the rendered contents of the file. This method mutates the input, returning the modified value to out or an error if something goes wrong. Above, the notice is prepended to the input. +The `copyrightPostProcessor` struct satisfies the `PostProcessor` interface by implementing the `Match` and `Process` methods. After PG* recieves all `Artifacts`, each is handed in turn to each registered processor's `Match` method. In the above case, we return `true` if the file is a part of the targeted Artifact types. If `true` is returned, `Process` is immediately called with the rendered contents of the file. This method mutates the input, returning the modified value to out or an error if something goes wrong. Above, the notice is prepended to the input. -PostProcessors are registered with PGS similar to `Plugins` and `Modules`: +PostProcessors are registered with PG* similar to `Modules`: ```go g := pgs.Init(pgs.IncludeGo()) g.RegisterModule(some.NewModule()) -g.RegisterPostProcessor(copyright.New("PGS Authors")) +g.RegisterPostProcessor(copyright.New("PG* Authors")) ``` ## Protocol Buffer AST -While `protoc` ensures that all the dependencies required to generate a proto file are loaded in as descriptors, it's up to the protoc-plugins to recognize the relationships between them. PGG handles this to some extent, but does not expose it in a easily accessible or testable manner outside of its sub-plugins and standard generation. To get around this, PGS uses the `gatherer` plugin to construct an abstract syntax tree (AST) of all the `Entities` loaded into the plugin. This AST is provided to every `Module` to facilitate code generation. +While `protoc` ensures that all the dependencies required to generate a proto file are loaded in as descriptors, it's up to the protoc-plugins to recognize the relationships between them. To get around this, PG* uses constructs an abstract syntax tree (AST) of all the `Entities` loaded into the plugin. This AST is provided to every `Module` to facilitate code generation. ### Hierarchy -The hierarchy generated by the PGS `gatherer` is fully linked, starting at a top-level `Package` down to each individual `Field` of a `Message`. The AST can be represented with the following digraph: +The hierarchy generated by the PG* `gatherer` is fully linked, starting at a top-level `Package` down to each individual `Field` of a `Message`. The AST can be represented with the following digraph:

-A `Package` describes a set of `Files` loaded within the same namespace. As would be expected, a `File` represents a single proto file, which contains any number of `Message`, `Enum` or `Service` entities. An `Enum` describes an integer-based enumeration type, containing each individual `EnumValue`. A `Service` describes a set of RPC `Methods`, which in turn refer to their input and output `Messages`. +A `Package` describes a set of `Files` loaded within the same namespace. As would be expected, a `File` represents a single proto file, which contains any number of `Message`, `Enum` or `Service` entities. An `Enum` describes an integer-based enumeration type, containing each individual `EnumValue`. A `Service` describes a set of RPC `Methods`, which in turn refer to their input and output `Messages`. A `Message` can contain other nested `Messages` and `Enums` as well as each of its `Fields`. For non-scalar types, a `Field` may also reference its `Message` or `Enum` type. As a mechanism for achieving union types, a `Message` can also contain `OneOf` entities that refer to some of its `Fields`. ### Visitor Pattern -The structure of the AST can be fairly complex and unpredictable. Likewise, `Module's` are typically concerned with only a subset of the entities in the graph. To separate the `Module's` algorithm from understanding and traversing the structure of the AST, PGS implements the `Visitor` pattern to decouple the two. Implementing this interface is straightforward and can greatly simplify code generation. +The structure of the AST can be fairly complex and unpredictable. Likewise, `Module's` are typically concerned with only a subset of the entities in the graph. To separate the `Module's` algorithm from understanding and traversing the structure of the AST, PG* implements the `Visitor` pattern to decouple the two. Implementing this interface is straightforward and can greatly simplify code generation. -Two base `Visitor` structs are provided by PGS to simplify developing implementations. First, the `NilVisitor` returns an instance that short-circuits execution for all Entity types. This is useful when certain branches of the AST are not interesting to code generation. For instance, if the `Module` is only concerned with `Services`, it can use a `NilVisitor` as an anonymous field and only implement the desired interface methods: +Two base `Visitor` structs are provided by PG* to simplify developing implementations. First, the `NilVisitor` returns an instance that short-circuits execution for all Entity types. This is useful when certain branches of the AST are not interesting to code generation. For instance, if the `Module` is only concerned with `Services`, it can use a `NilVisitor` as an anonymous field and only implement the desired interface methods: ```go // ServiceVisitor logs out each Method's name type serviceVisitor struct { - pgs.Visitor - pgs.DebuggerCommon + pgs.Visitor + pgs.DebuggerCommon } -func New(d pgs.DebuggerCommon) pgs.Visitor { - return serviceVistor{ - Visitor: pgs.NilVisitor(), - DebuggerCommon: d, - } +func New(d pgs.DebuggerCommon) pgs.Visitor { + return serviceVistor{ + Visitor: pgs.NilVisitor(), + DebuggerCommon: d, + } } -// Passthrough Packages, Files, and Services. All other methods can be -// ignored since Services can only live in Files and Files can only live in a +// Passthrough Packages, Files, and Services. All other methods can be +// ignored since Services can only live in Files and Files can only live in a // Package. func (v serviceVisitor) VisitPackage(pgs.Package) (pgs.Visitor, error) { return v, nil } func (v serviceVisitor) VisitFile(pgs.File) (pgs.Visitor, error) { return v, nil } @@ -322,8 +216,8 @@ func (v serviceVisitor) VisitService(pgs.Service) (pgs.Visitor, error) { return // VisitMethod logs out ServiceName#MethodName for m. func (v serviceVisitor) VisitMethod(m pgs.Method) (pgs.Vistitor, error) { - v.Logf("%v#%v", m.Service().Name(), m.Name()) - return nil, nil + v.Logf("%v#%v", m.Service().Name(), m.Name()) + return nil, nil } ``` @@ -331,19 +225,19 @@ If access to deeply nested `Nodes` is desired, a `PassthroughVisitor` can be use ```go type fieldVisitor struct { - pgs.Visitor - pgs.DebuggerCommon + pgs.Visitor + pgs.DebuggerCommon } func New(d pgs.DebuggerCommon) pgs.Visitor { - v := &fieldVisitor{DebuggerCommon: d} - v.Visitor = pgs.PassThroughVisitor(v) - return v + v := &fieldVisitor{DebuggerCommon: d} + v.Visitor = pgs.PassThroughVisitor(v) + return v } func (v *fieldVisitor) VisitField(f pgs.Field) (pgs.Visitor, error) { - v.Logf("%v.%v", f.Message().Name(), f.Name()) - return nil, nil + v.Logf("%v.%v", f.Message().Name(), f.Name()) + return nil, nil } ``` @@ -358,13 +252,13 @@ All `Entity` types and `Package` can be passed into `Walk`, allowing for startin ## Build Context -`Plugins` and `Modules` registered with the PGS `Generator` are initialized with an instance of `BuildContext` that encapsulates contextual paths, debugging, and parameter information. +`Modules` registered with the PG* `Generator` are initialized with an instance of `BuildContext` that encapsulates contextual paths, debugging, and parameter information. ### Output Paths -The `BuildContext's` `OutputPath` method returns the output directory that the PGS plugin is targeting. For `Plugins`, this path is initially `.` and is relative to the generation output directory specified in the protoc execution. For `Modules`, this path is also initially `.` but refers to the directory in which `protoc` is executed. This default behavior can be overridden for `Modules` by providing an `output_path` in the flag. +The `BuildContext's` `OutputPath` method returns the output directory that the PG* plugin is targeting. This path is also initially `.` but refers to the directory in which `protoc` is executed. This default behavior can be overridden by providing an `output_path` in the flag. -This value can be used to create file names for `Artifacts`, using `JoinPath(name ...string)` which is essentially an alias for `filepath.Join(ctx.Outpath, name...)`. Manually tracking directories relative to the `OutputPath` can be tedious, especially if the names are dynamic. Instead, a `BuildContext` can manage these, via `PushDir` and `PopDir`. +The `OutputPath` can be used to create file names for `Artifacts`, using `JoinPath(name ...string)` which is essentially an alias for `filepath.Join(ctx.OutputPath(), name...)`. Manually tracking directories relative to the `OutputPath` can be tedious, especially if the names are dynamic. Instead, a `BuildContext` can manage these, via `PushDir` and `PopDir`. ```go ctx.OutputPath() // foo @@ -378,65 +272,47 @@ ctx = ctx.PopDir() ctx.OutputPath() // foo ``` -Both `PluginBase` and `ModuleBase` wrap these methods to mutate their underlying `BuildContexts`. Those methods should be used instead of the ones on the contained `BuildContext` directly. +`ModuleBase` wraps these methods to mutate their underlying `BuildContexts`. Those methods should be used instead of the ones on the contained `BuildContext` directly. ### Debugging -The `BuildContext` exposes a `DebuggerCommon` interface which provides utilities for logging, error checking, and assertions. `Log` and the formatted `Logf` print messages to `os.Stderr`, typically prefixed with the `Plugin` or `Module` name. `Debug` and `Debugf` behave the same, but only print if enabled via the `DebugMode` or `DebugEnv` `InitOptions`. +The `BuildContext` exposes a `DebuggerCommon` interface which provides utilities for logging, error checking, and assertions. `Log` and the formatted `Logf` print messages to `os.Stderr`, typically prefixed with the `Module` name. `Debug` and `Debugf` behave the same, but only print if enabled via the `DebugMode` or `DebugEnv` `InitOptions`. `Fail` and `Failf` immediately stops execution of the protoc-plugin and causes `protoc` to fail generation with the provided message. `CheckErr` and `Assert` also fail with the provided messages if an error is passed in or if an expression evaluates to false, respectively. -Additional contextual prefixes can be provided by calling `Push` and `Pop` on the `BuildContext`. This behavior is similar to `PushDir` and `PopDir` but only impacts log messages. Both `PluginBase` and `ModuleBase` wrap these methods to mutate their underlying `BuildContexts`. Those methods should be used instead of the ones on the contained `BuildContext` directly. +Additional contextual prefixes can be provided by calling `Push` and `Pop` on the `BuildContext`. This behavior is similar to `PushDir` and `PopDir` but only impacts log messages. `ModuleBase` wraps these methods to mutate their underlying `BuildContexts`. Those methods should be used instead of the ones on the contained `BuildContext` directly. ### Parameters -The `BuildContext` also provides access to the pre-processed `Parameters` from the specified protoc flag. PGG allows for certain KV pairs in the parameters body, such as "plugins", "import_path", and "import_prefix" as well as import maps with the "M" prefix. PGS exposes these plus typed access to any other KV pairs passed in. The only PGS-specific key expected is "output_path", which is utilized by the a module's `BuildContext` for its `OutputPath`. - -PGS permits mutating the `Parameters` via the `MutateParams` `InitOption`. By passing in a `ParamMutator` function here, these KV pairs can be modified or verified prior to the PGG workflow begins. +The `BuildContext` also provides access to the pre-processed `Parameters` from the specified protoc flag. The only PG*-specific key expected is "output_path", which is utilized by a module's `BuildContext` for its `OutputPath`. -## Execution Workflows +PG* permits mutating the `Parameters` via the `MutateParams` `InitOption`. By passing in a `ParamMutator` function here, these KV pairs can be modified or verified prior to the PGG workflow begins. -Internally, PGS determines its behavior based off workflows. These are not publicly exposed to the API but can be modified based off `InitOptions` when initializing the `Generator`. +## Language-Specific Subpackages -### Standard Workflow +While implemented in Go, PG* seeks to be language agnostic in what it can do. Therefore, beyond the pre-generated base descriptor types, PG* has no dependencies on the protoc-gen-go (PGG) package. However, there are many nuances that each language's protoc-plugin introduce that can be generalized. For instance, PGG package naming, import paths, and output paths are a complex interaction of the proto package name, the `go_package` file option, and parameters passed to protoc. While PG*'s core API should not be overloaded with many language-specific methods, subpackages can be provided that can operate on `Parameters` and `Entities` to derive the appropriate results. -The standard workflow follows the steps described above in **The `protoc` Flow**. This is the out-of-the-box behavior of PGS-based plugins. +PG* currently implements the [pgsgo](https://godoc.org/github.com/lyft/protoc-gen-star/lang/go/) subpackage to provide these utilities to plugins targeting the Go language. Future subpackages are planned to support a variety of languages. -### Multi-Package Workflow +## PG* Development & Make Targets -Due to [purely philosophical reasons][single], PGG does not support passing in more than one package (ie, directory) of proto files at a time. In most circumstances, this is OK (if a bit annoying), however there are some generation patterns that may require loading in multiple packages/directories of protos simultaneously. By enabling this workflow, a PGS plugin will support running against multiple packages. - -This is achieved by splitting the `CodeGeneratorRequest` into multiple sub-requests, spawning a handful of child processes of the PGS plugin, and executing the PGG workflow against each sub-request independently. The parent process acts like `protoc` in this case, and captures the response of these before merging them together into a single `CodeGeneratorResponse`. `Modules` are not executed in the child processes; instead, the parent process executes them. If a `Module` implements the `MultiModule` interface, the `MultiExecute` method will be called with _all_ target `Packages` simultaneously. Otherwise, the `Execute` method is called separately for each target `Package`. - -**CAVEATS:** This workflow significantly changes the behavior from the Standard workflow and should be considered experimental. Also, the `ProtocInput` `InitOption` cannot be specified alongside this workflow. Changing the input will prevent the sub-requests from being properly executed. (A future update may make this possible.) Only enable this option if your plugin necessitates multi-package support. - -To enable this workflow, pass the `MultiPackage` `InitOption` to `Init`. - -### Exclude Go Workflow - -It is not always desirable for a PGS plugin to also generate the official Go source code coming from the PGG library (eg, when not generating Go code). In fact, by default, these files are not generated by PGS plugins. This is achieved by this workflow which decorates another workflow (typically, Standard or Multi-Package) to remove these files from the set of generated files. - -To disable this workflow, pass the `IncludeGo` `InitOption` to `Init`. - -## PGS Development & Make Targets - -PGS seeks to provide all the tools necessary to rapidly and ergonomically extend and build on top of the Protocol Buffer IDL. Whether the goal is to modify the official protoc-gen-go output or create entirely new files and packages, this library should offer a user-friendly wrapper around the complexities of the PB descriptors and the protoc-plugin workflow. +PG* seeks to provide all the tools necessary to rapidly and ergonomically extend and build on top of the Protocol Buffer IDL. Whether the goal is to modify the official protoc-gen-go output or create entirely new files and packages, this library should offer a user-friendly wrapper around the complexities of the PB descriptors and the protoc-plugin workflow. ### Setup -For developing on PGS, you should install the package within the `GOPATH`. PGS uses [glide][glide] for dependency management. +For developing on PG*, you should install the package within the `GOPATH`. PG* uses [glide][glide] for dependency management. ```sh go get -u github.com/lyft/protoc-gen-star cd $GOPATH/github.com/lyft/protoc-gen-star -make install +make vendor ``` To upgrade dependencies, please make the necessary modifications in `glide.yaml` and run `glide update`. ### Linting & Static Analysis -To avoid style nits and also to enforce some best practices for Go packages, PGS requires passing `golint`, `go vet`, and `go fmt -s` for all code changes. +To avoid style nits and also to enforce some best practices for Go packages, PG* requires passing `golint`, `go vet`, and `go fmt -s` for all code changes. ```sh make lint @@ -444,38 +320,49 @@ make lint ### Testing -PGS strives to have near 100% code coverage by unit tests. Most unit tests are run in parallel to catch potential race conditions. There are three ways of running unit tests, each taking longer than the next but providing more insight into test coverage: +PG* strives to have near 100% code coverage by unit tests. Most unit tests are run in parallel to catch potential race conditions. There are three ways of running unit tests, each taking longer than the next but providing more insight into test coverage: ```sh +# run code generation for the data used by the tests +make testdata + # run unit tests without race detection or code coverage reporting -make quick +make quick # run unit tests with race detection and code coverage -make tests +make tests # run unit tests with race detection and generates a code coverage report, opening in a browser -make cover +make cover ``` +#### protoc-gen-debug + +PG* comes with a specialized protoc-plugin, `protoc-gen-debug`. This plugin captures the CodeGeneratorRequest from a protoc execution and saves the serialized PB to disk. These files can be used as inputs to prevent calling protoc from tests. + ### Documentation -As PGS is intended to be an open-source utility, good documentation is important for consumers. Go is a self-documenting language, and provides a built in utility to view locally: `godoc`. The following command starts a godoc server and opens a browser window to this package's documentation. If you see a 404 or unavailable page initially, just refresh. +Go is a self-documenting language, and provides a built in utility to view locally: `godoc`. The following command starts a godoc server and opens a browser window to this package's documentation. If you see a 404 or unavailable page initially, just refresh. ```sh make docs ``` -#### Demo +### Demo -PGS comes with a "kitchen sink" example: [`protoc-gen-example`][pge]. This protoc plugin built on top of PGS prints out the target package's AST as a tree to stderr. This provides an end-to-end way of validating each of the nuanced types and nesting in PB descriptors: +PG* comes with a "kitchen sink" example: [`protoc-gen-example`][pge]. This protoc plugin built on top of PG* prints out the target package's AST as a tree to stderr. This provides an end-to-end way of validating each of the nuanced types and nesting in PB descriptors: ```sh -make demo +# create the example PG*-based plugin +make bin/protoc-gen-example + +# run protoc-gen-example against the demo protos +make testdata/generated ``` #### CI -PGS uses [TravisCI][travis] to validate all code changes. Please view the [configuration][travis.yml] for what tests are involved in the validation. +PG* uses [TravisCI][travis] to validate all code changes. Please view the [configuration][travis.yml] for what tests are involved in the validation. [glide]: http://glide.sh [pgg]: https://github.com/golang/protobuf/tree/master/protoc-gen-go diff --git a/artifact.go b/artifact.go index b95ea4b..91d4fa5 100644 --- a/artifact.go +++ b/artifact.go @@ -3,6 +3,7 @@ package pgs import ( "bytes" "errors" + "io" "os" "path/filepath" "strings" @@ -17,7 +18,15 @@ type Artifact interface { artifact() } +// A Template to use for rendering artifacts. Either text/template or +// html/template Template types satisfy this interface. +type Template interface { + Execute(w io.Writer, data interface{}) error +} + // GeneratorArtifact describes an Artifact that uses protoc for code generation. +// GeneratorArtifacts must be valid UTF8. To create binary files, use one of +// the "custom" Artifact types. type GeneratorArtifact interface { Artifact @@ -115,8 +124,8 @@ func (f GeneratorTemplateFile) ProtoFile() (*plugin_go.CodeGeneratorResponse_Fil } // A GeneratorAppend Artifact appends content to the end of the specified protoc -// generated file. This Artifact can only be used if another Plugin or Module -// generates a file with the same name. +// generated file. This Artifact can only be used if another Module generates a +// file with the same name. type GeneratorAppend struct { GeneratorArtifact @@ -141,7 +150,7 @@ func (f GeneratorAppend) ProtoFile() (*plugin_go.CodeGeneratorResponse_File, err }, nil } -// A GeneratorTemplateAppend appends content to a protoc generated file from a +// A GeneratorTemplateAppend appends content to a protoc-generated file from a // Template. See GeneratorAppend for limitations. type GeneratorTemplateAppend struct { GeneratorArtifact @@ -170,10 +179,10 @@ func (f GeneratorTemplateAppend) ProtoFile() (*plugin_go.CodeGeneratorResponse_F }, nil } -// A GeneratorInjection Artifact inserts content into a protoc generated file -// at the specified insertion point. The target file does not need to generated -// by this protoc-plugin but must be generated by an prior plugin executed by -// protoc. +// A GeneratorInjection Artifact inserts content into a protoc-generated file +// at the specified insertion point. The target file does not need to be +// generated by this protoc-plugin but must be generated by a prior plugin +// executed by protoc. type GeneratorInjection struct { GeneratorArtifact @@ -205,9 +214,9 @@ func (f GeneratorInjection) ProtoFile() (*plugin_go.CodeGeneratorResponse_File, } // A GeneratorTemplateInjection Artifact inserts content rendered from a -// Template into a protoc generated file at the specified insertion point. The -// target file does not need to generated by this protoc-plugin but must be -// generated by an prior plugin executed by protoc. +// Template into protoc-generated file at the specified insertion point. The +// target file does not need to be generated by this protoc-plugin but must be +// generated by a prior plugin executed by protoc. type GeneratorTemplateInjection struct { GeneratorArtifact TemplateArtifact diff --git a/ast.go b/ast.go new file mode 100644 index 0000000..65a6edf --- /dev/null +++ b/ast.go @@ -0,0 +1,351 @@ +package pgs + +import ( + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/golang/protobuf/protoc-gen-go/plugin" +) + +// AST encapsulates the entirety of the input CodeGeneratorRequest from protoc, +// parsed to build the Entity graph used by PG*. +type AST interface { + // Targets returns a map of the files specified in the protoc execution. For + // all Entities contained in these files, BuildTarget will return true. + Targets() map[string]File + + // Packages returns all the imported packages (including those for the target + // Files). This is limited to just the files that were imported by target + // protos, either directly or transitively. + Packages() map[string]Package + + // Lookup allows getting an Entity from the graph by its fully-qualified name + // (FQN). The FQN uses dot notation of the form ".{package}.{entity}", or the + // input path for Files. + Lookup(name string) (Entity, bool) +} + +type graph struct { + d Debugger + + targets map[string]File + packages map[string]Package + entities map[string]Entity +} + +func (g *graph) Targets() map[string]File { return g.targets } + +func (g *graph) Packages() map[string]Package { return g.packages } + +func (g *graph) Lookup(name string) (Entity, bool) { + e, ok := g.entities[name] + return e, ok +} + +// ProcessDescriptors converts a CodeGeneratorRequest from protoc into a fully +// connected AST entity graph. An error is returned if the input is malformed. +func ProcessDescriptors(debug Debugger, req *plugin_go.CodeGeneratorRequest) AST { + g := &graph{ + d: debug, + targets: make(map[string]File, len(req.GetFileToGenerate())), + packages: make(map[string]Package), + entities: make(map[string]Entity), + } + + for _, f := range req.GetFileToGenerate() { + g.targets[f] = nil + } + + for _, f := range req.GetProtoFile() { + pkg := g.hydratePackage(f) + pkg.addFile(g.hydrateFile(pkg, f)) + } + + return g +} + +func (g *graph) hydratePackage(f *descriptor.FileDescriptorProto) Package { + lookup := f.GetPackage() + if pkg, exists := g.packages[lookup]; exists { + return pkg + } + + p := &pkg{fd: f} + g.packages[lookup] = p + + return p +} + +func (g *graph) hydrateFile(pkg Package, f *descriptor.FileDescriptorProto) File { + fl := &file{ + pkg: pkg, + desc: f, + } + g.add(fl) + + if _, fl.buildTarget = g.targets[f.GetName()]; fl.buildTarget { + g.targets[f.GetName()] = fl + } + + enums := f.GetEnumType() + fl.enums = make([]Enum, 0, len(enums)) + for _, e := range enums { + fl.addEnum(g.hydrateEnum(fl, e)) + } + + msgs := f.GetMessageType() + fl.msgs = make([]Message, 0, len(f.GetMessageType())) + for _, msg := range msgs { + fl.addMessage(g.hydrateMessage(fl, msg)) + } + + srvs := f.GetService() + fl.srvs = make([]Service, 0, len(srvs)) + for _, sd := range srvs { + fl.addService(g.hydrateService(fl, sd)) + } + + for _, m := range fl.AllMessages() { + for _, me := range m.MapEntries() { + for _, fld := range me.Fields() { + fld.addType(g.hydrateFieldType(fld)) + } + } + + for _, fld := range m.Fields() { + fld.addType(g.hydrateFieldType(fld)) + } + } + + g.hydrateSourceCodeInfo(fl, f) + + return fl +} + +func (g *graph) hydrateSourceCodeInfo(f File, fd *descriptor.FileDescriptorProto) { + for _, loc := range fd.GetSourceCodeInfo().GetLocation() { + info := sci{desc: loc} + path := loc.GetPath() + + if len(path) == 1 { + switch path[0] { + case syntaxPath: + f.addSourceCodeInfo(info) + case packagePath: + f.addPackageSourceCodeInfo(info) + default: + continue + } + } + + if e := f.childAtPath(path); e != nil { + e.addSourceCodeInfo(info) + } + } +} + +func (g *graph) hydrateEnum(p ParentEntity, ed *descriptor.EnumDescriptorProto) Enum { + e := &enum{ + desc: ed, + parent: p, + } + g.add(e) + + vals := ed.GetValue() + e.vals = make([]EnumValue, 0, len(vals)) + for _, vd := range vals { + e.addValue(g.hydrateEnumValue(e, vd)) + } + + return e +} + +func (g *graph) hydrateEnumValue(e Enum, vd *descriptor.EnumValueDescriptorProto) EnumValue { + ev := &enumVal{ + desc: vd, + enum: e, + } + g.add(ev) + + return ev +} + +func (g *graph) hydrateService(f File, sd *descriptor.ServiceDescriptorProto) Service { + s := &service{ + desc: sd, + file: f, + } + g.add(s) + + for _, md := range sd.GetMethod() { + s.addMethod(g.hydrateMethod(s, md)) + } + + return s +} + +func (g *graph) hydrateMethod(s Service, md *descriptor.MethodDescriptorProto) Method { + m := &method{ + desc: md, + service: s, + } + g.add(m) + + m.in = g.mustSeen(md.GetInputType()).(Message) + m.out = g.mustSeen(md.GetOutputType()).(Message) + + return m +} + +func (g *graph) hydrateMessage(p ParentEntity, md *descriptor.DescriptorProto) Message { + m := &msg{ + desc: md, + parent: p, + } + g.add(m) + + for _, ed := range md.GetEnumType() { + m.addEnum(g.hydrateEnum(m, ed)) + } + + m.preservedMsgs = make([]Message, len(md.GetNestedType())) + for i, nmd := range md.GetNestedType() { + nm := g.hydrateMessage(m, nmd) + if nm.IsMapEntry() { + m.addMapEntry(nm) + } else { + m.addMessage(nm) + } + m.preservedMsgs[i] = nm + } + + for _, od := range md.GetOneofDecl() { + m.addOneOf(g.hydrateOneOf(m, od)) + } + + for _, fd := range md.GetField() { + fld := g.hydrateField(m, fd) + m.addField(fld) + + if idx := fld.Descriptor().OneofIndex; idx != nil { + m.oneofs[*idx].addField(fld) + } + } + + return m +} + +func (g *graph) hydrateField(m Message, fd *descriptor.FieldDescriptorProto) Field { + f := &field{ + desc: fd, + msg: m, + } + g.add(f) + + return f +} + +func (g *graph) hydrateOneOf(m Message, od *descriptor.OneofDescriptorProto) OneOf { + o := &oneof{ + desc: od, + msg: m, + } + g.add(o) + + return o +} + +func (g *graph) hydrateFieldType(fld Field) FieldType { + s := &scalarT{fld: fld} + + switch { + case s.ProtoType() == GroupT: + g.d.Fail("group types are deprecated and unsupported. Use an embedded message instead.") + return nil + case s.ProtoLabel() == Repeated: + return g.hydrateRepeatedFieldType(s) + case s.ProtoType() == EnumT: + return g.hydrateEnumFieldType(s) + case s.ProtoType() == MessageT: + return g.hydrateEmbedFieldType(s) + default: + return s + } +} + +func (g *graph) hydrateEnumFieldType(s *scalarT) FieldType { + return &enumT{ + scalarT: s, + enum: g.mustSeen(s.fld.Descriptor().GetTypeName()).(Enum), + } +} + +func (g *graph) hydrateEmbedFieldType(s *scalarT) FieldType { + return &embedT{ + scalarT: s, + msg: g.mustSeen(s.fld.Descriptor().GetTypeName()).(Message), + } +} + +func (g *graph) hydrateRepeatedFieldType(s *scalarT) FieldType { + r := &repT{ + scalarT: s, + } + r.el = &scalarE{ + typ: r, + ptype: r.ProtoType(), + } + + switch s.ProtoType() { + case EnumT: + r.el = &enumE{ + scalarE: r.el.(*scalarE), + enum: g.mustSeen(s.fld.Descriptor().GetTypeName()).(Enum), + } + case MessageT: + m := g.mustSeen(s.fld.Descriptor().GetTypeName()).(Message) + if m.IsMapEntry() { + return g.hydrateMapFieldType(r, m) + } + + r.el = &embedE{ + scalarE: r.el.(*scalarE), + msg: m, + } + } + + return r +} + +func (g *graph) hydrateMapFieldType(r *repT, m Message) FieldType { + mt := &mapT{repT: r} + + mt.key = m.Fields()[0].Type().toElem() + mt.key.setType(mt) + + mt.el = m.Fields()[1].Type().toElem() + mt.el.setType(mt) + + return mt +} + +func (g *graph) mustSeen(fqn string) Entity { + if existing, seen := g.entities[fqn]; seen { + return existing + } + + g.d.Failf("expected entity %q has not been hydrated", fqn) + return nil +} + +func (g *graph) add(e Entity) { + g.entities[g.resolveFQN(e)] = e +} + +func (g *graph) resolveFQN(e Entity) string { + if f, ok := e.(File); ok { + return f.Name().String() + } + + return e.FullyQualifiedName() +} + +var _ AST = (*graph)(nil) diff --git a/ast_test.go b/ast_test.go new file mode 100644 index 0000000..4edbfc4 --- /dev/null +++ b/ast_test.go @@ -0,0 +1,260 @@ +package pgs + +import ( + "io/ioutil" + "path/filepath" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func readCodeGenReq(t *testing.T, dir string) *plugin_go.CodeGeneratorRequest { + filename := filepath.Join("testdata", "graph", dir, "code_generator_request.pb.bin") + + data, err := ioutil.ReadFile(filename) + require.NoError(t, err, "unable to read CDR at %q", filename) + + req := &plugin_go.CodeGeneratorRequest{} + err = proto.Unmarshal(data, req) + require.NoError(t, err, "unable to unmarshal CDR data at %q", filename) + + return req +} + +func buildGraph(t *testing.T, dir string) AST { + d := InitMockDebugger() + ast := ProcessDescriptors(d, readCodeGenReq(t, dir)) + require.False(t, d.Failed(), "failed to build graph (see previous log statements)") + return ast +} + +func TestGraph_Messages(t *testing.T) { + t.Parallel() + g := buildGraph(t, "messages") + + tests := []struct { + lookup string + fldCt int + isMap, isRepeated, isEmbed, isEnum bool + }{ + { + lookup: ".graph.messages.Scalars", + fldCt: 15, + }, + { + lookup: ".graph.messages.Embedded", + fldCt: 6, + isEmbed: true, + }, + { + lookup: ".graph.messages.Enums", + fldCt: 6, + isEnum: true, + }, + { + lookup: ".graph.messages.Repeated", + fldCt: 13, + isRepeated: true, + }, + { + lookup: ".graph.messages.Maps", + fldCt: 13, + isMap: true, + }, + { + lookup: ".graph.messages.Recursive", + fldCt: 1, + isEmbed: true, + }, + } + + for _, test := range tests { + tc := test + t.Run(tc.lookup, func(t *testing.T) { + t.Parallel() + + ent, ok := g.Lookup(tc.lookup) + require.True(t, ok, "unknown entity lookup") + msg, ok := ent.(Message) + require.True(t, ok, "entity is not a message") + + flds := msg.Fields() + assert.Len(t, flds, tc.fldCt, "unexpected number of fields on the message") + + for _, fld := range flds { + t.Run(fld.Name().String(), func(t *testing.T) { + typ := fld.Type() + assert.Equal(t, tc.isMap, typ.IsMap(), "should not be a map") + assert.Equal(t, tc.isRepeated, typ.IsRepeated(), "should not be repeated") + assert.Equal(t, tc.isEmbed, typ.IsEmbed(), "should not be embedded") + assert.Equal(t, tc.isEnum, typ.IsEnum(), "should not be an enum") + }) + } + }) + } + + t.Run("oneof", func(t *testing.T) { + t.Parallel() + + ent, ok := g.Lookup(".graph.messages.OneOfs") + require.True(t, ok) + msg, ok := ent.(Message) + require.True(t, ok) + + flds := msg.Fields() + oneofFlds := msg.OneOfFields() + notOneofFlds := msg.NonOneOfFields() + + assert.Len(t, flds, 3) + assert.Len(t, oneofFlds, 1) + assert.Len(t, notOneofFlds, 2) + + oneofs := msg.OneOfs() + require.Len(t, oneofs, 1) + + oo := oneofs[0] + require.Len(t, oo.Fields(), 1) + assert.Equal(t, int32(2), oo.Fields()[0].Descriptor().GetNumber()) + assert.Equal(t, oneofFlds, oo.Fields()) + }) +} + +func TestGraph_Services(t *testing.T) { + t.Parallel() + + g := buildGraph(t, "services") + + t.Run("empty", func(t *testing.T) { + t.Parallel() + + ent, ok := g.Lookup(".graph.services.Empty") + require.True(t, ok) + svc, ok := ent.(Service) + require.True(t, ok) + + assert.Empty(t, svc.Methods()) + }) + + t.Run("unary", func(t *testing.T) { + t.Parallel() + + ent, ok := g.Lookup(".graph.services.Unary") + require.True(t, ok) + svc, ok := ent.(Service) + require.True(t, ok) + + mtds := svc.Methods() + assert.Len(t, mtds, 2) + + for _, mtd := range mtds { + assert.False(t, mtd.ClientStreaming(), mtd.FullyQualifiedName()) + assert.False(t, mtd.ServerStreaming(), mtd.FullyQualifiedName()) + } + }) + + t.Run("streaming", func(t *testing.T) { + t.Parallel() + + ent, ok := g.Lookup(".graph.services.Streaming") + require.True(t, ok) + svc, ok := ent.(Service) + require.True(t, ok) + + mtds := svc.Methods() + assert.Len(t, mtds, 3) + + tests := []struct{ client, server bool }{ + {true, false}, + {false, true}, + {true, true}, + } + + for i, mtd := range mtds { + assert.Equal(t, tests[i].client, mtd.ClientStreaming(), mtd.FullyQualifiedName()) + assert.Equal(t, tests[i].server, mtd.ServerStreaming(), mtd.FullyQualifiedName()) + } + }) +} + +func TestGraph_SourceCodeInfo(t *testing.T) { + t.Parallel() + + g := buildGraph(t, "info") + + tests := map[string]string{ + "Info": "root message", + "Info.Before": "before message", + "Info.BeforeEnum.BEFORE": "before enum value", + "Info.field": "field", + "Info.Middle": "middle message", + "Info.Middle.inner": "inner field", + "Info.other_field": "other field", + "Info.After": "after message", + "Info.AfterEnum": "after enum", + "Info.AfterEnum.AFTER": "after enum value", + "Info.OneOf": "oneof", + "Info.oneof_field": "oneof field", + "Enum": "root enum comment", + "Enum.ROOT": "root enum value", + "Service": "service", + "Service.Method": "method", + } + + for lookup, expected := range tests { + t.Run(lookup, func(t *testing.T) { + lo := ".graph.info." + lookup + ent, ok := g.Lookup(lo) + require.True(t, ok, "cannot find entity: %s", lo) + info := ent.SourceCodeInfo() + require.NotNil(t, info, "source code info is nil") + assert.Contains(t, info.LeadingComments(), expected, "invalid leading comment") + }) + } + + t.Run("file", func(t *testing.T) { + f, ok := g.Targets()["info/info.proto"] + require.True(t, ok, "cannot find file") + + info := f.SyntaxSourceCodeInfo() + require.NotNil(t, info, "no source code info on syntax") + assert.Contains(t, info.LeadingComments(), "syntax") + assert.Equal(t, info, f.SourceCodeInfo(), "SourceCodeInfo should return SyntaxSourceCodeInfo") + + info = f.PackageSourceCodeInfo() + require.NotNil(t, info, "no source code info on package") + assert.Contains(t, info.LeadingComments(), "package") + }) +} + +func TestGraph_MustSeen(t *testing.T) { + t.Parallel() + + md := InitMockDebugger() + g := &graph{ + d: md, + entities: make(map[string]Entity), + } + + f := dummyFile() + g.add(f) + + assert.Equal(t, f, g.mustSeen(g.resolveFQN(f))) + assert.Nil(t, g.mustSeen(".foo.bar.baz")) + assert.True(t, md.Failed()) +} + +func TestGraph_HydrateFieldType_Group(t *testing.T) { + t.Parallel() + + md := InitMockDebugger() + g := &graph{d: md} + + f := dummyField() + f.Descriptor().Type = GroupT.ProtoPtr() + + assert.Nil(t, g.hydrateFieldType(f)) + assert.True(t, md.Failed()) +} diff --git a/build_context_test.go b/build_context_test.go index 63824c6..f651a35 100644 --- a/build_context_test.go +++ b/build_context_test.go @@ -50,62 +50,62 @@ func TestPrefixContext_Debugf(t *testing.T) { func TestPrefixContext_Fail(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() c := initPrefixContext(nil, d, "foo") c.Fail("bar") - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPrefixContext_Failf(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() c := initPrefixContext(nil, d, "foo") c.Failf("bar %s", "baz") - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPrefixContext_CheckErr(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() c := initPrefixContext(nil, d, "foo") c.CheckErr(nil) - assert.False(t, d.failed) + assert.False(t, d.Failed()) err := errors.New("bar") c.CheckErr(err) - assert.True(t, d.failed) - assert.Equal(t, d.err, err) + assert.True(t, d.Exited()) + assert.Equal(t, d.Err(), err) } func TestPrefixContext_Assert(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() c := initPrefixContext(nil, d, "foo") c.Assert(true) - assert.False(t, d.failed) + assert.False(t, d.Failed()) c.Assert(false) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPrefixContext_OutputPath(t *testing.T) { t.Parallel() - d := Context(newMockDebugger(t), Parameters{}, "foo/bar") - c := initPrefixContext(d, newMockDebugger(t), "") + d := Context(InitMockDebugger(), Parameters{}, "foo/bar") + c := initPrefixContext(d, InitMockDebugger(), "") assert.Equal(t, c.OutputPath(), d.OutputPath()) } func TestPrefixContext_PushPop(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo/bar") - p := initPrefixContext(r, newMockDebugger(t), "baz") + r := Context(InitMockDebugger(), Parameters{}, "foo/bar") + p := initPrefixContext(r, InitMockDebugger(), "baz") c := p.Push("fizz") assert.IsType(t, prefixContext{}, c) @@ -115,8 +115,8 @@ func TestPrefixContext_PushPop(t *testing.T) { func TestPrefixContext_PushPopDir(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo/bar") - p := initPrefixContext(r, newMockDebugger(t), "fizz") + r := Context(InitMockDebugger(), Parameters{}, "foo/bar") + p := initPrefixContext(r, InitMockDebugger(), "fizz") c := p.PushDir("baz") assert.Equal(t, "foo/bar/baz", c.OutputPath()) @@ -127,7 +127,7 @@ func TestPrefixContext_Parameters(t *testing.T) { t.Parallel() p := Parameters{"foo": "bar"} - r := Context(newMockDebugger(t), p, ".") + r := Context(InitMockDebugger(), p, ".") c := r.Push("foo") assert.Equal(t, p, c.Parameters()) @@ -136,16 +136,16 @@ func TestPrefixContext_Parameters(t *testing.T) { func TestDirContext_OutputPath(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo/bar") - d := initDirContext(r, newMockDebugger(t), "baz") + r := Context(InitMockDebugger(), Parameters{}, "foo/bar") + d := initDirContext(r, InitMockDebugger(), "baz") assert.Equal(t, "foo/bar/baz", d.OutputPath()) } func TestDirContext_Push(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo/bar") - d := initDirContext(r, newMockDebugger(t), "baz") + r := Context(InitMockDebugger(), Parameters{}, "foo/bar") + d := initDirContext(r, InitMockDebugger(), "baz") c := d.Push("fizz") assert.Equal(t, d.OutputPath(), c.OutputPath()) @@ -155,8 +155,8 @@ func TestDirContext_Push(t *testing.T) { func TestDirContext_PushPopDir(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") - d := initDirContext(r, newMockDebugger(t), "bar") + r := Context(InitMockDebugger(), Parameters{}, "foo") + d := initDirContext(r, InitMockDebugger(), "bar") c := d.PushDir("baz") assert.Equal(t, "foo/bar/baz", c.OutputPath()) @@ -169,29 +169,29 @@ func TestDirContext_PushPopDir(t *testing.T) { func TestRootContext_OutputPath(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") + r := Context(InitMockDebugger(), Parameters{}, "foo") assert.Equal(t, "foo", r.OutputPath()) } func TestRootContext_PushPop(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() r := Context(d, Parameters{}, "foo") c := r.Push("bar") assert.Equal(t, "foo", c.OutputPath()) c = c.Pop() - assert.False(t, d.failed) + assert.False(t, d.Failed()) c.Pop() - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestRootContext_PushPopDir(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") + r := Context(InitMockDebugger(), Parameters{}, "foo") c := r.PushDir("bar") assert.Equal(t, "foo/bar", c.OutputPath()) @@ -206,21 +206,21 @@ func TestRootContext_Parameters(t *testing.T) { t.Parallel() p := Parameters{"foo": "bar"} - r := Context(newMockDebugger(t), p, "foo") + r := Context(InitMockDebugger(), p, "foo") assert.Equal(t, p, r.Parameters()) } func TestRootContext_JoinPath(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") + r := Context(InitMockDebugger(), Parameters{}, "foo") assert.Equal(t, "foo/bar", r.JoinPath("bar")) } func TestDirContext_JoinPath(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") + r := Context(InitMockDebugger(), Parameters{}, "foo") c := r.PushDir("bar") assert.Equal(t, "foo/bar/baz", c.JoinPath("baz")) @@ -229,7 +229,7 @@ func TestDirContext_JoinPath(t *testing.T) { func TestPrefixContext_JoinPath(t *testing.T) { t.Parallel() - r := Context(newMockDebugger(t), Parameters{}, "foo") + r := Context(InitMockDebugger(), Parameters{}, "foo") c := r.Push("baz") assert.Equal(t, "foo/bar", c.JoinPath("bar")) @@ -238,10 +238,10 @@ func TestPrefixContext_JoinPath(t *testing.T) { func TestPrefixContext_Exit(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() r := Context(d, Parameters{}, "") r.Exit(123) - assert.True(t, d.exited) - assert.Equal(t, 123, d.exitCode) + assert.True(t, d.Exited()) + assert.Equal(t, 123, d.ExitCode()) } diff --git a/comment.go b/comment.go index e87c0bc..2ca50f5 100644 --- a/comment.go +++ b/comment.go @@ -11,12 +11,6 @@ import ( const commentPrefix = "//" -// Commenter is a interface used by any node or entity which could have -// comments. -type Commenter interface { - Comments() string -} - // C returns a comment block, wrapping when the line's length will exceed wrap. func C(wrap int, args ...interface{}) string { s := commentScanner(wrap, args...) diff --git a/debug.go b/debug.go index a43e88c..5955d3e 100644 --- a/debug.go +++ b/debug.go @@ -1,7 +1,10 @@ package pgs import ( + "bytes" "fmt" + "io" + "log" "os" "strings" ) @@ -66,31 +69,50 @@ type logger interface { Printf(string, ...interface{}) } -type errFunc func(err error, msgs ...string) +type errFunc func(err error, msgs ...interface{}) -type failFunc func(msgs ...string) +type failFunc func(msgs ...interface{}) + +type exitFunc func(code int) type rootDebugger struct { err errFunc fail failFunc + exit exitFunc l logger logDebugs bool } -func initDebugger(g *Generator, l logger) Debugger { - return rootDebugger{ - err: g.pgg.Error, - fail: g.pgg.Fail, - logDebugs: g.debug, +func initDebugger(d bool, l logger) Debugger { + rd := rootDebugger{ + logDebugs: d, l: l, + exit: os.Exit, + } + + rd.fail = failFunc(rd.defaultFail) + rd.err = errFunc(rd.defaultErr) + + return rd +} + +func (d rootDebugger) defaultErr(err error, msg ...interface{}) { + if err != nil { + d.l.Printf("[error] %s: %v\n", fmt.Sprint(msg...), err) + d.exit(1) } } +func (d rootDebugger) defaultFail(msg ...interface{}) { + d.l.Println(msg...) + d.exit(1) +} + func (d rootDebugger) Log(v ...interface{}) { d.l.Println(v...) } func (d rootDebugger) Logf(format string, v ...interface{}) { d.l.Printf(format, v...) } func (d rootDebugger) Fail(v ...interface{}) { d.fail(fmt.Sprint(v...)) } func (d rootDebugger) Failf(format string, v ...interface{}) { d.fail(fmt.Sprintf(format, v...)) } -func (d rootDebugger) Exit(code int) { os.Exit(code) } +func (d rootDebugger) Exit(code int) { d.exit(code) } func (d rootDebugger) Debug(v ...interface{}) { if d.logDebugs { @@ -190,5 +212,75 @@ func (d prefixedDebugger) Pop() Debugger { return d.parent } -var _ Debugger = rootDebugger{} -var _ Debugger = prefixedDebugger{} +// MockDebugger serves as a root Debugger instance for usage in tests. Unlike +// an actual Debugger, MockDebugger will not exit the program, but will track +// failures, checked errors, and exit codes. +type MockDebugger interface { + Debugger + + // Output returns a reader of all logged data. + Output() io.Reader + + // Failed returns true if Fail or Failf has been called on this debugger or a + // descendant of it (via Push). + Failed() bool + + // Err returns the error passed to CheckErr. + Err() error + + // Exited returns true if this Debugger (or a descendant of it) would have + // called os.Exit. + Exited() bool + + // ExitCode returns the code this Debugger (or a descendant of it) passed to + // os.Exit. If Exited() returns false, this value is meaningless. + ExitCode() int +} + +type mockDebugger struct { + Debugger + + buf bytes.Buffer + failed bool + err error + exited bool + code int +} + +// InitMockDebugger creates a new MockDebugger for usage in tests. +func InitMockDebugger() MockDebugger { + md := &mockDebugger{} + d := initDebugger(true, log.New(&md.buf, "", 0)).(rootDebugger) + + d.fail = func(msgs ...interface{}) { + md.failed = true + d.defaultFail(msgs...) + } + + d.err = func(err error, msgs ...interface{}) { + if err != nil { + md.err = err + } + d.defaultErr(err, msgs...) + } + + d.exit = func(code int) { + md.exited = true + md.code = code + } + + md.Debugger = d + return md +} + +func (d *mockDebugger) Output() io.Reader { return &d.buf } +func (d *mockDebugger) Failed() bool { return d.failed } +func (d *mockDebugger) Err() error { return d.err } +func (d *mockDebugger) Exited() bool { return d.exited } +func (d *mockDebugger) ExitCode() int { return d.code } + +var ( + _ Debugger = rootDebugger{} + _ Debugger = prefixedDebugger{} + _ MockDebugger = &mockDebugger{} +) diff --git a/debug_test.go b/debug_test.go index 3312d55..c5e8a37 100644 --- a/debug_test.go +++ b/debug_test.go @@ -3,11 +3,11 @@ package pgs import ( "bytes" "fmt" + "io/ioutil" "testing" "errors" - "github.com/golang/protobuf/protoc-gen-go/generator" "github.com/stretchr/testify/assert" ) @@ -44,7 +44,7 @@ func TestRootDebugger_Fail(t *testing.T) { var failed bool - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Equal(t, "foobar", msgs[0]) failed = true } @@ -60,7 +60,7 @@ func TestRootDebugger_Failf(t *testing.T) { var failed bool - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Equal(t, "fizz buzz", msgs[0]) failed = true } @@ -107,7 +107,7 @@ func TestRootDebugger_CheckErr(t *testing.T) { e := errors.New("bad error") errd := false - errfn := func(err error, msg ...string) { + errfn := func(err error, msg ...interface{}) { assert.Equal(t, e, err) assert.Equal(t, "foo", msg[0]) errd = true @@ -126,19 +126,29 @@ func TestRootDebugger_Assert(t *testing.T) { failed := false - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Equal(t, "foo", msgs[0]) failed = true } rd := rootDebugger{fail: fail} - rd.Assert(1 == 1, "fizz") + rd.Assert(true, "fizz") assert.False(t, failed) - rd.Assert(1 == 0, "foo") + rd.Assert(false, "foo") assert.True(t, failed) } +func TestRootDebugger_Exit(t *testing.T) { + t.Parallel() + + var code int + + rd := rootDebugger{exit: func(c int) { code = c }} + rd.Exit(123) + assert.Equal(t, 123, code) +} + func TestRootDebugger_Push(t *testing.T) { t.Parallel() @@ -156,6 +166,51 @@ func TestRootDebugger_Pop(t *testing.T) { assert.Panics(t, func() { rd.Pop() }) } +func TestRootDebugger_DefaultErr(t *testing.T) { + t.Parallel() + + exited := false + code := 0 + l := newMockLogger() + rd := rootDebugger{ + l: l, + exit: func(c int) { + code = c + exited = true + }, + } + + rd.defaultErr(nil, "nothing") + + assert.False(t, exited) + assert.Empty(t, l.buf.String()) + + rd.defaultErr(errors.New("some error"), "something") + assert.True(t, exited) + assert.Equal(t, 1, code) + assert.Contains(t, l.buf.String(), "something") +} + +func TestRootDebugger_DefaultFail(t *testing.T) { + t.Parallel() + + exited := false + code := 0 + l := newMockLogger() + rd := rootDebugger{ + l: l, + exit: func(c int) { + code = c + exited = true + }, + } + + rd.defaultFail("something") + assert.True(t, exited) + assert.Equal(t, 1, code) + assert.Contains(t, l.buf.String(), "something") +} + func TestPrefixedDebugger_Log(t *testing.T) { t.Parallel() @@ -181,7 +236,7 @@ func TestPrefixedDebugger_Fail(t *testing.T) { var failed bool - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Contains(t, msgs[0], "FIZZ") assert.Contains(t, msgs[0], "foobar") failed = true @@ -198,7 +253,7 @@ func TestPrefixedDebugger_Failf(t *testing.T) { var failed bool - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Contains(t, msgs[0], "FIZZ") assert.Contains(t, msgs[0], "foo bar") failed = true @@ -252,7 +307,7 @@ func TestPrefixedDebugger_CheckErr(t *testing.T) { e := errors.New("bad error") errd := false - errfn := func(err error, msg ...string) { + errfn := func(err error, msg ...interface{}) { assert.Equal(t, e, err) assert.Contains(t, msg[0], "foo") assert.Contains(t, msg[0], "FIZZ") @@ -272,7 +327,7 @@ func TestPrefixedDebugger_Assert(t *testing.T) { failed := false - fail := func(msgs ...string) { + fail := func(msgs ...interface{}) { assert.Contains(t, msgs[0], "FIZZ") assert.Contains(t, msgs[0], "foo") failed = true @@ -318,62 +373,25 @@ func TestPrefixedDebugger_Push_Format(t *testing.T) { func TestPrefixedDebugger_Exit(t *testing.T) { t.Parallel() - md := newMockDebugger(t) + md := InitMockDebugger() d := &prefixedDebugger{parent: md} d.Exit(123) - assert.True(t, md.exited) - assert.Equal(t, 123, md.exitCode) + assert.True(t, md.Exited()) + assert.Equal(t, 123, md.ExitCode()) } func TestInitDebugger(t *testing.T) { t.Parallel() - - d := initDebugger(&Generator{ - pgg: Wrap(generator.New()), - gatherer: &gatherer{}, - }, nil) - + d := initDebugger(true, nil) assert.NotNil(t, d) } -type mockDebugger struct { - Debugger - - failed bool - err error - - exited bool - exitCode int -} - -func (d *mockDebugger) Exit(code int) { - if d.exited { - return - } - - d.exited = true - d.exitCode = code -} - -func newMockDebugger(t *testing.T) *mockDebugger { - d := &mockDebugger{} - d.Debugger = &rootDebugger{ - l: newMockLogger(), - err: func(err error, msgs ...string) { - d.err = err - d.failed = true - if t != nil { - t.Log(msgs) - } - }, - fail: func(msgs ...string) { - d.failed = true - if t != nil { - t.Log(msgs) - } - }, - } +func TestMockDebugger_Output(t *testing.T) { + t.Parallel() - return d + md := InitMockDebugger() + md.Log("foobar") + b, _ := ioutil.ReadAll(md.Output()) + assert.Equal(t, "foobar\n", string(b)) } diff --git a/entity.go b/entity.go index 13a6b24..4b2531c 100644 --- a/entity.go +++ b/entity.go @@ -1,18 +1,19 @@ package pgs -import "github.com/golang/protobuf/proto" +import ( + "github.com/golang/protobuf/proto" +) // Entity describes any member of the proto AST that is extensible via -// options. All nodes file and below are considered entities. +// options. All components of a File are considered entities. type Entity interface { Node - Commenter // The Name of the entity Name() Name // The fully qualified name of the entity. For example, a message - // 'HelloRequest' in a 'helloworld' package it takes the form of + // 'HelloRequest' in a 'helloworld' package takes the form of // '.helloworld.HelloRequest'. FullyQualifiedName() string @@ -23,8 +24,8 @@ type Entity interface { // Package returns the container package for this entity. Package() Package - // Imports includes all external packages required by this entity. - Imports() []Package + // Imports includes all external files required by this entity. + Imports() []File // File returns the File containing this entity. File() File @@ -39,6 +40,13 @@ type Entity interface { // this entity. Use this flag to determine if the file was targeted in the // protoc run or if it was loaded as an external dependency. BuildTarget() bool + + // SourceCodeInfo returns the SourceCodeInfo associated with the entity. + // Primarily, this struct contains the comments associated with the Entity. + SourceCodeInfo() SourceCodeInfo + + childAtPath(path []int32) Entity + addSourceCodeInfo(info SourceCodeInfo) } // A ParentEntity is any Entity type that can contain messages and/or enums. @@ -55,7 +63,7 @@ type ParentEntity interface { // MapEntries returns the MapEntry message types contained within this // Entity. These messages are not returned by the Messages or AllMessages - // methods. + // methods. Map Entry messages are typically not exposed to the end user. MapEntries() []Message // Enums returns the top-level enums from this entity. Nested enums diff --git a/enum.go b/enum.go index e646031..fb876a8 100644 --- a/enum.go +++ b/enum.go @@ -1,11 +1,8 @@ package pgs import ( - "strings" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" ) // Enum describes an enumeration type. Its parent can be either a Message or a @@ -13,12 +10,8 @@ import ( type Enum interface { Entity - // TypeName returns the type of this enum as it would be created in Go. - // This value will only differ from Name for nested enums. - TypeName() TypeName - // Descriptor returns the proto descriptor for this Enum - Descriptor() *generator.EnumDescriptor + Descriptor() *descriptor.EnumDescriptorProto // Parent resolves to either a Message or File that directly contains this // Enum. @@ -32,36 +25,26 @@ type Enum interface { } type enum struct { - rawDesc *descriptor.EnumDescriptorProto - genDesc *generator.EnumDescriptor - + desc *descriptor.EnumDescriptorProto parent ParentEntity - - vals []EnumValue - - comments string + vals []EnumValue + info SourceCodeInfo } -func (e *enum) Name() Name { return Name(e.rawDesc.GetName()) } -func (e *enum) FullyQualifiedName() string { return fullyQualifiedName(e.parent, e) } -func (e *enum) Syntax() Syntax { return e.parent.Syntax() } -func (e *enum) Package() Package { return e.parent.Package() } -func (e *enum) File() File { return e.parent.File() } -func (e *enum) BuildTarget() bool { return e.parent.BuildTarget() } -func (e *enum) Comments() string { return e.comments } -func (e *enum) Descriptor() *generator.EnumDescriptor { return e.genDesc } -func (e *enum) Parent() ParentEntity { return e.parent } -func (e *enum) Imports() []Package { return nil } -func (e *enum) TypeName() TypeName { return TypeName(strings.Join(e.genDesc.TypeName(), "_")) } - -func (e *enum) Values() []EnumValue { - ev := make([]EnumValue, len(e.vals)) - copy(ev, e.vals) - return ev -} +func (e *enum) Name() Name { return Name(e.desc.GetName()) } +func (e *enum) FullyQualifiedName() string { return fullyQualifiedName(e.parent, e) } +func (e *enum) Syntax() Syntax { return e.parent.Syntax() } +func (e *enum) Package() Package { return e.parent.Package() } +func (e *enum) File() File { return e.parent.File() } +func (e *enum) BuildTarget() bool { return e.parent.BuildTarget() } +func (e *enum) SourceCodeInfo() SourceCodeInfo { return e.info } +func (e *enum) Descriptor() *descriptor.EnumDescriptorProto { return e.desc } +func (e *enum) Parent() ParentEntity { return e.parent } +func (e *enum) Imports() []File { return nil } +func (e *enum) Values() []EnumValue { return e.vals } func (e *enum) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, error) { - return extension(e.rawDesc.GetOptions(), desc, &ext) + return extension(e.desc.GetOptions(), desc, &ext) } func (e *enum) accept(v Visitor) (err error) { @@ -89,4 +72,19 @@ func (e *enum) addValue(v EnumValue) { func (e *enum) setParent(p ParentEntity) { e.parent = p } +func (e *enum) childAtPath(path []int32) Entity { + switch { + case len(path) == 0: + return e + case len(path)%2 != 0: + return nil + case path[0] == enumTypeValuePath: + return e.vals[path[1]].childAtPath(path[2:]) + default: + return nil + } +} + +func (e *enum) addSourceCodeInfo(info SourceCodeInfo) { e.info = info } + var _ Enum = (*enum)(nil) diff --git a/enum_test.go b/enum_test.go index 8e2ca90..2f38ddb 100644 --- a/enum_test.go +++ b/enum_test.go @@ -6,34 +6,26 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" "github.com/stretchr/testify/assert" ) func TestEnum_Name(t *testing.T) { t.Parallel() - e := &enum{rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("foo")}} + e := &enum{desc: &descriptor.EnumDescriptorProto{Name: proto.String("foo")}} assert.Equal(t, "foo", e.Name().String()) } func TestEnum_FullyQualifiedName(t *testing.T) { t.Parallel() - e := &enum{rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("enum")}} + e := &enum{desc: &descriptor.EnumDescriptorProto{Name: proto.String("enum")}} f := dummyFile() f.addEnum(e) assert.Equal(t, f.FullyQualifiedName()+".enum", e.FullyQualifiedName()) } -func TestEnum_TypeName(t *testing.T) { - t.Parallel() - - e := dummyEnum() - assert.Equal(t, e.Name().String(), e.TypeName().String()) -} - func TestEnum_Syntax(t *testing.T) { t.Parallel() @@ -81,8 +73,8 @@ func TestEnum_BuildTarget(t *testing.T) { func TestEnum_Descriptor(t *testing.T) { t.Parallel() - e := &enum{genDesc: &generator.EnumDescriptor{}} - assert.Equal(t, e.genDesc, e.Descriptor()) + e := &enum{desc: &descriptor.EnumDescriptorProto{}} + assert.Equal(t, e.desc, e.Descriptor()) } func TestEnum_Parent(t *testing.T) { @@ -112,7 +104,7 @@ func TestEnum_Values(t *testing.T) { func TestEnum_Extension(t *testing.T) { // cannot be parallel - e := &enum{rawDesc: &descriptor.EnumDescriptorProto{}} + e := &enum{desc: &descriptor.EnumDescriptorProto{}} assert.NotPanics(t, func() { e.Extension(nil, nil) }) } @@ -148,6 +140,15 @@ func TestEnum_Accept(t *testing.T) { assert.Equal(t, 2, v.enumvalue) } +func TestEnum_ChildAtPath(t *testing.T) { + t.Parallel() + + e := &enum{} + assert.Equal(t, e, e.childAtPath(nil)) + assert.Nil(t, e.childAtPath([]int32{1})) + assert.Nil(t, e.childAtPath([]int32{999, 123})) +} + type mockEnum struct { Enum p ParentEntity @@ -166,8 +167,7 @@ func (e *mockEnum) accept(v Visitor) error { func dummyEnum() *enum { f := dummyFile() - e := &enum{rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("enum")}} - e.genDesc = &generator.EnumDescriptor{EnumDescriptorProto: e.rawDesc} + e := &enum{desc: &descriptor.EnumDescriptorProto{Name: proto.String("enum")}} f.addEnum(e) return e } diff --git a/enum_value.go b/enum_value.go index a962f97..db536ae 100644 --- a/enum_value.go +++ b/enum_value.go @@ -25,7 +25,7 @@ type enumVal struct { desc *descriptor.EnumValueDescriptorProto enum Enum - comments string + info SourceCodeInfo } func (ev *enumVal) Name() Name { return Name(ev.desc.GetName()) } @@ -34,11 +34,11 @@ func (ev *enumVal) Syntax() Syntax { return ev func (ev *enumVal) Package() Package { return ev.enum.Package() } func (ev *enumVal) File() File { return ev.enum.File() } func (ev *enumVal) BuildTarget() bool { return ev.enum.BuildTarget() } -func (ev *enumVal) Comments() string { return ev.comments } +func (ev *enumVal) SourceCodeInfo() SourceCodeInfo { return ev.info } func (ev *enumVal) Descriptor() *descriptor.EnumValueDescriptorProto { return ev.desc } func (ev *enumVal) Enum() Enum { return ev.enum } func (ev *enumVal) Value() int32 { return ev.desc.GetNumber() } -func (ev *enumVal) Imports() []Package { return nil } +func (ev *enumVal) Imports() []File { return nil } func (ev *enumVal) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, error) { return extension(ev.desc.GetOptions(), desc, &ext) @@ -56,4 +56,13 @@ func (ev *enumVal) accept(v Visitor) (err error) { func (ev *enumVal) setEnum(e Enum) { ev.enum = e } +func (ev *enumVal) childAtPath(path []int32) Entity { + if len(path) == 0 { + return ev + } + return nil +} + +func (ev *enumVal) addSourceCodeInfo(info SourceCodeInfo) { ev.info = info } + var _ EnumValue = (*enumVal)(nil) diff --git a/enum_value_test.go b/enum_value_test.go index 0fb3358..571e3e4 100644 --- a/enum_value_test.go +++ b/enum_value_test.go @@ -104,6 +104,14 @@ func TestEnumVal_Accept(t *testing.T) { assert.Equal(t, 1, v.enumvalue) } +func TestEnumVal_ChildAtPath(t *testing.T) { + t.Parallel() + + ev := &enumVal{} + assert.Equal(t, ev, ev.childAtPath(nil)) + assert.Nil(t, ev.childAtPath([]int32{1})) +} + type mockEnumValue struct { EnumValue e Enum diff --git a/extension.go b/extension.go index 7d329cb..01f8d17 100644 --- a/extension.go +++ b/extension.go @@ -1,11 +1,9 @@ package pgs import ( - "reflect" - - "fmt" - "errors" + "fmt" + "reflect" "github.com/golang/protobuf/proto" ) diff --git a/field.go b/field.go index 9a7143d..0095aa5 100644 --- a/field.go +++ b/field.go @@ -26,6 +26,10 @@ type Field interface { // Type returns the FieldType of this Field. Type() FieldType + // Required returns whether or not the field is labeled as required. This + // will only be true if the syntax is proto2. + Required() bool + setMessage(m Message) setOneOf(o OneOf) addType(t FieldType) @@ -37,17 +41,17 @@ type field struct { oneof OneOf typ FieldType - comments string + info SourceCodeInfo } func (f *field) Name() Name { return Name(f.desc.GetName()) } func (f *field) FullyQualifiedName() string { return fullyQualifiedName(f.msg, f) } func (f *field) Syntax() Syntax { return f.msg.Syntax() } func (f *field) Package() Package { return f.msg.Package() } -func (f *field) Imports() []Package { return f.typ.Imports() } +func (f *field) Imports() []File { return f.typ.Imports() } func (f *field) File() File { return f.msg.File() } func (f *field) BuildTarget() bool { return f.msg.BuildTarget() } -func (f *field) Comments() string { return f.comments } +func (f *field) SourceCodeInfo() SourceCodeInfo { return f.info } func (f *field) Descriptor() *descriptor.FieldDescriptorProto { return f.desc } func (f *field) Message() Message { return f.msg } func (f *field) InOneOf() bool { return f.oneof != nil } @@ -56,6 +60,11 @@ func (f *field) Type() FieldType { return f.typ } func (f *field) setMessage(m Message) { f.msg = m } func (f *field) setOneOf(o OneOf) { f.oneof = o } +func (f *field) Required() bool { + return f.Syntax().SupportsRequiredPrefix() && + f.desc.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REQUIRED +} + func (f *field) addType(t FieldType) { t.setField(f) f.typ = t @@ -74,4 +83,13 @@ func (f *field) accept(v Visitor) (err error) { return } -var _ (Field) = (*field)(nil) +func (f *field) childAtPath(path []int32) Entity { + if len(path) == 0 { + return f + } + return nil +} + +func (f *field) addSourceCodeInfo(info SourceCodeInfo) { f.info = info } + +var _ Field = (*field)(nil) diff --git a/field_test.go b/field_test.go index 86ed787..4e0ca48 100644 --- a/field_test.go +++ b/field_test.go @@ -1,9 +1,8 @@ package pgs import ( - "testing" - "errors" + "testing" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" @@ -138,18 +137,46 @@ func TestField_Imports(t *testing.T) { f.addType(&scalarT{}) assert.Empty(t, f.Imports()) - f.addType(&mockT{i: []Package{&pkg{}, &pkg{}}}) + f.addType(&mockT{i: []File{&file{}, &file{}}}) assert.Len(t, f.Imports(), 2) } +func TestField_Required(t *testing.T) { + t.Parallel() + + msg := dummyMsg() + + lbl := descriptor.FieldDescriptorProto_LABEL_REQUIRED + + f := &field{desc: &descriptor.FieldDescriptorProto{Label: &lbl}} + f.setMessage(msg) + + assert.False(t, f.Required(), "proto3 messages can never be marked required") + + f.File().(*file).desc.Syntax = proto.String(string(Proto2)) + assert.True(t, f.Required(), "proto2 + required") + + lbl = descriptor.FieldDescriptorProto_LABEL_OPTIONAL + f.desc.Label = &lbl + assert.False(t, f.Required(), "proto2 + optional") +} + +func TestField_ChildAtPath(t *testing.T) { + t.Parallel() + + f := &field{} + assert.Equal(t, f, f.childAtPath(nil)) + assert.Nil(t, f.childAtPath([]int32{1})) +} + type mockField struct { Field - i []Package + i []File m Message err error } -func (f *mockField) Imports() []Package { return f.i } +func (f *mockField) Imports() []File { return f.i } func (f *mockField) setMessage(m Message) { f.m = m } @@ -166,7 +193,7 @@ func dummyField() *field { str := descriptor.FieldDescriptorProto_TYPE_STRING f := &field{desc: &descriptor.FieldDescriptorProto{Name: proto.String("field"), Type: &str}} m.addField(f) - t := &scalarT{name: "string"} + t := &scalarT{} f.addType(t) return f } diff --git a/field_type.go b/field_type.go index 48035fb..d752f02 100644 --- a/field_type.go +++ b/field_type.go @@ -6,10 +6,6 @@ type FieldType interface { // equivalent, each instance of a FieldType is tied to its Field. Field() Field - // Name returns the TypeName for this Field, which represents the type of the - // field as it would exist in Go source code. - Name() TypeName - // IsRepeated returns true if and only if the field is marked as "repeated". // While map fields may be labeled as repeated, this method will not return // true for them. @@ -33,18 +29,14 @@ type FieldType interface { // IsRequired returns true if and only if the field is prefixed as required. IsRequired() bool - // IsSlice returns true if the field is represented in Go as a slice. This - // method returns true only for repeated and bytes-type fields. - IsSlice() bool - // ProtoType returns the ProtoType value for this field. ProtoType() ProtoType // ProtoLabel returns the ProtoLabel value for this field. ProtoLabel() ProtoLabel - // Imports includes all external packages required by this field. - Imports() []Package + // Imports includes all external proto files required by this field. + Imports() []File // Enum returns the Enum associated with this FieldType. If IsEnum returns // false, this value will be nil. @@ -75,21 +67,16 @@ type FieldType interface { toElem() FieldTypeElem } -type scalarT struct { - fld Field - name TypeName -} +type scalarT struct{ fld Field } func (s *scalarT) Field() Field { return s.fld } func (s *scalarT) IsRepeated() bool { return false } func (s *scalarT) IsMap() bool { return false } func (s *scalarT) IsEnum() bool { return false } func (s *scalarT) IsEmbed() bool { return false } -func (s *scalarT) Name() TypeName { return s.name } -func (s *scalarT) IsSlice() bool { return s.ProtoType().IsSlice() } func (s *scalarT) ProtoType() ProtoType { return ProtoType(s.fld.Descriptor().GetType()) } func (s *scalarT) ProtoLabel() ProtoLabel { return ProtoLabel(s.fld.Descriptor().GetLabel()) } -func (s *scalarT) Imports() []Package { return nil } +func (s *scalarT) Imports() []File { return nil } func (s *scalarT) setField(f Field) { s.fld = f } func (s *scalarT) Enum() Enum { return nil } func (s *scalarT) Embed() Message { return nil } @@ -108,7 +95,6 @@ func (s *scalarT) toElem() FieldTypeElem { return &scalarE{ typ: s, ptype: s.ProtoType(), - name: s.name, } } @@ -120,9 +106,9 @@ type enumT struct { func (e *enumT) Enum() Enum { return e.enum } func (e *enumT) IsEnum() bool { return true } -func (e *enumT) Imports() []Package { - if pkg := e.enum.Package(); pkg.GoName() != e.fld.Package().GoName() { - return []Package{pkg} +func (e *enumT) Imports() []File { + if f := e.enum.File(); f.Name() != e.fld.File().Name() { + return []File{f} } return nil } @@ -142,9 +128,9 @@ type embedT struct { func (e *embedT) Embed() Message { return e.msg } func (e *embedT) IsEmbed() bool { return true } -func (e *embedT) Imports() []Package { - if pkg := e.msg.Package(); pkg.GoName() != e.fld.Package().GoName() { - return []Package{pkg} +func (e *embedT) Imports() []File { + if f := e.msg.File(); f.Name() != e.fld.File().Name() { + return []File{f} } return nil } @@ -163,9 +149,8 @@ type repT struct { func (r *repT) IsRepeated() bool { return true } func (r *repT) Element() FieldTypeElem { return r.el } -func (r *repT) IsSlice() bool { return true } -func (r *repT) Imports() []Package { return r.el.Imports() } +func (r *repT) Imports() []File { return r.el.Imports() } func (r *repT) toElem() FieldTypeElem { panic("cannot convert repeated FieldType to FieldTypeElem") } @@ -176,7 +161,6 @@ type mapT struct { func (m *mapT) IsRepeated() bool { return false } func (m *mapT) IsMap() bool { return true } -func (m *mapT) IsSlice() bool { return false } func (m *mapT) Key() FieldTypeElem { return m.key } var ( diff --git a/field_type_elem.go b/field_type_elem.go index 734d015..5b4e93b 100644 --- a/field_type_elem.go +++ b/field_type_elem.go @@ -15,12 +15,8 @@ type FieldTypeElem interface { // IsEnum returns true if the component is an enum value. IsEnum() bool - // Name returns the TypeName describing this component (independent of the - // parent FieldType). - Name() TypeName - - // Imports includes all external packages required by this field. - Imports() []Package + // Imports includes all external Files required by this field. + Imports() []File // Enum returns the Enum associated with this FieldTypeElem. If IsEnum // returns false, this value will be nil. @@ -36,16 +32,14 @@ type FieldTypeElem interface { type scalarE struct { typ FieldType ptype ProtoType - name TypeName } func (s *scalarE) ParentType() FieldType { return s.typ } func (s *scalarE) ProtoType() ProtoType { return s.ptype } func (s *scalarE) IsEmbed() bool { return false } func (s *scalarE) IsEnum() bool { return false } -func (s *scalarE) Name() TypeName { return s.name } func (s *scalarE) setType(t FieldType) { s.typ = t } -func (s *scalarE) Imports() []Package { return nil } +func (s *scalarE) Imports() []File { return nil } func (s *scalarE) Enum() Enum { return nil } func (s *scalarE) Embed() Message { return nil } @@ -57,9 +51,9 @@ type enumE struct { func (e *enumE) IsEnum() bool { return true } func (e *enumE) Enum() Enum { return e.enum } -func (e *enumE) Imports() []Package { - if pkg := e.enum.Package(); pkg.GoName() != e.ParentType().Field().Package().GoName() { - return []Package{pkg} +func (e *enumE) Imports() []File { + if f := e.enum.File(); f.Name() != e.ParentType().Field().File().Name() { + return []File{f} } return nil } @@ -72,9 +66,9 @@ type embedE struct { func (e *embedE) IsEmbed() bool { return true } func (e *embedE) Embed() Message { return e.msg } -func (e *embedE) Imports() []Package { - if pkg := e.msg.Package(); pkg.GoName() != e.ParentType().Field().Package().GoName() { - return []Package{pkg} +func (e *embedE) Imports() []File { + if f := e.msg.File(); f.Name() != e.ParentType().Field().File().Name() { + return []File{f} } return nil } diff --git a/field_type_elem_test.go b/field_type_elem_test.go index 6a30f7b..d286cd2 100644 --- a/field_type_elem_test.go +++ b/field_type_elem_test.go @@ -3,6 +3,7 @@ package pgs import ( "testing" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/stretchr/testify/assert" ) @@ -31,12 +32,6 @@ func TestScalarE_IsEnum(t *testing.T) { assert.False(t, (&scalarE{}).IsEnum()) } -func TestScalarE_Name(t *testing.T) { - t.Parallel() - s := &scalarE{name: TypeName("foobar")} - assert.Equal(t, s.name, s.Name()) -} - func TestScalarE_Imports(t *testing.T) { t.Parallel() assert.Nil(t, (&scalarE{}).Imports()) @@ -66,15 +61,18 @@ func TestEnumE_Enum(t *testing.T) { func TestEnumE_Imports(t *testing.T) { t.Parallel() - e := &enumE{scalarE: &scalarE{}, enum: dummyEnum()} - f := dummyField() - e.typ = f.typ + en := dummyEnum() + f := dummyFile() + en.parent = f + e := &enumE{scalarE: &scalarE{}, enum: en} + fld := dummyField() + e.typ = fld.typ assert.Empty(t, e.Imports()) - e.enum.File().setPackage(&pkg{name: "not_the_same"}) + f.desc.Name = proto.String("some/other/file.proto") assert.Len(t, e.Imports(), 1) - assert.Equal(t, e.Enum().Package(), e.Imports()[0]) + assert.Equal(t, e.Enum().File(), e.Imports()[0]) } func TestEmbedE_IsEmbed(t *testing.T) { @@ -91,13 +89,16 @@ func TestEmbedE_Embed(t *testing.T) { func TestEmbedE_Imports(t *testing.T) { t.Parallel() - e := &embedE{scalarE: &scalarE{}, msg: dummyMsg()} - f := dummyField() - e.typ = f.typ + f := dummyFile() + msg := dummyMsg() + msg.parent = f + e := &embedE{scalarE: &scalarE{}, msg: msg} + fld := dummyField() + e.typ = fld.typ assert.Empty(t, e.Imports()) + f.desc.Name = proto.String("some/other/file.proto") - e.Embed().File().setPackage(&pkg{name: "not_the_same"}) assert.Len(t, e.Imports(), 1) - assert.Equal(t, e.Embed().Package(), e.Imports()[0]) + assert.Equal(t, e.Embed().File(), e.Imports()[0]) } diff --git a/field_type_test.go b/field_type_test.go index f0a6778..6388e8e 100644 --- a/field_type_test.go +++ b/field_type_test.go @@ -42,24 +42,6 @@ func TestScalarT_IsEmbed(t *testing.T) { assert.False(t, s.IsEmbed()) } -func TestScalarT_Name(t *testing.T) { - t.Parallel() - s := &scalarT{name: TypeName("foo")} - assert.Equal(t, "foo", s.Name().String()) -} - -func TestScalarT_IsSlice(t *testing.T) { - t.Parallel() - f := dummyField() - s := &scalarT{} - f.addType(s) - assert.False(t, s.IsSlice()) - - b := descriptor.FieldDescriptorProto_TYPE_BYTES - f.desc.Type = &b - assert.True(t, s.IsSlice()) -} - func TestScalarT_ProtoType(t *testing.T) { t.Parallel() f := dummyField() @@ -113,7 +95,7 @@ func TestScalarT_IsOptional(t *testing.T) { assert.True(t, s.IsOptional()) fl := dummyFile() - fl.desc.Syntax = proto.String("proto2") + fl.desc.Syntax = nil f.Message().setParent(fl) assert.True(t, s.IsOptional()) @@ -134,7 +116,7 @@ func TestScalarT_IsRequired(t *testing.T) { assert.False(t, s.IsRequired()) fl := dummyFile() - fl.desc.Syntax = proto.String("proto2") + fl.desc.Syntax = nil f.Message().setParent(fl) assert.False(t, s.IsRequired()) @@ -148,14 +130,13 @@ func TestScalarT_IsRequired(t *testing.T) { func TestScalarT_ToElem(t *testing.T) { t.Parallel() - s := &scalarT{name: TypeName("foo")} + s := &scalarT{} f := dummyField() f.addType(s) el := s.toElem() assert.Equal(t, s, el.ParentType()) assert.Equal(t, s.ProtoType(), el.ProtoType()) - assert.Equal(t, s.Name(), el.Name()) } func TestEnumT_Enum(t *testing.T) { @@ -173,23 +154,25 @@ func TestEnumT_IsEnum(t *testing.T) { func TestEnumT_Imports(t *testing.T) { t.Parallel() - e := &enumT{scalarT: &scalarT{}, enum: dummyEnum()} - f := dummyField() - f.addType(e) + f := dummyFile() + en := dummyEnum() + en.parent = f + e := &enumT{scalarT: &scalarT{}, enum: en} + fld := dummyField() + fld.addType(e) assert.Empty(t, e.Imports()) - e.enum.File().setPackage(&pkg{name: "not_the_same"}) - + f.desc.Name = proto.String("some/other/file.proto") assert.Len(t, e.Imports(), 1) - assert.Equal(t, e.enum.Package(), e.Imports()[0]) + assert.Equal(t, e.enum.File(), e.Imports()[0]) } func TestEnumT_ToElem(t *testing.T) { t.Parallel() e := &enumT{ - scalarT: &scalarT{name: TypeName("foo")}, + scalarT: &scalarT{}, enum: dummyEnum(), } f := dummyField() @@ -198,7 +181,6 @@ func TestEnumT_ToElem(t *testing.T) { el := e.toElem() assert.True(t, el.IsEnum()) assert.Equal(t, e.enum, el.Enum()) - assert.Equal(t, e.Name(), el.Name()) assert.Equal(t, e.ProtoType(), el.ProtoType()) } @@ -217,23 +199,24 @@ func TestEmbedT_Embed(t *testing.T) { func TestEmbedT_Imports(t *testing.T) { t.Parallel() - e := &embedT{scalarT: &scalarT{}, msg: dummyMsg()} - f := dummyField() - f.addType(e) + msg := dummyMsg() + f := dummyFile() + msg.parent = f + e := &embedT{scalarT: &scalarT{}, msg: msg} + dummyField().addType(e) assert.Empty(t, e.Imports()) - e.msg.File().setPackage(&pkg{name: "not_the_same"}) - + f.desc.Name = proto.String("some/other/file.proto") assert.Len(t, e.Imports(), 1) - assert.Equal(t, e.msg.Package(), e.Imports()[0]) + assert.Equal(t, e.msg.File(), e.Imports()[0]) } func TestEmbedT_ToElem(t *testing.T) { t.Parallel() e := &embedT{ - scalarT: &scalarT{name: TypeName("foo")}, + scalarT: &scalarT{}, msg: dummyMsg(), } f := dummyField() @@ -242,7 +225,6 @@ func TestEmbedT_ToElem(t *testing.T) { el := e.toElem() assert.True(t, el.IsEmbed()) assert.Equal(t, e.msg, el.Embed()) - assert.Equal(t, e.Name(), el.Name()) assert.Equal(t, e.ProtoType(), el.ProtoType()) } @@ -252,12 +234,6 @@ func TestRepT_IsRepeated(t *testing.T) { assert.True(t, r.IsRepeated()) } -func TestRepT_IsSlice(t *testing.T) { - t.Parallel() - r := &repT{} - assert.True(t, r.IsSlice()) -} - func TestRepT_Element(t *testing.T) { t.Parallel() r := &repT{el: &scalarE{}} @@ -267,18 +243,21 @@ func TestRepT_Element(t *testing.T) { func TestRepT_Imports(t *testing.T) { t.Parallel() - e := &embedT{scalarT: &scalarT{}, msg: dummyMsg()} + msg := dummyMsg() + f := dummyFile() + msg.parent = f + e := &embedT{scalarT: &scalarT{}, msg: msg} dummyField().addType(e) - f := dummyField() + fld := dummyField() r := &repT{scalarT: &scalarT{}, el: e.toElem()} - f.addType(r) + fld.addType(r) assert.Empty(t, r.Imports()) - r.el.Embed().File().setPackage(&pkg{name: "not_the_same"}) + f.desc.Name = proto.String("some/other/file.proto") assert.Len(t, r.Imports(), 1) - assert.Equal(t, r.el.Embed().Package(), r.Imports()[0]) + assert.Equal(t, r.el.Embed().File(), r.Imports()[0]) } func TestRepT_ToElem(t *testing.T) { @@ -296,11 +275,6 @@ func TestMapT_IsMap(t *testing.T) { assert.True(t, (&mapT{}).IsMap()) } -func TestMapT_IsSlice(t *testing.T) { - t.Parallel() - assert.False(t, (&mapT{}).IsSlice()) -} - func TestMapT_Key(t *testing.T) { t.Parallel() m := &mapT{key: &scalarE{}} @@ -309,11 +283,11 @@ func TestMapT_Key(t *testing.T) { type mockT struct { FieldType - i []Package + i []File f Field err error } -func (t *mockT) Imports() []Package { return t.i } +func (t *mockT) Imports() []File { return t.i } func (t *mockT) setField(f Field) { t.f = f } diff --git a/file.go b/file.go index 4156312..92424cd 100644 --- a/file.go +++ b/file.go @@ -2,60 +2,63 @@ package pgs import ( "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/generator" + "github.com/golang/protobuf/protoc-gen-go/descriptor" ) // File describes the contents of a single proto file. type File interface { ParentEntity - // InputPath returns the input FilePath of the generated Go code. This is - // equivalent to the value returned by Name. + // InputPath returns the input FilePath. This is equivalent to the value + // returned by Name. InputPath() FilePath - // OutputPath returns the output filepath of the generated Go code - OutputPath() FilePath - // Descriptor returns the underlying descriptor for the proto file - Descriptor() *generator.FileDescriptor + Descriptor() *descriptor.FileDescriptorProto - // Services returns the top-level services from this proto file. + // Services returns the services from this proto file. Services() []Service + // SyntaxSourceCodeInfo returns the comment info attached to the `syntax` + // stanza of the file. This method is an alias of the SourceCodeInfo method. + SyntaxSourceCodeInfo() SourceCodeInfo + + // PackageSourceCodeInfo returns the comment info attached to the `package` + // stanza of the file. + PackageSourceCodeInfo() SourceCodeInfo + setPackage(p Package) addService(s Service) - lookupComments(name string) string + addPackageSourceCodeInfo(info SourceCodeInfo) } type file struct { - desc *generator.FileDescriptor - pkg Package - outputPath FilePath - enums []Enum - msgs []Message - srvs []Service - buildTarget bool - comments map[string]string -} - -func (f *file) Name() Name { return Name(f.desc.GetName()) } -func (f *file) FullyQualifiedName() string { return "." + f.desc.GetPackage() } -func (f *file) Syntax() Syntax { return Syntax(f.desc.GetSyntax()) } -func (f *file) Package() Package { return f.pkg } -func (f *file) File() File { return f } -func (f *file) BuildTarget() bool { return f.buildTarget } -func (f *file) Comments() string { return "" } -func (f *file) Descriptor() *generator.FileDescriptor { return f.desc } -func (f *file) InputPath() FilePath { return FilePath(f.Name().String()) } -func (f *file) OutputPath() FilePath { return f.outputPath } -func (f *file) MapEntries() (me []Message) { return nil } + desc *descriptor.FileDescriptorProto + pkg Package + enums []Enum + msgs []Message + srvs []Service + buildTarget bool + syntaxInfo, packageInfo SourceCodeInfo +} + +func (f *file) Name() Name { return Name(f.desc.GetName()) } +func (f *file) FullyQualifiedName() string { return "." + f.desc.GetPackage() } +func (f *file) Syntax() Syntax { return Syntax(f.desc.GetSyntax()) } +func (f *file) Package() Package { return f.pkg } +func (f *file) File() File { return f } +func (f *file) BuildTarget() bool { return f.buildTarget } +func (f *file) Descriptor() *descriptor.FileDescriptorProto { return f.desc } +func (f *file) InputPath() FilePath { return FilePath(f.Name().String()) } +func (f *file) MapEntries() (me []Message) { return nil } +func (f *file) SourceCodeInfo() SourceCodeInfo { return f.SyntaxSourceCodeInfo() } +func (f *file) SyntaxSourceCodeInfo() SourceCodeInfo { return f.syntaxInfo } +func (f *file) PackageSourceCodeInfo() SourceCodeInfo { return f.packageInfo } func (f *file) Enums() []Enum { - es := make([]Enum, len(f.enums)) - copy(es, f.enums) - return es + return f.enums } func (f *file) AllEnums() []Enum { @@ -67,9 +70,7 @@ func (f *file) AllEnums() []Enum { } func (f *file) Messages() []Message { - msgs := make([]Message, len(f.msgs)) - copy(msgs, f.msgs) - return msgs + return f.msgs } func (f *file) AllMessages() []Message { @@ -81,12 +82,10 @@ func (f *file) AllMessages() []Message { } func (f *file) Services() []Service { - s := make([]Service, len(f.srvs)) - copy(s, f.srvs) - return s + return f.srvs } -func (f *file) Imports() (i []Package) { +func (f *file) Imports() (i []File) { for _, m := range f.AllMessages() { i = append(i, m.Imports()...) } @@ -149,4 +148,35 @@ func (f *file) addService(s Service) { func (f *file) addMapEntry(m Message) { panic("cannot add map entry directly to file") } -func (f *file) lookupComments(name string) string { return f.comments[name] } +func (f *file) childAtPath(path []int32) Entity { + switch { + case len(path) == 0: + return f + case len(path)%2 == 1: // all declaration paths are multiples of two + return nil + } + + var child Entity + switch path[0] { + case messageTypePath: + child = f.msgs[path[1]] + case enumTypePath: + child = f.enums[path[1]] + case servicePath: + child = f.srvs[path[1]] + default: + return nil + } + + return child.childAtPath(path[2:]) +} + +func (f *file) addSourceCodeInfo(info SourceCodeInfo) { + f.syntaxInfo = info +} + +func (f *file) addPackageSourceCodeInfo(info SourceCodeInfo) { + f.packageInfo = info +} + +var _ File = (*file)(nil) diff --git a/file_test.go b/file_test.go index a876a7b..74dd541 100644 --- a/file_test.go +++ b/file_test.go @@ -13,10 +13,8 @@ import ( func TestFile_Name(t *testing.T) { t.Parallel() - f := &file{desc: &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Name: proto.String("foobar"), - }, + f := &file{desc: &descriptor.FileDescriptorProto{ + Name: proto.String("foobar"), }} assert.Equal(t, Name("foobar"), f.Name()) @@ -25,8 +23,8 @@ func TestFile_Name(t *testing.T) { func TestFile_FullyQualifiedName(t *testing.T) { t.Parallel() - f := &file{desc: &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{Package: proto.String("foo")}, + f := &file{desc: &descriptor.FileDescriptorProto{ + Package: proto.String("foo"), }} assert.Equal(t, ".foo", f.FullyQualifiedName()) @@ -35,11 +33,7 @@ func TestFile_FullyQualifiedName(t *testing.T) { func TestFile_Syntax(t *testing.T) { t.Parallel() - f := &file{desc: &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Syntax: proto.String("proto2"), - }, - }} + f := &file{desc: &descriptor.FileDescriptorProto{}} assert.Equal(t, Proto2, f.Syntax()) } @@ -47,14 +41,14 @@ func TestFile_Syntax(t *testing.T) { func TestFile_Package(t *testing.T) { t.Parallel() - f := &file{pkg: &pkg{importPath: "fizz/buzz"}} + f := &file{pkg: &pkg{comments: "fizz/buzz"}} assert.Equal(t, f.pkg, f.Package()) } func TestFile_File(t *testing.T) { t.Parallel() - f := &file{outputPath: "foobar"} + f := &file{buildTarget: true} assert.Equal(t, f, f.File()) } @@ -70,24 +64,17 @@ func TestFile_BuildTarget(t *testing.T) { func TestFile_Descriptor(t *testing.T) { t.Parallel() - f := &file{desc: &generator.FileDescriptor{}} + f := &file{desc: &descriptor.FileDescriptorProto{}} assert.Equal(t, f.desc, f.Descriptor()) } func TestFile_InputPath(t *testing.T) { t.Parallel() - f := &file{desc: &generator.FileDescriptor{FileDescriptorProto: &descriptor.FileDescriptorProto{Name: proto.String("foo.bar")}}} + f := &file{desc: &descriptor.FileDescriptorProto{Name: proto.String("foo.bar")}} assert.Equal(t, "foo.bar", f.InputPath().String()) } -func TestFile_OutputPath(t *testing.T) { - t.Parallel() - - f := &file{outputPath: "foobar"} - assert.Equal(t, "foobar", f.OutputPath().String()) -} - func TestFile_Enums(t *testing.T) { t.Parallel() @@ -171,8 +158,8 @@ func TestFile_Imports(t *testing.T) { t.Parallel() m := &msg{} - m.addMessage(&mockMessage{i: []Package{&pkg{}}, Message: &msg{}}) - svc := &mockService{i: []Package{&pkg{}}, Service: &service{}} + m.addMessage(&mockMessage{i: []File{&file{}}, Message: &msg{}}) + svc := &mockService{i: []File{&file{}}, Service: &service{}} f := &file{} assert.Empty(t, f.Imports()) @@ -242,7 +229,7 @@ func TestFile_Extension(t *testing.T) { assert.NotPanics(t, func() { (&file{ - desc: &generator.FileDescriptor{FileDescriptorProto: &descriptor.FileDescriptorProto{}}, + desc: &descriptor.FileDescriptorProto{}, }).Extension(nil, nil) }) } @@ -273,17 +260,19 @@ func (f *mockFile) accept(v Visitor) error { func dummyFile() *file { pkg := dummyPkg() f := &file{ - pkg: pkg, - outputPath: "output/path.pb.go", - desc: &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Package: proto.String(pkg.ProtoName().String()), - Syntax: proto.String("proto3"), - Name: proto.String("file.proto"), - }, + pkg: pkg, + desc: &descriptor.FileDescriptorProto{ + Package: proto.String(pkg.ProtoName().String()), + Syntax: proto.String(string(Proto3)), + Name: proto.String("file.proto"), }, } pkg.addFile(f) return f } + +func dummyGenFile() (*file, *generator.FileDescriptor) { + f := dummyFile() + return f, &generator.FileDescriptor{FileDescriptorProto: f.desc} +} diff --git a/gatherer.go b/gatherer.go deleted file mode 100644 index 771b9f1..0000000 --- a/gatherer.go +++ /dev/null @@ -1,569 +0,0 @@ -package pgs - -import ( - "errors" - "fmt" - "strings" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" -) - -const gathererPluginName = "gatherer" - -type gatherer struct { - *PluginBase - entities map[string]Entity - pkgs map[string]Package - targets map[string]Package -} - -func newGatherer() *gatherer { return &gatherer{PluginBase: &PluginBase{}} } - -func (g *gatherer) Name() string { return gathererPluginName } - -func (g *gatherer) Init(gen *generator.Generator) { - g.PluginBase.Init(gen) - g.targets = make(map[string]Package) - g.pkgs = make(map[string]Package) - g.entities = make(map[string]Entity) -} - -func (g *gatherer) Generate(f *generator.FileDescriptor) { - comments := make(map[string]string) - for _, loc := range f.GetSourceCodeInfo().GetLocation() { - if loc.LeadingComments == nil { - continue - } - - name, err := g.nameByPath(f.FileDescriptorProto, loc.Path) - if err != nil { - g.Debug("unable to convert path to name:", err.Error()) - } - - comments[name] = strings.TrimSuffix(loc.GetLeadingComments(), "\n") - } - - pkg := g.hydratePackage(f, comments) - pkg.addFile(g.hydrateFile(pkg, f, comments)) -} - -func (g *gatherer) hydratePackage(f *generator.FileDescriptor, comments map[string]string) Package { - // TODO(btc): perhaps return error with specific info about failure - importPath := goImportPath(g.Generator.Unwrap(), f) - name := string(g.Generator.GoPackageName(importPath)) - if p, n, found := goPackageOption(f); found { - if p != "" { - importPath = generator.GoImportPath(p) - } - if n != "" { - name = n - } - } - - g.push("package:" + name) - defer g.pop() - - // have we already hydrated this package. In case we already did, and if - // current file contains comments in the package statement, concatenate it - // so that we don't give any precedence to whatsoever file. - pcomments := comments[fmt.Sprintf(".%s", name)] - if p, ok := g.pkgs[name]; ok { - c := make([]string, 0, 2) - - ccomments := p.Comments() - if ccomments != "" { - c = append(c, ccomments) - } - - if pcomments != "" { - c = append(c, pcomments) - } - - p.setComments(strings.Join(c, "\n")) - return p - } - - p := &pkg{ - fd: f, - name: name, - importPath: string(importPath), - comments: pcomments, - } - - g.pkgs[name] = p - return p -} - -func (g *gatherer) hydrateFile(pkg Package, f *generator.FileDescriptor, comments map[string]string) File { - fl := &file{ - pkg: pkg, - desc: f, - outputPath: FilePath(goFileName(f, g.Parameters().Paths())), - } - - if out, ok := g.seen(fl); ok { - return out.(*file) - } - g.add(fl) - - g.push("file:" + fl.Name().String()) - defer g.pop() - - g.Assert(f.GetPackage() == pkg.ProtoName().String(), - "proto package names should not be mixed in the same directory (", - pkg.ProtoName().String(), ", ", f.GetPackage(), ")") - - fl.buildTarget = g.BuildTarget(f.GetName()) - fl.comments = comments - - if _, seen := g.targets[fl.pkg.GoName().String()]; fl.buildTarget && !seen { - g.Debug("adding target package:", fl.pkg.GoName()) - g.targets[fl.pkg.GoName().String()] = fl.pkg - } - - fl.msgs = make([]Message, 0, len(f.GetMessageType())) - fl.enums = make([]Enum, 0, len(f.GetEnumType())) - fl.srvs = make([]Service, 0, len(f.GetService())) - - // populate all enum types - for _, ed := range f.GetEnumType() { - fl.addEnum(g.hydrateEnum(fl, ed)) - } - - // populate all message types - for _, md := range f.GetMessageType() { - fl.addMessage(g.hydrateMessage(fl, md)) - } - - // populates all field types. This must come after all messages to permit - // hydrating all types prior to hydration - for _, m := range fl.AllMessages() { - // This must come after all messages but before normal message fields to - // permit the later hydration. - for _, me := range m.MapEntries() { - for _, fld := range me.Fields() { - fld.addType(g.hydrateFieldType(fld)) - } - } - - for _, fld := range m.Fields() { - fld.addType(g.hydrateFieldType(fld)) - } - } - - // populate all services - for _, sd := range f.GetService() { - fl.addService(g.hydrateService(fl, sd)) - } - - return fl -} - -func (g *gatherer) hydrateMessage(parent ParentEntity, md *descriptor.DescriptorProto) Message { - m := &msg{ - rawDesc: md, - parent: parent, - } - - if out, ok := g.seen(m); ok { - return out.(*msg) - } - g.add(m) - - g.push("msg:" + m.Name().String()) - defer g.pop() - - name := m.FullyQualifiedName() - m.genDesc = g.Generator.ObjectNamed(name).(*generator.Descriptor) - m.comments = m.File().lookupComments(name) - - // populate all nested enums - for _, ed := range md.GetEnumType() { - m.addEnum(g.hydrateEnum(m, ed)) - } - - // populate all nested messages. If the message is a map entry type, stash it. - for _, smd := range md.GetNestedType() { - if sm := g.hydrateMessage(m, smd); sm.IsMapEntry() { - m.addMapEntry(sm) - } else { - m.addMessage(sm) - } - } - - // populate all fields - for _, fd := range md.GetField() { - m.addField(g.hydrateField(m, fd)) - } - - // populate all oneofs. This must come after the fields to properly associate - // the field relationships - for i, od := range md.GetOneofDecl() { - m.addOneOf(g.hydrateOneOf(m, int32(i), od)) - } - - return m -} - -func (g *gatherer) hydrateField(msg Message, fd *descriptor.FieldDescriptorProto) Field { - f := &field{ - desc: fd, - msg: msg, - } - - if out, ok := g.seen(f); ok { - return out.(*field) - } - g.add(f) - - f.comments = f.File().lookupComments(f.FullyQualifiedName()) - - return f -} - -func (g *gatherer) hydrateFieldType(fld Field) FieldType { - g.push("field-type:" + fld.FullyQualifiedName()) - defer g.pop() - - msg := fld.Message().Descriptor() - name, _ := g.Generator.GoType(msg, fld.Descriptor()) - - s := &scalarT{ - fld: fld, - name: TypeName(name), - } - - switch { - case s.ProtoType() == GroupT: - g.Fail("group types are deprecated and unsupported. Use an embedded message instead.") - return nil - case s.ProtoLabel() == Repeated: - return g.hydrateRepeatedFieldType(s) - case s.ProtoType() == EnumT: - return g.hydrateEnumFieldType(s) - case s.ProtoType() == MessageT: - return g.hydrateEmbedFieldType(s) - default: - return s - } -} - -func (g *gatherer) hydrateEnumFieldType(s *scalarT) FieldType { - e := &enumT{scalarT: s} - - ent, ok := g.seenObj(g.Generator.ObjectNamed(s.fld.Descriptor().GetTypeName())) - g.Assert(ok, "enum type not seen") - - en, ok := ent.(Enum) - g.Assert(ok, "unexpected entity type") - e.enum = en - - return e -} - -func (g *gatherer) hydrateEmbedFieldType(s *scalarT) FieldType { - e := &embedT{scalarT: s} - - ent, ok := g.seenObj(g.Generator.ObjectNamed(s.fld.Descriptor().GetTypeName())) - g.Assert(ok, "embed type not seen") - - m, ok := ent.(Message) - g.Assert(ok, "unexpected entity type") - e.msg = m - - return e -} - -func (g *gatherer) hydrateRepeatedFieldType(s *scalarT) FieldType { - r := &repT{scalarT: s} - r.el = &scalarE{ - typ: r, - ptype: r.ProtoType(), - name: r.Name().Element(), - } - - switch s.ProtoType() { - case EnumT: - ent, ok := g.seenObj(g.Generator.ObjectNamed(s.fld.Descriptor().GetTypeName())) - g.Assert(ok, "enum type not seen") - - en, ok := ent.(Enum) - g.Assert(ok, "unexpected entity type") - - r.el = &enumE{ - scalarE: r.el.(*scalarE), - enum: en, - } - case MessageT: - ent, ok := g.seenObj(g.Generator.ObjectNamed(s.fld.Descriptor().GetTypeName())) - g.Assert(ok, "embed type not seen") - - m, ok := ent.(Message) - g.Assert(ok, "unexpected entity type") - - if m.IsMapEntry() { - return g.hydrateMapFieldType(r, m) - } - - r.el = &embedE{ - scalarE: r.el.(*scalarE), - msg: m, - } - - } - - return r -} - -func (g *gatherer) hydrateMapFieldType(r *repT, m Message) FieldType { - mt := &mapT{repT: r} - - mt.key = m.Fields()[0].Type().toElem() - mt.key.setType(mt) - - mt.el = m.Fields()[1].Type().toElem() - mt.el.setType(mt) - - mt.name = TypeName(fmt.Sprintf( - "map[%s]%s", - mt.key.Name(), - mt.el.Name())) - - return mt -} - -func (g *gatherer) hydrateOneOf(msg Message, idx int32, od *descriptor.OneofDescriptorProto) OneOf { - o := &oneof{ - desc: od, - msg: msg, - } - - if out, ok := g.seen(o); ok { - return out.(*oneof) - } - g.add(o) - - g.push("oneof:" + o.Name().String()) - defer g.pop() - - o.comments = o.File().lookupComments(o.FullyQualifiedName()) - - for _, f := range msg.Fields() { - if i := f.Descriptor().OneofIndex; i != nil && idx == *i { - o.addField(f) - } - } - - return o -} - -func (g *gatherer) hydrateEnum(parent ParentEntity, ed *descriptor.EnumDescriptorProto) Enum { - e := &enum{ - rawDesc: ed, - parent: parent, - } - - if out, ok := g.seen(e); ok { - return out.(*enum) - } - g.add(e) - - g.push("enum:" + e.Name().String()) - defer g.pop() - - name := e.FullyQualifiedName() - e.genDesc = g.Generator.ObjectNamed(name).(*generator.EnumDescriptor) - e.comments = e.File().lookupComments(name) - - for _, vd := range ed.GetValue() { - e.addValue(g.hydrateEnumValue(e, vd)) - } - - return e -} - -func (g *gatherer) hydrateEnumValue(parent Enum, vd *descriptor.EnumValueDescriptorProto) EnumValue { - ev := &enumVal{ - desc: vd, - enum: parent, - } - - if out, ok := g.seen(ev); ok { - return out.(*enumVal) - } - g.add(ev) - - ev.comments = ev.File().lookupComments(ev.FullyQualifiedName()) - - return ev -} - -func (g *gatherer) hydrateService(parent File, sd *descriptor.ServiceDescriptorProto) Service { - s := &service{ - desc: sd, - file: parent, - } - - if out, ok := g.seen(s); ok { - return out.(*service) - } - g.add(s) - - g.push("service:" + s.Name().String()) - defer g.pop() - - s.comments = s.File().lookupComments(s.FullyQualifiedName()) - - for _, md := range sd.GetMethod() { - s.addMethod(g.hydrateMethod(s, md)) - } - - return s -} - -func (g *gatherer) hydrateMethod(parent Service, md *descriptor.MethodDescriptorProto) Method { - m := &method{ - desc: md, - service: parent, - } - - if out, ok := g.seen(m); ok { - return out.(*method) - } - g.add(m) - - g.push("method:" + m.Name().String()) - defer g.pop() - - m.comments = m.File().lookupComments(m.FullyQualifiedName()) - - in, ok := g.seenName(md.GetInputType()) - g.Assert(ok, "input type", md.GetInputType(), "not hydrated") - m.in = in.(*msg) - - out, ok := g.seenName(md.GetOutputType()) - g.Assert(ok, "output type", md.GetOutputType(), "not hydrated") - m.out = out.(*msg) - - return m -} - -func (g *gatherer) push(prefix string) { g.BuildContext = g.Push(prefix) } - -func (g *gatherer) pop() { g.BuildContext = g.Pop() } - -func (g *gatherer) seen(e Entity) (Entity, bool) { return g.seenName(g.resolveFullyQualifiedName(e)) } - -func (g *gatherer) seenName(ln string) (Entity, bool) { - out, ok := g.entities[ln] - return out, ok -} - -func (g *gatherer) seenObj(o generator.Object) (Entity, bool) { - ent, ok := g.seenName(o.File().GetName()) - g.Assert(ok, "dependent proto file not seen:", o.File().GetName()) - fl := ent.File() - - return g.seenName(fl.FullyQualifiedName() + "." + strings.Join(o.TypeName(), ".")) -} - -func (g *gatherer) add(e Entity) { g.entities[g.resolveFullyQualifiedName(e)] = e } - -func (g *gatherer) resolveFullyQualifiedName(e Entity) string { - if f, ok := e.(File); ok { - return f.Name().String() - } - - return e.FullyQualifiedName() -} - -func (g *gatherer) nameByPath(f *descriptor.FileDescriptorProto, path []int32) (string, error) { - const ( - packagePath = 2 // FileDescriptorProto.Package - messageTypePath = 4 // FileDescriptorProto.MessageType - enumTypePath = 5 // FileDescriptorProto.EnumType - servicePath = 6 // FileDescriptorProto.Service - - messageTypeFieldPath = 2 // DescriptorProto.Field - messageTypeNestedTypePath = 3 // DescriptorProto.NestedType - messageTypeEnumTypePath = 4 // DescriptorProto.EnumType - messageTypeOneofDeclPath = 8 // DescriptorProto.OneofDecl - ) - - // return fast in case it's the package leading comment - packageName := f.GetPackage() - if path[0] == packagePath { - return fmt.Sprintf(".%s", packageName), nil - } - - // as we're refering to concrete entities, entity type should be followed by - // an index number thus always leading to even paths. - if len(path)%2 != 0 { - return "", errors.New("path must have even elements") - } - - // tail-call recursive path to name conversion functor - var fn func(interface { - GetName() string - }, []int32, *[]string) error - fn = func(parent interface { - GetName() string - }, path []int32, names *[]string) error { - if len(path) == 0 { - return nil - } - - t := path[0] - n := path[1] - switch td := parent.(type) { - case *descriptor.FileDescriptorProto: - switch t { - case messageTypePath: - parent = td.MessageType[n] - case enumTypePath: - parent = td.EnumType[n] - case servicePath: - parent = td.Service[n] - } - case *descriptor.ServiceDescriptorProto: - parent = td.Method[n] - case *descriptor.EnumDescriptorProto: - parent = td.Value[n] - case *descriptor.DescriptorProto: - switch t { - case messageTypeFieldPath: - parent = td.Field[n] - case messageTypeNestedTypePath: - parent = td.NestedType[n] - case messageTypeEnumTypePath: - parent = td.EnumType[n] - case messageTypeOneofDeclPath: - parent = td.OneofDecl[n] - } - } - - *names = append(*names, parent.GetName()) - return fn(parent, path[2:], names) - } - - // reserve exactly the required capacity - var names []string - namesLen := uint(len(path) / 2) - if packageName != "" { - names = make([]string, 0, namesLen+1) - names = append(names, packageName) - } else { - names = make([]string, 0, namesLen) - } - - // start the conversion - err := fn(f, path, &names) - if err != nil { - return "", err - } - - return fmt.Sprintf(".%s", strings.Join(names, ".")), nil -} - -var _ generator.Plugin = (*gatherer)(nil) diff --git a/gatherer_test.go b/gatherer_test.go deleted file mode 100644 index 40aca1a..0000000 --- a/gatherer_test.go +++ /dev/null @@ -1,852 +0,0 @@ -package pgs - -import ( - "testing" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/golang/protobuf/protoc-gen-go/plugin" - "github.com/stretchr/testify/assert" -) - -func initTestGatherer(t *testing.T) *gatherer { - gen := generator.New() - g := &gatherer{PluginBase: &PluginBase{}} - g.Init(gen) - return g -} - -func TestGatherer_Generate(t *testing.T) { - t.Skip("generator needs to be initialized first") - t.Parallel() - - f := &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Name: proto.String("file.proto"), - Package: proto.String("pkg"), - }, - } - - g := initTestGatherer(t) - gen := generator.New() - gen.Request.FileToGenerate = []string{f.GetName()} - g.Generator = Wrap(gen) - pgg := initGathererPGG(g) - pgg.name = "pkg" - - g.Generate(f) - - assert.Len(t, g.targets, 1) - assert.Equal(t, "pkg", g.targets["pkg"].GoName().String()) - assert.Len(t, g.pkgs, 1) - assert.Equal(t, g.targets["pkg"], g.pkgs[g.targets["pkg"].GoName().String()]) - assert.Len(t, g.targets["pkg"].Files(), 1) - - assert.Equal(t, g.targets["pkg"], g.hydratePackage(f, map[string]string{})) -} - -func TestGatherer_HydrateFile(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - typ := StringT.Proto() - - me := &descriptor.DescriptorProto{ - Name: proto.String("MapEntry"), - Options: &descriptor.MessageOptions{MapEntry: proto.Bool(true)}, - Field: []*descriptor.FieldDescriptorProto{ - { - Name: proto.String("map_entry_field"), - Type: &typ, - TypeName: proto.String("string"), - }, - }, - } - - m := &descriptor.DescriptorProto{ - Name: proto.String("Msg"), - Field: []*descriptor.FieldDescriptorProto{ - { - Name: proto.String("msg_field"), - Type: &typ, - TypeName: proto.String("string"), - }, - }, - NestedType: []*descriptor.DescriptorProto{me}, - } - - e := &descriptor.EnumDescriptorProto{Name: proto.String("Enum")} - - s := &descriptor.ServiceDescriptorProto{Name: proto.String("Svc")} - - df := dummyFile() - desc := df.Descriptor() - desc.MessageType = []*descriptor.DescriptorProto{m} - desc.EnumType = []*descriptor.EnumDescriptorProto{e} - desc.Service = []*descriptor.ServiceDescriptorProto{s} - - pkg := df.Package() - - comments := map[string]string{} - - pgg.objs = map[string]generator.Object{ - df.FullyQualifiedName() + ".Msg": &generator.Descriptor{DescriptorProto: m}, - df.FullyQualifiedName() + ".Msg.MapEntry": &generator.Descriptor{DescriptorProto: me}, - df.FullyQualifiedName() + ".Enum": &generator.EnumDescriptor{EnumDescriptorProto: e}, - } - - f := g.hydrateFile(pkg, desc, comments) - assert.Equal(t, pkg, f.Package()) - assert.Equal(t, desc, f.Descriptor()) - assert.Equal(t, goFileName(desc, g.Parameters().Paths()), f.OutputPath().String()) - assert.Len(t, f.AllMessages(), 1) - assert.Len(t, f.Enums(), 1) - assert.Len(t, f.Services(), 1) - - _, ok := g.seen(f) - assert.True(t, ok) - assert.Equal(t, f, g.hydrateFile(pkg, desc, comments)) -} - -func TestGatherer_HydrateFile_PackageMismatch(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - initGathererPGG(g) - md := newMockDebugger(t) - g.BuildContext = Context(md, Parameters{}, "") - - df := dummyFile() - dp := df.Package() - desc := df.Descriptor() - desc.Package = proto.String("not_the_same_as_dp") - - g.hydrateFile(dp, desc, map[string]string{}) - assert.True(t, md.failed) -} - -func TestGatherer_HydrateMessage(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - me := &descriptor.DescriptorProto{ - Name: proto.String("MapEntry"), - Options: &descriptor.MessageOptions{MapEntry: proto.Bool(true)}, - } - - nm := &descriptor.DescriptorProto{ - Name: proto.String("NestedMsg"), - } - - de := dummyEnum() - ne := de.Descriptor().EnumDescriptorProto - - fld := dummyField().Descriptor() - - o := dummyOneof().Descriptor() - - dm := dummyMsg() - desc := dm.rawDesc - desc.Field = []*descriptor.FieldDescriptorProto{fld} - desc.EnumType = []*descriptor.EnumDescriptorProto{ne} - desc.NestedType = []*descriptor.DescriptorProto{nm, me} - desc.OneofDecl = []*descriptor.OneofDescriptorProto{o} - - f := dm.File() - - pgg.objs = map[string]generator.Object{ - fullyQualifiedName(f, dm): dm.Descriptor(), - fullyQualifiedName(dm, de): de.Descriptor(), - dm.FullyQualifiedName() + ".NestedMsg": &generator.Descriptor{DescriptorProto: nm}, - dm.FullyQualifiedName() + ".MapEntry": &generator.Descriptor{DescriptorProto: me}, - } - - m := g.hydrateMessage(f, desc) - assert.Equal(t, dm.Descriptor(), m.Descriptor()) - assert.Equal(t, f, m.Parent()) - assert.Len(t, m.Enums(), 1) - assert.Len(t, m.Messages(), 1) - assert.Len(t, m.MapEntries(), 1) - assert.Len(t, m.Fields(), 1) - assert.Len(t, m.OneOfs(), 1) - - _, ok := g.seen(m) - assert.True(t, ok) - assert.Equal(t, m, g.hydrateMessage(f, desc)) -} - -func TestGatherer_HydrateField(t *testing.T) { - t.Parallel() - - df := dummyField() - desc := df.Descriptor() - m := dummyMsg() - - g := initTestGatherer(t) - - f := g.hydrateField(m, desc) - assert.Equal(t, desc, f.Descriptor()) - assert.Equal(t, m, f.Message()) - - _, ok := g.seen(f) - assert.True(t, ok) - assert.Equal(t, f, g.hydrateField(m, desc)) -} - -func TestGatherer_HydrateFieldType_Scalar(t *testing.T) { - t.Skip("common file access for proto3 method is impossible to mock") - t.Parallel() - - g := initTestGatherer(t) - - typ := StringT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("scalar"), - Type: &typ, - TypeName: proto.String("*string"), - }, - } - - g.add(fld) - - ft := g.hydrateFieldType(fld) - assert.Equal(t, "*string", ft.Name().String()) -} - -func TestGatherer_HydrateFieldType_Enum(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - emb := &enum{ - parent: dummyFile(), - rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("EmbeddedEnum")}, - } - emb.genDesc = &generator.EnumDescriptor{EnumDescriptorProto: emb.rawDesc} - - typ := EnumT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("enum"), - Type: &typ, - TypeName: proto.String("EmbeddedEnum"), - }, - } - - g.add(emb) - g.add(emb.File()) - g.add(fld) - - pgg.types[fld.desc.GetName()] = fld.desc.GetTypeName() - pgg.objs[fld.desc.GetTypeName()] = &mockObject{ - file: emb.File().Descriptor(), - name: []string{emb.Name().String()}, - } - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsEnum()) - assert.Equal(t, "EmbeddedEnum", ft.Name().String()) -} - -func TestGatherer_HydrateFieldType_Embed(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - emb := &msg{ - parent: dummyFile(), - rawDesc: &descriptor.DescriptorProto{Name: proto.String("EmbeddedMessage")}, - } - emb.genDesc = &generator.Descriptor{DescriptorProto: emb.rawDesc} - - typ := MessageT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("embeded"), - Type: &typ, - TypeName: proto.String("*EmbeddedMessage"), - }, - } - - g.add(emb) - g.add(emb.File()) - g.add(fld) - - pgg.types[fld.desc.GetName()] = fld.desc.GetTypeName() - pgg.objs[fld.desc.GetTypeName()] = &mockObject{ - file: emb.File().Descriptor(), - name: []string{emb.Name().String()}, - } - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsEmbed()) - assert.Equal(t, "*EmbeddedMessage", ft.Name().String()) -} - -func TestGatherer_HydrateFieldType_Group(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - initGathererPGG(g) - - d := newMockDebugger(t) - g.PluginBase.BuildContext = Context(d, Parameters{}, ".") - - typ := GroupT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("deprecated_group"), - Type: &typ, - }, - } - - g.add(fld) - g.hydrateFieldType(fld) - assert.True(t, d.failed) -} - -func TestGatherer_HydrateFieldType_RepeatedScalar(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - - lbl := Repeated.Proto() - typ := StringT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("scalar_repeated"), - Label: &lbl, - Type: &typ, - TypeName: proto.String("[]string"), - }, - } - - g.add(fld) - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsRepeated()) - assert.Equal(t, "[]string", ft.Name().String()) - assert.Equal(t, "string", ft.Element().Name().String()) -} - -func TestGatherer_HydrateFieldType_RepeatedEnum(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - el := &enum{ - parent: dummyFile(), - rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("EmbeddedEnum")}, - } - el.genDesc = &generator.EnumDescriptor{EnumDescriptorProto: el.rawDesc} - - lbl := Repeated.Proto() - typ := EnumT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("enum_repeated"), - Label: &lbl, - Type: &typ, - TypeName: proto.String("[]EmbeddedEnum"), - }, - } - - g.add(el) - g.add(el.File()) - g.add(fld) - - pgg.types[fld.desc.GetName()] = fld.desc.GetTypeName() - pgg.objs[fld.desc.GetTypeName()] = &mockObject{ - file: el.File().Descriptor(), - name: []string{el.Name().String()}, - } - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsRepeated()) - assert.True(t, ft.Element().IsEnum()) - assert.Equal(t, "[]EmbeddedEnum", ft.Name().String()) - assert.Equal(t, "EmbeddedEnum", ft.Element().Name().String()) -} - -func TestGatherer_HydrateFieldType_RepeatedEmbed(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - el := &msg{ - parent: dummyFile(), - rawDesc: &descriptor.DescriptorProto{Name: proto.String("EmbeddedMessage")}, - } - el.genDesc = &generator.Descriptor{DescriptorProto: el.rawDesc} - - lbl := Repeated.Proto() - typ := MessageT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("embeded_repeated"), - Label: &lbl, - Type: &typ, - TypeName: proto.String("[]EmbeddedMessage"), - }, - } - - g.add(el) - g.add(el.File()) - g.add(fld) - - pgg.types[fld.desc.GetName()] = fld.desc.GetTypeName() - pgg.objs[fld.desc.GetTypeName()] = &mockObject{ - file: el.File().Descriptor(), - name: []string{el.Name().String()}, - } - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsRepeated()) - assert.True(t, ft.Element().IsEmbed()) - assert.Equal(t, "[]EmbeddedMessage", ft.Name().String()) - assert.Equal(t, "EmbeddedMessage", ft.Element().Name().String()) -} - -func TestGatherer_HydrateFieldType_Map(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - key := &field{ - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("key"), - TypeName: proto.String("string"), - }, - } - key.addType(&scalarT{name: TypeName("string")}) - - val := &field{ - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("value"), - TypeName: proto.String("int64"), - }, - } - val.addType(&scalarT{name: TypeName("int64")}) - - me := &msg{ - parent: dummyFile(), - rawDesc: &descriptor.DescriptorProto{ - Name: proto.String("FooBarEntry"), - Options: &descriptor.MessageOptions{MapEntry: proto.Bool(true)}, - }, - } - me.genDesc = &generator.Descriptor{DescriptorProto: me.rawDesc} - me.addField(key) - me.addField(val) - - lbl := Repeated.Proto() - typ := MessageT.Proto() - fld := &field{ - msg: dummyMsg(), - desc: &descriptor.FieldDescriptorProto{ - Name: proto.String("map_field"), - Label: &lbl, - Type: &typ, - TypeName: proto.String("FooBarEntry"), - }, - } - - g.add(key) - g.add(val) - g.add(me) - g.add(fld) - g.add(me.File()) - - pgg.types[fld.desc.GetName()] = me.Name().String() - pgg.objs[fld.desc.GetTypeName()] = &mockObject{ - file: me.File().Descriptor(), - name: []string{me.Name().String()}, - } - - ft := g.hydrateFieldType(fld) - assert.True(t, ft.IsMap()) - assert.Equal(t, "map[string]int64", ft.Name().String()) - assert.Equal(t, "string", ft.Key().Name().String()) - assert.Equal(t, "int64", ft.Element().Name().String()) -} - -func TestGatherer_HydrateOneOf(t *testing.T) { - t.Parallel() - - do := dummyOneof() - desc := do.Descriptor() - - m := do.Message() - m.addField(dummyField()) - - f := dummyField() - f.desc.OneofIndex = proto.Int32(123) - m.addField(f) - - g := initTestGatherer(t) - - o := g.hydrateOneOf(m, 123, desc) - assert.Equal(t, desc, o.Descriptor()) - assert.Equal(t, m, o.Message()) - assert.Len(t, o.Fields(), 1) - assert.Equal(t, f, o.Fields()[0]) - - _, ok := g.seen(o) - assert.True(t, ok) - assert.Equal(t, o, g.hydrateOneOf(m, 123, desc)) -} - -func TestGatherer_HydrateEnum(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - pgg := initGathererPGG(g) - - de := dummyEnum() - pgg.objs[de.FullyQualifiedName()] = de.genDesc - p := de.Parent() - - desc := de.rawDesc - desc.Value = []*descriptor.EnumValueDescriptorProto{{}} - - e := g.hydrateEnum(p, desc) - assert.Equal(t, de.genDesc, e.Descriptor()) - assert.Equal(t, p, e.Parent()) - assert.Len(t, e.Values(), 1) - - _, ok := g.seen(e) - assert.True(t, ok) - assert.Equal(t, e, g.hydrateEnum(p, desc)) -} - -func TestGatherer_HydrateEnumValue(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - e := dummyEnum() - desc := &descriptor.EnumValueDescriptorProto{} - - ev := g.hydrateEnumValue(e, desc) - assert.Equal(t, desc, ev.Descriptor()) - assert.Equal(t, e, ev.Enum()) - - _, ok := g.seen(ev) - assert.True(t, ok) - assert.Equal(t, ev, g.hydrateEnumValue(e, desc)) -} - -func TestGatherer_HydrateService(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - io := dummyMsg() - g.add(io) - f := dummyFile() - - desc := &descriptor.ServiceDescriptorProto{ - Method: []*descriptor.MethodDescriptorProto{{ - InputType: proto.String(io.FullyQualifiedName()), - OutputType: proto.String(io.FullyQualifiedName()), - }}, - } - - s := g.hydrateService(f, desc) - assert.Equal(t, desc, s.Descriptor()) - assert.Equal(t, f, s.File()) - assert.Len(t, s.Methods(), 1) - - _, ok := g.seen(s) - assert.True(t, ok) - assert.Equal(t, s, g.hydrateService(f, desc)) -} - -func TestGatherer_HydrateMethod(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - io := dummyMsg() - g.add(io) - - s := dummyService() - desc := &descriptor.MethodDescriptorProto{ - InputType: proto.String(io.FullyQualifiedName()), - OutputType: proto.String(io.FullyQualifiedName()), - } - - m := g.hydrateMethod(s, desc) - assert.Equal(t, io, m.Input()) - assert.Equal(t, io, m.Output()) - assert.Equal(t, s, m.Service()) - assert.Equal(t, desc, m.Descriptor()) - - _, ok := g.seen(m) - assert.True(t, ok) - assert.Equal(t, m, g.hydrateMethod(s, desc)) -} - -func TestGatherer_Name(t *testing.T) { - t.Parallel() - g := initTestGatherer(t) - assert.Equal(t, gathererPluginName, g.Name()) -} - -func TestGatherer_GenerateImports(t *testing.T) { - t.Parallel() - g := initTestGatherer(t) - assert.NotPanics(t, func() { g.GenerateImports(nil) }) -} - -func TestGatherer_Init(t *testing.T) { - t.Parallel() - - gen := &generator.Generator{Request: &plugin_go.CodeGeneratorRequest{}} - g := &gatherer{PluginBase: &PluginBase{}} - - assert.NotPanics(t, func() { g.Init(gen) }) - assert.Equal(t, gen, g.Generator.Unwrap()) - assert.NotNil(t, g.pkgs) - assert.NotNil(t, g.entities) -} - -func TestGatherer_ResolveFullyQualifiedName(t *testing.T) { - t.Parallel() - - f := dummyFile() - m := dummyMsg() - - g := initTestGatherer(t) - assert.Equal(t, f.Name().String(), g.resolveFullyQualifiedName(f)) - assert.Equal(t, m.FullyQualifiedName(), g.resolveFullyQualifiedName(m)) -} - -func TestGatherer_Add(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - m := dummyMsg() - g.add(m) - assert.Contains(t, g.entities, g.resolveFullyQualifiedName(m)) - assert.Equal(t, m, g.entities[g.resolveFullyQualifiedName(m)]) -} - -func TestGatherer_SeenName(t *testing.T) { - t.Parallel() - - m := dummyMsg() - g := initTestGatherer(t) - g.entities = map[string]Entity{ - "foo": m, - } - - e, ok := g.seenName("foo") - assert.True(t, ok) - assert.Equal(t, m, e) - - e, ok = g.seenName("bar") - assert.False(t, ok) - assert.Nil(t, e) -} - -func TestGatherer_Seen(t *testing.T) { - t.Parallel() - - g := initTestGatherer(t) - m := dummyMsg() - g.add(m) - - e, ok := g.seen(m) - assert.True(t, ok) - assert.Equal(t, m, e) - - e, ok = g.seen(dummyEnum()) - assert.False(t, ok) - assert.Nil(t, e) -} - -func TestGatherer_SeenObj(t *testing.T) { - t.Parallel() - - m := dummyMsg() - o := mockObject{ - file: m.File().Descriptor(), - name: dummyMsg().Name().Split(), - } - - g := initTestGatherer(t) - g.add(m.File()) - g.add(m) - - e, ok := g.seenObj(o) - assert.True(t, ok) - assert.Equal(t, m, e) -} - -func TestGatherer_NameByPath(t *testing.T) { - t.Parallel() - - file := &descriptor.FileDescriptorProto{ - Package: proto.String("my.package"), - Name: proto.String("file.proto"), - MessageType: []*descriptor.DescriptorProto{ - &descriptor.DescriptorProto{ - Name: proto.String("MyMessage"), - Field: []*descriptor.FieldDescriptorProto{ - &descriptor.FieldDescriptorProto{Name: proto.String("my_field")}, - &descriptor.FieldDescriptorProto{Name: proto.String("my_oneof_field")}, - }, - NestedType: []*descriptor.DescriptorProto{ - &descriptor.DescriptorProto{Name: proto.String("MyNestedMessage")}, - }, - OneofDecl: []*descriptor.OneofDescriptorProto{ - &descriptor.OneofDescriptorProto{Name: proto.String("my_oneof")}, - }, - }, - }, - EnumType: []*descriptor.EnumDescriptorProto{ - &descriptor.EnumDescriptorProto{ - Name: proto.String("MyEnum"), - Value: []*descriptor.EnumValueDescriptorProto{ - &descriptor.EnumValueDescriptorProto{Name: proto.String("FIRST")}, - &descriptor.EnumValueDescriptorProto{Name: proto.String("SECOND")}, - }, - }, - }, - Service: []*descriptor.ServiceDescriptorProto{ - &descriptor.ServiceDescriptorProto{ - Name: proto.String("MyService"), - Method: []*descriptor.MethodDescriptorProto{ - &descriptor.MethodDescriptorProto{Name: proto.String("MyMethod")}, - }, - }, - }, - } - - g := initTestGatherer(t) - - testCases := []struct { - name string - path []int32 - want string - }{ - { - name: "Package", - path: []int32{2}, - want: ".my.package", - }, - { - name: "Message", - path: []int32{4, 0}, - want: ".my.package.MyMessage", - }, - { - name: "Field in Message", - path: []int32{4, 0, 2, 0}, - want: ".my.package.MyMessage.my_field", - }, - { - name: "OneOf Field in Message", - path: []int32{4, 0, 2, 1}, - want: ".my.package.MyMessage.my_oneof_field", - }, - { - name: "NestedMessage in Message", - path: []int32{4, 0, 3, 0}, - want: ".my.package.MyMessage.MyNestedMessage", - }, - { - name: "OneOf in Message", - path: []int32{4, 0, 8, 0}, - want: ".my.package.MyMessage.my_oneof", - }, - { - name: "Enum", - path: []int32{5, 0}, - want: ".my.package.MyEnum", - }, - { - name: "EnumValue1 in Enum", - path: []int32{5, 0, 2, 0}, - want: ".my.package.MyEnum.FIRST", - }, - { - name: "EnumValue2 in Enum", - path: []int32{5, 0, 2, 1}, - want: ".my.package.MyEnum.SECOND", - }, - { - name: "Service", - path: []int32{6, 0}, - want: ".my.package.MyService", - }, - { - name: "Method in Service", - path: []int32{6, 0, 2, 0}, - want: ".my.package.MyService.MyMethod", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - v, err := g.nameByPath(file, tc.path) - assert.NoError(t, err) - assert.Equal(t, tc.want, v) - }) - } - -} - -type mockObject struct { - generator.Object - file *generator.FileDescriptor - name []string -} - -func (o mockObject) File() *generator.FileDescriptor { return o.file } -func (o mockObject) TypeName() []string { return o.name } - -type mockGathererPGG struct { - ProtocGenGo - objs map[string]generator.Object - types map[string]string - name string -} - -func initGathererPGG(g *gatherer) *mockGathererPGG { - pgg := &mockGathererPGG{ - ProtocGenGo: g.Generator, - objs: map[string]generator.Object{}, - types: map[string]string{}, - } - g.Generator = pgg - return pgg -} - -func (pgg *mockGathererPGG) GoType(m *generator.Descriptor, f *descriptor.FieldDescriptorProto) (string, string) { - return pgg.types[f.GetName()], "" -} - -func (pgg *mockGathererPGG) ObjectNamed(s string) generator.Object { - return pgg.objs[s] -} - -func (pgg *mockGathererPGG) packageName(fd packageFD) string { return pgg.name } diff --git a/generator.go b/generator.go index 9b6ad41..9442411 100644 --- a/generator.go +++ b/generator.go @@ -4,30 +4,21 @@ import ( "io" "log" "os" - - "github.com/golang/protobuf/protoc-gen-go/generator" ) -// Generator replaces the standard protoc-gen-go generator.Generator. It -// permits the registration of both standard protoc-gen-go plugins (eg, grpc) -// that are in-band with the officially generated code as well as Modules which -// enable creating out-of-band code using a computed protobuf AST. +// Generator configures and executes a protoc plugin's lifecycle. type Generator struct { Debugger - pgg ProtocGenGo // protoc-gen-go generator - gatherer *gatherer // gatherer pgg plugin - persister persister // handles writing artifacts to their output - workflow workflow // handles the actual code generation execution + persister persister // handles writing artifacts to their output + workflow workflow - plugins []Plugin // registered pgg plugins - mods []Module // registered pg* modules + mods []Module // registered pg* modules in io.Reader // protoc input reader out io.Writer // protoc output writer - debug bool // whether or not to print debug messages - includeGo bool // whether or not to gen official go code + debug bool // whether or not to print debug messages params Parameters // CLI parameters passed in from protoc paramMutators []ParamMutator // registered param mutators @@ -37,54 +28,24 @@ type Generator struct { // modify the behavior of the generator. func Init(opts ...InitOption) *Generator { g := &Generator{ - pgg: Wrap(generator.New()), - gatherer: newGatherer(), in: os.Stdin, out: os.Stdout, persister: newPersister(), - workflow: new(standardWorkflow), + workflow: &onceWorkflow{workflow: &standardWorkflow{}}, } - g.persister.SetPGG(g.pgg) - for _, opt := range opts { opt(g) } - g.Debugger = initDebugger(g, log.New(os.Stderr, "", 0)) + g.Debugger = initDebugger(g.debug, log.New(os.Stderr, "", 0)) g.persister.SetDebugger(g.Debugger) - if !g.includeGo { - g.workflow = &excludeGoWorkflow{workflow: g.workflow} - } - g.workflow = &onceWorkflow{workflow: g.workflow} - - return g -} - -// RegisterPlugin attaches protoc-gen-go plugins to the Generator. If p -// implements the protoc-gen-star Plugin interface, a Debugger will be passed -// in. This method is solely a wrapper around generator.RegisterPlugin. When -// designing these, all context should be cleared when Init is called. Note -// that these are currently global in scope and not specific to this generator -// instance. -func (g *Generator) RegisterPlugin(p ...generator.Plugin) *Generator { - for _, pl := range p { - g.Assert(pl != nil, "nil plugin provided") - g.Debug("registering plugin:", pl.Name()) - - if ppl, ok := pl.(Plugin); ok { - g.plugins = append(g.plugins, ppl) - } - - generator.RegisterPlugin(pl) - } - return g } -// RegisterModule should be called after Init but before Render to attach a -// custom Module to the Generator. This method can be called multiple times. +// RegisterModule should be called before Render to attach a custom Module to +// the Generator. This method can be called multiple times. func (g *Generator) RegisterModule(m ...Module) *Generator { for _, mod := range m { g.Assert(mod != nil, "nil module provided") @@ -95,10 +56,10 @@ func (g *Generator) RegisterModule(m ...Module) *Generator { return g } -// RegisterPostProcessor should be called after Init but before Render to -// attach PostProcessors to the Generator. This method can be called multiple -// times. PostProcessors are executed against their matches in the order in -// which they are registered. Only Artifacts generated by Modules are processed. +// RegisterPostProcessor should be called before Render to attach +// PostProcessors to the Generator. This method can be called multiple times. +// PostProcessors are executed against their matches in the order in which they +// are registered. func (g *Generator) RegisterPostProcessor(p ...PostProcessor) *Generator { for _, pp := range p { g.Assert(pp != nil, "nil post-processor provided") @@ -107,27 +68,23 @@ func (g *Generator) RegisterPostProcessor(p ...PostProcessor) *Generator { return g } -// AST returns the target Packages as well as all loaded Packages from the -// gatherer. Calling this method will trigger running the underlying -// protoc-gen-go workflow, including any registered plugins. Render can be -// safely called before or after this method without impacting module execution -// or writing the output to the output io.Writer. This method is particularly -// useful for integration-type tests. -func (g *Generator) AST() (targets map[string]Package, pkgs map[string]Package) { - g.workflow.Init(g) - g.workflow.Go() - return g.gatherer.targets, g.gatherer.pkgs +// AST returns the constructed AST graph from the gatherer. This method is +// idempotent, can be called multiple times (before and after calls to Render, +// even), and is particularly useful in testing. +func (g *Generator) AST() AST { + return g.workflow.Init(g) } -// Render emits all generated files from the plugins and Modules to the output -// io.Writer. If out is nil, os.Stdout is used. Render can only be called once -// and should be preceded by Init and any number of RegisterPlugin and/or -// RegisterModule calls. +// Render executes the protoc plugin flow, gathering the AST from the input +// io.Reader (typically stdin via protoc), running all the registered modules, +// and persisting the generated artifacts to the output io.Writer (typically +// stdout to protoc + direct file system writes for custom artifacts). This +// method is idempotent, in that subsequent calls to Render will have no +// effect. func (g *Generator) Render() { - g.workflow.Init(g) - g.workflow.Go() - g.workflow.Star() - g.workflow.Persist() + ast := g.workflow.Init(g) + arts := g.workflow.Run(ast) + g.workflow.Persist(arts) } func (g *Generator) push(prefix string) { g.Debugger = g.Push(prefix) } diff --git a/generator_test.go b/generator_test.go index 3ca8395..cb46d8f 100644 --- a/generator_test.go +++ b/generator_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/stretchr/testify/assert" ) @@ -29,90 +30,75 @@ func TestInit(t *testing.T) { assert.True(t, ok) } -func TestGenerator_RegisterPlugin(t *testing.T) { - t.Parallel() - - d := newMockDebugger(t) - g := &Generator{Debugger: d} - p := mockPlugin{PluginBase: &PluginBase{}, name: "foo"} - g.RegisterPlugin(p) - - assert.False(t, d.failed) - assert.Len(t, g.plugins, 1) - assert.Equal(t, p, g.plugins[0]) - - assert.Panics(t, func() { g.RegisterPlugin(nil) }) - assert.True(t, d.failed) -} - func TestGenerator_RegisterModule(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() g := &Generator{Debugger: d} assert.Empty(t, g.mods) g.RegisterModule(&mockModule{name: "foo"}) - assert.False(t, d.failed) + assert.False(t, d.Failed()) assert.Len(t, g.mods, 1) assert.Panics(t, func() { g.RegisterModule(nil) }) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestGenerator_RegisterPostProcessor(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := newPersister() g := &Generator{Debugger: d, persister: p} + pp := &mockPP{} + assert.Empty(t, p.procs) - g.RegisterPostProcessor(GoFmt()) + g.RegisterPostProcessor(pp) - assert.False(t, d.failed) + assert.False(t, d.Failed()) assert.Len(t, p.procs, 1) g.RegisterPostProcessor(nil) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestGenerator_AST(t *testing.T) { t.Parallel() g := Init() - g.workflow = &dummyWorkflow{} - - pkg := dummyPkg() - pkgName := pkg.GoName().String() - g.gatherer.targets = map[string]Package{pkgName: pkg} - g.gatherer.pkgs = map[string]Package{"foo": nil} + wf := &dummyWorkflow{AST: new(graph)} + g.workflow = wf - targets, pkgs := g.AST() - assert.Equal(t, g.gatherer.targets[pkgName], targets[pkgName]) - assert.Equal(t, g.gatherer.pkgs, pkgs) + assert.Equal(t, wf.AST, g.AST()) + assert.True(t, wf.initted) } func TestGenerator_Render(t *testing.T) { // cannot be parallel - req := &plugin_go.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} + req := &plugin_go.CodeGeneratorRequest{ + FileToGenerate: []string{"foo"}, + ProtoFile: []*descriptor.FileDescriptorProto{ + { + Name: proto.String("foo"), + Syntax: proto.String("proto2"), + Package: proto.String("bar"), + }, + }, + } b, err := proto.Marshal(req) assert.NoError(t, err) buf := &bytes.Buffer{} g := Init(ProtocInput(bytes.NewReader(b)), ProtocOutput(buf)) - g.pgg = mockGeneratorPGG{ProtocGenGo: g.pgg} - g.gatherer.targets = map[string]Package{"foo": &pkg{}} - g.pgg.response().File = []*plugin_go.CodeGeneratorResponse_File{{}} - assert.NotPanics(t, g.Render) var res plugin_go.CodeGeneratorResponse assert.NoError(t, proto.Unmarshal(buf.Bytes(), &res)) - assert.True(t, proto.Equal(g.pgg.response(), &res)) } func TestGenerator_PushPop(t *testing.T) { @@ -130,12 +116,3 @@ func TestGenerator_PushPop(t *testing.T) { _, ok = g.Debugger.(rootDebugger) assert.True(t, ok) } - -type mockGeneratorPGG struct { - ProtocGenGo -} - -func (pgg mockGeneratorPGG) Error(err error, msgs ...string) {} -func (pgg mockGeneratorPGG) Fail(msgs ...string) {} -func (pgg mockGeneratorPGG) prepare(param Parameters) {} -func (pgg mockGeneratorPGG) generate() {} diff --git a/init_option.go b/init_option.go index fbbc639..aca34d3 100644 --- a/init_option.go +++ b/init_option.go @@ -1,9 +1,8 @@ package pgs import ( - "os" - "io" + "os" "github.com/spf13/afero" ) @@ -26,11 +25,6 @@ func ProtocInput(r io.Reader) InitOption { return func(g *Generator) { g.in = r // os.Stdout is used. func ProtocOutput(w io.Writer) InitOption { return func(g *Generator) { g.out = w } } -// IncludeGo permits generation of the standard Go protocol buffer code -// alongside any custom modules. By default, none of the standard protoc-gen-go -// code is generated. -func IncludeGo() InitOption { return func(g *Generator) { g.includeGo = true } } - // DebugMode enables verbose logging for module development and debugging. func DebugMode() InitOption { return func(g *Generator) { g.debug = true } } @@ -38,26 +32,12 @@ func DebugMode() InitOption { return func(g *Generator) { g.debug = true } } // is non-empty. func DebugEnv(f string) InitOption { return func(g *Generator) { g.debug = os.Getenv(f) != "" } } -// RequirePlugin force-enables any plugins with name, regardless of the -// parameters passed in from protoc. -func RequirePlugin(name ...string) InitOption { - return MutateParams(func(p Parameters) { p.AddPlugin(name...) }) -} - -// MutateParams applies pm to the parameters passed in from protoc. The -// ParamMutator is applied prior to executing the protoc-gen-go workflow. -func MutateParams(pm ParamMutator) InitOption { - return func(g *Generator) { g.paramMutators = append(g.paramMutators, pm) } +// MutateParams applies pm to the parameters passed in from protoc. +func MutateParams(pm ...ParamMutator) InitOption { + return func(g *Generator) { g.paramMutators = append(g.paramMutators, pm...) } } // FileSystem overrides the default file system used to write Artifacts to // disk. By default, the OS's file system is used. This option currently only // impacts CustomFile and CustomTemplateFile artifacts generated by modules. func FileSystem(fs afero.Fs) InitOption { return func(g *Generator) { g.persister.SetFS(fs) } } - -// MultiPackage indicates that the Generator should expect files from multiple -// packages simultaneously. Normally, protoc-gen-go disallows running against -// files from more than one package at a time. -func MultiPackage() InitOption { - return func(g *Generator) { g.workflow = &multiPackageWorkflow{workflow: g.workflow} } -} diff --git a/init_option_test.go b/init_option_test.go index af0abfa..0db61c5 100644 --- a/init_option_test.go +++ b/init_option_test.go @@ -1,27 +1,16 @@ package pgs import ( + "bytes" "math/rand" "os" "strconv" "testing" - "bytes" - "github.com/spf13/afero" "github.com/stretchr/testify/assert" ) -func TestIncludeGo(t *testing.T) { - t.Parallel() - - g := &Generator{} - assert.False(t, g.includeGo) - - IncludeGo()(g) - assert.True(t, g.includeGo) -} - func TestDebugMode(t *testing.T) { t.Parallel() @@ -51,7 +40,7 @@ func TestDebugEnv(t *testing.T) { func TestFileSystem(t *testing.T) { t.Parallel() - p := dummyPersister(newMockDebugger(t)) + p := dummyPersister(InitMockDebugger()) g := &Generator{persister: p} fs := afero.NewMemMapFs() @@ -81,26 +70,3 @@ func TestProtocOutput(t *testing.T) { ProtocOutput(b)(g) assert.Equal(t, b, g.out) } - -func TestMultiPackage(t *testing.T) { - t.Parallel() - - g := &Generator{workflow: &dummyWorkflow{}} - - MultiPackage()(g) - _, ok := g.workflow.(*multiPackageWorkflow) - assert.True(t, ok) -} - -func TestRequirePlugin(t *testing.T) { - t.Parallel() - - g := Init(RequirePlugin("foo", "bar")) - - p := Parameters{} - for _, pm := range g.paramMutators { - pm(p) - } - - assert.Equal(t, "plugins=foo+bar", p.String()) -} diff --git a/lang/go/Makefile b/lang/go/Makefile new file mode 100644 index 0000000..a4dac5c --- /dev/null +++ b/lang/go/Makefile @@ -0,0 +1,42 @@ +.PHONY: testdata-go-names +testdata-names: ../../bin/protoc-gen-debug # parse the proto file sets in testdata/names and renders binary CodeGeneratorRequest + official go codegen + cd testdata/names && \ + set -e; for subdir in `find . -type d -mindepth 1 -maxdepth 1`; do \ + cd $$subdir; \ + params=`cat params`; \ + protoc -I . \ + --plugin=protoc-gen-debug=../../../../../bin/protoc-gen-debug \ + --debug_out=".:." \ + --go_out="plugins,paths=source_relative,$$params:." \ + `find . -name "*.proto"`; \ + cd -; \ + done + +testdata-packages: ../../bin/protoc-gen-debug + cd testdata/packages && \ + set -e; for subdir in `find . -type d -mindepth 1 -maxdepth 1 | grep -v targets`; do \ + cd $$subdir; \ + params=`cat params`; \ + protoc -I . -I .. \ + --plugin=protoc-gen-debug=../../../../../bin/protoc-gen-debug \ + --debug_out=".:." \ + --go_out="paths=source_relative,$$params:." \ + `find . -name "*.proto"`; \ + cd -; \ + done + +testdata-outputs: ../../bin/protoc-gen-debug + cd testdata/outputs && \ + set -e; for subdir in `find . -type d -mindepth 1 -maxdepth 1`; do \ + cd $$subdir; \ + params=`cat params`; \ + protoc -I . -I .. \ + --plugin=protoc-gen-debug=../../../../../bin/protoc-gen-debug \ + --debug_out=".:." \ + --go_out="$$params:." \ + `find . -name "*.proto"`; \ + cd -; \ + done + +../../bin/protoc-gen-debug: + cd ../.. && $(MAKE) bin/protoc-gen-debug diff --git a/lang/go/context.go b/lang/go/context.go new file mode 100644 index 0000000..bb9db23 --- /dev/null +++ b/lang/go/context.go @@ -0,0 +1,68 @@ +package pgsgo + +import "github.com/lyft/protoc-gen-star" + +// Context resolves Go-specific language for Packages & Entities generated by +// protoc-gen-go. The rules that drive the naming behavior are complicated, and +// result from an interplay of the go_package file option, the proto package, +// and the proto filename itself. Therefore, it is recommended that all proto +// files that are targeting Go should include a fully qualified go_package +// option. These must be consistent for all proto files that are intended to be +// in the same Go package. +type Context interface { + // Params returns the Parameters associated with this context. + Params() pgs.Parameters + + // Name returns the name of a Node as it would appear in the generation output + // of protoc-gen-go. For each type, the following is returned: + // + // - Package: the Go package name + // - File: the Go package name + // - Message: the struct name + // - Field: the field name on the Message struct + // - OneOf: the field name on the Message struct + // - Enum: the type name + // - EnumValue: the constant name + // - Service: the server interface name + // - Method: the method name on the server and client interface + // + Name(node pgs.Node) pgs.Name + + // ServerName returns the name of the server interface for the Service. + ServerName(service pgs.Service) pgs.Name + + // ClientName returns the name of the client interface for the Service. + ClientName(service pgs.Service) pgs.Name + + // OneofOption returns the struct name that wraps a OneOf option's value. These + // messages contain one field, matching the value returned by Name for this + // Field. + OneofOption(field pgs.Field) pgs.Name + + // TypeName returns the type name of a Field as it would appear in the + // generated message struct from protoc-gen-go. Fields from imported + // packages will be prefixed with the package name. + Type(field pgs.Field) TypeName + + // PackageName returns the name of the Node's package as it would appear in + // Go source generated by the official protoc-gen-go plugin. + PackageName(node pgs.Node) pgs.Name + + // ImportPath returns the Go import path for an entity as it would be + // included in an import block in a Go file. This value is only appropriate + // for Entities imported into a target file/package. + ImportPath(entity pgs.Entity) pgs.FilePath + + // OutputPath returns the output path relative to the plugin's output destination + OutputPath(entity pgs.Entity) pgs.FilePath +} + +type context struct{ p pgs.Parameters } + +// InitContext configures a Context that should be used for deriving Go names +// for all Packages and Entities. +func InitContext(params pgs.Parameters) Context { + return context{params} +} + +func (c context) Params() pgs.Parameters { return c.p } diff --git a/lang/go/context_test.go b/lang/go/context_test.go new file mode 100644 index 0000000..034417f --- /dev/null +++ b/lang/go/context_test.go @@ -0,0 +1,19 @@ +package pgsgo + +import ( + "testing" + + "github.com/lyft/protoc-gen-star" + "github.com/stretchr/testify/assert" +) + +func TestContext_Params(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + p.SetStr("foo", "bar") + ctx := InitContext(p) + + params := ctx.Params() + assert.Equal(t, "bar", params.Str("foo")) +} diff --git a/lang/go/docs.go b/lang/go/docs.go new file mode 100644 index 0000000..84f50ab --- /dev/null +++ b/lang/go/docs.go @@ -0,0 +1,2 @@ +// Package pgsgo contains Go-specific helpers for use with PG* based protoc-plugins +package pgsgo diff --git a/lang/go/gofmt.go b/lang/go/gofmt.go new file mode 100644 index 0000000..c736ed5 --- /dev/null +++ b/lang/go/gofmt.go @@ -0,0 +1,36 @@ +package pgsgo + +import ( + "go/format" + "strings" + + "github.com/lyft/protoc-gen-star" +) + +type goFmt struct{} + +// GoFmt returns a PostProcessor that runs gofmt on any files ending in ".go" +func GoFmt() pgs.PostProcessor { return goFmt{} } + +func (p goFmt) Match(a pgs.Artifact) bool { + var n string + + switch a := a.(type) { + case pgs.GeneratorFile: + n = a.Name + case pgs.GeneratorTemplateFile: + n = a.Name + case pgs.CustomFile: + n = a.Name + case pgs.CustomTemplateFile: + n = a.Name + default: + return false + } + + return strings.HasSuffix(n, ".go") +} + +func (p goFmt) Process(in []byte) ([]byte, error) { return format.Source(in) } + +var _ pgs.PostProcessor = goFmt{} diff --git a/lang/go/gofmt_test.go b/lang/go/gofmt_test.go new file mode 100644 index 0000000..f123254 --- /dev/null +++ b/lang/go/gofmt_test.go @@ -0,0 +1,54 @@ +package pgsgo + +import ( + "testing" + + "github.com/lyft/protoc-gen-star" + + "github.com/stretchr/testify/assert" +) + +func TestGoFmt_Match(t *testing.T) { + t.Parallel() + + pp := GoFmt() + + tests := []struct { + n string + a pgs.Artifact + m bool + }{ + {"GenFile", pgs.GeneratorFile{Name: "foo.go"}, true}, + {"GenFileNonGo", pgs.GeneratorFile{Name: "bar.txt"}, false}, + + {"GenTplFile", pgs.GeneratorTemplateFile{Name: "foo.go"}, true}, + {"GenTplFileNonGo", pgs.GeneratorTemplateFile{Name: "bar.txt"}, false}, + + {"CustomFile", pgs.CustomFile{Name: "foo.go"}, true}, + {"CustomFileNonGo", pgs.CustomFile{Name: "bar.txt"}, false}, + + {"CustomTplFile", pgs.CustomTemplateFile{Name: "foo.go"}, true}, + {"CustomTplFileNonGo", pgs.CustomTemplateFile{Name: "bar.txt"}, false}, + + {"NonMatch", pgs.GeneratorAppend{FileName: "foo.go"}, false}, + } + + for _, test := range tests { + tc := test + t.Run(tc.n, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.m, pp.Match(tc.a)) + }) + } +} + +func TestGoFmt_Process(t *testing.T) { + t.Parallel() + + src := []byte("// test\n package foo\n\nvar bar int = 123\n") + exp := []byte("// test\npackage foo\n\nvar bar int = 123\n") + + out, err := GoFmt().Process(src) + assert.NoError(t, err) + assert.Equal(t, exp, out) +} diff --git a/lang/go/helpers_test.go b/lang/go/helpers_test.go new file mode 100644 index 0000000..16234b4 --- /dev/null +++ b/lang/go/helpers_test.go @@ -0,0 +1,45 @@ +package pgsgo + +import ( + "io/ioutil" + "path/filepath" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/plugin" + "github.com/lyft/protoc-gen-star" + "github.com/stretchr/testify/require" +) + +func readCodeGenReq(t *testing.T, dir ...string) *plugin_go.CodeGeneratorRequest { + dirs := append(append([]string{"testdata"}, dir...), "code_generator_request.pb.bin") + filename := filepath.Join(dirs...) + + data, err := ioutil.ReadFile(filename) + require.NoError(t, err, "unable to read CDR at %q", filename) + + req := &plugin_go.CodeGeneratorRequest{} + err = proto.Unmarshal(data, req) + require.NoError(t, err, "unable to unmarshal CDR data at %q", filename) + + return req +} + +func buildGraph(t *testing.T, dir ...string) pgs.AST { + d := pgs.InitMockDebugger() + ast := pgs.ProcessDescriptors(d, readCodeGenReq(t, dir...)) + require.False(t, d.Failed(), "failed to build graph (see previous log statements)") + return ast +} + +func loadContext(t *testing.T, dir ...string) Context { + dirs := append(append([]string{"testdata"}, dir...), "params") + filename := filepath.Join(dirs...) + + data, err := ioutil.ReadFile(filename) + require.NoError(t, err, "unable to read params at %q", filename) + + params := pgs.ParseParameters(strings.TrimSpace(string(data))) + return InitContext(params) +} diff --git a/lang/go/name.go b/lang/go/name.go new file mode 100644 index 0000000..72ffbc9 --- /dev/null +++ b/lang/go/name.go @@ -0,0 +1,90 @@ +package pgsgo + +import ( + "fmt" + + "github.com/golang/protobuf/protoc-gen-go/generator" + "github.com/lyft/protoc-gen-star" +) + +func (c context) Name(node pgs.Node) pgs.Name { + // Message or Enum + type ChildEntity interface { + Name() pgs.Name + Parent() pgs.ParentEntity + } + + switch en := node.(type) { + case pgs.Package: // the package name for the first file (should be consistent) + return c.PackageName(en) + case pgs.File: // the package name for this file + return c.PackageName(en) + case ChildEntity: // Message or Enum types, which may be nested + n := pggUpperCamelCase(en.Name()) + if p, ok := en.Parent().(pgs.Message); ok { + n = pgs.Name(joinNames(c.Name(p), n)) + } + return n + case pgs.Field: // field names cannot conflict with other generated methods + return replaceProtected(pggUpperCamelCase(en.Name())) + case pgs.OneOf: // oneof field names cannot conflict with other generated methods + return replaceProtected(pggUpperCamelCase(en.Name())) + case pgs.EnumValue: // EnumValue are prefixed with the enum name + return pgs.Name(joinNames(c.Name(en.Enum()), en.Name())) + case pgs.Service: // always return the server name + return c.ServerName(en) + case pgs.Entity: // any other entity should be just upper-camel-cased + return pggUpperCamelCase(en.Name()) + default: + panic("unreachable") + } +} + +func (c context) OneofOption(field pgs.Field) pgs.Name { + return pgs.Name(joinNames(c.Name(field.Message()), c.Name(field))) +} + +func (c context) ServerName(s pgs.Service) pgs.Name { + n := pggUpperCamelCase(s.Name()) + return pgs.Name(fmt.Sprintf("%sServer", n)) +} + +func (c context) ClientName(s pgs.Service) pgs.Name { + n := pggUpperCamelCase(s.Name()) + return pgs.Name(fmt.Sprintf("%sClient", n)) +} + +// pggUpperCamelCase converts Name n to the protoc-gen-go defined upper +// camelcase. The rules are slightly different from pgs.UpperCamelCase in that +// leading underscores are converted to 'X', mid-string underscores followed by +// lowercase letters are removed and the letter is capitalized, all other +// punctuation is preserved. This method should be used when deriving names of +// protoc-gen-go generated code (ie, message/service struct names and field +// names). +// +// See: https://godoc.org/github.com/golang/protobuf/protoc-gen-go/generator#CamelCase +func pggUpperCamelCase(n pgs.Name) pgs.Name { + return pgs.Name(generator.CamelCase(n.String())) +} + +var protectedNames = map[pgs.Name]pgs.Name{ + "Reset": "Reset_", + "String": "String_", + "ProtoMessage": "ProtoMessage_", + "Marshal": "Marshal_", + "Unmarshal": "Unmarshal_", + "ExtensionRangeArray": "ExtensionRangeArray_", + "ExtensionMap": "ExtensionMap_", + "Descriptor": "Descriptor_", +} + +func replaceProtected(n pgs.Name) pgs.Name { + if use, protected := protectedNames[n]; protected { + return use + } + return n +} + +func joinNames(a, b pgs.Name) pgs.Name { + return pgs.Name(fmt.Sprintf("%s_%s", a, b)) +} diff --git a/lang/go/name_test.go b/lang/go/name_test.go new file mode 100644 index 0000000..d18a46d --- /dev/null +++ b/lang/go/name_test.go @@ -0,0 +1,204 @@ +package pgsgo + +import ( + "testing" + + "github.com/lyft/protoc-gen-star" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPGGUpperCamelCase(t *testing.T) { + t.Parallel() + + tests := []struct { + in string + ex string + }{ + {"foo_bar", "FooBar"}, + {"myJSON", "MyJSON"}, + {"PDFTemplate", "PDFTemplate"}, + {"_my_field_name_2", "XMyFieldName_2"}, + {"my.field", "My.field"}, + {"my_Field", "My_Field"}, + } + + for _, tc := range tests { + assert.Equal(t, tc.ex, pggUpperCamelCase(pgs.Name(tc.in)).String()) + } +} + +func TestName(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", "entities") + ctx := loadContext(t, "names", "entities") + + f := ast.Targets()["entities.proto"] + assert.Equal(t, pgs.Name("entities"), ctx.Name(f)) + assert.Equal(t, pgs.Name("entities"), ctx.Name(f.Package())) + + assert.Panics(t, func() { + ctx.Name(nil) + }) + + tests := []struct { + entity string + expected pgs.Name + }{ + // Top-Level Messages + {"UpperCamelCaseMessage", "UpperCamelCaseMessage"}, + {"lowerCamelCaseMessage", "LowerCamelCaseMessage"}, + {"SCREAMING_SNAKE_CASE", "SCREAMING_SNAKE_CASE"}, + {"Upper_Snake_Case", "Upper_Snake_Case"}, + {"lower_snake_case", "LowerSnakeCase"}, + {"lowercase", "Lowercase"}, + {"UPPERCASE", "UPPERCASE"}, + {"_underscore", "XUnderscore"}, + {"__DoubleUnderscore", "X_DoubleUnderscore"}, + {"String", "String"}, + + // Nested Messages + {"Nested.Message", "Nested_Message"}, + {"Nested._underscore", "Nested_XUnderscore"}, + {"Nested.String", "Nested_String"}, + {"Nested.Message.Message", "Nested_Message_Message"}, + + // Enums + {"UpperCamelCaseEnum", "UpperCamelCaseEnum"}, + {"lowerCamelCaseEnum", "LowerCamelCaseEnum"}, + {"SCREAMING_SNAKE_ENUM", "SCREAMING_SNAKE_ENUM"}, + {"lower_snake_enum", "LowerSnakeEnum"}, + {"Upper_Snake_Enum", "Upper_Snake_Enum"}, + + // EnumValues + {"UpperCamelCaseEnum.SCREAMING_SNAKE_CASE_ENUM_VALUE", "UpperCamelCaseEnum_SCREAMING_SNAKE_CASE_ENUM_VALUE"}, + {"UpperCamelCaseEnum.lower_snake_case_enum_value", "UpperCamelCaseEnum_lower_snake_case_enum_value"}, + {"UpperCamelCaseEnum.Upper_Snake_Case_Enum_Value", "UpperCamelCaseEnum_Upper_Snake_Case_Enum_Value"}, + {"UpperCamelCaseEnum.UpperCamelCaseEnumValue", "UpperCamelCaseEnum_UpperCamelCaseEnumValue"}, + {"UpperCamelCaseEnum.lowerCamelCaseEnumValue", "UpperCamelCaseEnum_lowerCamelCaseEnumValue"}, + {"lowerCamelCaseEnum.LCC_Value", "LowerCamelCaseEnum_LCC_Value"}, + {"SCREAMING_SNAKE_ENUM.SS_Value", "SCREAMING_SNAKE_ENUM_SS_Value"}, + {"lower_snake_enum.LS_Value", "LowerSnakeEnum_LS_Value"}, + {"Upper_Snake_Enum.US_Value", "Upper_Snake_Enum_US_Value"}, + + // Nested Enums + {"Nested.Enum", "Nested_Enum"}, + {"Nested.Enum.VALUE", "Nested_Enum_VALUE"}, + {"Nested.Message.Enum", "Nested_Message_Enum"}, + {"Nested.Message.Enum.VALUE", "Nested_Message_Enum_VALUE"}, + + // Field Names + {"Fields.lower_snake_case", "LowerSnakeCase"}, + {"Fields.Upper_Snake_Case", "Upper_Snake_Case"}, + {"Fields.SCREAMING_SNAKE_CASE", "SCREAMING_SNAKE_CASE"}, + {"Fields.lowerCamelCase", "LowerCamelCase"}, + {"Fields.UpperCamelCase", "UpperCamelCase"}, + {"Fields.string", "String_"}, + + // OneOfs + {"Oneofs.lower_snake_case", "LowerSnakeCase"}, + {"Oneofs.Upper_Snake_Case", "Upper_Snake_Case"}, + {"Oneofs.SCREAMING_SNAKE_CASE", "SCREAMING_SNAKE_CASE"}, + {"Oneofs.lowerCamelCase", "LowerCamelCase"}, + {"Oneofs.UpperCamelCase", "UpperCamelCase"}, + {"Oneofs.string", "String_"}, + {"Oneofs.oneof", "Oneof"}, + + // Services (always the Server name) + {"UpperCamelService", "UpperCamelServiceServer"}, + {"lowerCamelService", "LowerCamelServiceServer"}, + {"lower_snake_service", "LowerSnakeServiceServer"}, + {"Upper_Snake_Service", "Upper_Snake_ServiceServer"}, + {"SCREAMING_SNAKE_SERVICE", "SCREAMING_SNAKE_SERVICEServer"}, + {"reset", "ResetServer"}, + + // Methods + {"Service.UpperCamel", "UpperCamel"}, + {"Service.lowerCamel", "LowerCamel"}, + {"Service.lower_snake", "LowerSnake"}, + {"Service.Upper_Snake", "Upper_Snake"}, + {"Service.SCREAMING_SNAKE", "SCREAMING_SNAKE"}, + {"Service.Reset", "Reset"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.entity, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.entities." + tc.entity) + require.True(t, ok, "could not locate entity") + assert.Equal(t, tc.expected, ctx.Name(e)) + }) + } +} + +func TestContext_OneofOption(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", "entities") + ctx := loadContext(t, "names", "entities") + + tests := []struct { + field string + expected pgs.Name + }{ + {"LS", "Oneofs_LS"}, + {"US", "Oneofs_US"}, + {"SS", "Oneofs_SS"}, + {"LC", "Oneofs_LC"}, + {"UC", "Oneofs_UC"}, + {"S", "Oneofs_S"}, + {"lower_snake_case_o", "Oneofs_LowerSnakeCaseO"}, + {"Upper_Snake_Case_O", "Oneofs_Upper_Snake_Case_O"}, + {"SCREAMING_SNAKE_CASE_O", "Oneofs_SCREAMING_SNAKE_CASE_O"}, + {"lowerCamelCaseO", "Oneofs_LowerCamelCaseO"}, + {"UpperCamelCaseO", "Oneofs_UpperCamelCaseO"}, + {"reset", "Oneofs_Reset_"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.entities.Oneofs." + tc.field) + require.True(t, ok, "could not find field") + f := e.(pgs.Field) + assert.Equal(t, tc.expected, ctx.OneofOption(f)) + }) + } + +} + +func TestContext_ClientName(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", "entities") + ctx := loadContext(t, "names", "entities") + + tests := []struct { + service string + expected pgs.Name + }{ + {"UpperCamelService", "UpperCamelServiceClient"}, + {"lowerCamelService", "LowerCamelServiceClient"}, + {"lower_snake_service", "LowerSnakeServiceClient"}, + {"Upper_Snake_Service", "Upper_Snake_ServiceClient"}, + {"SCREAMING_SNAKE_SERVICE", "SCREAMING_SNAKE_SERVICEClient"}, + {"reset", "ResetClient"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.service, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.entities." + tc.service) + require.True(t, ok, "could not find service") + s := e.(pgs.Service) + assert.Equal(t, tc.expected, ctx.ClientName(s)) + }) + } +} diff --git a/lang/go/package.go b/lang/go/package.go new file mode 100644 index 0000000..27aa70a --- /dev/null +++ b/lang/go/package.go @@ -0,0 +1,94 @@ +package pgsgo + +import ( + "go/token" + "regexp" + "strings" + + "github.com/lyft/protoc-gen-star" +) + +var nonAlphaNumPattern = regexp.MustCompile("[^a-zA-Z0-9]") + +func (c context) PackageName(node pgs.Node) pgs.Name { + e, ok := node.(pgs.Entity) + if !ok { + e = node.(pgs.Package).Files()[0] + } + + _, pkg := c.optionPackage(e) + + if ip := c.p.Str("import_path"); ip != "" { + pkg = ip + } + + // if the package name is a Go keyword, prefix with '_' + if token.Lookup(pkg).IsKeyword() { + pkg = "_" + pkg + } + + // package name is kosher + return pgs.Name(pkg) +} + +func (c context) ImportPath(e pgs.Entity) pgs.FilePath { + path, _ := c.optionPackage(e) + path = c.p.Str("import_prefix") + path + return pgs.FilePath(path) +} + +func (c context) OutputPath(e pgs.Entity) pgs.FilePath { + out := e.File().InputPath().SetExt(".pb.go") + + // source relative doesn't try to be fancy + if Paths(c.p) == SourceRelative { + return out + } + + path, _ := c.optionPackage(e) + + // Import relative ignores the existing file structure + return pgs.FilePath(path).Push(out.Base()) +} + +func (c context) optionPackage(e pgs.Entity) (path, pkg string) { + // M mapping param overrides everything IFF the entity is not a build target + if override, ok := c.p["M"+e.File().InputPath().String()]; ok && !e.BuildTarget() { + path = override + pkg = override + if idx := strings.LastIndex(pkg, "/"); idx > -1 { + pkg = pkg[idx+1:] + } + return + } + + // check if there's a go_package option specified + pkg = e.File().Descriptor().GetOptions().GetGoPackage() + path = e.File().InputPath().Dir().String() + + if pkg == "" { + // have a proto package name, so use that + if n := e.Package().ProtoName(); n != "" { + pkg = n.SnakeCase().String() + } else { // no other info, then replace all non-alphanumerics from the input file name + pkg = nonAlphaNumPattern.ReplaceAllString(e.File().InputPath().BaseName(), "_") + } + return + } + + // go_package="example.com/foo/bar;baz" should have a package name of `baz` + if idx := strings.LastIndex(pkg, ";"); idx > -1 { + path = pkg[:idx] + pkg = pkg[idx+1:] + return + } + + // go_package="example.com/foo/bar" should have a package name of `bar` + if idx := strings.LastIndex(pkg, "/"); idx > -1 { + path = pkg + pkg = pkg[idx+1:] + return + } + + return +} diff --git a/lang/go/package_test.go b/lang/go/package_test.go new file mode 100644 index 0000000..59598c9 --- /dev/null +++ b/lang/go/package_test.go @@ -0,0 +1,131 @@ +package pgsgo + +import ( + "testing" + + "github.com/lyft/protoc-gen-star" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPackageName(t *testing.T) { + t.Parallel() + + tests := []struct { + dir string + expected pgs.Name + }{ + {"keyword", "_package"}, // go keywords are prefixed with _ + {"none", "NO_pack__age_name_"}, // if there is no package or go_package option, use the input filepath + {"package", "my_package"}, // use the go_package option + {"unnamed", "names_unnamed"}, // use the proto package if no go_package + {"import", "bar"}, // uses the basename if go_package contains a / + {"override", "baz"}, // if go_package contains ;, use everything to the right + {"import_path", "_package"}, // import_path param used if no go_package option + {"mapped", "unaffected"}, // M mapped params are ignored for build targets + } + + for _, test := range tests { + tc := test + t.Run(tc.dir, func(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", tc.dir) + ctx := loadContext(t, "names", tc.dir) + + for _, target := range ast.Targets() { + assert.Equal(t, tc.expected, ctx.PackageName(target)) + } + }) + } +} + +func TestImportPath(t *testing.T) { + t.Parallel() + + tests := []struct { + dir string + + fully pgs.FilePath + unqualified pgs.FilePath + none pgs.FilePath + }{ + { // no params changing the behavior of the import paths + "no_options", + "example.com/packages/targets/fully_qualified", + "targets/unqualified", + "targets/none", + }, + { // M params provided for each imported package + "mapped", + "example.com/foo/bar", + "example.com/fizz/buzz", + "example.com/quux", + }, + { // import_prefix param prefixes everything...pretty much doesn't work since it also prefixes the proto package + "import_prefix", + "foo.bar/example.com/packages/targets/fully_qualified", + "foo.bar/targets/unqualified", + "foo.bar/fizz/buzz", + }, + } + + for _, test := range tests { + tc := test + t.Run(tc.dir, func(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "packages", tc.dir) + ctx := loadContext(t, "packages", tc.dir) + + pkgs := map[string]pgs.FilePath{ + "packages.targets.fully_qualified": tc.fully, + "packages.targets.unqualified": tc.unqualified, + "packages.targets.none": tc.none, + } + + for pkg, expected := range pkgs { + t.Run(pkg, func(t *testing.T) { + p, ok := ast.Packages()[pkg] + require.True(t, ok, "package not found") + f := p.Files()[0] + assert.Equal(t, expected, ctx.ImportPath(f)) + }) + } + }) + } +} + +func TestOutputPath(t *testing.T) { + t.Parallel() + + tests := []struct { + dir, file string + expected pgs.FilePath + }{ + {"none", "none.proto", "none.pb.go"}, + {"none_srcrel", "none.proto", "none.pb.go"}, + {"unqualified", "unqualified.proto", "unqualified.pb.go"}, + {"unqualified_srcrel", "unqualified.proto", "unqualified.pb.go"}, + {"qualified", "qualified.proto", "example.com/qualified/qualified.pb.go"}, + {"qualified_srcrel", "qualified.proto", "qualified.pb.go"}, + {"import_prefix", "prefix.proto", "example.com/import_prefix/prefix.pb.go"}, + {"import_prefix_srcrel", "prefix.proto", "prefix.pb.go"}, + {"mapped", "mapped.proto", "mapped.pb.go"}, + {"mapped_srcrel", "mapped.proto", "mapped.pb.go"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.dir, func(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "outputs", tc.dir) + ctx := loadContext(t, "outputs", tc.dir) + f, ok := ast.Lookup(tc.file) + require.True(t, ok, "file not found") + assert.Equal(t, tc.expected, ctx.OutputPath(f)) + }) + } + +} diff --git a/lang/go/parameters.go b/lang/go/parameters.go new file mode 100644 index 0000000..23b0c6f --- /dev/null +++ b/lang/go/parameters.go @@ -0,0 +1,132 @@ +package pgsgo + +import ( + "fmt" + "strings" + + "github.com/lyft/protoc-gen-star" +) + +const ( + importPrefixKey = "import_prefix" + importPathKey = "import_path" + importMapKeyPrefix = "M" + pathTypeKey = "paths" + pluginsKey = "plugins" + pluginsSep = "+" +) + +// PathType describes how the generated output file paths should be constructed. +type PathType string + +const ( + // ImportPathRelative is the default and outputs the file based off the go + // import path defined in the go_package option. + ImportPathRelative PathType = "" + + // SourceRelative indicates files should be output relative to the path of + // the source file. + SourceRelative PathType = "source_relative" +) + +// Plugins returns the sub-plugins enabled for this protoc plugin. If the all +// value is true, all registered plugins are considered enabled (ie, protoc was +// called with an empty "plugins" parameter). Otherwise, plugins contains the +// list of plugins enabled by name. +func Plugins(p pgs.Parameters) (plugins []string, all bool) { + s, ok := p[pluginsKey] + if !ok { + return + } + + if all = s == ""; all { + return + } + + plugins = strings.Split(s, pluginsSep) + return +} + +// HasPlugin returns true if the plugin name is enabled in the parameters. This +// method will always return true if all plugins are enabled. +func HasPlugin(p pgs.Parameters, name string) bool { + plugins, all := Plugins(p) + if all { + return true + } + + for _, pl := range plugins { + if pl == name { + return true + } + } + + return false +} + +// AddPlugin adds name to the list of plugins in the parameters. If all plugins +// are enabled, this method is a noop. +func AddPlugin(p pgs.Parameters, name ...string) { + if len(name) == 0 { + return + } + + plugins, all := Plugins(p) + if all { + return + } + + p.SetStr(pluginsKey, strings.Join(append(plugins, name...), pluginsSep)) +} + +// EnableAllPlugins changes the parameters to enable all registered sub-plugins. +func EnableAllPlugins(p pgs.Parameters) { p.SetStr(pluginsKey, "") } + +// ImportPrefix returns the protoc-gen-go parameter. This prefix is added onto +// the beginning of all Go import paths. This is useful for things like +// generating protos in a subdirectory, or regenerating vendored protobufs +// in-place. By default, this method returns an empty string. +// +// See: https://github.com/golang/protobuf#parameters +func ImportPrefix(p pgs.Parameters) string { return p.Str(importPrefixKey) } + +// SetImportPrefix sets the protoc-gen-go ImportPrefix parameter. This is +// useful for overriding the behavior of the ImportPrefix at runtime. +func SetImportPrefix(p pgs.Parameters, prefix string) { p.SetStr(importPrefixKey, prefix) } + +// ImportPath returns the protoc-gen-go parameter. This value is used as the +// package if the input proto files do not declare a go_package option. If it +// contains slashes, everything up to the rightmost slash is ignored. +// +// See: https://github.com/golang/protobuf#parameters +func ImportPath(p pgs.Parameters) string { return p.Str(importPathKey) } + +// SetImportPath sets the protoc-gen-go ImportPath parameter. This is useful +// for overriding the behavior of the ImportPath at runtime. +func SetImportPath(p pgs.Parameters, path string) { p.SetStr(importPathKey, path) } + +// Paths returns the protoc-gen-go parameter. This value is used to switch the +// mode used to determine the output paths of the generated code. By default, +// paths are derived from the import path specified by go_package. It can be +// overridden to be "source_relative", ignoring the import path using the +// source path exclusively. +func Paths(p pgs.Parameters) PathType { return PathType(p.Str(pathTypeKey)) } + +// SetPaths sets the protoc-gen-go Paths parameter. This is useful for +// overriding the behavior of Paths at runtime. +func SetPaths(p pgs.Parameters, pt PathType) { p.SetStr(pathTypeKey, string(pt)) } + +// MappedImport returns the protoc-gen-go import overrides for the specified proto +// file. Each entry in the map keys off a proto file (as loaded by protoc) with +// values of the Go package to use. These values will be prefixed with the +// value of ImportPrefix when generating the Go code. +func MappedImport(p pgs.Parameters, proto string) (string, bool) { + imp, ok := p[fmt.Sprintf("%s%s", importMapKeyPrefix, proto)] + return imp, ok +} + +// AddImportMapping adds a proto file to Go package import mapping to the +// parameters. +func AddImportMapping(p pgs.Parameters, proto, pkg string) { + p[fmt.Sprintf("%s%s", importMapKeyPrefix, proto)] = pkg +} diff --git a/lang/go/parameters_test.go b/lang/go/parameters_test.go new file mode 100644 index 0000000..f5a9a6b --- /dev/null +++ b/lang/go/parameters_test.go @@ -0,0 +1,134 @@ +package pgsgo + +import ( + "testing" + + "github.com/lyft/protoc-gen-star" + "github.com/stretchr/testify/assert" +) + +func TestParameters_Plugins(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + plugins, all := Plugins(p) + assert.Empty(t, plugins) + assert.False(t, all) + + p[pluginsKey] = "foo+bar" + plugins, all = Plugins(p) + assert.Equal(t, []string{"foo", "bar"}, plugins) + assert.False(t, all) + + p[pluginsKey] = "" + plugins, all = Plugins(p) + assert.Empty(t, plugins) + assert.True(t, all) +} + +func TestParameters_HasPlugin(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + assert.False(t, HasPlugin(p, "foo")) + + p[pluginsKey] = "foo" + assert.True(t, HasPlugin(p, "foo")) + + p[pluginsKey] = "" + assert.True(t, HasPlugin(p, "foo")) + + p[pluginsKey] = "bar" + assert.False(t, HasPlugin(p, "foo")) +} + +func TestParameters_AddPlugin(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + AddPlugin(p, "foo", "bar") + assert.Equal(t, "foo+bar", p[pluginsKey]) + + AddPlugin(p, "baz") + assert.Equal(t, "foo+bar+baz", p[pluginsKey]) + + AddPlugin(p) + assert.Equal(t, "foo+bar+baz", p[pluginsKey]) + + p[pluginsKey] = "" + AddPlugin(p, "fizz", "buzz") + assert.Equal(t, "", p[pluginsKey]) +} + +func TestParameters_EnableAllPlugins(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{pluginsKey: "foo"} + _, all := Plugins(p) + assert.False(t, all) + + EnableAllPlugins(p) + _, all = Plugins(p) + assert.True(t, all) +} + +func TestParameters_ImportPrefix(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + assert.Empty(t, ImportPrefix(p)) + SetImportPrefix(p, "foo") + assert.Equal(t, "foo", ImportPrefix(p)) +} + +func TestParameters_ImportPath(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + assert.Empty(t, ImportPath(p)) + SetImportPath(p, "foo") + assert.Equal(t, "foo", ImportPath(p)) +} + +func TestParameters_ImportMap(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{ + "Mfoo.proto": "bar", + "Mfizz/buzz.proto": "baz", + } + + AddImportMapping(p, "quux.proto", "shme") + + tests := []struct { + proto, path string + exists bool + }{ + {"quux.proto", "shme", true}, + {"foo.proto", "bar", true}, + {"fizz/buzz.proto", "baz", true}, + {"abcde.proto", "", false}, + } + + for _, test := range tests { + t.Run(test.proto, func(t *testing.T) { + path, ok := MappedImport(p, test.proto) + if test.exists { + assert.True(t, ok) + assert.Equal(t, test.path, path) + } else { + assert.False(t, ok) + } + }) + } +} + +func TestParameters_Paths(t *testing.T) { + t.Parallel() + + p := pgs.Parameters{} + + assert.Equal(t, ImportPathRelative, Paths(p)) + SetPaths(p, SourceRelative) + assert.Equal(t, SourceRelative, Paths(p)) +} diff --git a/lang/go/testdata/names/entities/entities.proto b/lang/go/testdata/names/entities/entities.proto new file mode 100644 index 0000000..ce2cbe4 --- /dev/null +++ b/lang/go/testdata/names/entities/entities.proto @@ -0,0 +1,108 @@ +syntax = "proto3"; + +package names.entities; +option go_package = "entities"; + +message UpperCamelCaseMessage {} + +message lowerCamelCaseMessage {} + +message SCREAMING_SNAKE_CASE {} + +message Upper_Snake_Case {} + +message lower_snake_case {} + +message lowercase {} + +message UPPERCASE {} + +message _underscore {} + +message __DoubleUnderscore {} + +message String {} // protected name + +message Nested { + message Message { + message Message {} + + enum Enum { VALUE = 0; } + } + + message _underscore {} + + message String {} // protected name + + enum Enum { VALUE = 0; } +} + +enum UpperCamelCaseEnum { + SCREAMING_SNAKE_CASE_ENUM_VALUE = 0; + lower_snake_case_enum_value = 1; + Upper_Snake_Case_Enum_Value = 2; + UpperCamelCaseEnumValue = 3; + lowerCamelCaseEnumValue = 4; +} + +enum lowerCamelCaseEnum {LCC_Value = 0;} + +enum SCREAMING_SNAKE_ENUM {SS_Value = 0;} + +enum lower_snake_enum {LS_Value = 0;} + +enum Upper_Snake_Enum {US_Value = 0;} + +message Fields { + bool lower_snake_case = 1; + bool Upper_Snake_Case = 2; + bool SCREAMING_SNAKE_CASE = 3; + bool lowerCamelCase = 4; + bool UpperCamelCase = 5; + bool string = 6; // protected name +} + +message Oneofs { + oneof lower_snake_case {bool LS = 1;} + + oneof Upper_Snake_Case {bool US = 2;} + + oneof SCREAMING_SNAKE_CASE {bool SS = 3;} + + oneof lowerCamelCase {bool LC = 4;} + + oneof UpperCamelCase {bool UC = 5;} + + // protected + oneof string {bool S = 6;} + + oneof oneof { + bool lower_snake_case_o = 7; + bool Upper_Snake_Case_O = 8; + bool SCREAMING_SNAKE_CASE_O = 9; + bool lowerCamelCaseO = 10; + bool UpperCamelCaseO = 11; + bool reset = 12; // protected + } +} + +service UpperCamelService {} + +service lowerCamelService {} + +service lower_snake_service {} + +service Upper_Snake_Service {} + +service SCREAMING_SNAKE_SERVICE {} + +service reset {} + +service Service { + rpc UpperCamel(stream String) returns (String); + rpc lowerCamel(String) returns (stream String); + rpc lower_snake(stream String) returns (stream String); + rpc Upper_Snake(String) returns (String); + rpc SCREAMING_SNAKE(String) returns (String); + rpc Reset(String) returns (String); +} diff --git a/lang/go/testdata/names/entities/params b/lang/go/testdata/names/entities/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/import/import.proto b/lang/go/testdata/names/import/import.proto new file mode 100644 index 0000000..00d77a7 --- /dev/null +++ b/lang/go/testdata/names/import/import.proto @@ -0,0 +1,4 @@ +syntax="proto3"; + +package names.import; +option go_package = "example.com/foo/bar"; diff --git a/lang/go/testdata/names/import/params b/lang/go/testdata/names/import/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/import_path/import_path.proto b/lang/go/testdata/names/import_path/import_path.proto new file mode 100644 index 0000000..245b082 --- /dev/null +++ b/lang/go/testdata/names/import_path/import_path.proto @@ -0,0 +1,4 @@ +syntax="proto3"; +package names.import_path; + +message ImportPath {} diff --git a/lang/go/testdata/names/import_path/params b/lang/go/testdata/names/import_path/params new file mode 100644 index 0000000..e4626a4 --- /dev/null +++ b/lang/go/testdata/names/import_path/params @@ -0,0 +1 @@ +Mnames/import_path/import_path.proto=foobar,import_path=package diff --git a/lang/go/testdata/names/keyword/keyword.proto b/lang/go/testdata/names/keyword/keyword.proto new file mode 100644 index 0000000..2606940 --- /dev/null +++ b/lang/go/testdata/names/keyword/keyword.proto @@ -0,0 +1,6 @@ +syntax="proto3"; + +package names.keyword; +option go_package = "package"; + +message Package {} diff --git a/lang/go/testdata/names/keyword/params b/lang/go/testdata/names/keyword/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/mapped/mapped.proto b/lang/go/testdata/names/mapped/mapped.proto new file mode 100644 index 0000000..b00f493 --- /dev/null +++ b/lang/go/testdata/names/mapped/mapped.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package names.mapped; +option go_package="unaffected"; + +message Mapped {} diff --git a/lang/go/testdata/names/mapped/params b/lang/go/testdata/names/mapped/params new file mode 100644 index 0000000..b47ee5a --- /dev/null +++ b/lang/go/testdata/names/mapped/params @@ -0,0 +1 @@ +Mmapped.proto=foobar diff --git a/lang/go/testdata/names/none/NO.pack--age.name$.proto b/lang/go/testdata/names/none/NO.pack--age.name$.proto new file mode 100644 index 0000000..c99805e --- /dev/null +++ b/lang/go/testdata/names/none/NO.pack--age.name$.proto @@ -0,0 +1,3 @@ +syntax="proto3"; + +message None {} diff --git a/lang/go/testdata/names/none/params b/lang/go/testdata/names/none/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/override/override.proto b/lang/go/testdata/names/override/override.proto new file mode 100644 index 0000000..86a0289 --- /dev/null +++ b/lang/go/testdata/names/override/override.proto @@ -0,0 +1,4 @@ +syntax="proto3"; + +package names.override; +option go_package = "example.com/foo/bar;baz"; diff --git a/lang/go/testdata/names/override/params b/lang/go/testdata/names/override/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/package/package.proto b/lang/go/testdata/names/package/package.proto new file mode 100644 index 0000000..ad7c2b1 --- /dev/null +++ b/lang/go/testdata/names/package/package.proto @@ -0,0 +1,6 @@ +syntax="proto3"; + +package names.package; +option go_package = "my_package"; + +message Package {} diff --git a/lang/go/testdata/names/package/params b/lang/go/testdata/names/package/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/types/params b/lang/go/testdata/names/types/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/types/proto2.proto b/lang/go/testdata/names/types/proto2.proto new file mode 100644 index 0000000..cae3f31 --- /dev/null +++ b/lang/go/testdata/names/types/proto2.proto @@ -0,0 +1,65 @@ +syntax="proto2"; +package names.types; +option go_package = "example.com/foo/bar"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/type.proto"; + +message Proto2 { + optional double double = 1; + optional float float = 2; + optional int64 int64 = 3; + optional sfixed64 sfixed64 = 4; + optional sint64 sint64 = 5; + optional uint64 uint64 = 6; + optional fixed64 fixed64 = 7; + optional int32 int32 = 8; + optional sfixed32 sfixed32 = 9; + optional sint32 sint32 = 10; + optional uint32 uint32 = 11; + optional fixed32 fixed32 = 12; + optional bool bool = 13; + optional string string = 14; + optional bytes bytes = 15; + + optional Enum enum = 16; + optional google.protobuf.Syntax ext_enum = 17; + optional Required msg = 18; + optional google.protobuf.Duration ext_msg = 19; + + repeated double repeated_scalar = 20; + repeated Enum repeated_enum = 21; + repeated google.protobuf.Syntax repeated_ext_enum = 22; + repeated Required repeated_msg = 23; + repeated google.protobuf.Duration repeated_ext_msg = 24; + + map map_scalar = 25; + map map_enum = 26; + map map_ext_enum = 27; + map map_msg = 28; + map map_ext_msg = 29; + + enum Enum {VALUE = 0;} + + message Required { + required double double = 1; + required float float = 2; + required int64 int64 = 3; + required sfixed64 sfixed64 = 4; + required sint64 sint64 = 5; + required uint64 uint64 = 6; + required fixed64 fixed64 = 7; + required int32 int32 = 8; + required sfixed32 sfixed32 = 9; + required sint32 sint32 = 10; + required uint32 uint32 = 11; + required fixed32 fixed32 = 12; + required bool bool = 13; + required string string = 14; + required bytes bytes = 15; + required Enum enum = 16; + required google.protobuf.Syntax ext_enum = 17; + required Required msg = 18; + required google.protobuf.Duration ext_msg = 19; + } +} diff --git a/lang/go/testdata/names/types/proto3.proto b/lang/go/testdata/names/types/proto3.proto new file mode 100644 index 0000000..a797dd9 --- /dev/null +++ b/lang/go/testdata/names/types/proto3.proto @@ -0,0 +1,45 @@ +syntax="proto3"; +package names.types; +option go_package = "example.com/foo/bar"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/type.proto"; + +message Proto3 { + double double = 1; + float float = 2; + int64 int64 = 3; + sfixed64 sfixed64 = 4; + sint64 sint64 = 5; + uint64 uint64 = 6; + fixed64 fixed64 = 7; + int32 int32 = 8; + sfixed32 sfixed32 = 9; + sint32 sint32 = 10; + uint32 uint32 = 11; + fixed32 fixed32 = 12; + bool bool = 13; + string string = 14; + bytes bytes = 15; + + Enum enum = 16; + google.protobuf.Syntax ext_enum = 17; + Message msg = 18; + google.protobuf.Duration ext_msg = 19; + + repeated double repeated_scalar = 20; + repeated Enum repeated_enum = 21; + repeated google.protobuf.Syntax repeated_ext_enum = 22; + repeated Message repeated_msg = 23; + repeated google.protobuf.Duration repeated_ext_msg = 24; + + map map_scalar = 25; + map map_enum = 26; + map map_ext_enum = 27; + map map_msg = 28; + map map_ext_msg = 29; + + enum Enum {VALUE = 0;} + + message Message {} +} diff --git a/lang/go/testdata/names/unnamed/params b/lang/go/testdata/names/unnamed/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/names/unnamed/unnamed.proto b/lang/go/testdata/names/unnamed/unnamed.proto new file mode 100644 index 0000000..a674e38 --- /dev/null +++ b/lang/go/testdata/names/unnamed/unnamed.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +package names.unnamed; + +message Unnamed {} diff --git a/lang/go/testdata/outputs/import_prefix/params b/lang/go/testdata/outputs/import_prefix/params new file mode 100644 index 0000000..187385e --- /dev/null +++ b/lang/go/testdata/outputs/import_prefix/params @@ -0,0 +1 @@ +import_prefix=foo/ diff --git a/lang/go/testdata/outputs/import_prefix/prefix.proto b/lang/go/testdata/outputs/import_prefix/prefix.proto new file mode 100644 index 0000000..eeca6bd --- /dev/null +++ b/lang/go/testdata/outputs/import_prefix/prefix.proto @@ -0,0 +1,4 @@ +syntax="proto3"; +package outputs.import_prefix; +option go_package = "example.com/import_prefix"; +message ImportPrefix {} diff --git a/lang/go/testdata/outputs/import_prefix_srcrel/params b/lang/go/testdata/outputs/import_prefix_srcrel/params new file mode 100644 index 0000000..b2b1469 --- /dev/null +++ b/lang/go/testdata/outputs/import_prefix_srcrel/params @@ -0,0 +1,2 @@ +paths=source_relative,import_prefix=foo/ + diff --git a/lang/go/testdata/outputs/import_prefix_srcrel/prefix.proto b/lang/go/testdata/outputs/import_prefix_srcrel/prefix.proto new file mode 100644 index 0000000..eeca6bd --- /dev/null +++ b/lang/go/testdata/outputs/import_prefix_srcrel/prefix.proto @@ -0,0 +1,4 @@ +syntax="proto3"; +package outputs.import_prefix; +option go_package = "example.com/import_prefix"; +message ImportPrefix {} diff --git a/lang/go/testdata/outputs/mapped/mapped.proto b/lang/go/testdata/outputs/mapped/mapped.proto new file mode 100644 index 0000000..8970d7c --- /dev/null +++ b/lang/go/testdata/outputs/mapped/mapped.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.mapped; +option go_package="unaffected"; + +message Mapped {} diff --git a/lang/go/testdata/outputs/mapped/params b/lang/go/testdata/outputs/mapped/params new file mode 100644 index 0000000..b47ee5a --- /dev/null +++ b/lang/go/testdata/outputs/mapped/params @@ -0,0 +1 @@ +Mmapped.proto=foobar diff --git a/lang/go/testdata/outputs/mapped_srcrel/mapped.proto b/lang/go/testdata/outputs/mapped_srcrel/mapped.proto new file mode 100644 index 0000000..8970d7c --- /dev/null +++ b/lang/go/testdata/outputs/mapped_srcrel/mapped.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.mapped; +option go_package="unaffected"; + +message Mapped {} diff --git a/lang/go/testdata/outputs/mapped_srcrel/params b/lang/go/testdata/outputs/mapped_srcrel/params new file mode 100644 index 0000000..cf3c0e4 --- /dev/null +++ b/lang/go/testdata/outputs/mapped_srcrel/params @@ -0,0 +1 @@ +paths=source_relative,Mmapped.proto=foobar diff --git a/lang/go/testdata/outputs/none/none.proto b/lang/go/testdata/outputs/none/none.proto new file mode 100644 index 0000000..a7c41d7 --- /dev/null +++ b/lang/go/testdata/outputs/none/none.proto @@ -0,0 +1,3 @@ +syntax="proto3"; +package outputs.none; +message None {} diff --git a/lang/go/testdata/outputs/none/params b/lang/go/testdata/outputs/none/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/outputs/none_srcrel/none.proto b/lang/go/testdata/outputs/none_srcrel/none.proto new file mode 100644 index 0000000..c99805e --- /dev/null +++ b/lang/go/testdata/outputs/none_srcrel/none.proto @@ -0,0 +1,3 @@ +syntax="proto3"; + +message None {} diff --git a/lang/go/testdata/outputs/none_srcrel/params b/lang/go/testdata/outputs/none_srcrel/params new file mode 100644 index 0000000..3d387d0 --- /dev/null +++ b/lang/go/testdata/outputs/none_srcrel/params @@ -0,0 +1 @@ +paths=source_relative diff --git a/lang/go/testdata/outputs/qualified/params b/lang/go/testdata/outputs/qualified/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/outputs/qualified/qualified.proto b/lang/go/testdata/outputs/qualified/qualified.proto new file mode 100644 index 0000000..e821f85 --- /dev/null +++ b/lang/go/testdata/outputs/qualified/qualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.qualified; +option go_package="example.com/qualified"; + +message FullyQualified{} diff --git a/lang/go/testdata/outputs/qualified_srcrel/params b/lang/go/testdata/outputs/qualified_srcrel/params new file mode 100644 index 0000000..3d387d0 --- /dev/null +++ b/lang/go/testdata/outputs/qualified_srcrel/params @@ -0,0 +1 @@ +paths=source_relative diff --git a/lang/go/testdata/outputs/qualified_srcrel/qualified.proto b/lang/go/testdata/outputs/qualified_srcrel/qualified.proto new file mode 100644 index 0000000..e821f85 --- /dev/null +++ b/lang/go/testdata/outputs/qualified_srcrel/qualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.qualified; +option go_package="example.com/qualified"; + +message FullyQualified{} diff --git a/lang/go/testdata/outputs/unqualified/params b/lang/go/testdata/outputs/unqualified/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/outputs/unqualified/unqualified.proto b/lang/go/testdata/outputs/unqualified/unqualified.proto new file mode 100644 index 0000000..866369f --- /dev/null +++ b/lang/go/testdata/outputs/unqualified/unqualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.unqualified; +option go_package="unqualified"; + +message Unqualified{} diff --git a/lang/go/testdata/outputs/unqualified_srcrel/params b/lang/go/testdata/outputs/unqualified_srcrel/params new file mode 100644 index 0000000..3d387d0 --- /dev/null +++ b/lang/go/testdata/outputs/unqualified_srcrel/params @@ -0,0 +1 @@ +paths=source_relative diff --git a/lang/go/testdata/outputs/unqualified_srcrel/unqualified.proto b/lang/go/testdata/outputs/unqualified_srcrel/unqualified.proto new file mode 100644 index 0000000..866369f --- /dev/null +++ b/lang/go/testdata/outputs/unqualified_srcrel/unqualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package outputs.unqualified; +option go_package="unqualified"; + +message Unqualified{} diff --git a/lang/go/testdata/packages/import_prefix/import_prefix.proto b/lang/go/testdata/packages/import_prefix/import_prefix.proto new file mode 100644 index 0000000..b6fc199 --- /dev/null +++ b/lang/go/testdata/packages/import_prefix/import_prefix.proto @@ -0,0 +1,13 @@ +syntax="proto3"; +package packages.import_prefix; +option go_package="example.com/packages/import_prefix"; + +import "targets/fully_qualified/fully_qualified.proto"; +import "targets/unqualified/unqualified.proto"; +import "targets/none/none.proto"; + +message ImportPrefixed { + targets.fully_qualified.FullyQualified fully = 1; + targets.unqualified.Unqualified unqualified = 2; + targets.none.None none = 3; +} diff --git a/lang/go/testdata/packages/import_prefix/params b/lang/go/testdata/packages/import_prefix/params new file mode 100644 index 0000000..e25d003 --- /dev/null +++ b/lang/go/testdata/packages/import_prefix/params @@ -0,0 +1 @@ +Mtargets/none/none.proto=fizz/buzz,import_prefix=foo.bar/ diff --git a/lang/go/testdata/packages/mapped/mapped.proto b/lang/go/testdata/packages/mapped/mapped.proto new file mode 100644 index 0000000..8053ce0 --- /dev/null +++ b/lang/go/testdata/packages/mapped/mapped.proto @@ -0,0 +1,13 @@ +syntax="proto3"; +package packages.mapped; +option go_package="example.com/packages/mapped"; + +import "targets/fully_qualified/fully_qualified.proto"; +import "targets/unqualified/unqualified.proto"; +import "targets/none/none.proto"; + +message Mapped { + targets.fully_qualified.FullyQualified fully = 1; + targets.unqualified.Unqualified unqualified = 2; + targets.none.None none = 3; +} diff --git a/lang/go/testdata/packages/mapped/params b/lang/go/testdata/packages/mapped/params new file mode 100644 index 0000000..63595e8 --- /dev/null +++ b/lang/go/testdata/packages/mapped/params @@ -0,0 +1 @@ +Mtargets/fully_qualified/fully_qualified.proto=example.com/foo/bar,Mtargets/unqualified/unqualified.proto=example.com/fizz/buzz,Mtargets/none/none.proto=example.com/quux diff --git a/lang/go/testdata/packages/no_options/no_options.proto b/lang/go/testdata/packages/no_options/no_options.proto new file mode 100644 index 0000000..6df4ab5 --- /dev/null +++ b/lang/go/testdata/packages/no_options/no_options.proto @@ -0,0 +1,13 @@ +syntax="proto3"; +package packages.no_options; +option go_package="example.com/packages/no_options"; + +import "targets/fully_qualified/fully_qualified.proto"; +import "targets/unqualified/unqualified.proto"; +import "targets/none/none.proto"; + +message NoOptions { + targets.fully_qualified.FullyQualified fully = 1; + targets.unqualified.Unqualified unqualified = 2; + targets.none.None none = 3; +} diff --git a/lang/go/testdata/packages/no_options/params b/lang/go/testdata/packages/no_options/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/packages/targets/fully_qualified/fully_qualified.proto b/lang/go/testdata/packages/targets/fully_qualified/fully_qualified.proto new file mode 100644 index 0000000..154694d --- /dev/null +++ b/lang/go/testdata/packages/targets/fully_qualified/fully_qualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package packages.targets.fully_qualified; +option go_package="example.com/packages/targets/fully_qualified"; + +message FullyQualified{} diff --git a/lang/go/testdata/packages/targets/none/none.proto b/lang/go/testdata/packages/targets/none/none.proto new file mode 100644 index 0000000..5e9eb79 --- /dev/null +++ b/lang/go/testdata/packages/targets/none/none.proto @@ -0,0 +1,4 @@ +syntax="proto3"; +package packages.targets.none; + +message None{} diff --git a/lang/go/testdata/packages/targets/unqualified/unqualified.proto b/lang/go/testdata/packages/targets/unqualified/unqualified.proto new file mode 100644 index 0000000..d955772 --- /dev/null +++ b/lang/go/testdata/packages/targets/unqualified/unqualified.proto @@ -0,0 +1,5 @@ +syntax="proto3"; +package packages.targets.unqualified; +option go_package="unqualified"; + +message Unqualified{} diff --git a/lang/go/type_name.go b/lang/go/type_name.go new file mode 100644 index 0000000..c324197 --- /dev/null +++ b/lang/go/type_name.go @@ -0,0 +1,132 @@ +package pgsgo + +import ( + "fmt" + "log" + "strings" + + "github.com/lyft/protoc-gen-star" +) + +func (c context) Type(f pgs.Field) TypeName { + ft := f.Type() + + var t TypeName + switch { + case ft.IsMap(): + key := scalarType(ft.Key().ProtoType()) + return TypeName(fmt.Sprintf("map[%s]%s", key, c.elType(ft))) + case ft.IsRepeated(): + return TypeName(fmt.Sprintf("[]%s", c.elType(ft))) + case ft.IsEmbed(): + return c.importableTypeName(f, ft.Embed()).Pointer() + case ft.IsEnum(): + t = c.importableTypeName(f, ft.Enum()) + default: + t = scalarType(ft.ProtoType()) + } + + log.Println(f.File().Descriptor().GetSyntax()) + if f.Syntax() == pgs.Proto2 { + return t.Pointer() + } + + return t +} + +func (c context) importableTypeName(f pgs.Field, e pgs.Entity) TypeName { + t := TypeName(c.Name(e)) + + if c.ImportPath(e) == c.ImportPath(f) { + return t + } + + return TypeName(fmt.Sprintf("%s.%s", c.PackageName(e), t)) +} + +func (c context) elType(ft pgs.FieldType) TypeName { + el := ft.Element() + switch { + case el.IsEnum(): + return c.importableTypeName(ft.Field(), el.Enum()) + case el.IsEmbed(): + return c.importableTypeName(ft.Field(), el.Embed()).Pointer() + default: + return scalarType(el.ProtoType()) + } +} + +func scalarType(t pgs.ProtoType) TypeName { + switch t { + case pgs.DoubleT: + return "float64" + case pgs.FloatT: + return "float32" + case pgs.Int64T, pgs.SFixed64, pgs.SInt64: + return "int64" + case pgs.UInt64T, pgs.Fixed64T: + return "uint64" + case pgs.Int32T, pgs.SFixed32, pgs.SInt32: + return "int32" + case pgs.UInt32T, pgs.Fixed32T: + return "uint32" + case pgs.BoolT: + return "bool" + case pgs.StringT: + return "string" + case pgs.BytesT: + return "[]byte" + default: + panic("unreachable: invalid scalar type") + } +} + +// A TypeName describes the name of a type (type on a field, or method signature) +type TypeName string + +// String satisfies the strings.Stringer interface. +func (n TypeName) String() string { return string(n) } + +// Element returns the TypeName of the element of n. For types other than +// slices and maps, this just returns n. +func (n TypeName) Element() TypeName { + parts := strings.SplitN(string(n), "]", 2) + return TypeName(parts[len(parts)-1]) +} + +// Key returns the TypeName of the key of n. For slices, the return TypeName is +// always "int", and for non slice/map types an empty TypeName is returned. +func (n TypeName) Key() TypeName { + parts := strings.SplitN(string(n), "]", 2) + if len(parts) == 1 { + return TypeName("") + } + + parts = strings.SplitN(parts[0], "[", 2) + if len(parts) != 2 { + return TypeName("") + } else if parts[1] == "" { + return TypeName("int") + } + + return TypeName(parts[1]) +} + +// Pointer converts TypeName n to it's pointer type. If n is already a pointer, +// slice, or map, it is returned unmodified. +func (n TypeName) Pointer() TypeName { + ns := string(n) + if strings.HasPrefix(ns, "*") || + strings.HasPrefix(ns, "[") || + strings.HasPrefix(ns, "map[") { + return n + } + + return TypeName("*" + ns) +} + +// Value converts TypeName n to it's value type. If n is already a value type, +// slice, or map it is returned unmodified. +func (n TypeName) Value() TypeName { + return TypeName(strings.TrimPrefix(string(n), "*")) +} diff --git a/lang/go/type_name_test.go b/lang/go/type_name_test.go new file mode 100644 index 0000000..9d721a9 --- /dev/null +++ b/lang/go/type_name_test.go @@ -0,0 +1,345 @@ +package pgsgo + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/lyft/protoc-gen-star" + + "github.com/stretchr/testify/assert" +) + +func TestType(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", "types") + ctx := loadContext(t, "names", "types") + + tests := []struct { + field string + expected TypeName + }{ + // proto2 syntax, optional + {"Proto2.double", "*float64"}, + {"Proto2.float", "*float32"}, + {"Proto2.int64", "*int64"}, + {"Proto2.sfixed64", "*int64"}, + {"Proto2.sint64", "*int64"}, + {"Proto2.uint64", "*uint64"}, + {"Proto2.fixed64", "*uint64"}, + {"Proto2.int32", "*int32"}, + {"Proto2.sfixed32", "*int32"}, + {"Proto2.sint32", "*int32"}, + {"Proto2.uint32", "*uint32"}, + {"Proto2.fixed32", "*uint32"}, + {"Proto2.bool", "*bool"}, + {"Proto2.string", "*string"}, + {"Proto2.bytes", "[]byte"}, + {"Proto2.enum", "*Proto2_Enum"}, + {"Proto2.ext_enum", "*ptype.Syntax"}, + {"Proto2.msg", "*Proto2_Required"}, + {"Proto2.ext_msg", "*duration.Duration"}, + {"Proto2.repeated_scalar", "[]float64"}, + {"Proto2.repeated_enum", "[]Proto2_Enum"}, + {"Proto2.repeated_ext_enum", "[]ptype.Syntax"}, + {"Proto2.repeated_msg", "[]*Proto2_Required"}, + {"Proto2.repeated_ext_msg", "[]*duration.Duration"}, + {"Proto2.map_scalar", "map[string]float32"}, + {"Proto2.map_enum", "map[int32]Proto2_Enum"}, + {"Proto2.map_ext_enum", "map[uint64]ptype.Syntax"}, + {"Proto2.map_msg", "map[uint32]*Proto2_Required"}, + {"Proto2.map_ext_msg", "map[int64]*duration.Duration"}, + + // proto2 syntax, required + {"Proto2.Required.double", "*float64"}, + {"Proto2.Required.float", "*float32"}, + {"Proto2.Required.int64", "*int64"}, + {"Proto2.Required.sfixed64", "*int64"}, + {"Proto2.Required.sint64", "*int64"}, + {"Proto2.Required.uint64", "*uint64"}, + {"Proto2.Required.fixed64", "*uint64"}, + {"Proto2.Required.int32", "*int32"}, + {"Proto2.Required.sfixed32", "*int32"}, + {"Proto2.Required.sint32", "*int32"}, + {"Proto2.Required.uint32", "*uint32"}, + {"Proto2.Required.fixed32", "*uint32"}, + {"Proto2.Required.bool", "*bool"}, + {"Proto2.Required.string", "*string"}, + {"Proto2.Required.bytes", "[]byte"}, + {"Proto2.Required.enum", "*Proto2_Enum"}, + {"Proto2.Required.ext_enum", "*ptype.Syntax"}, + {"Proto2.Required.msg", "*Proto2_Required"}, + {"Proto2.Required.ext_msg", "*duration.Duration"}, + + {"Proto3.double", "float64"}, + {"Proto3.float", "float32"}, + {"Proto3.int64", "int64"}, + {"Proto3.sfixed64", "int64"}, + {"Proto3.sint64", "int64"}, + {"Proto3.uint64", "uint64"}, + {"Proto3.fixed64", "uint64"}, + {"Proto3.int32", "int32"}, + {"Proto3.sfixed32", "int32"}, + {"Proto3.sint32", "int32"}, + {"Proto3.uint32", "uint32"}, + {"Proto3.fixed32", "uint32"}, + {"Proto3.bool", "bool"}, + {"Proto3.string", "string"}, + {"Proto3.bytes", "[]byte"}, + {"Proto3.enum", "Proto3_Enum"}, + {"Proto3.ext_enum", "ptype.Syntax"}, + {"Proto3.msg", "*Proto3_Message"}, + {"Proto3.ext_msg", "*duration.Duration"}, + {"Proto3.repeated_scalar", "[]float64"}, + {"Proto3.repeated_enum", "[]Proto3_Enum"}, + {"Proto3.repeated_ext_enum", "[]ptype.Syntax"}, + {"Proto3.repeated_msg", "[]*Proto3_Message"}, + {"Proto3.repeated_ext_msg", "[]*duration.Duration"}, + {"Proto3.map_scalar", "map[string]float32"}, + {"Proto3.map_enum", "map[int32]Proto3_Enum"}, + {"Proto3.map_ext_enum", "map[uint64]ptype.Syntax"}, + {"Proto3.map_msg", "map[uint32]*Proto3_Message"}, + {"Proto3.map_ext_msg", "map[int64]*duration.Duration"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.types." + tc.field) + require.True(t, ok, "could not find field") + + fld, ok := e.(pgs.Field) + require.True(t, ok, "entity is not a field") + + assert.Equal(t, tc.expected, ctx.Type(fld)) + }) + } +} + +func TestTypeName(t *testing.T) { + t.Parallel() + + tests := []struct { + in string + el string + key string + ptr string + val string + }{ + { + in: "int", + el: "int", + ptr: "*int", + val: "int", + }, + { + in: "*int", + el: "*int", + ptr: "*int", + val: "int", + }, + { + in: "foo.bar", + el: "foo.bar", + ptr: "*foo.bar", + val: "foo.bar", + }, + { + in: "*foo.bar", + el: "*foo.bar", + ptr: "*foo.bar", + val: "foo.bar", + }, + { + in: "[]string", + el: "string", + key: "int", + ptr: "[]string", + val: "[]string", + }, + { + in: "[]*string", + el: "*string", + key: "int", + ptr: "[]*string", + val: "[]*string", + }, + { + in: "[]foo.bar", + el: "foo.bar", + key: "int", + ptr: "[]foo.bar", + val: "[]foo.bar", + }, + { + in: "[]*foo.bar", + el: "*foo.bar", + key: "int", + ptr: "[]*foo.bar", + val: "[]*foo.bar", + }, + { + in: "map[string]float64", + el: "float64", + key: "string", + ptr: "map[string]float64", + val: "map[string]float64", + }, + { + in: "map[string]*float64", + el: "*float64", + key: "string", + ptr: "map[string]*float64", + val: "map[string]*float64", + }, + { + in: "map[string]foo.bar", + el: "foo.bar", + key: "string", + ptr: "map[string]foo.bar", + val: "map[string]foo.bar", + }, + { + in: "map[string]*foo.bar", + el: "*foo.bar", + key: "string", + ptr: "map[string]*foo.bar", + val: "map[string]*foo.bar", + }, + { + in: "[][]byte", + el: "[]byte", + key: "int", + ptr: "[][]byte", + val: "[][]byte", + }, + { + in: "map[int64][]byte", + el: "[]byte", + key: "int64", + ptr: "map[int64][]byte", + val: "map[int64][]byte", + }, + } + + for _, test := range tests { + tc := test + t.Run(tc.in, func(t *testing.T) { + tn := TypeName(tc.in) + t.Parallel() + + t.Run("Element", func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.el, tn.Element().String()) + }) + + t.Run("Key", func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.key, tn.Key().String()) + }) + + t.Run("Pointer", func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.ptr, tn.Pointer().String()) + }) + + t.Run("Value", func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.val, tn.Value().String()) + }) + }) + } +} + +func TestTypeName_Key_Malformed(t *testing.T) { + t.Parallel() + tn := TypeName("]malformed") + assert.Empty(t, tn.Key().String()) +} + +func TestScalarType_Invalid(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + scalarType(pgs.ProtoType(0)) + }) +} + +func ExampleTypeName_Element() { + types := []string{ + "int", + "*my.Type", + "[]string", + "map[string]*io.Reader", + } + + for _, t := range types { + fmt.Println(TypeName(t).Element()) + } + + // Output: + // int + // *my.Type + // string + // *io.Reader +} + +func ExampleTypeName_Key() { + types := []string{ + "int", + "*my.Type", + "[]string", + "map[string]*io.Reader", + } + + for _, t := range types { + fmt.Println(TypeName(t).Key()) + } + + // Output: + // + // + // int + // string +} + +func ExampleTypeName_Pointer() { + types := []string{ + "int", + "*my.Type", + "[]string", + "map[string]*io.Reader", + } + + for _, t := range types { + fmt.Println(TypeName(t).Pointer()) + } + + // Output: + // *int + // *my.Type + // []string + // map[string]*io.Reader +} + +func ExampleTypeName_Value() { + types := []string{ + "int", + "*my.Type", + "[]string", + "map[string]*io.Reader", + } + + for _, t := range types { + fmt.Println(TypeName(t).Value()) + } + + // Output: + // int + // my.Type + // []string + // map[string]*io.Reader +} diff --git a/message.go b/message.go index 2745021..cdb27e0 100644 --- a/message.go +++ b/message.go @@ -1,26 +1,19 @@ package pgs import ( - "strings" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" ) -// Message describes a proto message, akin to a struct in Go. Messages can be -// contained in either another Message or File, and may house further Messages -// and/or Enums. While all Fields technically live on the Message, some may be -// contained within OneOf blocks. +// Message describes a proto message. Messages can be contained in either +// another Message or File, and may house further Messages and/or Enums. While +// all Fields technically live on the Message, some may be contained within +// OneOf blocks. type Message interface { ParentEntity - // TypeName returns the type of this message as it would be created in Go. - // This value will only differ from Name for nested messages. - TypeName() TypeName - // Descriptor returns the underlying proto descriptor for this message - Descriptor() *generator.Descriptor + Descriptor() *descriptor.DescriptorProto // Parent returns either the File or Message that directly contains this // Message. @@ -44,45 +37,58 @@ type Message interface { // to the wire format. IsMapEntry() bool + // IsWellKnown identifies whether or not this Message is a WKT from the + // `google.protobuf` package. Most official plugins special case these types + // and they usually need to be handled differently. + IsWellKnown() bool + + // WellKnownType returns the WellKnownType associated with this field. If + // IsWellKnown returns false, UnknownWKT is returned. + WellKnownType() WellKnownType + setParent(p ParentEntity) addField(f Field) addOneOf(o OneOf) } -// An MessageParent is any Entity type that can contain messages. File and -// Message types implement MessageParent. - type msg struct { + desc *descriptor.DescriptorProto parent ParentEntity - msgs []Message - enums []Enum - fields []Field - oneofs []OneOf - mapEntries []Message - - rawDesc *descriptor.DescriptorProto - genDesc *generator.Descriptor - - comments string -} - -func (m *msg) Name() Name { return Name(m.rawDesc.GetName()) } -func (m *msg) FullyQualifiedName() string { return fullyQualifiedName(m.parent, m) } -func (m *msg) Syntax() Syntax { return m.parent.Syntax() } -func (m *msg) Package() Package { return m.parent.Package() } -func (m *msg) File() File { return m.parent.File() } -func (m *msg) BuildTarget() bool { return m.parent.BuildTarget() } -func (m *msg) Comments() string { return m.comments } -func (m *msg) Descriptor() *generator.Descriptor { return m.genDesc } -func (m *msg) Parent() ParentEntity { return m.parent } -func (m *msg) IsMapEntry() bool { return m.rawDesc.GetOptions().GetMapEntry() } -func (m *msg) TypeName() TypeName { return TypeName(strings.Join(m.genDesc.TypeName(), "_")) } - -func (m *msg) Enums() []Enum { - es := make([]Enum, len(m.enums)) - copy(es, m.enums) - return es + msgs, preservedMsgs []Message + enums []Enum + fields []Field + oneofs []OneOf + maps []Message + + info SourceCodeInfo +} + +func (m *msg) Name() Name { return Name(m.desc.GetName()) } +func (m *msg) FullyQualifiedName() string { return fullyQualifiedName(m.parent, m) } +func (m *msg) Syntax() Syntax { return m.parent.Syntax() } +func (m *msg) Package() Package { return m.parent.Package() } +func (m *msg) File() File { return m.parent.File() } +func (m *msg) BuildTarget() bool { return m.parent.BuildTarget() } +func (m *msg) SourceCodeInfo() SourceCodeInfo { return m.info } +func (m *msg) Descriptor() *descriptor.DescriptorProto { return m.desc } +func (m *msg) Parent() ParentEntity { return m.parent } +func (m *msg) IsMapEntry() bool { return m.desc.GetOptions().GetMapEntry() } +func (m *msg) Enums() []Enum { return m.enums } +func (m *msg) Messages() []Message { return m.msgs } +func (m *msg) Fields() []Field { return m.fields } +func (m *msg) OneOfs() []OneOf { return m.oneofs } +func (m *msg) MapEntries() []Message { return m.maps } + +func (m *msg) WellKnownType() WellKnownType { + if m.Package().ProtoName() == WellKnownTypePackage { + return LookupWKT(m.Name()) + } + return UnknownWKT +} + +func (m *msg) IsWellKnown() bool { + return m.WellKnownType().Valid() } func (m *msg) AllEnums() []Enum { @@ -93,12 +99,6 @@ func (m *msg) AllEnums() []Enum { return es } -func (m *msg) Messages() []Message { - msgs := make([]Message, len(m.msgs)) - copy(msgs, m.msgs) - return msgs -} - func (m *msg) AllMessages() []Message { msgs := m.Messages() for _, sm := range m.msgs { @@ -107,18 +107,6 @@ func (m *msg) AllMessages() []Message { return msgs } -func (m *msg) MapEntries() []Message { - me := make([]Message, len(m.mapEntries)) - copy(me, m.mapEntries) - return me -} - -func (m *msg) Fields() []Field { - f := make([]Field, len(m.fields)) - copy(f, m.fields) - return f -} - func (m *msg) NonOneOfFields() (f []Field) { for _, fld := range m.fields { if !fld.InOneOf() { @@ -136,13 +124,7 @@ func (m *msg) OneOfFields() (f []Field) { return f } -func (m *msg) OneOfs() []OneOf { - o := make([]OneOf, len(m.oneofs)) - copy(o, m.oneofs) - return o -} - -func (m *msg) Imports() (i []Package) { +func (m *msg) Imports() (i []File) { for _, f := range m.fields { i = append(i, f.Imports()...) } @@ -150,7 +132,7 @@ func (m *msg) Imports() (i []Package) { } func (m *msg) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, error) { - return extension(m.rawDesc.GetOptions(), desc, &ext) + return extension(m.desc.GetOptions(), desc, &ext) } func (m *msg) accept(v Visitor) (err error) { @@ -213,7 +195,34 @@ func (m *msg) addOneOf(o OneOf) { func (m *msg) addMapEntry(me Message) { me.setParent(m) - m.mapEntries = append(m.mapEntries, me) + m.maps = append(m.maps, me) +} + +func (m *msg) childAtPath(path []int32) Entity { + switch { + case len(path) == 0: + return m + case len(path)%2 != 0: + return nil + } + + var child Entity + switch path[0] { + case messageTypeFieldPath: + child = m.fields[path[1]] + case messageTypeNestedTypePath: + child = m.preservedMsgs[path[1]] + case messageTypeEnumTypePath: + child = m.enums[path[1]] + case messageTypeOneofDeclPath: + child = m.oneofs[path[1]] + default: + return nil + } + + return child.childAtPath(path[2:]) } +func (m *msg) addSourceCodeInfo(info SourceCodeInfo) { m.info = info } + var _ Message = (*msg)(nil) diff --git a/message_test.go b/message_test.go index 3e7a992..472b14e 100644 --- a/message_test.go +++ b/message_test.go @@ -4,16 +4,17 @@ import ( "errors" "testing" + desc "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" + "github.com/golang/protobuf/ptypes/any" "github.com/stretchr/testify/assert" ) func TestMsg_Name(t *testing.T) { t.Parallel() - m := &msg{rawDesc: &descriptor.DescriptorProto{Name: proto.String("msg")}} + m := &msg{desc: &descriptor.DescriptorProto{Name: proto.String("msg")}} assert.Equal(t, "msg", m.Name().String()) } @@ -21,20 +22,13 @@ func TestMsg_Name(t *testing.T) { func TestMsg_FullyQualifiedName(t *testing.T) { t.Parallel() - m := &msg{rawDesc: &descriptor.DescriptorProto{Name: proto.String("msg")}} + m := &msg{desc: &descriptor.DescriptorProto{Name: proto.String("msg")}} f := dummyFile() f.addMessage(m) assert.Equal(t, f.FullyQualifiedName()+".msg", m.FullyQualifiedName()) } -func TestMsg_TypeName(t *testing.T) { - t.Parallel() - - m := dummyMsg() - assert.Equal(t, m.Name().String(), m.TypeName().String()) -} - func TestMsg_Syntax(t *testing.T) { t.Parallel() @@ -82,8 +76,8 @@ func TestMsg_BuildTarget(t *testing.T) { func TestMsg_Descriptor(t *testing.T) { t.Parallel() - m := &msg{genDesc: &generator.Descriptor{}} - assert.Equal(t, m.genDesc, m.Descriptor()) + m := &msg{desc: &descriptor.DescriptorProto{}} + assert.Equal(t, m.desc, m.Descriptor()) } func TestMsg_Parent(t *testing.T) { @@ -99,10 +93,10 @@ func TestMsg_Parent(t *testing.T) { func TestMsg_IsMapEntry(t *testing.T) { t.Parallel() - m := &msg{rawDesc: &descriptor.DescriptorProto{}} + m := &msg{desc: &descriptor.DescriptorProto{}} assert.False(t, m.IsMapEntry()) - m.rawDesc.Options = &descriptor.MessageOptions{ + m.desc.Options = &descriptor.MessageOptions{ MapEntry: proto.Bool(true), } assert.True(t, m.IsMapEntry()) @@ -222,7 +216,7 @@ func TestMsg_OneOfs(t *testing.T) { func TestMsg_Extension(t *testing.T) { // cannot be parallel - m := &msg{rawDesc: &descriptor.DescriptorProto{}} + m := &msg{desc: &descriptor.DescriptorProto{}} assert.NotPanics(t, func() { m.Extension(nil, nil) }) } @@ -299,18 +293,48 @@ func TestMsg_Imports(t *testing.T) { m := &msg{} assert.Empty(t, m.Imports()) - m.addField(&mockField{i: []Package{&pkg{}, &pkg{}}}) + m.addField(&mockField{i: []File{&file{}, &file{}}}) assert.Len(t, m.Imports(), 2) } +func TestMsg_ChildAtPath(t *testing.T) { + t.Parallel() + + m := &msg{} + assert.Equal(t, m, m.childAtPath(nil)) + assert.Nil(t, m.childAtPath([]int32{1})) + assert.Nil(t, m.childAtPath([]int32{999, 456})) +} + +func TestMsg_WellKnownType(t *testing.T) { + fd, md := desc.ForMessage(&any.Any{}) + p := &pkg{fd: fd} + f := &file{desc: fd} + m := &msg{desc: md} + f.addMessage(m) + p.addFile(f) + + assert.True(t, m.IsWellKnown()) + assert.Equal(t, AnyWKT, m.WellKnownType()) + + m.desc.Name = proto.String("Foobar") + assert.False(t, m.IsWellKnown()) + assert.Equal(t, UnknownWKT, m.WellKnownType()) + + m.desc.Name = proto.String("Any") + f.desc.Package = proto.String("fizz.buzz") + assert.False(t, m.IsWellKnown()) + assert.Equal(t, UnknownWKT, m.WellKnownType()) +} + type mockMessage struct { Message - i []Package + i []File p ParentEntity err error } -func (m *mockMessage) Imports() []Package { return m.i } +func (m *mockMessage) Imports() []File { return m.i } func (m *mockMessage) setParent(p ParentEntity) { m.p = p } @@ -326,9 +350,8 @@ func dummyMsg() *msg { f := dummyFile() m := &msg{ - rawDesc: &descriptor.DescriptorProto{Name: proto.String("message")}, + desc: &descriptor.DescriptorProto{Name: proto.String("message")}, } - m.genDesc = &generator.Descriptor{DescriptorProto: m.rawDesc} f.addMessage(m) return m diff --git a/method.go b/method.go index 2798a18..14af5e4 100644 --- a/method.go +++ b/method.go @@ -18,13 +18,13 @@ type Method interface { // Input returns the Message representing the input type for this. Input() Message - // Output returns the Message representing this output type for this. + // Output returns the Message representing the output type for this. Output() Message - // ClientStreaming indicates if this method allows clients to stream inputs + // ClientStreaming indicates if this method allows clients to stream inputs. ClientStreaming() bool - // ServerStreaming indicates if this method allows servers to stream outputs + // ServerStreaming indicates if this method allows servers to stream outputs. ServerStreaming() bool setService(Service) @@ -36,7 +36,7 @@ type method struct { in, out Message - comments string + info SourceCodeInfo } func (m *method) Name() Name { return Name(m.desc.GetName()) } @@ -45,7 +45,7 @@ func (m *method) Syntax() Syntax { return m.servi func (m *method) Package() Package { return m.service.Package() } func (m *method) File() File { return m.service.File() } func (m *method) BuildTarget() bool { return m.service.BuildTarget() } -func (m *method) Comments() string { return m.comments } +func (m *method) SourceCodeInfo() SourceCodeInfo { return m.info } func (m *method) Descriptor() *descriptor.MethodDescriptorProto { return m.desc } func (m *method) Service() Service { return m.service } func (m *method) Input() Message { return m.in } @@ -54,14 +54,14 @@ func (m *method) ClientStreaming() bool { return m.desc. func (m *method) ServerStreaming() bool { return m.desc.GetServerStreaming() } func (m *method) BiDirStreaming() bool { return m.ClientStreaming() && m.ServerStreaming() } -func (m *method) Imports() (i []Package) { - mine := m.Package().GoName() +func (m *method) Imports() (i []File) { + mine := m.File().Name() - if input := m.Input().Package(); mine != input.GoName() { + if input := m.Input().File(); mine != input.Name() { i = append(i, input) } - if output := m.Output().Package(); mine != output.GoName() { + if output := m.Output().File(); mine != output.Name() { i = append(i, output) } @@ -82,3 +82,14 @@ func (m *method) accept(v Visitor) (err error) { } func (m *method) setService(s Service) { m.service = s } + +func (m *method) childAtPath(path []int32) Entity { + if len(path) == 0 { + return m + } + return nil +} + +func (m *method) addSourceCodeInfo(info SourceCodeInfo) { m.info = info } + +var m Method = (*method)(nil) diff --git a/method_test.go b/method_test.go index 82c729b..53fc465 100644 --- a/method_test.go +++ b/method_test.go @@ -135,9 +135,9 @@ func TestMethod_Imports(t *testing.T) { s.addMethod(m) assert.Empty(t, m.Imports()) - m.in = &msg{parent: &file{pkg: &pkg{name: "not_the_same"}}} + m.in = &msg{parent: &file{pkg: &pkg{comments: "not_the_same"}}} assert.Len(t, m.Imports(), 1) - m.out = &msg{parent: &file{pkg: &pkg{name: "other_import"}}} + m.out = &msg{parent: &file{pkg: &pkg{comments: "other_import"}}} assert.Len(t, m.Imports(), 2) } @@ -160,14 +160,23 @@ func TestMethod_Accept(t *testing.T) { assert.Equal(t, 1, v.method) } +func TestMethod_ChildAtPath(t *testing.T) { + t.Parallel() + + m := &method{} + + assert.Equal(t, m, m.childAtPath(nil)) + assert.Nil(t, m.childAtPath([]int32{1})) +} + type mockMethod struct { Method - i []Package + i []File s Service err error } -func (m *mockMethod) Imports() []Package { return m.i } +func (m *mockMethod) Imports() []File { return m.i } func (m *mockMethod) setService(s Service) { m.s = s } diff --git a/module.go b/module.go index cf85207..069cf2a 100644 --- a/module.go +++ b/module.go @@ -3,9 +3,7 @@ package pgs import "os" // Module describes the interface for a domain-specific code generation module -// that can be registered with the pgs generator. A module should be used over -// a generator.Plugin if generating code NOT included in the *.pg.go file is -// desired. +// that can be registered with the PG* generator. type Module interface { // The Name of the Module, used when establishing the build context and used // as the base prefix for all debugger output. @@ -15,27 +13,27 @@ type Module interface { // should be stored and used by the Module. InitContext(c BuildContext) - // Execute is called on the module with the target Package as well as all + // Execute is called on the module with the target Files as well as all // loaded Packages from the gatherer. The module should return a slice of - // Artifacts that it would like to be generated. If a Module is used in - // multi-package mode and this module does not implement MultiModule, Execute - // will be called multiple times for each target Package. - Execute(target Package, packages map[string]Package) []Artifact -} - -// MultiModule adds special behavior to a Module that expects multi-package -// mode to be enabled for this protoc-plugin. -type MultiModule interface { - // MultiExecute is called instead of Execute for multi-package protoc - // executions. If this method is not present on a Module, Execute is called - // individually for each target package. - MultiExecute(targets map[string]Package, packages map[string]Package) []Artifact + // Artifacts that it would like to be generated. + Execute(targets map[string]File, packages map[string]Package) []Artifact } // ModuleBase provides utility methods and a base implementation for a // protoc-gen-star Module. ModuleBase should be used as an anonymously embedded // field of an actual Module implementation. The only methods that need to be // overridden are Name and Execute. +// +// type MyModule { +// *pgs.ModuleBase +// } +// +// func InitMyModule() *MyModule { return &MyModule{ &pgs.ModuleBase{} } } +// +// func (m *MyModule) Name() string { return "MyModule" } +// +// func (m *MyModule) Execute(...) []pgs.Artifact { ... } +// type ModuleBase struct { BuildContext artifacts []Artifact @@ -58,7 +56,7 @@ func (m *ModuleBase) Name() string { // Execute satisfies the Module interface, however this method will fail and // must be overridden by a parent struct. -func (m *ModuleBase) Execute(target Package, packages map[string]Package) []Artifact { +func (m *ModuleBase) Execute(targets map[string]File, packages map[string]Package) []Artifact { m.Fail("Execute method is not implemented for this module") return m.Artifacts() } @@ -101,9 +99,9 @@ func (m *ModuleBase) Artifacts() []Artifact { return out } -// AddArtifact adds a to this Module's collection of generation artifacts. This -// method is available as a convenience but the other Add/Overwrite methods -// should be used preferentially. +// AddArtifact adds an Artifact to this Module's collection of generation +// artifacts. This method is available as a convenience but the other Add & +// Overwrite methods should be used preferentially. func (m *ModuleBase) AddArtifact(a ...Artifact) { m.artifacts = append(m.artifacts, a...) } // AddGeneratorFile adds a file with the provided name and contents to the code @@ -119,8 +117,8 @@ func (m *ModuleBase) AddGeneratorFile(name, content string) { } // OverwriteGeneratorFile behaves the same as AddGeneratorFile, however if a -// previously executed Plugin or Module has created a file with the same name, -// it will be overwritten with this one. +// previously executed Module has created a file with the same name, it will be +// overwritten with this one. func (m *ModuleBase) OverwriteGeneratorFile(name, content string) { m.AddArtifact(GeneratorFile{ Name: name, diff --git a/module_test.go b/module_test.go index bb1ae0c..a857ae1 100644 --- a/module_test.go +++ b/module_test.go @@ -18,29 +18,17 @@ func newMockModule() *mockModule { return &mockModule{ModuleBase: &ModuleBase{}} func (m *mockModule) Name() string { return m.name } -func (m *mockModule) Execute(pkg Package, pkgs map[string]Package) []Artifact { +func (m *mockModule) Execute(targets map[string]File, packages map[string]Package) []Artifact { m.executed = true return nil } -type multiMockModule struct { - *mockModule - multiExecuted bool -} - -func newMultiMockModule() *multiMockModule { return &multiMockModule{mockModule: newMockModule()} } - -func (m *multiMockModule) MultiExecute(targets map[string]Package, packages map[string]Package) []Artifact { - m.multiExecuted = true - return nil -} - func TestModuleBase_InitContext(t *testing.T) { t.Parallel() m := new(ModuleBase) assert.Nil(t, m.BuildContext) - bc := Context(newMockDebugger(t), Parameters{}, ".") + bc := Context(InitMockDebugger(), Parameters{}, ".") m.InitContext(bc) assert.NotNil(t, m.BuildContext) } @@ -56,18 +44,18 @@ func TestModuleBase_Execute(t *testing.T) { t.Parallel() m := new(ModuleBase) - d := newMockDebugger(t) + d := InitMockDebugger() m.InitContext(Context(d, Parameters{}, ".")) assert.NotPanics(t, func() { m.Execute(nil, nil) }) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestModuleBase_PushPop(t *testing.T) { t.Parallel() m := new(ModuleBase) - m.InitContext(Context(newMockDebugger(t), Parameters{}, ".")) + m.InitContext(Context(InitMockDebugger(), Parameters{}, ".")) m.Push("foo") m.Pop() } @@ -76,7 +64,7 @@ func TestModuleBase_PushPopDir(t *testing.T) { t.Parallel() m := new(ModuleBase) - m.InitContext(Context(newMockDebugger(t), Parameters{}, "foo")) + m.InitContext(Context(InitMockDebugger(), Parameters{}, "foo")) m.PushDir("bar") assert.Equal(t, "foo/bar", m.OutputPath()) m.PopDir() diff --git a/name.go b/name.go index cfb16e3..54ad386 100644 --- a/name.go +++ b/name.go @@ -8,23 +8,10 @@ import ( "unicode/utf8" "path/filepath" - - "github.com/golang/protobuf/protoc-gen-go/generator" ) -var protectedNames = map[Name]Name{ - "Reset": "Reset_", - "String": "String_", - "ProtoMessage": "ProtoMessage_", - "Marshal": "Marshal_", - "Unmarshal": "Unmarshal_", - "ExtensionRangeArray": "ExtensionRangeArray_", - "ExtensionMap": "ExtensionMap_", - "Descriptor": "Descriptor_", -} - -// A Name describes a symbol (Message, Field, Enum, Service, Field) of the -// Entity. It can be converted to multiple forms using the provided helper +// A Name describes an identifier of an Entity (Message, Field, Enum, Service, +// Field). It can be converted to multiple forms using the provided helper // methods, or a custom transform can be used to modify its behavior. type Name string @@ -35,27 +22,6 @@ func (n Name) String() string { return string(n) } // title-cased and concatenated with no separator. func (n Name) UpperCamelCase() Name { return n.Transform(strings.Title, strings.Title, "") } -// PGGUpperCamelCase converts Name n to the protoc-gen-go defined upper -// camelcase. The rules are slightly different from UpperCamelCase in that -// leading underscores are converted to 'X', mid-string underscores followed by -// lowercase letters are removed and the letter is capitalized, all other -// punctuation is preserved. This method should be used when deriving names of -// protoc-gen-go generated code (ie, message/service struct names and field -// names). In addition, this method ensures the Name does not conflict with one -// of the generated method names, appending the fields with an underscore in -// the same manner as protoc-gen-go. -// -// See: https://godoc.org/github.com/golang/protobuf/protoc-gen-go/generator#CamelCase -func (n Name) PGGUpperCamelCase() Name { - out := Name(generator.CamelCase(n.String())) - - if use, protected := protectedNames[out]; protected { - return use - } - - return out -} - // LowerCamelCase converts Name n to lower camelcase, where each part is // title-cased and concatenated with no separator except the first which is // lower-cased. @@ -73,6 +39,10 @@ func (n Name) LowerSnakeCase() Name { return n.Transform(strings.ToLower, string // title-cased and concatenated with underscores. func (n Name) UpperSnakeCase() Name { return n.Transform(strings.Title, strings.Title, "_") } +// SnakeCase converts Name n to snake-case, where each part preserves its +// capitalization and concatenated with underscores. +func (n Name) SnakeCase() Name { return n.Transform(ID, ID, "_") } + // LowerDotNotation converts Name n to lower dot notation, where each part is // lower-cased and concatenated with periods. func (n Name) LowerDotNotation() Name { return n.Transform(strings.ToLower, strings.ToLower, ".") } @@ -142,6 +112,9 @@ func (n Name) Split() (parts []string) { // the standard strings package satisfy this signature. type NameTransformer func(string) string +// ID is a NameTransformer that does not mutate the string. +func ID(s string) string { return s } + // Chain combines the behavior of two Transformers into one. If multiple // transformations need to be performed on a Name, this method should be used // to reduce it to a single transformation before applying. @@ -168,56 +141,6 @@ func (n Name) Transform(mod, first NameTransformer, sep string) Name { return Name(strings.Join(parts, sep)) } -// A TypeName describes the name of a type (type on a field, or method signature) -type TypeName string - -// String satisfies the strings.Stringer interface. -func (n TypeName) String() string { return string(n) } - -// Element returns the TypeName of the element of n. For types other than -// slices and maps, this just returns n. -func (n TypeName) Element() TypeName { - parts := strings.SplitN(string(n), "]", 2) - return TypeName(parts[len(parts)-1]) -} - -// Key returns the TypeName of the key of n. For slices, the return TypeName is -// always "int", and for non slice/map types an empty TypeName is returned. -func (n TypeName) Key() TypeName { - parts := strings.SplitN(string(n), "]", 2) - if len(parts) == 1 { - return TypeName("") - } - - parts = strings.SplitN(parts[0], "[", 2) - if len(parts) != 2 { - return TypeName("") - } else if parts[1] == "" { - return TypeName("int") - } - - return TypeName(parts[1]) -} - -// Pointer converts TypeName n to it's pointer type. If n is already a pointer, -// slice, or map, it is returned unmodified. -func (n TypeName) Pointer() TypeName { - ns := string(n) - if strings.HasPrefix(ns, "*") || - strings.HasPrefix(ns, "[") || - strings.HasPrefix(ns, "map[") { - return n - } - - return TypeName("*" + ns) -} - -// Value converts TypeName n to it's value type. If n is already a value type, -// slice, or map it is returned unmodified. -func (n TypeName) Value() TypeName { - return TypeName(strings.TrimPrefix(string(n), "*")) -} - // A FilePath describes the name of a file or directory. This type simplifies // path related operations. type FilePath string @@ -250,8 +173,9 @@ func (n FilePath) SetExt(ext string) FilePath { return n.SetBase(n.BaseName() + // SetBase returns a new FilePath with the base element replaced with base. func (n FilePath) SetBase(base string) FilePath { return n.Dir().Push(base) } -// Pop returns a new FilePath with the last element removed -func (n FilePath) Pop() FilePath { return JoinPaths(n.String(), "..") } +// Pop returns a new FilePath with the last element removed. Pop is an alias +// for the Dir method. +func (n FilePath) Pop() FilePath { return n.Dir() } // Push returns a new FilePath with elem added to the end func (n FilePath) Push(elem string) FilePath { return JoinPaths(n.String(), elem) } diff --git a/name_test.go b/name_test.go index 159bc5c..0ed1f02 100644 --- a/name_test.go +++ b/name_test.go @@ -153,169 +153,6 @@ func TestName(t *testing.T) { } } -func TestName_PGGUpperCamelCase(t *testing.T) { - t.Parallel() - - tests := []struct { - in string - ex string - }{ - {"foo_bar", "FooBar"}, - {"myJSON", "MyJSON"}, - {"PDFTemplate", "PDFTemplate"}, - {"_my_field_name_2", "XMyFieldName_2"}, - {"my.field", "My.field"}, - {"my_Field", "My_Field"}, - {"string", "String_"}, - {"String", "String_"}, - } - - for _, tc := range tests { - assert.Equal(t, tc.ex, Name(tc.in).PGGUpperCamelCase().String()) - } -} - -func TestTypeName(t *testing.T) { - t.Parallel() - - tests := []struct { - in string - el string - key string - ptr string - val string - }{ - { - in: "int", - el: "int", - ptr: "*int", - val: "int", - }, - { - in: "*int", - el: "*int", - ptr: "*int", - val: "int", - }, - { - in: "foo.bar", - el: "foo.bar", - ptr: "*foo.bar", - val: "foo.bar", - }, - { - in: "*foo.bar", - el: "*foo.bar", - ptr: "*foo.bar", - val: "foo.bar", - }, - { - in: "[]string", - el: "string", - key: "int", - ptr: "[]string", - val: "[]string", - }, - { - in: "[]*string", - el: "*string", - key: "int", - ptr: "[]*string", - val: "[]*string", - }, - { - in: "[]foo.bar", - el: "foo.bar", - key: "int", - ptr: "[]foo.bar", - val: "[]foo.bar", - }, - { - in: "[]*foo.bar", - el: "*foo.bar", - key: "int", - ptr: "[]*foo.bar", - val: "[]*foo.bar", - }, - { - in: "map[string]float64", - el: "float64", - key: "string", - ptr: "map[string]float64", - val: "map[string]float64", - }, - { - in: "map[string]*float64", - el: "*float64", - key: "string", - ptr: "map[string]*float64", - val: "map[string]*float64", - }, - { - in: "map[string]foo.bar", - el: "foo.bar", - key: "string", - ptr: "map[string]foo.bar", - val: "map[string]foo.bar", - }, - { - in: "map[string]*foo.bar", - el: "*foo.bar", - key: "string", - ptr: "map[string]*foo.bar", - val: "map[string]*foo.bar", - }, - { - in: "[][]byte", - el: "[]byte", - key: "int", - ptr: "[][]byte", - val: "[][]byte", - }, - { - in: "map[int64][]byte", - el: "[]byte", - key: "int64", - ptr: "map[int64][]byte", - val: "map[int64][]byte", - }, - } - - for _, test := range tests { - tc := test - t.Run(tc.in, func(t *testing.T) { - tn := TypeName(tc.in) - t.Parallel() - - t.Run("Element", func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.el, tn.Element().String()) - }) - - t.Run("Key", func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.key, tn.Key().String()) - }) - - t.Run("Pointer", func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.ptr, tn.Pointer().String()) - }) - - t.Run("Value", func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.val, tn.Value().String()) - }) - }) - } -} - -func TestTypeName_Key_Malformed(t *testing.T) { - t.Parallel() - tn := TypeName("]malformed") - assert.Empty(t, tn.Key().String()) -} - func TestNameTransformer_Chain(t *testing.T) { t.Parallel() @@ -357,29 +194,6 @@ func ExampleName_UpperCamelCase() { // PDFTemplate } -func ExampleName_PGGUpperCamelCase() { - names := []string{ - "foo_bar", - "myJSON", - "PDFTemplate", - "_my_field_name_2", - "my.field", - "my_Field", - } - - for _, n := range names { - fmt.Println(Name(n).PGGUpperCamelCase()) - } - - // Output: - // FooBar - // MyJSON - // PDFTemplate - // XMyFieldName_2 - // My.field - // My_Field -} - func ExampleName_LowerCamelCase() { names := []string{ "foo_bar", @@ -448,6 +262,23 @@ func ExampleName_UpperSnakeCase() { // PDF_Template } +func ExampleName_SnakeCase() { + names := []string{ + "foo_bar", + "myJSON", + "PDFTemplate", + } + + for _, n := range names { + fmt.Println(Name(n).SnakeCase()) + } + + // Output: + // foo_bar + // my_JSON + // PDF_Template +} + func ExampleName_LowerDotNotation() { names := []string{ "foo_bar", @@ -481,79 +312,3 @@ func ExampleName_UpperDotNotation() { // My.JSON // PDF.Template } - -func ExampleTypeName_Element() { - types := []string{ - "int", - "*my.Type", - "[]string", - "map[string]*io.Reader", - } - - for _, t := range types { - fmt.Println(TypeName(t).Element()) - } - - // Output: - // int - // *my.Type - // string - // *io.Reader -} - -func ExampleTypeName_Key() { - types := []string{ - "int", - "*my.Type", - "[]string", - "map[string]*io.Reader", - } - - for _, t := range types { - fmt.Println(TypeName(t).Key()) - } - - // Output: - // - // - // int - // string -} - -func ExampleTypeName_Pointer() { - types := []string{ - "int", - "*my.Type", - "[]string", - "map[string]*io.Reader", - } - - for _, t := range types { - fmt.Println(TypeName(t).Pointer()) - } - - // Output: - // *int - // *my.Type - // []string - // map[string]*io.Reader -} - -func ExampleTypeName_Value() { - types := []string{ - "int", - "*my.Type", - "[]string", - "map[string]*io.Reader", - } - - for _, t := range types { - fmt.Println(TypeName(t).Value()) - } - - // Output: - // int - // my.Type - // []string - // map[string]*io.Reader -} diff --git a/node_nilvisitor_test.go b/node_nilvisitor_test.go index 87954f4..dfbe871 100644 --- a/node_nilvisitor_test.go +++ b/node_nilvisitor_test.go @@ -54,11 +54,11 @@ func enumNode() Node { // } sm := &msg{} - sm.addEnum(&enum{rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("Foo")}}) + sm.addEnum(&enum{desc: &descriptor.EnumDescriptorProto{Name: proto.String("Foo")}}) m := &msg{} m.addMessage(sm) - m.addEnum(&enum{rawDesc: &descriptor.EnumDescriptorProto{Name: proto.String("Bar")}}) + m.addEnum(&enum{desc: &descriptor.EnumDescriptorProto{Name: proto.String("Bar")}}) return m } diff --git a/oneof.go b/oneof.go index 00ce99e..25a2200 100644 --- a/oneof.go +++ b/oneof.go @@ -28,7 +28,7 @@ type oneof struct { msg Message flds []Field - comments string + info SourceCodeInfo } func (o *oneof) accept(v Visitor) (err error) { @@ -46,12 +46,12 @@ func (o *oneof) Syntax() Syntax { return o.msg.Syn func (o *oneof) Package() Package { return o.msg.Package() } func (o *oneof) File() File { return o.msg.File() } func (o *oneof) BuildTarget() bool { return o.msg.BuildTarget() } -func (o *oneof) Comments() string { return o.comments } +func (o *oneof) SourceCodeInfo() SourceCodeInfo { return o.info } func (o *oneof) Descriptor() *descriptor.OneofDescriptorProto { return o.desc } func (o *oneof) Message() Message { return o.msg } func (o *oneof) setMessage(m Message) { o.msg = m } -func (o *oneof) Imports() (i []Package) { +func (o *oneof) Imports() (i []File) { for _, f := range o.flds { i = append(i, f.Imports()...) } @@ -73,4 +73,13 @@ func (o *oneof) addField(f Field) { o.flds = append(o.flds, f) } +func (o *oneof) childAtPath(path []int32) Entity { + if len(path) == 0 { + return o + } + return nil +} + +func (o *oneof) addSourceCodeInfo(info SourceCodeInfo) { o.info = info } + var _ OneOf = (*oneof)(nil) diff --git a/oneof_test.go b/oneof_test.go index e51d4be..74da82d 100644 --- a/oneof_test.go +++ b/oneof_test.go @@ -95,7 +95,7 @@ func TestOneof_Imports(t *testing.T) { o := &oneof{} assert.Empty(t, o.Imports()) - o.addField(&mockField{i: []Package{&pkg{}, &pkg{}}, Field: &field{}}) + o.addField(&mockField{i: []File{&file{}, &file{}}, Field: &field{}}) assert.Len(t, o.Imports(), 2) } @@ -127,14 +127,22 @@ func TestOneof_Accept(t *testing.T) { assert.Equal(t, 1, v.oneof) } +func TestOneof_ChildAtPath(t *testing.T) { + t.Parallel() + + o := &oneof{} + assert.Equal(t, o, o.childAtPath(nil)) + assert.Nil(t, o.childAtPath([]int32{1})) +} + type mockOneOf struct { OneOf - i []Package + i []File m Message err error } -func (o *mockOneOf) Imports() []Package { return o.i } +func (o *mockOneOf) Imports() []File { return o.i } func (o *mockOneOf) setMessage(m Message) { o.m = m } diff --git a/package.go b/package.go index 265dee3..e6f6b28 100644 --- a/package.go +++ b/package.go @@ -1,24 +1,15 @@ package pgs +import "github.com/golang/protobuf/protoc-gen-go/descriptor" + // Package is a container that encapsulates all the files under a single -// package namespace. Specifically, this would be all the proto files loaded -// within the same directory (not recursively). While a proto file's package -// technically can differ from its sibling files, PGS will throw an error as -// this is typically a mistake or bad practice. +// package namespace. type Package interface { Node - Commenter - // The name of the proto package. This may or may not be the same as the Go - // package name. + // The name of the proto package. ProtoName() Name - // The name of the Go package. This is guaranteed to be unique. - GoName() Name - - // The fully qualified import path for this Go Package - ImportPath() string - // All the files loaded for this Package Files() []File @@ -28,24 +19,16 @@ type Package interface { } type pkg struct { - fd packageFD - importPath string - name string - files []File + fd *descriptor.FileDescriptorProto + files []File comments string } -func (p *pkg) ProtoName() Name { return Name(p.fd.GetPackage()) } -func (p *pkg) GoName() Name { return Name(p.name) } -func (p *pkg) ImportPath() string { return p.importPath } -func (p *pkg) Comments() string { return p.comments } +func (p *pkg) ProtoName() Name { return Name(p.fd.GetPackage()) } +func (p *pkg) Comments() string { return p.comments } -func (p *pkg) Files() []File { - fs := make([]File, len(p.files)) - copy(fs, p.files) - return fs -} +func (p *pkg) Files() []File { return p.files } func (p *pkg) accept(v Visitor) (err error) { if v == nil { @@ -73,11 +56,3 @@ func (p *pkg) addFile(f File) { func (p *pkg) setComments(comments string) { p.comments = comments } - -// packageFD stands in for a *generator.FileDescriptor. The FileDescriptor -// cannot be used directly as its PackageName method calls out to a global map. -type packageFD interface { - GetPackage() string -} - -var _ Package = (*pkg)(nil) diff --git a/package_test.go b/package_test.go index 4c13378..b3b60fd 100644 --- a/package_test.go +++ b/package_test.go @@ -3,6 +3,9 @@ package pgs import ( "testing" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "errors" "github.com/stretchr/testify/assert" @@ -11,22 +14,8 @@ import ( func TestPkg_ProtoName(t *testing.T) { t.Parallel() - p := &pkg{fd: mockPackageFD{gp: "foobar"}} - assert.Equal(t, Name("foobar"), p.ProtoName()) -} - -func TestPkg_GoName(t *testing.T) { - t.Parallel() - - p := &pkg{name: "foobar"} - assert.Equal(t, Name("foobar"), p.GoName()) -} - -func TestPkg_ImportPath(t *testing.T) { - t.Parallel() - - p := &pkg{importPath: "fizz/buzz"} - assert.Equal(t, "fizz/buzz", p.ImportPath()) + p := dummyPkg() + assert.Equal(t, p.fd.GetPackage(), p.ProtoName().String()) } func TestPkg_Files(t *testing.T) { @@ -84,21 +73,16 @@ func TestPkg_Accept(t *testing.T) { assert.Equal(t, 2, v.file) } -type mockPackageFD struct { - packageFD - pn string - gp string -} +func TestPackage_Comments(t *testing.T) { + t.Parallel() -func (mp mockPackageFD) PackageName() string { return mp.pn } -func (mp mockPackageFD) GetPackage() string { return mp.gp } + pkg := dummyPkg() + pkg.setComments("foobar") + assert.Equal(t, "foobar", pkg.Comments()) +} func dummyPkg() *pkg { return &pkg{ - fd: &mockPackageFD{ - pn: "pkg_name", - gp: "get_pkg", - }, - importPath: "import/path", + fd: &descriptor.FileDescriptorProto{Package: proto.String("pkg_name")}, } } diff --git a/parameters.go b/parameters.go index 55e2976..d6bbbac 100644 --- a/parameters.go +++ b/parameters.go @@ -8,37 +8,13 @@ import ( "time" ) -const ( - importPrefixKey = "import_prefix" - importPathKey = "import_path" - outputPathKey = "output_path" - importMapKeyPrefix = "M" - pluginsKey = "plugins" - pluginsSep = "+" -) - -// PathType describes how the generated file paths should be constructed. -type PathType string - -const ( - // PathTypeParam is the plugin param that allows specifying the path type - // mode used in code generation. - pathTypeKey = "paths" - - // ImportPath is the default and outputs the file based off the go import - // path defined in the go_package option. - ImportPath PathType = "" - - // SourceRelative indicates files should be output relative to the path of - // the source file. - SourceRelative PathType = "source_relative" -) +const outputPathKey = "output_path" // Parameters provides a convenience for accessing and modifying the parameters // passed into the protoc-gen-star plugin. type Parameters map[string]string -// ParseParameters converts the raw params string provided to protoc into a +// ParseParameters converts the raw params string provided by protoc into a // representative mapping. func ParseParameters(p string) (params Parameters) { parts := strings.Split(p, ",") @@ -55,115 +31,6 @@ func ParseParameters(p string) (params Parameters) { return } -// Plugins returns the sub-plugins enabled for this protoc plugin. If the all -// value is true, all registered plugins are considered enabled (ie, protoc was -// called with an empty "plugins" parameter). Otherwise, plugins contains the -// list of plugins enabled by name. -func (p Parameters) Plugins() (plugins []string, all bool) { - s, ok := p[pluginsKey] - if !ok { - return - } - - if all = s == ""; all { - return - } - - plugins = strings.Split(s, pluginsSep) - return -} - -// HasPlugin returns true if the plugin name is enabled in the parameters. This -// method will always return true if all plugins are enabled. -func (p Parameters) HasPlugin(name string) bool { - plugins, all := p.Plugins() - if all { - return true - } - - for _, pl := range plugins { - if pl == name { - return true - } - } - - return false -} - -// AddPlugin adds name to the list of plugins in the parameters. If all plugins -// are enabled, this method is a noop. -func (p Parameters) AddPlugin(name ...string) { - if len(name) == 0 { - return - } - - plugins, all := p.Plugins() - if all { - return - } - - p.SetStr(pluginsKey, strings.Join(append(plugins, name...), pluginsSep)) -} - -// EnableAllPlugins changes the parameters to enable all registered sub-plugins. -func (p Parameters) EnableAllPlugins() { p.SetStr(pluginsKey, "") } - -// ImportPrefix returns the protoc-gen-go parameter. This prefix is added onto -// the beginning of all Go import paths. This is useful for things like -// generating protos in a subdirectory, or regenerating vendored protobufs -// in-place. By default, this method returns an empty string. -// -// See: https://github.com/golang/protobuf#parameters -func (p Parameters) ImportPrefix() string { return p.Str(importPrefixKey) } - -// SetImportPrefix sets the protoc-gen-go ImportPrefix parameter. This is -// useful for overriding the behavior of the ImportPrefix at runtime. -func (p Parameters) SetImportPrefix(prefix string) { p.SetStr(importPrefixKey, prefix) } - -// ImportPath returns the protoc-gen-go parameter. This value is used as the -// package if the input proto files do not declare a go_package option. If it -// contains slashes, everything up to the rightmost slash is ignored. -// -// See: https://github.com/golang/protobuf#parameters -func (p Parameters) ImportPath() string { return p.Str(importPathKey) } - -// SetImportPath sets the protoc-gen-go ImportPath parameter. This is useful -// for overriding the behavior of the ImportPath at runtime. -func (p Parameters) SetImportPath(path string) { p.SetStr(importPathKey, path) } - -// Paths returns the protoc-gen-go parameter. This value is used to switch the -// mode used to determine the output paths of the generated code. By default, -// paths are derived from the import path specified by go_package. It can be -// overridden to be "source_relative", ignoring the import path using the -// source path exclusively. -func (p Parameters) Paths() PathType { return PathType(p.Str(pathTypeKey)) } - -// SetPaths sets the protoc-gen-go Paths parameter. This is useful for -// overriding the behavior of Paths at runtime. -func (p Parameters) SetPaths(pt PathType) { p.SetStr(pathTypeKey, string(pt)) } - -// ImportMap returns the protoc-gen-go import map overrides. Each entry in the -// map keys off a proto file (as loaded by protoc) with values of the Go -// package to use. These values will be prefixed with the value of ImportPrefix -// when generating the Go code. -func (p Parameters) ImportMap() map[string]string { - out := map[string]string{} - - for k, v := range p { - if strings.HasPrefix(k, importMapKeyPrefix) { - out[k[1:]] = v - } - } - - return out -} - -// AddImportMapping adds a proto file to Go package import mapping to the -// parameters. -func (p Parameters) AddImportMapping(proto, pkg string) { - p[fmt.Sprintf("%s%s", importMapKeyPrefix, proto)] = pkg -} - // OutputPath returns the protoc-gen-star special parameter. If not set in the // execution of protoc, "." is returned, indicating that output is relative to // the (unknown) output location for sub-plugins or the directory where protoc diff --git a/parameters_test.go b/parameters_test.go index a5b8a4f..ef50737 100644 --- a/parameters_test.go +++ b/parameters_test.go @@ -7,108 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestParameters_Plugins(t *testing.T) { - t.Parallel() - - p := Parameters{} - plugins, all := p.Plugins() - assert.Empty(t, plugins) - assert.False(t, all) - - p[pluginsKey] = "foo+bar" - plugins, all = p.Plugins() - assert.Equal(t, []string{"foo", "bar"}, plugins) - assert.False(t, all) - - p[pluginsKey] = "" - plugins, all = p.Plugins() - assert.Empty(t, plugins) - assert.True(t, all) -} - -func TestParameters_HasPlugin(t *testing.T) { - t.Parallel() - - p := Parameters{} - assert.False(t, p.HasPlugin("foo")) - - p[pluginsKey] = "foo" - assert.True(t, p.HasPlugin("foo")) - - p[pluginsKey] = "" - assert.True(t, p.HasPlugin("foo")) - - p[pluginsKey] = "bar" - assert.False(t, p.HasPlugin("foo")) -} - -func TestParameters_AddPlugin(t *testing.T) { - t.Parallel() - - p := Parameters{} - p.AddPlugin("foo", "bar") - assert.Equal(t, "foo+bar", p[pluginsKey]) - - p.AddPlugin("baz") - assert.Equal(t, "foo+bar+baz", p[pluginsKey]) - - p.AddPlugin() - assert.Equal(t, "foo+bar+baz", p[pluginsKey]) - - p[pluginsKey] = "" - p.AddPlugin("fizz", "buzz") - assert.Equal(t, "", p[pluginsKey]) -} - -func TestParameters_EnableAllPlugins(t *testing.T) { - t.Parallel() - - p := Parameters{pluginsKey: "foo"} - _, all := p.Plugins() - assert.False(t, all) - - p.EnableAllPlugins() - _, all = p.Plugins() - assert.True(t, all) -} - -func TestParameters_ImportPrefix(t *testing.T) { - t.Parallel() - - p := Parameters{} - assert.Empty(t, p.ImportPrefix()) - p.SetImportPrefix("foo") - assert.Equal(t, "foo", p.ImportPrefix()) -} - -func TestParameters_ImportPath(t *testing.T) { - t.Parallel() - - p := Parameters{} - assert.Empty(t, p.ImportPath()) - p.SetImportPath("foo") - assert.Equal(t, "foo", p.ImportPath()) -} - -func TestParameters_ImportMap(t *testing.T) { - t.Parallel() - - p := Parameters{ - "Mfoo.proto": "bar", - "Mfizz/buzz.proto": "baz", - } - - im := p.ImportMap() - assert.Len(t, p.ImportMap(), 2) - - p.AddImportMapping("quux.proto", "shme") - im = p.ImportMap() - assert.Len(t, im, 3) - assert.Equal(t, "shme", im["quux.proto"]) - assert.Equal(t, "bar", im["foo.proto"]) - assert.Equal(t, "baz", im["fizz/buzz.proto"]) -} - func TestParameters_OutputPath(t *testing.T) { t.Parallel() @@ -326,13 +224,3 @@ func TestParameters_Duration(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 789*time.Second, out) } - -func TestParameters_Paths(t *testing.T) { - t.Parallel() - - p := Parameters{} - - assert.Equal(t, ImportPath, p.Paths()) - p.SetPaths(SourceRelative) - assert.Equal(t, SourceRelative, p.Paths()) -} diff --git a/path.go b/path.go deleted file mode 100644 index e9e3ae5..0000000 --- a/path.go +++ /dev/null @@ -1,61 +0,0 @@ -package pgs - -import ( - "path" - "strings" - - "github.com/golang/protobuf/protoc-gen-go/generator" -) - -func goPackageOption(f *generator.FileDescriptor) (impPath, pkg string, ok bool) { - pkg = f.GetOptions().GetGoPackage() - if pkg == "" { - return - } - ok = true - - slash := strings.LastIndex(pkg, "/") - if slash < 0 { - return - } - - impPath, pkg = pkg, pkg[slash+1:] - sc := strings.IndexByte(impPath, ';') - if sc < 0 { - return - } - - impPath, pkg = impPath[:sc], impPath[sc+1:] - return -} - -func goFileName(f *generator.FileDescriptor, pathType PathType) string { - name := f.GetName() - if ext := path.Ext(name); ext == ".proto" || ext == ".protodevel" { - name = name[:len(name)-len(ext)] - } - name += ".pb.go" - - if pathType == SourceRelative { - return name - } - - if impPath, _, ok := goPackageOption(f); ok && impPath != "" { - _, name = path.Split(name) - name = path.Join(impPath, name) - } - - return name -} - -func goImportPath(g *generator.Generator, f *generator.FileDescriptor) generator.GoImportPath { - fn := goFileName(f, Parameters(g.Param).Paths()) - - importPath := path.Dir(fn) - if sub, ok := g.ImportMap[f.GetName()]; ok { - importPath = sub - } - importPath = path.Join(g.ImportPrefix, importPath) - - return generator.GoImportPath(importPath) -} diff --git a/path_test.go b/path_test.go deleted file mode 100644 index fd6981a..0000000 --- a/path_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package pgs - -import ( - "testing" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/stretchr/testify/assert" -) - -func TestGoPackageOption(t *testing.T) { - t.Parallel() - - fd := &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Options: &descriptor.FileOptions{}}} - - impPath, pkg, ok := goPackageOption(fd) - assert.Empty(t, impPath) - assert.Empty(t, pkg) - assert.False(t, ok) - - fd.Options.GoPackage = proto.String("foobar") - impPath, pkg, ok = goPackageOption(fd) - assert.Empty(t, impPath) - assert.Equal(t, "foobar", pkg) - assert.True(t, ok) - - fd.Options.GoPackage = proto.String("fizz/buzz") - impPath, pkg, ok = goPackageOption(fd) - assert.Equal(t, "fizz/buzz", impPath) - assert.Equal(t, "buzz", pkg) - assert.True(t, ok) - - fd.Options.GoPackage = proto.String("foo/bar;baz") - impPath, pkg, ok = goPackageOption(fd) - assert.Equal(t, "foo/bar", impPath) - assert.Equal(t, "baz", pkg) - assert.True(t, ok) -} - -func TestGoFileName(t *testing.T) { - t.Parallel() - - fd := &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Name: proto.String("dir/file.proto"), - Options: &descriptor.FileOptions{}, - }, - } - - assert.Equal(t, "dir/file.pb.go", goFileName(fd, ImportPath)) - - fd.FileDescriptorProto.Options.GoPackage = proto.String("other/path") - assert.Equal(t, "other/path/file.pb.go", goFileName(fd, ImportPath)) - assert.Equal(t, "dir/file.pb.go", goFileName(fd, SourceRelative)) -} - -func TestGoImportPath(t *testing.T) { - t.Parallel() - - fd := &generator.FileDescriptor{ - FileDescriptorProto: &descriptor.FileDescriptorProto{ - Name: proto.String("dir/file.proto"), - Options: &descriptor.FileOptions{}, - }, - } - - g := &generator.Generator{ImportMap: map[string]string{}} - - assert.Equal(t, generator.GoImportPath("dir"), goImportPath(g, fd)) - - g.ImportMap[fd.GetName()] = "other/pkg" - g.ImportPrefix = "github.com/example" - - assert.Equal(t, generator.GoImportPath("github.com/example/other/pkg"), goImportPath(g, fd)) -} diff --git a/persister.go b/persister.go index 5e9088d..94efb0c 100644 --- a/persister.go +++ b/persister.go @@ -11,16 +11,14 @@ import ( type persister interface { SetDebugger(d Debugger) - SetPGG(pgg ProtocGenGo) SetFS(fs afero.Fs) AddPostProcessor(proc ...PostProcessor) - Persist(a ...Artifact) + Persist(a ...Artifact) *plugin_go.CodeGeneratorResponse } type stdPersister struct { Debugger - pgg ProtocGenGo fs afero.Fs procs []PostProcessor } @@ -28,45 +26,46 @@ type stdPersister struct { func newPersister() *stdPersister { return &stdPersister{fs: afero.NewOsFs()} } func (p *stdPersister) SetDebugger(d Debugger) { p.Debugger = d } -func (p *stdPersister) SetPGG(pgg ProtocGenGo) { p.pgg = pgg } func (p *stdPersister) SetFS(fs afero.Fs) { p.fs = fs } func (p *stdPersister) AddPostProcessor(proc ...PostProcessor) { p.procs = append(p.procs, proc...) } -func (p *stdPersister) Persist(arts ...Artifact) { +func (p *stdPersister) Persist(arts ...Artifact) *plugin_go.CodeGeneratorResponse { + resp := new(plugin_go.CodeGeneratorResponse) + for _, a := range arts { switch a := a.(type) { case GeneratorFile: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert ", a.Name, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) - p.insertFile(f, a.Overwrite) + p.insertFile(resp, f, a.Overwrite) case GeneratorTemplateFile: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert ", a.Name, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) - p.insertFile(f, a.Overwrite) + p.insertFile(resp, f, a.Overwrite) case GeneratorAppend: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert append for ", a.FileName, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) n, _ := cleanGeneratorFileName(a.FileName) - p.insertAppend(n, f) + p.insertAppend(resp, n, f) case GeneratorTemplateAppend: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert append for ", a.FileName, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) n, _ := cleanGeneratorFileName(a.FileName) - p.insertAppend(n, f) + p.insertAppend(resp, n, f) case GeneratorInjection: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert injection ", a.InsertionPoint, " for ", a.FileName, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) - p.insertFile(f, false) + p.insertFile(resp, f, false) case GeneratorTemplateInjection: f, err := a.ProtoFile() p.CheckErr(err, "unable to convert injection ", a.InsertionPoint, " for ", a.FileName, " to proto") f.Content = proto.String(p.postProcess(a, f.GetContent())) - p.insertFile(f, false) + p.insertFile(resp, f, false) case CustomFile: p.writeFile( a.Name, @@ -88,10 +87,12 @@ func (p *stdPersister) Persist(arts ...Artifact) { p.Failf("unrecognized artifact type: %T", a) } } + + return resp } -func (p *stdPersister) indexOfFile(name string) int { - for i, f := range p.pgg.response().GetFile() { +func (p *stdPersister) indexOfFile(resp *plugin_go.CodeGeneratorResponse, name string) int { + for i, f := range resp.GetFile() { if f.GetName() == name && f.InsertionPoint == nil { return i } @@ -100,26 +101,28 @@ func (p *stdPersister) indexOfFile(name string) int { return -1 } -func (p *stdPersister) insertFile(f *plugin_go.CodeGeneratorResponse_File, overwrite bool) { +func (p *stdPersister) insertFile(resp *plugin_go.CodeGeneratorResponse, + f *plugin_go.CodeGeneratorResponse_File, overwrite bool) { if overwrite { - if i := p.indexOfFile(f.GetName()); i >= 0 { - p.pgg.response().File[i] = f + if i := p.indexOfFile(resp, f.GetName()); i >= 0 { + resp.File[i] = f return } } - p.pgg.response().File = append(p.pgg.response().File, f) + resp.File = append(resp.File, f) } -func (p *stdPersister) insertAppend(name string, f *plugin_go.CodeGeneratorResponse_File) { - i := p.indexOfFile(name) +func (p *stdPersister) insertAppend(resp *plugin_go.CodeGeneratorResponse, + name string, f *plugin_go.CodeGeneratorResponse_File) { + i := p.indexOfFile(resp, name) p.Assert(i > -1, "append target ", name, " missing") - p.pgg.response().File = append( - p.pgg.response().File[:i+1], + resp.File = append( + resp.File[:i+1], append( []*plugin_go.CodeGeneratorResponse_File{f}, - p.pgg.response().File[i+1:]..., + resp.File[i+1:]..., )..., ) } diff --git a/persister_test.go b/persister_test.go index 8bf9059..ca4da16 100644 --- a/persister_test.go +++ b/persister_test.go @@ -6,7 +6,6 @@ import ( "errors" - "github.com/golang/protobuf/protoc-gen-go/generator" "github.com/spf13/afero" "github.com/stretchr/testify/assert" ) @@ -14,46 +13,40 @@ import ( func TestPersister_Persist_Unrecognized(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) p.Persist(nil) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPersister_Persist_GeneratorFile(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist(GeneratorFile{ - Name: "foo", - Contents: "bar", - }) - - assert.Len(t, p.pgg.response().File, 1) - assert.Equal(t, "foo", p.pgg.response().File[0].GetName()) - assert.Equal(t, "bar", p.pgg.response().File[0].GetContent()) - - p.Persist(GeneratorFile{ - Name: "foo", - Contents: "baz", - }) - - assert.Len(t, p.pgg.response().File, 2) - - p.Persist(GeneratorFile{ - Name: "foo", - Contents: "fizz", - Overwrite: true, - }) - - assert.Len(t, p.pgg.response().File, 2) - assert.Equal(t, "fizz", p.pgg.response().File[0].GetContent()) + resp := p.Persist( + GeneratorFile{ + Name: "foo", + Contents: "bar", + }, + GeneratorFile{ + Name: "quux", + Contents: "baz", + }, + GeneratorFile{ + Name: "foo", + Contents: "fizz", + Overwrite: true, + }) + + assert.Len(t, resp.File, 2) + assert.Equal(t, "foo", resp.File[0].GetName()) + assert.Equal(t, "fizz", resp.File[0].GetContent()) } var genTpl = template.Must(template.New("good").Parse("{{ . }}")) @@ -61,122 +54,106 @@ var genTpl = template.Must(template.New("good").Parse("{{ . }}")) func TestPersister_Persist_GeneratorTemplateFile(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist(GeneratorTemplateFile{ - Name: "foo", - TemplateArtifact: TemplateArtifact{ - Template: genTpl, - Data: "bar", + resp := p.Persist( + GeneratorTemplateFile{ + Name: "foo", + TemplateArtifact: TemplateArtifact{ + Template: genTpl, + Data: "bar", + }, }, - }) - - assert.Len(t, p.pgg.response().File, 1) - assert.Equal(t, "foo", p.pgg.response().File[0].GetName()) - assert.Equal(t, "bar", p.pgg.response().File[0].GetContent()) - - p.Persist(GeneratorTemplateFile{ - Name: "foo", - TemplateArtifact: TemplateArtifact{ - Template: genTpl, - Data: "baz", + GeneratorTemplateFile{ + Name: "quux", + TemplateArtifact: TemplateArtifact{ + Template: genTpl, + Data: "baz", + }, }, - }) - - assert.Len(t, p.pgg.response().File, 2) - - p.Persist(GeneratorTemplateFile{ - Name: "foo", - TemplateArtifact: TemplateArtifact{ - Template: genTpl, - Data: "fizz", + GeneratorTemplateFile{ + Name: "foo", + TemplateArtifact: TemplateArtifact{ + Template: genTpl, + Data: "fizz", + }, + Overwrite: true, }, - Overwrite: true, - }) + ) - assert.Len(t, p.pgg.response().File, 2) - assert.Equal(t, "fizz", p.pgg.response().File[0].GetContent()) + assert.Len(t, resp.File, 2) + assert.Equal(t, "foo", resp.File[0].GetName()) + assert.Equal(t, "fizz", resp.File[0].GetContent()) } func TestPersister_Persist_GeneratorAppend(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist( + resp := p.Persist( GeneratorFile{Name: "foo"}, GeneratorFile{Name: "bar"}, + GeneratorAppend{ + FileName: "foo", + Contents: "baz", + }, + GeneratorAppend{ + FileName: "bar", + Contents: "quux", + }, ) - p.Persist(GeneratorAppend{ - FileName: "foo", - Contents: "baz", - }) + assert.Len(t, resp.File, 4) + assert.Equal(t, "", resp.File[1].GetName()) + assert.Equal(t, "baz", resp.File[1].GetContent()) + assert.Equal(t, "", resp.File[3].GetName()) + assert.Equal(t, "quux", resp.File[3].GetContent()) - assert.Len(t, p.pgg.response().File, 3) - assert.Equal(t, "", p.pgg.response().File[1].GetName()) - assert.Equal(t, "baz", p.pgg.response().File[1].GetContent()) + p.Persist(GeneratorAppend{FileName: "doesNotExist"}) - p.Persist(GeneratorAppend{ - FileName: "bar", - Contents: "quux", - }) - - assert.Len(t, p.pgg.response().File, 4) - assert.Equal(t, "", p.pgg.response().File[3].GetName()) - assert.Equal(t, "quux", p.pgg.response().File[3].GetContent()) - - p.Persist(GeneratorAppend{ - FileName: "doesNotExist", - }) - - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPersister_Persist_GeneratorTemplateAppend(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist( + resp := p.Persist( GeneratorFile{Name: "foo"}, GeneratorFile{Name: "bar"}, - ) - - p.Persist(GeneratorTemplateAppend{ - FileName: "foo", - TemplateArtifact: TemplateArtifact{ - Template: genTpl, - Data: "baz", - }, - }) - - assert.Len(t, p.pgg.response().File, 3) - assert.Equal(t, "", p.pgg.response().File[1].GetName()) - assert.Equal(t, "baz", p.pgg.response().File[1].GetContent()) - - p.Persist(GeneratorTemplateAppend{ - FileName: "bar", - TemplateArtifact: TemplateArtifact{ - Template: genTpl, - Data: "quux", + GeneratorTemplateAppend{ + FileName: "foo", + TemplateArtifact: TemplateArtifact{ + Template: genTpl, + Data: "baz", + }, + }, GeneratorTemplateAppend{ + FileName: "bar", + TemplateArtifact: TemplateArtifact{ + Template: genTpl, + Data: "quux", + }, }, - }) + ) - assert.Len(t, p.pgg.response().File, 4) - assert.Equal(t, "", p.pgg.response().File[3].GetName()) - assert.Equal(t, "quux", p.pgg.response().File[3].GetContent()) + assert.Len(t, resp.File, 4) + assert.Equal(t, "", resp.File[1].GetName()) + assert.Equal(t, "baz", resp.File[1].GetContent()) + assert.Equal(t, "", resp.File[3].GetName()) + assert.Equal(t, "quux", resp.File[3].GetContent()) - p.Persist(GeneratorTemplateAppend{ + resp = p.Persist(GeneratorTemplateAppend{ FileName: "doesNotExist", TemplateArtifact: TemplateArtifact{ Template: genTpl, @@ -184,38 +161,38 @@ func TestPersister_Persist_GeneratorTemplateAppend(t *testing.T) { }, }) - assert.True(t, d.failed) + assert.True(t, d.Failed()) } func TestPersister_Persist_GeneratorInjection(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist(GeneratorInjection{ + resp := p.Persist(GeneratorInjection{ FileName: "foo", InsertionPoint: "bar", Contents: "baz", }) - assert.Len(t, p.pgg.response().File, 1) - assert.Equal(t, "foo", p.pgg.response().File[0].GetName()) - assert.Equal(t, "bar", p.pgg.response().File[0].GetInsertionPoint()) - assert.Equal(t, "baz", p.pgg.response().File[0].GetContent()) + assert.Len(t, resp.File, 1) + assert.Equal(t, "foo", resp.File[0].GetName()) + assert.Equal(t, "bar", resp.File[0].GetInsertionPoint()) + assert.Equal(t, "baz", resp.File[0].GetContent()) } func TestPersister_Persist_GeneratorTemplateInjection(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) - p.Persist(GeneratorTemplateInjection{ + resp := p.Persist(GeneratorTemplateInjection{ FileName: "foo", InsertionPoint: "bar", TemplateArtifact: TemplateArtifact{ @@ -224,16 +201,16 @@ func TestPersister_Persist_GeneratorTemplateInjection(t *testing.T) { }, }) - assert.Len(t, p.pgg.response().File, 1) - assert.Equal(t, "foo", p.pgg.response().File[0].GetName()) - assert.Equal(t, "bar", p.pgg.response().File[0].GetInsertionPoint()) - assert.Equal(t, "baz", p.pgg.response().File[0].GetContent()) + assert.Len(t, resp.File, 1) + assert.Equal(t, "foo", resp.File[0].GetName()) + assert.Equal(t, "bar", resp.File[0].GetInsertionPoint()) + assert.Equal(t, "baz", resp.File[0].GetContent()) } func TestPersister_Persist_CustomFile(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) @@ -273,7 +250,7 @@ func TestPersister_Persist_CustomFile(t *testing.T) { func TestPersister_Persist_CustomTemplateFile(t *testing.T) { t.Parallel() - d := newMockDebugger(t) + d := InitMockDebugger() p := dummyPersister(d) fs := afero.NewMemMapFs() p.SetFS(fs) @@ -322,7 +299,7 @@ func TestPersister_Persist_CustomTemplateFile(t *testing.T) { func TestPersister_AddPostProcessor(t *testing.T) { t.Parallel() - p := dummyPersister(newMockDebugger(t)) + p := dummyPersister(InitMockDebugger()) good := &mockPP{match: true, out: []byte("good")} bad := &mockPP{err: errors.New("should not be called")} @@ -335,7 +312,6 @@ func TestPersister_AddPostProcessor(t *testing.T) { func dummyPersister(d Debugger) *stdPersister { return &stdPersister{ Debugger: d, - pgg: mockGeneratorPGG{ProtocGenGo: Wrap(generator.New())}, fs: afero.NewMemMapFs(), } } diff --git a/plugin.go b/plugin.go deleted file mode 100644 index 60a1711..0000000 --- a/plugin.go +++ /dev/null @@ -1,207 +0,0 @@ -package pgs - -import ( - "io" - "log" - "os" - "strconv" - "text/template" - - "github.com/golang/protobuf/protoc-gen-go/generator" -) - -// Plugin describes an official protoc-gen-go plugin that will also be passed a -// pre-configured debugger for use. The plugin must be registered via -// Generator.RegisterPlugin for it to be properly initialized. -type Plugin interface { - generator.Plugin - - // InitContext is called before the Plugin's Init method and is passed a - // pre-configured BuildContext instance. - InitContext(c BuildContext) -} - -// Template describes a template used to render content. Both the text/template -// and html/template packages satisfy this interface. -type Template interface { - Name() string - Execute(wr io.Writer, data interface{}) error -} - -// PluginBase provides utility methods and a base implementation for the -// protoc-gen-go sub-plugin workflow. -type PluginBase struct { - BuildContext - - Generator ProtocGenGo - - seenImports map[string]string - Imports map[string]string - - buildTargets map[string]struct{} -} - -// Name satisfies the protoc-gen-go plugin interface, however this method will -// fail and must be overridden by a parent struct. PluginBase should be used as -// an anonymously embedded field of an actual Plugin implementation. The only -// methods that need to be overridden are Name and Generate. -func (p *PluginBase) Name() string { - p.Fail("Name method is not implemented for this plugin") - return "unimplemented" -} - -// InitContext populates this Plugin with the BuildContext from the parent -// Generator, allowing for easy debug logging, and error checking. This -// method is called prior to Init for plugins registered directly with the -// generator. -func (p *PluginBase) InitContext(c BuildContext) { p.BuildContext = c } - -// Init sets up the plugin with a reference to the generator. This method -// satisfies the Init method for the protoc-gen-go plugin. -func (p *PluginBase) Init(g *generator.Generator) { - if p.BuildContext == nil { - d := initDebugger( - &Generator{pgg: Wrap(g)}, - log.New(os.Stderr, "", 0)).Push("unregistered plugin") - p.BuildContext = Context(d, Parameters{}, ".") - } - - p.Debug("Initializing") - p.Generator = Wrap(g) -} - -// Generate satisfies the protoc-gen-go plugin interface, however this method -// will fail and must be overridden by a parent struct. -func (p *PluginBase) Generate(file *generator.FileDescriptor) { - p.Fail("Generate method is not implemented for this plugin") -} - -var importsTmpl = template.Must(template.New("imports").Parse(`import({{ range $path, $pkg := . }} - {{ $pkg }} "{{ $path }}" -{{- end }} -)`)) - -// GenerateImports adds the imported packages to the top of the file to be -// generated, using the packages included in b.Imports. This method satisfies -// the GenerateImports method for the protoc-gen-go plugin, and is called after -// Generate for each particular FileDescriptor. The added Imports are cleared -// after this call is completed. -func (p *PluginBase) GenerateImports(file *generator.FileDescriptor) { - if p == nil || len(p.Imports) == 0 { - return - } - p.T(importsTmpl, p.Imports) - p.Imports = nil -} - -// AddImport safely registers an import at path with the target pkg name. The -// returned uniquePkg should be used within the code to avoid naming collisions. -// If referencing an entity from a protocol buffer, provide its FileDescriptor -// fd, otherwise leave it as nil. The Imports are cleared after GenerateImports -// is called. -func (p *PluginBase) AddImport(pkg, path string, fd *generator.FileDescriptor) (uniquePkg string) { - if p.seenImports == nil { - p.seenImports = map[string]string{} - } - - if p.Imports == nil { - p.Imports = map[string]string{} - } - - if existing, ok := p.seenImports[path]; ok { - p.Imports[path] = existing - return existing - } - - uniquePkg = generator.RegisterUniquePackageName(pkg, fd) - p.seenImports[path] = uniquePkg - p.Imports[path] = uniquePkg - - return -} - -// P wraps the generator's P method, printing the arguments to the generated -// output. It handles strings and int32s, plus handling indirections because -// they may be *string, etc. -func (p *PluginBase) P(args ...interface{}) { p.Generator.P(args...) } - -// In wraps the generator's In command, indenting the output by one tab. -func (p *PluginBase) In() { p.Generator.In() } - -// Out wraps the generator's Out command, outdenting the output by one tab. -func (p *PluginBase) Out() { p.Generator.Out() } - -// C behaves like the P method, but prints a comment block. -// The wrap parameter indicates what width to wrap the comment at. -func (p *PluginBase) C(wrap int, args ...interface{}) { - s := commentScanner(wrap, args...) - for s.Scan() { - p.P("// ", s.Text()) - } -} - -// C80 curries the C method with the traditional width of 80 characters, -// calling p.C(80, args...). -func (p *PluginBase) C80(args ...interface{}) { p.C(80, args...) } - -// T renders tpl into the target file, using data. The plugin is terminated if -// there is an error executing the template. Both text/template and -// html/template packages are compatible with this method. -func (p *PluginBase) T(tpl Template, data interface{}) { - p.CheckErr( - tpl.Execute(p.Generator, data), - "unable to render template: ", - strconv.Quote(tpl.Name())) -} - -// Push adds a prefix to the plugin's BuildContext. Pop should be called when -// that context is complete. -func (p *PluginBase) Push(prefix string) BuildContext { - p.BuildContext = p.BuildContext.Push(prefix) - return p.BuildContext -} - -// PushDir changes the OutputPath of the plugin's BuildContext. Pop (or PopDir) -// should be called when that context is complete. -func (p *PluginBase) PushDir(dir string) BuildContext { - p.BuildContext = p.BuildContext.PushDir(dir) - return p.BuildContext -} - -// Pop removes the last push from the plugin's BuildContext. This method should -// only be called after a paired Push or PushDir. -func (p *PluginBase) Pop() BuildContext { - p.BuildContext = p.BuildContext.Pop() - return p.BuildContext -} - -// PopDir removes the last PushDir from the plugin's BuildContext. This method -// should only be called after a paired PushDir. -func (p *PluginBase) PopDir() BuildContext { - p.BuildContext = p.BuildContext.PopDir() - return p.BuildContext -} - -// BuildTarget returns true if the specified proto filename was an input to -// protoc. This method is useful to determine if generation logic should be -// executed against it or if it is only loaded as a dependency. This method -// expects the value returned by generator.FileDescriptor.GetName or -// descriptor.FileDescriptorProto.GetName methods. -func (p *PluginBase) BuildTarget(proto string) bool { - if p.buildTargets == nil { - files := p.Generator.request().GetFileToGenerate() - p.buildTargets = make(map[string]struct{}, len(files)) - for _, f := range files { - p.buildTargets[f] = struct{}{} - } - } - - _, ok := p.buildTargets[proto] - return ok -} - -// BuildTargetObj returns whether or not a generator.Object was loaded from a -// BuildTarget file. -func (p *PluginBase) BuildTargetObj(o generator.Object) bool { return p.BuildTarget(o.File().GetName()) } - -var _ Plugin = (*PluginBase)(nil) diff --git a/plugin_test.go b/plugin_test.go deleted file mode 100644 index 29fe6b5..0000000 --- a/plugin_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package pgs - -import ( - "bytes" - "strconv" - "strings" - "testing" - "text/template" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/stretchr/testify/assert" -) - -func TestPluginBase_Name(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { new(PluginBase).Name() }) -} - -func TestPluginBase_InitDebugger(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - g := Init() - pb.InitContext(Context(g.Debugger, Parameters{}, ".")) - - assert.NotNil(t, pb.BuildContext) -} - -func TestPluginBase_Init(t *testing.T) { - t.Parallel() - - g := generator.New() - pb := new(PluginBase) - pb.Init(g) - assert.Equal(t, g, pb.Generator.Unwrap()) -} - -func TestPluginBase_Generate(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { new(PluginBase).Generate(new(generator.FileDescriptor)) }) -} - -func TestPluginBase_Imports(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - - f1 := pb.AddImport("foo", "bar", nil) - f2 := pb.AddImport("foo", "bar", nil) - f3 := pb.AddImport("foo", "baz", nil) - - assert.Equal(t, pb.seenImports, pb.Imports) - assert.Len(t, pb.Imports, 2) - - assert.Equal(t, f1, f2) - assert.NotEqual(t, f2, f3) - assert.Equal(t, f1, pb.Imports["bar"]) - assert.Equal(t, f3, pb.Imports["baz"]) - - assert.NotPanics(t, func() { pb.GenerateImports(nil) }) - assert.Empty(t, pb.Imports) - assert.Len(t, pb.seenImports, 2) - - assert.NotPanics(t, func() { pb.GenerateImports(nil) }) -} - -func TestPluginBase_P(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - pgg := &pluginProtocGenGo{ProtocGenGo: pb.Generator} - pb.Generator = pgg - - pb.P("foo", 123) - assert.Len(t, pgg.p, 2) - assert.Equal(t, "foo", pgg.p[0]) - assert.Equal(t, 123, pgg.p[1]) -} - -func TestPluginBase_In(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - pgg := &pluginProtocGenGo{ProtocGenGo: pb.Generator} - pb.Generator = pgg - - pb.In() - assert.Equal(t, 1, pgg.in) - pb.Out() - assert.Equal(t, 0, pgg.in) -} - -func TestPluginBase_C(t *testing.T) { - t.Parallel() - - tests := []struct { - in []interface{} - ex []interface{} - }{ - { - []interface{}{"foo", " bar", " baz"}, - []interface{}{"// ", "foo bar baz"}, - }, - { - in: []interface{}{"the quick brown fox jumps over the lazy dog"}, - ex: []interface{}{"// ", "the quick brown", "// ", "fox jumps over", "// ", "the lazy dog"}, - }, - { - in: []interface{}{"supercalifragilisticexpialidocious"}, - ex: []interface{}{"// ", "supercalifragilisticexpialidocious"}, - }, - { - in: []interface{}{"1234567890123456789012345 foo"}, - ex: []interface{}{"// ", "1234567890123456789012345", "// ", "foo"}, - }, - } - - pb := new(PluginBase) - pb.Init(generator.New()) - pgg := &pluginProtocGenGo{ProtocGenGo: pb.Generator} - pb.Generator = pgg - - for i, test := range tests { - tc := test - t.Run(strconv.Itoa(i), func(t *testing.T) { - pgg.p = pgg.p[:0] - pb.C(20, tc.in...) - assert.Equal(t, tc.ex, pgg.p) - }) - } -} - -func TestPluginBase_C80(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - pgg := &pluginProtocGenGo{ProtocGenGo: pb.Generator} - pb.Generator = pgg - - pb.C80(strings.Repeat("foo ", 20)) - assert.Equal(t, []interface{}{ - "// ", strings.TrimSpace(strings.Repeat("foo ", 19)), - "// ", "foo", - }, pgg.p) -} - -func TestPluginBase_T(t *testing.T) { - t.Parallel() - - tpl := template.Must(template.New("tpl").Parse(`foo{{ . }}`)) - - pb := new(PluginBase) - g := generator.New() - g.Buffer = new(bytes.Buffer) - pb.Init(g) - - assert.NotPanics(t, func() { pb.T(tpl, "bar") }) - assert.Contains(t, g.Buffer.String(), "foobar") -} - -func TestPluginBase_PushPop(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - - pb.Push("foo") - pb.Pop() -} - -func TestPluginBase_PushPopDir(t *testing.T) { - t.Parallel() - - pb := new(PluginBase) - pb.Init(generator.New()) - - pb.PushDir("foo/bar") - assert.Equal(t, "foo/bar", pb.OutputPath()) - pb.PopDir() - assert.Equal(t, ".", pb.OutputPath()) -} - -func TestPluginBase_BuildTarget(t *testing.T) { - t.Parallel() - - g := generator.New() - g.Request.FileToGenerate = []string{"file.proto"} - - pb := new(PluginBase) - pb.Init(g) - - o := mockGeneratorObj{f: dummyFile().Descriptor()} - - assert.True(t, pb.BuildTarget("file.proto")) - assert.True(t, pb.BuildTargetObj(o)) - - o.f.Name = proto.String("bar") - assert.False(t, pb.BuildTargetObj(o)) -} - -type mockPlugin struct { - *PluginBase - name string -} - -func (p mockPlugin) Name() string { return p.name } - -type pluginProtocGenGo struct { - ProtocGenGo - p []interface{} - in int -} - -func (p *pluginProtocGenGo) Name() string { return "pluginProtocGenGo" } -func (p *pluginProtocGenGo) P(args ...interface{}) { p.p = append(p.p, args...) } -func (p *pluginProtocGenGo) In() { p.in++ } -func (p *pluginProtocGenGo) Out() { p.in-- } - -type mockGeneratorObj struct { - generator.Object - f *generator.FileDescriptor -} - -func (o mockGeneratorObj) File() *generator.FileDescriptor { return o.f } diff --git a/post_process.go b/post_process.go index c09524f..3f406d7 100644 --- a/post_process.go +++ b/post_process.go @@ -1,12 +1,6 @@ package pgs -import ( - "go/format" - "strings" -) - // A PostProcessor modifies the output of an Artifact before final rendering. -// PostProcessors are only applied to Artifacts created by Modules. type PostProcessor interface { // Match returns true if the PostProcess should be applied to the Artifact. // Process is called immediately after Match for the same Artifact. @@ -16,31 +10,3 @@ type PostProcessor interface { // an error if something goes wrong. Process(in []byte) ([]byte, error) } - -type goFmt struct{} - -// GoFmt returns a PostProcessor that runs gofmt on any files ending in ".go" -func GoFmt() PostProcessor { return goFmt{} } - -func (p goFmt) Match(a Artifact) bool { - var n string - - switch a := a.(type) { - case GeneratorFile: - n = a.Name - case GeneratorTemplateFile: - n = a.Name - case CustomFile: - n = a.Name - case CustomTemplateFile: - n = a.Name - default: - return false - } - - return strings.HasSuffix(n, ".go") -} - -func (p goFmt) Process(in []byte) ([]byte, error) { return format.Source(in) } - -var _ PostProcessor = goFmt{} diff --git a/post_process_test.go b/post_process_test.go index d432672..3408d36 100644 --- a/post_process_test.go +++ b/post_process_test.go @@ -1,56 +1,5 @@ package pgs -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGoFmt_Match(t *testing.T) { - t.Parallel() - - pp := GoFmt() - - tests := []struct { - n string - a Artifact - m bool - }{ - {"GenFile", GeneratorFile{Name: "foo.go"}, true}, - {"GenFileNonGo", GeneratorFile{Name: "bar.txt"}, false}, - - {"GenTplFile", GeneratorTemplateFile{Name: "foo.go"}, true}, - {"GenTplFileNonGo", GeneratorTemplateFile{Name: "bar.txt"}, false}, - - {"CustomFile", CustomFile{Name: "foo.go"}, true}, - {"CustomFileNonGo", CustomFile{Name: "bar.txt"}, false}, - - {"CustomTplFile", CustomTemplateFile{Name: "foo.go"}, true}, - {"CustomTplFileNonGo", CustomTemplateFile{Name: "bar.txt"}, false}, - - {"NonMatch", GeneratorAppend{FileName: "foo.go"}, false}, - } - - for _, test := range tests { - tc := test - t.Run(tc.n, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.m, pp.Match(tc.a)) - }) - } -} - -func TestGoFmt_Process(t *testing.T) { - t.Parallel() - - src := []byte("// test\n package foo\n\nvar bar int = 123\n") - exp := []byte("// test\npackage foo\n\nvar bar int = 123\n") - - out, err := GoFmt().Process(src) - assert.NoError(t, err) - assert.Equal(t, exp, out) -} - type mockPP struct { match bool out []byte diff --git a/proto.go b/proto.go index 5672f9a..a6c59e9 100644 --- a/proto.go +++ b/proto.go @@ -9,7 +9,7 @@ const ( // Proto2 syntax permits the use of "optional" and "required" prefixes on // fields. Most of the field types in the generated go structs are pointers. // See: https://developers.google.com/protocol-buffers/docs/proto - Proto2 Syntax = "proto2" + Proto2 Syntax = "" // Proto3 syntax only allows for optional fields, but defaults to the zero // value of that particular type. Most of the field types in the generated go @@ -31,18 +31,18 @@ const ( // Optional (in the context of Proto2 syntax) identifies that the field may // be unset in the proto message. In Proto3 syntax, all fields are considered // Optional and default to their zero value. - Optional ProtoLabel = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_OPTIONAL) + Optional = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_OPTIONAL) // Required (in the context of Proto2 syntax) identifies that the field must // be set in the proto message. In Proto3 syntax, no fields can be identified // as Required. - Required ProtoLabel = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_REQUIRED) + Required = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_REQUIRED) // Repeated identifies that the field either permits multiple entries // (repeated) or is a map (map). Determining which requires further // evaluation of the descriptor and whether or not the embedded message is // identified as a MapEntry (see IsMap on FieldType). - Repeated ProtoLabel = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_REPEATED) + Repeated = ProtoLabel(descriptor.FieldDescriptorProto_LABEL_REPEATED) ) // Proto returns the FieldDescriptorProto_Label for this ProtoLabel. This @@ -52,6 +52,13 @@ func (pl ProtoLabel) Proto() descriptor.FieldDescriptorProto_Label { return descriptor.FieldDescriptorProto_Label(pl) } +// ProtoPtr returns a pointer to the FieldDescriptorProto_Label for this +// ProtoLabel. +func (pl ProtoLabel) ProtoPtr() *descriptor.FieldDescriptorProto_Label { + l := pl.Proto() + return &l +} + // ProtoType wraps the FieldDescriptorProto_Type enum for better readability // and utility methods. It is a 1-to-1 conversion. type ProtoType descriptor.FieldDescriptorProto_Type @@ -59,28 +66,28 @@ type ProtoType descriptor.FieldDescriptorProto_Type // 1-to-1 mapping of FieldDescriptorProto_Type enum to ProtoType. While all are // listed here, group types are not supported by this library. const ( - DoubleT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_DOUBLE) - FloatT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_FLOAT) - Int64T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_INT64) - UInt64T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_UINT64) - Int32T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_INT32) - Fixed64T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_FIXED64) - Fixed32T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_FIXED32) - BoolT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_BOOL) - StringT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_STRING) - GroupT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_GROUP) - MessageT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_MESSAGE) - BytesT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_BYTES) - UInt32T ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_UINT32) - EnumT ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_ENUM) - SFixed32 ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_SFIXED32) - SFixed64 ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_SFIXED64) - SInt32 ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_SINT32) - SInt64 ProtoType = ProtoType(descriptor.FieldDescriptorProto_TYPE_SINT64) + DoubleT = ProtoType(descriptor.FieldDescriptorProto_TYPE_DOUBLE) + FloatT = ProtoType(descriptor.FieldDescriptorProto_TYPE_FLOAT) + Int64T = ProtoType(descriptor.FieldDescriptorProto_TYPE_INT64) + UInt64T = ProtoType(descriptor.FieldDescriptorProto_TYPE_UINT64) + Int32T = ProtoType(descriptor.FieldDescriptorProto_TYPE_INT32) + Fixed64T = ProtoType(descriptor.FieldDescriptorProto_TYPE_FIXED64) + Fixed32T = ProtoType(descriptor.FieldDescriptorProto_TYPE_FIXED32) + BoolT = ProtoType(descriptor.FieldDescriptorProto_TYPE_BOOL) + StringT = ProtoType(descriptor.FieldDescriptorProto_TYPE_STRING) + GroupT = ProtoType(descriptor.FieldDescriptorProto_TYPE_GROUP) + MessageT = ProtoType(descriptor.FieldDescriptorProto_TYPE_MESSAGE) + BytesT = ProtoType(descriptor.FieldDescriptorProto_TYPE_BYTES) + UInt32T = ProtoType(descriptor.FieldDescriptorProto_TYPE_UINT32) + EnumT = ProtoType(descriptor.FieldDescriptorProto_TYPE_ENUM) + SFixed32 = ProtoType(descriptor.FieldDescriptorProto_TYPE_SFIXED32) + SFixed64 = ProtoType(descriptor.FieldDescriptorProto_TYPE_SFIXED64) + SInt32 = ProtoType(descriptor.FieldDescriptorProto_TYPE_SINT32) + SInt64 = ProtoType(descriptor.FieldDescriptorProto_TYPE_SINT64) ) // IsInt returns true if pt maps to an integer-like type. While EnumT types in -// Go are aliases of uint32, to correctly accomodate other languages with +// Go are aliases of uint32, to correctly accommodate other languages with // non-numeric enums, IsInt returns false for EnumT. func (pt ProtoType) IsInt() bool { switch pt { @@ -93,17 +100,20 @@ func (pt ProtoType) IsInt() bool { } // IsNumeric returns true if pt maps to a numeric type. While EnumT types in Go -// are aliases of uint32, to correctly accomodate other languages with non-numeric +// are aliases of uint32, to correctly accommodate other languages with non-numeric // enums, IsNumeric returns false for EnumT. func (pt ProtoType) IsNumeric() bool { return pt == DoubleT || pt == FloatT || pt.IsInt() } -// IsSlice returns true if the type is represented as a slice/array. At this -// time, only BytesT satisfies this condition. -func (pt ProtoType) IsSlice() bool { return pt == BytesT } - // Proto returns the FieldDescriptorProto_Type for this ProtoType. This // method is exclusively used to improve readability without having to switch // the types. func (pt ProtoType) Proto() descriptor.FieldDescriptorProto_Type { return descriptor.FieldDescriptorProto_Type(pt) } + +// ProtoPtr returns a pointer to the FieldDescriptorProto_Type for this +// ProtoType. +func (pt ProtoType) ProtoPtr() *descriptor.FieldDescriptorProto_Type { + t := pt.Proto() + return &t +} diff --git a/proto_test.go b/proto_test.go index b1b38af..e9ad983 100644 --- a/proto_test.go +++ b/proto_test.go @@ -3,6 +3,7 @@ package pgs import ( "testing" + "github.com/golang/protobuf/protoc-gen-go/descriptor" "github.com/stretchr/testify/assert" ) @@ -57,23 +58,21 @@ func TestProtoType_IsNumeric(t *testing.T) { } } -func TestProtoType_IsSlice(t *testing.T) { +func TestProtoType_Proto(t *testing.T) { t.Parallel() - yes := []ProtoType{BytesT} + pt := BytesT.Proto() + ptPtr := BytesT.ProtoPtr() + assert.Equal(t, descriptor.FieldDescriptorProto_TYPE_BYTES, pt) + assert.Equal(t, pt, *ptPtr) +} - no := []ProtoType{ - Int64T, UInt64T, SFixed64, SInt64, Fixed64T, - Int32T, UInt32T, SFixed32, SInt32, Fixed32T, - DoubleT, FloatT, BoolT, StringT, GroupT, - MessageT, EnumT, - } +func TestProtoLabel_Proto(t *testing.T) { + t.Parallel() - for _, pt := range yes { - assert.True(t, pt.IsSlice()) - } + pl := Repeated.Proto() + plPtr := Repeated.ProtoPtr() - for _, pt := range no { - assert.False(t, pt.IsSlice()) - } + assert.Equal(t, descriptor.FieldDescriptorProto_LABEL_REPEATED, pl) + assert.Equal(t, pl, *plPtr) } diff --git a/protoc-gen-debug/README.md b/protoc-gen-debug/README.md new file mode 100644 index 0000000..d4582ee --- /dev/null +++ b/protoc-gen-debug/README.md @@ -0,0 +1,52 @@ +# protoc-gen-debug + +This plugin can be used to create test files containing the entire encoded CodeGeneratorRequest passed from a protoc execution. This is useful for testing plugins programmatically without having to run protoc. For an example usage, check out [`ast_test.go`](../ast_test.go) in the project root as well as [`testdata/graph`](../testdata/graph) for the test cases. + +Executing the plugin will place a `code_generator_request.pb.bin` file in the specified output location which can be fed directly into a PG* plugin via the `ProtocInput` init option. + +## Installation + +For a local install: + +```bash +make bin/protoc-gen-debug +``` + +For a global install into `$GOPATH/bin`: + +```bash +go install github.com/lyft/protoc-gen-star/protoc-gen-debug +``` + +## Usage + +To create the `code_generator_request.pb.bin` file for all protos in the current directory: + +```bash +protoc \ + --plugin=protoc-gen-debug=path/to/protoc-gen-debug \ + --debug_out=".:." \ + *.proto +``` + +To use the `code_generator_request.pb.bin` in PG*: + +```go +func TestModule(t *testing.T) { + req, err := os.Open("./code_generator_request.pb.bin") + if err != nil { + t.Fatal(err) + } + + fs := afero.NewMemMapFs() + res := &bytes.Buffer{} + + pgs.Init( + pgs.ProtocInput(req), // use the pre-generated request + pgs.ProtocOutput(res), // capture CodeGeneratorResponse + pgs.FileSystem(fs), // capture any custom files written directly to disk + ).RegisterModule(&MyModule{}).Render() + + // check res and the fs for output +} +``` diff --git a/protoc-gen-debug/main.go b/protoc-gen-debug/main.go new file mode 100644 index 0000000..87a1d64 --- /dev/null +++ b/protoc-gen-debug/main.go @@ -0,0 +1,48 @@ +// protoc-gen-debug emits the raw encoded CodeGeneratorRequest from a protoc +// execution to a file. This is particularly useful for testing (see the +// testdata/graph package for test cases). +package main + +import ( + "bytes" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/plugin" +) + +func main() { + data, err := ioutil.ReadAll(os.Stdin) + if err != nil { + log.Fatal("unable to read input: ", err) + } + + req := &plugin_go.CodeGeneratorRequest{} + if err = proto.Unmarshal(data, req); err != nil { + log.Fatal("unable to unmarshal request: ", err) + } + + path := req.GetParameter() + if path == "" { + log.Fatal(`please execute the plugin with the output path to properly write the output file: --debug_out="{PATH}:{PATH}"`) + } + + err = ioutil.WriteFile(filepath.Join(path, "code_generator_request.pb.bin"), data, 0644) + if path == "" { + log.Fatal("unable to write request to disk: ", err) + } + + data, err = proto.Marshal(&plugin_go.CodeGeneratorResponse{}) + if err != nil { + log.Fatal("unable to marshal response payload: ", err) + } + + _, err = io.Copy(os.Stdout, bytes.NewReader(data)) + if err != nil { + log.Fatal("unable to write response to stdout: ", err) + } +} diff --git a/protoc_gen_go.go b/protoc_gen_go.go deleted file mode 100644 index 30a695d..0000000 --- a/protoc_gen_go.go +++ /dev/null @@ -1,62 +0,0 @@ -package pgs - -import ( - "io" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/golang/protobuf/protoc-gen-go/plugin" -) - -// ProtocGenGo is a superset of the generator.Generator API from the -// protoc-gen-go library. It exposes many of the members of the original struct, -// but also exposes others that permit easier testing of code that relies upon -// accessing protected members. -type ProtocGenGo interface { - // Unwrap returns the underlying generator.Generator instance. Typically this - // is called to access public fields off this struct. - Unwrap() *generator.Generator - - // The following methods/interfaces match the interface of a protoc-gen-go - // generator.Generator struct. - io.Writer - Error(err error, msgs ...string) - Fail(msgs ...string) - ObjectNamed(n string) generator.Object - GoType(message *generator.Descriptor, field *descriptor.FieldDescriptorProto) (typ string, wire string) - GoPackageName(importPath generator.GoImportPath) generator.GoPackageName - P(args ...interface{}) - In() - Out() - - // The following methods simplify execution in the protoc-gen-star Generator & Gatherer - prepare(params Parameters) - generate() - request() *plugin_go.CodeGeneratorRequest - setRequest(req *plugin_go.CodeGeneratorRequest) - response() *plugin_go.CodeGeneratorResponse - setResponse(res *plugin_go.CodeGeneratorResponse) -} - -// Wrap converts a generator.Generator instance into a type that satisfies the -// ProtocGenGo interface. -func Wrap(g *generator.Generator) ProtocGenGo { return &wrappedPGG{g} } - -type wrappedPGG struct{ *generator.Generator } - -func (pgg *wrappedPGG) Unwrap() *generator.Generator { return pgg.Generator } -func (pgg *wrappedPGG) request() *plugin_go.CodeGeneratorRequest { return pgg.Request } -func (pgg *wrappedPGG) setRequest(req *plugin_go.CodeGeneratorRequest) { pgg.Request = req } -func (pgg *wrappedPGG) response() *plugin_go.CodeGeneratorResponse { return pgg.Response } -func (pgg *wrappedPGG) setResponse(res *plugin_go.CodeGeneratorResponse) { pgg.Response = res } - -func (pgg *wrappedPGG) prepare(params Parameters) { - pgg.CommandLineParameters(params.String()) - pgg.WrapTypes() - pgg.SetPackageNames() - pgg.BuildTypeNameMap() -} - -func (pgg *wrappedPGG) generate() { pgg.GenerateAllFiles() } - -var _ ProtocGenGo = (*wrappedPGG)(nil) diff --git a/protoc_gen_go_test.go b/protoc_gen_go_test.go deleted file mode 100644 index 5625357..0000000 --- a/protoc_gen_go_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package pgs - -import ( - "testing" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/golang/protobuf/protoc-gen-go/plugin" - "github.com/stretchr/testify/assert" -) - -func TestWrappedPGG_SetRequest(t *testing.T) { - t.Parallel() - - wrapped := Wrap(&generator.Generator{}) - - assert.Nil(t, wrapped.request()) - - req := &plugin_go.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - wrapped.setRequest(req) - - assert.Equal(t, req, wrapped.request()) -} - -func TestWrappedPGG_SetResponse(t *testing.T) { - t.Parallel() - - wrapped := Wrap(&generator.Generator{}) - - assert.Nil(t, wrapped.response()) - - res := &plugin_go.CodeGeneratorResponse{Error: proto.String("foo")} - wrapped.setResponse(res) - - assert.Equal(t, res, wrapped.response()) -} diff --git a/service.go b/service.go index 12a9374..a4a9f1a 100644 --- a/service.go +++ b/service.go @@ -5,7 +5,7 @@ import ( "github.com/golang/protobuf/protoc-gen-go/descriptor" ) -// Service describes an proto service +// Service describes a proto service definition (typically, gRPC) type Service interface { Entity @@ -24,7 +24,7 @@ type service struct { methods []Method file File - comments string + info SourceCodeInfo } func (s *service) Name() Name { return Name(s.desc.GetName()) } @@ -33,14 +33,14 @@ func (s *service) Syntax() Syntax { return s.fil func (s *service) Package() Package { return s.file.Package() } func (s *service) File() File { return s.file } func (s *service) BuildTarget() bool { return s.file.BuildTarget() } -func (s *service) Comments() string { return s.comments } +func (s *service) SourceCodeInfo() SourceCodeInfo { return s.info } func (s *service) Descriptor() *descriptor.ServiceDescriptorProto { return s.desc } func (s *service) Extension(desc *proto.ExtensionDesc, ext interface{}) (bool, error) { return extension(s.desc.GetOptions(), desc, &ext) } -func (s *service) Imports() (i []Package) { +func (s *service) Imports() (i []File) { for _, m := range s.methods { i = append(i, m.Imports()...) } @@ -77,3 +77,20 @@ func (s *service) accept(v Visitor) (err error) { return } + +func (s *service) childAtPath(path []int32) Entity { + switch { + case len(path) == 0: + return s + case len(path)%2 != 0: + return nil + case path[0] == serviceTypeMethodPath: + return s.methods[path[1]].childAtPath(path[2:]) + default: + return nil + } +} + +func (s *service) addSourceCodeInfo(info SourceCodeInfo) { s.info = info } + +var _ Service = (*service)(nil) diff --git a/service_test.go b/service_test.go index 627aad3..2d396cc 100644 --- a/service_test.go +++ b/service_test.go @@ -90,7 +90,7 @@ func TestService_Imports(t *testing.T) { s := &service{} assert.Empty(t, s.Imports()) - s.addMethod(&mockMethod{i: []Package{&pkg{}}}) + s.addMethod(&mockMethod{i: []File{&file{}}}) assert.Len(t, s.Imports(), 1) } @@ -135,14 +135,23 @@ func TestService_Accept(t *testing.T) { assert.Equal(t, 2, v.method) } +func TestService_ChildAtPath(t *testing.T) { + t.Parallel() + + s := &service{} + assert.Equal(t, s, s.childAtPath(nil)) + assert.Nil(t, s.childAtPath([]int32{0})) + assert.Nil(t, s.childAtPath([]int32{0, 0})) +} + type mockService struct { Service - i []Package + i []File f File err error } -func (s *mockService) Imports() []Package { return s.i } +func (s *mockService) Imports() []File { return s.i } func (s *mockService) setFile(f File) { s.f = f } diff --git a/source_code_info.go b/source_code_info.go new file mode 100644 index 0000000..815e2ba --- /dev/null +++ b/source_code_info.go @@ -0,0 +1,56 @@ +package pgs + +import ( + "github.com/golang/protobuf/protoc-gen-go/descriptor" +) + +const ( + packagePath int32 = 2 // FileDescriptorProto.Package + messageTypePath int32 = 4 // FileDescriptorProto.MessageType + enumTypePath int32 = 5 // FileDescriptorProto.EnumType + servicePath int32 = 6 // FileDescriptorProto.Service + syntaxPath int32 = 12 // FileDescriptorProto.Syntax + messageTypeFieldPath int32 = 2 // DescriptorProto.Field + messageTypeNestedTypePath int32 = 3 // DescriptorProto.NestedType + messageTypeEnumTypePath int32 = 4 // DescriptorProto.EnumType + messageTypeOneofDeclPath int32 = 8 // DescriptorProto.OneofDecl + enumTypeValuePath int32 = 2 // EnumDescriptorProto.Value + serviceTypeMethodPath int32 = 2 // ServiceDescriptorProto.Method +) + +// SourceCodeInfo represents data about an entity from the source. Currently +// this only contains information about comments protoc associates with +// entities. +// +// All comments have their // or /* */ stripped by protoc. See the +// SourceCodeInfo documentation for more details about how comments are +// associated with entities. +type SourceCodeInfo interface { + // Location returns the SourceCodeInfo_Location from the file descriptor. + Location() *descriptor.SourceCodeInfo_Location + + // LeadingComments returns any comment immediately preceding the entity, + // without any whitespace between it and the comment. + LeadingComments() string + + // LeadingDetachedComments returns each comment block or line above the + // entity but separated by whitespace. + LeadingDetachedComments() []string + + // TrailingComments returns any comment immediately following the entity, + // without any whitespace between it and the comment. If the comment would be + // a leading comment for another entity, it won't be considered a trailing + // comment. + TrailingComments() string +} + +type sci struct { + desc *descriptor.SourceCodeInfo_Location +} + +func (info sci) Location() *descriptor.SourceCodeInfo_Location { return info.desc } +func (info sci) LeadingComments() string { return info.desc.GetLeadingComments() } +func (info sci) LeadingDetachedComments() []string { return info.desc.GetLeadingDetachedComments() } +func (info sci) TrailingComments() string { return info.desc.GetTrailingComments() } + +var _ SourceCodeInfo = sci{} diff --git a/source_code_info_test.go b/source_code_info_test.go new file mode 100644 index 0000000..a10764e --- /dev/null +++ b/source_code_info_test.go @@ -0,0 +1,26 @@ +package pgs + +import ( + "testing" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/stretchr/testify/assert" +) + +func TestSourceCodeInfo(t *testing.T) { + t.Parallel() + + desc := &descriptor.SourceCodeInfo_Location{ + LeadingComments: proto.String("leading"), + TrailingComments: proto.String("trailing"), + LeadingDetachedComments: []string{"detached"}, + } + + info := sci{desc} + + assert.Equal(t, desc, info.Location()) + assert.Equal(t, "leading", info.LeadingComments()) + assert.Equal(t, "trailing", info.TrailingComments()) + assert.Equal(t, []string{"detached"}, info.LeadingDetachedComments()) +} diff --git a/testdata/graph/README.md b/testdata/graph/README.md new file mode 100644 index 0000000..289049c --- /dev/null +++ b/testdata/graph/README.md @@ -0,0 +1,13 @@ +# AST Graph Test Data + +This directory contains various test proto file sets for black-box testing of the AST gatherer `graph`. + +Proto files are preprocessed to their descriptors, imported directly into the `ast_test.go` tests, and unmarshaled as a `DescriptorFileSet`. + +## To Generate + +From the project root: + +```sh +make testdata-graph +``` diff --git a/testdata/graph/info/info.proto b/testdata/graph/info/info.proto new file mode 100644 index 0000000..37b08d3 --- /dev/null +++ b/testdata/graph/info/info.proto @@ -0,0 +1,56 @@ +// syntax +syntax="proto3"; + +// package +package graph.info; + +// root message +message Info { + // before message + message Before {} + + // before enum + enum BeforeEnum { + // before enum value + BEFORE = 0; + } + + // field + map field = 1; + + // middle message + message Middle { + // inner field + bool inner = 1; + } + + // other field + repeated int32 other_field = 2; + + // after message + message After {} + + // after enum + enum AfterEnum { + // after enum value + AFTER = 0; + } + + // oneof + oneof OneOf { + // oneof field + After oneof_field = 3; + } +} + +// root enum comment +enum Enum { + // root enum value + ROOT = 0; +} + +// service +service Service { + // method + rpc Method(Info) returns (Info); +} diff --git a/testdata/graph/messages/embedded.proto b/testdata/graph/messages/embedded.proto new file mode 100644 index 0000000..bbb39fe --- /dev/null +++ b/testdata/graph/messages/embedded.proto @@ -0,0 +1,24 @@ +syntax="proto3"; +package graph.messages; + +import "messages/scalars.proto"; +import "google/protobuf/duration.proto"; + +message Before {} + +message Embedded { + message NestedBefore {} + + Before local_before = 1; + After local_after = 2; + + NestedBefore nested_before = 3; + NestedAfter nested_after = 4; + + Scalars external_in_package = 5; + google.protobuf.Duration external_3rd_party = 6; + + message NestedAfter {} +} + +message After {} diff --git a/testdata/graph/messages/enums.proto b/testdata/graph/messages/enums.proto new file mode 100644 index 0000000..345a32c --- /dev/null +++ b/testdata/graph/messages/enums.proto @@ -0,0 +1,24 @@ +syntax="proto3"; +package graph.messages; + +import "messages/enums_ext.proto"; +import "google/protobuf/type.proto"; + +enum BeforeEnum { BEFORE_VALUE = 0; } + +message Enums { + enum NestedBefore { BEFORE_VALUE = 0; } + + BeforeEnum before = 1; + AfterEnum after = 2; + + NestedBefore nested_before = 3; + NestedAfter nested_after = 4; + + External external_in_package = 5; + google.protobuf.Syntax external_3rd_party = 6; + + enum NestedAfter { AFTER_VALUE = 0; } +} + +enum AfterEnum { AFTER_VALUE = 0; } diff --git a/testdata/graph/messages/enums_ext.proto b/testdata/graph/messages/enums_ext.proto new file mode 100644 index 0000000..eb7d853 --- /dev/null +++ b/testdata/graph/messages/enums_ext.proto @@ -0,0 +1,4 @@ +syntax="proto3"; +package graph.messages; + +enum External { EXT_VALUE = 0; } diff --git a/testdata/graph/messages/maps.proto b/testdata/graph/messages/maps.proto new file mode 100644 index 0000000..4e94084 --- /dev/null +++ b/testdata/graph/messages/maps.proto @@ -0,0 +1,44 @@ +syntax="proto3"; +package graph.messages; + +import "messages/scalars.proto"; +import "messages/enums_ext.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/type.proto"; + +message BeforeMapMsg {} +enum BeforeMapEnum { BME_BEFORE = 0; } + +message Maps { + message NestedBeforeMsg {} + enum NestedBeforeEnum { BME_BEFORE = 0; } + + map scalar = 1; + + map before_msg = 2; + map after_msg = 3; + + map before_enum = 4; + map after_enum = 5; + + map nested_before_msg = 6; + map nested_after_msg = 7; + + map nested_before_enum = 8; + map nested_after_enum = 9; + + map external_in_package_msg = 10; + map external_in_package_enum = 11; + + map external_3rd_party_msg = 12; + map external_3rd_party_enum = 13; + + // this is a message! + message NestedAfterMsg {} + + // this is an enum! + enum NestedAfterEnum { AME_AFTER = 0; } +} + +message AfterMapMsg {} +enum AfterMapEnum { AME_AFTER = 0; } diff --git a/testdata/graph/messages/oneofs.proto b/testdata/graph/messages/oneofs.proto new file mode 100644 index 0000000..919451c --- /dev/null +++ b/testdata/graph/messages/oneofs.proto @@ -0,0 +1,13 @@ +syntax="proto3"; +package graph.messages; + +message OneOfs { + string before = 1; + + oneof oneof { + int32 inside = 2; + } + + bool after = 3; +} + diff --git a/testdata/graph/messages/recursive.proto b/testdata/graph/messages/recursive.proto new file mode 100644 index 0000000..0c84e73 --- /dev/null +++ b/testdata/graph/messages/recursive.proto @@ -0,0 +1,6 @@ +syntax="proto3"; +package graph.messages; + +message Recursive { + Recursive recurse = 1; +} diff --git a/testdata/graph/messages/repeated.proto b/testdata/graph/messages/repeated.proto new file mode 100644 index 0000000..50ed7f6 --- /dev/null +++ b/testdata/graph/messages/repeated.proto @@ -0,0 +1,41 @@ +syntax="proto3"; +package graph.messages; + +import "messages/scalars.proto"; +import "messages/enums_ext.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/type.proto"; + +message BeforeRepMsg {} +enum BeforeRepEnum { BRE_BEFORE = 0; } + +message Repeated { + message NestedBeforeMsg {} + enum NestedBeforeEnum { BME_BEFORE = 0; } + + repeated string scalar = 1; + + repeated BeforeRepMsg before_msg = 2; + repeated AfterRepMsg after_msg = 3; + + repeated BeforeRepEnum before_enum = 4; + repeated AfterRepEnum after_enum = 5; + + repeated NestedBeforeMsg nested_before_msg = 6; + repeated NestedAfterMsg nested_after_msg = 7; + + repeated NestedBeforeEnum nested_before_enum = 8; + repeated NestedAfterEnum nested_after_enum = 9; + + repeated Scalars external_in_package_msg = 10; + repeated External external_in_package_enum = 11; + + repeated google.protobuf.Duration external_3rd_party_msg = 12; + repeated google.protobuf.Syntax external_3rd_party_enum = 13; + + message NestedAfterMsg {} + enum NestedAfterEnum { AME_AFTER = 0; } +} + +message AfterRepMsg {} +enum AfterRepEnum { ARE_AFTER = 0; } diff --git a/testdata/graph/messages/scalars.proto b/testdata/graph/messages/scalars.proto new file mode 100644 index 0000000..d42e744 --- /dev/null +++ b/testdata/graph/messages/scalars.proto @@ -0,0 +1,20 @@ +syntax="proto3"; +package graph.messages; + +message Scalars { + double double = 1; + float float = 2; + int32 int32 = 3; + int64 int64 = 4; + uint32 uint32 = 5; + uint64 uint64 = 6; + sint32 sint32 = 7; + sint64 sint64 = 8; + fixed32 fixed32 = 9; + fixed64 fixed64 = 10; + sfixed32 sfixed32 = 11; + sfixed64 sfixed64 = 12; + bool bool = 13; + string string = 14; + bytes bytes = 15; +} diff --git a/testdata/graph/nested/nested.proto b/testdata/graph/nested/nested.proto new file mode 100644 index 0000000..a961891 --- /dev/null +++ b/testdata/graph/nested/nested.proto @@ -0,0 +1,27 @@ +syntax="proto3"; +package graph.nested; + +message Foo { + Bar x = 1; // usage before declaration + + // nested message + message Bar { + Baz a = 1; // usage before declaration + + // doubly nested enum + enum Baz {VALUE = 0;} + Baz b = 2; // usage after declaration + + // doubly nested message + message Quux {} + Quux c = 3; + } + + Bar y = 2; // usage after declaration + + // same name, different scope + enum Baz {VALUE = 0;} + Baz shallow = 3; + + Bar.Baz deep = 4; // usage of deeply nested child enum +} diff --git a/testdata/graph/services/services.proto b/testdata/graph/services/services.proto new file mode 100644 index 0000000..b8d3f0c --- /dev/null +++ b/testdata/graph/services/services.proto @@ -0,0 +1,27 @@ +syntax="proto3"; +package graph.services; + +message BeforeRequest {} +message BeforeResponse { + int32 foo = 99; // comment +} + +service Empty {} + +// unary only methods +service Unary { + // message come before + rpc UnaryBefore(BeforeRequest) returns (BeforeResponse); + + // messages come after + rpc UnaryAfter(AfterRequest) returns (AfterResponse); +} + +service Streaming { + rpc ClientStream(stream BeforeRequest) returns (BeforeResponse); + rpc ServerStream(AfterRequest) returns (stream AfterResponse); + rpc BiDiStream(stream BeforeRequest) returns (stream AfterResponse); +} + +message AfterRequest {} +message AfterResponse {} diff --git a/testdata/protoc-gen-example/jsonify.go b/testdata/protoc-gen-example/jsonify.go new file mode 100644 index 0000000..243aeb8 --- /dev/null +++ b/testdata/protoc-gen-example/jsonify.go @@ -0,0 +1,114 @@ +package main + +import ( + "text/template" + + "github.com/lyft/protoc-gen-star/lang/go" + + "github.com/lyft/protoc-gen-star" +) + +// JSONifyPlugin adds encoding/json Marshaler and Unmarshaler methods on PB +// messages that utilizes the more correct jsonpb package. +// See: https://godoc.org/github.com/golang/protobuf/jsonpb +type JSONifyModule struct { + *pgs.ModuleBase + ctx pgsgo.Context + tpl *template.Template +} + +// JSONify returns an initialized JSONifyPlugin +func JSONify() *JSONifyModule { return &JSONifyModule{ModuleBase: &pgs.ModuleBase{}} } + +func (p *JSONifyModule) InitContext(c pgs.BuildContext) { + p.ModuleBase.InitContext(c) + p.ctx = pgsgo.InitContext(c.Parameters()) + + tpl := template.New("jsonify").Funcs(map[string]interface{}{ + "package": p.ctx.PackageName, + "name": p.ctx.Name, + "marshaler": p.marshaler, + "unmarshaler": p.unmarshaler, + }) + + p.tpl = template.Must(tpl.Parse(jsonifyTpl)) +} + +// Name satisfies the generator.Plugin interface. +func (p *JSONifyModule) Name() string { return "jsonify" } + +func (p *JSONifyModule) Execute(targets map[string]pgs.File, pkgs map[string]pgs.Package) []pgs.Artifact { + + for _, t := range targets { + p.generate(t) + } + + return p.Artifacts() +} + +func (p *JSONifyModule) generate(f pgs.File) { + if len(f.Messages()) == 0 { + return + } + + name := p.ctx.OutputPath(f).SetExt(".json.go") + p.AddGeneratorTemplateFile(name.String(), p.tpl, f) +} + +func (p *JSONifyModule) marshaler(m pgs.Message) pgs.Name { + return p.ctx.Name(m) + "JSONMarshaler" +} + +func (p *JSONifyModule) unmarshaler(m pgs.Message) pgs.Name { + return p.ctx.Name(m) + "JSONUnmarshaler" +} + +const jsonifyTpl = `package {{ package . }} + +import ( + "bytes" + "encoding/json" + + "github.com/golang/protobuf/jsonpb" +) + +{{ range .AllMessages }} + +// {{ marshaler . }} describes the default jsonpb.Marshaler used by all +// instances of {{ name . }}. This struct is safe to replace or modify but +// should not be done so concurrently. +var {{ marshaler . }} = new(jsonpb.Marshaler) + +// MarshalJSON satisfies the encoding/json Marshaler interface. This method +// uses the more correct jsonpb package to correctly marshal the message. +func (m *{{ name . }}) MarshalJSON() ([]byte, error) { + if m == nil { + return json.Marshal(nil) + } + + + buf := &bytes.Buffer{} + if err := {{ marshaler . }}.Marshal(buf, m); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +var _ json.Marshaler = (*{{ name . }})(nil) + +// {{ unmarshaler . }} describes the default jsonpb.Unmarshaler used by all +// instances of {{ name . }}. This struct is safe to replace or modify but +// should not be done so concurrently. +var {{ unmarshaler . }} = new(jsonpb.Unmarshaler) + +// UnmarshalJSON satisfies the encoding/json Unmarshaler interface. This method +// uses the more correct jsonpb package to correctly unmarshal the message. +func (m *{{ name . }}) UnmarshalJSON(b []byte) error { + return {{ unmarshaler . }}.Unmarshal(bytes.NewReader(b), m) +} + +var _ json.Unmarshaler = (*{{ name . }})(nil) + +{{ end }} +` diff --git a/testdata/protoc-gen-example/jsonify_plugin.go b/testdata/protoc-gen-example/jsonify_plugin.go deleted file mode 100644 index b1fbd6c..0000000 --- a/testdata/protoc-gen-example/jsonify_plugin.go +++ /dev/null @@ -1,124 +0,0 @@ -package main - -import ( - "text/template" - - "fmt" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/golang/protobuf/protoc-gen-go/generator" - "github.com/lyft/protoc-gen-star" -) - -var ( - marshalJSONTpl = template.Must(template.New("marshalJSON").Parse(marshalJSON)) - unmarshalJSONTpl = template.Must(template.New("unmarshalJSON").Parse(unmarshalJSON)) -) - -// JSONifyPlugin adds encoding/json Marshaler and Unmarshaler methods on PB -// messages that utilizes the more correct jsonpb package. -// See: https://godoc.org/github.com/golang/protobuf/jsonpb -type JSONifyPlugin struct { - *pgs.PluginBase -} - -// JSONify returns an initialized JSONifyPlugin -func JSONify() *JSONifyPlugin { return &JSONifyPlugin{&pgs.PluginBase{}} } - -// Name satisfies the generator.Plugin interface. -func (p *JSONifyPlugin) Name() string { return "jsonify" } - -// Generate satisfies the generator.Plugin interface. -func (p *JSONifyPlugin) Generate(file *generator.FileDescriptor) { - if !p.BuildTarget(file.GetName()) || len(file.GetMessageType()) == 0 { - return - } - - p.Push(file.GetName()) - defer p.Pop() - - jpb := p.AddImport("jsonpb", "github.com/golang/protobuf/jsonpb", nil) - js := p.AddImport("json", "encoding/json", nil) - bytes := p.AddImport("bytes", "bytes", nil) - - for _, m := range file.GetMessageType() { - p.generateMessage(msgData{ - DescriptorProto: m, - Bytes: bytes, - JSON: js, - JSONPB: jpb, - }) - } -} - -func (p *JSONifyPlugin) generateMessage(m msgData) { - if m.GetOptions().GetMapEntry() { - return - } - - p.Push(m.GetName()).Debug("implementing json.Marshaler/Unmarshaler interface") - defer p.Pop() - - p.C80(m.Name(), "Marshaler describes the default jsonpb.Marshaler used by all instances of ", - m.Name(), ". This struct is safe to replace or modify but should not be done so concurrently.") - p.P("var ", m.Name(), "Marshaler = new(", m.JSONPB, ".Marshaler)") - - p.C80("MarshalJSON satisfies the encoding/json Marshaler interface. This method uses the more correct jsonpb package to correctly marshal the message.") - p.T(marshalJSONTpl, m) - - p.C80(m.Name(), "Unmarshaler describes the default jsonpb.Unmarshaler used by all instances of ", - m.Name(), ". This struct is safe to replace or modify but should not be done so concurrently.") - p.P("var ", m.Name(), "Unmarshaler = new(", m.JSONPB, ".Unmarshaler)") - - p.C80("UnmarshalJSON satisfies the encoding/json Unmarshaler interface. This method uses the more correct jsonpb package to correctly unmarshal the message.") - p.T(unmarshalJSONTpl, m) - - for _, nm := range m.GetNestedType() { - p.generateMessage(msgData{ - DescriptorProto: nm, - Parent: m.Name(), - Bytes: m.Bytes, - JSON: m.JSON, - JSONPB: m.JSONPB}) - } -} - -type msgData struct { - *descriptor.DescriptorProto - Parent string - Bytes, JSON, JSONPB string -} - -func (d msgData) Name() string { - if d.Parent == "" { - return d.GetName() - } - - return fmt.Sprintf("%s_%s", d.Parent, d.GetName()) -} - -const marshalJSON = `func (m *{{ .Name }}) MarshalJSON() ([]byte, error) { - if m == nil { - return {{ .JSON }}.Marshal(nil) - } - - - buf := &{{ .Bytes }}.Buffer{} - if err := {{ .Name }}Marshaler.Marshal(buf, m); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -var _ {{ .JSON }}.Marshaler = (*{{ .Name }})(nil) -` - -const unmarshalJSON = `func (m *{{ .Name }}) UnmarshalJSON(b []byte) error { - return {{ .Name }}Unmarshaler.Unmarshal({{ .Bytes }}.NewReader(b), m) -} - -var _ {{ .JSON }}.Unmarshaler = (*{{ .Name }})(nil) -` - -var _ pgs.Plugin = (*JSONifyPlugin)(nil) diff --git a/testdata/protoc-gen-example/main.go b/testdata/protoc-gen-example/main.go index 98a49b3..ca0d487 100644 --- a/testdata/protoc-gen-example/main.go +++ b/testdata/protoc-gen-example/main.go @@ -1,11 +1,17 @@ package main -import "github.com/lyft/protoc-gen-star" +import ( + "github.com/lyft/protoc-gen-star" + "github.com/lyft/protoc-gen-star/lang/go" +) func main() { - pgs.Init(pgs.IncludeGo(), pgs.DebugEnv("DEBUG"), pgs.MultiPackage()). - RegisterPlugin(JSONify()). - RegisterModule(ASTPrinter()). - RegisterPostProcessor(pgs.GoFmt()). - Render() + pgs.Init( + pgs.DebugEnv("DEBUG"), + ).RegisterModule( + ASTPrinter(), + JSONify(), + ).RegisterPostProcessor( + pgsgo.GoFmt(), + ).Render() } diff --git a/testdata/protoc-gen-example/printer_module.go b/testdata/protoc-gen-example/printer.go similarity index 81% rename from testdata/protoc-gen-example/printer_module.go rename to testdata/protoc-gen-example/printer.go index 89df4b8..1d04367 100644 --- a/testdata/protoc-gen-example/printer_module.go +++ b/testdata/protoc-gen-example/printer.go @@ -1,11 +1,12 @@ package main import ( - "bytes" "fmt" "io" "strings" + "bytes" + "github.com/lyft/protoc-gen-star" ) @@ -17,25 +18,34 @@ func ASTPrinter() *PrinterModule { return &PrinterModule{ModuleBase: &pgs.Module func (p *PrinterModule) Name() string { return "printer" } -func (p *PrinterModule) Execute(pkg pgs.Package, pkgs map[string]pgs.Package) []pgs.Artifact { - p.PushDir(pkg.Files()[0].OutputPath().Dir().String()) +func (p *PrinterModule) Execute(targets map[string]pgs.File, packages map[string]pgs.Package) []pgs.Artifact { + buf := &bytes.Buffer{} + + for _, f := range targets { + p.printFile(f, buf) + } + + return p.Artifacts() +} + +func (p *PrinterModule) printFile(f pgs.File, buf *bytes.Buffer) { + p.Push(f.Name().String()) defer p.Pop() - p.Debug("printing:", pkg.GoName()) - buf := &bytes.Buffer{} + buf.Reset() v := initPrintVisitor(buf, "") - p.CheckErr(pgs.Walk(v, pkg), "unable to print AST tree") + p.CheckErr(pgs.Walk(v, f), "unable to print AST tree") + + out := buf.String() if ok, _ := p.Parameters().Bool("log_tree"); ok { - p.Logf("Proto Tree:\n%s", buf.String()) + p.Logf("Proto Tree:\n%s", out) } p.AddGeneratorFile( - p.JoinPath(pkg.GoName().LowerSnakeCase().String()+".tree.txt"), - buf.String(), + f.InputPath().SetExt(".tree.txt").String(), + out, ) - - return p.Artifacts() } const ( @@ -76,10 +86,6 @@ func (v PrinterVisitor) writeLeaf(str string) { fmt.Fprintf(v.w, "%s%s%s\n", v.leafPrefix(), leafNodeSpacer, str) } -func (v PrinterVisitor) VisitPackage(p pgs.Package) (pgs.Visitor, error) { - return v.writeSubNode("Package: " + p.GoName().String()), nil -} - func (v PrinterVisitor) VisitFile(f pgs.File) (pgs.Visitor, error) { return v.writeSubNode("File: " + f.Name().String()), nil } diff --git a/testdata/protos/kitchen/kitchen.proto b/testdata/protos/kitchen/kitchen.proto index 804c268..5c47d68 100644 --- a/testdata/protos/kitchen/kitchen.proto +++ b/testdata/protos/kitchen/kitchen.proto @@ -1,7 +1,7 @@ syntax = "proto3"; package kitchen; -option go_package = "kitchen"; +option go_package = "github.com/lyft/protoc-gen-star/testdata/generated/kitchen"; import "kitchen/sink.proto"; import "google/protobuf/timestamp.proto"; diff --git a/testdata/protos/kitchen/sink.proto b/testdata/protos/kitchen/sink.proto index 35a91c9..b78a0f1 100644 --- a/testdata/protos/kitchen/sink.proto +++ b/testdata/protos/kitchen/sink.proto @@ -1,7 +1,7 @@ syntax = "proto3"; package kitchen; -option go_package = "kitchen"; +option go_package = "github.com/lyft/protoc-gen-star/testdata/generated/kitchen"; import "google/protobuf/timestamp.proto"; diff --git a/testdata/protos/multipackage/bar/baz/quux.proto b/testdata/protos/multipackage/bar/baz/quux.proto index 0f89ce0..5fd2f7e 100644 --- a/testdata/protos/multipackage/bar/baz/quux.proto +++ b/testdata/protos/multipackage/bar/baz/quux.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package baz; +option go_package = "github.com/lyft/protoc-gen-star/testdata/generated/multipackage/bar/baz"; message Quux { oneof id { diff --git a/testdata/protos/multipackage/bar/buzz.proto b/testdata/protos/multipackage/bar/buzz.proto index 38db1e9..cb2860a 100644 --- a/testdata/protos/multipackage/bar/buzz.proto +++ b/testdata/protos/multipackage/bar/buzz.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package bar; +option go_package = "github.com/lyft/protoc-gen-star/testdata/generated/multipackage/bar"; import "multipackage/bar/baz/quux.proto"; diff --git a/testdata/protos/multipackage/foo/fizz.proto b/testdata/protos/multipackage/foo/fizz.proto index 04eea68..1d6a9ae 100644 --- a/testdata/protos/multipackage/foo/fizz.proto +++ b/testdata/protos/multipackage/foo/fizz.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package foo; +option go_package = "github.com/lyft/protoc-gen-star/testdata/generated/multipackage/foo"; import "multipackage/bar/buzz.proto"; diff --git a/wkt.go b/wkt.go new file mode 100644 index 0000000..5d6c86e --- /dev/null +++ b/wkt.go @@ -0,0 +1,73 @@ +package pgs + +// WellKnownTypePackage is the proto package name where all Well Known Types +// currently reside. +const WellKnownTypePackage Name = "google.protobuf" + +// WellKnownType (WKT) encapsulates the Name of a Message from the +// `google.protobuf` package. Most official protoc plugins special case code +// generation on these messages. +type WellKnownType Name + +// 1-to-1 mapping of the WKT names to WellKnownTypes. +const ( + // UnknownWKT indicates that the type is not a known WKT. This value may be + // returned erroneously mapping a Name to a WellKnownType or if a WKT is + // added to the `google.protobuf` package but this library is outdated. + UnknownWKT WellKnownType = "Unknown" + + AnyWKT WellKnownType = "Any" + DurationWKT WellKnownType = "Duration" + EmptyWKT WellKnownType = "Empty" + StructWKT WellKnownType = "Struct" + TimestampWKT WellKnownType = "Timestamp" + ValueWKT WellKnownType = "Value" + ListValueWKT WellKnownType = "ListValue" + DoubleValueWKT WellKnownType = "DoubleValue" + FloatValueWKT WellKnownType = "FloatValue" + Int64ValueWKT WellKnownType = "Int64Value" + UInt64ValueWKT WellKnownType = "UInt64Value" + Int32ValueWKT WellKnownType = "Int32Value" + UInt32ValueWKT WellKnownType = "UInt32Value" + BoolValueWKT WellKnownType = "BoolValue" + StringValueWKT WellKnownType = "StringValue" + BytesValueWKT WellKnownType = "BytesValue" +) + +var wktLookup = map[Name]WellKnownType{ + "Any": AnyWKT, + "Duration": DurationWKT, + "Empty": EmptyWKT, + "Struct": StructWKT, + "Timestamp": TimestampWKT, + "Value": ValueWKT, + "ListValue": ListValueWKT, + "DoubleValue": DoubleValueWKT, + "FloatValue": FloatValueWKT, + "Int64Value": Int64ValueWKT, + "UInt64Value": UInt64ValueWKT, + "Int32Value": Int32ValueWKT, + "UInt32Value": UInt32ValueWKT, + "BoolValue": BoolValueWKT, + "StringValue": StringValueWKT, + "BytesValue": BytesValueWKT, +} + +// LookupWKT returns the WellKnownType related to the provided Name. If the +// name is not recognized, UnknownWKT is returned. +func LookupWKT(n Name) WellKnownType { + if wkt, ok := wktLookup[n]; ok { + return wkt + } + + return UnknownWKT +} + +// Name converts the WellKnownType to a Name. This is a convenience method. +func (wkt WellKnownType) Name() Name { return Name(wkt) } + +// Valid returns true if the WellKnownType is recognized by this library. +func (wkt WellKnownType) Valid() bool { + _, ok := wktLookup[wkt.Name()] + return ok +} diff --git a/wkt_test.go b/wkt_test.go new file mode 100644 index 0000000..c67f39a --- /dev/null +++ b/wkt_test.go @@ -0,0 +1,58 @@ +package pgs + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLookupWKT(t *testing.T) { + t.Parallel() + + tests := []struct { + name Name + expected WellKnownType + }{ + {"Any", AnyWKT}, + {"Duration", DurationWKT}, + {"Empty", EmptyWKT}, + {"Foobar", UnknownWKT}, + } + + for _, test := range tests { + tc := test + t.Run(tc.name.String(), func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, LookupWKT(tc.name)) + }) + } +} + +func TestWellKnownType_Name(t *testing.T) { + t.Parallel() + + wkt := WellKnownType("Foobar") + assert.Equal(t, Name("Foobar"), wkt.Name()) +} + +func TestWellKnownType_Valid(t *testing.T) { + t.Parallel() + + tests := []struct { + wkt WellKnownType + expected bool + }{ + {AnyWKT, true}, + {Int64ValueWKT, true}, + {UnknownWKT, false}, + {WellKnownType("Foobar"), false}, + } + + for _, test := range tests { + tc := test + t.Run(tc.wkt.Name().String(), func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, tc.wkt.Valid()) + }) + } +} diff --git a/workflow.go b/workflow.go index 2662f87..0e7d4a1 100644 --- a/workflow.go +++ b/workflow.go @@ -2,7 +2,6 @@ package pgs import ( "io/ioutil" - "sync" "github.com/golang/protobuf/proto" @@ -10,62 +9,39 @@ import ( ) type workflow interface { - Init(g *Generator) - Go() - Star() - Persist() + Init(*Generator) AST + Run(AST) []Artifact + Persist([]Artifact) } -// standardWorkflow uses a close-to-official execution pattern for PGGo. PG* -// modules are executed once the PGGo execution has completed and files are -// persisted the Generator's Persister instance (typically sending back to -// protoc and creating files on disk). -type standardWorkflow struct { - *Generator - arts []Artifact -} +// standardWorkflow describes a typical protoc-plugin flow, with the only +// exception being the behavior of the persistor directly writing custom file +// artifacts to disk (instead of via the plugin's output to protoc). +type standardWorkflow struct{ *Generator } -func (wf *standardWorkflow) Init(g *Generator) { - *wf = standardWorkflow{} +func (wf *standardWorkflow) Init(g *Generator) AST { wf.Generator = g wf.Debug("reading input") - data, err := ioutil.ReadAll(wf.in) + data, err := ioutil.ReadAll(g.in) wf.CheckErr(err, "reading input") wf.Debug("parsing input proto") - err = proto.Unmarshal(data, wf.pgg.request()) + req := new(plugin_go.CodeGeneratorRequest) + err = proto.Unmarshal(data, req) wf.CheckErr(err, "parsing input proto") - wf.Assert(len(wf.pgg.request().FileToGenerate) > 0, "no files to generate") + wf.Assert(len(req.FileToGenerate) > 0, "no files to generate") wf.Debug("parsing command-line params") - wf.params = ParseParameters(wf.pgg.request().GetParameter()) + wf.params = ParseParameters(req.GetParameter()) for _, pm := range wf.paramMutators { pm(wf.params) } -} - -func (wf *standardWorkflow) Go() { - wf.RegisterPlugin(wf.gatherer) - wf.params.AddPlugin(wf.gatherer.Name()) - - wf.Debug("initializing plugins") - for _, p := range wf.plugins { - p.InitContext(Context( - wf.Debugger.Push(p.Name()), - wf.params, - ".", - )) - } - wf.Debug("preparing official generator") - wf.pgg.prepare(wf.params) - - wf.Debug("generating official PGG PBs and gathering PG* AST") - wf.pgg.generate() + return ProcessDescriptors(g, req) } -func (wf *standardWorkflow) Star() { +func (wf *standardWorkflow) Run(ast AST) (arts []Artifact) { ctx := Context(wf.Debugger, wf.params, wf.params.OutputPath()) wf.Debug("initializing modules") @@ -75,20 +51,16 @@ func (wf *standardWorkflow) Star() { wf.Debug("executing modules") for _, m := range wf.mods { - if mm, ok := m.(MultiModule); ok { - wf.arts = append(wf.arts, mm.MultiExecute(wf.gatherer.targets, wf.gatherer.pkgs)...) - } else { - for _, pkg := range wf.gatherer.targets { - wf.arts = append(wf.arts, m.Execute(pkg, wf.gatherer.pkgs)...) - } - } + arts = append(arts, m.Execute(ast.Targets(), ast.Packages())...) } + + return } -func (wf *standardWorkflow) Persist() { - wf.persister.Persist(wf.arts...) +func (wf *standardWorkflow) Persist(arts []Artifact) { + resp := wf.persister.Persist(arts...) - data, err := proto.Marshal(wf.pgg.response()) + data, err := proto.Marshal(resp) wf.CheckErr(err, "marshaling output proto") n, err := wf.out.Write(data) @@ -98,59 +70,33 @@ func (wf *standardWorkflow) Persist() { wf.Debug("rendering successful") } -// onceWorkflow wraps an existing workflow, executing its methods only once. -// This is required to keep the Generator AST & Render methods idempotent. +// onceWorkflow wraps an existing workflow, executing its methods exactly +// once. Subsequent calls will ignore their inputs and use the previously +// provided values. type onceWorkflow struct { workflow - initOnce sync.Once - goOnce sync.Once - starOnce sync.Once - persistOnce sync.Once -} -func (wf *onceWorkflow) Init(g *Generator) { wf.initOnce.Do(func() { wf.workflow.Init(g) }) } -func (wf *onceWorkflow) Go() { wf.goOnce.Do(wf.workflow.Go) } -func (wf *onceWorkflow) Star() { wf.starOnce.Do(wf.workflow.Star) } -func (wf *onceWorkflow) Persist() { wf.persistOnce.Do(wf.workflow.Persist) } + initOnce sync.Once + ast AST -// excludeGoWorkflow wraps an existing workflow, stripping any PGGo generated -// files from the response. This workflow is used when the IncludeGo InitOption -// is not applied to the Generator. -type excludeGoWorkflow struct { - *Generator - workflow -} + runOnce sync.Once + arts []Artifact -func (wf *excludeGoWorkflow) Init(g *Generator) { - wf.Generator = g - wf.workflow.Init(g) + persistOnce sync.Once } -func (wf *excludeGoWorkflow) Go() { - wf.workflow.Go() - - scrubbed := make( - []*plugin_go.CodeGeneratorResponse_File, - 0, len(wf.pgg.response().File)) - - toScrub := make(map[string]struct{}, len(wf.pgg.response().File)) - el := struct{}{} - - for _, pkg := range wf.gatherer.targets { - for _, f := range pkg.Files() { - if f.BuildTarget() { - toScrub[f.OutputPath().String()] = el - } - } - } +func (wf *onceWorkflow) Init(g *Generator) AST { + wf.initOnce.Do(func() { wf.ast = wf.workflow.Init(g) }) + return wf.ast +} - for _, f := range wf.pgg.response().File { - if _, scrub := toScrub[f.GetName()]; !scrub { - scrubbed = append(scrubbed, f) - } else { - wf.Debug("excluding official Go PB:", f.GetName()) - } - } +func (wf *onceWorkflow) Run(ast AST) []Artifact { + wf.runOnce.Do(func() { + wf.arts = wf.workflow.Run(ast) + }) + return wf.arts +} - wf.pgg.response().File = scrubbed +func (wf *onceWorkflow) Persist(artifacts []Artifact) { + wf.persistOnce.Do(func() { wf.workflow.Persist(artifacts) }) } diff --git a/workflow_multipackage.go b/workflow_multipackage.go deleted file mode 100644 index 80ec5d0..0000000 --- a/workflow_multipackage.go +++ /dev/null @@ -1,306 +0,0 @@ -package pgs - -import ( - "bufio" - "context" - "io" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "sort" - "sync" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/descriptor" - protoc "github.com/golang/protobuf/protoc-gen-go/plugin" - "golang.org/x/sync/errgroup" -) - -const multiPackageSubProcessParam = "pgs_multipkg" - -type multiPackageWorkflow struct { - *Generator - workflow - - stdout io.Writer - idxLookup map[string]int - - spoofFanout *protoc.CodeGeneratorResponse -} - -func (wf *multiPackageWorkflow) Init(g *Generator) { - wf.Generator = g - wf.stdout = os.Stdout - wf.workflow.Init(g) -} - -func (wf *multiPackageWorkflow) Go() { - wf.Debug("evaluating multi-package mode") - - if set, _ := wf.params.Bool(multiPackageSubProcessParam); set { - wf.Debug("multi-package sub-process") - wf.subGo() - return - } - - subReqs := wf.splitRequest() - if len(subReqs) <= 1 { - wf.Debug("single package run") - wf.workflow.Go() - return - } - - wf.push("multi-package mode") - defer wf.pop() - wf.Debug("multiple packages detected") - - res := wf.fanoutSubReqs(subReqs) - origReq := wf.pgg.request() - - wf.pgg.setRequest(&protoc.CodeGeneratorRequest{ - FileToGenerate: subReqs[0].FileToGenerate, - ProtoFile: wf.pgg.request().ProtoFile, - }) - - wf.RegisterPlugin(wf.gatherer) - wf.gatherer.InitContext(Context( - wf.Debugger.Push(wf.gatherer.Name()), - wf.params, - ".", - )) - - params := ParseParameters(wf.params.String()) - params.SetStr(pluginsKey, wf.gatherer.Name()) - - wf.pgg.prepare(params) - wf.pgg.setRequest(origReq) - - wf.pgg.generate() - wf.pgg.setResponse(res) -} - -func (wf *multiPackageWorkflow) subGo() { - wf.workflow.Go() - - data, err := proto.Marshal(wf.pgg.response()) - wf.CheckErr(err, "marshaling output proto") - - n, err := wf.stdout.Write(data) - wf.CheckErr(err, "failed to write output") - wf.Assert(n == len(data), "failed to write all output") - - wf.Debug("sub-process execution successful, forwarding back to main process") - wf.Exit(0) -} - -// splitRequest identifies sub-requests in the original PGG Request by -// individual directories. Since PGG expects only single-package requests, this -// identifies how many independent runs of PGG would be required. -func (wf *multiPackageWorkflow) splitRequest() (subReqs []*protoc.CodeGeneratorRequest) { - wf.idxLookup = make(map[string]int, len(wf.pgg.request().ProtoFile)) - for i, f := range wf.pgg.request().ProtoFile { - wf.idxLookup[f.GetName()] = i - } - - params := ParseParameters(wf.params.String()) - params.SetBool(multiPackageSubProcessParam, true) - - fSets := wf.splitFileSets() - subReqs = make([]*protoc.CodeGeneratorRequest, len(fSets)) - for i, fs := range fSets { - subReqs[i] = &protoc.CodeGeneratorRequest{ - FileToGenerate: fs, - ProtoFile: wf.filterDeps(fs), - Parameter: proto.String(params.String()), - } - } - - return -} - -// splitFileSets segments the FileToGenerate on the original PGG Request by -// directory, maintaining the order of execution. -func (wf *multiPackageWorkflow) splitFileSets() (out [][]string) { - lu := map[string]int{} - - for _, f := range wf.pgg.request().FileToGenerate { - dir := filepath.Dir(f) - - if i, ok := lu[dir]; ok { - out[i] = append(out[i], f) - continue - } - - out = append(out, []string{f}) - lu[dir] = len(out) - 1 - } - - return -} - -// filterDeps resolves the dependencies of just the files listed in fs from the -// ProtoFile slice on the original request, maintaining the order from the -// original. -func (wf *multiPackageWorkflow) filterDeps(fs []string) []*descriptor.FileDescriptorProto { - var idxs []int - - for _, f := range fs { - idxs = append(idxs, wf.resolveIndexes(f)...) - } - - return wf.resolveProtos(idxs) -} - -// resolveIndexes identifies the indexes of ProtoFile elements that are -// dependencies of the file f. -func (wf *multiPackageWorkflow) resolveIndexes(f string) []int { - idx := wf.idxLookup[f] - pb := wf.pgg.request().ProtoFile[idx] - - out := []int{idx} - for _, d := range pb.Dependency { - out = append(out, wf.resolveIndexes(d)...) - } - - return out -} - -// resolveProtos converts ProtoFile indexes into a subset of the ProtoFile. -// files are included in the output in the same order they appear in the -// original slice and duplicates are automatically removed. -func (wf *multiPackageWorkflow) resolveProtos(idxs []int) (out []*descriptor.FileDescriptorProto) { - sort.Ints(idxs) - last := -1 - - for _, i := range idxs { - if last == i { - continue - } - - out = append(out, wf.pgg.request().ProtoFile[i]) - last = i - } - - return -} - -// fanoutSubReqs spawns sub processes to individually execute each sub request -// provided. The resulting response is merged together if all requests are -// successful. -func (wf *multiPackageWorkflow) fanoutSubReqs(subReqs []*protoc.CodeGeneratorRequest) *protoc.CodeGeneratorResponse { - if wf.spoofFanout != nil { - return wf.spoofFanout - } - - grp, ctx := errgroup.WithContext(context.Background()) - procs := wf.prepareProcesses(ctx, len(subReqs)) - return wf.handleProcesses(grp, procs, subReqs) -} - -// prepareProcesses sets up n SubProcess instances for use in the workflow -func (wf *multiPackageWorkflow) prepareProcesses(ctx context.Context, n int) []subProcess { - procs := make([]subProcess, n) - for i := 0; i < n; i++ { - procs[i] = exec.CommandContext(ctx, os.Args[0]) - } - return procs -} - -// handleProcesses multiplexes each of the sub-requests onto the provided procs -// and merges their responses into a single Response. -func (wf *multiPackageWorkflow) handleProcesses( - grp *errgroup.Group, - procs []subProcess, - subReqs []*protoc.CodeGeneratorRequest, -) *protoc.CodeGeneratorResponse { - outs := make([]*protoc.CodeGeneratorResponse, len(procs)) - - for i, proc := range procs { - p := proc - req := subReqs[i] - out := new(protoc.CodeGeneratorResponse) - outs[i] = out - - grp.Go(func() error { return wf.handleProcess(p, req, out) }) - } - - wf.CheckErr(grp.Wait(), "execution of sub-processes failed") - - res := new(protoc.CodeGeneratorResponse) - for _, out := range outs { - res.File = append(res.File, out.File...) - } - - return res -} - -// handleProcess handles a single SubProcess execution -func (wf *multiPackageWorkflow) handleProcess( - proc subProcess, - req *protoc.CodeGeneratorRequest, - res *protoc.CodeGeneratorResponse, -) error { - stdin, err := proc.StdinPipe() - if err != nil { - return err - } - - stdout, err := proc.StdoutPipe() - if err != nil { - return err - } - - stderr, err := proc.StderrPipe() - if err != nil { - return err - } - - wg := &sync.WaitGroup{} - wg.Add(3) - - var b []byte - - go func() { - in, _ := proto.Marshal(req) - stdin.Write(in) - stdin.Close() - wg.Done() - }() - - go func() { - b, _ = ioutil.ReadAll(stdout) - wg.Done() - }() - - go func() { - sc := bufio.NewScanner(stderr) - l := wf.Push(filepath.Dir(req.FileToGenerate[0])) - for sc.Scan() { - l.Log(sc.Text()) - } - wg.Done() - }() - - if err = proc.Start(); err != nil { - return err - } - - wg.Wait() - - if err = proc.Wait(); err != nil { - return err - } - - return proto.Unmarshal(b, res) -} - -// subProcess is the interface used by Multi-Package workflow -type subProcess interface { - Start() error - Wait() error - - StdinPipe() (io.WriteCloser, error) - StdoutPipe() (io.ReadCloser, error) - StderrPipe() (io.ReadCloser, error) -} diff --git a/workflow_multipackage_test.go b/workflow_multipackage_test.go deleted file mode 100644 index c14b333..0000000 --- a/workflow_multipackage_test.go +++ /dev/null @@ -1,319 +0,0 @@ -package pgs - -import ( - "bytes" - "errors" - "io" - "io/ioutil" - "testing" - - "context" - - "os" - - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/descriptor" - protoc "github.com/golang/protobuf/protoc-gen-go/plugin" - "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" -) - -func multiPackageReq() *protoc.CodeGeneratorRequest { - return &protoc.CodeGeneratorRequest{ - FileToGenerate: []string{ - "foo", - "bar", - "bar/quux", - "bar/baz", - "bar/fizz", - "bar/fizz/buzz", - }, - ProtoFile: []*descriptor.FileDescriptorProto{ - {Name: proto.String("bar/fizz/buzz")}, - {Name: proto.String("bar/fizz"), Dependency: []string{"bar/fizz/buzz"}}, - {Name: proto.String("bar/baz"), Dependency: []string{"bar/fizz"}}, - {Name: proto.String("bar/quux")}, - {Name: proto.String("bar"), Dependency: []string{"bar/baz"}}, - {Name: proto.String("foo"), Dependency: []string{"bar"}}, - }, - } -} - -func TestMultiPackageWorkflow_Init(t *testing.T) { - wf := &multiPackageWorkflow{workflow: &dummyWorkflow{}} - - g := Init() - wf.Init(g) - - assert.Equal(t, g, wf.Generator) - assert.Equal(t, os.Stdout, wf.stdout) -} - -func TestMultiPackageWorkflow_Go(t *testing.T) { - d := newMockDebugger(t) - g := Init() - g.Debugger = d - g.pgg = mockGeneratorPGG{g.pgg} - - req := multiPackageReq() - res := &protoc.CodeGeneratorResponse{Error: proto.String("foo")} - - g.pgg.setRequest(req) - - dwf := &dummyWorkflow{} - wf := &multiPackageWorkflow{workflow: dwf, spoofFanout: res} - wf.Init(g) - wf.Go() - - assert.Equal(t, req, g.pgg.request()) - assert.Equal(t, res, g.pgg.response()) - assert.False(t, dwf.goed) -} - -func TestMultiPackageWorkflow_Go_SinglePackage(t *testing.T) { - d := newMockDebugger(t) - g := Init() - g.Debugger = d - g.pgg = mockGeneratorPGG{g.pgg} - - req := &protoc.CodeGeneratorRequest{ - FileToGenerate: []string{"foo", "bar"}, - ProtoFile: []*descriptor.FileDescriptorProto{ - {Name: proto.String("bar")}, - {Name: proto.String("foo")}, - }, - } - - g.pgg.setRequest(req) - - dwf := &dummyWorkflow{} - wf := &multiPackageWorkflow{workflow: dwf} - wf.Init(g) - wf.Go() - - assert.True(t, dwf.goed) -} - -func TestMultiPackageWorkflow_Go_SubProcess(t *testing.T) { - d := newMockDebugger(t) - g := Init() - g.Debugger = d - g.pgg = mockGeneratorPGG{g.pgg} - g.params = Parameters{multiPackageSubProcessParam: "true"} - - req := &protoc.CodeGeneratorRequest{ - FileToGenerate: []string{"foo", "bar"}, - ProtoFile: []*descriptor.FileDescriptorProto{ - {Name: proto.String("bar")}, - {Name: proto.String("foo")}, - }, - } - - g.pgg.setRequest(req) - - dwf := &dummyWorkflow{} - wf := &multiPackageWorkflow{workflow: dwf} - wf.Init(g) - wf.stdout = &bytes.Buffer{} - wf.Go() - - assert.True(t, dwf.goed) - assert.NoError(t, d.err) - assert.True(t, d.exited) -} - -func TestMultiPackageWorkflow_SplitRequests(t *testing.T) { - g := Init() - g.pgg.setRequest(multiPackageReq()) - wf := &multiPackageWorkflow{Generator: g} - - subreqs := wf.splitRequest() - assert.Len(t, subreqs, 3) - - assert.Len(t, subreqs[0].FileToGenerate, 2) - assert.Equal(t, "foo", subreqs[0].FileToGenerate[0]) - assert.Len(t, subreqs[0].ProtoFile, 5, "all proto files except bar/quux") - - assert.Len(t, subreqs[1].FileToGenerate, 3) - assert.Equal(t, "bar/quux", subreqs[1].FileToGenerate[0]) - assert.Len(t, subreqs[1].ProtoFile, 4, "all files except foo and bar") - - assert.Len(t, subreqs[2].FileToGenerate, 1) - assert.Equal(t, "bar/fizz/buzz", subreqs[2].FileToGenerate[0]) - assert.Len(t, subreqs[2].ProtoFile, 1, "only the file to gen") - - set, err := ParseParameters(subreqs[0].GetParameter()).BoolDefault(multiPackageSubProcessParam, false) - assert.NoError(t, err) - assert.True(t, set) -} - -func TestMultiPackageWorkflow(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - assert.NotPanics(t, func() { - wf.fanoutSubReqs(nil) - }) -} - -func TestMultiPackageWorkflow_PrepareProcesses(t *testing.T) { - wf := &multiPackageWorkflow{} - procs := wf.prepareProcesses(context.Background(), 3) - - assert.Len(t, procs, 3) - for _, p := range procs { - assert.NotNil(t, p) - } -} - -func TestMultiPackageWorkflow_HandleProcesses(t *testing.T) { - g := Init() - d := newMockDebugger(t) - g.Debugger = d - wf := &multiPackageWorkflow{Generator: g} - - reqs := []*protoc.CodeGeneratorRequest{ - {FileToGenerate: []string{"alpha"}}, - {FileToGenerate: []string{"beta"}}, - } - - res0, _ := proto.Marshal(&protoc.CodeGeneratorResponse{File: []*protoc.CodeGeneratorResponse_File{ - {Name: proto.String("foo"), Content: proto.String("bar")}, - }}) - - res1, _ := proto.Marshal(&protoc.CodeGeneratorResponse{File: []*protoc.CodeGeneratorResponse_File{ - {Name: proto.String("fizz"), Content: proto.String("buzz")}, - }}) - - procs := []subProcess{ - &mockSubProcess{out: bytes.NewReader(res0)}, - &mockSubProcess{out: bytes.NewReader(res1)}, - } - - res := wf.handleProcesses(&errgroup.Group{}, procs, reqs) - - if !assert.Nil(t, d.err) { - return - } - - assert.Len(t, res.File, 2) - - assert.Equal(t, "bar", res.File[0].GetContent()) - assert.Equal(t, "buzz", res.File[1].GetContent()) -} - -func TestMultiPackageWorkflow_HandleProcess_Success(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - res := &protoc.CodeGeneratorResponse{Error: proto.String("bar")} - b, _ := proto.Marshal(res) - - sp := &mockSubProcess{ - out: bytes.NewReader(b), - err: bytes.NewBufferString("some line\n"), - } - - out := new(protoc.CodeGeneratorResponse) - assert.NoError(t, wf.handleProcess(sp, req, out)) - assert.True(t, proto.Equal(res, out)) - - b, _ = proto.Marshal(req) - assert.Equal(t, b, sp.in.Bytes()) -} - -func TestMultiPackageWorkflow_HandleProcess_BrokenIn(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{inErr: errors.New("pipe error")} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - assert.Equal(t, sp.inErr, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -func TestMultiPackageWorkflow_HandleProcess_BrokenOut(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{outErr: errors.New("pipe error")} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - assert.Equal(t, sp.outErr, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -func TestMultiPackageWorkflow_HandleProcess_BrokenErr(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{errErr: errors.New("pipe error")} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - assert.Equal(t, sp.errErr, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -func TestMultiPackageWorkflow_HandleProcess_StartErr(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{startErr: errors.New("start error")} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - assert.Equal(t, sp.startErr, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -func TestMultiPackageWorkflow_HandleProcess_WaitErr(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{waitErr: errors.New("wait error")} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - assert.Equal(t, sp.waitErr, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -func TestMultiPackageWorkflow_HandleProcess_UnmarshalErr(t *testing.T) { - g := Init() - g.Debugger = newMockDebugger(t) - wf := &multiPackageWorkflow{Generator: g} - - sp := &mockSubProcess{out: bytes.NewReader([]byte("not a valid proto"))} - req := &protoc.CodeGeneratorRequest{FileToGenerate: []string{"foo"}} - - assert.Error(t, wf.handleProcess(sp, req, new(protoc.CodeGeneratorResponse))) -} - -type mockSubProcess struct { - startErr, waitErr error - inErr, outErr, errErr error - in bytes.Buffer - out, err io.Reader -} - -func (sp *mockSubProcess) Start() error { return sp.startErr } -func (sp *mockSubProcess) Wait() error { return sp.waitErr } - -func (sp *mockSubProcess) StdinPipe() (io.WriteCloser, error) { - return NopWriteCloser{&sp.in}, sp.inErr -} - -func (sp *mockSubProcess) StdoutPipe() (io.ReadCloser, error) { - if sp.out == nil { - sp.out = bytes.NewReader([]byte{}) - } - return ioutil.NopCloser(sp.out), sp.outErr -} - -func (sp *mockSubProcess) StderrPipe() (io.ReadCloser, error) { - if sp.err == nil { - sp.err = bytes.NewReader([]byte{}) - } - return ioutil.NopCloser(sp.err), sp.errErr -} - -type NopWriteCloser struct{ io.Writer } - -func (w NopWriteCloser) Close() error { return nil } diff --git a/workflow_test.go b/workflow_test.go index 63aa6bd..4290ef5 100644 --- a/workflow_test.go +++ b/workflow_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/protoc-gen-go/generator" "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/stretchr/testify/assert" ) @@ -23,42 +22,23 @@ func TestStandardWorkflow_Init(t *testing.T) { g := Init(ProtocInput(bytes.NewReader(b)), MutateParams(func(p Parameters) { mutated = true })) g.workflow.Init(g) - assert.True(t, proto.Equal(req, g.pgg.request())) assert.True(t, mutated) } -func TestStandardWorkflow_Go(t *testing.T) { - t.Parallel() - - g := Init() - g.workflow = &standardWorkflow{Generator: g} - g.pgg = mockGeneratorPGG{} - g.params = Parameters{"foo": "bar"} - - g.workflow.Go() - assert.Equal(t, g.params, g.gatherer.BuildContext.Parameters()) -} - -func TestStandardWorkflow_Star(t *testing.T) { +func TestStandardWorkflow_Run(t *testing.T) { t.Parallel() g := Init() g.workflow = &standardWorkflow{Generator: g} g.params = Parameters{} - g.gatherer.targets = map[string]Package{"baz": dummyPkg()} m := newMockModule() m.name = "foo" - mm := newMultiMockModule() - mm.name = "bar" - - g.RegisterModule(m, mm) - - g.workflow.Star() + g.RegisterModule(m) + g.workflow.Run(&graph{}) assert.True(t, m.executed) - assert.True(t, mm.multiExecuted) } func TestStandardWorkflow_Persist(t *testing.T) { @@ -68,75 +48,45 @@ func TestStandardWorkflow_Persist(t *testing.T) { g.workflow = &standardWorkflow{Generator: g} g.persister = dummyPersister(g.Debugger) - assert.NotPanics(t, g.workflow.Persist) + assert.NotPanics(t, func() { g.workflow.Persist(nil) }) } func TestOnceWorkflow(t *testing.T) { t.Parallel() - d := &dummyWorkflow{} + d := &dummyWorkflow{ + AST: &graph{}, + Artifacts: []Artifact{&CustomFile{}}, + } wf := &onceWorkflow{workflow: d} - wf.Init(nil) - wf.Go() - wf.Star() - wf.Persist() + ast := wf.Init(nil) + arts := wf.Run(ast) + wf.Persist(arts) assert.True(t, d.initted) - assert.True(t, d.goed) - assert.True(t, d.starred) + assert.True(t, d.run) assert.True(t, d.persisted) d = &dummyWorkflow{} wf.workflow = d - wf.Init(nil) - wf.Go() - wf.Star() - wf.Persist() + assert.Equal(t, ast, wf.Init(nil)) + assert.Equal(t, arts, wf.Run(ast)) + wf.Persist(arts) assert.False(t, d.initted) - assert.False(t, d.goed) - assert.False(t, d.starred) + assert.False(t, d.run) assert.False(t, d.persisted) } -func TestExcludeGoWorkflow_Go(t *testing.T) { - t.Parallel() - - g := &Generator{ - Debugger: newMockDebugger(t), - pgg: Wrap(&generator.Generator{Response: &plugin_go.CodeGeneratorResponse{ - File: []*plugin_go.CodeGeneratorResponse_File{ - {Name: proto.String("fizz/buzz.pb.go")}, - {Name: proto.String("foo/bar.pb.go")}, - {Name: proto.String("foo/baz.pb.go")}, - }, - }}), - gatherer: &gatherer{ - targets: map[string]Package{"quux": &pkg{ - files: []File{ - &file{buildTarget: true, outputPath: "foo/bar.pb.go"}, - &file{buildTarget: false, outputPath: "fizz/buzz.pb.go"}, - }, - }}, - }, - } - - wf := &excludeGoWorkflow{Generator: g, workflow: &dummyWorkflow{}} - wf.Go() - - resp := g.pgg.response() - assert.Len(t, resp.File, 2) - assert.Equal(t, "fizz/buzz.pb.go", resp.File[0].GetName()) - assert.Equal(t, "foo/baz.pb.go", resp.File[1].GetName()) -} - type dummyWorkflow struct { - initted, goed, starred, persisted bool + AST AST + Artifacts []Artifact + + initted, run, persisted bool } -func (wf *dummyWorkflow) Init(g *Generator) { wf.initted = true } -func (wf *dummyWorkflow) Go() { wf.goed = true } -func (wf *dummyWorkflow) Star() { wf.starred = true } -func (wf *dummyWorkflow) Persist() { wf.persisted = true } +func (wf *dummyWorkflow) Init(g *Generator) AST { wf.initted = true; return wf.AST } +func (wf *dummyWorkflow) Run(ast AST) []Artifact { wf.run = true; return wf.Artifacts } +func (wf *dummyWorkflow) Persist(arts []Artifact) { wf.persisted = true }