From 402fe66816b95a2b42b0b70f03031c9a00334d44 Mon Sep 17 00:00:00 2001 From: Harsh Singhvi Date: Tue, 21 Nov 2023 00:41:16 +0530 Subject: [PATCH] - Api Billing - added user entity - user specific access --- controllers/controllers.go | 304 ++++++++++++++++++++++++++++++++++--- database/database.go | 78 ++++++---- main.go | 24 ++- middlewares/middlewares.go | 53 ++++--- models/auth.go | 10 +- models/billing.go | 24 +++ models/roles/roles.go | 9 ++ models/todo.go | 2 +- models/user.go | 10 ++ utils/error.go | 1 + 10 files changed, 443 insertions(+), 72 deletions(-) create mode 100644 models/billing.go create mode 100644 models/user.go diff --git a/controllers/controllers.go b/controllers/controllers.go index 15d1efd..0482f43 100644 --- a/controllers/controllers.go +++ b/controllers/controllers.go @@ -2,14 +2,16 @@ package controllers import ( "fmt" - "github.com/gin-gonic/gin" - guuid "github.com/google/uuid" "harshsinghvi/golang-postgres-kubernetes/database" "harshsinghvi/golang-postgres-kubernetes/models" "harshsinghvi/golang-postgres-kubernetes/models/roles" "harshsinghvi/golang-postgres-kubernetes/utils" + "log" "net/http" "time" + + "github.com/gin-gonic/gin" + guuid "github.com/google/uuid" ) func GetAllTodos(c *gin.Context) { @@ -144,19 +146,35 @@ func DeleteTodo(c *gin.Context) { } func CreateNewToken(c *gin.Context) { + var userId string + userId = c.Param("id") + + if userId == "" { + userIdFromToken, _ := c.Get("user_id") + userId = userIdFromToken.(string) + } + + if userId != "admin" { + if count, _ := database.Connection.Model(&models.User{}).Where("id = ?", userId).Count(); count == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "user_id, not found", + }) + return + } + } + id := guuid.New().String() token := utils.GenerateToken(id) - var accessToken models.AccessToken - c.BindJSON(&accessToken) - - insertError := database.Connection.Insert(&models.AccessToken{ + accessToken := models.AccessToken{ ID: id, Token: token, - Email: accessToken.Email, + UserID: userId, Roles: []string{roles.Read, roles.Write}, Expiry: time.Now().AddDate(0, 0, 10), CreatedAt: time.Now(), - }) + } + insertError := database.Connection.Insert(&accessToken) if insertError != nil { utils.InternalServerError(c, "Error while inserting new token into db, Reason:", insertError) @@ -166,12 +184,20 @@ func CreateNewToken(c *gin.Context) { c.JSON(http.StatusCreated, gin.H{ "status": http.StatusCreated, "message": "Token created Successfully", + "data": accessToken, "token": token, }) } func GetTokens(c *gin.Context) { - email := c.Param("email") + var userId string + userId = c.Param("id") + + if userId == "" { + userIdFromToken, _ := c.Get("user_id") + userId = userIdFromToken.(string) + } + var pag models.Pagination var err error var accessTokens []models.AccessToken @@ -180,8 +206,8 @@ func GetTokens(c *gin.Context) { querry := database.Connection.Model(&accessTokens).Order("created_at DESC") - if email != "admin" { - querry = querry.Where("email = ?", email) + if userId != "admin" { + querry = querry.Where("user_id = ?", userId) } if pag.TotalRecords, err = querry.Count(); err != nil { @@ -193,41 +219,68 @@ func GetTokens(c *gin.Context) { querry = querry.Limit(10).Offset(10 * (pag.CurrentPage)) } - if err := querry.Select(); err != nil { + if err = querry.Select(); err != nil { utils.InternalServerError(c, "Error while getting Tokens, Reason:", err) return } - c.JSON(http.StatusOK, gin.H{ "status": http.StatusOK, - "message": fmt.Sprintf("All Tokens by %s", email), + "message": fmt.Sprintf("All Tokens by %s", userId), "data": accessTokens, "pagination": pag.Validate(), }) } func UpdateToken(c *gin.Context) { - id := c.Param("id") + + var userId string + tokenId := c.Param("token-id") + userId = c.Param("id") + + if userId == "" { + userIdFromToken, _ := c.Get("user_id") + userId = userIdFromToken.(string) + } + var accessToken models.AccessToken + c.Bind(&accessToken) if accessToken.Roles == nil { c.JSON(http.StatusBadRequest, gin.H{ "status": http.StatusBadRequest, - "message": "Token not Udpated include data in req body", + "message": "Token roles not include to update data in req body", }) } + if c.Param("id") == "" { + for _, role := range accessToken.Roles { + if role == roles.Admin { + c.JSON(http.StatusUnauthorized, gin.H{ + "status": http.StatusUnauthorized, + "message": "invalid role admin", + }) + return + } + } + } + querry := database.Connection.Model(&models.AccessToken{}).Set("roles = ?", accessToken.Roles).Set("updated_at = ?", time.Now()) + querry = querry.Where("id = ?", tokenId) - res, err := querry.Where("id = ?", id).Update() + if userId != "admin" { + querry = querry.Where("user_id = ?", userId) + } + + res, err := querry.Update() if err != nil { utils.InternalServerError(c, "Error while editing token, Reason:", err) } + if res.RowsAffected() == 0 { c.JSON(http.StatusNotFound, gin.H{ "status": http.StatusNotFound, - "message": "Token not found", + "message": "Token/user not found or unauthorised request", }) return } @@ -237,3 +290,218 @@ func UpdateToken(c *gin.Context) { "message": "Token Edited Successfully", }) } + +func CreateNewUser(c *gin.Context) { + userId := guuid.New().String() + tokenId := guuid.New().String() + token := utils.GenerateToken(tokenId) + + user := models.User{} + + c.Bind(&user) + + count, err := database.Connection.Model(&models.User{}).Where("email = ?", user.Email).Count() + log.Println(count) + + if err != nil { + utils.InternalServerError(c, "Error while getting tokens, Reason:", err) + return + } + + if count != 0 { + + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "Email already exists", + }) + return + } + + user = models.User{ + ID: userId, + Email: user.Email, + CreatedAt: time.Now(), + } + + accessToken := models.AccessToken{ + ID: tokenId, + Token: token, + UserID: userId, + Roles: []string{roles.Read, roles.Write}, + Expiry: time.Now().AddDate(0, 0, 10), + CreatedAt: time.Now(), + } + + if insertError := database.Connection.Insert(&user); insertError != nil { + utils.InternalServerError(c, "Error while inserting new user into db, Reason:", insertError) + return + } + + if insertError := database.Connection.Insert(&accessToken); insertError != nil { + utils.InternalServerError(c, "Error while inserting new token into db, Reason:", insertError) + return + } + + c.JSON(http.StatusCreated, gin.H{ + "status": http.StatusCreated, + "message": "User And Token created Successfully", + "user": user, + "token": accessToken, + }) +} + +func GetUserID(c *gin.Context) { + userId, _ := c.Get("user_id") + var accessTokens []models.AccessToken + database.Connection.Model(&accessTokens).Where("user_id = ?", userId.(string)).Order("created_at DESC").Select() + + c.JSON(http.StatusOK, gin.H{ + "status": http.StatusOK, + "user_id": userId, + "access_tokens": accessTokens, + }) +} + +func CreateBill(c *gin.Context) { + var userId string + userId = c.Param("id") + + if userId == "" { + userIdFromToken, _ := c.Get("user_id") + userId = userIdFromToken.(string) + } + + var err error + + count, err := database.Connection.Model(&models.User{}).Where("id = ?", userId).Count() + if err != nil { + utils.InternalServerError(c, "Error counting users for bill, Reason:", err) + return + } + if count == 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "user not found", + }) + c.Abort() + return + } + + var bill = models.Bill{ + ID: guuid.New().String(), + APIUsage: 0, + BillValue: 0, + Sattled: false, + UserID: userId, + CreatedAt: time.Now(), + } + + if insertError := database.Connection.Insert(&bill); insertError != nil { + utils.InternalServerError(c, "Error inserting bill, Reason:", insertError) + return + } + + var accessTokens []models.AccessToken + + if err = database.Connection.Model(&accessTokens).Where("user_id = ?", userId).Select(); err != nil { + log.Printf("no Token found for the given user %s", err) + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "no Token found for the given user", + }) + c.Abort() + return + } + + var accessLogs []models.AccessLog + + var tokensStr string + + for index, accessToken := range accessTokens { + if index == 0 { + tokensStr = fmt.Sprintf("'%s'", accessToken.Token) + } + tokensStr = fmt.Sprintf("%s,'%s'", tokensStr, accessToken.Token) + } + + querry := database.Connection.Model(&accessLogs) + querry = querry.Set("bill_id = ?", bill.ID) + querry = querry.Set("billed = true") + + querry = querry.Where(fmt.Sprintf("token in (%s)", tokensStr)) + querry = querry.Where("status_code between 100 and 499") + querry = querry.Where("billed = false") + + res, updateErr := querry.Update() + + if updateErr != nil { + log.Printf("Error While fetching access logs %s", updateErr) + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "Error While fetching access logs", + }) + c.Abort() + return + } + + // if res.RowsAffected() == 0 { + // TODO: Delete bill and return error + // } + + bill.CalculateBillValue(res.RowsAffected()) + + querry = database.Connection.Model(&bill).WherePK() + res, updateErr = querry.Update() + + if updateErr != nil || res.RowsAffected() == 0 { + log.Printf("Error While updating bill %s", err) + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "Error While updating bill", + }) + c.Abort() + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": http.StatusOK, + "message": "Billing Done", + "data": bill, + }) +} + +func GetBills(c *gin.Context) { + var userId string + userId = c.Param("id") + + if userId == "" { + userIdFromToken, _ := c.Get("user_id") + userId = userIdFromToken.(string) + } + + var err error + var total float32 = 0 + var bills []models.Bill + + err = database.Connection.Model(&bills).Where("user_id = ?", userId).Order("created_at DESC").Select() + + if err != nil { + log.Printf("Error while getting bills, Reason: %s", err) + c.JSON(http.StatusBadRequest, gin.H{ + "status": http.StatusBadRequest, + "message": "Error While getting bills or not bills found", + }) + } + + for _, bill := range bills { + if !bill.Sattled { + total += bill.BillValue + } + } + c.JSON(http.StatusOK, gin.H{ + "status": http.StatusOK, + "message": "All Bills", + "data": bills, + "total": total, + }) +} diff --git a/database/database.go b/database/database.go index 7af4c37..0988898 100644 --- a/database/database.go +++ b/database/database.go @@ -1,8 +1,10 @@ package database import ( + "fmt" "github.com/go-pg/pg/v9" orm "github.com/go-pg/pg/v9/orm" + guuid "github.com/google/uuid" "harshsinghvi/golang-postgres-kubernetes/models" "harshsinghvi/golang-postgres-kubernetes/models/roles" "harshsinghvi/golang-postgres-kubernetes/utils" @@ -60,41 +62,39 @@ func Connect() *pg.DB { return Connection } - -func CreateTables() error { +func createTablesAndIndexes(tableName string, model interface{}, indexFields string) error { opts := &orm.CreateTableOptions{ IfNotExists: true, } - if createError := Connection.CreateTable(&models.Todo{}, opts); createError != nil { - log.Printf("Error while creating todo table, Reason: %v\n", createError) - return createError - } - if _, err := Connection.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS go_index_todos ON todos(completed, created_at);`); err != nil { - log.Println(err.Error()) - return err - } - if createError := Connection.CreateTable(&models.AccessToken{}, opts); createError != nil { - log.Printf("Error while creating access_tokens table, Reason: %v\n", createError) - return createError - } - if _, err := Connection.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS go_index_access_tokens ON access_tokens(created_at, token, email);`); err != nil { - log.Println(err.Error()) - return err - } - if createError := Connection.CreateTable(&models.AccessLog{}, opts); createError != nil { - log.Printf("Error while creating access_logs table, Reason: %v\n", createError) + + createIndexQuerry := fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS go_index_%s ON %s(%s);", tableName, tableName, indexFields) + + if createError := Connection.CreateTable(model, opts); createError != nil { + log.Printf("Error while creating %s table, Reason: %v\n", tableName, createError) return createError } - if _, err := Connection.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS go_index_access_logs ON access_logs(token, path, method, response_time, status_code, server_hostname, created_at);`); err != nil { - log.Println(err.Error()) + + if _, err := Connection.Exec(createIndexQuerry); err != nil { + log.Printf("Error while index of %s table, Reason: %v\n", tableName, err) return err } - log.Printf("Todo table and indexes created") - checkAndCreateAdminToken() + log.Printf("INFO: %s table and its indexes created", tableName) return nil } +func CreateTables() { + createTablesAndIndexes("todos", &models.Todo{}, "completed, created_at") + createTablesAndIndexes("access_tokens", &models.AccessToken{}, "created_at, token, expiry, user_id") + // TODO: fails + createTablesAndIndexes("access_logs", &models.AccessLog{}, "token, path, method, response_time, status_code, server_hostname, created_at, bill_id, billed") + createTablesAndIndexes("users", &models.User{}, "email, created_at") + createTablesAndIndexes("bills", &models.Bill{}, "sattled, user_id, created_at") + checkAndCreateAdminUser() + checkAndCreateAdminToken() + log.Printf("all tables and indexes created") +} + func checkAndCreateAdminToken() { var accessToken models.AccessToken querry := Connection.Model(&accessToken).Where("id = ?", "admin") @@ -107,12 +107,12 @@ func checkAndCreateAdminToken() { } id := "admin" - token := utils.GenerateToken(id) + token := utils.GenerateToken(guuid.New().String()) insertError := Connection.Insert(&models.AccessToken{ ID: id, Token: token, - Email: id, + UserID: id, Expiry: time.Now().AddDate(99, 0, 00), CreatedAt: time.Now(), Roles: []string{roles.Admin}, @@ -124,3 +124,29 @@ func checkAndCreateAdminToken() { log.Printf("Admin Token created") } + +func checkAndCreateAdminUser() { + var user models.User + querry := Connection.Model(&user).Where("id = ?", "admin") + count, err := querry.Count() + if err != nil { + log.Println("Error in getting users count") + } + if count != 0 { + return + } + + id := "admin" + + insertError := Connection.Insert(&models.User{ + ID: id, + Email: id, + CreatedAt: time.Now(), + }) + + if insertError != nil { + log.Printf("Error while inserting new user into db, Reason: %v\n", insertError) + } + + log.Printf("Admin user created") +} diff --git a/main.go b/main.go index cbad8db..d5e3501 100644 --- a/main.go +++ b/main.go @@ -52,10 +52,28 @@ func main() { v2 := api.Group("/v2") { - v2.POST("/token", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.CreateNewToken) - v2.GET("/token/:email", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.GetTokens) - v2.PUT("/token/:id", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.UpdateToken) + v2.POST("/user", controllers.CreateNewUser) + // Users Endpoints + v2.GET("/user", middlewares.AuthMiddleware([]string{roles.Any}), controllers.GetUserID) + v2.POST("/user/token", middlewares.AuthMiddleware([]string{roles.Write}), controllers.CreateNewToken) + v2.GET("/user/token", middlewares.AuthMiddleware([]string{roles.Read}), controllers.GetTokens) + v2.PUT("/user/token/:token-id", middlewares.AuthMiddleware([]string{roles.Write}), controllers.UpdateToken) + v2.POST("/user/bill", middlewares.AuthMiddleware([]string{roles.Any}), controllers.CreateBill) + v2.GET("/user/bill", middlewares.AuthMiddleware([]string{roles.Any}), controllers.GetBills) + + // Admin Endpoints + v2.POST("/user/:id/token", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.CreateNewToken) + v2.GET("/user/:id/token", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.GetTokens) + v2.PUT("/user/:user-id/token/:token-id", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.UpdateToken) + v2.POST("/user/:id/bill", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.CreateBill) + v2.GET("/user/:id/bill", middlewares.AuthMiddleware([]string{roles.Admin}), controllers.GetBills) + + // TODO Soft delete + // Delete Token + // delete user + + // Business Logic v2.GET("/todo/", middlewares.AuthMiddleware([]string{roles.Admin, roles.Read}), controllers.GetAllTodos) v2.GET("/todo/:id", middlewares.AuthMiddleware([]string{roles.Admin, roles.Read, roles.ReadOne}), controllers.GetSingleTodo) v2.POST("/todo/", middlewares.AuthMiddleware([]string{roles.Admin, roles.Write, roles.WriteNewOnly}), controllers.CreateTodo) diff --git a/middlewares/middlewares.go b/middlewares/middlewares.go index 3f54c2f..4e28cce 100644 --- a/middlewares/middlewares.go +++ b/middlewares/middlewares.go @@ -20,11 +20,13 @@ func AuthMiddleware(requiredRoles []string) gin.HandlerFunc { var err error token := c.GetHeader("token") reqId := guuid.New().String() + c.Set("requestId", reqId) c.Writer.Header().Set("X-Request-Id", reqId) if token == "" { utils.UnauthorizedResponse(c) + logReqToDb(reqId, accessToken.Token, c, reqStart) return } @@ -33,48 +35,57 @@ func AuthMiddleware(requiredRoles []string) gin.HandlerFunc { if count, err = querry.Count(); err != nil { utils.InternalServerError(c, "Error while getting tokens, Reason:", err) c.Abort() + logReqToDb(reqId, accessToken.Token, c, reqStart) return } if count == 0 { utils.UnauthorizedResponse(c) + logReqToDb(reqId, accessToken.Token, c, reqStart) return } if err = querry.Select(); err != nil { utils.InternalServerError(c, "Error while getting all todos, Reason:", err) - c.Abort() + logReqToDb(reqId, accessToken.Token, c, reqStart) return } if time.Until(accessToken.Expiry).Seconds() <= 0 || !roles.CheckRoles(requiredRoles, accessToken.Roles) { utils.UnauthorizedResponse(c) + logReqToDb(reqId, accessToken.Token, c, reqStart) return } + c.Set("token", token) + c.Set("user_id", accessToken.UserID) c.Next() + logReqToDb(reqId, accessToken.Token, c, reqStart) + } +} - var hostname string - if hostname, err = os.Hostname(); err != nil { - log.Printf("Error loading system hostname %v\n", err) - } +func logReqToDb(reqId string, accessToken string, c *gin.Context, reqStart time.Time) { + var err error + var hostname string + if hostname, err = os.Hostname(); err != nil { + log.Printf("Error loading system hostname %v\n", err) + } - insertError := database.Connection.Insert(&models.AccessLog{ - ID: reqId, - Token: accessToken.Token, - Path: c.Request.URL.Path, - ServerHostname: hostname, - ResponseSize: c.Writer.Size(), - StatusCode: c.Writer.Status(), - ClientIP: c.ClientIP(), - Method: c.Request.Method, - ResponseTime: time.Since(reqStart).Milliseconds(), - CreatedAt: time.Now(), - }) - if insertError != nil { - log.Println("Error loging request in db.") - return - } + insertError := database.Connection.Insert(&models.AccessLog{ + ID: reqId, + Token: accessToken, + Path: c.Request.URL.Path, + ServerHostname: hostname, + ResponseSize: c.Writer.Size(), + StatusCode: c.Writer.Status(), + ClientIP: c.ClientIP(), + Method: c.Request.Method, + ResponseTime: time.Since(reqStart).Milliseconds(), + CreatedAt: time.Now(), + }) + if insertError != nil { + log.Println("Error loging request in db.") + return } } diff --git a/models/auth.go b/models/auth.go index f2a602b..b81691c 100644 --- a/models/auth.go +++ b/models/auth.go @@ -3,18 +3,19 @@ package models import "time" type AccessToken struct { - ID string `json:"id"` - Email string `json:"email"` + ID string `json:"id"` + // Email string `json:"email"` Token string `json:"token"` Roles []string `json:"roles"` // read, read-one, write, write-new-only, write-update-only Expiry time.Time `json:"expiry"` + UserID string `json:"user_id"` UpdatedAt time.Time `json:"updated_at"` CreatedAt time.Time `json:"created_at"` } type AccessLog struct { ID string `json:"id"` - Token string `json:"token"` + Token string `json:"token"` // TODO: Change this to TokenID Path string `json:"path"` ClientIP string `json:"client_ip"` Method string `json:"method"` @@ -22,5 +23,8 @@ type AccessLog struct { ResponseSize int `json:"response_size"` StatusCode int `json:"status_code"` ServerHostname string `json:"server_hostname"` + BillID string `json:"bill_id"` + Billed bool `json:"billed" pg:",use_zero"` CreatedAt time.Time `json:"created_at"` + // UpdatedAt time.Time `json:"updated_at"` // TODO: Latter } diff --git a/models/billing.go b/models/billing.go new file mode 100644 index 0000000..d54ac3c --- /dev/null +++ b/models/billing.go @@ -0,0 +1,24 @@ +package models + +import "time" + +const RatePerApiCall float32 = 0.1 +const Currency = "INR" + +type Bill struct { + ID string `json:"id"` + APIUsage int `json:"api_usage"` + BillValue float32 `json:"bill_value"` + Sattled bool `json:"sattled"` + UserID string `json:"user_id"` + Currency string `json:"currency"` + CreatedAt time.Time `json:"created_at"` +} + +func (bill *Bill) CalculateBillValue(usage int) { + bill.APIUsage = usage + bill.BillValue = RatePerApiCall * float32(usage) + if bill.Currency == "" { + bill.Currency = Currency + } +} diff --git a/models/roles/roles.go b/models/roles/roles.go index 3f18d5b..b9f4741 100644 --- a/models/roles/roles.go +++ b/models/roles/roles.go @@ -1,6 +1,7 @@ package roles const ( + Any = "any" Admin = "admin" Read = "read" ReadOne = "read_one" @@ -11,7 +12,15 @@ const ( func CheckRoles(requiredRoles []string, grantedRoles []string) bool { for _, requiredRole := range requiredRoles { + + if requiredRole == Any && len(grantedRoles) != 0 { + return true + } + for _, grantedRole := range grantedRoles { + if grantedRole == Admin { + return true + } if grantedRole == requiredRole { return true } diff --git a/models/todo.go b/models/todo.go index f70b307..9ab38dc 100644 --- a/models/todo.go +++ b/models/todo.go @@ -5,7 +5,7 @@ import "time" type Todo struct { ID string `json:"id"` Text string `json:"text"` - Completed bool `json:"completed"` + Completed bool `json:"completed" pg:",use_zero"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..d59925e --- /dev/null +++ b/models/user.go @@ -0,0 +1,10 @@ +package models + +import "time" + +type User struct { + ID string `json:"id"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/utils/error.go b/utils/error.go index a040a44..70602ed 100644 --- a/utils/error.go +++ b/utils/error.go @@ -12,6 +12,7 @@ func InternalServerError(c *gin.Context, msg string, err error) { "status": http.StatusInternalServerError, "message": "Something went wrong", }) + c.Abort() } func UnauthorizedResponse(c *gin.Context) {