-
Notifications
You must be signed in to change notification settings - Fork 2
/
spark_app.py
50 lines (46 loc) · 1.64 KB
/
spark_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import sys
from flask import Flask, request, jsonify
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark.ml.clustering import KMeans
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
@app.route('/uberdata')
def get_uber_data():
spark = SparkSession\
.builder\
.appName("Uber Dataset")\
.getOrCreate()
cluster_count = int(request.args.get('cluster_count'))
dataset = spark.read.csv('uberdata.csv', inferSchema =True, header="True")
assembler = VectorAssembler(inputCols=["Lat", "Lon"],outputCol="features")
dataset=assembler.transform(dataset)
(training, testdata) = dataset.randomSplit([0.7, 0.3], seed = 5043)
kmeans = KMeans().setK(cluster_count)
model = kmeans.fit(dataset)
transformed=model.transform(testdata).withColumnRenamed("prediction","cluster_id")
transformed.createOrReplaceTempView("data_table")
transformed.cache()
centerList=list()
centers1 = model.clusterCenters()
count=int()
for center in centers1:
temp_list=center.tolist()
temp_list.append(count)
centerList.append(temp_list)
count=count+1
centers=spark.createDataFrame(centerList)
centers.createOrReplaceTempView("centers")
resultsDFF = spark.sql("SELECT centers._1 as Longitude, centers._2 as Latitude FROM data_table, centers WHERE data_table.cluster_id=centers._3")
data=resultsDFF.groupBy("Longitude", "Latitude").count()
return jsonify(data.toJSON().collect())
def get_port():
port = os.getenv("PORT")
if type(port) == str:
return port
return 8080
if __name__ == "__main__":
app.run(host="0.0.0.0", port=get_port())