From 4792f3837113829d4124bc80123c95dd22b3ba0e Mon Sep 17 00:00:00 2001 From: Antonio Pitasi Date: Tue, 14 Jan 2025 18:36:58 +0100 Subject: [PATCH] feat(prophet): pricepred model handler This is a vary bare future handler for the price prediction model we have today, and that essentially is an experiment to showcase the tech architecture (in fact, the date for the prediction in the request is hardcoded). We lack some essential features such as a way of configuring the prediction endpoint (again, hardcoded for now). This is also our first tentative of accepting a bytes array coming from solidity, mapping it into a json, and then doing the other way around for the response. --- prophet/handlers/pricepred/pricepred.go | 168 +++++++++++++++++++ prophet/handlers/pricepred/pricepred_test.go | 93 ++++++++++ warden/app/app.go | 3 + 3 files changed, 264 insertions(+) create mode 100644 prophet/handlers/pricepred/pricepred.go create mode 100644 prophet/handlers/pricepred/pricepred_test.go diff --git a/prophet/handlers/pricepred/pricepred.go b/prophet/handlers/pricepred/pricepred.go new file mode 100644 index 000000000..f8a777b78 --- /dev/null +++ b/prophet/handlers/pricepred/pricepred.go @@ -0,0 +1,168 @@ +// Package pricepred provides a handler for the price prediction AI model. +package pricepred + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "math/big" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/accounts/abi" +) + +var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve" + +var client = http.Client{ + Timeout: 3 * time.Second, +} + +// PricePredictorSolidity is a handler for the price prediction AI model, +// wrapping input and output in Solidity ABI types. +type PricePredictorSolidity struct{} + +func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) { + tokens, err := decodeInput(input) + if err != nil { + return nil, err + } + + req := Request{ + SolverInput: RequestSolverInput{ + Tokens: tokens, + TargetDate: "2022-01-01", + AdversaryMode: false, + }, + FalsePositiveRate: 0.01, + } + + res, err := Predict(ctx, req) + if err != nil { + return nil, err + } + + encodedRes, err := encodeOutput(req, res) + if err != nil { + return nil, err + } + + return encodedRes, nil +} + +func decodeInput(input []byte) ([]string, error) { + typ, err := abi.NewType("string[]", "string[]", nil) + if err != nil { + return nil, err + } + args := abi.Arguments{ + {Type: typ}, + } + + unpackArgs, err := args.Unpack(input) + if err != nil { + return nil, err + } + + tokens, ok := unpackArgs[0].([]string) + if !ok { + return nil, fmt.Errorf("failed to unpack input") + } + + return tokens, nil +} + +func encodeOutput(req Request, res Response) ([]byte, error) { + typ, err := abi.NewType("uint256[]", "", nil) + if err != nil { + log.Fatal(err) + } + args := abi.Arguments{ + { + Type: typ, + Name: "SolverOutput", + Indexed: false, + }, + } + + tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens)) + for i, v := range req.SolverInput.Tokens { + decimals := big.NewFloat(1e16) + pred := big.NewFloat(res.SolverOutput[v]) + tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil) + } + + enc, err := args.Pack(tokenPreds) + if err != nil { + return nil, err + } + + return enc, nil +} + +func (s PricePredictorSolidity) Verify(ctx context.Context, input []byte, output []byte) error { + // todo: verify output + return nil +} + +type RequestSolverInput struct { + Tokens []string `json:"tokens"` + TargetDate string `json:"target_date"` + AdversaryMode bool `json:"adversaryMode"` +} + +type Request struct { + SolverInput RequestSolverInput `json:"solverInput"` + FalsePositiveRate float64 `json:"falsePositiveRate"` +} + +type Response struct { + SolverOutput map[string]float64 `json:"solverOutput"` + SolverReceipt struct { + BloomFilter []byte `json:"bloomFilter"` + CountItems int `json:"countItems"` + } `json:"solverReceipt"` +} + +func Predict(ctx context.Context, req Request) (Response, error) { + reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + reqBody, err := json.Marshal(req) + if err != nil { + return Response{}, err + } + + httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody)) + if err != nil { + return Response{}, err + } + + httpReq.Header.Set("Accept", "application/json") + httpReq.Header.Set("Content-Type", "application/json") + + res, err := client.Do(httpReq) + if err != nil { + return Response{}, err + } + defer res.Body.Close() + response, err := io.ReadAll(res.Body) + if err != nil { + return Response{}, err + } + + if res.StatusCode != http.StatusOK { + return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response) + } + + var resResp Response + err = json.Unmarshal(response, &resResp) + if err != nil { + return Response{}, err + } + + return resResp, nil +} diff --git a/prophet/handlers/pricepred/pricepred_test.go b/prophet/handlers/pricepred/pricepred_test.go new file mode 100644 index 000000000..d7861762f --- /dev/null +++ b/prophet/handlers/pricepred/pricepred_test.go @@ -0,0 +1,93 @@ +package pricepred + +import ( + "context" + "encoding/base64" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecodeInput(t *testing.T) { + cases := []struct { + name string + input string + expected []string + }{ + { + name: "single element list", + input: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANFVEgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==", + expected: []string{"ETH"}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + bz, err := base64.StdEncoding.DecodeString(c.input) + require.NoError(t, err) + actual, err := decodeInput(bz) + require.NoError(t, err) + require.Equal(t, c.expected, actual) + }) + } +} + +func TestEncodeOutput(t *testing.T) { + cases := []struct { + name string + request []string + expected string + }{ + { + name: "single element list", + request: []string{"bitcoin", "tether", "uniswap"}, + expected: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000018d6e7f01865e90000000000000000000000000000000000000000000000000000002387ffdba3cf9c000000000000000000000000000000000000000000000000026b6b97cf726620", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + req := Request{ + SolverInput: RequestSolverInput{ + Tokens: c.request, + TargetDate: "2022-01-01", + AdversaryMode: false, + }, + FalsePositiveRate: 0.01, + } + res := Response{ + SolverOutput: map[string]float64{ + "uniswap": 17.435131034851075, + "tether": 1.000115715622902, + "bitcoin": 45820.74676003456, + }, + } + + actual, err := encodeOutput(req, res) + require.NoError(t, err) + + require.Equal(t, c.expected, hex.EncodeToString(actual)) + }) + } +} + +func TestPredict(t *testing.T) { + //t.Skip("this test relies on external HTTP call") + + ctx := context.Background() + res, err := Predict(ctx, Request{ + SolverInput: RequestSolverInput{ + Tokens: []string{"bitcoin", "tether", "uniswap"}, + TargetDate: "2022-01-01", + AdversaryMode: false, + }, + FalsePositiveRate: 0.01, + }) + require.NoError(t, err) + require.Len(t, res.SolverOutput, 3) + + for token, pred := range res.SolverOutput { + require.NotZero(t, pred, "prediction for %t is zero", token) + } +} diff --git a/warden/app/app.go b/warden/app/app.go index e72b314cd..e909340ab 100644 --- a/warden/app/app.go +++ b/warden/app/app.go @@ -89,6 +89,7 @@ import ( "github.com/warden-protocol/wardenprotocol/prophet" "github.com/warden-protocol/wardenprotocol/prophet/handlers/echo" + "github.com/warden-protocol/wardenprotocol/prophet/handlers/pricepred" ) const ( @@ -246,6 +247,8 @@ func New( baseAppOptions ...func(*baseapp.BaseApp), ) (*App, error) { prophet.Register("echo", echo.Handler{}) + prophet.Register("pricepred", pricepred.PricePredictorSolidity{}) + prophetP, err := prophet.New() if err != nil { panic(fmt.Errorf("failed to create prophet: %w", err))