diff --git a/config/sample.cfg.yaml b/config/sample.cfg.yaml index fd98de5..512afd9 100644 --- a/config/sample.cfg.yaml +++ b/config/sample.cfg.yaml @@ -13,6 +13,6 @@ userId: $USERID apiKey: $APIKEY - provider: openai - modelId: chatgpt-3.5-turbo + modelId: gpt-3.5-turbo url: https://api.openai.com apiKey: $APIKEY diff --git a/pkg/api/types.go b/pkg/api/types.go index a8a0ba7..ec5e0c0 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -8,7 +8,7 @@ type Model interface { type ModelInput struct { UserId string `json:"userid"` APIKey string `json:"apikey"` - ModelId string `json:"modelid"` + ModelId string `json:"modelId"` Provider string `json:"provider"` Prompt string `json:"prompt"` Context string `json:"context"` @@ -18,10 +18,11 @@ type ModelInput struct { type ModelResponse struct { Input string `json:"input_tokens"` Status string `json:"status"` - RequestID string `json:"request_id"` - ConversationID string `json:"conversation_id"` + RequestID string `json:"requestId"` + ConversationID string `json:"conversationId"` Output string `json:"output"` RawOutput string `json:"raw_output"` + ErrorMessage string `json:"error"` } // FeedbackPayload represents the payload for the feedback endpoint. diff --git a/pkg/filters/filters.go b/pkg/filters/filters.go index 7425d61..6d95d4a 100644 --- a/pkg/filters/filters.go +++ b/pkg/filters/filters.go @@ -1,9 +1,9 @@ package filters import ( - "gopkg.in/yaml.v2" - "github.com/openshift/wisdom/pkg/api" + "github.com/openshift/wisdom/pkg/filters/markdown" + "github.com/openshift/wisdom/pkg/filters/yaml" ) type Filter struct { @@ -15,7 +15,7 @@ type ResponseFilter func(response *api.ModelResponse) (*api.ModelResponse, error func NewFilter() Filter { filter := Filter{} - //filter.responseFilterChain = append(filter.responseFilterChain, YamlLinter) + filter.responseFilterChain = append(filter.responseFilterChain, markdown.MarkdownStripper, yaml.YamlLinter) return filter } @@ -42,13 +42,3 @@ func (f *Filter) FilterResponse(response *api.ModelResponse) (*api.ModelResponse } return output, err } - -func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) { - return response, isValidYAML(response.Output) -} - -func isValidYAML(yamlString string) error { - var data interface{} - err := yaml.Unmarshal([]byte(yamlString), &data) - return err -} diff --git a/pkg/filters/markdown/markdown.go b/pkg/filters/markdown/markdown.go new file mode 100644 index 0000000..dc41e95 --- /dev/null +++ b/pkg/filters/markdown/markdown.go @@ -0,0 +1,41 @@ +package markdown + +import ( + "fmt" + "regexp" + + //gomarkdown "github.com/gomarkdown/markdown" + log "github.com/sirupsen/logrus" + + "github.com/openshift/wisdom/pkg/api" +) + +var ( + markdownRegex = regexp.MustCompile("(?s)`{3}.*?\n(.*)`{3}") +) + +func MarkdownStripper(response *api.ModelResponse) (*api.ModelResponse, error) { + + if response.Output == "" { + return response, fmt.Errorf("response output is empty") + } + log.Debugf("Stripping markdown from response:\n %s\n", response.Output) + + //response.Output = markdownRegex.ReplaceAllString(response.Output, "") + matches := markdownRegex.FindStringSubmatch(response.Output) + response.Output = matches[1] + /* + node := gomarkdown.Parse([]byte(response.Output), nil) + + fmt.Printf("%#v\n", node.GetChildren()[0].GetChildren()[0]) + //fmt.Printf("%#v\n", node.GetChildren()[0].GetChildren()[1]) + //fmt.Printf("%s\n", node.GetChildren()[0].GetChildren()[1].AsLeaf().Literal) + + response.Output = string(node.GetChildren()[0].GetChildren()[0].AsLeaf().Literal) + */ + + //response.Output = stripmd.Strip(response.Output) + + log.Debugf("Stripped markdown from response:\n %s\n", response.Output) + return response, nil +} diff --git a/pkg/filters/yaml/yaml.go b/pkg/filters/yaml/yaml.go new file mode 100644 index 0000000..805b0e0 --- /dev/null +++ b/pkg/filters/yaml/yaml.go @@ -0,0 +1,27 @@ +package yaml + +import ( + "fmt" + + "gopkg.in/yaml.v2" + + log "github.com/sirupsen/logrus" + + "github.com/openshift/wisdom/pkg/api" +) + +func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) { + if err := isValidYAML(response.Output); err != nil { + return response, fmt.Errorf("response output is not valid YAML: %s", err) + } + return response, nil +} + +func isValidYAML(yamlString string) error { + var data interface{} + + log.Debugf("Validating YAML:\n%s\n", yamlString) + err := yaml.Unmarshal([]byte(yamlString), &data) + + return err +} diff --git a/pkg/model/invoker.go b/pkg/model/invoker.go index 17f5374..9a258dd 100644 --- a/pkg/model/invoker.go +++ b/pkg/model/invoker.go @@ -8,9 +8,16 @@ import ( func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) (*api.ModelResponse, error) { response, err := model.Invoke(input) + if response == nil { + response = &api.ModelResponse{} + } if err != nil { + response.ErrorMessage = err.Error() return response, err } output, err := filter.FilterResponse(response) + if err != nil { + response.ErrorMessage = err.Error() + } return output, err } diff --git a/pkg/model/openai/model_openai.go b/pkg/model/openai/model_openai.go index 8d2e779..8324af3 100644 --- a/pkg/model/openai/model_openai.go +++ b/pkg/model/openai/model_openai.go @@ -111,5 +111,6 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) { response := api.ModelResponse{} response.Input = input.Prompt response.Output = apiResp.Choices[0].Message.Content + response.RawOutput = apiResp.Choices[0].Message.Content return &response, err } diff --git a/pkg/server/inference_handler.go b/pkg/server/inference_handler.go index 83c62cd..a7c0f8f 100644 --- a/pkg/server/inference_handler.go +++ b/pkg/server/inference_handler.go @@ -36,21 +36,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) { return } response, err := model.InvokeModel(payload, m, h.Filter) - if err != nil { - w.WriteHeader(http.StatusExpectationFailed) - w.Write([]byte(err.Error())) - return - } buf := bytes.Buffer{} - err = json.NewEncoder(&buf).Encode(response) - if err != nil { - w.WriteHeader(http.StatusExpectationFailed) - w.Write([]byte(err.Error())) - return + if response != nil { + err = json.NewEncoder(&buf).Encode(response) + if err != nil { + w.WriteHeader(http.StatusExpectationFailed) + w.Write([]byte(err.Error())) + return + } } w.Header().Set("Content-Type", "text/json") - w.WriteHeader(http.StatusOK) + if err != nil || (response != nil && response.ErrorMessage != "") { + w.WriteHeader(http.StatusExpectationFailed) + } else { + w.WriteHeader(http.StatusOK) + } w.Write(buf.Bytes()) }