Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JNI wrapper memory leak #2178

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Runtime/jni/jnidummy.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
//===----------------------------------------------------------------------===//

#include "com_ibm_onnxmlir_OMModel.h"
#include "com_ibm_onnxmlir_OMTensor.h"

/* Dummy routine to force the link editor to embed code in libjniruntime.a
into libmodel.so */
void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) {
Java_com_ibm_onnxmlir_OMModel_main_1graph_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points(NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points_1jni(NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_input_1signature_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_output_1signature_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMTensor_free_1data_1jni(NULL, NULL, NULL);
}
49 changes: 34 additions & 15 deletions src/Runtime/jni/jniwrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ const char *jnistr[] = {
"com/ibm/onnxmlir/OMTensor", /* 3 CLS_COM_IBM_ONNXMLIR_OMTENSOR */
"com/ibm/onnxmlir/OMTensorList", /* 4 CLS_COM_IBM_ONNXMLIR_OMTENSORLIST*/
"<init>", /* 5 CTOR_INIT */
"(Ljava/nio/ByteBuffer;[J[JI)V", /* 6 CTOR_OMTENSOR */
"(Ljava/nio/ByteBuffer;[J[JIZ)V", /* 6 CTOR_OMTENSOR */
"([Lcom/ibm/onnxmlir/OMTensor;)V", /* 7 CTOR_OMTENSORLIST */
"()Ljava/nio/ByteBuffer;", /* 8 SIG_GET_DATA */
"(Ljava/nio/ByteBuffer;)V", /* 9 SIG_SET_DATA */
Expand Down Expand Up @@ -598,9 +598,11 @@ jobject omtl_native_to_java(
* If jni_owning is true, we take ownership by setting owner flag
* to false. This means that when we call omTensorListDestroy
* the data buffer will not be freed since it has been given to
* the Java direct byte buffer and the Java GC will be responsible
* for freeing the data buffer. This way we avoid copying the data
* buffer.
* the Java direct byte buffer. However, the direct byte buffer
* created by NewDirectByteBuffer is not subject to Java GC so we
* pass clean=true to the OMTensor construct to initializ a cleaner
* to manually free the data buffer. This way we avoid copying the
* data buffer.
*
* If jni_owning is false, it means the data buffer is not freeable
* due to one of the two following cases:
Expand All @@ -609,23 +611,18 @@ jobject omtl_native_to_java(
* responsible for freeing it
* - the data buffer is static
*
* Either way, since the data buffer will be given to Java and is
* subject to GC, we must make a copy of the data buffer.
* Either way, since the data buffer given to Java is not subject
* to GC, we can simply pass it through with clean=false and no
* cleaner will be initialized.
*/
void *jbytebuffer_data = jni_data;
if (jni_owning) {
LIB_CALL(omTensorSetOwning(jni_omts[i], (int64_t)0), 1, env,
japi->jecpt_cls, "");
LOG_PRINTF(LOG_DEBUG, "omt[%d]:%p data %p ownership taken", i,
jni_omts[i], jni_data);
} else {
LIB_VAR_CALL(jbytebuffer_data, malloc(jni_bufferSize),
jbytebuffer_data != NULL, env, japi->jecpt_cls, "jbytebuffer_data=%p",
jbytebuffer_data);
memcpy(jbytebuffer_data, jni_data, jni_bufferSize);
LOG_PRINTF(LOG_DEBUG, "omt[%d]:%p data %p copied into %p", i, jni_omts[i],
jni_data, jbytebuffer_data);
}

JNI_TYPE_VAR_CALL(env, jobject, jomt_data,
(*env)->NewDirectByteBuffer(env, jbytebuffer_data, jomt_bufferSize),
jomt_data != NULL, japi->jecpt_cls, "omt[%d]:jomt_data=%p", i,
Expand All @@ -652,7 +649,8 @@ jobject omtl_native_to_java(
/* Create the OMTensor Java object */
JNI_TYPE_VAR_CALL(env, jobject, jobj_omt,
(*env)->NewObject(env, japi->jomt_cls, japi->jomt_constructor,
jomt_data, jomt_shape, jomt_strides, jomt_dataType),
jomt_data, jomt_shape, jomt_strides, jomt_dataType,
(jboolean)jni_owning),
jobj_omt != NULL, japi->jecpt_cls, "omt[%d]:jobj_omt=%p", i, jobj_omt);

/* Set the OMTensor object in the object array */
Expand Down Expand Up @@ -732,7 +730,8 @@ static inline char *conv(const char *str, size_t (*fptr)(char *)) {
#endif

JNIEXPORT jobjectArray JNICALL
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points(JNIEnv *env, jclass cls) {
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points_1jni(
JNIEnv *env, jclass cls) {

log_init();

Expand Down Expand Up @@ -910,3 +909,23 @@ JNIEXPORT jstring JNICALL Java_com_ibm_onnxmlir_OMModel_output_1signature_1jni(

return java_osig;
}

JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_OMTensor_free_1data_1jni(
JNIEnv *env, jclass cls, jobject java_data) {

log_init();

/* Find and initialize Java Exception class */
JNI_TYPE_VAR_CALL(env, jclass, jecpt_cls,
(*env)->FindClass(env, jnistr[CLS_JAVA_LANG_EXCEPTION]),
jecpt_cls != NULL, NULL, "Class java/lang/Exception not found");

/* Get native data buffer backing direct byte buffer */
JNI_TYPE_VAR_CALL(env, void *, jni_data,
(*env)->GetDirectBufferAddress(env, java_data), jni_data != NULL,
jecpt_cls, "jni_data=%p", jni_data);

free(jni_data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into this during testing of my older finalize()-based PR, but jni_data points to the aligned allocation, not to the start of the allocation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are saying but I'm a bit surprised why this would happen. Basically, when you call NewDirectByteBuffer, you give it the pointer to the memory block. Now when you call GetDirectBufferAddress, it's going to give you a different (aligned) pointer? I mean I gave it the pointer, now I want it back, why would it change it behind my back?!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK maybe you mean onnx-mlir might give you an aligned pointer that's not the start of the allocation? OK in that case we need fiddle a bit like the offset thing you did. Itchy.

Copy link
Collaborator Author

@gongsu832 gongsu832 Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I looked at the conversion on your finalize PR and you said:

I thought about doing that but I wanted to avoid a limitation with java - all of the integral based arguments are signed. With an offset, which can be signed, I can keep the semantics more or less the same. With a Cleaner based approach I can create another DirectByteBuffer holding the actual allocation variable which the cleaner can use to clean up - no offset calculation required.

So I think instead of passing in the offset, we could just pass in the aligned pointer as jlong, right? As you said later, it doesn't really matter if it's signed or not. 64-bit is 64-bit. It only matters how you interpret it. Doing that would save us from having to keep another flag and doing the math.


return NULL;
}
4 changes: 2 additions & 2 deletions src/Runtime/jni/src/com/ibm/onnxmlir/OMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public class OMModel {
}

private static native OMTensorList main_graph_jni(OMTensorList list);
private static native String[] query_entry_points();
private static native String[] query_entry_points_jni();
private static native String input_signature_jni(String entry_point);
private static native String output_signature_jni(String entry_point);

Expand All @@ -188,7 +188,7 @@ public static OMTensorList mainGraph(OMTensorList list) {
* @return String array of entry point names
*/
public static String[] queryEntryPoints() {
return query_entry_points();
return query_entry_points_jni();
}

/**
Expand Down
50 changes: 48 additions & 2 deletions src/Runtime/jni/src/com/ibm/onnxmlir/OMTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

package com.ibm.onnxmlir;

import java.lang.ref.Cleaner;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
Expand All @@ -15,7 +16,7 @@
* shape, strides, data type, etc. associated with a tensor
* input/output.
*/
public class OMTensor {
public class OMTensor implements AutoCloseable {

/* We can use enum but that creates another class
* which complicates things for JNI.
Expand Down Expand Up @@ -92,6 +93,43 @@ public class OMTensor {
*/
private int _rank;

/* ByteBuffer created by jniwrapper using NewDirectByteBuffer
* apparently doesn't get GC-ed. So we need to register a
* cleaner to manually free the output tensor data buffer
* created by the model runtime.
*
* Note the cleaner API requires Java 9+. For Java 8, we can
* use finalize() but be aware that it has been deprecated since
* Java 9 (hence the use of cleaner).
*
* _cleanable is registered by the internal JNI-wrapper-only
* OMTensor constructor to signal the need for manual freeing.
*/
private static native Void free_data_jni(ByteBuffer data);

/* Create a Cleaner object and use it to register our freeData
* routine.
*/
private static final Cleaner _cleaner = Cleaner.create();
private final Cleaner.Cleanable _cleanable;

/* When OMTensor becomes phantom reachable, JVM will call the
* cleaner we registered and run freeData. We call the native
* free_data_jni to do the manual freeing.
*/
private static Runnable freeData(ByteBuffer data) {
return() -> {
free_data_jni(data);
};
}

/* User can explicit call close to run freeData. This is required
* for implementing AutoCloseable. */
@Override
public void close() {
if (_cleanable != null) _cleanable.clean();
}

/**
* Constructor
*
Expand All @@ -107,6 +145,7 @@ public OMTensor(byte[] data, long[] shape, boolean flag) {
else
setByteData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -132,6 +171,7 @@ public OMTensor(byte[] data, long[] shape) {
public OMTensor(short[] data, long[] shape) {
setShortData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -145,6 +185,7 @@ public OMTensor(short[] data, long[] shape) {
public OMTensor(int[] data, long[] shape) {
setIntData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -158,6 +199,7 @@ public OMTensor(int[] data, long[] shape) {
public OMTensor(long[] data, long[] shape) {
setLongData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -171,6 +213,7 @@ public OMTensor(long[] data, long[] shape) {
public OMTensor(float[] data, long[] shape) {
setFloatData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -184,6 +227,7 @@ public OMTensor(float[] data, long[] shape) {
public OMTensor(double[] data, long[] shape) {
setDoubleData(data);
putShape(shape);
_cleanable = null;
}


Expand Down Expand Up @@ -588,6 +632,7 @@ protected OMTensor(ByteBuffer data, long[] shape, ByteOrder endian, int dataType
_data = data.order(endian);
_dataType = dataType;
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -598,7 +643,7 @@ protected OMTensor(ByteBuffer data, long[] shape, ByteOrder endian, int dataType
* @param strides data stride
* @param dataType data type
*/
protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType) {
protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType, boolean clean) {
if (shape.length != strides.length)
throw new IllegalArgumentException(
"shape.length (" + shape.length + ") != stride.length (" + strides.length + ")");
Expand All @@ -610,6 +655,7 @@ protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType)
_rank = shape.length;
_shape = shape;
_strides = strides;
_cleanable = clean ? _cleaner.register(this, freeData(_data)) : null;
}

/**
Expand Down
3 changes: 2 additions & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,8 @@ def JniExecutionSession(jar_name, inputs):
)

dtype = {
"b1": np.bool,
# np.bool is deprecated, use np.bool_ instead
"b1": np.bool_,
"i1": np.int8,
"u1": np.uint8,
"i2": np.int16,
Expand Down