diff --git a/dd-java-agent/instrumentation/spark/spark-common/src/main/java/datadog/trace/instrumentation/spark/AbstractDatadogSparkListener.java b/dd-java-agent/instrumentation/spark/spark-common/src/main/java/datadog/trace/instrumentation/spark/AbstractDatadogSparkListener.java index 7fe4acf402e..df483513598 100644 --- a/dd-java-agent/instrumentation/spark/spark-common/src/main/java/datadog/trace/instrumentation/spark/AbstractDatadogSparkListener.java +++ b/dd-java-agent/instrumentation/spark/spark-common/src/main/java/datadog/trace/instrumentation/spark/AbstractDatadogSparkListener.java @@ -114,6 +114,7 @@ public abstract class AbstractDatadogSparkListener extends SparkListener { private final HashMap stageSpans = new HashMap<>(); private final HashMap stageToJob = new HashMap<>(); + private final HashMap jobToSqlExecution = new HashMap<>(); private final HashMap stageProperties = new HashMap<>(); private final SparkAggregatedTaskMetrics applicationMetrics = new SparkAggregatedTaskMetrics(); @@ -139,6 +140,9 @@ public abstract class AbstractDatadogSparkListener extends SparkListener { private boolean lastJobFailed = false; private String lastJobFailedMessage; private String lastJobFailedStackTrace; + private boolean lastSqlFailed = false; + private String lastSqlFailedMessage; + private String lastSqlFailedStackTrace; private int jobCount = 0; private int currentExecutorCount = 0; private int maxExecutorCount = 0; @@ -356,6 +360,11 @@ public synchronized void finishApplication( applicationSpan.setTag(DDTags.ERROR_TYPE, "Spark Application Failed"); applicationSpan.setTag(DDTags.ERROR_MSG, lastJobFailedMessage); applicationSpan.setTag(DDTags.ERROR_STACK, lastJobFailedStackTrace); + } else if (lastSqlFailed) { + applicationSpan.setError(true); + applicationSpan.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed"); + applicationSpan.setTag(DDTags.ERROR_MSG, lastSqlFailedMessage); + applicationSpan.setTag(DDTags.ERROR_STACK, lastSqlFailedStackTrace); } applicationMetrics.setSpanMetrics(applicationSpan); @@ -513,6 +522,9 @@ public synchronized void onJobStart(SparkListenerJobStart jobStart) { for (int stageId : getSparkJobStageIds(jobStart)) { stageToJob.put(stageId, jobStart.jobId()); } + if (sqlExecutionId != null) { + jobToSqlExecution.put(jobStart.jobId(), sqlExecutionId); + } jobSpans.put(jobStart.jobId(), jobSpan); notifyOl(x -> openLineageSparkListener.onJobStart(x), jobStart); } @@ -524,6 +536,8 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) { return; } + Long sqlExecutionId = jobToSqlExecution.remove(jobEnd.jobId()); + if (jobEnd.jobResult() instanceof JobFailed) { JobFailed jobFailed = (JobFailed) jobEnd.jobResult(); Exception exception = jobFailed.exception(); @@ -536,6 +550,18 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) { jobSpan.setTag(DDTags.ERROR_STACK, errorStackTrace); jobSpan.setTag(DDTags.ERROR_TYPE, "Spark Job Failed"); + // On Spark 3.4+, onSQLExecutionEnd may overwrite this with the authoritative + // errorMessage() field — that's intentional, as the SQL-level error is more precise. + if (sqlExecutionId != null) { + AgentSpan sqlSpan = sqlSpans.get(sqlExecutionId); + if (sqlSpan != null) { + sqlSpan.setError(true); + sqlSpan.setErrorMessage(errorMessage); + sqlSpan.setTag(DDTags.ERROR_STACK, errorStackTrace); + sqlSpan.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed"); + } + } + // Only propagate the error to the application if it is not a cancellation if (errorMessage != null && !errorMessage.toLowerCase().contains("cancelled")) { lastJobFailed = true; @@ -842,6 +868,9 @@ private void notifyOl(Consumer ol, T event) { private static final MethodHandle adaptiveExecutionIdMethod; private static final MethodHandle adaptiveSparkPlanMethod; + // Spark 3.4+ added errorMessage() to SparkListenerSQLExecutionEnd (SPARK-41827) + private static final MethodHandle sqlEndErrorMessageMethod; + @SuppressForbidden // Using reflection to avoid splitting the instrumentation once more private static Class findAdaptiveExecutionUpdateClass() throws ClassNotFoundException { return Class.forName( @@ -869,6 +898,18 @@ private static Class findAdaptiveExecutionUpdateClass() throws ClassNotFoundE adaptiveExecutionUpdateClass = executionUpdateClass; adaptiveExecutionIdMethod = executionIdMethod; adaptiveSparkPlanMethod = sparkPlanMethod; + + MethodHandle errorMessageMethod = null; + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + errorMessageMethod = + lookup.findVirtual( + SparkListenerSQLExecutionEnd.class, + "errorMessage", + MethodType.methodType(scala.Option.class)); + } catch (NoSuchMethodException | IllegalAccessException ignored) { + } + sqlEndErrorMessageMethod = errorMessageMethod; } private synchronized void updateAdaptiveSQLPlan(SparkListenerEvent event) { @@ -891,6 +932,13 @@ private synchronized void onSQLExecutionStart(SparkListenerSQLExecutionStart sql private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd) { AgentSpan span = sqlSpans.remove(sqlEnd.executionId()); + + // AnalysisException case: no job ran, so the span was never created + if (span == null && sqlQueries.containsKey(sqlEnd.executionId())) { + span = getOrCreateSqlSpan(sqlEnd.executionId(), null, null); + sqlSpans.remove(sqlEnd.executionId()); + } + SparkAggregatedTaskMetrics metrics = sqlMetrics.remove(sqlEnd.executionId()); sqlQueries.remove(sqlEnd.executionId()); sqlPlans.remove(sqlEnd.executionId()); @@ -899,12 +947,44 @@ private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd) if (metrics != null) { metrics.setSpanMetrics(span); } + + String errorMsg = getSqlEndErrorMessage(sqlEnd); + if (errorMsg != null) { + String sqlErrorMessage = getErrorMessageWithoutStackTrace(errorMsg); + span.setError(true); + span.setErrorMessage(sqlErrorMessage); + span.setTag(DDTags.ERROR_STACK, errorMsg); + span.setTag(DDTags.ERROR_TYPE, "Spark SQL Failed"); + + if (sqlErrorMessage == null || !sqlErrorMessage.toLowerCase().contains("cancelled")) { + lastSqlFailed = true; + lastSqlFailedMessage = sqlErrorMessage; + lastSqlFailedStackTrace = errorMsg; + } + } else { + lastSqlFailed = false; + } + notifyOl(x -> openLineageSparkListener.onOtherEvent(x), sqlEnd); span.finish(sqlEnd.time() * 1000); } } + private static String getSqlEndErrorMessage(SparkListenerSQLExecutionEnd sqlEnd) { + if (sqlEndErrorMessageMethod == null) { + return null; + } + try { + scala.Option errorMessage = (scala.Option) sqlEndErrorMessageMethod.invoke(sqlEnd); + if (errorMessage.isDefined()) { + return (String) errorMessage.get(); + } + } catch (Throwable ignored) { + } + return null; + } + private synchronized void onStreamingQueryStartedEvent( StreamingQueryListener.QueryStartedEvent event) { if (streamingQueries.size() > MAX_COLLECTION_SIZE) { diff --git a/dd-java-agent/instrumentation/spark/spark-common/src/testFixtures/groovy/datadog/trace/instrumentation/spark/AbstractSparkListenerTest.groovy b/dd-java-agent/instrumentation/spark/spark-common/src/testFixtures/groovy/datadog/trace/instrumentation/spark/AbstractSparkListenerTest.groovy index 16bd10d6ec1..e1f501df2c3 100644 --- a/dd-java-agent/instrumentation/spark/spark-common/src/testFixtures/groovy/datadog/trace/instrumentation/spark/AbstractSparkListenerTest.groovy +++ b/dd-java-agent/instrumentation/spark/spark-common/src/testFixtures/groovy/datadog/trace/instrumentation/spark/AbstractSparkListenerTest.groovy @@ -24,6 +24,9 @@ import org.apache.spark.scheduler.StageInfo import org.apache.spark.scheduler.TaskInfo import org.apache.spark.scheduler.TaskLocality import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart import org.apache.spark.storage.RDDInfo import scala.Option import scala.collection.immutable.HashMap @@ -652,6 +655,109 @@ abstract class AbstractSparkListenerTest extends InstrumentationSpecification { .getOption("spark.openlineage.circuitBreaker.timeoutInSeconds") == Option.apply("120") } + protected sqlExecutionStartEvent(Long executionId, Long time, String description = "SELECT * FROM test") { + def emptySeq = JavaConverters.asScalaBuffer([]).toSeq() + def emptyMap = new scala.collection.immutable.HashMap() + def sparkPlanInfo = new SparkPlanInfo("TestPlan", "TestPlan", emptySeq, emptyMap, emptySeq) + + return new SparkListenerSQLExecutionStart( + executionId, + description, + "details", + "physical plan", + sparkPlanInfo, + time + ) + } + + protected sqlExecutionEndEvent(Long executionId, Long time) { + return new SparkListenerSQLExecutionEnd(executionId, time) + } + + def "test SQL span created when no job runs"() { + setup: + def listener = getTestDatadogSparkListener() + listener.onApplicationStart(applicationStartEvent(1000L)) + listener.onOtherEvent(sqlExecutionStartEvent(1L, 2000L, "SELECT * FROM missing_table")) + listener.onOtherEvent(sqlExecutionEndEvent(1L, 3000L)) + listener.onApplicationEnd(new SparkListenerApplicationEnd(4000L)) + + expect: + assertTraces(1) { + trace(2) { + span { + operationName "spark.application" + resourceName "spark.application" + spanType "spark" + parent() + } + span { + operationName "spark.sql" + resourceName "SELECT * FROM missing_table" + spanType "spark" + errored false + childOf(span(0)) + } + } + } + } + + def "test job failure propagates error to SQL span"() { + setup: + def listener = getTestDatadogSparkListener() + listener.onApplicationStart(applicationStartEvent(1000L)) + + // SQL execution with a job that fails + listener.onOtherEvent(sqlExecutionStartEvent(1L, 2000L, "SELECT * FROM bad_table")) + listener.onJobStart(jobStartEventWithSql(1, 2100L, [1], 1L)) + listener.onJobEnd(jobFailedEvent(1, 2500L, "Table not found")) + listener.onOtherEvent(sqlExecutionEndEvent(1L, 3000L)) + + listener.onApplicationEnd(new SparkListenerApplicationEnd(4000L)) + + expect: + assertTraces(1) { + trace(3) { + span { + operationName "spark.application" + resourceName "spark.application" + spanType "spark" + errored true + parent() + } + span { + operationName "spark.sql" + resourceName "SELECT * FROM bad_table" + spanType "spark" + errored true + assert span.tags["error.type"] == "Spark SQL Failed" + childOf(span(0)) + } + span { + operationName "spark.job" + spanType "spark" + errored true + childOf(span(1)) + } + } + } + } + + protected jobStartEventWithSql(Integer jobId, Long time, ArrayList stageIds, Long sqlExecutionId) { + def stageInfos = stageIds.collect { stageId -> + createStageInfo(stageId) + } + def props = new Properties() + props.setProperty("spark.sql.execution.id", sqlExecutionId.toString()) + + return new SparkListenerJobStart( + jobId, + time, + JavaConverters.asScalaBuffer(stageInfos).toSeq(), + props + ) + } + protected validateRelativeError(double value, double expected, double relativeAccuracy) { double relativeError = Math.abs(value - expected) / expected assert relativeError < relativeAccuracy