Skip to content

Commit

Permalink
re-enable mistyte keys and make it optional
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohui sun authored and xiaohui sun committed Feb 12, 2025
1 parent cd9d1a5 commit 72a33db
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ class FetcherTest extends TestCase {
endDs: String,
namespace: String,
consistencyCheck: Boolean,
dropDsOnWrite: Boolean): Unit = {
dropDsOnWrite: Boolean,
misTypeKeys: Boolean = true): Unit = {
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
val tableUtils = TableUtils(spark)
Expand Down Expand Up @@ -564,7 +565,7 @@ class FetcherTest extends TestCase {
val lagMs = -100000
val laggedRequests = buildRequests(lagMs)
val laggedResponseDf =
FetcherTestUtil.joinResponses(spark, laggedRequests, mockApi, samplePercent = 5, logToHive = true)._2
FetcherTestUtil.joinResponses(spark, laggedRequests, mockApi, samplePercent = 5, logToHive = true, misTypeKeys = misTypeKeys)._2
val correctedLaggedResponse = laggedResponseDf
.withColumn("ts_lagged", laggedResponseDf.col("ts_millis") + lagMs)
.withColumn("ts_millis", col("ts_lagged"))
Expand Down Expand Up @@ -595,13 +596,13 @@ class FetcherTest extends TestCase {
|""".stripMargin)
}
// benchmark
FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, useJavaFetcher = true)
FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10)
FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, useJavaFetcher = true, misTypeKeys = misTypeKeys)
FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, misTypeKeys = misTypeKeys)

// comparison
val columns = endDsExpected.schema.fields.map(_.name)
val responseRows: Seq[Row] =
FetcherTestUtil.joinResponses(spark, requests, mockApi, useJavaFetcher = true, debug = true)._1.map { res =>
FetcherTestUtil.joinResponses(spark, requests, mockApi, useJavaFetcher = true, debug = true, misTypeKeys = misTypeKeys)._1.map { res =>
val all: Map[String, AnyRef] =
res.request.keys ++
res.values.get ++
Expand Down Expand Up @@ -657,18 +658,18 @@ class FetcherTest extends TestCase {
)
joinConf.setDerivations(derivations.toJava)

compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, misTypeKeys = false)
}

def testTemporalFetchJoinDerivationRenameOnly(): Unit = {
val namespace = "derivation_fetch"
val namespace = "derivation_fetch_rename_only"
val joinConf = generateMutationData(namespace)
val derivations = Seq(Builders.Derivation(name = "*", expression = "*"),
Builders.Derivation(name = "listing_id_renamed", expression = "listing_id")
)
joinConf.setDerivations(derivations.toJava)

compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, misTypeKeys = false)
}


Expand Down Expand Up @@ -724,7 +725,8 @@ object FetcherTestUtil {
runCount: Int = 1,
samplePercent: Double = -1,
logToHive: Boolean = false,
debug: Boolean = false)(implicit ec: ExecutionContext): (List[Response], DataFrame) = {
debug: Boolean = false,
misTypeKeys: Boolean = true)(implicit ec: ExecutionContext): (List[Response], DataFrame) = {
val chunkSize = 100
@transient lazy val fetcher = mockApi.buildFetcher(debug)
@transient lazy val javaFetcher = mockApi.buildJavaFetcher()
Expand All @@ -736,7 +738,15 @@ object FetcherTestUtil {
val result = requests.iterator
.grouped(chunkSize)
.map { oldReqs =>
val r = oldReqs
// deliberately mis-type a few keys
val r = if (misTypeKeys) {
oldReqs
.map(r =>
r.copy(keys = r.keys.mapValues { v =>
if (v.isInstanceOf[java.lang.Long]) v.toString else v
}.toMap))
} else oldReqs

val responses = if (useJavaFetcher) {
// Converting to java request and using the toScalaRequest functionality to test conversion
val convertedJavaRequests = r.map(new JavaRequest(_)).toJava
Expand All @@ -753,7 +763,11 @@ object FetcherTestUtil {
fetcher.fetchJoin(r)
}

System.currentTimeMillis() -> responses
// fix mis-typed keys in the request
val fixedResponses = if (misTypeKeys) {
responses.map(resps => resps.zip(oldReqs).map { case (resp, req) => resp.copy(request = req) })
} else responses
System.currentTimeMillis() -> fixedResponses
}
.flatMap {
case (start, future) =>
Expand Down

0 comments on commit 72a33db

Please sign in to comment.