Skip to content

Commit

Permalink
Prevent errgroup from canceling all providers (#1056)
Browse files Browse the repository at this point in the history
* avoid returning from errgroup

* cleanup

* cleanup

* update metrics to track missing prices

* add telemetry testing

* cleanup
  • Loading branch information
stevenlanders authored Sep 20, 2023
1 parent 0a28ee5 commit fc1390c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 23 deletions.
60 changes: 42 additions & 18 deletions oracle/price-feeder/oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,32 @@ func (o *Oracle) GetPrices() sdk.DecCoins {
return prices
}

// sendProviderFailureMetric function is overridden by unit tests
var sendProviderFailureMetric = telemetry.IncrCounterWithLabels

// safeMapContains handles a nil check if the map is nil
func safeMapContains[V any](m map[string]V, key string) bool {
if m == nil {
return false
}
_, ok := m[key]
return ok
}

// reportPriceErrMetrics sends metrics to telemetry for missing prices
func reportPriceErrMetrics[V any](providerName string, priceType string, prices map[string]V, expected []types.CurrencyPair) {
for _, pair := range expected {
if !safeMapContains(prices, pair.String()) {
sendProviderFailureMetric([]string{"failure", "provider"}, 1, []metrics.Label{
{Name: "type", Value: priceType},
{Name: "reason", Value: "error"},
{Name: "provider", Value: providerName},
{Name: "base", Value: pair.Base},
})
}
}
}

// SetPrices retrieves all the prices and candles from our set of providers as
// determined in the config. If candles are available, uses TVWAP in order
// to determine prices. If candles are not available, uses the most recent prices
Expand Down Expand Up @@ -220,42 +246,33 @@ func (o *Oracle) SetPrices(ctx context.Context) error {
prices := make(map[string]provider.TickerPrice, 0)
candles := make(map[string][]provider.CandlePrice, 0)
ch := make(chan struct{})
errCh := make(chan error, 1)

go func() {
defer close(ch)
prices, err = priceProvider.GetTickerPrices(currencyPairs...)
if err != nil {
telemetry.IncrCounterWithLabels([]string{"failure", "provider"}, 1, []metrics.Label{
{Name: "type", Value: "ticker"},
{Name: "reason", Value: "error"},
{Name: "provider", Value: providerName},
})
errCh <- err
o.logger.Debug().Err(err).Msg("failed to get ticker prices from provider")
}
reportPriceErrMetrics(providerName, "ticker", prices, currencyPairs)

candles, err = priceProvider.GetCandlePrices(currencyPairs...)
if err != nil {
telemetry.IncrCounterWithLabels([]string{"failure", "provider"}, 1, []metrics.Label{
{Name: "type", Value: "candle"},
{Name: "reason", Value: "error"},
{Name: "provider", Value: providerName},
})
errCh <- err
o.logger.Debug().Err(err).Msg("failed to get candle prices from provider")
}
reportPriceErrMetrics(providerName, "candle", prices, currencyPairs)
}()

select {
case <-ch:
break
case err := <-errCh:
return err
case <-time.After(o.providerTimeout):
telemetry.IncrCounterWithLabels([]string{"failure", "provider"}, 1, []metrics.Label{
{Name: "reason", Value: "timeout"},
{Name: "provider", Value: providerName},
})
return fmt.Errorf("provider timed out: %s", providerName)
o.logger.Error().Msgf("provider timed out: %s", providerName)
// returning nil to avoid canceling other providers that might succeed
return nil
}

// flatten and collect prices based on the base currency per provider
Expand All @@ -266,7 +283,13 @@ func (o *Oracle) SetPrices(ctx context.Context) error {
success := SetProviderTickerPricesAndCandles(providerName, providerPrices, providerCandles, prices, candles, pair)
if !success {
mtx.Unlock()
return fmt.Errorf("failed to find any exchange rates in provider responses")
telemetry.IncrCounterWithLabels([]string{"failure", "provider"}, 1, []metrics.Label{
{Name: "reason", Value: "set-prices"},
{Name: "provider", Value: providerName},
})
o.logger.Error().Msgf("failed to set prices for provider %s", providerName)
// returning nil to avoid canceling other providers that might succeed
return nil
}
}

Expand All @@ -276,7 +299,8 @@ func (o *Oracle) SetPrices(ctx context.Context) error {
}

if err := g.Wait(); err != nil {
o.logger.Debug().Err(err).Msg("failed to get ticker prices from provider")
// this should not be possible because there are no errors returned from the tasks
o.logger.Error().Err(err).Msg("set-prices errgroup returned an error")
}

computedPrices, err := GetComputedPrices(
Expand Down
114 changes: 109 additions & 5 deletions oracle/price-feeder/oracle/oracle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ package oracle
import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/armon/go-metrics"
sdkclient "github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/slices"

"github.com/sei-protocol/sei-chain/oracle/price-feeder/config"
"github.com/sei-protocol/sei-chain/oracle/price-feeder/oracle/client"
Expand All @@ -20,6 +23,74 @@ import (
oracletypes "github.com/sei-protocol/sei-chain/x/oracle/types"
)

type mockTelemetry struct {
mx sync.Mutex
recorded []mockMetric
}

type mockMetric struct {
keys []string
val float32
labels []metrics.Label
}

func resetMockTelemetry() *mockTelemetry {
res := &mockTelemetry{
mx: sync.Mutex{},
}
sendProviderFailureMetric = res.IncrCounterWithLabels
return res
}

func (r mockMetric) containsLabel(expected metrics.Label) bool {
for _, l := range r.labels {
if l.Name == expected.Name && l.Value == expected.Value {
return true
}
}
return false
}

func (r mockMetric) labelsEqual(expected []metrics.Label) bool {
if len(expected) != len(r.labels) {
return false
}
for _, l := range expected {
if !r.containsLabel(l) {
return false
}
}
return true
}

func (mt *mockTelemetry) IncrCounterWithLabels(keys []string, val float32, labels []metrics.Label) {
mt.mx.Lock()
defer mt.mx.Unlock()
mt.recorded = append(mt.recorded, mockMetric{keys, val, labels})
}

func (mt *mockTelemetry) Len() int {
return len(mt.recorded)
}

func (mt *mockTelemetry) AssertProviderError(t *testing.T, provider, base, reason, priceType string) {
mt.AssertContains(t, []string{"failure", "provider"}, 1, []metrics.Label{
{Name: "provider", Value: provider},
{Name: "base", Value: base},
{Name: "reason", Value: reason},
{Name: "type", Value: priceType},
})
}

func (mt *mockTelemetry) AssertContains(t *testing.T, keys []string, val float32, labels []metrics.Label) {
for _, r := range mt.recorded {
if r.val == val && slices.Equal(keys, r.keys) && r.labelsEqual(labels) {
return
}
}
require.Fail(t, fmt.Sprintf("no matching metric found: keys=%v, val=%v, labels=%v", keys, val, labels))
}

type mockProvider struct {
prices map[string]provider.TickerPrice
}
Expand Down Expand Up @@ -160,7 +231,6 @@ func (ots *OracleTestSuite) TestPrices() {
Whitelist: denomList(denoms...),
},
}

// Use a mock provider with exchange rates that are not specified in
// configuration.
ots.oracle.priceProviders = map[string]provider.Provider{
Expand All @@ -181,10 +251,22 @@ func (ots *OracleTestSuite) TestPrices() {
},
},
}

telemetryMock := resetMockTelemetry()
ots.Require().Error(ots.oracle.SetPrices(context.TODO()))
ots.Require().Empty(ots.oracle.GetPrices())

ots.Require().Equal(10, telemetryMock.Len())
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderKraken, "UMEE", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderHuobi, "USDC", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderCoinbase, "USDT", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderKraken, "UMEE", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderHuobi, "USDC", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderCoinbase, "USDT", "error", "candle")

// use a mock provider without a conversion rate for these stablecoins
ots.oracle.priceProviders = map[string]provider.Provider{
config.ProviderBinance: mockProvider{
Expand All @@ -204,8 +286,15 @@ func (ots *OracleTestSuite) TestPrices() {
},
},
}

telemetryMock = resetMockTelemetry()
ots.Require().Error(ots.oracle.SetPrices(context.TODO()))
ots.Require().Equal(6, telemetryMock.Len())
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderHuobi, "USDC", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderCoinbase, "USDT", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderHuobi, "USDC", "error", "candle")
telemetryMock.AssertProviderError(ots.T(), config.ProviderCoinbase, "USDT", "error", "candle")

prices := ots.oracle.GetPrices()
ots.Require().Len(prices, 0)
Expand Down Expand Up @@ -254,7 +343,9 @@ func (ots *OracleTestSuite) TestPrices() {
},
}

telemetryMock = resetMockTelemetry()
ots.Require().NoError(ots.oracle.SetPrices(context.TODO()))
ots.Require().Equal(0, telemetryMock.Len())

prices = ots.oracle.GetPrices()
ots.Require().Len(prices, 4)
Expand Down Expand Up @@ -307,7 +398,12 @@ func (ots *OracleTestSuite) TestPrices() {
},
}

telemetryMock = resetMockTelemetry()
ots.Require().NoError(ots.oracle.SetPrices(context.TODO()))
ots.Require().Equal(2, telemetryMock.Len())
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "candle")

prices = ots.oracle.GetPrices()
ots.Require().Len(prices, 4)
ots.Require().Equal(sdk.MustNewDecFromStr("3.70"), prices.AmountOf("uumee"))
Expand Down Expand Up @@ -358,8 +454,12 @@ func (ots *OracleTestSuite) TestPrices() {
},
},
}

telemetryMock = resetMockTelemetry()
ots.Require().NoError(ots.oracle.SetPrices(context.TODO()))
ots.Require().Equal(2, telemetryMock.Len())
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderBinance, "UMEE", "error", "candle")

prices = ots.oracle.GetPrices()
ots.Require().Len(prices, 4)
ots.Require().Equal(sdk.MustNewDecFromStr("3.71"), prices.AmountOf("uumee"))
Expand Down Expand Up @@ -403,8 +503,12 @@ func (ots *OracleTestSuite) TestPrices() {
},
config.ProviderOkx: failingProvider{},
}

telemetryMock = resetMockTelemetry()
ots.Require().NoError(ots.oracle.SetPrices(context.TODO()))
ots.Require().Equal(4, telemetryMock.Len())
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "ticker")
telemetryMock.AssertProviderError(ots.T(), config.ProviderOkx, "XBT", "error", "candle")

prices = ots.oracle.GetPrices()
ots.Require().Len(prices, 3)
ots.Require().Equal(sdk.MustNewDecFromStr("3.71"), prices.AmountOf("uumee"))
Expand Down

0 comments on commit fc1390c

Please sign in to comment.