Skip to content

Commit

Permalink
continue presence
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamet committed Jul 4, 2020
1 parent 0b67d06 commit 3a63244
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 47 deletions.
5 changes: 2 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ import com.typesafe.tools.mima.core._
import BuildHelper._
import Dependencies._

// Different version for compiler-plugin since >=3.8.0 is not binary
// compatible with 3.7.x. When loaded inside SBT (which has its own old
// version), the binary incompatibility surfaces.
val protobufCompilerVersion = "3.12.2"

val MimaPreviousVersion = "0.10.0"
Expand Down Expand Up @@ -345,6 +342,8 @@ lazy val e2eNoJava = (project in file("e2e-nojava"))
.settings(e2eCommonSettings)
.settings(
Compile / PB.protocVersion := "-v" + versions.protobuf,
Compile / PB.protocOptions += "--experimental_allow_proto3_optional",
Compile / PB.generate := ((Compile / PB.generate) dependsOn (protocGenScalaUnix / Compile / assembly)).value,
Compile / PB.targets := Seq(
(
PB.gens.plugin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {
message <- file.allMessages if message.isSealedOneofType
} yield SealedOneof(
message,
message.getOneofs.get(0).getFields.asScala.map(_.getMessageType).toVector
message.getRealOneofs.get(0).getFields.asScala.map(_.getMessageType).toVector
)
new SealedOneofsCache(sealedOneof)
}
Expand Down Expand Up @@ -185,7 +185,7 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {
implicit class FieldDescriptorPimp(val fd: FieldDescriptor) {
import NameUtils._

def containingOneOf: Option[OneofDescriptor] = Option(fd.getContainingOneof)
def containingOneOf: Option[OneofDescriptor] = Option(fd.getRealContainingOneof())

def isInOneof: Boolean = containingOneOf.isDefined

Expand Down Expand Up @@ -230,7 +230,7 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {

def oneOfTypeName: ScalaName = {
assert(isInOneof)
fd.getContainingOneof.scalaType / upperScalaName
fd.getRealContainingOneof.scalaType / upperScalaName
}

def noBox =
Expand All @@ -241,12 +241,16 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {
// Is this field boxed inside an Option in Scala. Equivalent, does the Java API
// support hasX methods for this field.
def supportsPresence: Boolean =
fd.isOptional && !fd.isInOneof && (!fd.getFile.isProto3 || fd.isMessage) &&
fd.isOptional && !fd.isInOneof && (!fd.getFile.isProto3 || fd.isMessage || fd
.toProto()
.getProto3Optional()) &&
!noBox && !fd.isSealedOneofType

// Is the Scala representation of this field a singular type.
def isSingular =
fd.isRequired || (fd.getFile.isProto3 && !fd.isInOneof && fd.isOptional && !fd.isMessage) || (
fd.isRequired || (fd.getFile.isProto3 && !fd.isInOneof && fd.isOptional && !fd.isMessage && !fd
.toProto()
.getProto3Optional()) || (
fd.isOptional && (noBox || (fd.isSealedOneofType && !fd.isInOneof))
)

Expand Down Expand Up @@ -431,6 +435,7 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {
}

implicit class OneofDescriptorPimp(val oneof: OneofDescriptor) {

def javaEnumName = {
val name = NameUtils.snakeCaseToCamelCase(oneof.getName, true)
s"get${name}Case"
Expand Down Expand Up @@ -476,16 +481,17 @@ class DescriptorImplicits(params: GeneratorParams, files: Seq[FileDescriptor]) {

def sealedOneofStyle: SealedOneofStyle = {
assert(isSealedOneofType)
if (message.getOneofs.asScala.exists(_.getName == "sealed_value")) SealedOneofStyle.Default
else if (message.getOneofs.asScala.exists(_.getName == "sealed_value_optional"))
if (message.getRealOneofs.asScala.exists(_.getName == "sealed_value"))
SealedOneofStyle.Default
else if (message.getRealOneofs.asScala.exists(_.getName == "sealed_value_optional"))
SealedOneofStyle.Optional
else throw new RuntimeException("Unexpected oneof style")
}

// every message that passes this filter must be a sealed oneof. The check that it actually
// obeys the rules is done in ProtoValidation.
def isSealedOneofType: Boolean = {
message.getOneofs.asScala
message.getRealOneofs.asScala
.exists(o => o.getName == "sealed_value" || o.getName == "sealed_value_optional")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ProtoValidation(implicits: DescriptorImplicits) {
m.getNestedTypes.asScala.foreach(validateMessage)
m.getFields.asScala.foreach(validateField)
if (m.isSealedOneofType) {
val oneof = m.getOneofs.get(0)
val oneof = m.getRealOneofs.get(0)
if (m.getContainingType != null) {
throw new GeneratorException(s"${m.getFullName}: sealed oneofs must be top-level messages")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ class ProtobufGenerator(
annotations = annotations(field)
)
}
val oneOfFields = message.getOneofs.asScala.map { oneOf =>
val oneOfFields = message.getRealOneofs.asScala.map { oneOf =>
val ctorDefaultValue: Option[String] =
if (message.getFile.noDefaultValuesInConstructor) None
else Some(oneOf.empty.fullNameWithMaybeRoot(message))
Expand Down Expand Up @@ -751,7 +751,7 @@ class ProtobufGenerator(
fp.add(s"var __requiredFields$index: _root_.scala.Long = $bn")
}
}
.print(message.getOneofs.asScala)((printer, oneof) =>
.print(message.getRealOneofs.asScala)((printer, oneof) =>
printer.add(s"var __${oneof.scalaName.name} = `_message__`.${oneof.scalaName.nameSymbol}")
)
.add(s"""var _done__ = false
Expand Down Expand Up @@ -846,7 +846,7 @@ class ProtobufGenerator(
.add(s"$myFullScalaName(")
.indent
.addWithDelimiter(",")(
(message.fieldsWithoutOneofs ++ message.getOneofs.asScala).map {
(message.fieldsWithoutOneofs ++ message.getRealOneofs.asScala).map {
case e: FieldDescriptor if e.isRepeated =>
s" ${e.scalaName.asSymbol} = __${e.scalaName}.result()"
case e: FieldDescriptor =>
Expand Down Expand Up @@ -895,7 +895,7 @@ class ProtobufGenerator(
else javaFieldToScala("javaPbSource", field)
Seq(s"${field.scalaName.asSymbol} = $conversion")
}
val oneOfs = message.getOneofs.asScala.map {
val oneOfs = message.getRealOneofs.asScala.map {
case oneOf =>
val head =
s"${oneOf.scalaName.nameSymbol} = javaPbSource.${oneOf.javaEnumName}.getNumber match {"
Expand Down Expand Up @@ -976,7 +976,7 @@ class ProtobufGenerator(
enclosingType = field.enclosingType
)
}
val oneOfs = message.getOneofs.asScala.map { oneOf =>
val oneOfs = message.getRealOneofs.asScala.map { oneOf =>
val elems = oneOf.fields.map { field =>
val value =
s"__fieldsMap.get(scalaDescriptor.findFieldByNumber(${field.getNumber}).get)"
Expand Down Expand Up @@ -1022,7 +1022,7 @@ class ProtobufGenerator(
case field if !field.isInOneof =>
val default = defaultValueForDefaultInstance(field)
s"${field.scalaName.asSymbol} = $default"
} ++ message.getOneofs.asScala.map { oneof =>
} ++ message.getRealOneofs.asScala.map { oneof =>
s"${oneof.scalaName.nameSymbol} = ${oneof.empty.fullName}"
})
.outdent
Expand Down Expand Up @@ -1062,7 +1062,7 @@ class ProtobufGenerator(
)
}
}
.print(message.getOneofs.asScala) {
.print(message.getRealOneofs.asScala) {
case (printer, oneof) =>
val oneofName = oneof.scalaName.nameSymbol
printer
Expand Down Expand Up @@ -1353,7 +1353,7 @@ class ProtobufGenerator(
.call(generateEnumCompanionForField(message))
.call(generateDefaultInstance(message))
.print(message.getEnumTypes.asScala)(printEnum)
.print(message.getOneofs.asScala)(printOneof)
.print(message.getRealOneofs.asScala)(printOneof)
.print(message.nestedTypes)(printMessage)
.print(message.getExtensions.asScala)(printExtension)
.when(message.generateLenses)(generateMessageLens(message))
Expand Down Expand Up @@ -1457,7 +1457,7 @@ class ProtobufGenerator(
)
}
}
.print(message.getOneofs.asScala) {
.print(message.getRealOneofs.asScala) {
case (printer, oneof) =>
printer.add(
s"""def clear${oneof.scalaType.name}: ${message.scalaType.nameSymbol} = copy(${oneof.scalaName.nameSymbol} = ${oneof.empty
Expand Down Expand Up @@ -1736,22 +1736,17 @@ object ProtobufGenerator {
def handleCodeGeneratorRequest(request: CodeGenRequest): CodeGenResponse = {
parseParameters(request.parameter) match {
case Right(params) =>
try {
val implicits = new DescriptorImplicits(params, request.allProtos)
val generator = new ProtobufGenerator(params, implicits)
val validator = new ProtoValidation(implicits)
validator.validateFiles(request.allProtos)
import implicits.FileDescriptorPimp
val files = request.filesToGenerate.flatMap { file =>
if (file.scalaOptions.getSingleFile)
generator.generateSingleScalaFileForFileDescriptor(file)
else generator.generateMultipleScalaFilesForFileDescriptor(file)
}
CodeGenResponse.succeed(files)
} catch {
case e: GeneratorException =>
CodeGenResponse.fail(e.message)
val implicits = new DescriptorImplicits(params, request.allProtos)
val generator = new ProtobufGenerator(params, implicits)
val validator = new ProtoValidation(implicits)
validator.validateFiles(request.allProtos)
import implicits.FileDescriptorPimp
val files = request.filesToGenerate.flatMap { file =>
if (file.scalaOptions.getSingleFile)
generator.generateSingleScalaFileForFileDescriptor(file)
else generator.generateMultipleScalaFilesForFileDescriptor(file)
}
CodeGenResponse.succeed(files, Set(CodeGeneratorResponse.Feature.FEATURE_PROTO3_OPTIONAL))
case Left(error) =>
CodeGenResponse.fail(error)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SealedOneofsGenerator(message: Descriptor, implicits: DescriptorImplicits)
val sealedOneofType = message.sealedOneofScalaType
val sealedOneofName = message.sealedOneofTraitScalaType.nameSymbol
val typeMapper = s"_root_.scalapb.TypeMapper[${baseType}, ${sealedOneofType}]"
val oneof = message.getOneofs.get(0)
val oneof = message.getRealOneofs.get(0)
val typeMapperName = message.sealedOneofTypeMapper.name

if (message.sealedOneofStyle != SealedOneofStyle.Optional) {
Expand Down
14 changes: 14 additions & 0 deletions e2e-nojava/src/main/protobuf/proto3_presence.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
syntax = "proto3";

package com.thesamet.proto.e2e;

enum PEnum {
PENUM_UNKNOWN = 0;
PENUM_A = 1;
}

message Presence {
optional int32 i32 = 1;
optional PEnum penum = 2;
optional Presence presence = 3;
}
3 changes: 3 additions & 0 deletions e2e/src/main/protobuf/google/protobuf/unittest.proto
Original file line number Diff line number Diff line change
Expand Up @@ -834,3 +834,6 @@ service TestService {
message BarRequest {}
message BarResponse {}

message TestMap {
map<int32, string> int32_to_string = 1;
}
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import Keys._
object Dependencies {
object versions {
val grpc = "1.30.2"
val protobuf = "3.11.4"
val protobuf = "3.12.2"
val utest = "0.7.4"
val fastparse = "2.3.0"
val silencer = "1.6.0"
val collectionCompat = "2.1.6"
val coursier = "2.0.0-RC6-21"
val mockito = "3.2.0"
val protocGen = "0.9.0-RC2"
val protocGen = "0.9.0-RC3"
val protocJar = "3.11.4"
val protobufRuntimeScala = "0.8.6"
val commonsCodec = "1.14"
Expand Down
21 changes: 15 additions & 6 deletions proptest/src/test/scala/GenTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ object GenTypes {
val REPEATED = Value("repeated")
}

case class FieldOptions(modifier: FieldModifier.Value, isPacked: Boolean)
case class FieldOptions(modifier: FieldModifier.Value, isPacked: Boolean, proto3Presence: Boolean)

def genFieldModifier(allowRequired: Boolean): Gen[FieldModifier.Value] =
if (allowRequired)
Expand Down Expand Up @@ -188,19 +188,28 @@ object GenTypes {
protoSyntax: ProtoSyntax,
inOneof: Boolean
): Gen[FieldOptions] =
if (inOneof) Gen.const(FieldOptions(FieldModifier.OPTIONAL, isPacked = false))
if (inOneof)
Gen.const(FieldOptions(FieldModifier.OPTIONAL, isPacked = false, proto3Presence = false))
else
fieldType match {
case MessageReference(id) =>
genFieldModifier(allowRequired = protoSyntax.isProto2 && id < messageId)
.map(mod => FieldOptions(mod, isPacked = false))
case MapType(_, _) => Gen.const(FieldOptions(FieldModifier.REPEATED, isPacked = false))
for {
mod <- genFieldModifier(allowRequired = protoSyntax.isProto2 && id < messageId)
proto3Presence <- if (mod == FieldModifier.OPTIONAL && protoSyntax.isProto3)
Gen.oneOf(true, false)
else Gen.const(false)
} yield FieldOptions(mod, isPacked = false, proto3Presence = proto3Presence)
case MapType(_, _) =>
Gen.const(FieldOptions(FieldModifier.REPEATED, isPacked = false, proto3Presence = false))
case _ =>
for {
mod <- genFieldModifier(allowRequired = protoSyntax.isProto2)
packed <- if (fieldType.packable && mod == FieldModifier.REPEATED)
Gen.oneOf(true, false)
else Gen.const(false)
} yield FieldOptions(mod, isPacked = packed)
proto3Presence <- if (mod == FieldModifier.OPTIONAL && protoSyntax.isProto3)
Gen.oneOf(true, false)
else Gen.const(false)
} yield FieldOptions(mod, isPacked = packed, proto3Presence = proto3Presence)
}
}
4 changes: 3 additions & 1 deletion proptest/src/test/scala/GraphGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,14 @@ object GraphGen {
if (protoPackage.isEmpty) state else state.closeNamespace
)

def genRootNode: Gen[RootNode] =
def genRootNode: Gen[RootNode] = {
listWithStatefulGen(State(), maxSize = 10)(genFileNode)
.map {
case (files, state) =>
assert(state.namespace.parent.isEmpty)
RootNode(files)
}
.suchThat(_.maxMessageId.isDefined)

}
}
3 changes: 2 additions & 1 deletion proptest/src/test/scala/Nodes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ object Nodes {
assert(
fieldOptions.modifier == FieldModifier.OPTIONAL || fieldOptions.modifier == FieldModifier.REPEATED
)
if (fieldOptions.modifier == FieldModifier.OPTIONAL || fieldType.isMap) ""
if (fieldOptions.proto3Presence) "optional "
else if (fieldOptions.modifier == FieldModifier.OPTIONAL || fieldType.isMap) ""
else if (fieldOptions.modifier == FieldModifier.REPEATED) "repeated "
else throw new RuntimeException("Unexpected modifier")
}
Expand Down
1 change: 1 addition & 0 deletions proptest/src/test/scala/SchemaGenerators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ object SchemaGenerators {
file.getAbsolutePath
}
val args = Seq(
"--experimental_allow_proto3_optional",
"--proto_path",
(tmpDir.toString + ":protobuf:third_party"),
"--java_out",
Expand Down

0 comments on commit 3a63244

Please sign in to comment.