Skip to content

Commit

Permalink
feat(prophet): pricepred model handler (#1168)
Browse files Browse the repository at this point in the history
This is a vary bare future handler for the price prediction model we
have today, and that essentially is an experiment to showcase the tech
architecture (in fact, the date for the prediction in the request is
hardcoded).

We lack some essential features such as a way of configuring the
prediction endpoint (again, hardcoded for now).

This is also our first tentative of accepting a bytes array coming from
solidity, mapping it into a json, and then doing the other way around
for the response.
  • Loading branch information
Pitasi authored Jan 21, 2025
1 parent 1c62973 commit ab2c337
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 0 deletions.
168 changes: 168 additions & 0 deletions prophet/handlers/pricepred/pricepred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Package pricepred provides a handler for the price prediction AI model.
package pricepred

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"time"

"github.com/ethereum/go-ethereum/accounts/abi"
)

var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve"

var client = http.Client{
Timeout: 3 * time.Second,
}

// PricePredictorSolidity is a handler for the price prediction AI model,
// wrapping input and output in Solidity ABI types.
type PricePredictorSolidity struct{}

func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) {
tokens, err := decodeInput(input)
if err != nil {
return nil, err
}

req := Request{
SolverInput: RequestSolverInput{
Tokens: tokens,
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}

res, err := Predict(ctx, req)
if err != nil {
return nil, err
}

encodedRes, err := encodeOutput(req, res)
if err != nil {
return nil, err
}

return encodedRes, nil
}

func decodeInput(input []byte) ([]string, error) {
typ, err := abi.NewType("string[]", "string[]", nil)
if err != nil {
return nil, err
}
args := abi.Arguments{
{Type: typ},
}

unpackArgs, err := args.Unpack(input)
if err != nil {
return nil, err
}

tokens, ok := unpackArgs[0].([]string)
if !ok {
return nil, fmt.Errorf("failed to unpack input")
}

return tokens, nil
}

func encodeOutput(req Request, res Response) ([]byte, error) {
typ, err := abi.NewType("uint256[]", "", nil)
if err != nil {
log.Fatal(err)
}
args := abi.Arguments{
{
Type: typ,
Name: "SolverOutput",
Indexed: false,
},
}

tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens))
for i, v := range req.SolverInput.Tokens {
decimals := big.NewFloat(1e16)
pred := big.NewFloat(res.SolverOutput[v])
tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil)
}

enc, err := args.Pack(tokenPreds)
if err != nil {
return nil, err
}

return enc, nil
}

func (s PricePredictorSolidity) Verify(ctx context.Context, input []byte, output []byte) error {
// todo: verify output
return nil
}

type RequestSolverInput struct {
Tokens []string `json:"tokens"`
TargetDate string `json:"target_date"`
AdversaryMode bool `json:"adversaryMode"`
}

type Request struct {
SolverInput RequestSolverInput `json:"solverInput"`
FalsePositiveRate float64 `json:"falsePositiveRate"`
}

type Response struct {
SolverOutput map[string]float64 `json:"solverOutput"`
SolverReceipt struct {
BloomFilter []byte `json:"bloomFilter"`
CountItems int `json:"countItems"`
} `json:"solverReceipt"`
}

func Predict(ctx context.Context, req Request) (Response, error) {
reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()

reqBody, err := json.Marshal(req)
if err != nil {
return Response{}, err
}

httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody))
if err != nil {
return Response{}, err
}

httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")

res, err := client.Do(httpReq)
if err != nil {
return Response{}, err
}
defer res.Body.Close()
response, err := io.ReadAll(res.Body)
if err != nil {
return Response{}, err
}

if res.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
}

var resResp Response
err = json.Unmarshal(response, &resResp)
if err != nil {
return Response{}, err
}

return resResp, nil
}
93 changes: 93 additions & 0 deletions prophet/handlers/pricepred/pricepred_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package pricepred

import (
"context"
"encoding/base64"
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeInput(t *testing.T) {
cases := []struct {
name string
input string
expected []string
}{
{
name: "single element list",
input: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANFVEgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==",
expected: []string{"ETH"},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
bz, err := base64.StdEncoding.DecodeString(c.input)
require.NoError(t, err)
actual, err := decodeInput(bz)
require.NoError(t, err)
require.Equal(t, c.expected, actual)
})
}
}

func TestEncodeOutput(t *testing.T) {
cases := []struct {
name string
request []string
expected string
}{
{
name: "single element list",
request: []string{"bitcoin", "tether", "uniswap"},
expected: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000018d6e7f01865e90000000000000000000000000000000000000000000000000000002387ffdba3cf9c000000000000000000000000000000000000000000000000026b6b97cf726620",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
req := Request{
SolverInput: RequestSolverInput{
Tokens: c.request,
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}
res := Response{
SolverOutput: map[string]float64{
"uniswap": 17.435131034851075,
"tether": 1.000115715622902,
"bitcoin": 45820.74676003456,
},
}

actual, err := encodeOutput(req, res)
require.NoError(t, err)

require.Equal(t, c.expected, hex.EncodeToString(actual))
})
}
}

func TestPredict(t *testing.T) {
//t.Skip("this test relies on external HTTP call")

ctx := context.Background()
res, err := Predict(ctx, Request{
SolverInput: RequestSolverInput{
Tokens: []string{"bitcoin", "tether", "uniswap"},
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
})
require.NoError(t, err)
require.Len(t, res.SolverOutput, 3)

for token, pred := range res.SolverOutput {
require.NotZero(t, pred, "prediction for %t is zero", token)
}
}
3 changes: 3 additions & 0 deletions warden/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ import (

"github.com/warden-protocol/wardenprotocol/prophet"
"github.com/warden-protocol/wardenprotocol/prophet/handlers/echo"
"github.com/warden-protocol/wardenprotocol/prophet/handlers/pricepred"
)

const (
Expand Down Expand Up @@ -246,6 +247,8 @@ func New(
baseAppOptions ...func(*baseapp.BaseApp),
) (*App, error) {
prophet.Register("echo", echo.Handler{})
prophet.Register("pricepred", pricepred.PricePredictorSolidity{})

prophetP, err := prophet.New()
if err != nil {
panic(fmt.Errorf("failed to create prophet: %w", err))
Expand Down

0 comments on commit ab2c337

Please sign in to comment.