Skip to content

Commit 13aa686

Browse files
committed
Add ability to support Aplpha sketches
1 parent bf2457b commit 13aa686

File tree

7 files changed

+182
-30
lines changed

7 files changed

+182
-30
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5605,6 +5605,12 @@
56055605
],
56065606
"sqlState" : "428EK"
56075607
},
5608+
"THETA_INVALID_FAMILY" : {
5609+
"message" : [
5610+
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>. Got: <value>."
5611+
],
5612+
"sqlState" : "22546"
5613+
},
56085614
"THETA_INVALID_INPUT_SKETCH_BUFFER" : {
56095615
"message" : [
56105616
"Invalid call to <function>; only valid Theta sketch buffers are supported as inputs (such as those produced by the `theta_sketch_agg` function)."

python/pyspark/sql/functions/builtin.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25941,10 +25941,12 @@ def hll_union(
2594125941
def theta_sketch_agg(
2594225942
col: "ColumnOrName",
2594325943
lgNomEntries: Optional[Union[int, Column]] = None,
25944+
family: Optional[str] = None,
2594425945
) -> Column:
2594525946
"""
2594625947
Aggregate function: returns the compact binary representation of the Datasketches
25947-
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries.
25948+
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries
25949+
and the specified sketch family.
2594825950

2594925951
.. versionadded:: 4.1.0
2595025952

@@ -25954,6 +25956,8 @@ def theta_sketch_agg(
2595425956
lgNomEntries : :class:`~pyspark.sql.Column` or int, optional
2595525957
The log-base-2 of nominal entries, where nominal entries is the size of the sketch
2595625958
(must be between 4 and 26, defaults to 12)
25959+
family : str, optional
25960+
The sketch family: 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').
2595725961

2595825962
Returns
2595925963
-------
@@ -25986,12 +25990,23 @@ def theta_sketch_agg(
2598625990
+--------------------------------------------------+
2598725991
| 3|
2598825992
+--------------------------------------------------+
25993+
25994+
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15, "ALPHA"))).show()
25995+
+-------------------------------------------------------+
25996+
|theta_sketch_estimate(theta_sketch_agg(value, 15, AL..|
25997+
+-------------------------------------------------------+
25998+
| 3|
25999+
+-------------------------------------------------------+
2598926000
"""
2599026001
fn = "theta_sketch_agg"
25991-
if lgNomEntries is None:
26002+
if lgNomEntries is None and family is None:
2599226003
return _invoke_function_over_columns(fn, col)
25993-
else:
26004+
elif family is None:
2599426005
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
26006+
else:
26007+
if lgNomEntries is None:
26008+
lgNomEntries = 12 # default value
26009+
return _invoke_function_over_columns(fn, col, lit(lgNomEntries), lit(family))
2599526010

2599626011

2599726012
@_try_remote_functions

sql/api/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,17 @@ object functions {
11981198
def theta_sketch_agg(e: Column, lgNomEntries: Column): Column =
11991199
Column.fn("theta_sketch_agg", e, lgNomEntries)
12001200

1201+
/**
1202+
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
1203+
* built with the values in the input column and configured with the `lgNomEntries` nominal
1204+
* entries and `family`.
1205+
*
1206+
* @group agg_funcs
1207+
* @since 4.1.0
1208+
*/
1209+
def theta_sketch_agg(e: Column, lgNomEntries: Column, family: Column): Column =
1210+
Column.fn("theta_sketch_agg", e, lgNomEntries, family)
1211+
12011212
/**
12021213
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
12031214
* built with the values in the input column and configured with the `lgNomEntries` nominal
@@ -1242,6 +1253,47 @@ object functions {
12421253
def theta_sketch_agg(columnName: String): Column =
12431254
theta_sketch_agg(Column(columnName))
12441255

1256+
/**
1257+
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
1258+
* built with the values in the input column, configured with `lgNomEntries` and `family`.
1259+
*
1260+
* @group agg_funcs
1261+
* @since 4.1.0
1262+
*/
1263+
def theta_sketch_agg(e: Column, lgNomEntries: Int, family: String): Column =
1264+
Column.fn("theta_sketch_agg", e, lit(lgNomEntries), lit(family))
1265+
1266+
/**
1267+
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
1268+
* built with the values in the input column, configured with `lgNomEntries` and `family`.
1269+
*
1270+
* @group agg_funcs
1271+
* @since 4.1.0
1272+
*/
1273+
def theta_sketch_agg(columnName: String, lgNomEntries: Int, family: String): Column =
1274+
theta_sketch_agg(Column(columnName), lgNomEntries, family)
1275+
1276+
/**
1277+
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
1278+
* built with the values in the input column, configured with the specified `family` and default
1279+
* lgNomEntries.
1280+
*
1281+
* @group agg_funcs
1282+
* @since 4.1.0
1283+
*/
1284+
def theta_sketch_agg(e: Column, family: String): Column =
1285+
theta_sketch_agg(e, 12, family)
1286+
1287+
/**
1288+
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
1289+
* built with the values in the input column, configured with specified `family`.
1290+
*
1291+
* @group agg_funcs
1292+
* @since 4.1.0
1293+
*/
1294+
def theta_sketch_agg(columnName: String, family: String): Column =
1295+
theta_sketch_agg(columnName, 12, family)
1296+
12451297
/**
12461298
* Aggregate function: returns the compact binary representation of the Datasketches
12471299
* ThetaSketch, generated by the union of Datasketches ThetaSketch instances in the input column

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20+
import org.apache.datasketches.common.Family
2021
import org.apache.datasketches.memory.Memory
2122
import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder}
2223

2324
import org.apache.spark.SparkUnsupportedOperationException
2425
import org.apache.spark.sql.catalyst.InternalRow
2526
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
26-
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
2727
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
2828
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils}
2929
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
5959
*
6060
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information.
6161
*
62-
* @param left
62+
* @param child
6363
* child expression against which unique counting will occur
64-
* @param right
64+
* @param lgNomEntriesExpr
6565
* the log-base-2 of nomEntries decides the number of buckets for the sketch
66+
* @param familyExpr
67+
* the family of the sketch (QUICKSELECT or ALPHA)
6668
* @param mutableAggBufferOffset
6769
* offset for mutable aggregation buffer
6870
* @param inputAggBufferOffset
@@ -71,46 +73,66 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
7173
// scalastyle:off line.size.limit
7274
@ExpressionDescription(
7375
usage = """
74-
_FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation.
76+
_FUNC_(expr, lgNomEntries, family) - Returns the ThetaSketch compact binary representation.
7577
`lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding
76-
the number buckets or slots for the ThetaSketch. """,
78+
the number buckets or slots for the ThetaSketch.
79+
`family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').
80+
Note: You can pass family as the second parameter to use default lgNomEntries with a specific family.""",
7781
examples = """
7882
Examples:
83+
> SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
84+
3
7985
> SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
8086
3
87+
> SELECT theta_sketch_estimate(_FUNC_(col, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
88+
3
89+
> SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
90+
3
8191
""",
8292
group = "agg_funcs",
8393
since = "4.1.0")
8494
// scalastyle:on line.size.limit
8595
case class ThetaSketchAgg(
86-
left: Expression,
87-
right: Expression,
96+
child: Expression,
97+
lgNomEntriesExpr: Expression,
98+
familyExpr: Expression,
8899
override val mutableAggBufferOffset: Int,
89100
override val inputAggBufferOffset: Int)
90101
extends TypedImperativeAggregate[ThetaSketchState]
91-
with BinaryLike[Expression]
92102
with ExpectsInputTypes {
93103

94104
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
95105

96-
lazy val lgNomEntries: Int = {
97-
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
106+
private lazy val lgNomEntries: Int = {
107+
val lgNomEntriesInput = lgNomEntriesExpr.eval().asInstanceOf[Int]
98108
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
99109
lgNomEntriesInput
100110
}
101111

102-
// Constructors
112+
private lazy val family: Family =
113+
ThetaSketchUtils.parseFamily(familyExpr.eval().asInstanceOf[UTF8String].toString, prettyName)
103114

115+
// Constructors
104116
def this(child: Expression) = {
105-
this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0)
117+
this(child,
118+
Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS),
119+
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
120+
0, 0)
106121
}
107122

108123
def this(child: Expression, lgNomEntries: Expression) = {
109-
this(child, lgNomEntries, 0, 0)
124+
this(child,
125+
lgNomEntries,
126+
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
127+
0, 0)
128+
}
129+
130+
def this(child: Expression, lgNomEntries: Expression, family: Expression) = {
131+
this(child, lgNomEntries, family, 0, 0)
110132
}
111133

112134
def this(child: Expression, lgNomEntries: Int) = {
113-
this(child, Literal(lgNomEntries), 0, 0)
135+
this(child, Literal(lgNomEntries))
114136
}
115137

116138
// Copy constructors required by ImperativeAggregate
@@ -122,15 +144,16 @@ case class ThetaSketchAgg(
122144
copy(inputAggBufferOffset = newInputAggBufferOffset)
123145

124146
override protected def withNewChildrenInternal(
125-
newLeft: Expression,
126-
newRight: Expression): ThetaSketchAgg =
127-
copy(left = newLeft, right = newRight)
147+
newChildren: IndexedSeq[Expression]): ThetaSketchAgg =
148+
copy(child = newChildren(0), lgNomEntriesExpr = newChildren(1), familyExpr = newChildren(2))
149+
150+
override def children: Seq[Expression] = Seq(child, lgNomEntriesExpr, familyExpr)
128151

129152
// Overrides for TypedImperativeAggregate
130153

131154
override def prettyName: String = "theta_sketch_agg"
132155

133-
override def inputTypes: Seq[AbstractDataType] =
156+
override def inputTypes: Seq[AbstractDataType] = {
134157
Seq(
135158
TypeCollection(
136159
ArrayType(IntegerType),
@@ -141,21 +164,24 @@ case class ThetaSketchAgg(
141164
IntegerType,
142165
LongType,
143166
StringTypeWithCollation(supportsTrimCollation = true)),
144-
IntegerType)
167+
IntegerType,
168+
StringType)
169+
}
145170

146171
override def dataType: DataType = BinaryType
147172

148173
override def nullable: Boolean = false
149174

150175
/**
151-
* Instantiate an UpdateSketch instance using the lgNomEntries param.
176+
* Instantiate an UpdateSketch instance using the lgNomEntries and family params.
152177
*
153178
* @return
154179
* an UpdateSketch instance wrapped with UpdatableSketchBuffer
155180
*/
156181
override def createAggregationBuffer(): ThetaSketchState = {
157182
val builder = new UpdateSketchBuilder
158183
builder.setLogNominalEntries(lgNomEntries)
184+
builder.setFamily(family)
159185
UpdatableSketchBuffer(builder.build)
160186
}
161187

@@ -176,7 +202,7 @@ case class ThetaSketchAgg(
176202
*/
177203
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = {
178204
// Return early for null values.
179-
val v = left.eval(input)
205+
val v = child.eval(input)
180206
if (v == null) return updateBuffer
181207

182208
// Initialized buffer should be UpdatableSketchBuffer, else error out.
@@ -186,7 +212,7 @@ case class ThetaSketchAgg(
186212
}
187213

188214
// Handle the different data types for sketch updates.
189-
left.dataType match {
215+
child.dataType match {
190216
case ArrayType(IntegerType, _) =>
191217
val arr = v.asInstanceOf[ArrayData].toIntArray()
192218
sketch.update(arr)
@@ -213,7 +239,7 @@ case class ThetaSketchAgg(
213239
case _ =>
214240
throw new SparkUnsupportedOperationException(
215241
errorClass = "_LEGACY_ERROR_TEMP_3121",
216-
messageParameters = Map("dataType" -> left.dataType.toString))
242+
messageParameters = Map("dataType" -> child.dataType.toString))
217243
}
218244

219245
UpdatableSketchBuffer(sketch)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ThetaSketchUtils.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.datasketches.common.SketchesArgumentException
20+
import java.util.Locale
21+
22+
import org.apache.datasketches.common.{Family, SketchesArgumentException}
2123
import org.apache.datasketches.memory.{Memory, MemoryBoundsException}
2224
import org.apache.datasketches.theta.CompactSketch
2325

2426
import org.apache.spark.sql.errors.QueryExecutionErrors
2527

28+
2629
object ThetaSketchUtils {
2730
/*
2831
* Bounds copied from DataSketches' ThetaUtil. These define the valid range for lgNomEntries,
@@ -36,6 +39,11 @@ object ThetaSketchUtils {
3639
final val MAX_LG_NOM_LONGS = 26
3740
final val DEFAULT_LG_NOM_LONGS = 12
3841

42+
// Family constants for ThetaSketch
43+
final val FAMILY_QUICKSELECT = "QUICKSELECT"
44+
final val FAMILY_ALPHA = "ALPHA"
45+
final val DEFAULT_FAMILY = FAMILY_QUICKSELECT
46+
3947
/**
4048
* Validates the lgNomLongs parameter for Theta sketch size. Throws a Spark SQL exception if the
4149
* value is out of bounds.
@@ -53,6 +61,26 @@ object ThetaSketchUtils {
5361
}
5462
}
5563

64+
/**
65+
* Converts a family string to DataSketches Family enum.
66+
* Throws a Spark SQL exception if the family name is invalid.
67+
*
68+
* @param familyName The family name string
69+
* @param prettyName The display name of the function/expression for error messages
70+
* @return The corresponding DataSketches Family enum value
71+
*/
72+
def parseFamily(familyName: String, prettyName: String): Family = {
73+
familyName.toUpperCase(Locale.ROOT) match {
74+
case FAMILY_QUICKSELECT => Family.QUICKSELECT
75+
case FAMILY_ALPHA => Family.ALPHA
76+
case _ =>
77+
throw QueryExecutionErrors.thetaInvalidFamily(
78+
function = prettyName,
79+
value = familyName,
80+
validFamilies = Seq(FAMILY_QUICKSELECT, FAMILY_ALPHA))
81+
}
82+
}
83+
5684
/**
5785
* Wraps a byte array into a DataSketches CompactSketch object.
5886
* This method safely deserializes a compact Theta sketch from its binary representation,

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,4 +3136,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
31363136
"max" -> toSQLValue(max, IntegerType),
31373137
"value" -> toSQLValue(value, IntegerType)))
31383138
}
3139+
3140+
def thetaInvalidFamily(function: String, value: String, validFamilies: Seq[String]): Throwable = {
3141+
new SparkRuntimeException(
3142+
errorClass = "THETA_INVALID_FAMILY",
3143+
messageParameters = Map(
3144+
"function" -> toSQLId(function),
3145+
"value" -> toSQLValue(value, StringType),
3146+
"validFamilies" -> validFamilies.map(f => toSQLId(f)).mkString(", ")))
3147+
}
31393148
}

0 commit comments

Comments
 (0)