Skip to content

Commit

Permalink
add support for huggingface inference
Browse files Browse the repository at this point in the history
  • Loading branch information
bparees committed Sep 6, 2023
1 parent a3feb62 commit 2ad70cd
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 54 deletions.
9 changes: 6 additions & 3 deletions cmd/wisdom/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/model"
hf "github.com/openshift/wisdom/pkg/model/huggingface"
"github.com/openshift/wisdom/pkg/model/ibm"
"github.com/openshift/wisdom/pkg/model/openai"
"github.com/openshift/wisdom/pkg/server"
Expand Down Expand Up @@ -213,8 +214,8 @@ func newInferCommand() *cobra.Command {
log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", o.provider, o.modelId, o.prompt)
response, err := model.InvokeModel(input, m)
if err != nil {
if response != nil && response.Output != "" {
log.Debugf("Response(Error):\n%s", response.Output)
if response.Error != "" {
log.Debugf("Response(Error):\n%s", response.Error)
}
return fmt.Errorf("error invoking the LLM: %v", err)
}
Expand Down Expand Up @@ -243,9 +244,11 @@ func initModels(config api.Config) map[string]api.Model {
switch m.Provider {
case "ibm":
models[m.Provider+"/"+m.ModelId] = ibm.NewIBMModel(m.ModelId, m.URL, m.UserId, m.APIKey)

case "openai":
models[m.Provider+"/"+m.ModelId] = openai.NewOpenAIModel(m.ModelId, m.URL, m.APIKey)
case "huggingface":
models[m.Provider+"/"+m.ModelId] = hf.NewHFModel(m.ModelId, m.URL, m.APIKey)

default:
log.Errorf("unknown provider: %s", m.Provider)
}
Expand Down
12 changes: 5 additions & 7 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ type Filter struct {
ResponseFilterChain []ResponseFilter
}

type InputFilter func(input *ModelInput) (*ModelInput, error)
type ResponseFilter func(response *ModelResponse) (*ModelResponse, error)
type InputFilter func(input ModelInput) (ModelInput, error)
type ResponseFilter func(response ModelResponse) (ModelResponse, error)

func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Filter {
filter := Filter{
Expand All @@ -20,7 +20,7 @@ func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Fil
return filter
}

func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) {
func (f Filter) FilterInput(input ModelInput) (ModelInput, error) {
output := input
var err error
for _, filter := range f.InputFilterChain {
Expand All @@ -32,7 +32,7 @@ func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) {
return output, err
}

func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error) {
func (f Filter) FilterResponse(response ModelResponse) (ModelResponse, error) {
output := response
var err error
for _, filter := range f.ResponseFilterChain {
Expand All @@ -45,10 +45,8 @@ func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error)
}

type Model interface {
Invoke(ModelInput) (*ModelResponse, error)
Invoke(ModelInput) (ModelResponse, error)
GetFilter() Filter
//FilterInput(*ModelInput) (*ModelInput, error)
//FilterResponse(*ModelResponse) (*ModelResponse, error)
}

// ModelInput represents the payload for the prompt_request endpoint.
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/markdown/markdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var (
markdownRegex = regexp.MustCompile("(?s)`{3}.*?\n(.*)`{3}")
)

func MarkdownStripper(response *api.ModelResponse) (*api.ModelResponse, error) {
func MarkdownStripper(response api.ModelResponse) (api.ModelResponse, error) {

if response.Output == "" {
return response, fmt.Errorf("response output is empty")
Expand Down
2 changes: 1 addition & 1 deletion pkg/filters/yaml/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/openshift/wisdom/pkg/api"
)

func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) {
func YamlLinter(response api.ModelResponse) (api.ModelResponse, error) {
if err := isValidYAML(response.Output); err != nil {
return response, fmt.Errorf("response output is not valid YAML: %s", err)
}
Expand Down
70 changes: 70 additions & 0 deletions pkg/model/huggingface/model_huggingface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package ibm

import (
"context"
"fmt"

huggingface "github.com/hupe1980/go-huggingface"

"github.com/openshift/wisdom/pkg/api"
)

type HFModel struct {
modelId string
url string
apiKey string
filter api.Filter
}

func NewHFModel(modelId, url, apiKey string) *HFModel {
//filter := api.NewFilter(nil, []api.ResponseFilter{markdown.MarkdownStripper, yaml.YamlLinter})
filter := api.Filter{}

return &HFModel{
modelId: modelId,
url: url,
apiKey: apiKey,
filter: filter,
}
}

func (m *HFModel) GetFilter() api.Filter {
return m.filter
}

func (m *HFModel) Invoke(input api.ModelInput) (api.ModelResponse, error) {

if input.APIKey == "" && m.apiKey == "" {
return api.ModelResponse{}, fmt.Errorf("api key is required, none provided")
}

apiKey := m.apiKey
if input.APIKey != "" {
apiKey = input.APIKey
}
client := huggingface.NewInferenceClient(apiKey)
client.SetModel(m.modelId)

req := &huggingface.TextGenerationRequest{
Inputs: input.Prompt,
Model: input.ModelId,
}

a := 100
req.Parameters.MaxNewTokens = &a
b := 30.0
req.Parameters.MaxTime = &b
c := 2
req.Parameters.NumReturnSequences = &c

resp, err := client.TextGeneration(context.Background(), req)
if err != nil {
return api.ModelResponse{}, fmt.Errorf("error making api request: %v", err)
}

response := api.ModelResponse{}
response.Input = input.Prompt
response.Output = resp.GeneratedText

return response, err
}
25 changes: 10 additions & 15 deletions pkg/model/ibm/model_ibm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ import (
"github.com/openshift/wisdom/pkg/filters/yaml"
)

const (
PROVIDER_ID = "ibm"
MODEL_ID = "L3Byb2plY3RzL2czYmNfc3RhY2tfc3RnMl9lcG9jaDNfanVsXzMx"
)

type IBMModelRequestPayload struct {
Prompt string `json:"prompt"`
ModelID string `json:"model_id"`
Expand Down Expand Up @@ -56,13 +51,13 @@ func (m *IBMModel) GetFilter() api.Filter {
return m.filter
}

func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
func (m *IBMModel) Invoke(input api.ModelInput) (api.ModelResponse, error) {

if input.UserId == "" && m.userId == "" {
return nil, fmt.Errorf("user email address is required, none provided")
return api.ModelResponse{}, fmt.Errorf("user email address is required, none provided")
}
if input.APIKey == "" && m.apiKey == "" {
return nil, fmt.Errorf("api key is required, none provided")
return api.ModelResponse{}, fmt.Errorf("api key is required, none provided")
}

apiKey, userId := m.apiKey, m.userId
Expand All @@ -84,14 +79,14 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
jsonPayload, err := json.Marshal(payload)
if err != nil {
//fmt.Println("Error encoding JSON:", err)
return nil, err
return api.ModelResponse{}, err
}

apiURL := m.url + "/api/v1/jobs"
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonPayload))
if err != nil {
//fmt.Println("Error creating HTTP request:", err)
return nil, err
return api.ModelResponse{}, err
}

// Set the "Content-Type" header to "application/json"
Expand All @@ -106,21 +101,21 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
resp, err := client.Do(req)
if err != nil {
//fmt.Println("Error making API request:", err)
return nil, err
return api.ModelResponse{}, err
}
defer resp.Body.Close()

// Check the response status code
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status: %v", resp.Status)
return api.ModelResponse{}, fmt.Errorf("API request failed with status: %v", resp.Status)
}

// Parse the JSON response into the APIResponse struct
var apiResp IBMModelResponsePayload
err = json.NewDecoder(resp.Body).Decode(&apiResp)
if err != nil {
//fmt.Println("Error decoding API response:", err)
return nil, err
return api.ModelResponse{}, err
}
response := api.ModelResponse{}
response.Input = input.Prompt
Expand All @@ -129,9 +124,9 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
//output := apiResp.AllTokens[len(apiResp.InputTokens):]
response.RequestID = apiResp.JobID

return &response, err
return response, err
}

func (m *IBMModel) FilterInput(input *api.ModelInput) (*api.ModelInput, error) {
func (m *IBMModel) FilterInput(input api.ModelInput) (api.ModelInput, error) {
return m.filter.FilterInput(input)
}
16 changes: 11 additions & 5 deletions pkg/model/invoker.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
package model

import (
"fmt"

log "github.com/sirupsen/logrus"

"github.com/openshift/wisdom/pkg/api"
)

func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, error) {
response, err := model.Invoke(input)
if response == nil {
response = &api.ModelResponse{}
func InvokeModel(input api.ModelInput, model api.Model) (api.ModelResponse, error) {
log.Debugf("model input:\n%#v", input)
input, err := model.GetFilter().FilterInput(input)
if err != nil {
return api.ModelResponse{}, fmt.Errorf("error filtering input: %s", err)
}
log.Debugf("model response: %#v", response)
log.Debugf("model filtered input:\n%#v", input)
response, err := model.Invoke(input)
log.Debugf("model response:\n%#v\nerror: %v", response, err)
if err != nil {
response.Error = err.Error()
return response, err
Expand All @@ -21,5 +26,6 @@ func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, err
if err != nil {
response.Error = err.Error()
}
log.Debugf("model filtered output:\n%#v", output)
return output, err
}
23 changes: 9 additions & 14 deletions pkg/model/openai/model_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"github.com/openshift/wisdom/pkg/api"
)

const (
PROVIDER_ID = "openai"
MODEL_ID = "gpt-3.5-turbo"
)

// OpenAI
type OpenAIMessage struct {
Role string `json:"role"`
Expand Down Expand Up @@ -55,10 +50,10 @@ func (m *OpenAIModel) GetFilter() api.Filter {
return m.filter
}

func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
func (m *OpenAIModel) Invoke(input api.ModelInput) (api.ModelResponse, error) {

if input.APIKey == "" && m.apiKey == "" {
return nil, fmt.Errorf("api key is required, none provided")
return api.ModelResponse{}, fmt.Errorf("api key is required, none provided")
}

apiKey := m.apiKey
Expand All @@ -75,14 +70,14 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
jsonPayload, err := json.Marshal(payload)
if err != nil {
//fmt.Println("Error encoding JSON:", err)
return nil, err
return api.ModelResponse{}, err
}

apiURL := m.url + "/v1/chat/completions"
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonPayload))
if err != nil {
//fmt.Println("Error creating HTTP request:", err)
return nil, err
return api.ModelResponse{}, err
}

// Set the "Content-Type" header to "application/json"
Expand All @@ -97,28 +92,28 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
resp, err := client.Do(req)
if err != nil {
//fmt.Println("Error making API request:", err)
return nil, err
return api.ModelResponse{}, err
}
defer resp.Body.Close()

// Check the response status code
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status: %s", resp.Status)
return api.ModelResponse{}, fmt.Errorf("API request failed with status: %s", resp.Status)
}

// Parse the JSON response into the APIResponse struct
var apiResp OpenAIModelResponsePayload
err = json.NewDecoder(resp.Body).Decode(&apiResp)
if err != nil {
fmt.Println("Error decoding API response:", err)
return nil, err
return api.ModelResponse{}, err
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("model returned no valid responses: %v", apiResp)
return api.ModelResponse{}, fmt.Errorf("model returned no valid responses: %v", apiResp)
}
response := api.ModelResponse{}
response.Input = input.Prompt
response.Output = apiResp.Choices[0].Message.Content
response.RawOutput = apiResp.Choices[0].Message.Content
return &response, err
return response, err
}
19 changes: 11 additions & 8 deletions pkg/server/inference_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) {
log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", payload.Provider, payload.ModelId, payload.Prompt)

response, err := model.InvokeModel(payload, m)
if err != nil {
log.Errorf("failed to invoke model: %v", err)
http.Error(w, "Failed to invoke model", http.StatusInternalServerError)
return
}

buf := bytes.Buffer{}
if response != nil {
err = json.NewEncoder(&buf).Encode(response)
if err != nil {
w.WriteHeader(http.StatusExpectationFailed)
w.Write([]byte(err.Error()))
return
}
err = json.NewEncoder(&buf).Encode(response)
if err != nil {
w.WriteHeader(http.StatusExpectationFailed)
w.Write([]byte(err.Error()))
return
}

w.Header().Set("Content-Type", "text/json")
if err != nil || (response != nil && response.Error != "") {
if err != nil || (response.Error != "") {
log.Debugf("model invocation returning error: %v", err)
w.WriteHeader(http.StatusExpectationFailed)
} else {
Expand Down

0 comments on commit 2ad70cd

Please sign in to comment.