Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.{CompletionIterator, TaskInterruptListener}

/**
* Data corresponding to one partition of a JDBCRDD.
Expand Down Expand Up @@ -346,6 +346,27 @@ class JDBCRDD(
stmt.setFetchSize(dialect.getFetchSize(options))
stmt.setQueryTimeout(options.queryTimeout)

// JDBC socket reads (e.g., from executeQuery() / ResultSet.next()) are not interruptible via
// Thread.interrupt(). Register the listener immediately before executeQuery() so we close the
// partition connection on kill and unblock the native read. We capture conn in a local val
// (after connection setup) so the listener closes the same reference the task thread uses;
// we only close the connection (not rs/stmt) to avoid races with the completion listener.
// Tradeoff: interrupts during getConnection / sessionInitStatement / prepareStatement are not
// covered here; those steps are usually short compared to the main query + fetch loop.
val connForInterrupt = conn
context.addTaskInterruptListener(new TaskInterruptListener {
override def onTaskInterrupted(context: TaskContext, reason: String): Unit = {
try {
if (connForInterrupt != null && !connForInterrupt.isClosed) {
connForInterrupt.close()
}
} catch {
case NonFatal(e) =>
logWarning("Exception closing JDBC connection on task interrupt", e)
}
}
})

rs = SQLMetrics.withTimingNs(queryExecutionTimeMetric) {
try {
stmt.executeQuery()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDiale
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, TaskInterruptListener}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.NextIterator

/**
* Util functions for JDBC tables.
Expand Down Expand Up @@ -798,6 +798,26 @@ object JdbcUtils extends Logging with SQLConfHelper {
val outMetrics = TaskContext.get().taskMetrics().outputMetrics

val conn = dialect.createConnectionFactory(options)(-1)

// Close JDBC connection so blocked native reads (e.g. executeBatch) fail instead of
// ignoring Thread.interrupt(). Listener registered after opening the connection; we don't need
// to synchronize or use atomic references.
// Interrupt during connection setup can miss the listener; finally still closes the
// connection. After registration, closing connection makes later JDBC calls throw
// SQLException and the task unwinds.
Option(TaskContext.get()).foreach { tc =>
tc.addTaskInterruptListener(new TaskInterruptListener {
override def onTaskInterrupted(context: TaskContext, reason: String): Unit = {
try {
conn.close()
} catch {
case NonFatal(e) =>
logWarning("Exception closing JDBC connection on task interrupt", e)
}
}
})
}

var committed = false

var finalIsolationLevel = Connection.TRANSACTION_NONE
Expand Down Expand Up @@ -859,6 +879,10 @@ object JdbcUtils extends Logging with SQLConfHelper {
rowCount += 1
totalRowCount += 1
if (rowCount % batchSize == 0) {
// Hot spot for native blocking reads; TaskInterruptListener (registered after
// opening the connection in this method) closes conn to unblock. JDBC 4.0 section 9.6:
// methods on a closed Connection throw SQLException (expected for major drivers).
// Mid-batch kill may drop the in-flight batch; still better than hanging forever.
stmt.executeBatch()
rowCount = 0
}
Expand Down Expand Up @@ -899,7 +923,16 @@ object JdbcUtils extends Logging with SQLConfHelper {
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
// The connection may already be closed by the task interrupt listener; rollback
// is best-effort in that case.
try {
if (!conn.isClosed) {
conn.rollback()
}
} catch {
case NonFatal(e) =>
logWarning("Exception rolling back transaction on task failure", e)
}
} else {
outMetrics.setRecordsWritten(totalRowCount)
}
Expand Down
Loading