Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {
private final HashMap<Long, AgentSpan> stageSpans = new HashMap<>();

private final HashMap<Integer, Integer> stageToJob = new HashMap<>();
private final HashMap<Integer, Long> jobToSqlExecution = new HashMap<>();
private final HashMap<Long, Properties> stageProperties = new HashMap<>();

private final SparkAggregatedTaskMetrics applicationMetrics = new SparkAggregatedTaskMetrics();
Expand All @@ -139,6 +140,9 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {
private boolean lastJobFailed = false;
private String lastJobFailedMessage;
private String lastJobFailedStackTrace;
private boolean lastSqlFailed = false;
private String lastSqlFailedMessage;
private String lastSqlFailedStackTrace;
private int jobCount = 0;
private int currentExecutorCount = 0;
private int maxExecutorCount = 0;
Expand Down Expand Up @@ -356,6 +360,11 @@ public synchronized void finishApplication(
applicationSpan.setTag(DDTags.ERROR_TYPE, "Spark Application Failed");
applicationSpan.setTag(DDTags.ERROR_MSG, lastJobFailedMessage);
applicationSpan.setTag(DDTags.ERROR_STACK, lastJobFailedStackTrace);
} else if (lastSqlFailed) {
applicationSpan.setError(true);
applicationSpan.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed");
applicationSpan.setTag(DDTags.ERROR_MSG, lastSqlFailedMessage);
applicationSpan.setTag(DDTags.ERROR_STACK, lastSqlFailedStackTrace);
}

applicationMetrics.setSpanMetrics(applicationSpan);
Expand Down Expand Up @@ -513,6 +522,9 @@ public synchronized void onJobStart(SparkListenerJobStart jobStart) {
for (int stageId : getSparkJobStageIds(jobStart)) {
stageToJob.put(stageId, jobStart.jobId());
}
if (sqlExecutionId != null) {
jobToSqlExecution.put(jobStart.jobId(), sqlExecutionId);
}
jobSpans.put(jobStart.jobId(), jobSpan);
notifyOl(x -> openLineageSparkListener.onJobStart(x), jobStart);
}
Expand All @@ -524,6 +536,8 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
return;
}

Long sqlExecutionId = jobToSqlExecution.remove(jobEnd.jobId());

if (jobEnd.jobResult() instanceof JobFailed) {
JobFailed jobFailed = (JobFailed) jobEnd.jobResult();
Exception exception = jobFailed.exception();
Expand All @@ -536,6 +550,18 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
jobSpan.setTag(DDTags.ERROR_STACK, errorStackTrace);
jobSpan.setTag(DDTags.ERROR_TYPE, "Spark Job Failed");

// On Spark 3.4+, onSQLExecutionEnd may overwrite this with the authoritative
// errorMessage() field — that's intentional, as the SQL-level error is more precise.
if (sqlExecutionId != null) {
AgentSpan sqlSpan = sqlSpans.get(sqlExecutionId);
if (sqlSpan != null) {
sqlSpan.setError(true);
sqlSpan.setErrorMessage(errorMessage);
sqlSpan.setTag(DDTags.ERROR_STACK, errorStackTrace);
sqlSpan.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed");
}
}

// Only propagate the error to the application if it is not a cancellation
if (errorMessage != null && !errorMessage.toLowerCase().contains("cancelled")) {
lastJobFailed = true;
Expand Down Expand Up @@ -842,6 +868,9 @@ private <T extends SparkListenerEvent> void notifyOl(Consumer<T> ol, T event) {
private static final MethodHandle adaptiveExecutionIdMethod;
private static final MethodHandle adaptiveSparkPlanMethod;

// Spark 3.4+ added errorMessage() to SparkListenerSQLExecutionEnd (SPARK-41827)
private static final MethodHandle sqlEndErrorMessageMethod;

@SuppressForbidden // Using reflection to avoid splitting the instrumentation once more
private static Class<?> findAdaptiveExecutionUpdateClass() throws ClassNotFoundException {
return Class.forName(
Expand Down Expand Up @@ -869,6 +898,18 @@ private static Class<?> findAdaptiveExecutionUpdateClass() throws ClassNotFoundE
adaptiveExecutionUpdateClass = executionUpdateClass;
adaptiveExecutionIdMethod = executionIdMethod;
adaptiveSparkPlanMethod = sparkPlanMethod;

MethodHandle errorMessageMethod = null;
try {
MethodHandles.Lookup lookup = MethodHandles.lookup();
errorMessageMethod =
lookup.findVirtual(
SparkListenerSQLExecutionEnd.class,
"errorMessage",
MethodType.methodType(scala.Option.class));
} catch (NoSuchMethodException | IllegalAccessException ignored) {
}
sqlEndErrorMessageMethod = errorMessageMethod;
}

private synchronized void updateAdaptiveSQLPlan(SparkListenerEvent event) {
Expand All @@ -891,6 +932,13 @@ private synchronized void onSQLExecutionStart(SparkListenerSQLExecutionStart sql

private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd) {
AgentSpan span = sqlSpans.remove(sqlEnd.executionId());

// AnalysisException case: no job ran, so the span was never created
if (span == null && sqlQueries.containsKey(sqlEnd.executionId())) {
span = getOrCreateSqlSpan(sqlEnd.executionId(), null, null);
sqlSpans.remove(sqlEnd.executionId());
}

SparkAggregatedTaskMetrics metrics = sqlMetrics.remove(sqlEnd.executionId());
sqlQueries.remove(sqlEnd.executionId());
sqlPlans.remove(sqlEnd.executionId());
Expand All @@ -899,12 +947,44 @@ private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd)
if (metrics != null) {
metrics.setSpanMetrics(span);
}

String errorMsg = getSqlEndErrorMessage(sqlEnd);
if (errorMsg != null) {
String sqlErrorMessage = getErrorMessageWithoutStackTrace(errorMsg);
span.setError(true);
span.setErrorMessage(sqlErrorMessage);
span.setTag(DDTags.ERROR_STACK, errorMsg);
span.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed");

if (sqlErrorMessage == null || !sqlErrorMessage.toLowerCase().contains("cancelled")) {
lastSqlFailed = true;
lastSqlFailedMessage = sqlErrorMessage;
lastSqlFailedStackTrace = errorMsg;
}
} else {
lastSqlFailed = false;
}

notifyOl(x -> openLineageSparkListener.onOtherEvent(x), sqlEnd);

span.finish(sqlEnd.time() * 1000);
}
}

private static String getSqlEndErrorMessage(SparkListenerSQLExecutionEnd sqlEnd) {
if (sqlEndErrorMessageMethod == null) {
return null;
}
try {
scala.Option<?> errorMessage = (scala.Option<?>) sqlEndErrorMessageMethod.invoke(sqlEnd);
if (errorMessage.isDefined()) {
return (String) errorMessage.get();
}
} catch (Throwable ignored) {
}
return null;
}

private synchronized void onStreamingQueryStartedEvent(
StreamingQueryListener.QueryStartedEvent event) {
if (streamingQueries.size() > MAX_COLLECTION_SIZE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.spark.scheduler.StageInfo
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.scheduler.TaskLocality
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
import org.apache.spark.storage.RDDInfo
import scala.Option
import scala.collection.immutable.HashMap
Expand Down Expand Up @@ -652,6 +655,109 @@ abstract class AbstractSparkListenerTest extends InstrumentationSpecification {
.getOption("spark.openlineage.circuitBreaker.timeoutInSeconds") == Option.apply("120")
}

protected sqlExecutionStartEvent(Long executionId, Long time, String description = "SELECT * FROM test") {
def emptySeq = JavaConverters.asScalaBuffer([]).toSeq()
def emptyMap = new scala.collection.immutable.HashMap<String, String>()
def sparkPlanInfo = new SparkPlanInfo("TestPlan", "TestPlan", emptySeq, emptyMap, emptySeq)

return new SparkListenerSQLExecutionStart(
executionId,
description,
"details",
"physical plan",
sparkPlanInfo,
time
)
}

protected sqlExecutionEndEvent(Long executionId, Long time) {
return new SparkListenerSQLExecutionEnd(executionId, time)
}

def "test SQL span created when no job runs"() {
setup:
def listener = getTestDatadogSparkListener()
listener.onApplicationStart(applicationStartEvent(1000L))
listener.onOtherEvent(sqlExecutionStartEvent(1L, 2000L, "SELECT * FROM missing_table"))
listener.onOtherEvent(sqlExecutionEndEvent(1L, 3000L))
listener.onApplicationEnd(new SparkListenerApplicationEnd(4000L))

expect:
assertTraces(1) {
trace(2) {
span {
operationName "spark.application"
resourceName "spark.application"
spanType "spark"
parent()
}
span {
operationName "spark.sql"
resourceName "SELECT * FROM missing_table"
spanType "spark"
errored false
childOf(span(0))
}
}
}
}

def "test job failure propagates error to SQL span"() {
setup:
def listener = getTestDatadogSparkListener()
listener.onApplicationStart(applicationStartEvent(1000L))

// SQL execution with a job that fails
listener.onOtherEvent(sqlExecutionStartEvent(1L, 2000L, "SELECT * FROM bad_table"))
listener.onJobStart(jobStartEventWithSql(1, 2100L, [1], 1L))
listener.onJobEnd(jobFailedEvent(1, 2500L, "Table not found"))
listener.onOtherEvent(sqlExecutionEndEvent(1L, 3000L))

listener.onApplicationEnd(new SparkListenerApplicationEnd(4000L))

expect:
assertTraces(1) {
trace(3) {
span {
operationName "spark.application"
resourceName "spark.application"
spanType "spark"
errored true
parent()
}
span {
operationName "spark.sql"
resourceName "SELECT * FROM bad_table"
spanType "spark"
errored true
assert span.tags["error.type"] == "Spark SQL Failed"
childOf(span(0))
}
span {
operationName "spark.job"
spanType "spark"
errored true
childOf(span(1))
}
}
}
}

protected jobStartEventWithSql(Integer jobId, Long time, ArrayList<Integer> stageIds, Long sqlExecutionId) {
def stageInfos = stageIds.collect { stageId ->
createStageInfo(stageId)
}
def props = new Properties()
props.setProperty("spark.sql.execution.id", sqlExecutionId.toString())

return new SparkListenerJobStart(
jobId,
time,
JavaConverters.asScalaBuffer(stageInfos).toSeq(),
props
)
}

protected validateRelativeError(double value, double expected, double relativeAccuracy) {
double relativeError = Math.abs(value - expected) / expected
assert relativeError < relativeAccuracy
Expand Down
Loading