Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*.logs
*.iml
*.db

.idea/
.ijwb/
**/local_warehouse/
Expand Down Expand Up @@ -126,4 +127,4 @@ MODULE.bazel*


# mill build output
out/**
out/**
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ class AwsApiImpl(conf: Map[String, String]) extends Api(conf) {
* a fully functional Chronon serving stack in Aws
* @return
*/
override def externalRegistry: ExternalSourceRegistry = ???
@transient lazy val registry: ExternalSourceRegistry = new ExternalSourceRegistry()
override def externalRegistry: ExternalSourceRegistry = registry

/** The logResponse method is currently unimplemented. We'll need to implement this prior to bringing up the
* fully functional serving stack in Aws which includes logging feature responses to a stream for OOC
*/
override def logResponse(resp: LoggableResponse): Unit = ???
override def logResponse(resp: LoggableResponse): Unit = ()

override def genMetricsKvStore(tableBaseName: String): KVStore = {
new DynamoDBKVStoreImpl(ddbClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ object DynamoDBKVStoreConstants {
val defaultWriteCapacityUnits = 10L
}

/** Exception thrown when attempting to create a DynamoDB table that already exists */
case class TableAlreadyExistsException(tableName: String)
extends RuntimeException(s"DynamoDB table '$tableName' already exists")

class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
import DynamoDBKVStoreConstants._

Expand Down Expand Up @@ -104,7 +108,9 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
logger.info(s"Table created successfully! Details: \n${tableDescription.toString}")
metricsContext.increment("create.successes")
} catch {
case _: ResourceInUseException => logger.info(s"Table: $dataset already exists")
case _: ResourceInUseException =>
logger.info(s"Table: $dataset already exists")
throw TableAlreadyExistsException(dataset)
case e: Exception =>
logger.error(s"Error creating Dynamodb table: $dataset", e)
metricsContext.increment("create.failures")
Expand Down
10 changes: 10 additions & 0 deletions flink/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,14 @@ trait FlinkModule extends Cross.Module[String] with build.BaseModule {
mvn"com.fasterxml.jackson.core:jackson-annotations:2.15.2"
)
}

override def assemblyRules = super.assemblyRules ++ Seq(
mill.scalalib.Assembly.Rule.Relocate("org.apache.flink.kinesis.shaded.**", "shaded.@1")
)

override def upstreamAssemblyClasspath = Task {
super.upstreamAssemblyClasspath() ++
Task.traverse(Seq(build.flink_connectors))(_.localClasspath)().flatten
}

}
8 changes: 7 additions & 1 deletion flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ abstract class BaseFlinkJob {

def groupByServingInfoParsed: GroupByServingInfoParsed

/** Run the streaming job without tiling (direct processing).
* Events are processed and written directly to KV store without windowing/buffering.
*/
def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse]

/** Run the streaming job with tiling enabled (default mode).
* This is the main execution method that should be implemented by subclasses.
*/
Expand Down Expand Up @@ -223,7 +228,8 @@ object FlinkJob {
FlinkJob.runWriteInternalManifestJob(env, jobArgs.streamingManifestPath(), maybeParentJobId.get, groupByName)
}

val jobDatastream = flinkJob.runTiledGroupByJob(env)
// val jobDatastream = flinkJob.runTiledGroupByJob(env)
val jobDatastream = flinkJob.runGroupByJob(env)

jobDatastream
.addSink(new MetricsSink(groupByName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ class ChainedGroupByJob(eventSrc: FlinkSource[ProjectedEvent],
}
}

/** Untiled mode is not yet implemented for ChainedGroupByJob (JoinSource GroupBys).
* Use runTiledGroupByJob instead.
*/
override def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = {
throw new NotImplementedError(
s"Untiled mode is not implemented for ChainedGroupByJob (JoinSource GroupBys). " +
s"GroupBy: $groupByName uses JoinSource and only supports tiled mode."
)
}

/** Build the tiled version of the Flink GroupBy job that chains features using a JoinSource.
* The operators are structured as follows:
* - Source: Read from Kafka topic into ProjectedEvent stream
Expand Down
108 changes: 108 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/deser/LoginsSerDe.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package ai.chronon.flink.deser

import ai.chronon.api.{StructField => ZStructField, StructType => ZStructType, LongType => ZLongType, DoubleType => ZDoubleType, StringType => ZStringType}
import ai.chronon.online.TopicInfo
import ai.chronon.online.serde.{Mutation, SerDe}

import scala.util.Try

/** SerDe for the logins events when messages are simple JSON objects with fields:
* - ts: Long (event time in millis)
* - event_id: String
* - user_id: String (or numeric convertible to String)
* - login_method: String
* - device_type: String
* - ip_address: String
*
* If you use Avro on the wire, consider wrapping an AvroSerDe with your Avro schema instead.
*/
class LoginsSerDe(topicInfo: TopicInfo) extends SerDe {

private val zSchema: ZStructType = ZStructType(
"logins_event",
Array(
ZStructField("ts", ZLongType),
ZStructField("event_id", ZStringType),
ZStructField("user_id", ZStringType),
ZStructField("login_method", ZStringType),
ZStructField("device_type", ZStringType),
ZStructField("ip_address", ZStringType)
)
)

override def schema: ZStructType = zSchema

override def fromBytes(bytes: Array[Byte]): Mutation = {
val json = new String(bytes, java.nio.charset.StandardCharsets.UTF_8).trim
val parsed = parseFlatJson(json)
val row: Array[Any] = Array[Any](
parsed.getOrElse("ts", null).asInstanceOf[java.lang.Long],
toStringOrNull(parsed.get("event_id")),
toStringOrNull(parsed.get("user_id")),
toStringOrNull(parsed.get("login_method")),
toStringOrNull(parsed.get("device_type")),
toStringOrNull(parsed.get("ip_address"))
)
Mutation(schema, null, row)
}

private def toStringOrNull(v: Option[Any]): String = v match {
case Some(s: String) => s
case Some(n: java.lang.Number) => String.valueOf(n)
case Some(other) => String.valueOf(other)
case None => null
}

/** Minimal, dependency-free flat JSON parser for simple key-value objects.
* Accepts numbers, strings and booleans; strings must be quoted.
* Not suitable for nested objects or arrays.
*/
private def parseFlatJson(input: String): Map[String, Any] = {
// strip outer braces
val s = input.dropWhile(_ != '{').drop(1).reverse.dropWhile(_ != '}').drop(1).reverse
if (s.trim.isEmpty) return Map.empty
val parts = splitTopLevel(s)
parts.flatMap { kv =>
val idx = kv.indexOf(":")
if (idx <= 0) None
else {
val key = unquote(kv.substring(0, idx).trim)
val raw = kv.substring(idx + 1).trim
Some(key -> parseValue(raw))
}
}.toMap
}

private def splitTopLevel(s: String): Seq[String] = {
val buf = new StringBuilder
var inString = false
var esc = false
val out = scala.collection.mutable.ArrayBuffer.empty[String]
s.foreach { ch =>
if (esc) { buf.append(ch); esc = false }
else ch match {
case '\\' if inString => buf.append(ch); esc = true
case '"' => inString = !inString; buf.append(ch)
case ',' if !inString => out += buf.result(); buf.clear()
case c => buf.append(c)
}
}
val last = buf.result().trim
if (last.nonEmpty) out += last
out.toSeq
}

private def unquote(s: String): String = {
val t = s.trim
if (t.startsWith("\"") && t.endsWith("\"")) t.substring(1, t.length - 1) else t
}

private def parseValue(raw: String): Any = {
if (raw == "null") null
else if (raw == "true" || raw == "false") java.lang.Boolean.valueOf(raw)
else if (raw.startsWith("\"") && raw.endsWith("\"")) unquote(raw)
else Try(java.lang.Long.valueOf(raw)).orElse(Try(java.lang.Double.valueOf(raw))).getOrElse(raw)
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ object FlinkSourceProvider {
new KafkaFlinkSource(props, deserializationSchema, topicInfo)
case "pubsub" =>
loadPubsubSource(props, deserializationSchema, topicInfo)
case "kinesis" =>
loadKinesisSource(props, deserializationSchema, topicInfo)
case _ =>
throw new IllegalArgumentException(s"Unsupported message bus: ${topicInfo.messageBus}")
}
Expand All @@ -28,4 +30,16 @@ object FlinkSourceProvider {
val onlineImpl = constructor.newInstance(props, deserializationSchema, topicInfo)
onlineImpl.asInstanceOf[FlinkSource[T]]
}

// Kinesis source is loaded via reflection as we don't want the Flink module to depend on the Kinesis connector
// module as we don't want to pull in AWS deps in contexts such as running in GCP
private def loadKinesisSource[T](props: Map[String, String],
deserializationSchema: DeserializationSchema[T],
topicInfo: TopicInfo): FlinkSource[T] = {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass("ai.chronon.flink_connectors.kinesis.KinesisFlinkSource")
val constructor = cls.getConstructors.apply(0)
val onlineImpl = constructor.newInstance(props, deserializationSchema, topicInfo)
onlineImpl.asInstanceOf[FlinkSource[T]]
}
}
1 change: 1 addition & 0 deletions flink_connectors/package.mill
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait FlinkConnectorsModule extends Cross.Module[String] with build.BaseModule {
.exclude("com.fasterxml.jackson.core" -> "jackson-databind")
.exclude("com.fasterxml.jackson.core" -> "jackson-annotations"),
mvn"io.netty:netty-codec-http2:4.1.124.Final",
mvn"org.apache.flink:flink-connector-kinesis:4.2.0-1.17",
)

object test extends build.BaseTestModule {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package ai.chronon.flink_connectors.kinesis

import ai.chronon.flink.FlinkUtils
import ai.chronon.online.TopicInfo
import org.apache.flink.kinesis.shaded.org.apache.flink.connector.aws.config.AWSConfigConstants
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants

import java.util.Properties

object KinesisConfig {
case class ConsumerConfig(properties: Properties, parallelism: Int)

object Keys {
val AwsRegion = "AWS_REGION"
val AwsDefaultRegion = "AWS_DEFAULT_REGION"
val AwsAccessKeyId = "AWS_ACCESS_KEY_ID"
val AwsSecretAccessKey = "AWS_SECRET_ACCESS_KEY"
val KinesisEndpoint = "KINESIS_ENDPOINT"
val TaskParallelism = "tasks"
val InitialPosition = "initial_position"
val EnableEfo = "enable_efo"
val EfoConsumerName = "efo_consumer_name"
}

object Defaults {
val Parallelism = 1
val InitialPosition: String = ConsumerConfigConstants.InitialPosition.LATEST.toString
}

def buildConsumerConfig(props: Map[String, String], topicInfo: TopicInfo): ConsumerConfig = {
val lookup = new PropertyLookup(props, topicInfo)
val properties = new Properties()

val region = lookup.requiredOneOf(Keys.AwsRegion, Keys.AwsDefaultRegion)
val accessKeyId = lookup.required(Keys.AwsAccessKeyId)
val secretAccessKey = lookup.required(Keys.AwsSecretAccessKey)

properties.setProperty(AWSConfigConstants.AWS_REGION, region)
properties.setProperty(AWSConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC")
properties.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, accessKeyId)
properties.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, secretAccessKey)


val initialPosition = lookup.optional(Keys.InitialPosition).getOrElse(Defaults.InitialPosition)
properties.setProperty(ConsumerConfigConstants.STREAM_INITIAL_POSITION, initialPosition)


val endpoint = lookup.optional(Keys.KinesisEndpoint)
val publisherType = lookup.optional(Keys.EnableEfo).map { enabledFlag =>
if (enabledFlag.toBoolean) ConsumerConfigConstants.RecordPublisherType.EFO.toString
else ConsumerConfigConstants.RecordPublisherType.POLLING.toString
}
val efoConsumerName = lookup.optional(Keys.EfoConsumerName)
val parallelism = lookup.optional(Keys.TaskParallelism).map(_.toInt).getOrElse(Defaults.Parallelism)

endpoint.foreach(properties.setProperty(AWSConfigConstants.AWS_ENDPOINT, _))
publisherType.foreach(properties.setProperty(ConsumerConfigConstants.RECORD_PUBLISHER_TYPE, _))
efoConsumerName.foreach(properties.setProperty(ConsumerConfigConstants.EFO_CONSUMER_NAME, _))

ConsumerConfig(properties, parallelism)
}

private final class PropertyLookup(props: Map[String, String], topicInfo: TopicInfo) {
def optional(key: String): Option[String] =
FlinkUtils.getProperty(key, props, topicInfo)

def required(key: String): String =
optional(key).getOrElse(missing(key))

def requiredOneOf(primary: String, fallback: String): String =
optional(primary).orElse(optional(fallback)).getOrElse(missing(s"$primary or $fallback"))

private def missing(name: String): Nothing =
throw new IllegalArgumentException(s"Missing required property: $name")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package ai.chronon.flink_connectors.kinesis

import org.apache.flink.api.common.serialization.DeserializationSchema
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema
import org.apache.flink.util.Collector

import java.util

class KinesisDeserializationSchemaWrapper[T](deserializationSchema: DeserializationSchema[T])
extends KinesisDeserializationSchema[T] {

override def open(context: DeserializationSchema.InitializationContext): Unit = {
deserializationSchema.open(context)
}

override def deserialize(
recordValue: Array[Byte],
partitionKey: String,
seqNum: String,
approxArrivalTimestamp: Long,
stream: String,
shardId: String
): T = {
val results = new util.ArrayList[T]()
val collector = new Collector[T] {
override def collect(record: T): Unit = results.add(record)
override def close(): Unit = {}
}

deserializationSchema.deserialize(recordValue, collector)

if (!results.isEmpty) results.get(0) else null.asInstanceOf[T]
}

override def getProducedType: TypeInformation[T] = deserializationSchema.getProducedType
}

Loading