Skip to content

Commit 3acc001

Browse files
committed
Use annotation instead of field reflection to store op types
Signed-off-by: Ryan Nett <[email protected]>
1 parent 2f22b57 commit 3acc001

File tree

6 files changed

+69
-26
lines changed

6 files changed

+69
-26
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version;
2424

2525
import com.google.protobuf.InvalidProtocolBufferException;
26-
import java.lang.reflect.Field;
27-
import java.lang.reflect.Modifier;
2826
import java.util.Collections;
2927
import java.util.IdentityHashMap;
3028
import java.util.Set;
@@ -43,6 +41,7 @@
4341
import org.tensorflow.op.RawGradientAdapter;
4442
import org.tensorflow.op.RawOp;
4543
import org.tensorflow.op.TypedGradientAdapter;
44+
import org.tensorflow.op.annotation.GeneratedOpMetadata;
4645
import org.tensorflow.op.math.Add;
4746
import org.tensorflow.proto.framework.OpList;
4847

@@ -190,41 +189,36 @@ public static synchronized boolean registerCustomGradient(
190189

191190
/**
192191
* Register a custom gradient function for ops of {@code opClass} type. The actual op type is
193-
* detected from the class's {@code OP_NAME} field.
192+
* detected from the class's {@link GeneratedOpMetadata} annotation. As such, it only works on
193+
* generated op classes.
194194
*
195195
* @param opClass the class of op to register the gradient for.
196196
* @param gradient the gradient function to use
197197
* @return {@code true} if the gradient was registered, {@code false} if there was already a
198198
* gradient registered for this op
199-
* @throws IllegalArgumentException if {@code opClass} does not have a static {@code OP_NAME}
199+
* @throws IllegalArgumentException if {@code opClass} does not have a {@link GeneratedOpMetadata}
200200
* field.
201201
*/
202202
public static synchronized <T extends RawOp> boolean registerCustomGradient(
203203
Class<T> opClass, CustomGradient<T> gradient) {
204-
try {
205-
Field nameField = opClass.getDeclaredField("OP_NAME");
206-
207-
if (!Modifier.isStatic(nameField.getModifiers())) {
208-
throw new IllegalArgumentException(
209-
"Class " + opClass + " has an OP_NAME field, but it is not static.");
210-
}
211-
212-
String opType = (String) nameField.get(null);
204+
GeneratedOpMetadata metadata = opClass.getAnnotation(GeneratedOpMetadata.class);
213205

214-
if (hasGradient(opType)) {
215-
return false;
216-
}
217-
218-
GradFunc g = new TypedGradientAdapter<>(gradient, opClass);
219-
GradOpRegistry.Global().Register(opType, g);
220-
gradientFuncs.add(g);
221-
} catch (IllegalAccessException | NoSuchFieldException e) {
206+
if (metadata == null) {
222207
throw new IllegalArgumentException(
223-
"Could not get OP_NAME field for class "
208+
"Class "
224209
+ opClass
225-
+ ", ensure it is a generated op class",
226-
e);
210+
+ " does not have a GeneratedOpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
211+
}
212+
213+
String opType = metadata.opType();
214+
215+
if (hasGradient(opType)) {
216+
return false;
227217
}
218+
219+
GradFunc g = new TypedGradientAdapter<>(gradient, opClass);
220+
GradOpRegistry.Global().Register(opType, g);
221+
gradientFuncs.add(g);
228222
return true;
229223
}
230224
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
=======================================================================
16+
17+
*/
18+
package org.tensorflow.op.annotation;
19+
20+
import java.lang.annotation.ElementType;
21+
import java.lang.annotation.Retention;
22+
import java.lang.annotation.RetentionPolicy;
23+
import java.lang.annotation.Target;
24+
import java.util.function.Function;
25+
import org.tensorflow.Operation;
26+
import org.tensorflow.op.RawOp;
27+
28+
/**
29+
* An annotation that should only be used by codegeneration. Used to provide some metadata about the
30+
* op.
31+
*
32+
* <p><b>DO NOT USE MANUALLY</b>
33+
*/
34+
@Target(ElementType.TYPE)
35+
@Retention(RetentionPolicy.RUNTIME)
36+
public @interface GeneratedOpMetadata {
37+
String opType();
38+
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ public class Names {
2929

3030
public static final ClassName Operator = ClassName.get(OpPackage + ".annotation", "Operator");
3131
public static final ClassName Endpoint = ClassName.get(OpPackage + ".annotation", "Endpoint");
32+
public static final ClassName GeneratedOpMetadata =
33+
ClassName.get(OpPackage + ".annotation", "GeneratedOpMetadata");
3234

3335
public static final ClassName TType = ClassName.get(TypesPackage + ".family", "TType");
3436
public static final ClassName TString = ClassName.get(TypesPackage, "TString");

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/op/generator/ClassGenerator.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ void buildClass() {
220220
if (!isStateSelector) {
221221
builder.addModifiers(Modifier.FINAL);
222222
builder.superclass(Names.RawOp);
223+
addMetadataAnnotation();
223224
}
224225

225226
if (isStateSubclass) {
@@ -1031,4 +1032,12 @@ private void buildInputsClass() {
10311032
inputsBuilder.addTypeVariables(typeVars);
10321033
this.builder.addType(inputsBuilder.build());
10331034
}
1035+
1036+
/** Adds the GeneratedOpMetadata annotation */
1037+
private void addMetadataAnnotation() {
1038+
builder.addAnnotation(
1039+
AnnotationSpec.builder(Names.GeneratedOpMetadata)
1040+
.addMember("opType", "$L", className + ".OP_NAME")
1041+
.build());
1042+
}
10341043
}

tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import org.tensorflow.DeviceSpec;
1818
import org.tensorflow.EagerSession;
1919
import org.tensorflow.ExecutionEnvironment;
20-
import org.tensorflow.op.OpScope;
2120
import org.tensorflow.op.Op;
21+
import org.tensorflow.op.OpScope;
2222
import org.tensorflow.op.Ops;
2323
import org.tensorflow.op.Scope;
2424

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import org.tensorflow.Operand;
2626
import org.tensorflow.Operation;
2727
import org.tensorflow.Output;
28-
import org.tensorflow.op.OpScope;
2928
import org.tensorflow.op.Op;
29+
import org.tensorflow.op.OpScope;
3030
import org.tensorflow.op.Ops;
3131
import org.tensorflow.op.Scope;
3232
import org.tensorflow.op.core.NoOp;

0 commit comments

Comments
 (0)