Skip to content

Commit

Permalink
Refactoring final: Reimplement auth header caching properly
Browse files Browse the repository at this point in the history
  • Loading branch information
radito3 committed Mar 11, 2024
1 parent b53a519 commit 8188bd5
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 467 deletions.
2 changes: 1 addition & 1 deletion internal/executors/executor.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package executors

type Executor interface {
Execute(ctx ExecutorContext) *ExecutorResult
Execute(Context) *ExecutorResult
}
9 changes: 6 additions & 3 deletions internal/executors/executor_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ var (
}
)

func NewExecutorContext(input map[string]string, store map[string]string) ExecutorContext {
return ExecutorContext{
func NewExecutorContext(input map[string]string, store map[string]string) Context {
if store == nil {
store = make(map[string]string)
}
return &ExecutorContext{
input: input,
store: store, // copy?
store: store,
}
}

Expand Down
5 changes: 1 addition & 4 deletions internal/executors/executor_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ import pb "github.com/SAP/remote-work-processor/build/proto/generated"

type ExecutorResult struct {
Output map[string]string
Store map[string]string
Status pb.TaskExecutionResponseMessage_TaskState
Error string
}

type ExecutorResultOption func(*ExecutorResult)

func NewExecutorResult(opts ...ExecutorResultOption) *ExecutorResult {
r := &ExecutorResult{
Store: make(map[string]string),
}
r := &ExecutorResult{}

for _, opt := range opts {
opt(r)
Expand Down
120 changes: 6 additions & 114 deletions internal/executors/http/authorization_header.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package http

import (
"github.com/SAP/remote-work-processor/internal/utils"
"regexp"
"strconv"

"github.com/SAP/remote-work-processor/internal/executors"
"regexp"
)

const (
Expand All @@ -15,119 +12,14 @@ const (

var iasTokenUrlRegex = regexp.MustCompile(IasTokenUrlPattern)

type AuthorizationHeader interface {
GetName() string
GetValue() string
HasValue() bool
}

type CacheableAuthorizationHeader interface {
AuthorizationHeader
GetCachingKey() string
GetCacheableValue() (string, error)
ApplyCachedToken(token string) (CacheableAuthorizationHeader, error)
}

type AuthorizationHeaderView string

type CacheableAuthorizationHeaderView struct {
AuthorizationHeaderView
header *oAuthorizationHeader
}

type CachedToken struct {
Token string `json:"token,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
}

func NewCacheableAuthorizationHeaderView(value string, header *oAuthorizationHeader) CacheableAuthorizationHeaderView {
return CacheableAuthorizationHeaderView{
AuthorizationHeaderView: AuthorizationHeaderView(value),
header: header,
}
}

func (h CacheableAuthorizationHeaderView) GetCachingKey() string {
//return h.header.cachingKey
return ""
}

func (h CacheableAuthorizationHeaderView) GetCacheableValue() (string, error) {
token := h.header.token
if token == nil {
return "", nil
}

t, err := utils.ToJson(token)
if err != nil {
return "", err
}

cached := CachedToken{
Token: t,
Timestamp: strconv.FormatInt(token.issuedAt, 10),
}

value, err := utils.ToJson(cached)
if err != nil {
return "", err
}
return value, nil
}

func (h CacheableAuthorizationHeaderView) ApplyCachedToken(token string) (CacheableAuthorizationHeader, error) {
if token == "" {
return h, nil
}

cached := &CachedToken{}
err := utils.FromJson(token, cached)
if err != nil {
return nil, err
}

if cached.Token == "" || cached.Timestamp == "" {
return h, nil
}

// TODO: try direct deserialization of a timestamp instead of first to string and then manual parsing
issuedAt, err := strconv.ParseInt(cached.Timestamp, 10, 64)
if err != nil {
return nil, err
}

err = h.header.setToken(cached.Token, issuedAt)
return h, err
}

func EmptyAuthorizationHeader() AuthorizationHeaderView {
return ""
}

func NewAuthorizationHeaderView(value string) AuthorizationHeaderView {
return AuthorizationHeaderView(value)
}

func (h AuthorizationHeaderView) GetName() string {
return AuthorizationHeaderName
}

func (h AuthorizationHeaderView) GetValue() string {
return string(h)
}

func (h AuthorizationHeaderView) HasValue() bool {
return h != ""
}

// Currently only Basic and Bearer token authentication is supported.
// OAuth 2.0 will be added later

func CreateAuthorizationHeader(params *HttpRequestParameters) (AuthorizationHeader, error) {
func CreateAuthorizationHeader(params *HttpRequestParameters) (string, error) {
authHeader := params.GetAuthorizationHeader()

if authHeader != "" {
return AuthorizationHeaderView(authHeader), nil
return authHeader, nil
}

user := params.GetUser()
Expand All @@ -138,18 +30,18 @@ func CreateAuthorizationHeader(params *HttpRequestParameters) (AuthorizationHead
if user != "" && iasTokenUrlRegex.Match([]byte(tokenUrl)) {
return NewIasAuthorizationHeader(tokenUrl, user, params.GetCertificateAuthentication().GetClientCertificate()).Generate()
}
return NewOAuthHeaderGenerator(params).Generate()
return NewOAuthHeaderGenerator(params).GenerateWithCacheAside()
}

if user != "" {
return NewBasicAuthorizationHeader(user, pass).Generate()
}

if noAuthorizationRequired(params) {
return EmptyAuthorizationHeader(), nil
return "", nil
}

return nil, executors.NewNonRetryableError("Input values for the authentication-related keys " +
return "", executors.NewNonRetryableError("Input values for the authentication-related keys " +
"(user, password & authorizationHeader) are not combined properly.")
}

Expand Down
10 changes: 5 additions & 5 deletions internal/executors/http/basic_authorization_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ import (

type basicAuthorizationHeader struct {
username string
password string
password []byte
}

func NewBasicAuthorizationHeader(u string, p string) AuthorizationHeaderGenerator {
return &basicAuthorizationHeader{
username: u,
password: p,
password: []byte(p),
}
}

func (h *basicAuthorizationHeader) Generate() (AuthorizationHeader, error) {
str := fmt.Sprintf("%s:%s", h.username, h.password)
func (h *basicAuthorizationHeader) Generate() (string, error) {
str := fmt.Sprintf("%s:%s", h.username, string(h.password))
encoded := base64.StdEncoding.EncodeToString([]byte(str))

return NewAuthorizationHeaderView(fmt.Sprintf("Basic %s", encoded)), nil
return fmt.Sprintf("Basic %s", encoded), nil
}
12 changes: 6 additions & 6 deletions internal/executors/http/csrf_token_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/SAP/remote-work-processor/internal/functional"
)

const CSRF_VERB = "fetch"
const CsrfVerb = "fetch"

var csrfTokenHeaders = []string{"X-Csrf-Token", "X-Xsrf-Token"}

Expand All @@ -19,7 +19,7 @@ type csrfTokenFetcher struct {
succeedOnTimeout bool
}

func NewCsrfTokenFetcher(p *HttpRequestParameters, authHeader AuthorizationHeader) TokenFetcher {
func NewCsrfTokenFetcher(p *HttpRequestParameters, authHeader string) TokenFetcher {
return &csrfTokenFetcher{
HttpExecutor: NewDefaultHttpRequestExecutor(),
csrfUrl: p.csrfUrl,
Expand All @@ -44,14 +44,14 @@ func (f *csrfTokenFetcher) Fetch() (string, error) {
return "", fmt.Errorf("no csrf header present in response from %s", f.csrfUrl)
}

func createCsrfHeaders(authHeader AuthorizationHeader) HttpHeaders {
func createCsrfHeaders(authHeader string) HttpHeaders {
csrfHeaders := make(map[string]string)
for _, headerKey := range csrfTokenHeaders {
csrfHeaders[headerKey] = CSRF_VERB
csrfHeaders[headerKey] = CsrfVerb
}

if authHeader.HasValue() {
csrfHeaders[authHeader.GetName()] = authHeader.GetValue()
if authHeader != "" {
csrfHeaders[AuthorizationHeaderName] = authHeader
}
return csrfHeaders
}
Expand Down
7 changes: 6 additions & 1 deletion internal/executors/http/generator.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package http

type AuthorizationHeaderGenerator interface {
Generate() (AuthorizationHeader, error)
Generate() (string, error)
}

type CacheableAuthorizationHeaderGenerator interface {
AuthorizationHeaderGenerator
GenerateWithCacheAside() (string, error)
}
19 changes: 7 additions & 12 deletions internal/executors/http/http_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func NewDefaultHttpRequestExecutor() *HttpRequestExecutor {
return &HttpRequestExecutor{}
}

func (e *HttpRequestExecutor) Execute(ctx executors.ExecutorContext) *executors.ExecutorResult {
func (e *HttpRequestExecutor) Execute(ctx executors.Context) *executors.ExecutorResult {
log.Println("Executing HttpRequest command...")
params, err := NewHttpRequestParametersFromContext(ctx)
if err != nil {
Expand Down Expand Up @@ -74,10 +74,6 @@ func (e *HttpRequestExecutor) ExecuteWithParameters(p *HttpRequestParameters) (*
return nil, err
}

// TODO: get cached token from server request message store
// apply to *http.Request if present and do not request new auth header
// otherwise, request it, set in store (add it to ExecutionResponse) and return in message to server

authHeader, err := CreateAuthorizationHeader(p)
if err != nil {
return nil, err
Expand All @@ -91,7 +87,7 @@ func (e *HttpRequestExecutor) ExecuteWithParameters(p *HttpRequestParameters) (*
return execute(client, p, authHeader)
}

func obtainCsrf(p *HttpRequestParameters, authHeader AuthorizationHeader) error {
func obtainCsrf(p *HttpRequestParameters, authHeader string) error {
fetcher := NewCsrfTokenFetcher(p, authHeader)
token, err := fetcher.Fetch()
if err != nil {
Expand All @@ -102,7 +98,7 @@ func obtainCsrf(p *HttpRequestParameters, authHeader AuthorizationHeader) error
return nil
}

func execute(c *http.Client, p *HttpRequestParameters, authHeader AuthorizationHeader) (*HttpResponse, error) {
func execute(c *http.Client, p *HttpRequestParameters, authHeader string) (*HttpResponse, error) {
req, timeCh, err := createRequest(p.method, p.url, p.headers, p.body, authHeader)
if err != nil {
return nil, executors.NewNonRetryableError("could not create http request: %v", err).WithCause(err)
Expand Down Expand Up @@ -150,8 +146,7 @@ func requestTimedOut(err error) bool {
return errors.As(err, &e) && e.Timeout()
}

func createRequest(method string, url string, headers map[string]string, body string,
authHeader AuthorizationHeader) (*http.Request, <-chan int64, error) {
func createRequest(method string, url string, headers map[string]string, body, authHeader string) (*http.Request, <-chan int64, error) {
timeCh := make(chan int64, 1)

req, err := http.NewRequest(method, url, bytes.NewBuffer([]byte(body)))
Expand All @@ -173,13 +168,13 @@ func createRequest(method string, url string, headers map[string]string, body st
return req.WithContext(httptrace.WithClientTrace(req.Context(), trace)), timeCh, nil
}

func addHeaders(req *http.Request, headers map[string]string, authHeader AuthorizationHeader) {
func addHeaders(req *http.Request, headers map[string]string, authHeader string) {
for k, v := range headers {
req.Header.Add(k, v)
}

if authHeader.HasValue() {
req.Header.Set(authHeader.GetName(), authHeader.GetValue())
if authHeader != "" {
req.Header.Set(AuthorizationHeaderName, authHeader)
}
}

Expand Down
Loading

0 comments on commit 8188bd5

Please sign in to comment.