From d186061f0cdc717ce813e52d91c29130e2d1a6bc Mon Sep 17 00:00:00 2001 From: Sarthak Gupta Date: Mon, 23 Aug 2021 07:18:29 +0100 Subject: [PATCH] field: fix missing optional support (#95) --- .travis.yml | 5 +- Makefile | 17 +++ field.go | 46 ++++++- field_test.go | 72 +++++++++++ field_type.go | 3 +- field_type_test.go | 21 ++++ go.mod | 1 + init_option.go | 8 ++ lang/go/Makefile | 13 ++ lang/go/testdata/presence/types/params | 0 lang/go/testdata/presence/types/proto3.proto | 67 ++++++++++ lang/go/type_name.go | 2 +- lang/go/type_name_p2_presence_test.go | 121 +++++++++++++++++++ lang/go/type_name_p3_presence_test.go | 90 ++++++++++++++ lang/go/type_name_test.go | 114 +---------------- message.go | 29 +++++ message_test.go | 46 +++++++ oneof.go | 10 ++ oneof_test.go | 14 +++ persister.go | 12 +- proto.go | 8 +- protoc-gen-debug/main.go | 9 +- 22 files changed, 577 insertions(+), 131 deletions(-) create mode 100644 lang/go/testdata/presence/types/params create mode 100644 lang/go/testdata/presence/types/proto3.proto create mode 100644 lang/go/type_name_p2_presence_test.go create mode 100644 lang/go/type_name_p3_presence_test.go diff --git a/.travis.yml b/.travis.yml index 7ab01b8..f05c01d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,9 @@ go_import_path: github.com/lyft/protoc-gen-star env: matrix: - - PROTOC_VER="3.5.1" - - PROTOC_VER="3.6.1" + - PROTOC_VER="3.5.0" + - PROTOC_VER="3.6.0" + - PROTOC_VER="3.17.0" before_install: - mkdir -p $GOPATH/bin diff --git a/Makefile b/Makefile index 7a6d4b9..c679944 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ # the name of this package PKG := $(shell go list .) +PROTOC_VER := $(shell protoc --version | cut -d' ' -f2) .PHONY: bootstrap bootstrap: testdata # set up the project for development @@ -16,15 +17,27 @@ lint: # lints the package for common code smells .PHONY: quick quick: testdata # runs all tests without the race detector or coverage +ifeq ($(PROTOC_VER), 3.17.0) + go test $(PKGS) --tags=proto3_presence +else go test $(PKGS) +endif .PHONY: tests tests: testdata # runs all tests against the package with race detection and coverage percentage +ifeq ($(PROTOC_VER), 3.17.0) + go test -race -cover ./... --tags=proto3_presence +else go test -race -cover ./... +endif .PHONY: cover cover: testdata # runs all tests against the package, generating a coverage report and opening it in the browser +ifeq ($(PROTOC_VER), 3.17.0) + go test -race -covermode=atomic -coverprofile=cover.out ./... --tags=proto3_presence || true +else go test -race -covermode=atomic -coverprofile=cover.out ./... || true +endif go tool cover -html cover.out -o cover.html open cover.html @@ -76,6 +89,10 @@ testdata-go: protoc-gen-go bin/protoc-gen-debug # generate go-specific testdata testdata-names \ testdata-packages \ testdata-outputs +ifeq ($(PROTOC_VER), 3.17.0) + cd lang/go && $(MAKE) \ + testdata-presence +endif vendor: # install project dependencies which glide || (curl https://glide.sh/get | sh) diff --git a/field.go b/field.go index d1fa9d0..38bc1b4 100644 --- a/field.go +++ b/field.go @@ -17,16 +17,29 @@ type Field interface { Message() Message // InOneOf returns true if the field is in a OneOf of the parent Message. + // This will return true for synthetic oneofs (proto3 field presence) as well. InOneOf() bool - // OneOf returns the OneOf that this field is apart of. Nil is returned if + // InRealOneOf returns true if the field is in a OneOf of the parent Message. + // This will return false for synthetic oneofs, and will only include 'real' oneofs. + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md + InRealOneOf() bool + + // OneOf returns the OneOf that this field is a part of. Nil is returned if // the field is not within a OneOf. OneOf() OneOf // Type returns the FieldType of this Field. Type() FieldType - // Required returns whether or not the field is labeled as required. This + // HasPresence returns true for all fields that have explicit presence as defined by: + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md + HasPresence() bool + + // HasOptionalKeyword returns whether the field is labeled as optional. + HasOptionalKeyword() bool + + // Required returns whether the field is labeled as required. This // will only be true if the syntax is proto2. Required() bool @@ -61,6 +74,35 @@ func (f *field) Type() FieldType { return f.typ } func (f *field) setMessage(m Message) { f.msg = m } func (f *field) setOneOf(o OneOf) { f.oneof = o } +func (f *field) InRealOneOf() bool { + return f.InOneOf() && !f.desc.GetProto3Optional() +} + +func (f *field) HasPresence() bool { + if f.InOneOf() { + return true + } + + if f.Type().IsEmbed() { + return true + } + + if !f.Type().IsRepeated() && !f.Type().IsMap() { + if f.Syntax() == Proto2 { + return true + } + return f.HasOptionalKeyword() + } + return false +} + +func (f *field) HasOptionalKeyword() bool { + if f.Syntax() == Proto3 { + return f.desc.GetProto3Optional() + } + return f.desc.GetLabel() == descriptor.FieldDescriptorProto_LABEL_OPTIONAL +} + func (f *field) Required() bool { return f.Syntax().SupportsRequiredPrefix() && f.desc.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REQUIRED diff --git a/field_test.go b/field_test.go index 6915a34..a243cbc 100644 --- a/field_test.go +++ b/field_test.go @@ -99,6 +99,58 @@ func TestField_OneOf(t *testing.T) { assert.True(t, f.InOneOf()) } +func TestField_InRealOneOf(t *testing.T) { + t.Parallel() + + f := dummyField() + assert.False(t, f.InRealOneOf()) + + f = dummyOneOfField(false) + assert.True(t, f.InRealOneOf()) + + f = dummyOneOfField(true) + assert.False(t, f.InRealOneOf()) +} + +func TestField_HasPresence(t *testing.T) { + t.Parallel() + + f := dummyField() + f.addType(&repT{scalarT: &scalarT{}}) + assert.False(t, f.HasPresence()) + + f.addType(&mapT{repT: &repT{scalarT: &scalarT{}}}) + assert.False(t, f.HasPresence()) + + f.addType(&scalarT{}) + assert.False(t, f.HasPresence()) + + opt := true + f.desc = &descriptor.FieldDescriptorProto{Proto3Optional: &opt} + assert.True(t, f.HasPresence()) +} + +func TestField_HasOptionalKeyword(t *testing.T) { + t.Parallel() + + optLabel := descriptor.FieldDescriptorProto_LABEL_OPTIONAL + + f := &field{msg: &msg{parent: dummyFile()}} + assert.False(t, f.HasOptionalKeyword()) + + f.desc = &descriptor.FieldDescriptorProto{Label: &optLabel} + assert.False(t, f.HasOptionalKeyword()) + + f = dummyField() + assert.False(t, f.HasOptionalKeyword()) + + f = dummyOneOfField(false) + assert.False(t, f.HasOptionalKeyword()) + + f = dummyOneOfField(true) + assert.True(t, f.HasOptionalKeyword()) +} + func TestField_Type(t *testing.T) { t.Parallel() @@ -194,3 +246,23 @@ func dummyField() *field { f.addType(t) return f } + +func dummyOneOfField(synthetic bool) *field { + m := dummyMsg() + o := dummyOneof() + str := descriptor.FieldDescriptorProto_TYPE_STRING + var oIndex int32 + oIndex = 1 + f := &field{desc: &descriptor.FieldDescriptorProto{ + Name: proto.String("field"), + Type: &str, + OneofIndex: &oIndex, + Proto3Optional: &synthetic, + }} + o.addField(f) + m.addField(f) + m.addOneOf(o) + t := &scalarT{} + f.addType(t) + return f +} diff --git a/field_type.go b/field_type.go index d752f02..771341c 100644 --- a/field_type.go +++ b/field_type.go @@ -22,8 +22,7 @@ type FieldType interface { // repeated fields containing embeds will still return false. IsEmbed() bool - // IsOptional returns true if the message's syntax is not Proto2 or - // the field is prefixed as optional. + // IsOptional returns true if the field is prefixed as optional. IsOptional() bool // IsRequired returns true if and only if the field is prefixed as required. diff --git a/field_type_test.go b/field_type_test.go index 6388e8e..66a043d 100644 --- a/field_type_test.go +++ b/field_type_test.go @@ -88,6 +88,27 @@ func TestScalarT_Key(t *testing.T) { func TestScalarT_IsOptional(t *testing.T) { t.Parallel() + s := &scalarT{} + f := dummyOneOfField(true) + f.addType(s) + + assert.True(t, s.IsOptional()) + + fl := dummyFile() + fl.desc.Syntax = nil + f.Message().setParent(fl) + + assert.True(t, s.IsOptional()) + + req := descriptor.FieldDescriptorProto_LABEL_REQUIRED + f.desc.Label = &req + + assert.False(t, s.IsOptional()) +} + +func TestScalarT_IsNotOptional(t *testing.T) { + t.Parallel() + s := &scalarT{} f := dummyField() f.addType(s) diff --git a/go.mod b/go.mod index 3c40def..d68d685 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,5 @@ require ( github.com/golang/protobuf v1.5.2 github.com/spf13/afero v1.3.3 github.com/stretchr/testify v1.6.1 + google.golang.org/protobuf v1.26.0 // indirect ) diff --git a/init_option.go b/init_option.go index aba7d86..5c92c2f 100644 --- a/init_option.go +++ b/init_option.go @@ -47,3 +47,11 @@ func FileSystem(fs afero.Fs) InitOption { return func(g *Generator) { g.persiste func BiDirectional() InitOption { return func(g *Generator) { g.workflow = &onceWorkflow{workflow: &standardWorkflow{BiDi: true}} } } + +// SupportedFeatures allows defining protoc features to enable / disable. +// See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/implementing_proto3_presence.md#signaling-that-your-code-generator-supports-proto3-optional +func SupportedFeatures(feat *uint64) InitOption { + return func(g *Generator) { + g.persister.SetSupportedFeatures(feat) + } +} diff --git a/lang/go/Makefile b/lang/go/Makefile index cf308ad..565dc2c 100644 --- a/lang/go/Makefile +++ b/lang/go/Makefile @@ -38,5 +38,18 @@ testdata-outputs: ../../bin/protoc-gen-debug cd -; \ done +testdata-presence: ../../bin/protoc-gen-debug + cd testdata/presence && \ + set -e; for subdir in `find . -mindepth 1 -maxdepth 1 -type d`; do \ + cd $$subdir; \ + params=`cat params`; \ + protoc -I . -I .. \ + --plugin=protoc-gen-debug=../../../../../bin/protoc-gen-debug \ + --debug_out=".:." \ + --go_out="$$params:." \ + `find . -name "*.proto"`; \ + cd -; \ + done + ../../bin/protoc-gen-debug: cd ../.. && $(MAKE) bin/protoc-gen-debug diff --git a/lang/go/testdata/presence/types/params b/lang/go/testdata/presence/types/params new file mode 100644 index 0000000..e69de29 diff --git a/lang/go/testdata/presence/types/proto3.proto b/lang/go/testdata/presence/types/proto3.proto new file mode 100644 index 0000000..2a565f5 --- /dev/null +++ b/lang/go/testdata/presence/types/proto3.proto @@ -0,0 +1,67 @@ +syntax="proto3"; +package names.types; +option go_package = "example.com/foo/bar"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/type.proto"; + +message Proto3 { + double double = 1; + float float = 2; + int64 int64 = 3; + sfixed64 sfixed64 = 4; + sint64 sint64 = 5; + uint64 uint64 = 6; + fixed64 fixed64 = 7; + int32 int32 = 8; + sfixed32 sfixed32 = 9; + sint32 sint32 = 10; + uint32 uint32 = 11; + fixed32 fixed32 = 12; + bool bool = 13; + string string = 14; + bytes bytes = 15; + + Enum enum = 16; + google.protobuf.Syntax ext_enum = 17; + Message msg = 18; + google.protobuf.Duration ext_msg = 19; + + repeated double repeated_scalar = 20; + repeated Enum repeated_enum = 21; + repeated google.protobuf.Syntax repeated_ext_enum = 22; + repeated Message repeated_msg = 23; + repeated google.protobuf.Duration repeated_ext_msg = 24; + + map map_scalar = 25; + map map_enum = 26; + map map_ext_enum = 27; + map map_msg = 28; + map map_ext_msg = 29; + + enum Enum {VALUE = 0;} + + message Message {} + + message Optional { + optional double double = 1; + optional float float = 2; + optional int64 int64 = 3; + optional sfixed64 sfixed64 = 4; + optional sint64 sint64 = 5; + optional uint64 uint64 = 6; + optional fixed64 fixed64 = 7; + optional int32 int32 = 8; + optional sfixed32 sfixed32 = 9; + optional sint32 sint32 = 10; + optional uint32 uint32 = 11; + optional fixed32 fixed32 = 12; + optional bool bool = 13; + optional string string = 14; + optional bytes bytes = 15; + optional Enum enum = 16; + optional google.protobuf.Syntax ext_enum = 17; + optional Optional msg = 18; + optional google.protobuf.Duration ext_msg = 19; + } +} diff --git a/lang/go/type_name.go b/lang/go/type_name.go index 9f24821..9f3a1e2 100644 --- a/lang/go/type_name.go +++ b/lang/go/type_name.go @@ -25,7 +25,7 @@ func (c context) Type(f pgs.Field) TypeName { t = scalarType(ft.ProtoType()) } - if f.Syntax() == pgs.Proto2 { + if f.HasPresence() { return t.Pointer() } diff --git a/lang/go/type_name_p2_presence_test.go b/lang/go/type_name_p2_presence_test.go new file mode 100644 index 0000000..c09cd7d --- /dev/null +++ b/lang/go/type_name_p2_presence_test.go @@ -0,0 +1,121 @@ +// +build !proto3_presence + +package pgsgo + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + pgs "github.com/lyft/protoc-gen-star" +) + +func TestType(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "names", "types") + ctx := loadContext(t, "names", "types") + + tests := []struct { + field string + expected TypeName + }{ + // proto2 syntax, optional + {"Proto2.double", "*float64"}, + {"Proto2.float", "*float32"}, + {"Proto2.int64", "*int64"}, + {"Proto2.sfixed64", "*int64"}, + {"Proto2.sint64", "*int64"}, + {"Proto2.uint64", "*uint64"}, + {"Proto2.fixed64", "*uint64"}, + {"Proto2.int32", "*int32"}, + {"Proto2.sfixed32", "*int32"}, + {"Proto2.sint32", "*int32"}, + {"Proto2.uint32", "*uint32"}, + {"Proto2.fixed32", "*uint32"}, + {"Proto2.bool", "*bool"}, + {"Proto2.string", "*string"}, + {"Proto2.bytes", "[]byte"}, + {"Proto2.enum", "*Proto2_Enum"}, + {"Proto2.ext_enum", "*ptype.Syntax"}, + {"Proto2.msg", "*Proto2_Required"}, + {"Proto2.ext_msg", "*duration.Duration"}, + {"Proto2.repeated_scalar", "[]float64"}, + {"Proto2.repeated_enum", "[]Proto2_Enum"}, + {"Proto2.repeated_ext_enum", "[]ptype.Syntax"}, + {"Proto2.repeated_msg", "[]*Proto2_Required"}, + {"Proto2.repeated_ext_msg", "[]*duration.Duration"}, + {"Proto2.map_scalar", "map[string]float32"}, + {"Proto2.map_enum", "map[int32]Proto2_Enum"}, + {"Proto2.map_ext_enum", "map[uint64]ptype.Syntax"}, + {"Proto2.map_msg", "map[uint32]*Proto2_Required"}, + {"Proto2.map_ext_msg", "map[int64]*duration.Duration"}, + + // proto2 syntax, required + {"Proto2.Required.double", "*float64"}, + {"Proto2.Required.float", "*float32"}, + {"Proto2.Required.int64", "*int64"}, + {"Proto2.Required.sfixed64", "*int64"}, + {"Proto2.Required.sint64", "*int64"}, + {"Proto2.Required.uint64", "*uint64"}, + {"Proto2.Required.fixed64", "*uint64"}, + {"Proto2.Required.int32", "*int32"}, + {"Proto2.Required.sfixed32", "*int32"}, + {"Proto2.Required.sint32", "*int32"}, + {"Proto2.Required.uint32", "*uint32"}, + {"Proto2.Required.fixed32", "*uint32"}, + {"Proto2.Required.bool", "*bool"}, + {"Proto2.Required.string", "*string"}, + {"Proto2.Required.bytes", "[]byte"}, + {"Proto2.Required.enum", "*Proto2_Enum"}, + {"Proto2.Required.ext_enum", "*ptype.Syntax"}, + {"Proto2.Required.msg", "*Proto2_Required"}, + {"Proto2.Required.ext_msg", "*duration.Duration"}, + + {"Proto3.double", "float64"}, + {"Proto3.float", "float32"}, + {"Proto3.int64", "int64"}, + {"Proto3.sfixed64", "int64"}, + {"Proto3.sint64", "int64"}, + {"Proto3.uint64", "uint64"}, + {"Proto3.fixed64", "uint64"}, + {"Proto3.int32", "int32"}, + {"Proto3.sfixed32", "int32"}, + {"Proto3.sint32", "int32"}, + {"Proto3.uint32", "uint32"}, + {"Proto3.fixed32", "uint32"}, + {"Proto3.bool", "bool"}, + {"Proto3.string", "string"}, + {"Proto3.bytes", "[]byte"}, + {"Proto3.enum", "Proto3_Enum"}, + {"Proto3.ext_enum", "ptype.Syntax"}, + {"Proto3.msg", "*Proto3_Message"}, + {"Proto3.ext_msg", "*duration.Duration"}, + {"Proto3.repeated_scalar", "[]float64"}, + {"Proto3.repeated_enum", "[]Proto3_Enum"}, + {"Proto3.repeated_ext_enum", "[]ptype.Syntax"}, + {"Proto3.repeated_msg", "[]*Proto3_Message"}, + {"Proto3.repeated_ext_msg", "[]*duration.Duration"}, + {"Proto3.map_scalar", "map[string]float32"}, + {"Proto3.map_enum", "map[int32]Proto3_Enum"}, + {"Proto3.map_ext_enum", "map[uint64]ptype.Syntax"}, + {"Proto3.map_msg", "map[uint32]*Proto3_Message"}, + {"Proto3.map_ext_msg", "map[int64]*duration.Duration"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.types." + tc.field) + require.True(t, ok, "could not find field") + + fld, ok := e.(pgs.Field) + require.True(t, ok, "entity is not a field") + + assert.Equal(t, tc.expected, ctx.Type(fld)) + }) + } +} diff --git a/lang/go/type_name_p3_presence_test.go b/lang/go/type_name_p3_presence_test.go new file mode 100644 index 0000000..2312591 --- /dev/null +++ b/lang/go/type_name_p3_presence_test.go @@ -0,0 +1,90 @@ +// +build proto3_presence + +package pgsgo + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + pgs "github.com/lyft/protoc-gen-star" +) + +func TestType(t *testing.T) { + t.Parallel() + + ast := buildGraph(t, "presence", "types") + ctx := loadContext(t, "presence", "types") + + tests := []struct { + field string + expected TypeName + }{ + {"Proto3.double", "float64"}, + {"Proto3.float", "float32"}, + {"Proto3.int64", "int64"}, + {"Proto3.sfixed64", "int64"}, + {"Proto3.sint64", "int64"}, + {"Proto3.uint64", "uint64"}, + {"Proto3.fixed64", "uint64"}, + {"Proto3.int32", "int32"}, + {"Proto3.sfixed32", "int32"}, + {"Proto3.sint32", "int32"}, + {"Proto3.uint32", "uint32"}, + {"Proto3.fixed32", "uint32"}, + {"Proto3.bool", "bool"}, + {"Proto3.string", "string"}, + {"Proto3.bytes", "[]byte"}, + {"Proto3.enum", "Proto3_Enum"}, + {"Proto3.ext_enum", "typepb.Syntax"}, + {"Proto3.msg", "*Proto3_Message"}, + {"Proto3.ext_msg", "*durationpb.Duration"}, + {"Proto3.repeated_scalar", "[]float64"}, + {"Proto3.repeated_enum", "[]Proto3_Enum"}, + {"Proto3.repeated_ext_enum", "[]typepb.Syntax"}, + {"Proto3.repeated_msg", "[]*Proto3_Message"}, + {"Proto3.repeated_ext_msg", "[]*durationpb.Duration"}, + {"Proto3.map_scalar", "map[string]float32"}, + {"Proto3.map_enum", "map[int32]Proto3_Enum"}, + {"Proto3.map_ext_enum", "map[uint64]typepb.Syntax"}, + {"Proto3.map_msg", "map[uint32]*Proto3_Message"}, + {"Proto3.map_ext_msg", "map[int64]*durationpb.Duration"}, + + // proto3 syntax optional + {"Proto3.Optional.double", "*float64"}, + {"Proto3.Optional.float", "*float32"}, + {"Proto3.Optional.int64", "*int64"}, + {"Proto3.Optional.sfixed64", "*int64"}, + {"Proto3.Optional.sint64", "*int64"}, + {"Proto3.Optional.uint64", "*uint64"}, + {"Proto3.Optional.fixed64", "*uint64"}, + {"Proto3.Optional.int32", "*int32"}, + {"Proto3.Optional.sfixed32", "*int32"}, + {"Proto3.Optional.sint32", "*int32"}, + {"Proto3.Optional.uint32", "*uint32"}, + {"Proto3.Optional.fixed32", "*uint32"}, + {"Proto3.Optional.bool", "*bool"}, + {"Proto3.Optional.string", "*string"}, + {"Proto3.Optional.bytes", "[]byte"}, + {"Proto3.Optional.enum", "*Proto3_Enum"}, + {"Proto3.Optional.ext_enum", "*typepb.Syntax"}, + {"Proto3.Optional.msg", "*Proto3_Optional"}, + {"Proto3.Optional.ext_msg", "*durationpb.Duration"}, + } + + for _, test := range tests { + tc := test + t.Run(tc.field, func(t *testing.T) { + t.Parallel() + + e, ok := ast.Lookup(".names.types." + tc.field) + require.True(t, ok, "could not find field") + + fld, ok := e.(pgs.Field) + require.True(t, ok, "entity is not a field") + + assert.Equal(t, tc.expected, ctx.Type(fld)) + }) + } +} diff --git a/lang/go/type_name_test.go b/lang/go/type_name_test.go index 133e8b5..b1282ee 100644 --- a/lang/go/type_name_test.go +++ b/lang/go/type_name_test.go @@ -4,120 +4,10 @@ import ( "fmt" "testing" - pgs "github.com/lyft/protoc-gen-star" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/assert" -) - -func TestType(t *testing.T) { - t.Parallel() - - ast := buildGraph(t, "names", "types") - ctx := loadContext(t, "names", "types") - - tests := []struct { - field string - expected TypeName - }{ - // proto2 syntax, optional - {"Proto2.double", "*float64"}, - {"Proto2.float", "*float32"}, - {"Proto2.int64", "*int64"}, - {"Proto2.sfixed64", "*int64"}, - {"Proto2.sint64", "*int64"}, - {"Proto2.uint64", "*uint64"}, - {"Proto2.fixed64", "*uint64"}, - {"Proto2.int32", "*int32"}, - {"Proto2.sfixed32", "*int32"}, - {"Proto2.sint32", "*int32"}, - {"Proto2.uint32", "*uint32"}, - {"Proto2.fixed32", "*uint32"}, - {"Proto2.bool", "*bool"}, - {"Proto2.string", "*string"}, - {"Proto2.bytes", "[]byte"}, - {"Proto2.enum", "*Proto2_Enum"}, - {"Proto2.ext_enum", "*ptype.Syntax"}, - {"Proto2.msg", "*Proto2_Required"}, - {"Proto2.ext_msg", "*duration.Duration"}, - {"Proto2.repeated_scalar", "[]float64"}, - {"Proto2.repeated_enum", "[]Proto2_Enum"}, - {"Proto2.repeated_ext_enum", "[]ptype.Syntax"}, - {"Proto2.repeated_msg", "[]*Proto2_Required"}, - {"Proto2.repeated_ext_msg", "[]*duration.Duration"}, - {"Proto2.map_scalar", "map[string]float32"}, - {"Proto2.map_enum", "map[int32]Proto2_Enum"}, - {"Proto2.map_ext_enum", "map[uint64]ptype.Syntax"}, - {"Proto2.map_msg", "map[uint32]*Proto2_Required"}, - {"Proto2.map_ext_msg", "map[int64]*duration.Duration"}, - - // proto2 syntax, required - {"Proto2.Required.double", "*float64"}, - {"Proto2.Required.float", "*float32"}, - {"Proto2.Required.int64", "*int64"}, - {"Proto2.Required.sfixed64", "*int64"}, - {"Proto2.Required.sint64", "*int64"}, - {"Proto2.Required.uint64", "*uint64"}, - {"Proto2.Required.fixed64", "*uint64"}, - {"Proto2.Required.int32", "*int32"}, - {"Proto2.Required.sfixed32", "*int32"}, - {"Proto2.Required.sint32", "*int32"}, - {"Proto2.Required.uint32", "*uint32"}, - {"Proto2.Required.fixed32", "*uint32"}, - {"Proto2.Required.bool", "*bool"}, - {"Proto2.Required.string", "*string"}, - {"Proto2.Required.bytes", "[]byte"}, - {"Proto2.Required.enum", "*Proto2_Enum"}, - {"Proto2.Required.ext_enum", "*ptype.Syntax"}, - {"Proto2.Required.msg", "*Proto2_Required"}, - {"Proto2.Required.ext_msg", "*duration.Duration"}, - - {"Proto3.double", "float64"}, - {"Proto3.float", "float32"}, - {"Proto3.int64", "int64"}, - {"Proto3.sfixed64", "int64"}, - {"Proto3.sint64", "int64"}, - {"Proto3.uint64", "uint64"}, - {"Proto3.fixed64", "uint64"}, - {"Proto3.int32", "int32"}, - {"Proto3.sfixed32", "int32"}, - {"Proto3.sint32", "int32"}, - {"Proto3.uint32", "uint32"}, - {"Proto3.fixed32", "uint32"}, - {"Proto3.bool", "bool"}, - {"Proto3.string", "string"}, - {"Proto3.bytes", "[]byte"}, - {"Proto3.enum", "Proto3_Enum"}, - {"Proto3.ext_enum", "ptype.Syntax"}, - {"Proto3.msg", "*Proto3_Message"}, - {"Proto3.ext_msg", "*duration.Duration"}, - {"Proto3.repeated_scalar", "[]float64"}, - {"Proto3.repeated_enum", "[]Proto3_Enum"}, - {"Proto3.repeated_ext_enum", "[]ptype.Syntax"}, - {"Proto3.repeated_msg", "[]*Proto3_Message"}, - {"Proto3.repeated_ext_msg", "[]*duration.Duration"}, - {"Proto3.map_scalar", "map[string]float32"}, - {"Proto3.map_enum", "map[int32]Proto3_Enum"}, - {"Proto3.map_ext_enum", "map[uint64]ptype.Syntax"}, - {"Proto3.map_msg", "map[uint32]*Proto3_Message"}, - {"Proto3.map_ext_msg", "map[int64]*duration.Duration"}, - } - for _, test := range tests { - tc := test - t.Run(tc.field, func(t *testing.T) { - t.Parallel() - - e, ok := ast.Lookup(".names.types." + tc.field) - require.True(t, ok, "could not find field") - - fld, ok := e.(pgs.Field) - require.True(t, ok, "entity is not a field") - - assert.Equal(t, tc.expected, ctx.Type(fld)) - }) - } -} + pgs "github.com/lyft/protoc-gen-star" +) func TestTypeName(t *testing.T) { t.Parallel() diff --git a/message.go b/message.go index fa15a2c..13a2f88 100644 --- a/message.go +++ b/message.go @@ -29,9 +29,18 @@ type Message interface { // OneOfFields returns only the fields contained within OneOf blocks. OneOfFields() []Field + // SyntheticOneOfFields returns only the fields contained within synthetic OneOf blocks. + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md + SyntheticOneOfFields() []Field + // OneOfs returns the OneOfs contained within this Message. OneOfs() []OneOf + // RealOneOfs returns the OneOfs contained within this Message. + // This excludes synthetic OneOfs. + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md + RealOneOfs() []OneOf + // Extensions returns all of the Extensions applied to this Message. Extensions() []Extension @@ -139,6 +148,26 @@ func (m *msg) OneOfFields() (f []Field) { return f } +func (m *msg) SyntheticOneOfFields() (f []Field) { + for _, o := range m.oneofs { + if o.IsSynthetic() { + f = append(f, o.Fields()...) + } + } + + return f +} + +func (m *msg) RealOneOfs() (r []OneOf) { + for _, o := range m.oneofs { + if !o.IsSynthetic() { + r = append(r, o) + } + } + + return r +} + func (m *msg) Imports() (i []File) { // Mapping for avoiding duplicate entries mp := make(map[string]File, len(m.fields)) diff --git a/message_test.go b/message_test.go index 65ab251..b6eec25 100644 --- a/message_test.go +++ b/message_test.go @@ -211,6 +211,52 @@ func TestMsg_OneOfs(t *testing.T) { assert.Len(t, m.OneOfs(), 1) } +func TestMsg_SyntheticOneOfFields_And_RealOneOfs(t *testing.T) { + t.Parallel() + + oSyn := &oneof{} + oSyn.flds = []Field{dummyOneOfField(true)} + oSyn.flds[0].setOneOf(oSyn) + + oReal := &oneof{} + oReal.flds = []Field{dummyField(), dummyField()} + oReal.flds[0].setOneOf(oReal) + oReal.flds[1].setOneOf(oReal) + + // no one offs + m := dummyMsg() + assert.Len(t, m.OneOfFields(), 0, "oneof fields") + assert.Len(t, m.SyntheticOneOfFields(), 0, "synthetic oneof fields") + assert.Len(t, m.OneOfs(), 0, "oneofs") + assert.Len(t, m.RealOneOfs(), 0, "real oneofs") + + // one real oneof + m.addField(oReal.flds[0]) + m.addField(oReal.flds[1]) + m.addOneOf(oReal) + assert.Len(t, m.OneOfFields(), 2, "oneof fields") + assert.Len(t, m.SyntheticOneOfFields(), 0, "synthetic oneof fields") + assert.Len(t, m.OneOfs(), 1, "oneofs") + assert.Len(t, m.RealOneOfs(), 1, "real oneofs") + + // one real, one synthetic oneof + m.addField(oSyn.flds[0]) + m.addOneOf(oSyn) + assert.Len(t, m.OneOfFields(), 3, "oneof fields") + assert.Len(t, m.SyntheticOneOfFields(), 1, "synthetic oneof fields") + assert.Len(t, m.OneOfs(), 2, "oneofs") + assert.Len(t, m.RealOneOfs(), 1, "real oneofs") + + // one synthetic oneof + m = dummyMsg() + m.addField(oSyn.flds[0]) + m.addOneOf(oSyn) + assert.Len(t, m.OneOfFields(), 1, "oneof fields") + assert.Len(t, m.SyntheticOneOfFields(), 1, "synthetic oneof fields") + assert.Len(t, m.OneOfs(), 1, "oneofs") + assert.Len(t, m.RealOneOfs(), 0, "real oneofs") +} + func TestMsg_Extension(t *testing.T) { // cannot be parallel m := &msg{desc: &descriptor.DescriptorProto{}} diff --git a/oneof.go b/oneof.go index c0d7a42..34970af 100644 --- a/oneof.go +++ b/oneof.go @@ -19,6 +19,10 @@ type OneOf interface { // Fields returns all fields contained within this OneOf. Fields() []Field + // IsSynthetic returns true if this is a proto3 synthetic oneof. + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md + IsSynthetic() bool + setMessage(m Message) addField(f Field) } @@ -52,6 +56,12 @@ func (o *oneof) Descriptor() *descriptor.OneofDescriptorProto { return o.desc } func (o *oneof) Message() Message { return o.msg } func (o *oneof) setMessage(m Message) { o.msg = m } +func (o *oneof) IsSynthetic() bool { + return o.Syntax() == Proto3 && + len(o.flds) == 1 && + !o.flds[0].InRealOneOf() +} + func (o *oneof) Imports() (i []File) { // Mapping for avoiding duplicate entries mp := make(map[string]File, len(o.flds)) diff --git a/oneof_test.go b/oneof_test.go index a03ab3e..d0eef9c 100644 --- a/oneof_test.go +++ b/oneof_test.go @@ -119,6 +119,20 @@ func TestOneof_Fields(t *testing.T) { assert.Len(t, o.Fields(), 1) } +func TestOneof_IsSynthetic(t *testing.T) { + t.Parallel() + + o := &oneof{msg: &msg{parent: dummyFile()}} + assert.False(t, o.IsSynthetic()) + + o.flds = []Field{dummyField()} + o.flds[0].setOneOf(o) + assert.False(t, o.IsSynthetic()) + + o.flds = []Field{dummyOneOfField(true)} + assert.True(t, o.IsSynthetic()) +} + func TestOneof_Accept(t *testing.T) { t.Parallel() diff --git a/persister.go b/persister.go index 5cbcece..21d48b7 100644 --- a/persister.go +++ b/persister.go @@ -8,12 +8,12 @@ import ( "github.com/golang/protobuf/proto" plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/spf13/afero" - "google.golang.org/protobuf/types/pluginpb" ) type persister interface { SetDebugger(d Debugger) SetFS(fs afero.Fs) + SetSupportedFeatures(f *uint64) AddPostProcessor(proc ...PostProcessor) Persist(a ...Artifact) *plugin_go.CodeGeneratorResponse } @@ -21,21 +21,21 @@ type persister interface { type stdPersister struct { Debugger - fs afero.Fs - procs []PostProcessor + fs afero.Fs + procs []PostProcessor + supportedFeatures *uint64 } func newPersister() *stdPersister { return &stdPersister{fs: afero.NewOsFs()} } func (p *stdPersister) SetDebugger(d Debugger) { p.Debugger = d } func (p *stdPersister) SetFS(fs afero.Fs) { p.fs = fs } +func (p *stdPersister) SetSupportedFeatures(f *uint64) { p.supportedFeatures = f } func (p *stdPersister) AddPostProcessor(proc ...PostProcessor) { p.procs = append(p.procs, proc...) } func (p *stdPersister) Persist(arts ...Artifact) *plugin_go.CodeGeneratorResponse { resp := new(plugin_go.CodeGeneratorResponse) - - supportedFeatures := uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) - resp.SupportedFeatures = &supportedFeatures + resp.SupportedFeatures = p.supportedFeatures for _, a := range arts { switch a := a.(type) { diff --git a/proto.go b/proto.go index 3ce2dd9..a3b78df 100644 --- a/proto.go +++ b/proto.go @@ -13,10 +13,10 @@ const ( // See: https://developers.google.com/protocol-buffers/docs/proto Proto2 Syntax = "" - // Proto3 syntax only allows for optional fields, but defaults to the zero - // value of that particular type. Most of the field types in the generated go - // structs are value types. - // See: https://developers.google.com/protocol-buffers/docs/proto3 + // Proto3 syntax permits the use of "optional" field presence. Non optional fields default to the zero + // value of that particular type if not defined. + // Most of the field types in the generated go structs are value types. + // See: https://github.com/protocolbuffers/protobuf/blob/v3.17.0/docs/field_presence.md#presence-in-proto3-apis Proto3 Syntax = "proto3" ) diff --git a/protoc-gen-debug/main.go b/protoc-gen-debug/main.go index 20b8755..289738d 100644 --- a/protoc-gen-debug/main.go +++ b/protoc-gen-debug/main.go @@ -11,6 +11,8 @@ import ( "os" "path/filepath" + "google.golang.org/protobuf/types/pluginpb" + "github.com/golang/protobuf/proto" plugin_go "github.com/golang/protobuf/protoc-gen-go/plugin" ) @@ -41,8 +43,11 @@ func main() { log.Fatal("unable to write request to disk: ", err) } - data, err = proto.Marshal(&plugin_go.CodeGeneratorResponse{}) - if err != nil { + // protoc-gen-debug supports proto3 field presence for testing purposes + var supportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + if data, err = proto.Marshal(&plugin_go.CodeGeneratorResponse{ + SupportedFeatures: &supportedFeatures, + }); err != nil { log.Fatal("unable to marshal response payload: ", err) }