Skip to content

Commit

Permalink
implement github auth + bearer token validation
Browse files Browse the repository at this point in the history
  • Loading branch information
bparees committed Aug 17, 2023
1 parent 2b75bfd commit 482df40
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 8 deletions.
44 changes: 41 additions & 3 deletions cmd/wisdom/main.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package main

import (
"encoding/base64"
"fmt"
"net/http"
"os"
"reflect"

"github.com/gorilla/mux"
"github.com/gorilla/sessions"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"gopkg.in/yaml.v2"

"github.com/openshift/wisdom/pkg/api"
Expand Down Expand Up @@ -91,13 +94,48 @@ func newStartServerCommand() *cobra.Command {
DefaultProvider: config.DefaultProvider,
DefaultModel: config.DefaultModelId,
Models: models,
ClientID: config.ServerConfig.ClientID,
ClientSecret: config.ServerConfig.ClientSecret,
}
h.BearerTokens = make(map[string]bool)
for _, t := range config.ServerConfig.BearerTokens {
h.BearerTokens[t] = true

tokenKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.TokenEncryptionKey)
if err != nil {
return err
}
h.TokenEncryptionKey = tokenKey

h.AuthConfig = oauth2.Config{
ClientID: h.ClientID,
ClientSecret: h.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
},
RedirectURL: "https://localhost:8443/githubcallback",
Scopes: []string{"user:email"},
}

authKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.SessionAuthKey)
if err != nil {
return err
}
encKey, err := base64.StdEncoding.DecodeString(config.ServerConfig.SessionEncryptionKey)
if err != nil {
return err
}

h.CookieStore = sessions.NewCookieStore(authKey, encKey)
h.CookieStore.Options = &sessions.Options{
Path: "/",
MaxAge: 0,
HttpOnly: false,
}

r.HandleFunc("/infer", h.InferHandler).Methods("POST")
//r.HandleFunc("/feedback", h.FeedbackHandler).Methods("POST")
r.HandleFunc("/login", h.HandleLogin)
r.HandleFunc("/githubcallback", h.HandleGithubCallback)
r.HandleFunc("/apitoken", h.HandleApiToken)

log.Infof("Default model provider: %s\n", h.DefaultProvider)
log.Infof("Default model: %s\n", h.DefaultModel)
Expand Down
23 changes: 20 additions & 3 deletions pkg/api/types.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package api

import (
"github.com/golang-jwt/jwt/v4"
)

type Model interface {
Invoke(ModelInput) (*ModelResponse, error)
}
Expand All @@ -25,6 +29,15 @@ type ModelResponse struct {
ErrorMessage string `json:"error"`
}

type Claims struct {
Username string `json:"username"`
jwt.RegisteredClaims
}

type APIToken struct {
Token string `json:"token"`
}

// FeedbackPayload represents the payload for the feedback endpoint.
type FeedbackPayload struct {
RequestID string `json:"requestId"`
Expand All @@ -45,9 +58,13 @@ type ModelConfig struct {
}

type ServerConfig struct {
TLSCertFile string `yaml:"tlsCertFile"`
TLSKeyFile string `yaml:"tlsKeyFile"`
BearerTokens []string `yaml:"bearerTokens"`
TLSCertFile string `yaml:"tlsCertFile"`
TLSKeyFile string `yaml:"tlsKeyFile"`
ClientID string `yaml:"clientId"`
ClientSecret string `yaml:"clientSecret"`
SessionAuthKey string `yaml:"sessionAuthKey"`
SessionEncryptionKey string `yaml:"sessionEncryptionKey"`
TokenEncryptionKey string `yaml:"tokenEncryptionKey"`
}

type Config struct {
Expand Down
160 changes: 159 additions & 1 deletion pkg/server/auth_handler.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,152 @@
package server

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/openshift/wisdom/pkg/api"

log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
)

func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
url := h.AuthConfig.AuthCodeURL("", oauth2.AccessTypeOffline)
http.Redirect(w, r, url, http.StatusFound)
}

func (h *Handler) HandleGithubCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")

token, err := h.AuthConfig.Exchange(context.Background(), code)
if err != nil {
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return
}

client := h.AuthConfig.Client(context.Background(), token)
// Make an authenticated API request to get the user's information
resp, err := client.Get("https://api.github.com/user")
if err != nil {
http.Error(w, "Failed to get user information", http.StatusInternalServerError)
return
}
defer resp.Body.Close()

var user map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
http.Error(w, "Failed to decode response", http.StatusInternalServerError)
return
}

userID := user["login"] // Extract the user's GitHub ID

// Create a session for the user
session, err := h.CookieStore.Get(r, "wisdom-session")
if err != nil {
log.Errorf("failed to create session: %v\n", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}

// Store the OAuth token in the session
session.Values["username"] = userID
err = session.Save(r, w)

if err != nil {
log.Errorf("failed to save session: %v\n", err)
http.Error(w, "Failed to save session", http.StatusInternalServerError)
return
}

log.Debugf("Stored session for user %s\n", userID)

http.Redirect(w, r, "/apitoken", http.StatusSeeOther)
fmt.Fprintf(w, "Successfully authenticated with GitHub!\n")
fmt.Fprintf(w, "GitHub ID: %s\n", userID)
//w.WriteHeader(http.StatusOK)

}

func (h *Handler) HandleApiToken(w http.ResponseWriter, r *http.Request) {

session, err := h.CookieStore.Get(r, "wisdom-session")
if err != nil {
url := r.URL
url.Path = "/login"
http.Redirect(w, r, url.String(), http.StatusFound)
return
}

username, ok := session.Values["username"].(string)
if !ok {
url := r.URL
url.Path = "/login"
http.Redirect(w, r, url.String(), http.StatusFound)
}

if username != "bparees" {
url := r.URL
url.Path = "/login"
http.Redirect(w, r, url.String(), http.StatusFound)
}

claims := &api.Claims{
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
// In JWT, the expiry time is expressed as unix milliseconds
//ExpiresAt: jwt.NewNumericDate(expirationTime),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

tokenString, err := token.SignedString(h.TokenEncryptionKey)

apiToken := api.APIToken{
Token: tokenString,
}
buf := bytes.Buffer{}
err = json.NewEncoder(&buf).Encode(apiToken)

w.Header().Set("Content-Type", "text/json")
w.WriteHeader(http.StatusOK)
w.Write(buf.Bytes())
}

func (h *Handler) isAuthorized(w http.ResponseWriter, r *http.Request) bool {
// Retrieve the session for the user
session, err := h.CookieStore.Get(r, "wisdom-session")
if err != nil {
url := r.URL
url.Path = "/login"
http.Redirect(w, r, url.String(), http.StatusFound)
//http.Error(w, "Failed to get session", http.StatusUnauthorized)
return false
}

username, ok := session.Values["username"].(string)
if !ok {
url := r.URL
url.Path = "/login"
http.Redirect(w, r, url.String(), http.StatusFound)
//http.Error(w, "Session does not contain a valid username", http.StatusUnauthorized)
return false
}

// Check if the username is in the allow list
if username != "bparees" {
http.Error(w, "User is not authorized", http.StatusUnauthorized)
return false
}
fmt.Fprint(w, "Authenticated resource access!")
return true
}

func (h *Handler) hasValidBearerToken(r *http.Request) bool {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
Expand All @@ -17,5 +159,21 @@ func (h *Handler) hasValidBearerToken(r *http.Request) bool {

token := strings.TrimPrefix(authHeader, "Bearer ")

return h.BearerTokens[token]
claims := &api.Claims{}
tkn, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
return h.TokenEncryptionKey, nil
})
if err != nil {
if err == jwt.ErrSignatureInvalid {
return false
}
return false
}
if !tkn.Valid {
return false
}
if claims.Username != "bparees" {
return false
}
return true
}
10 changes: 9 additions & 1 deletion pkg/server/handler.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package server

import (
"github.com/gorilla/sessions"
"github.com/openshift/wisdom/pkg/api"
"github.com/openshift/wisdom/pkg/filters"
"golang.org/x/oauth2"
)

type Handler struct {
Filter filters.Filter
DefaultModel string
DefaultProvider string
Models map[string]api.Model
BearerTokens map[string]bool
ClientID string
ClientSecret string
//SessionAuthKey string
//SessionEncryptionKey string
AuthConfig oauth2.Config
CookieStore *sessions.CookieStore
TokenEncryptionKey []byte
}
19 changes: 19 additions & 0 deletions pkg/server/inference_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,22 @@ func (h *Handler) InferHandler(w http.ResponseWriter, r *http.Request) {
}
w.Write(buf.Bytes())
}

/*
func (h *Handler) FeedbackHandler(w http.ResponseWriter, r *http.Request) {
var payload api.FeedbackPayload
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil {
http.Error(w, "Invalid request payload", http.StatusBadRequest)
return
}
// Handle the feedback request here
response := "Feedback received."
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}
*/

0 comments on commit 482df40

Please sign in to comment.