Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Adding responseFormat parameter in OpenAI Chat Completion Request #2322

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
Expand All @@ -16,10 +17,80 @@ import spray.json._

import scala.language.existentials

object OpenAIResponseFormat extends Enumeration {
case class ResponseFormat(paylodName: String, prompt: String) extends super.Val(paylodName)

val TEXT: ResponseFormat = ResponseFormat("text", "Output must be in text format")
val JSON: ResponseFormat = ResponseFormat("json_object", "Output must be in JSON format")

def asStringSet: Set[String] =
OpenAIResponseFormat.values.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].paylodName)
}


trait HasOpenAITextParamsExtended extends HasOpenAITextParams {
val responseFormat: ServiceParam[Map[String, String]] = new ServiceParam[Map[String, String]](
this,
"responseFormat",
"Response format for the completion. Can be 'json_object' or 'text'.",
isRequired = false) {
override val payloadName: String = "response_format"
}

def getResponseFormat: Map[String, String] = getScalarParam(responseFormat)

def setResponseFormat(value: Map[String, String]): this.type = {
val allowedFormat = OpenAIResponseFormat.asStringSet

// This test is to validate that value is properly formatted Map('type' -> '<format>')
if (value == null || value.size !=1 || !value.contains("type") || value("type").isEmpty) {
throw new IllegalArgumentException("Response format map must of the form Map('type' -> '<format>')"
+ " where <format> is one of " + allowedFormat.mkString(", "))
}

// This test is to validate that the format is one of the allowed formats
if (!allowedFormat.contains(value("type").toLowerCase)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " +
allowedFormat.mkString(", "))
}
setScalarParam(responseFormat, value)
}

def setResponseFormat(value: String): this.type = {
if (value == null || value.isEmpty) {
this
} else {
setResponseFormat(Map("type" -> value.toLowerCase))
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
setScalarParam(responseFormat, Map("type" -> value.paylodName))
}

// override this field to include the new parameter
override private[openai] val sharedTextParams: Seq[ServiceParam[_]] = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs,
responseFormat
)
}

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
with HasOpenAITextParamsExtended with HasMessagesInput with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -54,13 +125,27 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(

override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
// This method is used to convert the messages into a format that the OpenAI API can understand. The
// StringEntity returned by this method is in JSON format.
// It is marked as package private so that it can be tested.
private[openai] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
val mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
}
val fullPayload = optionalParams.updated("messages", mappedMessages)

val mappedMessagesWithAdditionalMessages = {
// if response format is JSON, then openAI expects the prompt to contain instruction to return JSON
// for this reason, we add an additional message to the system prompt to ensure that the response is in JSON
if (isSet(responseFormat) && getResponseFormat("type") == OpenAIResponseFormat.JSON.paylodName) {
mappedMessages :+ Map("role" -> "system",
"content" -> OpenAIResponseFormat.JSON.prompt)
} else {
mappedMessages
}
}
val fullPayload = optionalParams.updated("messages", mappedMessagesWithAdditionalMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams with HasMessagesInput
with HasOpenAITextParamsExtended with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -163,6 +163,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
if (isSet(responseFormat)) {
chatCompletion.setResponseFormat(getResponseFormat)
}
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)
Expand All @@ -183,6 +186,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for completion models")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.commons.io.IOUtils
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
Expand Down Expand Up @@ -151,6 +152,103 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
assert(Option(results.apply(2).getAs[Row]("out")).isEmpty)
}

test("getOptionalParam should include responseFormat"){
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)

def validateResponseFormat(params: Map[String, Any], responseFormat: String): Unit = {
val responseFormatPayloadName = this.completion.responseFormat.payloadName
assert(params.contains(responseFormatPayloadName))
val responseFormatMap = params(responseFormatPayloadName).asInstanceOf[Map[String, String]]
assert(responseFormatMap.contains("type"))
assert(responseFormatMap("type") == responseFormat)
}

val messages: Seq[Row] = Seq(
OpenAIMessage("user", "Whats your favorite color")
).toDF("role", "content", "name").collect()

val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)
assert(!optionalParams.contains("response_format"))

completion.setResponseFormat("")
val optionalParams0: Map[String, Any] = completion.getOptionalParams(messages.head)
assert(!optionalParams0.contains("response_format"))

completion.setResponseFormat("json_object")
val optionalParams1: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams1, "json_object")

completion.setResponseFormat("text")
val optionalParams2: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams2, "text")

completion.setResponseFormat(Map("type" -> "json_object"))
val optionalParams3: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams3, "json_object")

completion.setResponseFormat(OpenAIResponseFormat.TEXT)
val optionalParams4: Map[String, Any] = completion.getOptionalParams(messages.head)
validateResponseFormat(optionalParams4, "text")
}

test("setResponseFormat should throw exception if invalid format"){
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)

assertThrows[IllegalArgumentException] {
completion.setResponseFormat("invalid_format")
}

assertThrows[IllegalArgumentException] {
completion.setResponseFormat(Map("type" -> "invalid_format"))
}

assertThrows[IllegalArgumentException] {
completion.setResponseFormat(Map("invalid_key" -> "json_object"))
}
}

test("getStringEntity should add system message if response format is json_object") {
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)
.setMessagesCol("messages")
.setResponseFormat("json_object")
.setTemperature(0.7)

val messages: Seq[Row] = Seq(
OpenAIMessage("user", "Whats your favorite color")
).toDF("role", "content", "name").collect()

val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)

val expectedStringExcerpt = """{"role":"system","content":""" +"\"" + OpenAIResponseFormat.JSON.prompt + "\"}"
val actualStringEntity = completion.getStringEntity(messages, optionalParams)
val actualJsonString = IOUtils.toString(actualStringEntity.getContent, "UTF-8")
assert(actualJsonString.contains(expectedStringExcerpt))
}

test("getStringEntity should not add system message if response format is not json_object") {
val completion = new OpenAIChatCompletion()
.setDeploymentName(deploymentNameGpt4)
.setMessagesCol("messages")
.setResponseFormat("text")
.setTemperature(0.7)

val messages: Seq[Row] = Seq(
OpenAIMessage("user", "Whats your favorite color")
).toDF("role", "content", "name").collect()

val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)

val textFormatStringExcerpt = """{"role":"system","content":""" +"\"" + OpenAIResponseFormat.TEXT.prompt + "\"}"
val jsonFormatStringExcerpt = """{"role":"system","content":""" +"\"" + OpenAIResponseFormat.JSON.prompt + "\"}"
val actualStringEntity = completion.getStringEntity(messages, optionalParams)
val actualJsonString = IOUtils.toString(actualStringEntity.getContent, "UTF-8")
assert(!actualJsonString.contains(textFormatStringExcerpt))
assert(!actualJsonString.contains(jsonFormatStringExcerpt))
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ trait OpenAIAPIKey {
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val deploymentNameGpt4o: String = "gpt-4o"
lazy val modelNameGpt4: String = "gpt-4"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,30 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.get(0) != null))
}

test("Basic Usage JSON - Gpt 4o with responseFormat") {
val promptGpt4o: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4o)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
.setResponseFormat(OpenAIResponseFormat.JSON)

promptGpt4o.setPromptTemplate(
"""Split a word into prefix and postfix
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setResponseFormat("json_object")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down
Loading