Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JNI wrapper memory leak #2178

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ find_package(Java COMPONENTS Development)
find_package(JNI)

if (Java_Development_FOUND AND JNI_FOUND)
message(STATUS "Java version : ${Java_VERSION}")
set(ONNX_MLIR_ENABLE_JNI TRUE)
include(UseJava)
else()
Expand Down
4 changes: 3 additions & 1 deletion src/Runtime/jni/jnidummy.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
//===----------------------------------------------------------------------===//

#include "com_ibm_onnxmlir_OMModel.h"
#include "com_ibm_onnxmlir_OMTensor.h"

/* Dummy routine to force the link editor to embed code in libjniruntime.a
into libmodel.so */
void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) {
Java_com_ibm_onnxmlir_OMModel_main_1graph_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points(NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points_1jni(NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_input_1signature_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMModel_output_1signature_1jni(NULL, NULL, NULL);
Java_com_ibm_onnxmlir_OMTensor_free_1data_1jni(NULL, NULL, 0L);
}
48 changes: 31 additions & 17 deletions src/Runtime/jni/jniwrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "jnilog.h"

extern OMTensorList *run_main_graph(OMTensorList *);
extern void *omTensorGetAllocatedPtr(const OMTensor *tensor);

/* Declare type var, make call and assign to var, check condition.
* It's assumed that a Java exception has already been thrown so
Expand Down Expand Up @@ -175,7 +176,7 @@ const char *jnistr[] = {
"com/ibm/onnxmlir/OMTensor", /* 3 CLS_COM_IBM_ONNXMLIR_OMTENSOR */
"com/ibm/onnxmlir/OMTensorList", /* 4 CLS_COM_IBM_ONNXMLIR_OMTENSORLIST*/
"<init>", /* 5 CTOR_INIT */
"(Ljava/nio/ByteBuffer;[J[JI)V", /* 6 CTOR_OMTENSOR */
"(Ljava/nio/ByteBuffer;[J[JIJ)V", /* 6 CTOR_OMTENSOR */
"([Lcom/ibm/onnxmlir/OMTensor;)V", /* 7 CTOR_OMTENSORLIST */
"()Ljava/nio/ByteBuffer;", /* 8 SIG_GET_DATA */
"(Ljava/nio/ByteBuffer;)V", /* 9 SIG_SET_DATA */
Expand Down Expand Up @@ -497,12 +498,13 @@ OMTensorList *omtl_java_to_native(
free(jobj_omts);

/* Create OMTensorList to be constructed and passed to the model
* shared library. Note that we do own the pointers to the native
* OMTensor structs, jni_omts.
* shared library. Note that omTensorListCreate now copies the input
* tensor list so we must free jni_omts to avoid memory leak.
*/
LIB_TYPE_VAR_CALL(OMTensorList *, jni_omtl,
omTensorListCreate(jni_omts, (int64_t)jomtl_omtn), jni_omtl != NULL, env,
japi->jecpt_cls, "jni_omtl=%p", jni_omtl);
free(jni_omts);

return jni_omtl;
}
Expand Down Expand Up @@ -562,6 +564,9 @@ jobject omtl_native_to_java(

LIB_TYPE_VAR_CALL(void *, jni_data, omTensorGetDataPtr(jni_omts[i]),
jni_data != NULL, env, japi->jecpt_cls, "omt[%d]:data=%p", i, jni_data);
LIB_TYPE_VAR_CALL(void *, jni_alloc, omTensorGetAllocatedPtr(jni_omts[i]),
jni_alloc != NULL, env, japi->jecpt_cls, "omt[%d]:alloc=%p", i,
jni_alloc);
LIB_TYPE_VAR_CALL(const int64_t *, jni_shape, omTensorGetShape(jni_omts[i]),
jni_shape != NULL, env, japi->jecpt_cls, "omt[%d]:shape=%p", i,
jni_shape);
Expand Down Expand Up @@ -598,9 +603,11 @@ jobject omtl_native_to_java(
* If jni_owning is true, we take ownership by setting owner flag
* to false. This means that when we call omTensorListDestroy
* the data buffer will not be freed since it has been given to
* the Java direct byte buffer and the Java GC will be responsible
* for freeing the data buffer. This way we avoid copying the data
* buffer.
* the Java direct byte buffer. However, the direct byte buffer
* created by NewDirectByteBuffer is not subject to Java GC so we
* pass clean=true to the OMTensor construct to initializ a cleaner
* to manually free the data buffer. This way we avoid copying the
* data buffer.
*
* If jni_owning is false, it means the data buffer is not freeable
* due to one of the two following cases:
Expand All @@ -609,23 +616,18 @@ jobject omtl_native_to_java(
* responsible for freeing it
* - the data buffer is static
*
* Either way, since the data buffer will be given to Java and is
* subject to GC, we must make a copy of the data buffer.
* Either way, since the data buffer given to Java is not subject
* to GC, we can simply pass it through with clean=false and no
* cleaner will be initialized.
*/
void *jbytebuffer_data = jni_data;
if (jni_owning) {
LIB_CALL(omTensorSetOwning(jni_omts[i], (int64_t)0), 1, env,
japi->jecpt_cls, "");
LOG_PRINTF(LOG_DEBUG, "omt[%d]:%p data %p ownership taken", i,
jni_omts[i], jni_data);
} else {
LIB_VAR_CALL(jbytebuffer_data, malloc(jni_bufferSize),
jbytebuffer_data != NULL, env, japi->jecpt_cls, "jbytebuffer_data=%p",
jbytebuffer_data);
memcpy(jbytebuffer_data, jni_data, jni_bufferSize);
LOG_PRINTF(LOG_DEBUG, "omt[%d]:%p data %p copied into %p", i, jni_omts[i],
jni_data, jbytebuffer_data);
}

JNI_TYPE_VAR_CALL(env, jobject, jomt_data,
(*env)->NewDirectByteBuffer(env, jbytebuffer_data, jomt_bufferSize),
jomt_data != NULL, japi->jecpt_cls, "omt[%d]:jomt_data=%p", i,
Expand All @@ -652,7 +654,8 @@ jobject omtl_native_to_java(
/* Create the OMTensor Java object */
JNI_TYPE_VAR_CALL(env, jobject, jobj_omt,
(*env)->NewObject(env, japi->jomt_cls, japi->jomt_constructor,
jomt_data, jomt_shape, jomt_strides, jomt_dataType),
jomt_data, jomt_shape, jomt_strides, jomt_dataType,
jni_owning ? (jlong)jni_alloc : (jlong)0),
jobj_omt != NULL, japi->jecpt_cls, "omt[%d]:jobj_omt=%p", i, jobj_omt);

/* Set the OMTensor object in the object array */
Expand Down Expand Up @@ -732,7 +735,8 @@ static inline char *conv(const char *str, size_t (*fptr)(char *)) {
#endif

JNIEXPORT jobjectArray JNICALL
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points(JNIEnv *env, jclass cls) {
Java_com_ibm_onnxmlir_OMModel_query_1entry_1points_1jni(
JNIEnv *env, jclass cls) {

log_init();

Expand Down Expand Up @@ -910,3 +914,13 @@ JNIEXPORT jstring JNICALL Java_com_ibm_onnxmlir_OMModel_output_1signature_1jni(

return java_osig;
}

JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_OMTensor_free_1data_1jni(
JNIEnv *env, jclass cls, jlong alloc) {

log_init();

free((void *)alloc);

return NULL;
}
4 changes: 2 additions & 2 deletions src/Runtime/jni/src/com/ibm/onnxmlir/OMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public class OMModel {
}

private static native OMTensorList main_graph_jni(OMTensorList list);
private static native String[] query_entry_points();
private static native String[] query_entry_points_jni();
private static native String input_signature_jni(String entry_point);
private static native String output_signature_jni(String entry_point);

Expand All @@ -188,7 +188,7 @@ public static OMTensorList mainGraph(OMTensorList list) {
* @return String array of entry point names
*/
public static String[] queryEntryPoints() {
return query_entry_points();
return query_entry_points_jni();
}

/**
Expand Down
50 changes: 48 additions & 2 deletions src/Runtime/jni/src/com/ibm/onnxmlir/OMTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

package com.ibm.onnxmlir;

import java.lang.ref.Cleaner;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
Expand All @@ -15,7 +16,7 @@
* shape, strides, data type, etc. associated with a tensor
* input/output.
*/
public class OMTensor {
public class OMTensor implements AutoCloseable {

/* We can use enum but that creates another class
* which complicates things for JNI.
Expand Down Expand Up @@ -92,6 +93,43 @@ public class OMTensor {
*/
private int _rank;

/* ByteBuffer created by jniwrapper using NewDirectByteBuffer
* apparently doesn't get GC-ed. So we need to register a
* cleaner to manually free the output tensor data buffer
* created by the model runtime.
*
* Note the cleaner API requires Java 9+. For Java 8, we can
* use finalize() but be aware that it has been deprecated since
* Java 9 (hence the use of cleaner).
*
* _cleanable is registered by the internal JNI-wrapper-only
* OMTensor constructor to signal the need for manual freeing.
*/
private static native Void free_data_jni(long alloc);

/* Create a Cleaner object and use it to register our freeData
* routine.
*/
private static final Cleaner _cleaner = Cleaner.create();
private final Cleaner.Cleanable _cleanable;

/* When OMTensor becomes phantom reachable, JVM will call the
* cleaner we registered and run freeData. We call the native
* free_data_jni to do the manual freeing.
*/
private static Runnable freeData(long alloc) {
return() -> {
free_data_jni(alloc);
};
}

/* User can explicit call close to run freeData. This is required
* for implementing AutoCloseable.
*/
public void close() {
if (_cleanable != null) _cleanable.clean();
}

/**
* Constructor
*
Expand All @@ -107,6 +145,7 @@ public OMTensor(byte[] data, long[] shape, boolean flag) {
else
setByteData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -132,6 +171,7 @@ public OMTensor(byte[] data, long[] shape) {
public OMTensor(short[] data, long[] shape) {
setShortData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -145,6 +185,7 @@ public OMTensor(short[] data, long[] shape) {
public OMTensor(int[] data, long[] shape) {
setIntData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -158,6 +199,7 @@ public OMTensor(int[] data, long[] shape) {
public OMTensor(long[] data, long[] shape) {
setLongData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -171,6 +213,7 @@ public OMTensor(long[] data, long[] shape) {
public OMTensor(float[] data, long[] shape) {
setFloatData(data);
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -184,6 +227,7 @@ public OMTensor(float[] data, long[] shape) {
public OMTensor(double[] data, long[] shape) {
setDoubleData(data);
putShape(shape);
_cleanable = null;
}


Expand Down Expand Up @@ -588,6 +632,7 @@ protected OMTensor(ByteBuffer data, long[] shape, ByteOrder endian, int dataType
_data = data.order(endian);
_dataType = dataType;
putShape(shape);
_cleanable = null;
}

/**
Expand All @@ -598,7 +643,7 @@ protected OMTensor(ByteBuffer data, long[] shape, ByteOrder endian, int dataType
* @param strides data stride
* @param dataType data type
*/
protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType) {
protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType, long alloc) {
if (shape.length != strides.length)
throw new IllegalArgumentException(
"shape.length (" + shape.length + ") != stride.length (" + strides.length + ")");
Expand All @@ -610,6 +655,7 @@ protected OMTensor(ByteBuffer data, long[] shape, long[] strides, int dataType)
_rank = shape.length;
_shape = shape;
_strides = strides;
_cleanable = alloc == 0 ? null : _cleaner.register(this, freeData(alloc));
}

/**
Expand Down
10 changes: 10 additions & 0 deletions src/Runtime/jni/src/com/ibm/onnxmlir/OMTensorList.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,14 @@ public OMTensor[] getOmtArray() {
public OMTensor getOmtByIndex(int index) {
return _omts[index];
}

/**
* OMTensor cleaner
*
* Explicitly close OMTensor to free native data buffer (if any)
*/
public void close() {
for (int i = 0; i < _omts.length; i++)
_omts[i].close();
}
}
3 changes: 2 additions & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,8 @@ def JniExecutionSession(jar_name, inputs):
)

dtype = {
"b1": np.bool,
# np.bool is deprecated, use np.bool_ instead
"b1": np.bool_,
"i1": np.int8,
"u1": np.uint8,
"i2": np.int16,
Expand Down