diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 555ccbae9..3fda7ed12 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -512,7 +512,8 @@ class FetcherTest extends TestCase { endDs: String, namespace: String, consistencyCheck: Boolean, - dropDsOnWrite: Boolean): Unit = { + dropDsOnWrite: Boolean, + misTypeKeys: Boolean = true): Unit = { implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true) val tableUtils = TableUtils(spark) @@ -564,7 +565,7 @@ class FetcherTest extends TestCase { val lagMs = -100000 val laggedRequests = buildRequests(lagMs) val laggedResponseDf = - FetcherTestUtil.joinResponses(spark, laggedRequests, mockApi, samplePercent = 5, logToHive = true)._2 + FetcherTestUtil.joinResponses(spark, laggedRequests, mockApi, samplePercent = 5, logToHive = true, misTypeKeys = misTypeKeys)._2 val correctedLaggedResponse = laggedResponseDf .withColumn("ts_lagged", laggedResponseDf.col("ts_millis") + lagMs) .withColumn("ts_millis", col("ts_lagged")) @@ -595,13 +596,13 @@ class FetcherTest extends TestCase { |""".stripMargin) } // benchmark - FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, useJavaFetcher = true) - FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10) + FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, useJavaFetcher = true, misTypeKeys = misTypeKeys) + FetcherTestUtil.joinResponses(spark, requests, mockApi, runCount = 10, misTypeKeys = misTypeKeys) // comparison val columns = endDsExpected.schema.fields.map(_.name) val responseRows: Seq[Row] = - FetcherTestUtil.joinResponses(spark, requests, mockApi, useJavaFetcher = true, debug = true)._1.map { res => + FetcherTestUtil.joinResponses(spark, requests, mockApi, useJavaFetcher = true, debug = true, misTypeKeys = misTypeKeys)._1.map { res => val all: Map[String, AnyRef] = res.request.keys ++ res.values.get ++ @@ -657,18 +658,18 @@ class FetcherTest extends TestCase { ) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, misTypeKeys = false) } def testTemporalFetchJoinDerivationRenameOnly(): Unit = { - val namespace = "derivation_fetch" + val namespace = "derivation_fetch_rename_only" val joinConf = generateMutationData(namespace) val derivations = Seq(Builders.Derivation(name = "*", expression = "*"), Builders.Derivation(name = "listing_id_renamed", expression = "listing_id") ) joinConf.setDerivations(derivations.toJava) - compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true) + compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true, misTypeKeys = false) } @@ -724,7 +725,8 @@ object FetcherTestUtil { runCount: Int = 1, samplePercent: Double = -1, logToHive: Boolean = false, - debug: Boolean = false)(implicit ec: ExecutionContext): (List[Response], DataFrame) = { + debug: Boolean = false, + misTypeKeys: Boolean = true)(implicit ec: ExecutionContext): (List[Response], DataFrame) = { val chunkSize = 100 @transient lazy val fetcher = mockApi.buildFetcher(debug) @transient lazy val javaFetcher = mockApi.buildJavaFetcher() @@ -736,7 +738,15 @@ object FetcherTestUtil { val result = requests.iterator .grouped(chunkSize) .map { oldReqs => - val r = oldReqs + // deliberately mis-type a few keys + val r = if (misTypeKeys) { + oldReqs + .map(r => + r.copy(keys = r.keys.mapValues { v => + if (v.isInstanceOf[java.lang.Long]) v.toString else v + }.toMap)) + } else oldReqs + val responses = if (useJavaFetcher) { // Converting to java request and using the toScalaRequest functionality to test conversion val convertedJavaRequests = r.map(new JavaRequest(_)).toJava @@ -753,7 +763,11 @@ object FetcherTestUtil { fetcher.fetchJoin(r) } - System.currentTimeMillis() -> responses + // fix mis-typed keys in the request + val fixedResponses = if (misTypeKeys) { + responses.map(resps => resps.zip(oldReqs).map { case (resp, req) => resp.copy(request = req) }) + } else responses + System.currentTimeMillis() -> fixedResponses } .flatMap { case (start, future) =>