Skip to content

Commit

Permalink
Basic markdown stripping and yaml linting
Browse files Browse the repository at this point in the history
  • Loading branch information
bparees committed Aug 16, 2023
1 parent 0d6c34c commit 253abe5
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 28 deletions.
2 changes: 1 addition & 1 deletion config/sample.cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
userId: $USERID
apiKey: $APIKEY
- provider: openai
modelId: chatgpt-3.5-turbo
modelId: gpt-3.5-turbo
url: https://api.openai.com
apiKey: $APIKEY
7 changes: 4 additions & 3 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type Model interface {
type ModelInput struct {
UserId string `json:"userid"`
APIKey string `json:"apikey"`
ModelId string `json:"modelid"`
ModelId string `json:"modelId"`
Provider string `json:"provider"`
Prompt string `json:"prompt"`
Context string `json:"context"`
Expand All @@ -18,10 +18,11 @@ type ModelInput struct {
type ModelResponse struct {
Input string `json:"input_tokens"`
Status string `json:"status"`
RequestID string `json:"request_id"`
ConversationID string `json:"conversation_id"`
RequestID string `json:"requestId"`
ConversationID string `json:"conversationId"`
Output string `json:"output"`
RawOutput string `json:"raw_output"`
ErrorMessage string `json:"error"`
}

// FeedbackPayload represents the payload for the feedback endpoint.
Expand Down
16 changes: 3 additions & 13 deletions pkg/filters/filters.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package filters

import (
"gopkg.in/yaml.v2"

"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters/markdown"
"github.com/openshift/wisdom/pkg/filters/yaml"
)

type Filter struct {
Expand All @@ -15,7 +15,7 @@ type ResponseFilter func(response *api.ModelResponse) (*api.ModelResponse, error

func NewFilter() Filter {
filter := Filter{}
//filter.responseFilterChain = append(filter.responseFilterChain, YamlLinter)
filter.responseFilterChain = append(filter.responseFilterChain, markdown.MarkdownStripper, yaml.YamlLinter)
return filter
}

Expand All @@ -42,13 +42,3 @@ func (f *Filter) FilterResponse(response *api.ModelResponse) (*api.ModelResponse
}
return output, err
}

func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) {
return response, isValidYAML(response.Output)
}

func isValidYAML(yamlString string) error {
var data interface{}
err := yaml.Unmarshal([]byte(yamlString), &data)
return err
}
41 changes: 41 additions & 0 deletions pkg/filters/markdown/markdown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package markdown

import (
"fmt"
"regexp"

//gomarkdown "github.com/gomarkdown/markdown"
log "github.com/sirupsen/logrus"

"github.com/openshift/wisdom/pkg/api"
)

var (
markdownRegex = regexp.MustCompile("(?s)`{3}.*?\n(.*)`{3}")
)

func MarkdownStripper(response *api.ModelResponse) (*api.ModelResponse, error) {

if response.Output == "" {
return response, fmt.Errorf("response output is empty")
}
log.Debugf("Stripping markdown from response:\n %s\n", response.Output)

//response.Output = markdownRegex.ReplaceAllString(response.Output, "")
matches := markdownRegex.FindStringSubmatch(response.Output)
response.Output = matches[1]
/*
node := gomarkdown.Parse([]byte(response.Output), nil)
fmt.Printf("%#v\n", node.GetChildren()[0].GetChildren()[0])
//fmt.Printf("%#v\n", node.GetChildren()[0].GetChildren()[1])
//fmt.Printf("%s\n", node.GetChildren()[0].GetChildren()[1].AsLeaf().Literal)
response.Output = string(node.GetChildren()[0].GetChildren()[0].AsLeaf().Literal)
*/

//response.Output = stripmd.Strip(response.Output)

log.Debugf("Stripped markdown from response:\n %s\n", response.Output)
return response, nil
}
27 changes: 27 additions & 0 deletions pkg/filters/yaml/yaml.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package yaml

import (
"fmt"

"gopkg.in/yaml.v2"

log "github.com/sirupsen/logrus"

"github.com/openshift/wisdom/pkg/api"
)

func YamlLinter(response *api.ModelResponse) (*api.ModelResponse, error) {
if err := isValidYAML(response.Output); err != nil {
return response, fmt.Errorf("response output is not valid YAML: %s", err)
}
return response, nil
}

func isValidYAML(yamlString string) error {
var data interface{}

log.Debugf("Validating YAML:\n%s\n", yamlString)
err := yaml.Unmarshal([]byte(yamlString), &data)

return err
}
7 changes: 7 additions & 0 deletions pkg/model/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ import (
func InvokeModel(input api.ModelInput, model api.Model, filter filters.Filter) (*api.ModelResponse, error) {

response, err := model.Invoke(input)
if response == nil {
response = &api.ModelResponse{}
}
if err != nil {
response.ErrorMessage = err.Error()
return response, err
}
output, err := filter.FilterResponse(response)
if err != nil {
response.ErrorMessage = err.Error()
}
return output, err
}
1 change: 1 addition & 0 deletions pkg/model/openai/model_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,6 @@ func (m *OpenAIModel) Invoke(input api.ModelInput) (*api.ModelResponse, error) {
response := api.ModelResponse{}
response.Input = input.Prompt
response.Output = apiResp.Choices[0].Message.Content
response.RawOutput = apiResp.Choices[0].Message.Content
return &response, err
}
23 changes: 12 additions & 11 deletions pkg/server/inference_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) {
return
}
response, err := model.InvokeModel(payload, m, h.Filter)
if err != nil {
w.WriteHeader(http.StatusExpectationFailed)
w.Write([]byte(err.Error()))
return
}

buf := bytes.Buffer{}
err = json.NewEncoder(&buf).Encode(response)
if err != nil {
w.WriteHeader(http.StatusExpectationFailed)
w.Write([]byte(err.Error()))
return
if response != nil {
err = json.NewEncoder(&buf).Encode(response)
if err != nil {
w.WriteHeader(http.StatusExpectationFailed)
w.Write([]byte(err.Error()))
return
}
}

w.Header().Set("Content-Type", "text/json")
w.WriteHeader(http.StatusOK)
if err != nil || (response != nil && response.ErrorMessage != "") {
w.WriteHeader(http.StatusExpectationFailed)
} else {
w.WriteHeader(http.StatusOK)
}
w.Write(buf.Bytes())
}

0 comments on commit 253abe5

Please sign in to comment.