forked from aws-samples/aws-glue-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSparkConnectorMySQL.scala
250 lines (212 loc) · 8.04 KB
/
SparkConnectorMySQL.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
/*
* Copyright 2016-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: MIT-0
*/
import collection.JavaConverters._
import java.sql.{Connection, DriverManager, ResultSet}
import java.util
import java.util.Optional
import org.apache.spark.internal.Logging
import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{EqualTo, Filter}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
* A simple Spark DataSource V2 with read and write support, the connector will connect to
* a local MySQL database and its employee table for reading/writing.
*/
class JdbcSourceV2 extends DataSourceV2 with ReadSupport with WriteSupport {
override def createReader(options: DataSourceOptions): JdbcDataSourceReader =
new JdbcDataSourceReader(
options.get("url").get(),
options.get("user").get(),
options.get("password").get(),
options.get("table").get()
)
override def createWriter(jobId: String, schema: StructType, mode: SaveMode,
options: DataSourceOptions): Optional[DataSourceWriter] = {
Optional.of(new MysqlDataSourceWriter(
options.get("url").get(),
options.get("user").get(),
options.get("password").get(),
options.get("table").get(),
schema)
)
}
}
class JdbcDataSourceReader(url: String,
user: String,
password: String,
table: String)
extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
// Assuming a fixed schema on premise.
var requiredSchema = StructType(Seq(
StructField("id", IntegerType),
StructField("emp_name", StringType),
StructField("dep_name", StringType),
StructField("salary", DecimalType(7, 2)),
StructField("age", DecimalType(3, 0))
))
var filters = Array.empty[Filter]
var wheres = Array.empty[String]
def readSchema: StructType = requiredSchema
override def planInputPartitions(): util.List[InputPartition[InternalRow]] = {
val columns = requiredSchema.fields.map(_.name)
Seq((1, 6), (7, 100)).map { case (minId, maxId) =>
val partition = s"id BETWEEN $minId AND $maxId"
new JdbcInputPartition(url, user, password, table, columns, wheres, partition)
.asInstanceOf[InputPartition[InternalRow]]
}.toList.asJava
}
def pruneColumns(requiredSchema: StructType): Unit = {
this.requiredSchema = requiredSchema
}
def pushFilters(filters: Array[Filter]): Array[Filter] = {
val supported = ListBuffer.empty[Filter]
val unsupported = ListBuffer.empty[Filter]
val wheres = ListBuffer.empty[String]
filters.foreach {
case filter: EqualTo =>
supported += filter
wheres += s"${filter.attribute} = '${filter.value}'"
case filter => unsupported += filter
}
this.filters = supported.toArray
this.wheres = wheres.toArray
unsupported.toArray
}
def pushedFilters: Array[Filter] = filters
}
class JdbcInputPartition(url: String,
user: String,
password: String,
table: String,
columns: Seq[String],
wheres: Seq[String],
partition: String)
extends InputPartition[InternalRow] {
def createPartitionReader(): JdbcDataReader =
new JdbcDataReader(url, user, password, table, columns, wheres, partition)
}
class JdbcDataReader(
url: String,
user: String,
password: String,
table: String,
columns: Seq[String],
wheres: Seq[String],
partition: String)
extends InputPartitionReader[InternalRow] {
// scalastyle:off
Class.forName("com.mysql.jdbc.Driver")
// scalastyle:on
private var conn: Connection = null
private var rs: ResultSet = null
def next(): Boolean = {
if (rs == null) {
conn = DriverManager.getConnection(url, user, password)
val sqlBuilder = new StringBuilder()
sqlBuilder ++= s"SELECT ${columns.mkString(", ")} FROM $table WHERE $partition"
if (wheres.nonEmpty) {
sqlBuilder ++= " AND " + wheres.mkString(" AND ")
}
val sql = sqlBuilder.toString
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(1000)
rs = stmt.executeQuery()
}
rs.next()
}
def get(): InternalRow = {
val values = columns.map {
case "id" => rs.getInt("id")
case "emp_name" => UTF8String.fromString(rs.getString("emp_name"))
case "dep_name" => UTF8String.fromString(rs.getString("dep_name"))
case "salary" => Decimal(rs.getBigDecimal("salary"))
case "age" => Decimal(rs.getBigDecimal("age"))
}
InternalRow.fromSeq(values)
}
def close(): Unit = {
conn.close()
}
}
class MysqlDataSourceWriter(url: String,
user: String,
password: String,
table: String,
rddSchema: StructType) extends DataSourceWriter with Logging {
override def createWriterFactory(): DataWriterFactory[InternalRow] = {
new MysqlDataWriterFactory(url, user, password, table, rddSchema)
}
override def commit(messages: Array[WriterCommitMessage]): Unit = {
// logic called if all partition write job succeeds
log.info("Write succeeded")
}
override def abort(messages: Array[WriterCommitMessage]): Unit = {
// logic called if write job fails
log.error("Write failed")
}
}
class MysqlDataWriterFactory(url: String,
user: String,
password: String,
table: String,
rddSchema: StructType)
extends DataWriterFactory[InternalRow] {
override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long):
DataWriter[InternalRow] = {
new JdbcDataWriter(url, user, password, table, rddSchema)
}
}
class JdbcDataWriter(url: String,
user: String,
password: String,
table: String,
rddSchema: StructType) extends DataWriter[InternalRow] {
var connection = DriverManager.getConnection(url, user, password)
val statement = getInsertStatement(table, rddSchema)
val numFields = rddSchema.fields.length
val preparedStatement = connection.prepareStatement(statement)
override def write(record: InternalRow): Unit = {
var i = 0
while (i < numFields) {
rddSchema.fields(i).dataType match {
case IntegerType => preparedStatement.setInt(i + 1, record.getInt(i))
case StringType => preparedStatement.setString(i + 1, record.getString(i))
case dt: DecimalType =>
preparedStatement.setBigDecimal(i + 1,
record.getDecimal(i, dt.precision, dt.scale).toJavaBigDecimal)
}
i = i + 1
}
preparedStatement.addBatch()
}
override def commit(): WriterCommitMessage = {
preparedStatement.executeBatch()
closeConnection
WriteSucceeded
}
override def abort(): Unit = {
closeConnection
}
private def closeConnection(): Unit = {
if (connection != null) {
connection.close()
connection = null
}
}
private def getInsertStatement(table: String,
rddSchema: StructType): String = {
val columns = rddSchema.fields.map(x => x.name).mkString(",")
s"INSERT INTO $table ($columns) " +
s"VALUES(${List.fill(rddSchema.fields.length)("?").mkString(",")})"
}
}
object WriteSucceeded extends WriterCommitMessage