diff --git a/cmd/wisdom/main.go b/cmd/wisdom/main.go index 6dfae5e..5f21b24 100644 --- a/cmd/wisdom/main.go +++ b/cmd/wisdom/main.go @@ -16,6 +16,7 @@ import ( "github.com/openshift/wisdom/pkg/api" "github.com/openshift/wisdom/pkg/model" + hf "github.com/openshift/wisdom/pkg/model/huggingface" "github.com/openshift/wisdom/pkg/model/ibm" "github.com/openshift/wisdom/pkg/model/openai" "github.com/openshift/wisdom/pkg/server" @@ -213,8 +214,8 @@ func newInferCommand() *cobra.Command { log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", o.provider, o.modelId, o.prompt) response, err := model.InvokeModel(input, m) if err != nil { - if response != nil && response.Output != "" { - log.Debugf("Response(Error):\n%s", response.Output) + if response.Error != "" { + log.Debugf("Response(Error):\n%s", response.Error) } return fmt.Errorf("error invoking the LLM: %v", err) } @@ -243,9 +244,11 @@ func initModels(config api.Config) map[string]api.Model { switch m.Provider { case "ibm": models[m.Provider+"/"+m.ModelId] = ibm.NewIBMModel(m.ModelId, m.URL, m.UserId, m.APIKey) - case "openai": models[m.Provider+"/"+m.ModelId] = openai.NewOpenAIModel(m.ModelId, m.URL, m.APIKey) + case "huggingface": + models[m.Provider+"/"+m.ModelId] = hf.NewHFModel(m.ModelId, m.URL, m.APIKey) + default: log.Errorf("unknown provider: %s", m.Provider) } diff --git a/pkg/api/types.go b/pkg/api/types.go index 19d0098..e5b1180 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -9,8 +9,8 @@ type Filter struct { ResponseFilterChain []ResponseFilter } -type InputFilter func(input *ModelInput) (*ModelInput, error) -type ResponseFilter func(response *ModelResponse) (*ModelResponse, error) +type InputFilter func(input ModelInput) (ModelInput, error) +type ResponseFilter func(response ModelResponse) (ModelResponse, error) func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Filter { filter := Filter{ @@ -20,7 +20,7 @@ func NewFilter(inputFilters []InputFilter, responseFilters []ResponseFilter) Fil return filter } -func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) { +func (f Filter) FilterInput(input ModelInput) (ModelInput, error) { output := input var err error for _, filter := range f.InputFilterChain { @@ -32,7 +32,7 @@ func (f *Filter) FilterInput(input *ModelInput) (*ModelInput, error) { return output, err } -func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error) { +func (f Filter) FilterResponse(response ModelResponse) (ModelResponse, error) { output := response var err error for _, filter := range f.ResponseFilterChain { @@ -45,10 +45,8 @@ func (f Filter) FilterResponse(response *ModelResponse) (*ModelResponse, error) } type Model interface { - Invoke(ModelInput) (*ModelResponse, error) + Invoke(ModelInput) (ModelResponse, error) GetFilter() Filter - //FilterInput(*ModelInput) (*ModelInput, error) - //FilterResponse(*ModelResponse) (*ModelResponse, error) } // ModelInput represents the payload for the prompt_request endpoint. diff --git a/pkg/filters/markdown/markdown.go b/pkg/filters/markdown/markdown.go index e8010a4..b3c5358 100644 --- a/pkg/filters/markdown/markdown.go +++ b/pkg/filters/markdown/markdown.go @@ -14,7 +14,7 @@ var ( markdownRegex = regexp.MustCompile("(?s)`{3}.*?\n(.*)`{3}") ) -func MarkdownStripper(response *api.ModelResponse) (*api.ModelResponse, error) { +func MarkdownStripper(response api.ModelResponse) (api.ModelResponse, error) { if response.Output == "" { return response, fmt.Errorf("response output is empty") diff --git a/pkg/filters/yaml/yaml.go b/pkg/filters/yaml/yaml.go index 2ef05a4..ab1fc3c 100644 --- a/pkg/filters/yaml/yaml.go +++ b/pkg/filters/yaml/yaml.go @@ -10,7 +10,7 @@ import ( "github.com/openshift/wisdom/pkg/api" ) -func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) { +func YamlLinter(response api.ModelResponse) (api.ModelResponse, error) { if err := isValidYAML(response.Output); err != nil { return response, fmt.Errorf("response output is not valid YAML: %s", err) } diff --git a/pkg/model/huggingface/model_huggingface.go b/pkg/model/huggingface/model_huggingface.go new file mode 100644 index 0000000..6b3685b --- /dev/null +++ b/pkg/model/huggingface/model_huggingface.go @@ -0,0 +1,70 @@ +package ibm + +import ( + "context" + "fmt" + + huggingface "github.com/hupe1980/go-huggingface" + + "github.com/openshift/wisdom/pkg/api" +) + +type HFModel struct { + modelId string + url string + apiKey string + filter api.Filter +} + +func NewHFModel(modelId, url, apiKey string) *HFModel { + //filter := api.NewFilter(nil, []api.ResponseFilter{markdown.MarkdownStripper, yaml.YamlLinter}) + filter := api.Filter{} + + return &HFModel{ + modelId: modelId, + url: url, + apiKey: apiKey, + filter: filter, + } +} + +func (m *HFModel) GetFilter() api.Filter { + return m.filter +} + +func (m *HFModel) Invoke(input api.ModelInput) (api.ModelResponse, error) { + + if input.APIKey == "" && m.apiKey == "" { + return api.ModelResponse{}, fmt.Errorf("api key is required, none provided") + } + + apiKey := m.apiKey + if input.APIKey != "" { + apiKey = input.APIKey + } + client := huggingface.NewInferenceClient(apiKey) + client.SetModel(m.modelId) + + req := &huggingface.TextGenerationRequest{ + Inputs: input.Prompt, + Model: input.ModelId, + } + + a := 100 + req.Parameters.MaxNewTokens = &a + b := 30.0 + req.Parameters.MaxTime = &b + c := 2 + req.Parameters.NumReturnSequences = &c + + resp, err := client.TextGeneration(context.Background(), req) + if err != nil { + return api.ModelResponse{}, fmt.Errorf("error making api request: %v", err) + } + + response := api.ModelResponse{} + response.Input = input.Prompt + response.Output = resp.GeneratedText + + return response, err +} diff --git a/pkg/model/ibm/model_ibm.go b/pkg/model/ibm/model_ibm.go index b791f2c..90dc9b0 100644 --- a/pkg/model/ibm/model_ibm.go +++ b/pkg/model/ibm/model_ibm.go @@ -11,11 +11,6 @@ import ( "github.com/openshift/wisdom/pkg/filters/yaml" ) -const ( - PROVIDER_ID = "ibm" - MODEL_ID = "L3Byb2plY3RzL2czYmNfc3RhY2tfc3RnMl9lcG9jaDNfanVsXzMx" -) - type IBMModelRequestPayload struct { Prompt string `json:"prompt"` ModelID string `json:"model_id"` @@ -56,13 +51,13 @@ func (m *IBMModel) GetFilter() api.Filter { return m.filter } -func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { +func (m *IBMModel) Invoke(input api.ModelInput) (api.ModelResponse, error) { if input.UserId == "" && m.userId == "" { - return nil, fmt.Errorf("user email address is required, none provided") + return api.ModelResponse{}, fmt.Errorf("user email address is required, none provided") } if input.APIKey == "" && m.apiKey == "" { - return nil, fmt.Errorf("api key is required, none provided") + return api.ModelResponse{}, fmt.Errorf("api key is required, none provided") } apiKey, userId := m.apiKey, m.userId @@ -84,14 +79,14 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { jsonPayload, err := json.Marshal(payload) if err != nil { //fmt.Println("Error encoding JSON:", err) - return nil, err + return api.ModelResponse{}, err } apiURL := m.url + "/api/v1/jobs" req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonPayload)) if err != nil { //fmt.Println("Error creating HTTP request:", err) - return nil, err + return api.ModelResponse{}, err } // Set the "Content-Type" header to "application/json" @@ -106,13 +101,13 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { resp, err := client.Do(req) if err != nil { //fmt.Println("Error making API request:", err) - return nil, err + return api.ModelResponse{}, err } defer resp.Body.Close() // Check the response status code if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status: %v", resp.Status) + return api.ModelResponse{}, fmt.Errorf("API request failed with status: %v", resp.Status) } // Parse the JSON response into the APIResponse struct @@ -120,7 +115,7 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { err = json.NewDecoder(resp.Body).Decode(&apiResp) if err != nil { //fmt.Println("Error decoding API response:", err) - return nil, err + return api.ModelResponse{}, err } response := api.ModelResponse{} response.Input = input.Prompt @@ -129,9 +124,9 @@ func (m *IBMModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { //output := apiResp.AllTokens[len(apiResp.InputTokens):] response.RequestID = apiResp.JobID - return &response, err + return response, err } -func (m *IBMModel) FilterInput(input *api.ModelInput) (*api.ModelInput, error) { +func (m *IBMModel) FilterInput(input api.ModelInput) (api.ModelInput, error) { return m.filter.FilterInput(input) } diff --git a/pkg/model/invoker.go b/pkg/model/invoker.go index a2a59a9..6b71db7 100644 --- a/pkg/model/invoker.go +++ b/pkg/model/invoker.go @@ -1,17 +1,22 @@ package model import ( + "fmt" + log "github.com/sirupsen/logrus" "github.com/openshift/wisdom/pkg/api" ) -func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, error) { - response, err := model.Invoke(input) - if response == nil { - response = &api.ModelResponse{} +func InvokeModel(input api.ModelInput, model api.Model) (api.ModelResponse, error) { + log.Debugf("model input:\n%#v", input) + input, err := model.GetFilter().FilterInput(input) + if err != nil { + return api.ModelResponse{}, fmt.Errorf("error filtering input: %s", err) } - log.Debugf("model response: %#v", response) + log.Debugf("model filtered input:\n%#v", input) + response, err := model.Invoke(input) + log.Debugf("model response:\n%#v\nerror: %v", response, err) if err != nil { response.Error = err.Error() return response, err @@ -21,5 +26,6 @@ func InvokeModel(input api.ModelInput, model api.Model) (*api.ModelResponse, err if err != nil { response.Error = err.Error() } + log.Debugf("model filtered output:\n%#v", output) return output, err } diff --git a/pkg/model/openai/model_openai.go b/pkg/model/openai/model_openai.go index 2e1930f..1750dbc 100644 --- a/pkg/model/openai/model_openai.go +++ b/pkg/model/openai/model_openai.go @@ -9,11 +9,6 @@ import ( "github.com/openshift/wisdom/pkg/api" ) -const ( - PROVIDER_ID = "openai" - MODEL_ID = "gpt-3.5-turbo" -) - // OpenAI type OpenAIMessage struct { Role string `json:"role"` @@ -55,10 +50,10 @@ func (m *OpenAIModel) GetFilter() api.Filter { return m.filter } -func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { +func (m *OpenAIModel) Invoke(input api.ModelInput) (api.ModelResponse, error) { if input.APIKey == "" && m.apiKey == "" { - return nil, fmt.Errorf("api key is required, none provided") + return api.ModelResponse{}, fmt.Errorf("api key is required, none provided") } apiKey := m.apiKey @@ -75,14 +70,14 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { jsonPayload, err := json.Marshal(payload) if err != nil { //fmt.Println("Error encoding JSON:", err) - return nil, err + return api.ModelResponse{}, err } apiURL := m.url + "/v1/chat/completions" req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonPayload)) if err != nil { //fmt.Println("Error creating HTTP request:", err) - return nil, err + return api.ModelResponse{}, err } // Set the "Content-Type" header to "application/json" @@ -97,13 +92,13 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { resp, err := client.Do(req) if err != nil { //fmt.Println("Error making API request:", err) - return nil, err + return api.ModelResponse{}, err } defer resp.Body.Close() // Check the response status code if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status: %s", resp.Status) + return api.ModelResponse{}, fmt.Errorf("API request failed with status: %s", resp.Status) } // Parse the JSON response into the APIResponse struct @@ -111,14 +106,14 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { err = json.NewDecoder(resp.Body).Decode(&apiResp) if err != nil { fmt.Println("Error decoding API response:", err) - return nil, err + return api.ModelResponse{}, err } if len(apiResp.Choices) == 0 { - return nil, fmt.Errorf("model returned no valid responses: %v", apiResp) + return api.ModelResponse{}, fmt.Errorf("model returned no valid responses: %v", apiResp) } response := api.ModelResponse{} response.Input = input.Prompt response.Output = apiResp.Choices[0].Message.Content response.RawOutput = apiResp.Choices[0].Message.Content - return &response, err + return response, err } diff --git a/pkg/server/inference_handler.go b/pkg/server/inference_handler.go index 5245d6c..fef3c18 100644 --- a/pkg/server/inference_handler.go +++ b/pkg/server/inference_handler.go @@ -49,19 +49,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) { log.Debugf("Using provider/model %s/%s for prompt:\n%s\n", payload.Provider, payload.ModelId, payload.Prompt) response, err := model.InvokeModel(payload, m) + if err != nil { + log.Errorf("failed to invoke model: %v", err) + http.Error(w, "Failed to invoke model", http.StatusInternalServerError) + return + } buf := bytes.Buffer{} - if response != nil { - err = json.NewEncoder(&buf).Encode(response) - if err != nil { - w.WriteHeader(http.StatusExpectationFailed) - w.Write([]byte(err.Error())) - return - } + err = json.NewEncoder(&buf).Encode(response) + if err != nil { + w.WriteHeader(http.StatusExpectationFailed) + w.Write([]byte(err.Error())) + return } w.Header().Set("Content-Type", "text/json") - if err != nil || (response != nil && response.Error != "") { + if err != nil || (response.Error != "") { log.Debugf("model invocation returning error: %v", err) w.WriteHeader(http.StatusExpectationFailed) } else {