diff --git a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala new file mode 100644 index 000000000..aff190ecb --- /dev/null +++ b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.datastax.spark.connector.datasource + +import scala.collection.mutable +import com.datastax.spark.connector._ +import com.datastax.spark.connector.cluster.DefaultCluster +import com.datastax.spark.connector.cql.CassandraConnector +import org.scalatest.BeforeAndAfterEach +import com.datastax.spark.connector.datasource.CassandraCatalog +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import com.datastax.spark.connector.cql.CassandraConnector +import org.apache.spark.sql.SparkSession + + +class CassandraCatalogMetricsSpec extends SparkCassandraITFlatSpecBase with DefaultCluster with BeforeAndAfterEach { + + override lazy val conn = CassandraConnector(defaultConf) + + override lazy val spark = SparkSession.builder() + .config(sparkConf + // Enable Codahale/Dropwizard metrics + .set("spark.metrics.conf.executor.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") + .set("spark.metrics.conf.driver.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") + .set("spark.sql.sources.useV1SourceList", "") + .set("spark.sql.defaultCatalog", "cassandra") + .set("spark.sql.catalog.cassandra", classOf[CassandraCatalog].getCanonicalName) + ) + .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession() + + override def beforeClass { + conn.withSessionDo { session => + session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }") + session.execute(s"CREATE TABLE IF NOT EXISTS $ks.leftjoin (key INT, x INT, PRIMARY KEY (key))") + for (i <- 1 to 1000 * 10) { + session.execute(s"INSERT INTO $ks.leftjoin (key, x) values ($i, $i)") + } + } + } + + var readRowCount: Long = 0 + var readByteCount: Long = 0 + + it should "update Codahale read metrics for SELECT queries" in { + val df = spark.sql(s"SELECT x FROM $ks.leftjoin LIMIT 2") + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => + val tc = org.apache.spark.TaskContext.get() + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) + } + + val metrics = metricsRDD.collect() + readRowCount = metrics.map(_._1).sum - readRowCount + readByteCount = metrics.map(_._2).sum - readByteCount + + assert(readRowCount > 0) + assert(readByteCount == readRowCount * 4) // 4 bytes per INT result + } + + it should "update Codahale read metrics for COUNT queries" in { + val df = spark.sql(s"SELECT COUNT(*) FROM $ks.leftjoin") + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => + val tc = org.apache.spark.TaskContext.get() + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) + } + + val metrics = metricsRDD.collect() + readRowCount = metrics.map(_._1).sum - readRowCount + readByteCount = metrics.map(_._2).sum - readByteCount + + assert(readRowCount > 0) + assert(readByteCount == readRowCount * 8) // 8 bytes per COUNT result + } +} diff --git a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala index 5cfed4bb7..8b51cbe1e 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala @@ -78,17 +78,17 @@ class CassandraCatalogTableReadSpec extends CassandraCatalogSpecBase { it should "handle count pushdowns" in { setupBasicTable() val request = spark.sql(s"""SELECT COUNT(*) from $defaultKs.$testTable""") - val reader = request + var factory = request .queryExecution .executedPlan .collectFirst { - case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition) + case batchScanExec: BatchScanExec=> batchScanExec.readerFactory case adaptiveSparkPlanExec: AdaptiveSparkPlanExec => adaptiveSparkPlanExec.executedPlan.collectLeaves().collectFirst{ - case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition) + case batchScanExec: BatchScanExec=> batchScanExec.readerFactory }.get } - reader.get.isInstanceOf[CassandraCountPartitionReader] should be (true) + factory.get.asInstanceOf[CassandraScanPartitionReaderFactory].isCountQuery should be (true) request.collect()(0).get(0) should be (101) } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala index ce67c2dd5..2dc663de5 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.sources.In import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.metrics.InputMetricsUpdater +import org.apache.spark.TaskContext import scala.util.{Failure, Success} @@ -80,6 +82,7 @@ abstract class CassandraBaseInJoinReader( protected val maybeRateLimit = JoinHelper.maybeRateLimit(readConf) protected val requestsPerSecondRateLimiter = JoinHelper.requestsPerSecondRateLimiter(readConf) + protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf) protected def pairWithRight(left: CassandraRow): SettableFuture[Iterator[(CassandraRow, InternalRow)]] = { val resultFuture = SettableFuture.create[Iterator[(CassandraRow, InternalRow)]] val leftSide = Iterator.continually(left) @@ -87,9 +90,10 @@ abstract class CassandraBaseInJoinReader( queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete { case Success(rs) => val resultSet = new PrefetchingResultSetIterator(rs) + val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics) /* This is a much less than ideal place to actually rate limit, we are buffering these futures this means we will most likely exceed our threshold*/ - val throttledIterator = resultSet.map(maybeRateLimit) + val throttledIterator = iteratorWithMetrics.map(maybeRateLimit) val rightSide = throttledIterator.map(rowReader.read(_, rowMetadata)) resultFuture.set(leftSide.zip(rightSide)) case Failure(throwable) => @@ -121,6 +125,7 @@ abstract class CassandraBaseInJoinReader( override def get(): InternalRow = currentRow override def close(): Unit = { + metricsUpdater.finish() session.close() } } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala index 104137727..8244e2f1a 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala @@ -30,6 +30,8 @@ import com.datastax.spark.connector.util.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.metrics.InputMetricsUpdater +import org.apache.spark.TaskContext case class CassandraScanPartitionReaderFactory( connector: CassandraConnector, @@ -38,10 +40,12 @@ case class CassandraScanPartitionReaderFactory( readConf: ReadConf, queryParts: CqlQueryParts) extends PartitionReaderFactory { + def isCountQuery: Boolean = queryParts.selectedColumnRefs.contains(RowCountRef) + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val cassandraPartition = partition.asInstanceOf[CassandraPartition[Any, _ <: Token[Any]]] - if (queryParts.selectedColumnRefs.contains(RowCountRef)) { + if (isCountQuery) { //Count Pushdown CassandraCountPartitionReader( connector, @@ -79,6 +83,8 @@ abstract class CassandraPartitionReaderBase protected val rowIterator = getIterator() protected var lastRow: InternalRow = InternalRow() + protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf) + override def next(): Boolean = { if (rowIterator.hasNext) { lastRow = rowIterator.next() @@ -91,6 +97,7 @@ abstract class CassandraPartitionReaderBase override def get(): InternalRow = lastRow override def close(): Unit = { + metricsUpdater.finish() scanner.close() } @@ -125,7 +132,8 @@ abstract class CassandraPartitionReaderBase tokenRanges.iterator.flatMap { range => val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, readConf.consistencyLevel, readConf.fetchSizeInRows) val meta = scanResult.metadata - scanResult.rows.map(rowReader.read(_, meta)) + val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics) + iteratorWithMetrics.map(rowReader.read(_, meta)) } } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala index 8fb63b69c..7b4ed59fb 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType +import org.apache.spark.metrics.OutputMetricsUpdater +import org.apache.spark.TaskContext case class CassandraDriverDataWriterFactory( connector: CassandraConnector, @@ -54,22 +56,31 @@ case class CassandraDriverDataWriter( private val columns = SomeColumns(inputSchema.fieldNames.map(name => ColumnName(name)): _*) - private val writer = + private val metricsUpdater = OutputMetricsUpdater(TaskContext.get(), writeConf) + + private val asycWriter = TableWriter(connector, tableDef, columns, writeConf, false)(unsafeRowWriterFactory) .getAsyncWriter() + private val writer = asycWriter.copy( + successHandler = Some(metricsUpdater.batchFinished(success = true, _, _, _)), + failureHandler = Some(metricsUpdater.batchFinished(success = false, _, _, _))) + override def write(record: InternalRow): Unit = writer.write(record) override def commit(): WriterCommitMessage = { + metricsUpdater.finish() writer.close() CassandraCommitMessage() } override def abort(): Unit = { + metricsUpdater.finish() writer.close() } override def close(): Unit = { + metricsUpdater.finish() //Our proxy Session Handler handles double closes by ignoring them so this is fine writer.close() }