Skip to content

Commit

Permalink
refactor robinhood client to use generics
Browse files Browse the repository at this point in the history
  • Loading branch information
vitoordaz committed Jan 7, 2024
1 parent 031f028 commit cf5bfa0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 54 deletions.
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
Expand Down
78 changes: 24 additions & 54 deletions internal/robinhood/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,68 +65,19 @@ func (dc *defaultClient) GetToken(ctx context.Context, username, password, mfa s
}

func (dc *defaultClient) GetMarket(ctx context.Context, id string) (*Market, error) {
result, err := dc.get(ctx, nil, getDetailURL(EndpointMarket, id), &Market{})
if err != nil {
return nil, err
}
return result.(*Market), nil
return doGet[Market](ctx, dc.c, nil, getDetailURL(EndpointMarket, id))
}

func (dc *defaultClient) GetOrders(ctx context.Context, auth *ResponseToken, cursor string) (*ResponseOrders, error) {
result, err := dc.list(ctx, auth, EndpointOrders, cursor, &ResponseOrders{})
if err != nil {
return nil, err
}
return result.(*ResponseOrders), nil
return doList[ResponseOrders](ctx, dc.c, auth, EndpointOrders, cursor)
}

func (dc *defaultClient) GetPositions(
ctx context.Context,
auth *ResponseToken,
cursor string,
) (*ResponsePositions, error) {
result, err := dc.list(ctx, auth, EndpointPositions, cursor, &ResponsePositions{})
if err != nil {
return nil, err
}
return result.(*ResponsePositions), nil
func (dc *defaultClient) GetPositions(ctx context.Context, auth *ResponseToken, cursor string) (*ResponsePositions, error) {
return doList[ResponsePositions](ctx, dc.c, auth, EndpointPositions, cursor)
}

func (dc *defaultClient) GetInstrument(ctx context.Context, id string) (*Instrument, error) {
result, err := dc.get(ctx, nil, getDetailURL(EndpointInstrument, id), &Instrument{})
if err != nil {
return nil, err
}
return result.(*Instrument), nil
}

func (dc *defaultClient) get(
ctx context.Context,
auth *ResponseToken,
url string,
result interface{},
) (interface{}, error) {
r := dc.c.R().SetContext(ctx).SetResult(result)
if auth != nil {
r = r.SetAuthScheme(auth.TokenType).SetAuthToken(auth.AccessToken)
}
resp, err := r.Get(url)
if err != nil {
return nil, err
}
if resp.IsError() {
return nil, resp.Error().(error)
}
return resp.Result(), nil
}

func (dc *defaultClient) list(
ctx context.Context,
auth *ResponseToken,
endpoint, cursor string,
result interface{},
) (interface{}, error) {
return dc.get(ctx, auth, getListURL(endpoint, cursor), result)
return doGet[Instrument](ctx, dc.c, nil, getDetailURL(EndpointInstrument, id))
}

func getDetailURL(prefix, id string) string {
Expand All @@ -147,3 +98,22 @@ func getListURL(endpoint, cursor string) string {
func isURL(v string) bool {
return strings.HasPrefix(v, "https://") || strings.HasPrefix(v, "http://")
}

func doGet[T any](ctx context.Context, client *resty.Client, auth *ResponseToken, url string) (*T, error) {
r := client.R().SetContext(ctx).SetResult(new(T))
if auth != nil {
r = r.SetAuthScheme(auth.TokenType).SetAuthToken(auth.AccessToken)
}
resp, err := r.Get(url)
if err != nil {
return nil, err
}
if resp.IsError() {
return nil, resp.Error().(error)
}
return resp.Result().(*T), nil
}

func doList[T any](ctx context.Context, client *resty.Client, auth *ResponseToken, endpoint, cursor string) (*T, error) {
return doGet[T](ctx, client, auth, getListURL(endpoint, cursor))
}

0 comments on commit cf5bfa0

Please sign in to comment.