Needs to be called after {@link #buildOptionsClass()}
+ */
+ private void buildGettersAndSetters() {
+ if (optionsClass != null) {
+ optionsClass.methodSpecs.stream()
+ .filter(x -> !x.isConstructor())
+ .forEach(
+ method -> {
+ String argName = method.parameters.get(0).name;
+ builder.addMethod(
+ MethodSpec.methodBuilder(method.name)
+ .addParameter(method.parameters.get(0))
+ .addJavadoc(method.javadoc)
+ .returns(ClassName.get(fullPackage, className, "Options"))
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addCode("return new Options().$L($L);", method.name, argName)
+ .build());
+ });
+ }
+
+ for (ArgDef output : op.getOutputArgList()) {
+ String name = getJavaName(output);
+ ApiDef.Arg argDef = argApis.get(output);
+ builder.addMethod(
+ MethodSpec.methodBuilder(name)
+ .addModifiers(Modifier.PUBLIC)
+ .returns(resolver.typeOf(output).listIfIterable().javaType)
+ .addJavadoc("Gets $L.\n", name)
+ .addJavadoc("$L", parseDocumentation(argDef.getDescription()))
+ .addJavadoc("\n@return $L.", name)
+ .addCode("return $L;", name)
+ .build());
+ }
+
+ }
+
+ /**
+ * Add interface implementation methods if this is a single output or list output op.
+ */
+ private void buildInterfaceImpl() {
+ ArgDef output = op.getOutputArg(0);
+ TypeName type = resolver.typeOf(output).unwrapArg();
+
+ boolean uncheckedCast = type instanceof WildcardTypeName;
+ TypeName outputTType = uncheckedCast ? Names.TType : type;
+
+ if (mode == RenderMode.OPERAND) {
+ TypeName outputType = ParameterizedTypeName.get(Names.Output, outputTType);
+ MethodSpec.Builder asOutput = MethodSpec.methodBuilder("asOutput")
+ .addModifiers(Modifier.PUBLIC)
+ .returns(outputType)
+ .addAnnotation(Override.class);
+
+ if (uncheckedCast) {
+ asOutput.addAnnotation(
+ AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "$S", "unchecked").build());
+ asOutput.addCode("return ($T) $L;", outputType, getJavaName(output));
+ } else {
+ asOutput.addCode("return $L;", getJavaName(output));
+ }
+
+ builder.addMethod(asOutput
+ .build());
+ } else {
+ TypeName operandType = ParameterizedTypeName.get(Names.Operand, outputTType);
+ TypeName returnType = ParameterizedTypeName.get(ClassName.get(Iterator.class), operandType);
+
+ MethodSpec.Builder iterator = MethodSpec.methodBuilder("iterator")
+ .addModifiers(Modifier.PUBLIC)
+ .returns(returnType)
+ .addAnnotation(Override.class)
+ .addAnnotation(
+ AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "{$S, $S}", "rawtypes", "unchecked")
+ .build());
+
+ iterator.addCode("return ($T) $L.iterator();", Iterator.class, getJavaName(output));
+
+ builder.addMethod(iterator.build());
+ }
+ }
+
+ /**
+ * Add a constructor to get the outputs from an operation
+ */
+ private void buildConstructor() {
+ MethodSpec.Builder ctor = MethodSpec.constructorBuilder().addModifiers(Modifier.PRIVATE);
+
+ ctor.addParameter(Names.Operation, "operation");
+
+ for (ArgDef output : op.getOutputArgList()) {
+ ResolvedType type = resolver.typeOf(output);
+ if (type.iterable || type.unwrapArg() instanceof WildcardTypeName) {
+ ctor.addAnnotation(
+ AnnotationSpec.builder(SuppressWarnings.class).addMember("value", "$S", "unchecked").build());
+ break;
+ }
+ }
+ CodeBlock.Builder body = CodeBlock.builder();
+ body.addStatement("super(operation)");
+
+ if (op.getOutputArgCount() > 0) {
+ body.addStatement("int outputIdx = 0");
+ for (ArgDef output : op.getOutputArgList()) {
+ ResolvedType type = resolver.typeOf(output);
+ boolean iterable = type.iterable;
+ if (iterable) {
+ String lengthVar = getJavaName(output) + "Length";
+
+ body.addStatement("int $L = operation.outputListLength($S)", lengthVar, output.getName());
+
+ if (type.unwrapArg() instanceof WildcardTypeName) {
+ body.addStatement("$L = $T.asList(operation.outputList(outputIdx, $L))",
+ getJavaName(output),
+ Arrays.class,
+ lengthVar);
+ } else {
+ body.addStatement("$L = $T.asList(($T) operation.outputList(outputIdx, $L))",
+ getJavaName(output),
+ Arrays.class,
+ ArrayTypeName.of(type.javaType),
+ lengthVar);
+ }
+ body.addStatement("outputIdx += $L", lengthVar);
+
+ } else {
+ body.addStatement("$L = operation.output(outputIdx++)", getJavaName(output));
+ }
+ }
+ }
+
+ ctor.addCode(body.build());
+ builder.addMethod(ctor.build());
+ }
+
+}
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java
new file mode 100644
index 00000000000..80d6698fe36
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/GeneratorUtils.java
@@ -0,0 +1,80 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.op.generator;
+
+import org.commonmark.node.Node;
+import org.commonmark.parser.Parser;
+import org.tensorflow.op.generator.javadoc.JavaDocRenderer;
+import org.tensorflow.proto.framework.OpDef.ArgDef;
+
+/**
+ * Utilities for op generation
+ */
+final class GeneratorUtils {
+
+ private static final Parser parser = Parser.builder().build();
+
+ /**
+ * Convert a Python style name to a Java style name.
+ *
+ * Does snake_case -> camelCase and handles keywords.
+ *
+ * Not valid for class names, meant for fields and methods.
+ *
+ * Generally you should use {@link ClassGenerator#getJavaName(ArgDef)}.
+ */
+ static String javaizeMemberName(String name) {
+ StringBuilder result = new StringBuilder();
+ boolean capNext = Character.isUpperCase(name.charAt(0));
+ for (char c : name.toCharArray()) {
+ if (c == '_') {
+ capNext = true;
+ } else if (capNext) {
+ result.append(Character.toUpperCase(c));
+ capNext = false;
+ } else {
+ result.append(c);
+ }
+ }
+ name = result.toString();
+ switch (name) {
+ case "size":
+ return "sizeOutput";
+ case "if":
+ return "ifOutput";
+ case "while":
+ return "whileOutput";
+ case "for":
+ return "forOutput";
+ case "case":
+ return "caseOutput";
+ default:
+ return name;
+ }
+ }
+
+ /**
+ * Convert markdown descriptions to JavaDocs.
+ */
+ static String parseDocumentation(String docs) {
+ Node document = parser.parse(docs);
+ JavaDocRenderer renderer = JavaDocRenderer.builder().build();
+ return renderer.render(document).trim();
+ }
+
+
+}
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java
new file mode 100644
index 00000000000..da4f63405cc
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/OpGenerator.java
@@ -0,0 +1,263 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================
+*/
+package org.tensorflow.op.generator;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.TypeSpec;
+import org.tensorflow.proto.framework.ApiDef;
+import org.tensorflow.proto.framework.OpDef;
+import org.tensorflow.proto.framework.OpList;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.FileVisitResult;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.SimpleFileVisitor;
+import java.nio.file.StandardOpenOption;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+public final class OpGenerator {
+
+ private static final String LICENSE =
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ + "\n"
+ + "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ + "you may not use this file except in compliance with the License.\n"
+ + "You may obtain a copy of the License at\n"
+ + "\n"
+ + " http://www.apache.org/licenses/LICENSE-2.0\n"
+ + "\n"
+ + "Unless required by applicable law or agreed to in writing, software\n"
+ + "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ + "See the License for the specific language governing permissions and\n"
+ + "limitations under the License.\n"
+ + "=======================================================================*/"
+ + "\n";
+
+ private static final String HELP_TEXT = "Args should be: {@code base_package} is {@code org.tensorflow.op} by default.
+ *
+ * Will delete everything in {@code outputDir}.
+ */
+ public static void main(String[] args) {
+
+ if (args[0].equals("--help")) {
+ System.out.println(HELP_TEXT);
+ return;
+ }
+
+ if (args.length < 2) {
+ System.err.println(HELP_TEXT);
+ System.exit(1);
+ }
+
+ File outputDir = new File(args[0]);
+ File inputFile = new File(args[1]);
+
+ if (!inputFile.exists()) {
+ System.err.println("Op def file " + inputFile + " does not exist.");
+ System.exit(1);
+ }
+
+ if (!inputFile.isFile()) {
+ System.err.println("Op def file " + inputFile + " is not a file.");
+ System.exit(1);
+ }
+
+ if (!inputFile.canRead()) {
+ System.err.println("Can't read Op def file " + inputFile + ".");
+ System.exit(1);
+ }
+
+ String packageName = "org.tensorflow.op";
+
+ if (args.length > 2) {
+ packageName = args[2];
+ }
+
+ File basePackage = new File(outputDir, packageName.replace('.', '/'));
+
+ if (basePackage.exists()) {
+ try {
+ Files.walkFileTree(
+ basePackage.toPath(),
+ new SimpleFileVisitor See {@link #arrayIfIterable()}, {@link #listIfIterable()}, {@link #iterableIfIterable()}.
+ */
+ final boolean iterable;
+
+ ResolvedType(TypeName javaType, TypeName jniType, boolean iterable) {
+ if (javaType == null) {
+ throw new NullPointerException("Can't create with a null javaType");
+ }
+ if (jniType == null) {
+ throw new NullPointerException("Can't create with a null jniType");
+ }
+
+ this.javaType = javaType;
+ this.jniType = jniType;
+ this.iterable = iterable;
+ }
+
+ ResolvedType(TypeName javaType, TypeName jniType) {
+ this(javaType, jniType, false);
+ }
+
+ ResolvedType(TypeName type, boolean iterable) {
+ if (type == null) {
+ throw new NullPointerException("Can't create with a null type");
+ }
+
+ if (type.isPrimitive()) {
+ this.javaType = type.box();
+ jniType = type;
+ } else {
+ this.javaType = type;
+ this.jniType = type;
+ }
+ this.iterable = iterable;
+ }
+
+ ResolvedType(TypeName type) {
+ this(type, false);
+ }
+
+ ResolvedType(Class> javaType, Class> jniType, boolean iterable) {
+ this(TypeName.get(javaType), TypeName.get(jniType), iterable);
+ }
+
+ ResolvedType(Class> javaType, Class> jniType) {
+ this(TypeName.get(javaType), TypeName.get(jniType), false);
+ }
+
+ ResolvedType(Class> type, boolean iterable) {
+ this(TypeName.get(type), iterable);
+ }
+
+ ResolvedType(Class> type) {
+ this(type, false);
+ }
+
+ /** Returns a copy of this type with the specified {@code iterable} value. */
+ ResolvedType withIterable(boolean iterable) {
+ return new ResolvedType(javaType, jniType, iterable);
+ }
+
+ /** Get the unboxed version of {@code javaType} if it is a boxed primitive. */
+ TypeName unboxed() {
+ if (javaType.isBoxedPrimitive()) {
+ return javaType.unbox();
+ } else {
+ return javaType;
+ }
+ }
+
+ /** Return a copy, wrapping {@code javaType} in an array if this type is iterable. */
+ ResolvedType arrayIfIterable() {
+ TypeName newJType;
+ if (iterable) {
+ newJType = ArrayTypeName.of(javaType);
+ } else {
+ newJType = javaType;
+ }
+ return new ResolvedType(newJType, jniType, iterable);
+ }
+
+ /** Return a copy, wrapping {@code javaType} in {@link Iterable} if this type is iterable. */
+ ResolvedType iterableIfIterable() {
+ TypeName newJType;
+ if (iterable) {
+ newJType = ParameterizedTypeName.get(ClassName.get(Iterable.class), javaType);
+ } else {
+ newJType = javaType;
+ }
+ return new ResolvedType(newJType, jniType, iterable);
+ }
+
+ /** Return a copy, wrapping {@code javaType} in {@link List} if this type is iterable. */
+ ResolvedType listIfIterable() {
+ TypeName newJType;
+ if (iterable) {
+ newJType = ParameterizedTypeName.get(ClassName.get(List.class), javaType);
+ } else {
+ newJType = javaType;
+ }
+ return new ResolvedType(newJType, jniType, iterable);
+ }
+
+ /** True if wrapping will be done by {@link #classIfGeneric()} */
+ boolean shouldWrapInClass() {
+ return javaType instanceof TypeVariableName || javaType instanceof WildcardTypeName;
+ }
+
+ /**
+ * Return a copy, wrapping {@code javaType} in {@link Class} if it is a single type variable or a
+ * wildcard.
+ */
+ ResolvedType classIfGeneric() {
+ TypeName newJType;
+ if (javaType instanceof TypeVariableName || javaType instanceof WildcardTypeName) {
+ newJType = ParameterizedTypeName.get(ClassName.get(Class.class), javaType);
+ } else {
+ newJType = javaType;
+ }
+ return new ResolvedType(newJType, jniType, iterable);
+ }
+
+ /** Recursively get all type variable names in {@code javaType}. */
+ Set These are excluded from factory methods.
+ */
+ private final Set Removes extraneous generics and wraps in Operand/Output.
+ */
+ private void calculateArgTypes() {
+ Map Will add to {@link #reachedFromInput} if {@code arg} is an input of the op.
+ */
+ private ResolvedType calculateTypeOf(ArgDef arg) {
+
+ boolean isInput = op.getInputArgList().contains(arg);
+
+ ResolvedType type = new ResolvedType(WILDCARD);
+
+ if (arg.getType() != DataType.DT_INVALID) {
+ type = new ResolvedType(forDataType(arg.getType()));
+ } else if (!arg.getTypeAttr().isEmpty()) {
+ String typeAttr = arg.getTypeAttr();
+ if (known.containsKey(typeAttr)) {
+ type = known.get(typeAttr);
+ } else {
+ AttrDef attr =
+ op.getAttrList().stream().filter(x -> x.getName().equals(typeAttr)).findFirst().get();
+ type = typeOf(attr, isInput);
+ }
+ } else if (!arg.getTypeListAttr().isEmpty()) {
+ type = type.withIterable(true);
+ if (isInput) {
+ reachedFromInput.add(arg.getTypeListAttr());
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Can't resolve type of argument " + arg.getName() + " in operation " + op.getName());
+ }
+
+ if (!arg.getNumberAttr().isEmpty()) {
+ type = type.withIterable(true);
+ if (isInput) {
+ reachedFromInput.add(arg.getNumberAttr());
+ }
+ known.put(arg.getNumberAttr(), new ResolvedType(TypeName.LONG));
+ }
+
+ return type;
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java
new file mode 100644
index 00000000000..c5c7866af5c
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/javadoc/CoreJavaDocNodeRenderer.java
@@ -0,0 +1,536 @@
+package org.tensorflow.op.generator.javadoc;
+
+import com.google.common.base.CaseFormat;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Set;
+import org.commonmark.node.AbstractVisitor;
+import org.commonmark.node.BlockQuote;
+import org.commonmark.node.BulletList;
+import org.commonmark.node.Code;
+import org.commonmark.node.Document;
+import org.commonmark.node.Emphasis;
+import org.commonmark.node.FencedCodeBlock;
+import org.commonmark.node.HardLineBreak;
+import org.commonmark.node.Heading;
+import org.commonmark.node.HtmlBlock;
+import org.commonmark.node.HtmlInline;
+import org.commonmark.node.Image;
+import org.commonmark.node.IndentedCodeBlock;
+import org.commonmark.node.Link;
+import org.commonmark.node.ListBlock;
+import org.commonmark.node.ListItem;
+import org.commonmark.node.Node;
+import org.commonmark.node.OrderedList;
+import org.commonmark.node.Paragraph;
+import org.commonmark.node.SoftLineBreak;
+import org.commonmark.node.StrongEmphasis;
+import org.commonmark.node.Text;
+import org.commonmark.node.ThematicBreak;
+import org.commonmark.renderer.NodeRenderer;
+
+/**
+ * The node renderer that renders all the core nodes (comes last in the order of node renderers).
+ */
+public class CoreJavaDocNodeRenderer extends AbstractVisitor implements NodeRenderer {
+
+ private static final String[] html5Tags = {
+ "",
+ "",
+ "",
+ "",
+ "