1717
1818package org .apache .spark .sql .catalyst .expressions .aggregate
1919
20+ import org .apache .datasketches .common .Family
2021import org .apache .datasketches .memory .Memory
2122import org .apache .datasketches .theta .{CompactSketch , Intersection , SetOperation , Sketch , Union , UpdateSketch , UpdateSketchBuilder }
2223
2324import org .apache .spark .SparkUnsupportedOperationException
2425import org .apache .spark .sql .catalyst .InternalRow
2526import org .apache .spark .sql .catalyst .expressions .{ExpectsInputTypes , Expression , ExpressionDescription , Literal }
26- import org .apache .spark .sql .catalyst .expressions .aggregate .TypedImperativeAggregate
2727import org .apache .spark .sql .catalyst .trees .{BinaryLike , UnaryLike }
2828import org .apache .spark .sql .catalyst .util .{ArrayData , CollationFactory , ThetaSketchUtils }
2929import org .apache .spark .sql .errors .QueryExecutionErrors
@@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
5959 *
6060 * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html ]] for more information.
6161 *
62- * @param left
62+ * @param child
6363 * child expression against which unique counting will occur
64- * @param right
64+ * @param lgNomEntriesExpr
6565 * the log-base-2 of nomEntries decides the number of buckets for the sketch
66+ * @param familyExpr
67+ * the family of the sketch (QUICKSELECT or ALPHA)
6668 * @param mutableAggBufferOffset
6769 * offset for mutable aggregation buffer
6870 * @param inputAggBufferOffset
@@ -71,46 +73,66 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
7173// scalastyle:off line.size.limit
7274@ ExpressionDescription (
7375 usage = """
74- _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation.
76+ _FUNC_(expr, lgNomEntries, family ) - Returns the ThetaSketch compact binary representation.
7577 `lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding
76- the number buckets or slots for the ThetaSketch. """ ,
78+ the number buckets or slots for the ThetaSketch.
79+ `family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').
80+ Note: You can pass family as the second parameter to use default lgNomEntries with a specific family.""" ,
7781 examples = """
7882 Examples:
83+ > SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
84+ 3
7985 > SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
8086 3
87+ > SELECT theta_sketch_estimate(_FUNC_(col, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
88+ 3
89+ > SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
90+ 3
8191 """ ,
8292 group = " agg_funcs" ,
8393 since = " 4.1.0" )
8494// scalastyle:on line.size.limit
8595case class ThetaSketchAgg (
86- left : Expression ,
87- right : Expression ,
96+ child : Expression ,
97+ lgNomEntriesExpr : Expression ,
98+ familyExpr : Expression ,
8899 override val mutableAggBufferOffset : Int ,
89100 override val inputAggBufferOffset : Int )
90101 extends TypedImperativeAggregate [ThetaSketchState ]
91- with BinaryLike [Expression ]
92102 with ExpectsInputTypes {
93103
94104 // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
95105
96- lazy val lgNomEntries : Int = {
97- val lgNomEntriesInput = right .eval().asInstanceOf [Int ]
106+ private lazy val lgNomEntries : Int = {
107+ val lgNomEntriesInput = lgNomEntriesExpr .eval().asInstanceOf [Int ]
98108 ThetaSketchUtils .checkLgNomLongs(lgNomEntriesInput, prettyName)
99109 lgNomEntriesInput
100110 }
101111
102- // Constructors
112+ private lazy val family : Family =
113+ ThetaSketchUtils .parseFamily(familyExpr.eval().asInstanceOf [UTF8String ].toString, prettyName)
103114
115+ // Constructors
104116 def this (child : Expression ) = {
105- this (child, Literal (ThetaSketchUtils .DEFAULT_LG_NOM_LONGS ), 0 , 0 )
117+ this (child,
118+ Literal (ThetaSketchUtils .DEFAULT_LG_NOM_LONGS ),
119+ Literal (UTF8String .fromString(ThetaSketchUtils .DEFAULT_FAMILY )),
120+ 0 , 0 )
106121 }
107122
108123 def this (child : Expression , lgNomEntries : Expression ) = {
109- this (child, lgNomEntries, 0 , 0 )
124+ this (child,
125+ lgNomEntries,
126+ Literal (UTF8String .fromString(ThetaSketchUtils .DEFAULT_FAMILY )),
127+ 0 , 0 )
128+ }
129+
130+ def this (child : Expression , lgNomEntries : Expression , family : Expression ) = {
131+ this (child, lgNomEntries, family, 0 , 0 )
110132 }
111133
112134 def this (child : Expression , lgNomEntries : Int ) = {
113- this (child, Literal (lgNomEntries), 0 , 0 )
135+ this (child, Literal (lgNomEntries))
114136 }
115137
116138 // Copy constructors required by ImperativeAggregate
@@ -122,15 +144,16 @@ case class ThetaSketchAgg(
122144 copy(inputAggBufferOffset = newInputAggBufferOffset)
123145
124146 override protected def withNewChildrenInternal (
125- newLeft : Expression ,
126- newRight : Expression ): ThetaSketchAgg =
127- copy(left = newLeft, right = newRight)
147+ newChildren : IndexedSeq [Expression ]): ThetaSketchAgg =
148+ copy(child = newChildren(0 ), lgNomEntriesExpr = newChildren(1 ), familyExpr = newChildren(2 ))
149+
150+ override def children : Seq [Expression ] = Seq (child, lgNomEntriesExpr, familyExpr)
128151
129152 // Overrides for TypedImperativeAggregate
130153
131154 override def prettyName : String = " theta_sketch_agg"
132155
133- override def inputTypes : Seq [AbstractDataType ] =
156+ override def inputTypes : Seq [AbstractDataType ] = {
134157 Seq (
135158 TypeCollection (
136159 ArrayType (IntegerType ),
@@ -141,21 +164,24 @@ case class ThetaSketchAgg(
141164 IntegerType ,
142165 LongType ,
143166 StringTypeWithCollation (supportsTrimCollation = true )),
144- IntegerType )
167+ IntegerType ,
168+ StringType )
169+ }
145170
146171 override def dataType : DataType = BinaryType
147172
148173 override def nullable : Boolean = false
149174
150175 /**
151- * Instantiate an UpdateSketch instance using the lgNomEntries param .
176+ * Instantiate an UpdateSketch instance using the lgNomEntries and family params .
152177 *
153178 * @return
154179 * an UpdateSketch instance wrapped with UpdatableSketchBuffer
155180 */
156181 override def createAggregationBuffer (): ThetaSketchState = {
157182 val builder = new UpdateSketchBuilder
158183 builder.setLogNominalEntries(lgNomEntries)
184+ builder.setFamily(family)
159185 UpdatableSketchBuffer (builder.build)
160186 }
161187
@@ -176,7 +202,7 @@ case class ThetaSketchAgg(
176202 */
177203 override def update (updateBuffer : ThetaSketchState , input : InternalRow ): ThetaSketchState = {
178204 // Return early for null values.
179- val v = left .eval(input)
205+ val v = child .eval(input)
180206 if (v == null ) return updateBuffer
181207
182208 // Initialized buffer should be UpdatableSketchBuffer, else error out.
@@ -186,7 +212,7 @@ case class ThetaSketchAgg(
186212 }
187213
188214 // Handle the different data types for sketch updates.
189- left .dataType match {
215+ child .dataType match {
190216 case ArrayType (IntegerType , _) =>
191217 val arr = v.asInstanceOf [ArrayData ].toIntArray()
192218 sketch.update(arr)
@@ -213,7 +239,7 @@ case class ThetaSketchAgg(
213239 case _ =>
214240 throw new SparkUnsupportedOperationException (
215241 errorClass = " _LEGACY_ERROR_TEMP_3121" ,
216- messageParameters = Map (" dataType" -> left .dataType.toString))
242+ messageParameters = Map (" dataType" -> child .dataType.toString))
217243 }
218244
219245 UpdatableSketchBuffer (sketch)
0 commit comments