Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prophet): pricepred model handler #1168

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions prophet/handlers/pricepred/pricepred.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Package pricepred provides a handler for the price prediction AI model.
package pricepred

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"time"

"github.com/ethereum/go-ethereum/accounts/abi"
)

var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve"

var client = http.Client{
Timeout: 3 * time.Second,
}
Comment on lines +18 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make URL configurable and review timeout duration.

  1. The prediction service URL is hardcoded to a devnet endpoint, which limits flexibility across different environments.
  2. The 3-second timeout might be too short for some network conditions.

Consider:

  1. Making the URL configurable through environment variables or configuration
  2. Increasing the timeout or making it configurable
  3. Adding retries for transient failures
-var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve"
+var (
+    URL = getEnvOrDefault(
+        "PREDICTION_SERVICE_URL",
+        "https://prediction.devnet.wardenprotocol.org/task/inference/solve",
+    )
+    timeout = getEnvDurationOrDefault("PREDICTION_SERVICE_TIMEOUT", 10*time.Second)
+)
 
 var client = http.Client{
-    Timeout: 3 * time.Second,
+    Timeout: timeout,
 }
+
+func getEnvOrDefault(key, defaultValue string) string {
+    if value := os.Getenv(key); value != "" {
+        return value
+    }
+    return defaultValue
+}
+
+func getEnvDurationOrDefault(key string, defaultValue time.Duration) time.Duration {
+    if value := os.Getenv(key); value != "" {
+        if duration, err := time.ParseDuration(value); err == nil {
+            return duration
+        }
+    }
+    return defaultValue
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
var URL = "https://prediction.devnet.wardenprotocol.org/task/inference/solve"
var client = http.Client{
Timeout: 3 * time.Second,
}
var (
URL = getEnvOrDefault(
"PREDICTION_SERVICE_URL",
"https://prediction.devnet.wardenprotocol.org/task/inference/solve",
)
timeout = getEnvDurationOrDefault("PREDICTION_SERVICE_TIMEOUT", 10*time.Second)
)
var client = http.Client{
Timeout: timeout,
}
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvDurationOrDefault(key string, defaultValue time.Duration) time.Duration {
if value := os.Getenv(key); value != "" {
if duration, err := time.ParseDuration(value); err == nil {
return duration
}
}
return defaultValue
}


// PricePredictorSolidity is a handler for the price prediction AI model,
// wrapping input and output in Solidity ABI types.
type PricePredictorSolidity struct{}

func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) {
tokens, err := decodeInput(input)
if err != nil {
return nil, err
}

req := Request{
SolverInput: RequestSolverInput{
Tokens: tokens,
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}

res, err := Predict(ctx, req)
if err != nil {
return nil, err
}

encodedRes, err := encodeOutput(req, res)
if err != nil {
return nil, err
}

return encodedRes, nil
}
Comment on lines +28 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove hardcoded date and add input validation.

The Execute method has several issues:

  1. Uses a hardcoded date from 2022
  2. Lacks input validation
  3. Missing error handling for empty token list

Consider:

 func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) {
     tokens, err := decodeInput(input)
     if err != nil {
         return nil, err
     }
 
+    if len(tokens) == 0 {
+        return nil, fmt.Errorf("empty token list")
+    }
+
+    futureDate := time.Now().AddDate(1, 0, 0).Format("2006-01-02")
+
     req := Request{
         SolverInput: RequestSolverInput{
             Tokens:        tokens,
-            TargetDate:    "2022-01-01",
+            TargetDate:    futureDate,
             AdversaryMode: false,
         },
         FalsePositiveRate: 0.01,
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) {
tokens, err := decodeInput(input)
if err != nil {
return nil, err
}
req := Request{
SolverInput: RequestSolverInput{
Tokens: tokens,
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}
res, err := Predict(ctx, req)
if err != nil {
return nil, err
}
encodedRes, err := encodeOutput(req, res)
if err != nil {
return nil, err
}
return encodedRes, nil
}
func (s PricePredictorSolidity) Execute(ctx context.Context, input []byte) ([]byte, error) {
tokens, err := decodeInput(input)
if err != nil {
return nil, err
}
if len(tokens) == 0 {
return nil, fmt.Errorf("empty token list")
}
futureDate := time.Now().AddDate(1, 0, 0).Format("2006-01-02")
req := Request{
SolverInput: RequestSolverInput{
Tokens: tokens,
TargetDate: futureDate,
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}
res, err := Predict(ctx, req)
if err != nil {
return nil, err
}
encodedRes, err := encodeOutput(req, res)
if err != nil {
return nil, err
}
return encodedRes, nil
}


func decodeInput(input []byte) ([]string, error) {
typ, err := abi.NewType("string[]", "string[]", nil)
if err != nil {
return nil, err
}
args := abi.Arguments{
{Type: typ},
}

unpackArgs, err := args.Unpack(input)
if err != nil {
return nil, err
}

tokens, ok := unpackArgs[0].([]string)
if !ok {
return nil, fmt.Errorf("failed to unpack input")
}

return tokens, nil
}
Comment on lines +56 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and add token validation.

The function needs better error handling and input validation:

  1. Type assertion error message is not descriptive enough
  2. Missing validation for token strings
  3. No maximum limit on number of tokens

Consider:

 func decodeInput(input []byte) ([]string, error) {
+    const maxTokens = 100 // Prevent DoS attacks
+
     typ, err := abi.NewType("string[]", "string[]", nil)
     if err != nil {
         return nil, err
     }
     args := abi.Arguments{
         {Type: typ},
     }
 
     unpackArgs, err := args.Unpack(input)
     if err != nil {
         return nil, err
     }
 
     tokens, ok := unpackArgs[0].([]string)
     if !ok {
-        return nil, fmt.Errorf("failed to unpack input")
+        return nil, fmt.Errorf("expected []string, got %T", unpackArgs[0])
+    }
+
+    if len(tokens) > maxTokens {
+        return nil, fmt.Errorf("too many tokens: %d (max: %d)", len(tokens), maxTokens)
+    }
+
+    for i, token := range tokens {
+        if !isValidToken(token) {
+            return nil, fmt.Errorf("invalid token at index %d: %s", i, token)
+        
     }
 
     return tokens, nil
 }
+
+func isValidToken(token string) bool {
+    // Add validation logic for token strings
+    return len(token) > 0 && len(token) <= 10 && !strings.ContainsAny(token, " \t\n")
+}

Committable suggestion skipped: line range outside the PR's diff.


func encodeOutput(req Request, res Response) ([]byte, error) {
typ, err := abi.NewType("uint256[]", "", nil)
if err != nil {
log.Fatal(err)
}
args := abi.Arguments{
{
Type: typ,
Name: "SolverOutput",
Indexed: false,
},
}

tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens))
for i, v := range req.SolverInput.Tokens {
decimals := big.NewFloat(1e16)
pred := big.NewFloat(res.SolverOutput[v])
tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil)
}

enc, err := args.Pack(tokenPreds)
if err != nil {
return nil, err
}

return enc, nil
}
Comment on lines +78 to +104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix error handling and make decimal places configurable.

Critical issues in the function:

  1. Uses log.Fatal which is inappropriate for a library function
  2. Hardcoded decimal places (1e16)
  3. Missing validation of response data
  4. No error handling for big.Float operations

Consider:

+const (
+    defaultDecimals = 16
+    maxDecimals     = 18
+)
+
 func encodeOutput(req Request, res Response) ([]byte, error) {
     typ, err := abi.NewType("uint256[]", "", nil)
     if err != nil {
-        log.Fatal(err)
+        return nil, fmt.Errorf("failed to create ABI type: %w", err)
     }
     args := abi.Arguments{
         {
             Type:    typ,
             Name:    "SolverOutput",
             Indexed: false,
         },
     }
 
     tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens))
     for i, v := range req.SolverInput.Tokens {
-        decimals := big.NewFloat(1e16)
+        pred, ok := res.SolverOutput[v]
+        if !ok {
+            return nil, fmt.Errorf("missing prediction for token: %s", v)
+        }
+        if pred < 0 {
+            return nil, fmt.Errorf("negative prediction for token %s: %f", v, pred)
+        }
+
+        decimals := big.NewFloat(math.Pow10(defaultDecimals))
         pred := big.NewFloat(res.SolverOutput[v])
-        tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil)
+        result, accuracy := big.NewFloat(0).Mul(pred, decimals).Int(nil)
+        if accuracy != big.Exact {
+            return nil, fmt.Errorf("loss of precision for token %s", v)
+        }
+        tokenPreds[i] = result
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func encodeOutput(req Request, res Response) ([]byte, error) {
typ, err := abi.NewType("uint256[]", "", nil)
if err != nil {
log.Fatal(err)
}
args := abi.Arguments{
{
Type: typ,
Name: "SolverOutput",
Indexed: false,
},
}
tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens))
for i, v := range req.SolverInput.Tokens {
decimals := big.NewFloat(1e16)
pred := big.NewFloat(res.SolverOutput[v])
tokenPreds[i], _ = big.NewFloat(0).Mul(pred, decimals).Int(nil)
}
enc, err := args.Pack(tokenPreds)
if err != nil {
return nil, err
}
return enc, nil
}
const (
defaultDecimals = 16
maxDecimals = 18
)
func encodeOutput(req Request, res Response) ([]byte, error) {
typ, err := abi.NewType("uint256[]", "", nil)
if err != nil {
return nil, fmt.Errorf("failed to create ABI type: %w", err)
}
args := abi.Arguments{
{
Type: typ,
Name: "SolverOutput",
Indexed: false,
},
}
tokenPreds := make([]*big.Int, len(req.SolverInput.Tokens))
for i, v := range req.SolverInput.Tokens {
pred, ok := res.SolverOutput[v]
if !ok {
return nil, fmt.Errorf("missing prediction for token: %s", v)
}
if pred < 0 {
return nil, fmt.Errorf("negative prediction for token %s: %f", v, pred)
}
decimals := big.NewFloat(math.Pow10(defaultDecimals))
pred := big.NewFloat(res.SolverOutput[v])
result, accuracy := big.NewFloat(0).Mul(pred, decimals).Int(nil)
if accuracy != big.Exact {
return nil, fmt.Errorf("loss of precision for token %s", v)
}
tokenPreds[i] = result
}
enc, err := args.Pack(tokenPreds)
if err != nil {
return nil, err
}
return enc, nil
}


func (s PricePredictorSolidity) Verify(ctx context.Context, input []byte, output []byte) error {
// todo: verify output
return nil
}
Comment on lines +106 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement the Verify method.

The Verify method is a critical security feature that is currently not implemented. This could lead to acceptance of invalid or malicious responses.

The method should verify:

  1. Response signature if available
  2. Response format and data types
  3. Prediction values are within reasonable bounds
  4. All requested tokens are present in the response

Would you like me to generate an implementation for the Verify method?


type RequestSolverInput struct {
Tokens []string `json:"tokens"`
TargetDate string `json:"target_date"`
AdversaryMode bool `json:"adversaryMode"`
}

type Request struct {
SolverInput RequestSolverInput `json:"solverInput"`
FalsePositiveRate float64 `json:"falsePositiveRate"`
}

type Response struct {
SolverOutput map[string]float64 `json:"solverOutput"`
SolverReceipt struct {
BloomFilter []byte `json:"bloomFilter"`
CountItems int `json:"countItems"`
} `json:"solverReceipt"`
}

func Predict(ctx context.Context, req Request) (Response, error) {
reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()

reqBody, err := json.Marshal(req)
if err != nil {
return Response{}, err
}

httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody))
if err != nil {
return Response{}, err
}

httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")

res, err := client.Do(httpReq)
if err != nil {
return Response{}, err
}
defer res.Body.Close()
response, err := io.ReadAll(res.Body)
if err != nil {
return Response{}, err
}

if res.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
}

var resResp Response
err = json.Unmarshal(response, &resResp)
if err != nil {
return Response{}, err
}

return resResp, nil
}
Comment on lines +130 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance error handling and add retry mechanism.

The function has several issues:

  1. Duplicate timeout configuration
  2. No retry mechanism for transient failures
  3. Basic error handling
  4. Response body not limited in size

Consider:

+const (
+    maxResponseSize = 1 << 20 // 1MB
+    maxRetries = 3
+)
+
 func Predict(ctx context.Context, req Request) (Response, error) {
-    reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
-    defer cancel()
+    var resp Response
+    var lastErr error
 
     reqBody, err := json.Marshal(req)
     if err != nil {
         return Response{}, err
     }
 
-    httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody))
-    if err != nil {
-        return Response{}, err
-    }
+    for attempt := 0; attempt <= maxRetries; attempt++ {
+        if attempt > 0 {
+            // Exponential backoff
+            backoff := time.Duration(math.Pow(2, float64(attempt-1))) * time.Second
+            timer := time.NewTimer(backoff)
+            select {
+            case <-ctx.Done():
+                timer.Stop()
+                return Response{}, ctx.Err()
+            case <-timer.C:
+            }
+        }
 
-    httpReq.Header.Set("Accept", "application/json")
-    httpReq.Header.Set("Content-Type", "application/json")
+        httpReq, err := http.NewRequestWithContext(ctx, "POST", URL, bytes.NewReader(reqBody))
+        if err != nil {
+            return Response{}, err
+        }
 
-    res, err := client.Do(httpReq)
-    if err != nil {
-        return Response{}, err
-    }
-    defer res.Body.Close()
-    response, err := io.ReadAll(res.Body)
-    if err != nil {
-        return Response{}, err
-    }
+        httpReq.Header.Set("Accept", "application/json")
+        httpReq.Header.Set("Content-Type", "application/json")
 
-    if res.StatusCode != http.StatusOK {
-        return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
-    }
+        res, err := client.Do(httpReq)
+        if err != nil {
+            lastErr = err
+            continue
+        }
+        defer res.Body.Close()
 
-    var resResp Response
-    err = json.Unmarshal(response, &resResp)
-    if err != nil {
-        return Response{}, err
+        // Limit response size
+        limitedReader := io.LimitReader(res.Body, maxResponseSize)
+        response, err := io.ReadAll(limitedReader)
+        if err != nil {
+            lastErr = err
+            continue
+        }
+
+        if res.StatusCode != http.StatusOK {
+            lastErr = fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
+            // Don't retry on 4xx errors
+            if res.StatusCode/100 == 4 {
+                break
+            }
+            continue
+        }
+
+        if err := json.Unmarshal(response, &resp); err != nil {
+            return Response{}, fmt.Errorf("failed to unmarshal response: %w", err)
+        }
+
+        return resp, nil
     }
 
-    return resResp, nil
+    return Response{}, fmt.Errorf("max retries exceeded: %w", lastErr)
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func Predict(ctx context.Context, req Request) (Response, error) {
reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
reqBody, err := json.Marshal(req)
if err != nil {
return Response{}, err
}
httpReq, err := http.NewRequestWithContext(reqCtx, "POST", URL, bytes.NewReader(reqBody))
if err != nil {
return Response{}, err
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
res, err := client.Do(httpReq)
if err != nil {
return Response{}, err
}
defer res.Body.Close()
response, err := io.ReadAll(res.Body)
if err != nil {
return Response{}, err
}
if res.StatusCode != http.StatusOK {
return Response{}, fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
}
var resResp Response
err = json.Unmarshal(response, &resResp)
if err != nil {
return Response{}, err
}
return resResp, nil
}
const (
maxResponseSize = 1 << 20 // 1MB
maxRetries = 3
)
func Predict(ctx context.Context, req Request) (Response, error) {
var resp Response
var lastErr error
reqBody, err := json.Marshal(req)
if err != nil {
return Response{}, err
}
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff
backoff := time.Duration(math.Pow(2, float64(attempt-1))) * time.Second
timer := time.NewTimer(backoff)
select {
case <-ctx.Done():
timer.Stop()
return Response{}, ctx.Err()
case <-timer.C:
}
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", URL, bytes.NewReader(reqBody))
if err != nil {
return Response{}, err
}
httpReq.Header.Set("Accept", "application/json")
httpReq.Header.Set("Content-Type", "application/json")
res, err := client.Do(httpReq)
if err != nil {
lastErr = err
continue
}
defer res.Body.Close()
// Limit response size
limitedReader := io.LimitReader(res.Body, maxResponseSize)
response, err := io.ReadAll(limitedReader)
if err != nil {
lastErr = err
continue
}
if res.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("unexpected status code: %d. Server returned error: %s", res.StatusCode, response)
// Don't retry on 4xx errors
if res.StatusCode/100 == 4 {
break
}
continue
}
if err := json.Unmarshal(response, &resp); err != nil {
return Response{}, fmt.Errorf("failed to unmarshal response: %w", err)
}
return resp, nil
}
return Response{}, fmt.Errorf("max retries exceeded: %w", lastErr)
}

93 changes: 93 additions & 0 deletions prophet/handlers/pricepred/pricepred_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package pricepred

import (
"context"
"encoding/base64"
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeInput(t *testing.T) {
cases := []struct {
name string
input string
expected []string
}{
{
name: "single element list",
input: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANFVEgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==",
expected: []string{"ETH"},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
bz, err := base64.StdEncoding.DecodeString(c.input)
require.NoError(t, err)
actual, err := decodeInput(bz)
require.NoError(t, err)
require.Equal(t, c.expected, actual)
})
}
}
Comment on lines +12 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance test coverage with additional test cases.

The current test suite only covers a single token. Consider adding test cases for:

  • Multiple tokens
  • Empty input
  • Invalid base64 input
  • Invalid token format
 func TestDecodeInput(t *testing.T) {
 	cases := []struct {
 		name     string
 		input    string
 		expected []string
 	}{
 		{
 			name:     "single element list",
 			input:    "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANFVEgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==",
 			expected: []string{"ETH"},
 		},
+		{
+			name:     "multiple tokens",
+			input:    "...", // Add base64 encoded input for multiple tokens
+			expected: []string{"BTC", "ETH", "USDT"},
+		},
+		{
+			name:     "empty input",
+			input:    "",
+			expected: nil,
+		},
+		{
+			name:  "invalid base64",
+			input: "invalid-base64",
+			expected: nil,
+		},
 	}

Committable suggestion skipped: line range outside the PR's diff.


func TestEncodeOutput(t *testing.T) {
cases := []struct {
name string
request []string
expected string
}{
{
name: "single element list",
request: []string{"bitcoin", "tether", "uniswap"},
expected: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000018d6e7f01865e90000000000000000000000000000000000000000000000000000002387ffdba3cf9c000000000000000000000000000000000000000000000000026b6b97cf726620",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
req := Request{
SolverInput: RequestSolverInput{
Tokens: c.request,
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
}
res := Response{
SolverOutput: map[string]float64{
"uniswap": 17.435131034851075,
"tether": 1.000115715622902,
"bitcoin": 45820.74676003456,
},
}

actual, err := encodeOutput(req, res)
require.NoError(t, err)

require.Equal(t, c.expected, hex.EncodeToString(actual))
})
}
}
Comment on lines +36 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update test with dynamic date and add error scenarios.

The test has two issues:

  1. Uses a hardcoded date from 2022, which is in the past
  2. Missing test cases for error scenarios

Consider:

  1. Using a dynamic future date
  2. Adding test cases for:
    • Empty token list
    • Invalid token names
    • Zero/negative predictions
    • Missing predictions for requested tokens
 func TestEncodeOutput(t *testing.T) {
+	futureDate := time.Now().AddDate(1, 0, 0).Format("2006-01-02")
 	cases := []struct {
 		name     string
 		request  []string
+		date     string
 		expected string
+		wantErr  bool
 	}{
 		{
 			name:     "single element list",
 			request:  []string{"bitcoin", "tether", "uniswap"},
+			date:     futureDate,
 			expected: "00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000018d6e7f01865e90000000000000000000000000000000000000000000000000000002387ffdba3cf9c000000000000000000000000000000000000000000000000026b6b97cf726620",
+			wantErr:  false,
 		},
+		{
+			name:     "empty token list",
+			request:  []string{},
+			date:     futureDate,
+			expected: "",
+			wantErr:  true,
+		},
 	}

Committable suggestion skipped: line range outside the PR's diff.


func TestPredict(t *testing.T) {
//t.Skip("this test relies on external HTTP call")

ctx := context.Background()
res, err := Predict(ctx, Request{
SolverInput: RequestSolverInput{
Tokens: []string{"bitcoin", "tether", "uniswap"},
TargetDate: "2022-01-01",
AdversaryMode: false,
},
FalsePositiveRate: 0.01,
})
require.NoError(t, err)
require.Len(t, res.SolverOutput, 3)

for token, pred := range res.SolverOutput {
require.NotZero(t, pred, "prediction for %t is zero", token)
}
}
Comment on lines +75 to +93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Mock external HTTP calls and update the test date.

The test has several issues:

  1. Makes real HTTP calls which can make tests flaky and slow
  2. Uses a hardcoded date from 2022
  3. Has a commented-out skip line instead of actively managing the test execution

Consider:

  1. Using a mock HTTP client
  2. Using a dynamic future date
  3. Properly managing test execution with environment variables

Here's a suggested implementation:

+type mockHTTPClient struct {
+    DoFunc func(req *http.Request) (*http.Response, error)
+}
+
+func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
+    return m.DoFunc(req)
+}
+
 func TestPredict(t *testing.T) {
-    //t.Skip("this test relies on external HTTP call")
+    if testing.Short() {
+        t.Skip("skipping test in short mode")
+    }
+
+    futureDate := time.Now().AddDate(1, 0, 0).Format("2006-01-02")
+    originalClient := client
+    defer func() { client = originalClient }()
+
+    mockResp := &http.Response{
+        StatusCode: http.StatusOK,
+        Body: io.NopCloser(strings.NewReader(`{
+            "solverOutput": {
+                "bitcoin": 50000,
+                "tether": 1,
+                "uniswap": 20
+            }
+        }`)),
+    }
+
+    client = &mockHTTPClient{
+        DoFunc: func(req *http.Request) (*http.Response, error) {
+            return mockResp, nil
+        },
+    }

     ctx := context.Background()
     res, err := Predict(ctx, Request{
         SolverInput: RequestSolverInput{
             Tokens:        []string{"bitcoin", "tether", "uniswap"},
-            TargetDate:    "2022-01-01",
+            TargetDate:    futureDate,
             AdversaryMode: false,
         },
         FalsePositiveRate: 0.01,
     })

Committable suggestion skipped: line range outside the PR's diff.

3 changes: 3 additions & 0 deletions warden/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ import (

"github.com/warden-protocol/wardenprotocol/prophet"
"github.com/warden-protocol/wardenprotocol/prophet/handlers/echo"
"github.com/warden-protocol/wardenprotocol/prophet/handlers/pricepred"
)

const (
Expand Down Expand Up @@ -246,6 +247,8 @@ func New(
baseAppOptions ...func(*baseapp.BaseApp),
) (*App, error) {
prophet.Register("echo", echo.Handler{})
prophet.Register("pricepred", pricepred.PricePredictorSolidity{})

prophetP, err := prophet.New()
if err != nil {
panic(fmt.Errorf("failed to create prophet: %w", err))
Expand Down
Loading