From 482df4060c2ec6950c331484cf580a50ef315b8d Mon Sep 17 00:00:00 2001 From: bparees Date: Thu, 17 Aug 2023 17:53:08 -0400 Subject: [PATCH] implement github auth + bearer token validation --- cmd/wisdom/main.go | 44 ++++++++- pkg/api/types.go | 23 ++++- pkg/server/auth_handler.go | 160 +++++++++++++++++++++++++++++++- pkg/server/handler.go | 10 +- pkg/server/inference_handler.go | 19 ++++ 5 files changed, 248 insertions(+), 8 deletions(-) diff --git a/cmd/wisdom/main.go b/cmd/wisdom/main.go index ff8a985..419857e 100644 --- a/cmd/wisdom/main.go +++ b/cmd/wisdom/main.go @@ -1,14 +1,17 @@ package main import ( + "encoding/base64" "fmt" "net/http" "os" "reflect" "github.com/gorilla/mux" + "github.com/gorilla/sessions" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/oauth2" "gopkg.in/yaml.v2" "github.com/openshift/wisdom/pkg/api" @@ -91,13 +94,48 @@ func newStartServerCommand() *cobra.Command { DefaultProvider: config.DefaultProvider, DefaultModel: config.DefaultModelId, Models: models, + ClientID: config.ServerConfig.ClientID, + ClientSecret: config.ServerConfig.ClientSecret, } - h.BearerTokens = make(map[string]bool) - for _, t := range config.ServerConfig.BearerTokens { - h.BearerTokens[t] = true + + tokenKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.TokenEncryptionKey) + if err != nil { + return err + } + h.TokenEncryptionKey = tokenKey + + h.AuthConfig = oauth2.Config{ + ClientID: h.ClientID, + ClientSecret: h.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://github.com/login/oauth/authorize", + TokenURL: "https://github.com/login/oauth/access_token", + }, + RedirectURL: "https://localhost:8443/githubcallback", + Scopes: []string{"user:email"}, + } + + authKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.SessionAuthKey) + if err != nil { + return err + } + encKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.SessionEncryptionKey) + if err != nil { + return err + } + + h.CookieStore = sessions.NewCookieStore(authKey, encKey) + h.CookieStore.Options = &sessions.Options{ + Path: "/", + MaxAge: 0, + HttpOnly: false, } r.HandleFunc("/infer", h.InferHandler).Methods("POST") + //r.HandleFunc("/feedback", h.FeedbackHandler).Methods("POST") + r.HandleFunc("/login", h.HandleLogin) + r.HandleFunc("/githubcallback", h.HandleGithubCallback) + r.HandleFunc("/apitoken", h.HandleApiToken) log.Infof("Default model provider: %s\n", h.DefaultProvider) log.Infof("Default model: %s\n", h.DefaultModel) diff --git a/pkg/api/types.go b/pkg/api/types.go index ec5e0c0..08fb3dc 100644 --- a/pkg/api/types.go +++ b/pkg/api/types.go @@ -1,5 +1,9 @@ package api +import ( + "github.com/golang-jwt/jwt/v4" +) + type Model interface { Invoke(ModelInput) (*ModelResponse, error) } @@ -25,6 +29,15 @@ type ModelResponse struct { ErrorMessage string `json:"error"` } +type Claims struct { + Username string `json:"username"` + jwt.RegisteredClaims +} + +type APIToken struct { + Token string `json:"token"` +} + // FeedbackPayload represents the payload for the feedback endpoint. type FeedbackPayload struct { RequestID string `json:"requestId"` @@ -45,9 +58,13 @@ type ModelConfig struct { } type ServerConfig struct { - TLSCertFile string `yaml:"tlsCertFile"` - TLSKeyFile string `yaml:"tlsKeyFile"` - BearerTokens []string `yaml:"bearerTokens"` + TLSCertFile string `yaml:"tlsCertFile"` + TLSKeyFile string `yaml:"tlsKeyFile"` + ClientID string `yaml:"clientId"` + ClientSecret string `yaml:"clientSecret"` + SessionAuthKey string `yaml:"sessionAuthKey"` + SessionEncryptionKey string `yaml:"sessionEncryptionKey"` + TokenEncryptionKey string `yaml:"tokenEncryptionKey"` } type Config struct { diff --git a/pkg/server/auth_handler.go b/pkg/server/auth_handler.go index d4b2471..578d657 100644 --- a/pkg/server/auth_handler.go +++ b/pkg/server/auth_handler.go @@ -1,10 +1,152 @@ package server import ( + "bytes" + "context" + "encoding/json" + "fmt" "net/http" "strings" + + "github.com/golang-jwt/jwt/v4" + "github.com/openshift/wisdom/pkg/api" + + log "github.com/sirupsen/logrus" + "golang.org/x/oauth2" ) +func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { + url := h.AuthConfig.AuthCodeURL("", oauth2.AccessTypeOffline) + http.Redirect(w, r, url, http.StatusFound) +} + +func (h *Handler) HandleGithubCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + + token, err := h.AuthConfig.Exchange(context.Background(), code) + if err != nil { + http.Error(w, "Failed to exchange token", http.StatusInternalServerError) + return + } + + client := h.AuthConfig.Client(context.Background(), token) + // Make an authenticated API request to get the user's information + resp, err := client.Get("https://api.github.com/user") + if err != nil { + http.Error(w, "Failed to get user information", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + var user map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { + http.Error(w, "Failed to decode response", http.StatusInternalServerError) + return + } + + userID := user["login"] // Extract the user's GitHub ID + + // Create a session for the user + session, err := h.CookieStore.Get(r, "wisdom-session") + if err != nil { + log.Errorf("failed to create session: %v\n", err) + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + // Store the OAuth token in the session + session.Values["username"] = userID + err = session.Save(r, w) + + if err != nil { + log.Errorf("failed to save session: %v\n", err) + http.Error(w, "Failed to save session", http.StatusInternalServerError) + return + } + + log.Debugf("Stored session for user %s\n", userID) + + http.Redirect(w, r, "/apitoken", http.StatusSeeOther) + fmt.Fprintf(w, "Successfully authenticated with GitHub!\n") + fmt.Fprintf(w, "GitHub ID: %s\n", userID) + //w.WriteHeader(http.StatusOK) + +} + +func (h *Handler) HandleApiToken(w http.ResponseWriter, r *http.Request) { + + session, err := h.CookieStore.Get(r, "wisdom-session") + if err != nil { + url := r.URL + url.Path = "/login" + http.Redirect(w, r, url.String(), http.StatusFound) + return + } + + username, ok := session.Values["username"].(string) + if !ok { + url := r.URL + url.Path = "/login" + http.Redirect(w, r, url.String(), http.StatusFound) + } + + if username != "bparees" { + url := r.URL + url.Path = "/login" + http.Redirect(w, r, url.String(), http.StatusFound) + } + + claims := &api.Claims{ + Username: username, + RegisteredClaims: jwt.RegisteredClaims{ + // In JWT, the expiry time is expressed as unix milliseconds + //ExpiresAt: jwt.NewNumericDate(expirationTime), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + tokenString, err := token.SignedString(h.TokenEncryptionKey) + + apiToken := api.APIToken{ + Token: tokenString, + } + buf := bytes.Buffer{} + err = json.NewEncoder(&buf).Encode(apiToken) + + w.Header().Set("Content-Type", "text/json") + w.WriteHeader(http.StatusOK) + w.Write(buf.Bytes()) +} + +func (h *Handler) isAuthorized(w http.ResponseWriter, r *http.Request) bool { + // Retrieve the session for the user + session, err := h.CookieStore.Get(r, "wisdom-session") + if err != nil { + url := r.URL + url.Path = "/login" + http.Redirect(w, r, url.String(), http.StatusFound) + //http.Error(w, "Failed to get session", http.StatusUnauthorized) + return false + } + + username, ok := session.Values["username"].(string) + if !ok { + url := r.URL + url.Path = "/login" + http.Redirect(w, r, url.String(), http.StatusFound) + //http.Error(w, "Session does not contain a valid username", http.StatusUnauthorized) + return false + } + + // Check if the username is in the allow list + if username != "bparees" { + http.Error(w, "User is not authorized", http.StatusUnauthorized) + return false + } + fmt.Fprint(w, "Authenticated resource access!") + return true +} + func (h *Handler) hasValidBearerToken(r *http.Request) bool { authHeader := r.Header.Get("Authorization") if authHeader == "" { @@ -17,5 +159,21 @@ func (h *Handler) hasValidBearerToken(r *http.Request) bool { token := strings.TrimPrefix(authHeader, "Bearer ") - return h.BearerTokens[token] + claims := &api.Claims{} + tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { + return h.TokenEncryptionKey, nil + }) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return false + } + return false + } + if !tkn.Valid { + return false + } + if claims.Username != "bparees" { + return false + } + return true } diff --git a/pkg/server/handler.go b/pkg/server/handler.go index b6be653..4d439e6 100644 --- a/pkg/server/handler.go +++ b/pkg/server/handler.go @@ -1,8 +1,10 @@ package server import ( + "github.com/gorilla/sessions" "github.com/openshift/wisdom/pkg/api" "github.com/openshift/wisdom/pkg/filters" + "golang.org/x/oauth2" ) type Handler struct { @@ -10,5 +12,11 @@ type Handler struct { DefaultModel string DefaultProvider string Models map[string]api.Model - BearerTokens map[string]bool + ClientID string + ClientSecret string + //SessionAuthKey string + //SessionEncryptionKey string + AuthConfig oauth2.Config + CookieStore *sessions.CookieStore + TokenEncryptionKey []byte } diff --git a/pkg/server/inference_handler.go b/pkg/server/inference_handler.go index a7c0f8f..2b7efd7 100644 --- a/pkg/server/inference_handler.go +++ b/pkg/server/inference_handler.go @@ -55,3 +55,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) { } w.Write(buf.Bytes()) } + +/* +func (h *Handler) FeedbackHandler(w http.ResponseWriter, r *http.Request) { + var payload api.FeedbackPayload + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil { + http.Error(w, "Invalid request payload", http.StatusBadRequest) + return + } + + // Handle the feedback request here + + response := "Feedback received." + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) +} +*/