55 */
66package org .apache .spark .sql .crossdata .execution
77
8+ import java .io ._
9+
810import com .stratio .common .utils .components .logger .impl .Slf4jLoggerComponent
911import com .stratio .crossdata .common .profiler .PerformanceLogger
1012import com .stratio .crossdata .connector .NativeScan
1113import com .stratio .crossdata .metrics .{MetricsGlossary , MetricsRegister }
14+ import org .apache .spark .SparkEnv
15+ import org .apache .spark .io .CompressionCodec
1216import org .apache .spark .rdd .RDD
1317import org .apache .spark .sql .catalyst .InternalRow
14- import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeSet , GetMapValue , GetStructField , Literal }
18+ import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeSet , GetMapValue , GetStructField , Literal , UnsafeProjection , UnsafeRow }
1519import org .apache .spark .sql .catalyst .plans .QueryPlan
1620import org .apache .spark .sql .catalyst .plans .logical ._
1721import org .apache .spark .sql .crossdata .execution .command .XDExplainCommand
18- import org .apache .spark .sql .execution . SparkPlan
22+ import org .apache .spark .sql .crossdata . serializers . CustomStreamedRow
1923import org .apache .spark .sql .execution .datasources ._
24+ import org .apache .spark .sql .execution .{DeserializeToObjectExec , InputAdapter , LocalLimitExec , MapPartitionsExec , SerializeFromObjectExec , SparkPlan , WholeStageCodegenExec }
25+ import org .apache .spark .sql .types .IntegerType
2026
27+ import scala .collection .mutable .ArrayBuffer
2128import scala .util .{Failure , Success , Try }
2229
2330case class XDPlan (@ transient xdQueryExecution : XDQueryExecution ,
@@ -31,6 +38,8 @@ case class XDPlan(@transient xdQueryExecution: XDQueryExecution,
3138
3239 private lazy val nativeQueryExecutor : Option [NativeScan ] = findNativeQueryExecutor(analyzedPlan, isNativeQueriesEnabled)
3340
41+ private lazy val applyXDLimitRule = xdQueryExecution.sparkSession.sparkContext.conf.getBoolean(" spark.sql.crossdata.limitRule" , false )
42+
3443 private lazy val usablePlan : QueryPlan [_] =
3544 if (nativeQueryExecutor.exists(x => supportedPlan(x, analyzedPlan))) {
3645 analyzedPlan
@@ -81,6 +90,147 @@ case class XDPlan(@transient xdQueryExecution: XDQueryExecution,
8190 }
8291 }
8392
93+ /**
94+ * Decode the byte arrays back to UnsafeRows and put them into buffer.
95+ *
96+ * NOTE: Great part of this code is a copy from the [[SparkPlan ]].decodeUnsafeRows(bytes).
97+ *
98+ */
99+ private def decodeUnsafeRows (bytes : Array [Byte ]): Iterator [InternalRow ] = {
100+ val nFields = schema.length
101+
102+ val codec = CompressionCodec .createCodec(SparkEnv .get.conf)
103+ val bis = new ByteArrayInputStream (bytes)
104+ val ins = new DataInputStream (codec.compressedInputStream(bis))
105+
106+ new Iterator [InternalRow ] {
107+ private var sizeOfNextRow = ins.readInt()
108+ override def hasNext : Boolean = sizeOfNextRow >= 0
109+ override def next (): InternalRow = {
110+ val bs = new Array [Byte ](sizeOfNextRow)
111+ ins.readFully(bs)
112+ val row = new UnsafeRow (nFields)
113+ row.pointTo(bs, sizeOfNextRow)
114+ sizeOfNextRow = ins.readInt()
115+ row
116+ }
117+ }
118+ }
119+
120+ /**
121+ * Packing the UnsafeRows into byte array for faster serialization.
122+ * The byte arrays are in the following format:
123+ * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
124+ *
125+ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
126+ * compressed.
127+ *
128+ * NOTE: Great part of this code is a copy from the [[SparkPlan ]].getByteArrayRdd(n).
129+ *
130+ * @param executedPlan usable executed plan.
131+ * @param limit global limit.
132+ * @return executed plan result.
133+ */
134+ private def getByteArrayRdd (executedPlan : SparkPlan , limit : Int = - 1 ): RDD [Array [Byte ]] = {
135+ executedPlan.execute().mapPartitionsInternal { iter =>
136+ var count = 0
137+ val buffer = new Array [Byte ](4 << 10 ) // 4K
138+ val codec = CompressionCodec .createCodec(SparkEnv .get.conf)
139+ val bos = new ByteArrayOutputStream ()
140+ val out = new DataOutputStream (codec.compressedOutputStream(bos))
141+ while (iter.hasNext && (limit < 0 || count < limit)) {
142+ val row = iter.next().asInstanceOf [UnsafeRow ]
143+ out.writeInt(row.getSizeInBytes)
144+ row.writeToStream(out, buffer)
145+ count += 1
146+ }
147+ out.writeInt(- 1 )
148+ out.flush()
149+ out.close()
150+ Iterator (bos.toByteArray)
151+ }
152+ }
153+
154+ /**
155+ * Fetch data partition by partition from a specific plan until the number of rows reaches a specific limit.
156+ *
157+ * NOTE: Great part of this code is a copy from the [[SparkPlan.executeTake(n) ]].
158+ *
159+ * @param executedPlan usable executed plan.
160+ * @param limit global limit.
161+ * @return executed plan result.
162+ */
163+ private def incrementalExecute (executedPlan : SparkPlan , limit : Int ): Array [InternalRow ] = {
164+ if (limit == 0 ) {
165+ return new Array [InternalRow ](0 )
166+ }
167+
168+ import org .json4s ._
169+ import org .json4s .jackson .JsonMethods ._
170+ import org .json4s .jackson .Serialization ._
171+ implicit val formats : Formats = DefaultFormats
172+
173+ val buf = new ArrayBuffer [InternalRow ]
174+ val childRDD = getByteArrayRdd(executedPlan, limit)
175+ val totalParts = childRDD.partitions.length
176+ var partsScanned = 0
177+ var numOfRows = 0
178+ while (numOfRows < limit && partsScanned < totalParts) {
179+ // The number of partitions to try in this iteration. It is ok for this number to be
180+ // greater than totalParts because we actually cap it at totalParts in runJob.
181+ var numPartsToTry = 1L
182+ if (partsScanned > 0 ) {
183+ // If we didn't find any rows after the previous iteration, quadruple and retry.
184+ // Otherwise, interpolate the number of partitions we need to try, but overestimate
185+ // it by 50%. We also cap the estimation in the end.
186+ val limitScaleUpFactor = Math .max(sqlContext.conf.limitScaleUpFactor, 2 )
187+ if (buf.isEmpty) {
188+ numPartsToTry = partsScanned * limitScaleUpFactor
189+ } else {
190+ // the left side of max is >=1 whenever partsScanned >= 2
191+ numPartsToTry = Math .max((1.5 * limit * partsScanned / buf.size).toInt - partsScanned, 1 )
192+ numPartsToTry = Math .min(numPartsToTry, partsScanned * limitScaleUpFactor)
193+ }
194+ }
195+
196+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
197+ val sc = sqlContext.sparkContext
198+
199+ val jobRDD = if (numOfRows == 0 ) childRDD else getByteArrayRdd(executedPlan, limit- numOfRows)
200+
201+ val res = sc.runJob(jobRDD, (it : Iterator [Array [Byte ]]) => if (it.hasNext) it.next() else Array .empty[Byte ], p)
202+
203+ res.foreach{ pres =>
204+ val iter = if (numOfRows < limit){
205+ decodeUnsafeRows(pres)
206+ } else {
207+ Iterator .empty
208+ }
209+
210+ while (iter.hasNext && numOfRows < limit){
211+ val row = iter.next()
212+ val numElements = row.getInt(0 )
213+
214+ if (numOfRows + numElements <= limit){
215+ numOfRows += numElements
216+ buf += row
217+ } else {
218+ collectFirst{
219+ case SerializeFromObjectExec (serializer, _) =>
220+ val remainingRows = limit - numOfRows
221+ val projection = UnsafeProjection .create(serializer)
222+ projection.initialize(0 )
223+ numOfRows += remainingRows
224+ buf += projection(InternalRow .apply((remainingRows, write((parse(row.getString(1 )).extract[List [CustomStreamedRow ]]).take(remainingRows)))))
225+ }
226+ }
227+ }
228+ }
229+ partsScanned += p.size
230+ }
231+ buf.toArray
232+ }
233+
84234 override def executeCollect (): Array [InternalRow ] = {
85235 nativeQueryExecutor match {
86236 case Some (nqe) =>
@@ -109,7 +259,13 @@ case class XDPlan(@transient xdQueryExecution: XDQueryExecution,
109259 case None =>
110260 logPerformance(" [XDPlan][executeSpark]" ) {
111261 MetricsRegister .countExecution(MetricsGlossary .Counter .`current_queries_spark_total`, discountWhenFinished = true ) {
112- Try (child.executeCollect())
262+ child match {
263+ case WholeStageCodegenExec (SerializeFromObjectExec (_, InputAdapter (MapPartitionsExec (_, _, DeserializeToObjectExec (_, _, WholeStageCodegenExec (LocalLimitExec (limit, _))))))) if applyXDLimitRule =>
264+ logger.debug(" Resolving query with incremental execution" )
265+ Try (incrementalExecute(child, limit))
266+ case _ =>
267+ Try (child.executeCollect())
268+ }
113269 }
114270 } match {
115271 case Success (r) => r
0 commit comments