diff --git a/go.sum b/go.sum index d2c522c..d10de7b 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,7 @@ github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/internal/robinhood/client.go b/internal/robinhood/client.go index de9410a..38f7fd5 100644 --- a/internal/robinhood/client.go +++ b/internal/robinhood/client.go @@ -65,68 +65,19 @@ func (dc *defaultClient) GetToken(ctx context.Context, username, password, mfa s } func (dc *defaultClient) GetMarket(ctx context.Context, id string) (*Market, error) { - result, err := dc.get(ctx, nil, getDetailURL(EndpointMarket, id), &Market{}) - if err != nil { - return nil, err - } - return result.(*Market), nil + return doGet[Market](ctx, dc.c, nil, getDetailURL(EndpointMarket, id)) } func (dc *defaultClient) GetOrders(ctx context.Context, auth *ResponseToken, cursor string) (*ResponseOrders, error) { - result, err := dc.list(ctx, auth, EndpointOrders, cursor, &ResponseOrders{}) - if err != nil { - return nil, err - } - return result.(*ResponseOrders), nil + return doList[ResponseOrders](ctx, dc.c, auth, EndpointOrders, cursor) } -func (dc *defaultClient) GetPositions( - ctx context.Context, - auth *ResponseToken, - cursor string, -) (*ResponsePositions, error) { - result, err := dc.list(ctx, auth, EndpointPositions, cursor, &ResponsePositions{}) - if err != nil { - return nil, err - } - return result.(*ResponsePositions), nil +func (dc *defaultClient) GetPositions(ctx context.Context, auth *ResponseToken, cursor string) (*ResponsePositions, error) { + return doList[ResponsePositions](ctx, dc.c, auth, EndpointPositions, cursor) } func (dc *defaultClient) GetInstrument(ctx context.Context, id string) (*Instrument, error) { - result, err := dc.get(ctx, nil, getDetailURL(EndpointInstrument, id), &Instrument{}) - if err != nil { - return nil, err - } - return result.(*Instrument), nil -} - -func (dc *defaultClient) get( - ctx context.Context, - auth *ResponseToken, - url string, - result interface{}, -) (interface{}, error) { - r := dc.c.R().SetContext(ctx).SetResult(result) - if auth != nil { - r = r.SetAuthScheme(auth.TokenType).SetAuthToken(auth.AccessToken) - } - resp, err := r.Get(url) - if err != nil { - return nil, err - } - if resp.IsError() { - return nil, resp.Error().(error) - } - return resp.Result(), nil -} - -func (dc *defaultClient) list( - ctx context.Context, - auth *ResponseToken, - endpoint, cursor string, - result interface{}, -) (interface{}, error) { - return dc.get(ctx, auth, getListURL(endpoint, cursor), result) + return doGet[Instrument](ctx, dc.c, nil, getDetailURL(EndpointInstrument, id)) } func getDetailURL(prefix, id string) string { @@ -147,3 +98,22 @@ func getListURL(endpoint, cursor string) string { func isURL(v string) bool { return strings.HasPrefix(v, "https://") || strings.HasPrefix(v, "http://") } + +func doGet[T any](ctx context.Context, client *resty.Client, auth *ResponseToken, url string) (*T, error) { + r := client.R().SetContext(ctx).SetResult(new(T)) + if auth != nil { + r = r.SetAuthScheme(auth.TokenType).SetAuthToken(auth.AccessToken) + } + resp, err := r.Get(url) + if err != nil { + return nil, err + } + if resp.IsError() { + return nil, resp.Error().(error) + } + return resp.Result().(*T), nil +} + +func doList[T any](ctx context.Context, client *resty.Client, auth *ResponseToken, endpoint, cursor string) (*T, error) { + return doGet[T](ctx, client, auth, getListURL(endpoint, cursor)) +}