diff --git a/.gitignore b/.gitignore index cdbd28eca7c..b9391f86945 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ __pycache__ cmake_build/ tensorflow/contrib/cmake/_build/ .idea/** +.run /build/ [Bb]uild/ /tensorflow/core/util/version_info.cc diff --git a/tensorflow-core/tensorflow-core-api/build.sh b/tensorflow-core/tensorflow-core-api/build.sh index 37bb6d8482b..365d4da9895 100755 --- a/tensorflow-core/tensorflow-core-api/build.sh +++ b/tensorflow-core/tensorflow-core-api/build.sh @@ -89,7 +89,7 @@ $BAZEL_BIN/java_op_generator \ GEN_RESOURCE_DIR=src/gen/resources/org/tensorflow/op mkdir -p $GEN_RESOURCE_DIR -# Generate Java operator wrappers +# Export op defs $BAZEL_BIN/java_op_exporter \ --api_dirs=$BAZEL_SRCS/external/org_tensorflow/tensorflow/core/api_def/base_api,src/bazel/api_def \ $TENSORFLOW_LIB > $GEN_RESOURCE_DIR/ops.pb diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index e2c8e06725a..5f4829f95c8 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -328,6 +328,41 @@ + + + org.codehaus.mojo + exec-maven-plugin + 3.0.0 + + + generate-ops + + java + + + + + + + org.tensorflow + tensorflow-core-generator + ${project.version} + + + + false + true + org.tensorflow.op.generator.OpGenerator + + ${project.basedir}/src/gen/java + ${project.basedir}/src/gen/resources/org/tensorflow/op/ops.pb + + + maven-jar-plugin 3.1.0 diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java index a30138e0386..1b061bebef9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java @@ -91,6 +91,7 @@ private static void register(Class type) { static { // TODO (karllessard) scan and registered automatically all annotated tensors types + // TODO use in generator? register(TBool.class); register(TFloat64.class); register(TFloat32.class); diff --git a/tensorflow-core/tensorflow-core-generator/pom.xml b/tensorflow-core/tensorflow-core-generator/pom.xml index b108d079dc7..614ba621b43 100644 --- a/tensorflow-core/tensorflow-core-generator/pom.xml +++ b/tensorflow-core/tensorflow-core-generator/pom.xml @@ -10,8 +10,8 @@ tensorflow-core-generator jar - TensorFlow Core Annotation Processor - Annotation processor for TensorFlow Java client + TensorFlow Core Generators + Code generators for TensorFlow Java client org.tensorflow.core.generator @@ -28,11 +28,22 @@ javapoet 1.12.1 + + com.google.protobuf + protobuf-java + ${protobuf.version} + com.github.javaparser javaparser-core 3.15.12 + + + org.commonmark + commonmark + 0.17.1 + diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java new file mode 100644 index 00000000000..d2218c7d6e0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java @@ -0,0 +1,73 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.TypeName; + +public class Names { + + public static final String TensorflowPackage = "org.tensorflow"; + public static final String OpPackage = TensorflowPackage + ".op"; + public static final String TypesPackage = TensorflowPackage + ".types"; + + public static final ClassName Operator = ClassName.get(OpPackage + ".annotation", "Operator"); + public static final ClassName Endpoint = ClassName.get(OpPackage + ".annotation", "Endpoint"); + + public static final ClassName TType = ClassName.get(TypesPackage + ".family", "TType"); + public static final ClassName TString = ClassName.get(TypesPackage, "TString"); + public static final ClassName TBool = ClassName.get(TypesPackage, "TBool"); + + public static final ClassName TNumber = ClassName.get(TypesPackage + ".family", "TNumber"); + + public static final ClassName TFloating = ClassName.get(TypesPackage + ".family", "TFloating"); + public static final ClassName TBfloat16 = ClassName.get(TypesPackage, "TBfloat16"); + public static final ClassName TFloat16 = ClassName.get(TypesPackage, "TFloat16"); + public static final ClassName TFloat32 = ClassName.get(TypesPackage, "TFloat32"); + public static final ClassName TFloat64 = ClassName.get(TypesPackage, "TFloat64"); + + public static final ClassName TIntegral = ClassName.get(TypesPackage + ".family", "TIntegral"); + public static final ClassName TUint8 = ClassName.get(TypesPackage, "TUint8"); + public static final ClassName TInt32 = ClassName.get(TypesPackage, "TInt32"); + public static final ClassName TInt64 = ClassName.get(TypesPackage, "TInt64"); + + public static final TypeName Op = ClassName.get(OpPackage, "Op"); + public static final ClassName RawOp = ClassName.get(OpPackage, "RawOp"); + public static final ClassName Operation = ClassName.get(TensorflowPackage, "Operation"); + public static final ClassName Operands = ClassName.get(OpPackage, "Operands"); + public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder"); + public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op); + + public static final ClassName Operand = ClassName.get(TensorflowPackage, "Operand"); + public static final ClassName Output = ClassName.get(TensorflowPackage, "Output"); + + public static final ClassName Shape = ClassName.get(TensorflowPackage + ".ndarray", "Shape"); + public static final ClassName Tensor = ClassName.get(TensorflowPackage, "Tensor"); + public static final ClassName ConcreteFunction = ClassName.get(TensorflowPackage, "ConcreteFunction"); + + public static final ClassName Scope = ClassName.get(OpPackage, "Scope"); + public static final TypeName DeviceSpec = ClassName.get(TensorflowPackage, "DeviceSpec"); + public static final ClassName Ops = ClassName.get(OpPackage, "Ops"); + + public static final TypeName ExecutionEnvironment = + ClassName.get(TensorflowPackage, "ExecutionEnvironment"); + public static final TypeName EagerSession = ClassName.get(TensorflowPackage, "EagerSession"); + + public static final TypeName String = ClassName.get(String.class); + +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java new file mode 100644 index 00000000000..20e5edf4989 --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java @@ -0,0 +1,785 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.generator; + +import static org.tensorflow.op.generator.GeneratorUtils.javaizeMemberName; +import static org.tensorflow.op.generator.GeneratorUtils.parseDocumentation; + +import com.squareup.javapoet.AnnotationSpec; +import com.squareup.javapoet.ArrayTypeName; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.CodeBlock; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeSpec.Builder; +import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import javax.lang.model.element.Modifier; +import org.tensorflow.Names; +import org.tensorflow.proto.framework.ApiDef; +import org.tensorflow.proto.framework.ApiDef.Endpoint; +import org.tensorflow.proto.framework.ApiDef.Visibility; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.OpDef; +import org.tensorflow.proto.framework.OpDef.ArgDef; +import org.tensorflow.proto.framework.OpDef.AttrDef; + +/** + * A generator to generate a op class + */ +final class ClassGenerator { + + /** + * Return true if we can generate the operation class for {@code op}. + */ + static boolean canGenerateOp(OpDef op, ApiDef apiDef) { + return apiDef.getVisibility() != Visibility.SKIP + && !op.getAttrList().stream().anyMatch(x -> x.getType().contains("func")) + && !op.getName().startsWith("_"); //TODO do I want this? Some interesting ops like _XlaCompile + } + + enum RenderMode { + DEFAULT, LIST_OPERAND, OPERAND; + } + + /** + * The in-progress class builder for the top level op class. + */ + private final TypeSpec.Builder builder; + + /** + * The op to build. + */ + private final OpDef op; + + /** + * The api definition for the current op. + */ + private final ApiDef apiDef; + + /** + * A type resolver for the current op. + */ + private final TypeResolver resolver; + + /** + * The full package of the class. + */ + private final String fullPackage; + + /** + * The base package for this op generation run. + */ + private final String basePackage; + + /** + * The group of this op. + */ + private final String group; + + /** + * The class name for this op. + */ + private final String className; + + /** + * The endpoint being generated in this class. + */ + private final Endpoint endpoint; + + /** + * The generated options class, or null if it doesn't have one or {@link #buildOptionsClass()} has not been ran. + */ + private TypeSpec optionsClass = null; + + /** + * What type of op this is. + */ + private RenderMode mode = RenderMode.DEFAULT; + + /** + * The required attributes of this op. + */ + private final List requiredAttributes = new ArrayList<>(); + + /** + * The optional attributes of this op. + */ + private final List optionalAttributes = new ArrayList<>(); + + /** + * The class's type parameters, initialized in {@link #buildClass()}. + */ + private final Set typeParams = new LinkedHashSet<>(); + + /** + * The api defs for the arguments. + */ + private final Map argApis = new HashMap<>(); + + /** + * The api defs for the attributes. + */ + private final Map attrApis = new HashMap<>(); + + ClassGenerator(Builder builder, OpDef op, ApiDef apiDef, + String basePackage, String fullPackage, String group, String className, Endpoint endpoint) { + + this.builder = builder; + this.op = op; + this.apiDef = apiDef; + this.resolver = new TypeResolver(op); + this.basePackage = basePackage; + this.fullPackage = fullPackage; + this.group = group; + this.className = className; + this.endpoint = endpoint; + + op.getAttrList().forEach(attr -> { + if (attr.hasDefaultValue() && !attr.getType().contains("type")) { + optionalAttributes.add(attr); + } else { + requiredAttributes.add(attr); + } + }); + + for (AttrDef attr : op.getAttrList()) { + ApiDef.Attr api = apiDef.getAttrList().stream().filter(x -> x.getName().equals(attr.getName())).findFirst().get(); + attrApis.put(attr, api); + } + + for (ArgDef arg : op.getInputArgList()) { + ApiDef.Arg api = apiDef.getInArgList().stream().filter(x -> x.getName().equals(arg.getName())).findFirst().get(); + argApis.put(arg, api); + } + for (ArgDef arg : op.getOutputArgList()) { + ApiDef.Arg api = apiDef.getOutArgList().stream().filter(x -> x.getName().equals(arg.getName())).findFirst().get(); + argApis.put(arg, api); + } + } + + /** + * Get the Java variable name for an argument. + */ + private String getJavaName(ArgDef arg) { + String name = arg.getName(); + String rename = argApis.get(arg).getRenameTo(); + if (!rename.isEmpty()) { + return javaizeMemberName(rename); + } else { + return javaizeMemberName(name); + } + } + + /** + * Get the Java variable name for an attribute. + */ + private String getJavaName(AttrDef arg) { + String name = arg.getName(); + String rename = attrApis.get(arg).getRenameTo(); + if (!rename.isEmpty()) { + return javaizeMemberName(rename); + } else { + return javaizeMemberName(name); + } + } + + /** + * Get the fully qualified name of the class being generated. + */ + private String fullClassName() { + return fullPackage + "." + className; + } + + /** + * Build the class. + */ + void buildClass() { + builder.addModifiers(Modifier.PUBLIC, Modifier.FINAL); + builder.superclass(Names.RawOp); + + // add class javadocs + String summary = parseDocumentation(apiDef.getSummary()); + if (!summary.isEmpty()) { + builder.addJavadoc("$L", summary + "\n"); + } else { + builder.addJavadoc("The $L operation\n", apiDef.getGraphOpName()); + } + String desc = parseDocumentation(apiDef.getDescription()); + if (!desc.isEmpty()) { + builder.addJavadoc("$L", desc + "\n"); + } + + // add superinterface and set mode + if (op.getOutputArgCount() == 1) { + ArgDef output = op.getOutputArg(0); + ResolvedType rType = resolver.typeOf(output); + TypeName type = rType.unwrapArg(); + boolean iterable = rType.iterable; + TypeName operandTypeParam = + type instanceof WildcardTypeName ? Names.TType : type; + TypeName operandType = ParameterizedTypeName.get(Names.Operand, operandTypeParam); + + if (iterable) { + mode = RenderMode.LIST_OPERAND; + builder.addSuperinterface(ParameterizedTypeName.get(ClassName.get(Iterable.class), operandType)); + } else { + mode = RenderMode.OPERAND; + builder.addSuperinterface(operandType); + } + } + + // add and store type variables + Set seenGenerics = new HashSet<>(); + for (ArgDef output : op.getOutputArgList()) { + ResolvedType type = resolver.typeOf(output); + for (TypeVariableName typeVar : type.findGenerics()) { + if (seenGenerics.add(typeVar.name)) { + typeParams.add(typeVar); + builder.addTypeVariable(typeVar); + builder.addJavadoc( + "\n@param <$L> data type for {@code $L} output\n", typeVar.name, output.getName()); + } + } + } + + // add deprecated if necessary + if (endpoint.getDeprecated()) { + builder.addAnnotation(Deprecated.class); + Endpoint first = apiDef.getEndpoint(0); + String explanation; + if (!first.getDeprecated()) { + explanation = "use {@link " + basePackage + "." + first.getName() + "} instead"; + } else { + explanation = op.getDeprecation().getExplanation(); + } + builder.addJavadoc("\n@deprecated $L", explanation); + } + + // add the Operator annotation + if (apiDef.getVisibility() != Visibility.HIDDEN) { + AnnotationSpec.Builder annotation = AnnotationSpec.builder(Names.Operator); + if (!group.equals("core")) { + annotation.addMember("group", "$S", group); + } + builder.addAnnotation(annotation.build()); + } + + if (!optionalAttributes.isEmpty()) { + buildOptionsClass(); + } + + buildFactoryMethods(); + buildGettersAndSetters(); + if (mode != RenderMode.DEFAULT) { + buildInterfaceImpl(); + } + + // add op name field + builder + .addField(FieldSpec.builder(TypeResolver.STRING, "OP_NAME", Modifier.PUBLIC, Modifier.STATIC, Modifier.FINAL) + .addJavadoc("$L", "The name of this op, as known by TensorFlow core engine") + .initializer("$S", op.getName()) + .build()); + + // add output fields + if (op.getOutputArgCount() > 0) { + for (ArgDef output : op.getOutputArgList()) { + builder + .addField(resolver.typeOf(output).listIfIterable().javaType, getJavaName(output), Modifier.PRIVATE); + } + } + + buildConstructor(); + } + + /** + * Add a nested class for Options + */ + private void buildOptionsClass() { + + if (optionalAttributes.isEmpty()) { + return; + } + + TypeSpec.Builder optionsBuilder = TypeSpec.classBuilder("Options").addModifiers(Modifier.PUBLIC, Modifier.STATIC); + optionsBuilder.addJavadoc("$L", "Optional attributes for {@link " + fullClassName() + "}"); + + ClassName optionsClassName = ClassName.get(fullPackage, className, "Options"); + + for (AttrDef attr : optionalAttributes) { + ResolvedType type = resolver.typeOf(attr); + + String name = getJavaName(attr); + + ApiDef.Attr apiAttr = attrApis.get(attr); + + String description = + apiAttr.getDescription().isEmpty() + ? String.format("the %s option", name) + : apiAttr.getDescription(); + + // add the setter method, adding one with varargs and one with List if the attribute is + // iterable + if (type.iterable) { + optionsBuilder.addMethod( + MethodSpec.methodBuilder(name) + .returns(optionsClassName) + .addModifiers(Modifier.PUBLIC) + .addParameter(type.classIfGeneric().listIfIterable().javaType, name) + .addJavadoc("Sets the $L option.\n", name) + .addJavadoc("\n@param $L $L", name, parseDocumentation(description)) + .addJavadoc("\n@return this Options instance.") + .addCode("this.$L = $L;\nreturn this;\n", name, name) + .build()); + + optionsBuilder.addMethod( + MethodSpec.methodBuilder(name) + .returns(optionsClassName) + .addModifiers(Modifier.PUBLIC) + .addParameter(type.classIfGeneric().arrayIfIterable().javaType, name) + .addJavadoc("Sets the $L option.\n", name) + .addJavadoc("\n@param $L $L", name, parseDocumentation(description)) + .addJavadoc("\n@return this Options instance.") + .addCode("this.$L = $T.asList($L);\nreturn this;\n", name, Arrays.class, name) + .varargs() + .build()); + + } else { + optionsBuilder.addMethod( + MethodSpec.methodBuilder(name) + .returns(optionsClassName) + .addModifiers(Modifier.PUBLIC) + .addParameter(type.classIfGeneric().javaType, name) + .addJavadoc("Sets the $L option.\n", name) + .addJavadoc("\n@param $L $L", name, parseDocumentation(description)) + .addJavadoc("\n@return this Options instance.") + .addCode("this.$L = $L;\nreturn this;\n", name, name) + .build()); + } + + // add the field + optionsBuilder.addField(type.classIfGeneric().listIfIterable().javaType, name, Modifier.PRIVATE); + } + + // add a private constructor + optionsBuilder.addMethod(MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE).build()); + + optionsClass = optionsBuilder.build(); + builder.addType(optionsClass); + } + + /** + * Write statements to set an attribute in an OperationBuilder. Meant to be used in {@link #buildFactoryMethods()} + * + * @param body the body to write to + * @param attr the attribute to set + * @param type the type of the attribute, or null to get it ourselves + * @param optional whether the attribute is optional + */ + private void writeSetAttr(CodeBlock.Builder body, AttrDef attr, ResolvedType type, boolean optional) { + String varName = optional ? "opts." + getJavaName(attr) : getJavaName(attr); + if (type == null) { + type = resolver.typeOf(attr); + } + + if (type.jniType.equals(ClassName.get(DataType.class))) { + body.addStatement("opBuilder.setAttr($S, $T.$L($L))", attr.getName(), + Names.Operands, type.iterable ? "toDataTypes" : "toDataType", varName); + } else { + if (type.iterable) { + String arrayName = javaizeMemberName(attr.getName()) + "Array"; + body.addStatement("$T[] $L = new $T[$L.size()]", type.jniType, arrayName, type.jniType, varName); + + body.beginControlFlow("for (int i = 0 ; i < $L.length ; i++)", arrayName); + + body.addStatement("$L[i] = $L.get(i)", arrayName, varName); + + body.endControlFlow(); + + body.addStatement("opBuilder.setAttr($S, $L)", attr.getName(), arrayName); + } else { + body.addStatement("opBuilder.setAttr($S, $L)", attr.getName(), varName); + } + } + } + + /** + * Add the {@code create} factory methods. + */ + private void buildFactoryMethods() { + MethodSpec.Builder factoryBuilder = MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC); + + // the main creator will inherit any class type params + TypeName returnType = ClassName.get(fullPackage, className); + if (!typeParams.isEmpty()) { + returnType = ParameterizedTypeName.get((ClassName) returnType, typeParams.toArray(new TypeName[0])); + } + factoryBuilder.returns(returnType); + + factoryBuilder.addAnnotation( + AnnotationSpec.builder(Names.Endpoint).addMember("describeByClass", "true") + .build()); + + factoryBuilder.addJavadoc("Factory method to create a class wrapping a new $L operation.\n", op.getName()); + + // we're going to build the body as add arguments + CodeBlock.Builder body = CodeBlock.builder(); + + Map paramTags = new LinkedHashMap<>(); + + factoryBuilder + .addParameter(ParameterSpec.builder(Names.Scope, "scope").build()); + paramTags.put("scope", CodeBlock.of("current scope")); + + Set typeVars = new LinkedHashSet<>(typeParams); + + body.addStatement("$T opBuilder = scope.env().opBuilder($S, scope.makeOpName($S))", + Names.OperationBuilder, op.getName(), + className); + + // add the inputs as parameters, and add them to the op builder + for (ArgDef input : op.getInputArgList()) { + ApiDef.Arg argDef = argApis.get(input); + ResolvedType type = resolver.typeOf(input); + String name = getJavaName(input); + + ParameterSpec.Builder param = ParameterSpec.builder(type.iterableIfIterable().javaType, name); + String description = + argDef.getDescription().isEmpty() + ? String.format("the %s value", name) + : argDef.getDescription(); + paramTags.put(name, CodeBlock.of("$L", parseDocumentation(description))); + factoryBuilder.addParameter(param.build()); + + typeVars.addAll(type.findGenerics()); + + if (type.iterable) { + body.addStatement("opBuilder.addInputList($T.asOutputs($L))", Names.Operands, name); + } else { + body.addStatement("opBuilder.addInput($L.asOutput())", name); + } + + } + + body.addStatement("opBuilder = scope.apply(opBuilder)"); + + // add the required attribute params, and build the default type maps for use in the secondary factory + Map defaultTypes = new HashMap<>(); + Map defaultTypeVars = new HashMap<>(); + for (AttrDef attr : requiredAttributes) { + if (resolver.partOfInput(attr.getName())) { + continue; + } + + ResolvedType type = resolver.typeOf(attr); + ApiDef.Attr apiAttr = attrApis.get(attr); + + ParameterSpec.Builder builder = ParameterSpec + .builder(type.classIfGeneric().listIfIterable().javaType, getJavaName(attr)); + + String javaName = getJavaName(attr); + String description = + apiAttr.getDescription().isEmpty() + ? String.format("the value of the %s property", javaName) + : apiAttr.getDescription(); + paramTags.put(javaName, CodeBlock.of("$L", parseDocumentation(description))); + + typeVars.addAll(type.findGenerics()); + + factoryBuilder.addParameter(builder.build()); + + // we only add defaults for type variable arguments + if (attr.hasDefaultValue() && type.shouldWrapInClass()) { + TypeName defaultType = TypeResolver.forDataType(attr.getDefaultValue().getType()); + if (!(defaultType instanceof WildcardTypeName) && defaultType != Names.TType) { + defaultTypes.put(attr, defaultType); + defaultTypeVars.put(((TypeVariableName) type.javaType).name, defaultType); + } + } + + writeSetAttr(body, attr, type, false); + } + + // add optional attributes + if (optionsClass != null) { + factoryBuilder.addParameter( + ParameterSpec.builder(ArrayTypeName.of(ClassName.get(fullPackage, className, "Options")), "options").build()); + paramTags.put("options", CodeBlock.of("$L", "carries optional attribute values")); + factoryBuilder.varargs(); + + body.beginControlFlow("if (options != null)"); + + body.beginControlFlow("for (Options opts : options)"); + for (AttrDef attr : optionalAttributes) { + String name = getJavaName(attr); + body.beginControlFlow("if (opts.$L != null)", name); + + writeSetAttr(body, attr, null, true); + + body.endControlFlow(); + } + body.endControlFlow(); + + body.endControlFlow(); + + } + + body.addStatement("return new $L(opBuilder.build())", typeParams.isEmpty() ? className : (className + "<>")); + + factoryBuilder.addCode(body.build()); + paramTags.forEach( + (param, doc) -> { + String description = doc.toString(); + if (description.isEmpty() || description.equals("\n")) { + factoryBuilder.addJavadoc( + "\n@param $L the $L property", param, param); + } else { + factoryBuilder.addJavadoc("\n@param $L $L", param, doc); + } + }); + for (TypeVariableName typeVar : typeVars) { + factoryBuilder.addJavadoc( + "\n@param <" + + typeVar.name + + "> data type for {@code " + + op.getName() + + "} output and operands"); + } + + factoryBuilder.addTypeVariables(typeVars); + factoryBuilder.addJavadoc("\n@return a new instance of $L\n", className); + MethodSpec method = factoryBuilder.build(); + builder.addMethod(method); + + if (!defaultTypes.isEmpty()) { + buildSecondaryFactory(defaultTypes, defaultTypeVars, method, paramTags); + } + } + + /** + * Add a secondary factory method with the provided default type maps + */ + private void buildSecondaryFactory(Map defaultTypes, Map defaultTypeVars, + MethodSpec mainFactory, Map paramTags) { + MethodSpec.Builder factoryBuilder = MethodSpec.methodBuilder(mainFactory.name) + .addModifiers(mainFactory.modifiers) + .returns(ParameterizedTypeName.get(ClassName.get(fullPackage, className), typeParams.stream() + .map(x -> defaultTypeVars.getOrDefault(x.name, x)).toArray(TypeName[]::new))); + factoryBuilder.addAnnotations(mainFactory.annotations); + + factoryBuilder + .addJavadoc("Factory method to create a class wrapping a new $L operation, with the default output types.\n", + op.getName()); + + CodeBlock.Builder body = CodeBlock.builder(); + body.add("return create("); + + Set typeVars = new LinkedHashSet<>(); + + // we want to add all of the main factory's parameters except for those with defaults + // we just pass them through + boolean first = true; + for (ParameterSpec param : mainFactory.parameters) { + if (!first) { + body.add(", "); + } + + AttrDef attr = op.getAttrList().stream().filter(x -> getJavaName(x).equals(param.name)).findFirst() + .orElse(null); + if (attr != null && resolver.typeOf(attr).shouldWrapInClass() && defaultTypes.containsKey(attr)) { + body.add("$T.class", defaultTypes.get(attr)); + } else { + factoryBuilder.addParameter(param); + factoryBuilder.addJavadoc("\n@param $L $L", param.name, paramTags.get(param.name)); + typeVars.addAll(new ResolvedType(param.type).findGenerics()); + body.add("$L", param.name); + } + first = false; + } + + body.add(");"); + + for (TypeVariableName typeVar : typeVars) { + factoryBuilder.addJavadoc( + "\n@param <" + + typeVar.name + + "> data type for {@code " + + op.getName() + + "} output and operands"); + } + + factoryBuilder.addJavadoc("\n@return a new instance of $L, with default output types", className); + factoryBuilder.addCode(body.build()); + factoryBuilder.addTypeVariables(typeVars); + + builder.addMethod(factoryBuilder.build()); + } + + /** + * Add getters for the outputs and setters/Options creators for the optional attributes. + * + *

Needs to be called after {@link #buildOptionsClass()} + */ + private void buildGettersAndSetters() { + if (optionsClass != null) { + optionsClass.methodSpecs.stream() + .filter(x -> !x.isConstructor()) + .forEach( + method -> { + String argName = method.parameters.get(0).name; + builder.addMethod( + MethodSpec.methodBuilder(method.name) + .addParameter(method.parameters.get(0)) + .addJavadoc(method.javadoc) + .returns(ClassName.get(fullPackage, className, "Options")) + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addCode("return new Options().$L($L);", method.name, argName) + .build()); + }); + } + + for (ArgDef output : op.getOutputArgList()) { + String name = getJavaName(output); + ApiDef.Arg argDef = argApis.get(output); + builder.addMethod( + MethodSpec.methodBuilder(name) + .addModifiers(Modifier.PUBLIC) + .returns(resolver.typeOf(output).listIfIterable().javaType) + .addJavadoc("Gets $L.\n", name) + .addJavadoc("$L", parseDocumentation(argDef.getDescription())) + .addJavadoc("\n@return $L.", name) + .addCode("return $L;", name) + .build()); + } + + } + + /** + * Add interface implementation methods if this is a single output or list output op. + */ + private void buildInterfaceImpl() { + ArgDef output = op.getOutputArg(0); + TypeName type = resolver.typeOf(output).unwrapArg(); + + boolean uncheckedCast = type instanceof WildcardTypeName; + TypeName outputTType = uncheckedCast ? Names.TType : type; + + if (mode == RenderMode.OPERAND) { + TypeName outputType = ParameterizedTypeName.get(Names.Output, outputTType); + MethodSpec.Builder asOutput = MethodSpec.methodBuilder("asOutput") + .addModifiers(Modifier.PUBLIC) + .returns(outputType) + .addAnnotation(Override.class); + + if (uncheckedCast) { + asOutput.addAnnotation( + AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "$S", "unchecked").build()); + asOutput.addCode("return ($T) $L;", outputType, getJavaName(output)); + } else { + asOutput.addCode("return $L;", getJavaName(output)); + } + + builder.addMethod(asOutput + .build()); + } else { + TypeName operandType = ParameterizedTypeName.get(Names.Operand, outputTType); + TypeName returnType = ParameterizedTypeName.get(ClassName.get(Iterator.class), operandType); + + MethodSpec.Builder iterator = MethodSpec.methodBuilder("iterator") + .addModifiers(Modifier.PUBLIC) + .returns(returnType) + .addAnnotation(Override.class) + .addAnnotation( + AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "{$S, $S}", "rawtypes", "unchecked") + .build()); + + iterator.addCode("return ($T) $L.iterator();", Iterator.class, getJavaName(output)); + + builder.addMethod(iterator.build()); + } + } + + /** + * Add a constructor to get the outputs from an operation + */ + private void buildConstructor() { + MethodSpec.Builder ctor = MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE); + + ctor.addParameter(Names.Operation, "operation"); + + for (ArgDef output : op.getOutputArgList()) { + ResolvedType type = resolver.typeOf(output); + if (type.iterable || type.unwrapArg() instanceof WildcardTypeName) { + ctor.addAnnotation( + AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "$S", "unchecked").build()); + break; + } + } + CodeBlock.Builder body = CodeBlock.builder(); + body.addStatement("super(operation)"); + + if (op.getOutputArgCount() > 0) { + body.addStatement("int outputIdx = 0"); + for (ArgDef output : op.getOutputArgList()) { + ResolvedType type = resolver.typeOf(output); + boolean iterable = type.iterable; + if (iterable) { + String lengthVar = getJavaName(output) + "Length"; + + body.addStatement("int $L = operation.outputListLength($S)", lengthVar, output.getName()); + + if (type.unwrapArg() instanceof WildcardTypeName) { + body.addStatement("$L = $T.asList(operation.outputList(outputIdx, $L))", + getJavaName(output), + Arrays.class, + lengthVar); + } else { + body.addStatement("$L = $T.asList(($T) operation.outputList(outputIdx, $L))", + getJavaName(output), + Arrays.class, + ArrayTypeName.of(type.javaType), + lengthVar); + } + body.addStatement("outputIdx += $L", lengthVar); + + } else { + body.addStatement("$L = operation.output(outputIdx++)", getJavaName(output)); + } + } + } + + ctor.addCode(body.build()); + builder.addMethod(ctor.build()); + } + +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java new file mode 100644 index 00000000000..80d6698fe36 --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java @@ -0,0 +1,80 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.generator; + +import org.commonmark.node.Node; +import org.commonmark.parser.Parser; +import org.tensorflow.op.generator.javadoc.JavaDocRenderer; +import org.tensorflow.proto.framework.OpDef.ArgDef; + +/** + * Utilities for op generation + */ +final class GeneratorUtils { + + private static final Parser parser = Parser.builder().build(); + + /** + * Convert a Python style name to a Java style name. + * + * Does snake_case -> camelCase and handles keywords. + * + * Not valid for class names, meant for fields and methods. + * + * Generally you should use {@link ClassGenerator#getJavaName(ArgDef)}. + */ + static String javaizeMemberName(String name) { + StringBuilder result = new StringBuilder(); + boolean capNext = Character.isUpperCase(name.charAt(0)); + for (char c : name.toCharArray()) { + if (c == '_') { + capNext = true; + } else if (capNext) { + result.append(Character.toUpperCase(c)); + capNext = false; + } else { + result.append(c); + } + } + name = result.toString(); + switch (name) { + case "size": + return "sizeOutput"; + case "if": + return "ifOutput"; + case "while": + return "whileOutput"; + case "for": + return "forOutput"; + case "case": + return "caseOutput"; + default: + return name; + } + } + + /** + * Convert markdown descriptions to JavaDocs. + */ + static String parseDocumentation(String docs) { + Node document = parser.parse(docs); + JavaDocRenderer renderer = JavaDocRenderer.builder().build(); + return renderer.render(document).trim(); + } + + +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java new file mode 100644 index 00000000000..da4f63405cc --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java @@ -0,0 +1,263 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.op.generator; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.TypeSpec; +import org.tensorflow.proto.framework.ApiDef; +import org.tensorflow.proto.framework.OpDef; +import org.tensorflow.proto.framework.OpList; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.StandardOpenOption; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.LinkedHashMap; +import java.util.Map; + +public final class OpGenerator { + + private static final String LICENSE = + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + + "\n" + + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + + "you may not use this file except in compliance with the License.\n" + + "You may obtain a copy of the License at\n" + + "\n" + + " http://www.apache.org/licenses/LICENSE-2.0\n" + + "\n" + + "Unless required by applicable law or agreed to in writing, software\n" + + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + + "See the License for the specific language governing permissions and\n" + + "limitations under the License.\n" + + "=======================================================================*/" + + "\n"; + + private static final String HELP_TEXT = "Args should be: [base_package]"; + + private static final int API_DEF_FIELD_NUMBER = 100; + + /** + * Args should be {@code [base_package]}. + * + *

{@code base_package} is {@code org.tensorflow.op} by default. + * + *

Will delete everything in {@code outputDir}. + */ + public static void main(String[] args) { + + if (args[0].equals("--help")) { + System.out.println(HELP_TEXT); + return; + } + + if (args.length < 2) { + System.err.println(HELP_TEXT); + System.exit(1); + } + + File outputDir = new File(args[0]); + File inputFile = new File(args[1]); + + if (!inputFile.exists()) { + System.err.println("Op def file " + inputFile + " does not exist."); + System.exit(1); + } + + if (!inputFile.isFile()) { + System.err.println("Op def file " + inputFile + " is not a file."); + System.exit(1); + } + + if (!inputFile.canRead()) { + System.err.println("Can't read Op def file " + inputFile + "."); + System.exit(1); + } + + String packageName = "org.tensorflow.op"; + + if (args.length > 2) { + packageName = args[2]; + } + + File basePackage = new File(outputDir, packageName.replace('.', '/')); + + if (basePackage.exists()) { + try { + Files.walkFileTree( + basePackage.toPath(), + new SimpleFileVisitor() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) + throws IOException { + Files.delete(file); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path dir, IOException exc) + throws IOException { + Files.delete(dir); + return FileVisitResult.CONTINUE; + } + }); + } catch (IOException ignored) { + + } + } + + generate(outputDir, packageName, inputFile); + } + + /** Build the list of ops and api defs, then call {@link #generate(File, String, Map)} */ + private static void generate(File outputDir, String packageName, File opDefs) { + OpList opList = null; + try { + opList = OpList.parseFrom(new FileInputStream(opDefs)); + } catch (FileNotFoundException e) { + // already checked in main, shouldn't happen here + System.err.println("Op def file " + opDefs + " not found."); + System.exit(1); + } catch (IOException e) { + throw new RuntimeException( + "Error parsing op def file " + + opDefs + + ", was it generated by op_export_main.cc (in tensorflow-core-api)?", + e); + } + Map defs = new LinkedHashMap<>(opList.getOpCount()); + + for (OpDef op : opList.getOpList()) { + try { + if (!op.getUnknownFields().hasField(API_DEF_FIELD_NUMBER)) { + throw new RuntimeException( + "No attached ApiDef for op " + + op.getName() + + ", was " + + opDefs + + " generated by op_export_main.cc (in tensorflow-core-api)? The op's ApiDef should be" + + " attached as unknown field " + + API_DEF_FIELD_NUMBER + + "."); + } + ApiDef api = + ApiDef.parseFrom( + op.getUnknownFields() + .getField(API_DEF_FIELD_NUMBER) + .getLengthDelimitedList() + .get(0)); + defs.put(op, api); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException( + "Could not parse attached ApiDef for op " + + op.getName() + + ", was " + + opDefs + + " generated by op_export_main.cc (in tensorflow-core-api)?", + e); + } + } + + generate(outputDir, packageName, defs); + } + + /** Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}. */ + private static void generate(File outputDir, String basePackage, Map ops) { + ops.entrySet().stream() + .filter(e -> ClassGenerator.canGenerateOp(e.getKey(), e.getValue())) + .forEach( + (entry) -> { + entry + .getValue() + .getEndpointList() + .forEach( + (endpoint) -> { + String name; + String pack; + + int pos = endpoint.getName().lastIndexOf('.'); + if (pos > -1) { + pack = endpoint.getName().substring(0, pos); + name = endpoint.getName().substring(pos + 1); + } else { + pack = "core"; + name = endpoint.getName(); + } + + TypeSpec.Builder cls = TypeSpec.classBuilder(name); + try { + new ClassGenerator( + cls, + entry.getKey(), + entry.getValue(), + basePackage, + basePackage + "." + pack, + pack, + name, + endpoint) + .buildClass(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to generate class for op " + entry.getKey().getName(), e); + } + TypeSpec spec = cls.build(); + + JavaFile file = + JavaFile.builder(basePackage + "." + pack, spec) + .indent(" ") + .skipJavaLangImports(true) + .build(); + + File outputFile = + new File( + outputDir, + basePackage.replace('.', '/') + + '/' + + pack.replace('.', '/') + + '/' + + spec.name + + ".java"); + outputFile.getParentFile().mkdirs(); + try { + StringBuilder builder = new StringBuilder(); + builder.append(LICENSE + '\n'); + builder.append("// This class has been generated, DO NOT EDIT!\n\n"); + file.writeTo(builder); + + Files.write( + outputFile.toPath(), + builder.toString().getBytes(StandardCharsets.UTF_8), + StandardOpenOption.WRITE, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING); + } catch (IOException ioException) { + throw new IllegalStateException( + "Failed to write file " + outputFile, ioException); + } + }); + }); + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ResolvedType.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ResolvedType.java new file mode 100644 index 00000000000..65d4547e263 --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ResolvedType.java @@ -0,0 +1,223 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.op.generator; + +import com.squareup.javapoet.ArrayTypeName; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; +import org.tensorflow.Names; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.StringJoiner; + +/** Holds type information for inputs, outputs, or attributes, and provides utilities. */ +final class ResolvedType { + + /** The java level type. */ + final TypeName javaType; + + /** The type for jni/attribute setting use. */ + final TypeName jniType; + + /** + * Whether this type should be made iterable when used. + * + *

See {@link #arrayIfIterable()}, {@link #listIfIterable()}, {@link #iterableIfIterable()}. + */ + final boolean iterable; + + ResolvedType(TypeName javaType, TypeName jniType, boolean iterable) { + if (javaType == null) { + throw new NullPointerException("Can't create with a null javaType"); + } + if (jniType == null) { + throw new NullPointerException("Can't create with a null jniType"); + } + + this.javaType = javaType; + this.jniType = jniType; + this.iterable = iterable; + } + + ResolvedType(TypeName javaType, TypeName jniType) { + this(javaType, jniType, false); + } + + ResolvedType(TypeName type, boolean iterable) { + if (type == null) { + throw new NullPointerException("Can't create with a null type"); + } + + if (type.isPrimitive()) { + this.javaType = type.box(); + jniType = type; + } else { + this.javaType = type; + this.jniType = type; + } + this.iterable = iterable; + } + + ResolvedType(TypeName type) { + this(type, false); + } + + ResolvedType(Class javaType, Class jniType, boolean iterable) { + this(TypeName.get(javaType), TypeName.get(jniType), iterable); + } + + ResolvedType(Class javaType, Class jniType) { + this(TypeName.get(javaType), TypeName.get(jniType), false); + } + + ResolvedType(Class type, boolean iterable) { + this(TypeName.get(type), iterable); + } + + ResolvedType(Class type) { + this(type, false); + } + + /** Returns a copy of this type with the specified {@code iterable} value. */ + ResolvedType withIterable(boolean iterable) { + return new ResolvedType(javaType, jniType, iterable); + } + + /** Get the unboxed version of {@code javaType} if it is a boxed primitive. */ + TypeName unboxed() { + if (javaType.isBoxedPrimitive()) { + return javaType.unbox(); + } else { + return javaType; + } + } + + /** Return a copy, wrapping {@code javaType} in an array if this type is iterable. */ + ResolvedType arrayIfIterable() { + TypeName newJType; + if (iterable) { + newJType = ArrayTypeName.of(javaType); + } else { + newJType = javaType; + } + return new ResolvedType(newJType, jniType, iterable); + } + + /** Return a copy, wrapping {@code javaType} in {@link Iterable} if this type is iterable. */ + ResolvedType iterableIfIterable() { + TypeName newJType; + if (iterable) { + newJType = ParameterizedTypeName.get(ClassName.get(Iterable.class), javaType); + } else { + newJType = javaType; + } + return new ResolvedType(newJType, jniType, iterable); + } + + /** Return a copy, wrapping {@code javaType} in {@link List} if this type is iterable. */ + ResolvedType listIfIterable() { + TypeName newJType; + if (iterable) { + newJType = ParameterizedTypeName.get(ClassName.get(List.class), javaType); + } else { + newJType = javaType; + } + return new ResolvedType(newJType, jniType, iterable); + } + + /** True if wrapping will be done by {@link #classIfGeneric()} */ + boolean shouldWrapInClass() { + return javaType instanceof TypeVariableName || javaType instanceof WildcardTypeName; + } + + /** + * Return a copy, wrapping {@code javaType} in {@link Class} if it is a single type variable or a + * wildcard. + */ + ResolvedType classIfGeneric() { + TypeName newJType; + if (javaType instanceof TypeVariableName || javaType instanceof WildcardTypeName) { + newJType = ParameterizedTypeName.get(ClassName.get(Class.class), javaType); + } else { + newJType = javaType; + } + return new ResolvedType(newJType, jniType, iterable); + } + + /** Recursively get all type variable names in {@code javaType}. */ + Set findGenerics() { + if (javaType instanceof TypeVariableName) { + return Collections.singleton((TypeVariableName) javaType); + } else if (javaType instanceof ParameterizedTypeName) { + Set names = new LinkedHashSet<>(); + for (TypeName t : ((ParameterizedTypeName) javaType).typeArguments) { + names.addAll(new ResolvedType(t).findGenerics()); + } + return names; + } + return Collections.emptySet(); + } + + /** + * Return the type argument if {@code javaType} is {@code Operand} or {@code Output}, or return + * {@code javaType}. + */ + TypeName unwrapArg() { + if (javaType instanceof ParameterizedTypeName) { + ParameterizedTypeName pType = (ParameterizedTypeName) javaType; + if (pType.rawType.equals(Names.Operand) || pType.rawType.equals(Names.Output)) { + return pType.typeArguments.get(0); + } + } + return javaType; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResolvedType that = (ResolvedType) o; + return iterable == that.iterable + && Objects.equals(javaType, that.javaType) + && Objects.equals(jniType, that.jniType); + } + + @Override + public int hashCode() { + return Objects.hash(javaType, jniType, iterable); + } + + @Override + public String toString() { + return new StringJoiner(", ", ResolvedType.class.getSimpleName() + "(", ")") + .add("javaType=" + javaType) + .add("jniType=" + jniType) + .add("iterable=" + iterable) + .toString(); + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/TypeResolver.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/TypeResolver.java new file mode 100644 index 00000000000..f24813598ce --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/TypeResolver.java @@ -0,0 +1,313 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.op.generator; + +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.ParameterizedTypeName; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeVariableName; +import com.squareup.javapoet.WildcardTypeName; +import org.tensorflow.Names; +import org.tensorflow.proto.framework.AttrValue; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.OpDef; +import org.tensorflow.proto.framework.OpDef.ArgDef; +import org.tensorflow.proto.framework.OpDef.AttrDef; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * A utility class to handle type calculations for a {@link ClassGenerator}. Should be one to one + * with {@link ClassGenerator} instances. + */ +final class TypeResolver { + + static TypeName WILDCARD = WildcardTypeName.subtypeOf(TypeName.OBJECT); + static TypeName STRING = TypeName.get(java.lang.String.class); + /** Data types that are real numbers. */ + private static final Set realNumberTypes = new HashSet<>(); + + static { + realNumberTypes.add(DataType.DT_FLOAT); + realNumberTypes.add(DataType.DT_DOUBLE); + realNumberTypes.add(DataType.DT_INT32); + realNumberTypes.add(DataType.DT_UINT8); + realNumberTypes.add(DataType.DT_INT16); + realNumberTypes.add(DataType.DT_INT8); + realNumberTypes.add(DataType.DT_INT64); + realNumberTypes.add(DataType.DT_QINT8); + realNumberTypes.add(DataType.DT_QUINT8); + realNumberTypes.add(DataType.DT_QINT32); + realNumberTypes.add(DataType.DT_BFLOAT16); + realNumberTypes.add(DataType.DT_QINT16); + realNumberTypes.add(DataType.DT_QUINT16); + realNumberTypes.add(DataType.DT_UINT16); + realNumberTypes.add(DataType.DT_HALF); + realNumberTypes.add(DataType.DT_UINT32); + realNumberTypes.add(DataType.DT_UINT64); + } + + /** The op def to get types for. */ + private final OpDef op; + /** The processed argument types. */ + private final Map argTypes = new HashMap<>(); + /** Known types. Not simply a cache. */ + private final Map known = new HashMap<>(); + /** + * Attributes that were reached while getting the types of inputs. + * + *

These are excluded from factory methods. + */ + private final Set reachedFromInput = new HashSet<>(); + private char nextGenericLetter = 'T'; + + TypeResolver(OpDef op) { + this.op = op; + calculateArgTypes(); + } + + /** Get the {@code TType} type for a datatype, or {@code ? extends TType} if there isn't one. */ + static TypeName forDataType(DataType dataType) { + switch (dataType) { + case DT_STRING: + return Names.TString; + case DT_BOOL: + return Names.TBool; + case DT_BFLOAT16: + return Names.TBfloat16; + case DT_HALF: + return Names.TFloat16; + case DT_FLOAT: + return Names.TFloat32; + case DT_DOUBLE: + return Names.TFloat64; + case DT_UINT8: + return Names.TUint8; + case DT_INT32: + return Names.TInt32; + case DT_INT64: + return Names.TInt64; + default: + return WildcardTypeName.subtypeOf(Names.TType); + } + } + + /** + * Calculate the types of the arguments. + * + *

Removes extraneous generics and wraps in Operand/Output. + */ + private void calculateArgTypes() { + Map genericCounts = new HashMap<>(); + + ArrayList args = new ArrayList<>(op.getInputArgCount() + op.getOutputArgCount()); + Set outputs = new HashSet<>(op.getOutputArgList()); + args.addAll(op.getInputArgList()); + args.addAll(op.getOutputArgList()); + + for (ArgDef arg : args) { + ResolvedType type = calculateTypeOf(arg); + argTypes.put(arg, type); + if (type.javaType instanceof TypeVariableName) { + String name = ((TypeVariableName) type.javaType).name; + + // can't remove output types, as the class will depend on them. + int incr = outputs.contains(arg) ? 2 : 1; + + genericCounts.put(name, genericCounts.getOrDefault(name, 0) + incr); + } + } + + for (Map.Entry entry : argTypes.entrySet()) { + if (entry.getValue().javaType instanceof TypeVariableName) { + TypeVariableName type = (TypeVariableName) entry.getValue().javaType; + if (genericCounts.get(type.name) <= 1 && type.bounds.size() == 1) { + entry.setValue( + new ResolvedType( + WildcardTypeName.subtypeOf(type.bounds.get(0)), entry.getValue().iterable)); + } + } + ClassName baseClass; + if (outputs.contains(entry.getKey())) { + baseClass = Names.Output; + } else { + baseClass = Names.Operand; + } + + entry.setValue( + new ResolvedType( + ParameterizedTypeName.get(baseClass, entry.getValue().javaType), + entry.getValue().iterable)); + } + } + + /** + * Returns true if the attribute with name {@code attrName} was reached while calculating input + * types. + */ + boolean partOfInput(String attrName) { + return reachedFromInput.contains(attrName); + } + + /** Get a new generic letter. */ + private TypeVariableName nextGeneric() { + char letter = nextGenericLetter++; + if (nextGenericLetter > 'Z') { + nextGenericLetter = 'A'; + } + return TypeVariableName.get(String.valueOf(letter)); + } + + /** Returns true if the attribute is a real number type or is limited to real number types. */ + private boolean isRealNumberTyped(AttrValue value) { + if (!value.hasList()) { + return realNumberTypes.contains(value.getType()); + } + for (DataType d : value.getList().getTypeList()) { + if (!realNumberTypes.contains(d)) { + return false; + } + } + return true; + } + + /** Get the family {@code TType} of an attribute, i.e. {@code TNumber} */ + private TypeName typeFamily(AttrDef attr) { + if (isRealNumberTyped(attr.getAllowedValues())) { + return Names.TNumber; + } + + return Names.TType; + } + + /** Get the type of an attribute. */ + ResolvedType typeOf(AttrDef attr) { + return typeOf(attr, false); + } + + /** + * Get the type of an attribute + * + * @param fromInput whether we're calculating input types and should add this attr to {@link + * #reachedFromInput} + */ + private ResolvedType typeOf(AttrDef attr, boolean fromInput) { + if (known.containsKey(attr.getName())) { + return known.get(attr.getName()); + } + + boolean iterable = false; + String typeName = attr.getType(); + if (typeName.startsWith("list(")) { + iterable = true; + typeName = typeName.substring(5, typeName.length() - 1); + } + + ResolvedType types; + + switch (typeName) { + case "string": + types = new ResolvedType(STRING); + break; + case "int": + types = new ResolvedType(TypeName.LONG); + break; + case "float": + types = new ResolvedType(TypeName.FLOAT); + break; + case "bool": + types = new ResolvedType(TypeName.BOOLEAN); + break; + case "shape": + types = new ResolvedType(Names.Shape); + break; + case "tensor": + types = new ResolvedType(Names.Tensor); + break; + case "type": + TypeName family = typeFamily(attr); + TypeName type = + iterable ? WildcardTypeName.subtypeOf(family) : nextGeneric().withBounds(family); + types = new ResolvedType(type, TypeName.get(DataType.class)); + break; + case "func": + types = new ResolvedType(Names.ConcreteFunction); + break; + default: + throw new IllegalArgumentException("No Java type for " + typeName); + } + + types = types.withIterable(iterable); + known.put(attr.getName(), types); + if (fromInput) { + reachedFromInput.add(attr.getName()); + } + return types; + } + + /** Get the type of an argument (calculated in the constructor) */ + ResolvedType typeOf(ArgDef arg) { + return argTypes.get(arg); + } + + /** + * Calculate the type of an argument. Should only be used when creating {@link #argTypes}. + * + *

Will add to {@link #reachedFromInput} if {@code arg} is an input of the op. + */ + private ResolvedType calculateTypeOf(ArgDef arg) { + + boolean isInput = op.getInputArgList().contains(arg); + + ResolvedType type = new ResolvedType(WILDCARD); + + if (arg.getType() != DataType.DT_INVALID) { + type = new ResolvedType(forDataType(arg.getType())); + } else if (!arg.getTypeAttr().isEmpty()) { + String typeAttr = arg.getTypeAttr(); + if (known.containsKey(typeAttr)) { + type = known.get(typeAttr); + } else { + AttrDef attr = + op.getAttrList().stream().filter(x -> x.getName().equals(typeAttr)).findFirst().get(); + type = typeOf(attr, isInput); + } + } else if (!arg.getTypeListAttr().isEmpty()) { + type = type.withIterable(true); + if (isInput) { + reachedFromInput.add(arg.getTypeListAttr()); + } + } else { + throw new IllegalArgumentException( + "Can't resolve type of argument " + arg.getName() + " in operation " + op.getName()); + } + + if (!arg.getNumberAttr().isEmpty()) { + type = type.withIterable(true); + if (isInput) { + reachedFromInput.add(arg.getNumberAttr()); + } + known.put(arg.getNumberAttr(), new ResolvedType(TypeName.LONG)); + } + + return type; + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java new file mode 100644 index 00000000000..c5c7866af5c --- /dev/null +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java @@ -0,0 +1,536 @@ +package org.tensorflow.op.generator.javadoc; + +import com.google.common.base.CaseFormat; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import org.commonmark.node.AbstractVisitor; +import org.commonmark.node.BlockQuote; +import org.commonmark.node.BulletList; +import org.commonmark.node.Code; +import org.commonmark.node.Document; +import org.commonmark.node.Emphasis; +import org.commonmark.node.FencedCodeBlock; +import org.commonmark.node.HardLineBreak; +import org.commonmark.node.Heading; +import org.commonmark.node.HtmlBlock; +import org.commonmark.node.HtmlInline; +import org.commonmark.node.Image; +import org.commonmark.node.IndentedCodeBlock; +import org.commonmark.node.Link; +import org.commonmark.node.ListBlock; +import org.commonmark.node.ListItem; +import org.commonmark.node.Node; +import org.commonmark.node.OrderedList; +import org.commonmark.node.Paragraph; +import org.commonmark.node.SoftLineBreak; +import org.commonmark.node.StrongEmphasis; +import org.commonmark.node.Text; +import org.commonmark.node.ThematicBreak; +import org.commonmark.renderer.NodeRenderer; + +/** + * The node renderer that renders all the core nodes (comes last in the order of node renderers). + */ +public class CoreJavaDocNodeRenderer extends AbstractVisitor implements NodeRenderer { + + private static final String[] html5Tags = { + "", + "", + "

", + "", + "
", + "