diff --git a/.gitignore b/.gitignore index 70bc562daf..c42394c6ed 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.logs *.iml *.db + .idea/ .ijwb/ **/local_warehouse/ @@ -127,3 +128,4 @@ MODULE.bazel* # mill build output out/** +agents.md diff --git a/api/src/main/scala/ai/chronon/api/Extensions.scala b/api/src/main/scala/ai/chronon/api/Extensions.scala index 3651411819..855b89ecf1 100644 --- a/api/src/main/scala/ai/chronon/api/Extensions.scala +++ b/api/src/main/scala/ai/chronon/api/Extensions.scala @@ -596,6 +596,16 @@ object Extensions { groupBy.sources.toScala .find(_.topic != null) + def commonConfValue(key: String): Option[String] = { + for { + metaData <- Option(groupBy.metaData) + execInfo <- Option(metaData.executionInfo) + conf <- Option(execInfo.conf) + common <- Option(conf.common) + value <- Option(common.get(key)).map(_.trim.toLowerCase).filter(_.nonEmpty) + } yield value + } + // de-duplicate all columns necessary for aggregation in a deterministic order // so we use distinct instead of toSet here def aggregationInputs: Array[String] = diff --git a/cloud_aws/src/main/scala/ai/chronon/integrations/aws/AwsApiImpl.scala b/cloud_aws/src/main/scala/ai/chronon/integrations/aws/AwsApiImpl.scala index a1e893a00a..9179cebc2d 100644 --- a/cloud_aws/src/main/scala/ai/chronon/integrations/aws/AwsApiImpl.scala +++ b/cloud_aws/src/main/scala/ai/chronon/integrations/aws/AwsApiImpl.scala @@ -36,7 +36,7 @@ class AwsApiImpl(conf: Map[String, String]) extends Api(conf) { } override def genKvStore: KVStore = { - new DynamoDBKVStoreImpl(ddbClient) + new DynamoDBKVStoreImpl(ddbClient, conf) } /** The stream decoder method in the AwsApi is currently unimplemented. This needs to be implemented before @@ -48,15 +48,16 @@ class AwsApiImpl(conf: Map[String, String]) extends Api(conf) { * a fully functional Chronon serving stack in Aws * @return */ - override def externalRegistry: ExternalSourceRegistry = ??? + @transient lazy val registry: ExternalSourceRegistry = new ExternalSourceRegistry() + override def externalRegistry: ExternalSourceRegistry = registry /** The logResponse method is currently unimplemented. We'll need to implement this prior to bringing up the * fully functional serving stack in Aws which includes logging feature responses to a stream for OOC */ - override def logResponse(resp: LoggableResponse): Unit = ??? + override def logResponse(resp: LoggableResponse): Unit = () override def genMetricsKvStore(tableBaseName: String): KVStore = { - new DynamoDBKVStoreImpl(ddbClient) + new DynamoDBKVStoreImpl(ddbClient, conf) } override def genEnhancedStatsKvStore(tableBaseName: String): KVStore = ??? diff --git a/cloud_aws/src/main/scala/ai/chronon/integrations/aws/DynamoDBKVStoreImpl.scala b/cloud_aws/src/main/scala/ai/chronon/integrations/aws/DynamoDBKVStoreImpl.scala index 964d00e170..7e75fe61b4 100644 --- a/cloud_aws/src/main/scala/ai/chronon/integrations/aws/DynamoDBKVStoreImpl.scala +++ b/cloud_aws/src/main/scala/ai/chronon/integrations/aws/DynamoDBKVStoreImpl.scala @@ -4,6 +4,7 @@ import ai.chronon.api.Constants import ai.chronon.api.Constants.{ContinuationKey, ListLimit} import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.online.KVStore +import ai.chronon.spark.{IonPathConfig, IonWriter} import ai.chronon.online.KVStore.GetResponse import ai.chronon.online.KVStore.ListRequest import ai.chronon.online.KVStore.ListResponse @@ -29,6 +30,14 @@ import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType import software.amazon.awssdk.services.dynamodb.model.ScanRequest import software.amazon.awssdk.services.dynamodb.model.ScanResponse +import software.amazon.awssdk.services.dynamodb.model.ImportTableRequest +import software.amazon.awssdk.services.dynamodb.model.InputFormat +import software.amazon.awssdk.services.dynamodb.model.InputCompressionType +import software.amazon.awssdk.services.dynamodb.model.S3BucketSource +import software.amazon.awssdk.services.dynamodb.model.TableCreationParameters +import software.amazon.awssdk.services.dynamodb.model.BillingMode +import software.amazon.awssdk.services.dynamodb.model.DescribeImportRequest +import software.amazon.awssdk.services.dynamodb.model.ImportStatus import java.time.Instant import java.util @@ -57,7 +66,11 @@ object DynamoDBKVStoreConstants { val defaultWriteCapacityUnits = 10L } -class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore { +/** Exception thrown when attempting to create a DynamoDB table that already exists */ +case class TableAlreadyExistsException(tableName: String) + extends RuntimeException(s"DynamoDB table '$tableName' already exists") + +class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient, conf: Map[String, String] = Map.empty) extends KVStore { import DynamoDBKVStoreConstants._ protected val metricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.KVStore).withSuffix("dynamodb") @@ -104,7 +117,9 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore { logger.info(s"Table created successfully! Details: \n${tableDescription.toString}") metricsContext.increment("create.successes") } catch { - case _: ResourceInUseException => logger.info(s"Table: $dataset already exists") + case _: ResourceInUseException => + logger.info(s"Table: $dataset already exists") + throw TableAlreadyExistsException(dataset) case e: Exception => logger.error(s"Error creating Dynamodb table: $dataset", e) metricsContext.increment("create.failures") @@ -218,10 +233,102 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore { Future.sequence(futureResponses) } - /** Implementation of bulkPut is currently a TODO for the DynamoDB store. This involves transforming the underlying - * Parquet data to Amazon's Ion format + swapping out old table for new (as bulkLoad only writes to new tables) + /** Bulk loads data from S3 Ion files into DynamoDB using the ImportTable API. + * + * The Ion files are expected to have been written by IonWriter during GroupByUpload. + * The S3 location is determined by IonWriter.resolveS3Location using: + * - Root path from config: spark.chronon.table_write.upload.root_path + * - Dataset name: sourceOfflineTable (e.g., namespace.groupby_v1__upload) + * - Partition column and value: ds={partition} + * + * Full path: s3://{bucket}/{sourceOfflineTable}/ds={partition}/ */ - override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = ??? + override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = { + val rootPath = conf.get(IonPathConfig.UploadLocationKey) + val partitionColumn = conf.getOrElse(IonPathConfig.PartitionColumnKey, IonPathConfig.DefaultPartitionColumn) + + // Use shared IonWriter path resolution to ensure consistency between producer and consumer + val path = IonWriter.resolvePartitionPath(sourceOfflineTable, partitionColumn, partition, rootPath) + val s3Source = toS3BucketSource(path) + logger.info(s"Starting DynamoDB import for table: $destinationOnlineDataSet from S3: $s3Source") + + val tableParams = TableCreationParameters.builder() + .tableName(destinationOnlineDataSet) + .keySchema( + KeySchemaElement.builder().attributeName(partitionKeyColumn).keyType(KeyType.HASH).build() + ) + .attributeDefinitions( + AttributeDefinition.builder().attributeName(partitionKeyColumn).attributeType(ScalarAttributeType.B).build() + ) + .billingMode(BillingMode.PAY_PER_REQUEST) + .build() + + val importRequest = ImportTableRequest.builder() + .s3BucketSource(s3Source) + .inputFormat(InputFormat.ION) + .inputCompressionType(InputCompressionType.NONE) + .tableCreationParameters(tableParams) + .build() + + try { + val startTs = System.currentTimeMillis() + val importResponse = dynamoDbClient.importTable(importRequest) + val importArn = importResponse.importTableDescription().importArn() + + logger.info(s"DynamoDB import initiated with ARN: $importArn for table: $destinationOnlineDataSet") + + // Wait for import to complete + waitForImportCompletion(importArn, destinationOnlineDataSet) + + val duration = System.currentTimeMillis() - startTs + logger.info(s"DynamoDB import completed for table: $destinationOnlineDataSet in ${duration}ms") + metricsContext.increment("bulkPut.successes") + metricsContext.distribution("bulkPut.latency", duration) + } catch { + case e: Exception => + logger.error(s"Failed to import data to DynamoDB table: $destinationOnlineDataSet", e) + metricsContext.increment("bulkPut.failures") + throw e + } + } + + /** Converts a Hadoop Path to an S3BucketSource for DynamoDB ImportTable. */ + private def toS3BucketSource(path: org.apache.hadoop.fs.Path): S3BucketSource = { + val uri = path.toUri + S3BucketSource.builder() + .s3Bucket(uri.getHost) + .s3KeyPrefix(uri.getPath.stripPrefix("/") + "/") + .build() + } + + /** Waits for a DynamoDB import to complete by polling the import status. */ + private def waitForImportCompletion(importArn: String, tableName: String): Unit = { + val maxWaitTimeMs = 30 * 60 * 1000L // 30 minutes + val pollIntervalMs = 10 * 1000L // 10 seconds + val startTime = System.currentTimeMillis() + + var status: ImportStatus = ImportStatus.IN_PROGRESS + while (status == ImportStatus.IN_PROGRESS && (System.currentTimeMillis() - startTime) < maxWaitTimeMs) { + Thread.sleep(pollIntervalMs) + + val describeRequest = DescribeImportRequest.builder().importArn(importArn).build() + val describeResponse = dynamoDbClient.describeImport(describeRequest) + status = describeResponse.importTableDescription().importStatus() + + logger.info(s"DynamoDB import status for $tableName: $status") + } + + status match { + case ImportStatus.COMPLETED => + logger.info(s"DynamoDB import completed successfully for table: $tableName") + case ImportStatus.FAILED | ImportStatus.CANCELLED => + throw new RuntimeException(s"DynamoDB import failed with status: $status for table: $tableName") + case ImportStatus.IN_PROGRESS => + throw new RuntimeException(s"DynamoDB import timed out after ${maxWaitTimeMs}ms for table: $tableName") + case _ => + logger.warn(s"Unknown import status: $status for table: $tableName") + } + } private def getCapacityUnits(props: Map[String, Any], key: String, defaultValue: Long): Long = { props.get(key) match { diff --git a/flink/package.mill b/flink/package.mill index 04abd4ac14..e1985038f0 100644 --- a/flink/package.mill +++ b/flink/package.mill @@ -77,4 +77,14 @@ trait FlinkModule extends Cross.Module[String] with build.BaseModule { mvn"com.fasterxml.jackson.core:jackson-annotations:2.15.2" ) } + + override def assemblyRules = super.assemblyRules ++ Seq( + mill.scalalib.Assembly.Rule.Relocate("org.apache.flink.kinesis.shaded.**", "shaded.@1") + ) + + override def upstreamAssemblyClasspath = Task { + super.upstreamAssemblyClasspath() ++ + Task.traverse(Seq(build.flink_connectors))(_.localClasspath)().flatten + } + } \ No newline at end of file diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala index e4304e1f3f..ba722c02cc 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala @@ -37,6 +37,11 @@ abstract class BaseFlinkJob { def groupByServingInfoParsed: GroupByServingInfoParsed + /** Run the streaming job without tiling (direct processing). + * Events are processed and written directly to KV store without windowing/buffering. + */ + def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] + /** Run the streaming job with tiling enabled (default mode). * This is the main execution method that should be implemented by subclasses. */ @@ -223,7 +228,8 @@ object FlinkJob { FlinkJob.runWriteInternalManifestJob(env, jobArgs.streamingManifestPath(), maybeParentJobId.get, groupByName) } - val jobDatastream = flinkJob.runTiledGroupByJob(env) + // val jobDatastream = flinkJob.runTiledGroupByJob(env) + val jobDatastream = flinkJob.runGroupByJob(env) jobDatastream .addSink(new MetricsSink(groupByName)) diff --git a/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala b/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala index 65fc0a0a08..388a279816 100644 --- a/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala @@ -92,6 +92,16 @@ class ChainedGroupByJob(eventSrc: FlinkSource[ProjectedEvent], } } + /** Untiled mode is not yet implemented for ChainedGroupByJob (JoinSource GroupBys). + * Use runTiledGroupByJob instead. + */ + override def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + throw new NotImplementedError( + s"Untiled mode is not implemented for ChainedGroupByJob (JoinSource GroupBys). " + + s"GroupBy: $groupByName uses JoinSource and only supports tiled mode." + ) + } + /** Build the tiled version of the Flink GroupBy job that chains features using a JoinSource. * The operators are structured as follows: * - Source: Read from Kafka topic into ProjectedEvent stream diff --git a/flink/src/main/scala/ai/chronon/flink/deser/LoginsSerDe.scala b/flink/src/main/scala/ai/chronon/flink/deser/LoginsSerDe.scala new file mode 100644 index 0000000000..732064d714 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/deser/LoginsSerDe.scala @@ -0,0 +1,108 @@ +package ai.chronon.flink.deser + +import ai.chronon.api.{StructField => ZStructField, StructType => ZStructType, LongType => ZLongType, DoubleType => ZDoubleType, StringType => ZStringType} +import ai.chronon.online.TopicInfo +import ai.chronon.online.serde.{Mutation, SerDe} + +import scala.util.Try + +/** SerDe for the logins events when messages are simple JSON objects with fields: + * - ts: Long (event time in millis) + * - event_id: String + * - user_id: String (or numeric convertible to String) + * - login_method: String + * - device_type: String + * - ip_address: String + * + * If you use Avro on the wire, consider wrapping an AvroSerDe with your Avro schema instead. + */ +class LoginsSerDe(topicInfo: TopicInfo) extends SerDe { + + private val zSchema: ZStructType = ZStructType( + "logins_event", + Array( + ZStructField("ts", ZLongType), + ZStructField("event_id", ZStringType), + ZStructField("user_id", ZStringType), + ZStructField("login_method", ZStringType), + ZStructField("device_type", ZStringType), + ZStructField("ip_address", ZStringType) + ) + ) + + override def schema: ZStructType = zSchema + + override def fromBytes(bytes: Array[Byte]): Mutation = { + val json = new String(bytes, java.nio.charset.StandardCharsets.UTF_8).trim + val parsed = parseFlatJson(json) + val row: Array[Any] = Array[Any]( + parsed.getOrElse("ts", null).asInstanceOf[java.lang.Long], + toStringOrNull(parsed.get("event_id")), + toStringOrNull(parsed.get("user_id")), + toStringOrNull(parsed.get("login_method")), + toStringOrNull(parsed.get("device_type")), + toStringOrNull(parsed.get("ip_address")) + ) + Mutation(schema, null, row) + } + + private def toStringOrNull(v: Option[Any]): String = v match { + case Some(s: String) => s + case Some(n: java.lang.Number) => String.valueOf(n) + case Some(other) => String.valueOf(other) + case None => null + } + + /** Minimal, dependency-free flat JSON parser for simple key-value objects. + * Accepts numbers, strings and booleans; strings must be quoted. + * Not suitable for nested objects or arrays. + */ + private def parseFlatJson(input: String): Map[String, Any] = { + // strip outer braces + val s = input.dropWhile(_ != '{').drop(1).reverse.dropWhile(_ != '}').drop(1).reverse + if (s.trim.isEmpty) return Map.empty + val parts = splitTopLevel(s) + parts.flatMap { kv => + val idx = kv.indexOf(":") + if (idx <= 0) None + else { + val key = unquote(kv.substring(0, idx).trim) + val raw = kv.substring(idx + 1).trim + Some(key -> parseValue(raw)) + } + }.toMap + } + + private def splitTopLevel(s: String): Seq[String] = { + val buf = new StringBuilder + var inString = false + var esc = false + val out = scala.collection.mutable.ArrayBuffer.empty[String] + s.foreach { ch => + if (esc) { buf.append(ch); esc = false } + else ch match { + case '\\' if inString => buf.append(ch); esc = true + case '"' => inString = !inString; buf.append(ch) + case ',' if !inString => out += buf.result(); buf.clear() + case c => buf.append(c) + } + } + val last = buf.result().trim + if (last.nonEmpty) out += last + out.toSeq + } + + private def unquote(s: String): String = { + val t = s.trim + if (t.startsWith("\"") && t.endsWith("\"")) t.substring(1, t.length - 1) else t + } + + private def parseValue(raw: String): Any = { + if (raw == "null") null + else if (raw == "true" || raw == "false") java.lang.Boolean.valueOf(raw) + else if (raw.startsWith("\"") && raw.endsWith("\"")) unquote(raw) + else Try(java.lang.Long.valueOf(raw)).orElse(Try(java.lang.Double.valueOf(raw))).getOrElse(raw) + } +} + + diff --git a/flink/src/main/scala/ai/chronon/flink/source/FlinkSourceProvider.scala b/flink/src/main/scala/ai/chronon/flink/source/FlinkSourceProvider.scala index e16b89a478..de6fb949b6 100644 --- a/flink/src/main/scala/ai/chronon/flink/source/FlinkSourceProvider.scala +++ b/flink/src/main/scala/ai/chronon/flink/source/FlinkSourceProvider.scala @@ -12,6 +12,8 @@ object FlinkSourceProvider { new KafkaFlinkSource(props, deserializationSchema, topicInfo) case "pubsub" => loadPubsubSource(props, deserializationSchema, topicInfo) + case "kinesis" => + loadKinesisSource(props, deserializationSchema, topicInfo) case _ => throw new IllegalArgumentException(s"Unsupported message bus: ${topicInfo.messageBus}") } @@ -28,4 +30,16 @@ object FlinkSourceProvider { val onlineImpl = constructor.newInstance(props, deserializationSchema, topicInfo) onlineImpl.asInstanceOf[FlinkSource[T]] } + + // Kinesis source is loaded via reflection as we don't want the Flink module to depend on the Kinesis connector + // module as we don't want to pull in AWS deps in contexts such as running in GCP + private def loadKinesisSource[T](props: Map[String, String], + deserializationSchema: DeserializationSchema[T], + topicInfo: TopicInfo): FlinkSource[T] = { + val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader + val cls = cl.loadClass("ai.chronon.flink_connectors.kinesis.KinesisFlinkSource") + val constructor = cls.getConstructors.apply(0) + val onlineImpl = constructor.newInstance(props, deserializationSchema, topicInfo) + onlineImpl.asInstanceOf[FlinkSource[T]] + } } diff --git a/flink_connectors/package.mill b/flink_connectors/package.mill index 4e9da6cdd8..d312350acb 100644 --- a/flink_connectors/package.mill +++ b/flink_connectors/package.mill @@ -22,7 +22,8 @@ trait FlinkConnectorsModule extends Cross.Module[String] with build.BaseModule { .exclude("com.fasterxml.jackson.core" -> "jackson-core") .exclude("com.fasterxml.jackson.core" -> "jackson-databind") .exclude("com.fasterxml.jackson.core" -> "jackson-annotations"), - mvn"io.netty:netty-codec-http2:4.1.129.Final", + mvn"io.netty:netty-codec-http2:4.1.124.Final", + mvn"org.apache.flink:flink-connector-kinesis:4.2.0-1.17", ) object test extends build.BaseTestModule { diff --git a/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisConfig.scala b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisConfig.scala new file mode 100644 index 0000000000..54001bd619 --- /dev/null +++ b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisConfig.scala @@ -0,0 +1,76 @@ +package ai.chronon.flink_connectors.kinesis + +import ai.chronon.flink.FlinkUtils +import ai.chronon.online.TopicInfo +import org.apache.flink.kinesis.shaded.org.apache.flink.connector.aws.config.AWSConfigConstants +import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants + +import java.util.Properties + +object KinesisConfig { + case class ConsumerConfig(properties: Properties, parallelism: Int) + + object Keys { + val AwsRegion = "AWS_REGION" + val AwsDefaultRegion = "AWS_DEFAULT_REGION" + val AwsAccessKeyId = "AWS_ACCESS_KEY_ID" + val AwsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + val KinesisEndpoint = "KINESIS_ENDPOINT" + val TaskParallelism = "tasks" + val InitialPosition = "initial_position" + val EnableEfo = "enable_efo" + val EfoConsumerName = "efo_consumer_name" + } + + object Defaults { + val Parallelism = 1 + val InitialPosition: String = ConsumerConfigConstants.InitialPosition.LATEST.toString + } + + def buildConsumerConfig(props: Map[String, String], topicInfo: TopicInfo): ConsumerConfig = { + val lookup = new PropertyLookup(props, topicInfo) + val properties = new Properties() + + val region = lookup.requiredOneOf(Keys.AwsRegion, Keys.AwsDefaultRegion) + val accessKeyId = lookup.required(Keys.AwsAccessKeyId) + val secretAccessKey = lookup.required(Keys.AwsSecretAccessKey) + + properties.setProperty(AWSConfigConstants.AWS_REGION, region) + properties.setProperty(AWSConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC") + properties.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, accessKeyId) + properties.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, secretAccessKey) + + + val initialPosition = lookup.optional(Keys.InitialPosition).getOrElse(Defaults.InitialPosition) + properties.setProperty(ConsumerConfigConstants.STREAM_INITIAL_POSITION, initialPosition) + + + val endpoint = lookup.optional(Keys.KinesisEndpoint) + val publisherType = lookup.optional(Keys.EnableEfo).map { enabledFlag => + if (enabledFlag.toBoolean) ConsumerConfigConstants.RecordPublisherType.EFO.toString + else ConsumerConfigConstants.RecordPublisherType.POLLING.toString + } + val efoConsumerName = lookup.optional(Keys.EfoConsumerName) + val parallelism = lookup.optional(Keys.TaskParallelism).map(_.toInt).getOrElse(Defaults.Parallelism) + + endpoint.foreach(properties.setProperty(AWSConfigConstants.AWS_ENDPOINT, _)) + publisherType.foreach(properties.setProperty(ConsumerConfigConstants.RECORD_PUBLISHER_TYPE, _)) + efoConsumerName.foreach(properties.setProperty(ConsumerConfigConstants.EFO_CONSUMER_NAME, _)) + + ConsumerConfig(properties, parallelism) + } + + private final class PropertyLookup(props: Map[String, String], topicInfo: TopicInfo) { + def optional(key: String): Option[String] = + FlinkUtils.getProperty(key, props, topicInfo) + + def required(key: String): String = + optional(key).getOrElse(missing(key)) + + def requiredOneOf(primary: String, fallback: String): String = + optional(primary).orElse(optional(fallback)).getOrElse(missing(s"$primary or $fallback")) + + private def missing(name: String): Nothing = + throw new IllegalArgumentException(s"Missing required property: $name") + } +} diff --git a/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisDeserializationSchemaWrapper.scala b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisDeserializationSchemaWrapper.scala new file mode 100644 index 0000000000..b2a408f9f1 --- /dev/null +++ b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisDeserializationSchemaWrapper.scala @@ -0,0 +1,38 @@ +package ai.chronon.flink_connectors.kinesis + +import org.apache.flink.api.common.serialization.DeserializationSchema +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema +import org.apache.flink.util.Collector + +import java.util + +class KinesisDeserializationSchemaWrapper[T](deserializationSchema: DeserializationSchema[T]) + extends KinesisDeserializationSchema[T] { + + override def open(context: DeserializationSchema.InitializationContext): Unit = { + deserializationSchema.open(context) + } + + override def deserialize( + recordValue: Array[Byte], + partitionKey: String, + seqNum: String, + approxArrivalTimestamp: Long, + stream: String, + shardId: String + ): T = { + val results = new util.ArrayList[T]() + val collector = new Collector[T] { + override def collect(record: T): Unit = results.add(record) + override def close(): Unit = {} + } + + deserializationSchema.deserialize(recordValue, collector) + + if (!results.isEmpty) results.get(0) else null.asInstanceOf[T] + } + + override def getProducedType: TypeInformation[T] = deserializationSchema.getProducedType +} + diff --git a/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisFlinkSource.scala b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisFlinkSource.scala new file mode 100644 index 0000000000..6b0e8490f2 --- /dev/null +++ b/flink_connectors/src/main/scala/ai/chronon/flink_connectors/kinesis/KinesisFlinkSource.scala @@ -0,0 +1,66 @@ +package ai.chronon.flink_connectors.kinesis + +import ai.chronon.flink.source.FlinkSource +import ai.chronon.online.TopicInfo +import org.apache.flink.api.common.eventtime.WatermarkStrategy +import org.apache.flink.api.common.serialization.DeserializationSchema +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer + +/** Chronon Flink source that reads events from AWS Kinesis. Can be configured on the topic as: + * kinesis://stream-name/tasks=20/ + * + * Config params such as the AWS region, access key ID, and secret access key are read from the + * online properties (configured in teams.py or env variables passed via -Z flags). + * + * Kinesis differs from Kafka in a few aspects: + * 1. Shard-based parallelism - we can derive parallelism based on the number of shards, but + * allow the user to override it via the 'tasks' property. + * 2. Kinesis streams maintain their own checkpointing via shard iterators, so job restarts will + * resume from the last processed position. + * + * Required properties: + * - AWS_REGION (or AWS_DEFAULT_REGION): The AWS region where the Kinesis stream exists + * - AWS_ACCESS_KEY_ID: AWS access key for authentication + * - AWS_SECRET_ACCESS_KEY: AWS secret access key for authentication + * + * Optional properties: + * - tasks: Override the default parallelism + * - KINESIS_ENDPOINT: Custom Kinesis endpoint (useful for local testing) + * - initial_position: Starting position (LATEST, TRIM_HORIZON, or AT_TIMESTAMP) + */ +class KinesisFlinkSource[T](props: Map[String, String], + deserializationSchema: DeserializationSchema[T], + topicInfo: TopicInfo) + extends FlinkSource[T] { + + private val config = KinesisConfig.buildConsumerConfig(props, topicInfo) + + implicit val parallelism: Int = config.parallelism + + override def getDataStream(topic: String, groupByName: String)(env: StreamExecutionEnvironment, + parallelism: Int): SingleOutputStreamOperator[T] = { + + // Wrap the deserialization schema to handle Collector-based deserialization + // This is needed because FlinkKinesisConsumer doesn't support DeserializationSchema + // that uses the Collector API (which allows producing multiple records per input) + val wrappedSchema = new KinesisDeserializationSchemaWrapper[T](deserializationSchema) + + val kinesisConsumer = new FlinkKinesisConsumer[T]( + topicInfo.name, + wrappedSchema, + config.properties + ) + + // skip watermarks at the source as we derive them post Spark expr eval + val noWatermarks: WatermarkStrategy[T] = WatermarkStrategy.noWatermarks() + + env + .addSource(kinesisConsumer, s"Kinesis source: $groupByName - ${topicInfo.name}") + .setParallelism(parallelism) + .uid(s"kinesis-source-$groupByName") + .assignTimestampsAndWatermarks(noWatermarks) + } +} + diff --git a/flink_connectors/src/test/scala/ai/chronon/flink_connectors/kinesis/KinesisConfigSpec.scala b/flink_connectors/src/test/scala/ai/chronon/flink_connectors/kinesis/KinesisConfigSpec.scala new file mode 100644 index 0000000000..10fe25d8a3 --- /dev/null +++ b/flink_connectors/src/test/scala/ai/chronon/flink_connectors/kinesis/KinesisConfigSpec.scala @@ -0,0 +1,76 @@ +package ai.chronon.flink_connectors.kinesis + +import ai.chronon.flink_connectors.kinesis.KinesisConfig.{ConsumerConfig, Defaults, Keys} +import ai.chronon.online.TopicInfo +import org.apache.flink.kinesis.shaded.org.apache.flink.connector.aws.config.AWSConfigConstants +import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class KinesisConfigSpec extends AnyFlatSpec with Matchers { + + behavior of "KinesisConfig.buildConsumerConfig" + + it should "require a region or default region" in { + val props = Map( + Keys.AwsAccessKeyId -> "access", + Keys.AwsSecretAccessKey -> "secret" + ) + + val topicInfo = TopicInfo("test-stream", "kinesis", Map.empty) + + an[IllegalArgumentException] should be thrownBy { + KinesisConfig.buildConsumerConfig(props, topicInfo) + } + } + + it should "default optional values when not provided" in { + val props = Map( + Keys.AwsRegion -> "us-west-2", + Keys.AwsAccessKeyId -> "access", + Keys.AwsSecretAccessKey -> "secret" + ) + + val topicInfo = TopicInfo("test-stream", "kinesis", Map.empty) + + val kinesisConfig = KinesisConfig.buildConsumerConfig(props, topicInfo) + + + kinesisConfig.parallelism shouldBe Defaults.Parallelism + kinesisConfig.properties.getProperty(AWSConfigConstants.AWS_REGION) shouldBe "us-west-2" + kinesisConfig.properties.getProperty(AWSConfigConstants.AWS_CREDENTIALS_PROVIDER) shouldBe "BASIC" + kinesisConfig.properties.getProperty(ConsumerConfigConstants.STREAM_INITIAL_POSITION) shouldBe Defaults.InitialPosition + kinesisConfig.properties.containsKey(AWSConfigConstants.AWS_ENDPOINT) shouldBe false + kinesisConfig.properties.containsKey(ConsumerConfigConstants.RECORD_PUBLISHER_TYPE) shouldBe false + kinesisConfig.properties.containsKey(ConsumerConfigConstants.EFO_CONSUMER_NAME) shouldBe false + } + + it should "apply overrides and optional fields from props and topic params" in { + val props = Map( + Keys.AwsAccessKeyId -> "access", + Keys.AwsSecretAccessKey -> "secret", + Keys.EnableEfo -> "true", + Keys.TaskParallelism -> "7" + ) + + val topicParams = Map( + Keys.AwsDefaultRegion -> "us-west-1", + Keys.InitialPosition -> ConsumerConfigConstants.InitialPosition.TRIM_HORIZON.toString, + Keys.KinesisEndpoint -> "http://localhost:4566", + Keys.EfoConsumerName -> "consumer" + ) + + val topicInfo = TopicInfo("test-stream", "kinesis", topicParams) + + val kinesisConfig = KinesisConfig.buildConsumerConfig(props, topicInfo) + + kinesisConfig.parallelism shouldBe 7 + kinesisConfig.properties.getProperty(AWSConfigConstants.AWS_REGION) shouldBe "us-west-1" + kinesisConfig.properties.getProperty(ConsumerConfigConstants.STREAM_INITIAL_POSITION) shouldBe ConsumerConfigConstants.InitialPosition.TRIM_HORIZON.toString + kinesisConfig.properties.getProperty(AWSConfigConstants.AWS_ENDPOINT) shouldBe "http://localhost:4566" + kinesisConfig.properties.getProperty(ConsumerConfigConstants.RECORD_PUBLISHER_TYPE) shouldBe ConsumerConfigConstants.RecordPublisherType.EFO.toString + kinesisConfig.properties.getProperty(ConsumerConfigConstants.EFO_CONSUMER_NAME) shouldBe "consumer" + } +} + + diff --git a/python/package.mill b/python/package.mill index d3080a275c..eb02f11f75 100644 --- a/python/package.mill +++ b/python/package.mill @@ -114,6 +114,11 @@ object `package` extends PythonModule with RuffModule with PublishModule { // Install in editable mode with generated sources. Can mess ruffCheck import order (due to gen_thrift priority change) def installEditable() = Task.Command { + // Install into virtual env if it exists, otherwise use the system python + val repoVenvPython = moduleDir / os.up / ".venv" / "bin" / "python" + val pythonExec = if (os.exists(repoVenvPython)) repoVenvPython.toString else "python3" + println(s"Using python executable: $pythonExec") + // Generate sources in Task.dest as usual val generatedPaths = generatedSources() @@ -136,7 +141,7 @@ object `package` extends PythonModule with RuffModule with PublishModule { // Install in editable mode println(s"Working directory: $pythonDir") - os.proc("python", "-m", "pip", "install", "-e", ".").call(cwd = pythonDir) + os.proc(pythonExec, "-m", "pip", "install", "-e", ".").call(cwd = pythonDir) println("Package installed in editable mode with generated sources") } } diff --git a/spark/package.mill b/spark/package.mill index 7eabb21af9..22409e0102 100644 --- a/spark/package.mill +++ b/spark/package.mill @@ -20,6 +20,7 @@ trait SparkModule extends Cross.Module[String] with build.BaseModule { mvn"io.netty:netty-all:4.1.129.Final", mvn"org.rogach::scallop:5.1.0", mvn"org.apache.avro:avro:1.11.4", + mvn"com.amazon.ion:ion-java:1.9.6", mvn"io.delta::delta-spark:3.2.0", // .exclude("org.apache.parquet" -> "parquet-avro"), // Avoid version conflict with Spark's Parquet, // mvn"org.apache.parquet:parquet-avro:1.15.2".forceVersion(), @@ -49,4 +50,4 @@ trait SparkModule extends Cross.Module[String] with build.BaseModule { "-Dspark.sql.hive.convertMetastoreParquet=false" ) } -} \ No newline at end of file +} diff --git a/spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala b/spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala index 01946cc918..1071bd9a51 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala @@ -43,9 +43,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.{col, lit, not, to_date} import org.apache.spark.sql.types import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -274,24 +272,46 @@ object GroupByUpload { val metaRdd = tableUtils.sparkSession.sparkContext.parallelize(metaRows.toSeq) val metaDf = tableUtils.sparkSession.createDataFrame(metaRdd, kvDf.schema) - kvDf - .union(metaDf) - .withColumn("ds", lit(endDs)) - .save(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps, partitionColumns = List("ds")) - - val kvDfReloaded = tableUtils - .loadTable(groupByConf.metaData.uploadTable) - .where(not(col("key_json").eqNullSafe(Constants.GroupByServingInfoKey))) - - val metricRow = - kvDfReloaded.selectExpr("sum(bit_length(key_bytes))/8", "sum(bit_length(value_bytes))/8", "count(*)").collect() - - if (metricRow.length > 0 && metricRow(0).getLong(2) > 0) { - context.gauge(Metrics.Name.KeyBytes, metricRow(0).getDouble(0).toLong) - context.gauge(Metrics.Name.ValueBytes, metricRow(0).getDouble(1).toLong) - context.gauge(Metrics.Name.RowCount, metricRow(0).getLong(2)) + val uploadFormat = + groupByConf + .commonConfValue(IonPathConfig.uploadFormatKey) + .getOrElse("parquet") + val partitionCol = groupByConf + .commonConfValue(IonPathConfig.PartitionColumnKey) + .getOrElse(IonPathConfig.DefaultPartitionColumn) + val uploadDf = kvDf.union(metaDf).withColumn(partitionCol, lit(endDs)) + + logger.info(s"GroupBy upload with upload format: $uploadFormat") + + if (uploadFormat == "ion") { + val rootPath = + groupByConf + .commonConfValue(IonPathConfig.UploadLocationKey) + val ionDf = uploadDf.withColumn(partitionCol, to_date(col(partitionCol))) + IonWriter.write( + ionDf, groupByConf.metaData.uploadTable, partitionCol, endDs, rootPath + ) } else { - throw new RuntimeException("GroupBy upload resulted in zero rows.") + uploadDf.save(groupByConf.metaData.uploadTable, + groupByConf.metaData.tableProps, + partitionColumns = List(partitionCol)) + + val kvDfReloaded = tableUtils + .loadTable(groupByConf.metaData.uploadTable) + .where(not(col("key_json").eqNullSafe(Constants.GroupByServingInfoKey))) + + val metricRow = + kvDfReloaded + .selectExpr("sum(bit_length(key_bytes))/8", "sum(bit_length(value_bytes))/8", "count(*)") + .collect() + + if (metricRow.length > 0 && metricRow(0).getLong(2) > 0) { + context.gauge(Metrics.Name.KeyBytes, metricRow(0).getDouble(0).toLong) + context.gauge(Metrics.Name.ValueBytes, metricRow(0).getDouble(1).toLong) + context.gauge(Metrics.Name.RowCount, metricRow(0).getLong(2)) + } else { + throw new RuntimeException("GroupBy upload resulted in zero rows.") + } } val jobDuration = (System.currentTimeMillis() - startTs) / 1000 diff --git a/spark/src/main/scala/ai/chronon/spark/IonWriter.scala b/spark/src/main/scala/ai/chronon/spark/IonWriter.scala new file mode 100644 index 0000000000..94f380a2b9 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/IonWriter.scala @@ -0,0 +1,149 @@ +package ai.chronon.spark + +import com.amazon.ion.IonType +import com.amazon.ion.system.IonBinaryWriterBuilder +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.TaskContext +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.LoggerFactory + +import java.math.BigDecimal +import java.sql.Date +import java.time.{LocalDate, ZoneOffset} +import java.util.UUID +import scala.util.control.NonFatal + +/** Configuration keys used for Ion upload paths. */ +object IonPathConfig { + val uploadFormatKey = "spark.chronon.table_write.upload.format" + val UploadLocationKey = "spark.chronon.table_write.upload.location" + val PartitionColumnKey = "spark.chronon.partition.column" + val DefaultPartitionColumn = "ds" +} + +object IonWriter { + private val logger = LoggerFactory.getLogger(getClass) + + def write(df: DataFrame, + dataSetName: String, + partitionColumn: String, + partitionValue: String, + rootPath: Option[String] = None): Seq[String] = { + val serializableConf = new SerializableConfiguration(df.sparkSession.sparkContext.hadoopConfiguration) + val schema = df.schema + + val partitionPath = resolvePartitionPath(dataSetName, partitionColumn, partitionValue, rootPath) + + val requiredColumns = Seq("key_bytes", "value_bytes", partitionColumn) + val missingColumns = requiredColumns.filterNot(schema.fieldNames.contains) + if (missingColumns.nonEmpty) { + throw new IllegalArgumentException( + s"DataFrame schema for Ion upload is missing required column(s): ${missingColumns.mkString(", ")}") + } + + val keyIdx = schema.fieldIndex("key_bytes") + val valueIdx = schema.fieldIndex("value_bytes") + val tsIdx = schema.fieldIndex(partitionColumn) + + val written = df.rdd.mapPartitionsWithIndex( (partitionId, iter) => + if (!iter.hasNext) Iterator.empty + else { + val unique = UUID.randomUUID().toString + val filePath = new Path(partitionPath, s"part-$partitionId-$unique.ion") + val fs = FileSystem.get(filePath.toUri, serializableConf.value) + fs.mkdirs(partitionPath) + val out = fs.create(filePath, true) + val writer = IonBinaryWriterBuilder.standard().build(out) + + var rowCount = 0L + var keyBytesTotal = 0L + var valueBytesTotal = 0L + + try { + iter.foreach { row => + writer.stepIn(IonType.STRUCT) + writer.setFieldName("Item") + writer.stepIn(IonType.STRUCT) + if (!row.isNullAt(keyIdx)) { + val bytes = row.getAs[Array[Byte]](keyIdx) + writer.setFieldName("keyBytes") + writer.writeBlob(bytes) + keyBytesTotal += bytes.length + } + if (!row.isNullAt(valueIdx)) { + val bytes = row.getAs[Array[Byte]](valueIdx) + writer.setFieldName("valueBytes") + writer.writeBlob(bytes) + valueBytesTotal += bytes.length + } + if (!row.isNullAt(tsIdx)) { + writer.setFieldName("ts") + val millis = toMillis(row.get(tsIdx)) + writer.writeDecimal(millis) + } + writer.stepOut() + writer.stepOut() + rowCount += 1 + } + writer.finish() + } catch { + case NonFatal(e) => + logger.error(s"Failed writing Ion file at $filePath", e) + throw e + } finally { + writer.close() + out.close() + } + Iterator.single((filePath.toString, rowCount, keyBytesTotal, valueBytesTotal)) + } + ).collect() + + val totalRows = written.map(_._2).sum + if (totalRows == 0L) { + throw new RuntimeException("Ion upload produced zero rows.") + } + + val totalKeyBytes = written.map(_._3).sum + val totalValueBytes = written.map(_._4).sum + logger.info( + s"Wrote Ion files for partition $partitionValue at $partitionPath rows=$totalRows key_bytes=$totalKeyBytes value_bytes=$totalValueBytes" + ) + written.map(_._1) + } + + + def resolvePartitionPath(dataSetName: String, + partitionColumn: String, + partitionValue: String, + rootPath: Option[String]): Path = { + val root = validateRootPath(rootPath) + new Path(new Path(root, dataSetName), s"$partitionColumn=$partitionValue") + } + + /** Validates and normalizes rootPath. Must be s3:// or file:// format. */ + def validateRootPath(rootPath: Option[String]): String = { + val trimmed = rootPath.map(_.trim.stripSuffix("/")).filter(_.nonEmpty).getOrElse { + throw new IllegalArgumentException( + s"Location path is required. Set '${IonPathConfig.UploadLocationKey}' in configuration.") + } + + if (!trimmed.matches("^(s3|s3a|s3n|file):/{1,3}.*")) { + throw new IllegalArgumentException( + s"Root path must start with s3:// or file:/ but got: $trimmed") + } + trimmed + } + + def toMillis(value: Any): BigDecimal = { + value match { + case null => throw new IllegalArgumentException("Partition column is blank; cannot write Ion timestamp") + case date: Date => + BigDecimal.valueOf(date.toInstant.toEpochMilli) + case localDate: LocalDate => + BigDecimal.valueOf(localDate.atStartOfDay(ZoneOffset.UTC).toInstant.toEpochMilli) + case other => + throw new IllegalArgumentException(s"Unsupported partition type: ${other.getClass.getName}") + } + } +} diff --git a/spark/src/test/scala/ai/chronon/spark/upload/IonWriterTest.scala b/spark/src/test/scala/ai/chronon/spark/upload/IonWriterTest.scala new file mode 100644 index 0000000000..ba2aa9b269 --- /dev/null +++ b/spark/src/test/scala/ai/chronon/spark/upload/IonWriterTest.scala @@ -0,0 +1,129 @@ +package ai.chronon.spark.upload + +import ai.chronon.spark.IonWriter +import ai.chronon.spark.utils.SparkTestBase +import com.amazon.ion.system.IonSystemBuilder +import com.amazon.ion.{IonBlob, IonDecimal, IonStruct, IonText} +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.flatspec.AnyFlatSpec +import java.io.{File, FileInputStream} +import java.net.URI +import java.nio.file.{Files, Paths} +import java.time.Instant +import java.time.LocalDate +import scala.jdk.CollectionConverters._ + +class IonWriterTest extends SparkTestBase with Matchers { + + private val tmpDir = Files.createTempDirectory("ion-writer-test").toFile + + override protected def sparkConfs: Map[String, String] = Map( + "spark.sql.warehouse.dir" -> new File(tmpDir, "warehouse").getAbsolutePath + ) + + behavior of "IonWriter" + + it should "write ion files with expected rows and fields" in { + val partitionValue = "2025-10-17" + val tsValue = "2025-10-17T00:00:00Z" + val tsValueMillis = Instant.parse(tsValue).toEpochMilli + val rootPath = Some(tmpDir.toURI.toString) + val dataSetName = "ion-output" + + val schema = StructType( + Seq( + StructField("key_bytes", BinaryType, nullable = true), + StructField("value_bytes", BinaryType, nullable = true), + StructField("key_json", StringType, nullable = true), + StructField("value_json", StringType, nullable = true), + StructField("ds", DateType, nullable = false) + ) + ) + + val rows = Seq( + Row("k1".getBytes("UTF-8"), "v1-bytes".getBytes("UTF-8"), "k1-json", """{"v":"one"}""", LocalDate.parse(partitionValue)), + Row("k2".getBytes("UTF-8"), "v2-bytes".getBytes("UTF-8"), "k2-json", """{"v":"two"}""", LocalDate.parse(partitionValue)) + ) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows, numSlices = 2), schema) + val paths = IonWriter.write(df, dataSetName, "ds", partitionValue, rootPath) + + paths should not be empty + all(paths) should include(s"ds=$partitionValue") + + val ion = IonSystemBuilder.standard().build() + + val parsed = + paths.flatMap { p => + val path = + if (p.startsWith("file:")) Paths.get(new URI(p)) // handle fully-qualified file URIs + else Paths.get(p) // hadoop Path.toString() returns a filesystem path without a scheme + Files.exists(path) shouldBe true + val datagram = ion.getLoader.load(new FileInputStream(path.toFile)) + datagram.iterator().asScala.map { value => + val struct = value.asInstanceOf[IonStruct].get("Item").asInstanceOf[IonStruct] + val keyBytes = Option(struct.get("keyBytes")).map(_.asInstanceOf[IonBlob].getBytes) + val valueBytes = Option(struct.get("valueBytes")).map(_.asInstanceOf[IonBlob].getBytes) + val ts = Option(struct.get("ts")).map(_.asInstanceOf[IonDecimal]) + (keyBytes, valueBytes, ts) + } + } + + parsed.size shouldBe rows.size + parsed.map(_._1.get.toSeq).toSet should contain("k1".getBytes("UTF-8").toSeq) + parsed.map(_._2.get.toSeq).toSet should contain("v2-bytes".getBytes("UTF-8").toSeq) + parsed.flatMap(_._3).foreach(_.bigDecimalValue().longValueExact() shouldBe tsValueMillis) + } + + it should "honor upload bucket when provided" in { + val partitionValue = "2025-10-18" + val dataSetName = "ion-output-bucket" + val rootPath = Some(new File(tmpDir, "bucket-root").toURI.toString) + + val schema = StructType( + Seq( + StructField("key_bytes", BinaryType, nullable = true), + StructField("value_bytes", BinaryType, nullable = true), + StructField("key_json", StringType, nullable = true), + StructField("value_json", StringType, nullable = true), + StructField("ds", DateType, nullable = false) + ) + ) + + val rows = Seq( + Row("k3".getBytes("UTF-8"), "v3-bytes".getBytes("UTF-8"), "k3-json", """{"v":"three"}""", LocalDate.parse(partitionValue)) + ) + + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows, numSlices = 1), schema) + + val paths = IonWriter.write(df, dataSetName, "ds", partitionValue, rootPath) + + paths should not be empty + all(paths) should include(dataSetName) + all(paths) should include(s"ds=$partitionValue") + } + + it should "validate root path with valid schemes" in { + IonWriter.validateRootPath(Some("s3://my-bucket/path")) shouldBe "s3://my-bucket/path" + IonWriter.validateRootPath(Some("s3a://my-bucket")) shouldBe "s3a://my-bucket" + IonWriter.validateRootPath(Some("file:///tmp/local")) shouldBe "file:///tmp/local" + IonWriter.validateRootPath(Some(" s3://trimmed/ ")) shouldBe "s3://trimmed" + } + + it should "reject invalid root paths" in { + an[IllegalArgumentException] should be thrownBy IonWriter.validateRootPath(None) + an[IllegalArgumentException] should be thrownBy IonWriter.validateRootPath(Some("")) + an[IllegalArgumentException] should be thrownBy IonWriter.validateRootPath(Some(" ")) + an[IllegalArgumentException] should be thrownBy IonWriter.validateRootPath(Some("no-scheme-bucket")) + } + + it should "resolve partition path correctly" in { + val bucketUri = new File(tmpDir, "bucket-resolve").toURI.toString + val path = IonWriter.resolvePartitionPath("my-dataset", "ds", "2025-01-15", Some(bucketUri)) + path.toString should include("my-dataset") + path.toString should include("ds=2025-01-15") + } +}