-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(prophet): pricepred model handler (#1168)
This is a vary bare future handler for the price prediction model we have today, and that essentially is an experiment to showcase the tech architecture (in fact, the date for the prediction in the request is hardcoded). We lack some essential features such as a way of configuring the prediction endpoint (again, hardcoded for now). This is also our first tentative of accepting a bytes array coming from solidity, mapping it into a json, and then doing the other way around for the response.
- Loading branch information
Showing
3 changed files
with
264 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
// Package pricepred provides a handler for the price prediction AI model. | ||
package pricepred | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"log" | ||
"math/big" | ||
"net/http" | ||
"time" | ||
|
||
"github.com/ethereum/go-ethereum/accounts/abi" | ||
) | ||
|
||
var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve" | ||
|
||
var client = http.Client{ | ||
Timeout: 3 * time.Second, | ||
} | ||
|
||
// PricePredictorSolidity is a handler for the price prediction AI model, | ||
// wrapping input and output in Solidity ABI types. | ||
type PricePredictorSolidity struct{} | ||
|
||
func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) { | ||
tokens, err := decodeInput(input) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
req := Request{ | ||
SolverInput: RequestSolverInput{ | ||
Tokens: tokens, | ||
TargetDate: "2022-01-01", | ||
AdversaryMode: false, | ||
}, | ||
FalsePositiveRate: 0.01, | ||
} | ||
|
||
res, err := Predict(ctx, req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
encodedRes, err := encodeOutput(req, res) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return encodedRes, nil | ||
} | ||
|
||
func decodeInput(input []byte) ([]string, error) { | ||
typ, err := abi.NewType("string[]", "string[]", nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
args := abi.Arguments{ | ||
{Type: typ}, | ||
} | ||
|
||
unpackArgs, err := args.Unpack(input) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
tokens, ok := unpackArgs[0].([]string) | ||
if !ok { | ||
return nil, fmt.Errorf("failed to unpack input") | ||
} | ||
|
||
return tokens, nil | ||
} | ||
|
||
func encodeOutput(req Request, res Response) ([]byte, error) { | ||
typ, err := abi.NewType("uint256[]", "", nil) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
args := abi.Arguments{ | ||
{ | ||
Type: typ, | ||
Name: "SolverOutput", | ||
Indexed: false, | ||
}, | ||
} | ||
|
||
tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens)) | ||
for i, v := range req.SolverInput.Tokens { | ||
decimals := big.NewFloat(1e16) | ||
pred := big.NewFloat(res.SolverOutput[v]) | ||
tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil) | ||
} | ||
|
||
enc, err := args.Pack(tokenPreds) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return enc, nil | ||
} | ||
|
||
func (s PricePredictorSolidity) Verify(ctx context.Context, input []byte, output []byte) error { | ||
// todo: verify output | ||
return nil | ||
} | ||
|
||
type RequestSolverInput struct { | ||
Tokens []string `json:"tokens"` | ||
TargetDate string `json:"target_date"` | ||
AdversaryMode bool `json:"adversaryMode"` | ||
} | ||
|
||
type Request struct { | ||
SolverInput RequestSolverInput `json:"solverInput"` | ||
FalsePositiveRate float64 `json:"falsePositiveRate"` | ||
} | ||
|
||
type Response struct { | ||
SolverOutput map[string]float64 `json:"solverOutput"` | ||
SolverReceipt struct { | ||
BloomFilter []byte `json:"bloomFilter"` | ||
CountItems int `json:"countItems"` | ||
} `json:"solverReceipt"` | ||
} | ||
|
||
func Predict(ctx context.Context, req Request) (Response, error) { | ||
reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second) | ||
defer cancel() | ||
|
||
reqBody, err := json.Marshal(req) | ||
if err != nil { | ||
return Response{}, err | ||
} | ||
|
||
httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody)) | ||
if err != nil { | ||
return Response{}, err | ||
} | ||
|
||
httpReq.Header.Set("Accept", "application/json") | ||
httpReq.Header.Set("Content-Type", "application/json") | ||
|
||
res, err := client.Do(httpReq) | ||
if err != nil { | ||
return Response{}, err | ||
} | ||
defer res.Body.Close() | ||
response, err := io.ReadAll(res.Body) | ||
if err != nil { | ||
return Response{}, err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response) | ||
} | ||
|
||
var resResp Response | ||
err = json.Unmarshal(response, &resResp) | ||
if err != nil { | ||
return Response{}, err | ||
} | ||
|
||
return resResp, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package pricepred | ||
|
||
import ( | ||
"context" | ||
"encoding/base64" | ||
"encoding/hex" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestDecodeInput(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
input string | ||
expected []string | ||
}{ | ||
{ | ||
name: "single element list", | ||
input: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANFVEgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==", | ||
expected: []string{"ETH"}, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
bz, err := base64.StdEncoding.DecodeString(c.input) | ||
require.NoError(t, err) | ||
actual, err := decodeInput(bz) | ||
require.NoError(t, err) | ||
require.Equal(t, c.expected, actual) | ||
}) | ||
} | ||
} | ||
|
||
func TestEncodeOutput(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
request []string | ||
expected string | ||
}{ | ||
{ | ||
name: "single element list", | ||
request: []string{"bitcoin", "tether", "uniswap"}, | ||
expected: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000018d6e7f01865e90000000000000000000000000000000000000000000000000000002387ffdba3cf9c000000000000000000000000000000000000000000000000026b6b97cf726620", | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
req := Request{ | ||
SolverInput: RequestSolverInput{ | ||
Tokens: c.request, | ||
TargetDate: "2022-01-01", | ||
AdversaryMode: false, | ||
}, | ||
FalsePositiveRate: 0.01, | ||
} | ||
res := Response{ | ||
SolverOutput: map[string]float64{ | ||
"uniswap": 17.435131034851075, | ||
"tether": 1.000115715622902, | ||
"bitcoin": 45820.74676003456, | ||
}, | ||
} | ||
|
||
actual, err := encodeOutput(req, res) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, c.expected, hex.EncodeToString(actual)) | ||
}) | ||
} | ||
} | ||
|
||
func TestPredict(t *testing.T) { | ||
//t.Skip("this test relies on external HTTP call") | ||
|
||
ctx := context.Background() | ||
res, err := Predict(ctx, Request{ | ||
SolverInput: RequestSolverInput{ | ||
Tokens: []string{"bitcoin", "tether", "uniswap"}, | ||
TargetDate: "2022-01-01", | ||
AdversaryMode: false, | ||
}, | ||
FalsePositiveRate: 0.01, | ||
}) | ||
require.NoError(t, err) | ||
require.Len(t, res.SolverOutput, 3) | ||
|
||
for token, pred := range res.SolverOutput { | ||
require.NotZero(t, pred, "prediction for %t is zero", token) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters