Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [DRAFT - DO NOT MERGE] df.ai.gen - allow PySpark users to easily leverage OpenAI prompting #2267

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
import os, json, subprocess, unittest

if sys.version >= "3":
basestring = str

import pyspark
from pyspark import SparkContext
from pyspark import sql
from pyspark.ml.param.shared import *
from pyspark.rdd import RDD

from pyspark.sql import SparkSession, SQLContext

from synapse.ml.core.init_spark import *
spark = init_spark()
sc = SQLContext(spark.sparkContext)

class AIFunctions:
def __init__(self, df):
self.df = df
self.subscriptionKey = None
self.deploymentName = None
self.customServiceName = None

def setup(self, subscriptionKey = None, deploymentName = None, customServiceName = None):
self.subscriptionKey = subscriptionKey
self.deploymentName = deploymentName
self.customServiceName = customServiceName

def gen(self, template, outputCol = None, **options):
jvm = SparkContext.getOrCreate()._jvm
prompt = jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIPrompt()
prompt = prompt.setSubscriptionKey(self.subscriptionKey)
prompt = prompt.setDeploymentName(self.deploymentName)
prompt = prompt.setCustomServiceName(self.customServiceName)
prompt = prompt.setOutputCol(outputCol)
prompt = prompt.setPromptTemplate(template)
results = prompt.transform(self.df._jdf)
results.createOrReplaceTempView("my_temp_view")
results = spark.sql("SELECT * FROM my_temp_view")
return results

def get_AI_functions(df):
if not hasattr(df, "_ai_instance"):
df._ai_instance = AIFunctions(df)
return df._ai_instance

setattr(pyspark.sql.DataFrame, "ai", property(get_AI_functions))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from DataFrameAIExtensions import *
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
setUrl(v + urlPath.stripPrefix("/"))
}

override def getUrl: String = this.getOrDefault(url)
override def getUrl: String = "https://synapseml-openai-2.openai.azure.com/openai/deployments/gpt-4/chat/completions"

def setDefaultInternalEndpoint(v: String): this.type = setDefault(
url, v + s"/cognitive/${this.internalServiceType}/" + urlPath.stripPrefix("/"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

# Prepare training and test data.
import unittest
import os, json, subprocess, unittest

from synapse.ml.io.http import *
from pyspark.sql.types import *
from synapse.ml.services.openai import *

from pyspark.sql import SparkSession, SQLContext

from synapse.ml.core.init_spark import *
spark = init_spark()
sc = SQLContext(spark.sparkContext)

class DataFrameAIExtentionsTest(unittest.TestCase):
def test_gen(self):
schema = StructType([
StructField("text", StringType(), True),
StructField("category", StringType(), True)
])

data = [
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
]

df = spark.createDataFrame(data, schema)

secretJson = subprocess.check_output(
"az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2",
shell=True,
)
openai_api_key = json.loads(secretJson)["value"]
print(openai_api_key)
Fixed Show fixed Hide fixed

df.ai.setup(subscriptionKey=openai_api_key, deploymentName="gpt-35-turbo", customServiceName="synapseml-openai-2")

results = df.ai.gen("Complete this comma separated list of 5 {category}: {text}, ", outputCol="outParsed")
results.select("outParsed").show(truncate = False)
nonNullCount = results.filter(col("outParsed").isNotNull()).count()
assert (nonNullCount == 3)

def test_gen_2(self):
schema = StructType([
StructField("name", StringType(), True),
StructField("address", StringType(), True)
])

data = [
("Anne F.", "123 First Street, 98053"),
("George K.", "345 Washington Avenue, London"),
]

df = spark.createDataFrame(data, schema)

secretJson = subprocess.check_output(
"az keyvault secret show --vault-name mmlspark-build-keys --name openai-api-key-2",
shell=True,
)
openai_api_key = json.loads(secretJson)["value"]

df.ai.setup(subscriptionKey=openai_api_key, deploymentName="gpt-35-turbo", customServiceName="synapseml-openai-2")
results = df.ai.gen("Generate the likely country of {name}, given that they are from {address}. It is imperitive that your response contains the country only, no elaborations.", outputCol="outParsed")
results.select("outParsed").show(truncate = False)
nonNullCount = results.filter(col("outParsed").isNotNull()).count()
assert (nonNullCount == 2)

if __name__ == "__main__":
result = unittest.main()
Loading