diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c
index 25a35029..c825f5e2 100644
--- a/ext/mysql2/client.c
+++ b/ext/mysql2/client.c
@@ -19,7 +19,7 @@ extern VALUE mMysql2, cMysql2Error, cMysql2TimeoutError;
static VALUE sym_id, sym_version, sym_header_version, sym_async, sym_symbolize_keys, sym_as, sym_array, sym_stream;
static VALUE sym_no_good_index_used, sym_no_index_used, sym_query_was_slow;
static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args,
- intern_current_query_options, intern_read_timeout;
+ intern_current_query_options, intern_read_timeout, intern_values;
#define REQUIRE_INITIALIZED(wrapper) \
if (!wrapper->initialized) { \
@@ -221,6 +221,7 @@ static void rb_mysql_client_mark(void * wrapper) {
if (w) {
rb_gc_mark_movable(w->encoding);
rb_gc_mark_movable(w->active_fiber);
+ rb_gc_mark_movable(w->prepared_statements);
}
}
@@ -353,6 +354,14 @@ static VALUE invalidate_fd(int clientfd)
}
#endif /* _WIN32 */
+static int decr_mysql2_stmt_hash(VALUE key, VALUE val, VALUE arg)
+{
+ mysql_client_wrapper *wrapper = (mysql_client_wrapper *)arg;
+ VALUE stmt = rb_ivar_get(wrapper->prepared_statements, key);
+ // rb_funcall(stmt, rb_intern("close"), 0);
+ return 0;
+}
+
static void *nogvl_close(void *ptr) {
mysql_client_wrapper *wrapper = ptr;
@@ -388,6 +397,8 @@ void decr_mysql2_client(mysql_client_wrapper *wrapper)
}
#endif
+ // rb_hash_foreach(wrapper->prepared_statements, decr_mysql2_stmt_hash, (VALUE)wrapper);
+
nogvl_close(wrapper);
xfree(wrapper->client);
xfree(wrapper);
@@ -404,6 +415,7 @@ static VALUE allocate(VALUE klass) {
#endif
wrapper->encoding = Qnil;
wrapper->active_fiber = Qnil;
+ wrapper->prepared_statements = rb_hash_new();
wrapper->automatic_close = 1;
wrapper->server_version = 0;
wrapper->reconnect_enabled = 0;
@@ -1535,10 +1547,25 @@ static VALUE initialize_ext(VALUE self) {
* Create a new prepared statement.
*/
static VALUE rb_mysql_client_prepare_statement(VALUE self, VALUE sql) {
+ VALUE stmt;
GET_CLIENT(self);
REQUIRE_CONNECTED(wrapper);
- return rb_mysql_stmt_new(self, sql);
+ stmt = rb_mysql_stmt_new(self, sql);
+
+ return stmt;
+}
+
+/* call-seq:
+ * client.prepared_statements
+ *
+ * Returns an array of prepared statement objects.
+ */
+static VALUE rb_mysql_client_prepared_statements_read(VALUE self) {
+ unsigned long retVal;
+ GET_CLIENT(self);
+
+ return rb_funcall(wrapper->prepared_statements, intern_values, 0);
}
void init_mysql2_client() {
@@ -1588,6 +1615,7 @@ void init_mysql2_client() {
rb_define_method(cMysql2Client, "last_id", rb_mysql_client_last_id, 0);
rb_define_method(cMysql2Client, "affected_rows", rb_mysql_client_affected_rows, 0);
rb_define_method(cMysql2Client, "prepare", rb_mysql_client_prepare_statement, 1);
+ rb_define_method(cMysql2Client, "prepared_statements", rb_mysql_client_prepared_statements_read, 0);
rb_define_method(cMysql2Client, "thread_id", rb_mysql_client_thread_id, 0);
rb_define_method(cMysql2Client, "ping", rb_mysql_client_ping, 0);
rb_define_method(cMysql2Client, "select_db", rb_mysql_client_select_db, 1);
@@ -1641,6 +1669,7 @@ void init_mysql2_client() {
intern_new_with_args = rb_intern("new_with_args");
intern_current_query_options = rb_intern("@current_query_options");
intern_read_timeout = rb_intern("@read_timeout");
+ intern_values = rb_intern("values");
#ifdef CLIENT_LONG_PASSWORD
rb_const_set(cMysql2Client, rb_intern("LONG_PASSWORD"),
diff --git a/ext/mysql2/client.h b/ext/mysql2/client.h
index 6a8227bd..67bd35a5 100644
--- a/ext/mysql2/client.h
+++ b/ext/mysql2/client.h
@@ -4,6 +4,7 @@
typedef struct {
VALUE encoding;
VALUE active_fiber; /* rb_fiber_current() or Qnil */
+ VALUE prepared_statements;
long server_version;
int reconnect_enabled;
unsigned int connect_timeout;
diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c
index fa3b660c..c71224ff 100644
--- a/ext/mysql2/statement.c
+++ b/ext/mysql2/statement.c
@@ -75,7 +75,15 @@ void decr_mysql2_stmt(mysql_stmt_wrapper *stmt_wrapper) {
stmt_wrapper->refcount--;
if (stmt_wrapper->refcount == 0) {
+ // If the GC get to client first it will be nil, and this cleanup won't matter
+ if (stmt_wrapper->client_wrapper && stmt_wrapper->client_wrapper->refcount > 0) {
+ // Remove the reference to this statement handle from the Client object.
+ rb_hash_delete(stmt_wrapper->client_wrapper->prepared_statements,
+ ULL2NUM((unsigned long long)stmt_wrapper));
+ }
+
nogvl_stmt_close(stmt_wrapper);
+ decr_mysql2_client(stmt_wrapper->client_wrapper);
xfree(stmt_wrapper);
}
}
@@ -140,10 +148,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) {
rb_stmt = Data_Make_Struct(cMysql2Statement, mysql_stmt_wrapper, rb_mysql_stmt_mark, rb_mysql_stmt_free, stmt_wrapper);
#endif
{
- stmt_wrapper->client = rb_client;
stmt_wrapper->refcount = 1;
stmt_wrapper->closed = 0;
stmt_wrapper->stmt = NULL;
+
+ /* Keep a handle to the Client to ensure it doesn't get garbage collected first */
+ stmt_wrapper->client = rb_client;
+ if (rb_client != Qnil) {
+ stmt_wrapper->client_wrapper = DATA_PTR(rb_client);
+ stmt_wrapper->client_wrapper->refcount++;
+ } else {
+ stmt_wrapper->client_wrapper = NULL;
+ }
}
// instantiate stmt
@@ -178,6 +194,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) {
}
}
+ // Stash a reference to this statement handle into the Client to prevent
+ // premature garbage collection.
+ //
+ // A statement can either be free explicitly or when the client object is
+ // torn down. Freeing a statement handle at any other time causes protocol
+ // traffic that might happen while the connection state is set for another
+ // operation.
+ {
+ GET_CLIENT(rb_client);
+ rb_hash_aset(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper), rb_stmt);
+ }
+
return rb_stmt;
}
@@ -609,7 +637,9 @@ static VALUE rb_mysql_stmt_close(VALUE self) {
RAW_GET_STATEMENT(self);
if (!stmt_wrapper->closed) {
+ GET_CLIENT(stmt_wrapper->client);
stmt_wrapper->closed = 1;
+ rb_hash_delete(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper));
rb_thread_call_without_gvl(nogvl_stmt_close, stmt_wrapper, RUBY_UBF_IO, 0);
}
diff --git a/ext/mysql2/statement.h b/ext/mysql2/statement.h
index e4851067..78a2ef4d 100644
--- a/ext/mysql2/statement.h
+++ b/ext/mysql2/statement.h
@@ -2,10 +2,11 @@
#define MYSQL2_STATEMENT_H
typedef struct {
+ int closed;
+ int refcount;
VALUE client;
+ mysql_client_wrapper *client_wrapper;
MYSQL_STMT *stmt;
- int refcount;
- int closed;
} mysql_stmt_wrapper;
void init_mysql2_statement(void);