Skip to content

Commit 5127413

Browse files
committed
Cleanup and more review changes
Signed-off-by: Ryan Nett <[email protected]>
1 parent 3acc001 commit 5127413

File tree

16 files changed

+225
-685
lines changed

16 files changed

+225
-685
lines changed

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java

Lines changed: 39 additions & 618 deletions
Large diffs are not rendered by default.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java renamed to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/BaseGradientAdapter.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import org.bytedeco.javacpp.PointerScope;
22+
import org.tensorflow.internal.c_api.GradFunc;
2223
import org.tensorflow.internal.c_api.NativeOutput;
2324
import org.tensorflow.internal.c_api.NativeOutputVector;
2425
import org.tensorflow.internal.c_api.Node;
2526
import org.tensorflow.internal.c_api.TF_Operation;
26-
import org.tensorflow.op.RawGradientAdapter;
27-
import org.tensorflow.op.TypedGradientAdapter;
2827

29-
/** Helpers for {@link TypedGradientAdapter} and {@link RawGradientAdapter}. */
30-
public class GradientAdapterHelpers {
28+
/** Helper base class for custom gradient adapters <b>INTERNAL USE ONLY</b> */
29+
public abstract class BaseGradientAdapter extends GradFunc {
3130

3231
/**
3332
* Convert an array of native outputs to a list of {@link Output}s.
@@ -36,7 +35,7 @@ public class GradientAdapterHelpers {
3635
* @param nativeOutputs the native outputs to convert
3736
* @return a list of Outputs
3837
*/
39-
public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) {
38+
protected static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) {
4039
List<Output<?>> gradInputs = new ArrayList<>((int) nativeOutputs.size());
4140
for (int i = 0; i < nativeOutputs.size(); i++) {
4241
NativeOutput output = nativeOutputs.get(i);
@@ -51,7 +50,7 @@ public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nati
5150
* @param outputs the outputs to put
5251
* @param nativeOutputs the native array to put the outputs into
5352
*/
54-
public static void putToNativeOutputs(
53+
protected static void putToNativeOutputs(
5554
List<Operand<?>> outputs, NativeOutputVector nativeOutputs) {
5655
nativeOutputs.resize(outputs.size());
5756
for (int i = 0; i < outputs.size(); i++) {
@@ -68,14 +67,14 @@ public static void putToNativeOutputs(
6867
* @param node the native node
6968
* @return a graph operation with the underlying native node
7069
*/
71-
public static GraphOperation getGraphOp(Graph g, Node node) {
70+
protected static GraphOperation getGraphOp(Graph g, Node node) {
7271
try (PointerScope scope = new PointerScope();
7372
Graph.Reference ref = g.ref()) {
7473
return new GraphOperation(g, new TF_Operation(node));
7574
}
7675
}
7776

78-
public static void useDangerousLockedBuilders(Graph g, boolean dangerous) {
77+
protected static void useDangerousLockedBuilders(Graph g, boolean dangerous) {
7978
g.setDangerousGradientBuilder(dangerous);
8079
}
8180
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@
3838
import org.tensorflow.internal.c_api.TF_Status;
3939
import org.tensorflow.op.CustomGradient;
4040
import org.tensorflow.op.RawCustomGradient;
41-
import org.tensorflow.op.RawGradientAdapter;
42-
import org.tensorflow.op.RawOp;
43-
import org.tensorflow.op.TypedGradientAdapter;
41+
import org.tensorflow.op.RawOpInputs;
42+
import org.tensorflow.op.annotation.GeneratedOpInputsMetadata;
4443
import org.tensorflow.op.annotation.GeneratedOpMetadata;
4544
import org.tensorflow.op.math.Add;
4645
import org.tensorflow.proto.framework.OpList;
@@ -181,7 +180,7 @@ public static synchronized boolean registerCustomGradient(
181180
if (hasGradient(opType)) {
182181
return false;
183182
}
184-
GradFunc g = new RawGradientAdapter(gradient);
183+
GradFunc g = RawCustomGradient.adapter(gradient);
185184
GradOpRegistry.Global().Register(opType, g);
186185
gradientFuncs.add(g);
187186
return true;
@@ -199,24 +198,33 @@ public static synchronized boolean registerCustomGradient(
199198
* @throws IllegalArgumentException if {@code opClass} does not have a {@link GeneratedOpMetadata}
200199
* field.
201200
*/
202-
public static synchronized <T extends RawOp> boolean registerCustomGradient(
201+
public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(
203202
Class<T> opClass, CustomGradient<T> gradient) {
204-
GeneratedOpMetadata metadata = opClass.getAnnotation(GeneratedOpMetadata.class);
203+
GeneratedOpInputsMetadata metadata = opClass.getAnnotation(GeneratedOpInputsMetadata.class);
205204

206205
if (metadata == null) {
207206
throw new IllegalArgumentException(
208-
"Class "
207+
"Inputs Class "
209208
+ opClass
209+
+ " does not have a GeneratedOpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
210+
}
211+
GeneratedOpMetadata outputMetadata =
212+
metadata.outputsClass().getAnnotation(GeneratedOpMetadata.class);
213+
214+
if (outputMetadata == null) {
215+
throw new IllegalArgumentException(
216+
"Op Class "
217+
+ metadata.outputsClass()
210218
+ " does not have a GeneratedOpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug.");
211219
}
212220

213-
String opType = metadata.opType();
221+
String opType = outputMetadata.opType();
214222

215223
if (hasGradient(opType)) {
216224
return false;
217225
}
218226

219-
GradFunc g = new TypedGradientAdapter<>(gradient, opClass);
227+
GradFunc g = CustomGradient.adapter(gradient, opClass);
220228
GradOpRegistry.Global().Register(opType, g);
221229
gradientFuncs.add(g);
222230
return true;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"tensorflow_adapters.h",
5858
"tensorflow/c/eager/c_api.h",
5959
"tensorflow/c/eager/c_api_experimental.h",
60-
"tensorflow/cc/framework/scope.h",
60+
"tensorflow/cc/framework/scope.h",
6161
"tensorflow/cc/framework/grad_op_registry.h",
6262
"tensorflow/core/platform/status.h",
6363
"tensorflow/core/graph/graph.h",
@@ -395,15 +395,11 @@ public void map(InfoMap infoMap) {
395395
new Info("TF_ImportGraphDefOptions")
396396
.pointerTypes("TF_ImportGraphDefOptions")
397397
.base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions"))
398-
.put(
399-
new Info(
400-
"TF_WhileParams")
401-
.purify())
398+
.put(new Info("TF_WhileParams").purify())
402399
.put(new Info("TF_Operation").purify())
403400
.put(
404401
new Info("TF_Operation::node")
405402
.javaText("public native @MemberGetter @ByRef Node node();"))
406-
407403
.put(
408404
new Info("TFE_Context")
409405
.pointerTypes("TFE_Context")

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
import org.tensorflow.Operand;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.TensorFlow;
23+
import org.tensorflow.internal.c_api.GradFunc;
2324

2425
/**
2526
* A custom gradient for ops of type {@link T}. Should be registered using {@link
2627
* TensorFlow#registerCustomGradient(Class, CustomGradient)}.
2728
*
2829
* @param <T> the type of op this gradient is for.
2930
*/
31+
@SuppressWarnings("rawtypes")
3032
@FunctionalInterface
31-
public interface CustomGradient<T extends RawOp> {
33+
public interface CustomGradient<T extends RawOpInputs> {
3234

3335
/**
3436
* Calculate the gradients for {@code op}.
@@ -39,4 +41,15 @@ public interface CustomGradient<T extends RawOp> {
3941
* @return the gradients of the op's inputs.
4042
*/
4143
List<Operand<?>> call(Ops tf, T op, List<Output<?>> gradInputs);
44+
45+
/**
46+
* Create an adapter for the custom gradient so that it can be used by native code.
47+
*
48+
* <p>You should not be calling this yourself, use {@link TensorFlow#registerCustomGradient(Class,
49+
* CustomGradient)}.
50+
*/
51+
public static <T extends RawOpInputs<?>> GradFunc adapter(
52+
CustomGradient<T> gradient, Class<T> opClass) {
53+
return new TypedGradientAdapter<T>(gradient, opClass);
54+
}
4255
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.tensorflow.Operand;
2222
import org.tensorflow.Output;
2323
import org.tensorflow.TensorFlow;
24+
import org.tensorflow.internal.c_api.GradFunc;
2425

2526
/**
2627
* A custom gradient for an op of unspecified type. Should be registered using {@link
@@ -42,4 +43,14 @@ public interface RawCustomGradient {
4243
* @return the gradients of the op's inputs.
4344
*/
4445
List<Operand<?>> call(Ops tf, GraphOperation op, List<Output<?>> gradInputs);
46+
47+
/**
48+
* Create an adapter for the custom gradient so that it can be used by native code.
49+
*
50+
* <p>You should not be calling this yourself, use {@link
51+
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
52+
*/
53+
public static GradFunc adapter(RawCustomGradient gradient) {
54+
return new RawGradientAdapter(gradient);
55+
}
4556
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,32 @@
1212
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
15-
==============================================================================
15+
=======================================================================
16+
1617
*/
1718
package org.tensorflow.op;
1819

1920
import static org.tensorflow.internal.c_api.global.tensorflow.StatusFromTF_Status;
2021

2122
import java.util.List;
2223
import org.bytedeco.javacpp.PointerScope;
23-
import org.tensorflow.GradientAdapterHelpers;
24+
import org.tensorflow.BaseGradientAdapter;
2425
import org.tensorflow.Graph;
2526
import org.tensorflow.GraphOperation;
2627
import org.tensorflow.Operand;
2728
import org.tensorflow.Output;
28-
import org.tensorflow.internal.c_api.GradFunc;
2929
import org.tensorflow.internal.c_api.NativeOperation;
3030
import org.tensorflow.internal.c_api.NativeOutputVector;
3131
import org.tensorflow.internal.c_api.NativeStatus;
3232
import org.tensorflow.internal.c_api.TF_Scope;
3333
import org.tensorflow.internal.c_api.TF_Status;
3434

3535
/** A native adapter for {@link RawCustomGradient}. */
36-
public final class RawGradientAdapter extends GradFunc {
36+
final class RawGradientAdapter extends BaseGradientAdapter {
3737

3838
private final RawCustomGradient gradient;
3939

40-
public RawGradientAdapter(RawCustomGradient gradient) {
40+
RawGradientAdapter(RawCustomGradient gradient) {
4141
super();
4242
this.gradient = gradient;
4343
}
@@ -57,15 +57,15 @@ public NativeStatus call(
5757
Scope nativeScope = new GradientScope(scope, g);
5858
Ops tf = new Ops(nativeScope);
5959

60-
List<Output<?>> gradInputs = GradientAdapterHelpers.fromNativeOutputs(g, grad_inputs);
60+
List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);
6161

62-
GraphOperation operation = GradientAdapterHelpers.getGraphOp(g, op.node());
62+
GraphOperation operation = BaseGradientAdapter.getGraphOp(g, op.node());
6363

64-
GradientAdapterHelpers.useDangerousLockedBuilders(g, true);
64+
BaseGradientAdapter.useDangerousLockedBuilders(g, true);
6565
List<Operand<?>> gradOutputs = gradient.call(tf, operation, gradInputs);
66-
GradientAdapterHelpers.useDangerousLockedBuilders(g, false);
66+
BaseGradientAdapter.useDangerousLockedBuilders(g, false);
6767

68-
GradientAdapterHelpers.putToNativeOutputs(gradOutputs, grad_outputs);
68+
BaseGradientAdapter.putToNativeOutputs(gradOutputs, grad_outputs);
6969
} catch (Throwable t) {
7070
t.printStackTrace();
7171
throw t;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawOp.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,19 @@ public final String toString() {
5959
* Constructor.
6060
*
6161
* @param operation the underlying operation
62+
* @param requiredType the type that the underlying operation must be
6263
*/
63-
protected RawOp(Operation operation) {
64+
protected RawOp(Operation operation, String requiredType) {
65+
if (!requiredType.equals(operation.type())) {
66+
throw new IllegalArgumentException(
67+
"Can't create a "
68+
+ this.getClass()
69+
+ " from an operation with type \""
70+
+ operation.type()
71+
+ "\", operation must have type \""
72+
+ requiredType
73+
+ "\".");
74+
}
6475
this.operation = operation;
6576
}
6677
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
15-
==============================================================================
15+
=======================================================================
16+
1617
*/
1718
package org.tensorflow.op;
1819

@@ -22,31 +23,29 @@
2223
import java.lang.reflect.InvocationTargetException;
2324
import java.util.List;
2425
import org.bytedeco.javacpp.PointerScope;
25-
import org.tensorflow.GradientAdapterHelpers;
26+
import org.tensorflow.BaseGradientAdapter;
2627
import org.tensorflow.Graph;
2728
import org.tensorflow.Operand;
2829
import org.tensorflow.Output;
29-
import org.tensorflow.internal.c_api.GradFunc;
3030
import org.tensorflow.internal.c_api.NativeOperation;
3131
import org.tensorflow.internal.c_api.NativeOutputVector;
3232
import org.tensorflow.internal.c_api.NativeStatus;
3333
import org.tensorflow.internal.c_api.TF_Scope;
3434
import org.tensorflow.internal.c_api.TF_Status;
3535

3636
/** A native adapter for {@link CustomGradient}. */
37-
public final class TypedGradientAdapter<T extends RawOp> extends GradFunc {
37+
final class TypedGradientAdapter<T extends RawOpInputs<?>> extends BaseGradientAdapter {
3838

3939
private final CustomGradient<T> gradient;
40-
private final Class<T> opClass;
40+
private final Class<T> opInputClass;
4141
private final Constructor<T> ctor;
4242

43-
public TypedGradientAdapter(CustomGradient<T> gradient, Class<T> opClass) {
43+
TypedGradientAdapter(CustomGradient<T> gradient, Class<T> opInputClass) {
4444
super();
4545
this.gradient = gradient;
46-
this.opClass = opClass;
46+
this.opInputClass = opInputClass;
4747
//noinspection unchecked
48-
this.ctor = (Constructor<T>) opClass.getDeclaredConstructors()[0];
49-
ctor.setAccessible(true);
48+
this.ctor = (Constructor<T>) this.opInputClass.getDeclaredConstructors()[0];
5049
}
5150

5251
@Override
@@ -65,18 +64,19 @@ public NativeStatus call(
6564

6665
Ops tf = new Ops(nativeScope);
6766

68-
List<Output<?>> gradInputs = GradientAdapterHelpers.fromNativeOutputs(g, grad_inputs);
67+
List<Output<?>> gradInputs = BaseGradientAdapter.fromNativeOutputs(g, grad_inputs);
6968

70-
T rawOp = ctor.newInstance(GradientAdapterHelpers.getGraphOp(g, op.node()));
69+
T rawOp = ctor.newInstance(BaseGradientAdapter.getGraphOp(g, op.node()));
7170

72-
GradientAdapterHelpers.useDangerousLockedBuilders(g, true);
71+
BaseGradientAdapter.useDangerousLockedBuilders(g, true);
7372
List<Operand<?>> gradOutputs = gradient.call(tf, rawOp, gradInputs);
74-
GradientAdapterHelpers.useDangerousLockedBuilders(g, false);
73+
BaseGradientAdapter.useDangerousLockedBuilders(g, false);
7574

76-
GradientAdapterHelpers.putToNativeOutputs(gradOutputs, grad_outputs);
75+
BaseGradientAdapter.putToNativeOutputs(gradOutputs, grad_outputs);
7776

7877
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
79-
RuntimeException re = new RuntimeException("Could not instantiate Op class " + opClass, e);
78+
RuntimeException re =
79+
new RuntimeException("Could not instantiate Op class " + opInputClass, e);
8080
re.printStackTrace();
8181
throw re;
8282
} catch (Throwable t) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
=======================================================================
16+
17+
*/
18+
package org.tensorflow.op.annotation;
19+
20+
import java.lang.annotation.ElementType;
21+
import java.lang.annotation.Retention;
22+
import java.lang.annotation.RetentionPolicy;
23+
import java.lang.annotation.Target;
24+
import org.tensorflow.op.RawOp;
25+
26+
/**
27+
* An annotation that should only be used by codegeneration. Used to provide some metadata about the
28+
* op.
29+
*
30+
* <p><b>DO NOT USE MANUALLY</b>
31+
*/
32+
@Target(ElementType.TYPE)
33+
@Retention(RetentionPolicy.RUNTIME)
34+
public @interface GeneratedOpInputsMetadata {
35+
Class<? extends RawOp> outputsClass();
36+
}

0 commit comments

Comments
 (0)