Skip to content

Commit c7f5fbe

Browse files
committed
feat: added csrf feature
1 parent c6125a0 commit c7f5fbe

File tree

7 files changed

+38
-14
lines changed

7 files changed

+38
-14
lines changed

src/controllers/root.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package controller
22

33
import (
4+
"net/http"
5+
46
"github.com/Wong801/gin-api/src/api"
57
"github.com/Wong801/gin-api/src/db"
68
service "github.com/Wong801/gin-api/src/services"
79
"github.com/gin-gonic/gin"
10+
csrf "github.com/utrack/gin-csrf"
811
)
912

1013
type RootController struct {
@@ -41,3 +44,12 @@ func (rc RootController) Ping() func(c *gin.Context) {
4144
c.Next()
4245
}
4346
}
47+
48+
func (rc RootController) GetToken() func(c *gin.Context) {
49+
return func(c *gin.Context) {
50+
c.Set("status", http.StatusOK)
51+
c.Set("data", csrf.GetToken(c))
52+
53+
c.Next()
54+
}
55+
}

src/controllers/user.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
model "github.com/Wong801/gin-api/src/models"
88
service "github.com/Wong801/gin-api/src/services"
99
"github.com/gin-gonic/gin"
10-
csrf "github.com/utrack/gin-csrf"
1110
)
1211

1312
type UserController struct {
@@ -94,18 +93,25 @@ func (uc UserController) Login() gin.HandlerFunc {
9493
return
9594
}
9695
c.SetCookie("jwt", token.Jwt, token.MaxAge, "/", token.Domain, token.Secure, token.HttpOnly)
97-
c.SetCookie("X-CSRF-TOKEN", csrf.GetToken(c), token.MaxAge, "/", token.Domain, false, false)
9896
c.Set("data", map[string]string{
9997
"message": "Login Success",
10098
})
10199
c.Next()
102100
}
103101
}
104102

103+
func (uc UserController) CheckLogin() gin.HandlerFunc {
104+
return func(c *gin.Context) {
105+
c.Set("data", map[string]string{
106+
"message": "success",
107+
})
108+
c.Next()
109+
}
110+
}
111+
105112
func (uc UserController) Logout() gin.HandlerFunc {
106113
return func(c *gin.Context) {
107114
c.SetCookie("jwt", "", -1, "/", "", true, true)
108-
c.SetCookie("X-CSRF-TOKEN", "", -1, "/", "", false, false)
109115
c.Next()
110116
}
111117
}

src/models/user.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,9 @@ type UserChangePassword struct {
3636
func (u User) GetBase() *UserBase {
3737
return &u.UserBase
3838
}
39+
40+
func (ub UserBase) GetUser() *User {
41+
return &User{
42+
UserBase: ub,
43+
}
44+
}

src/routes/main.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package route
22

33
import (
4-
"errors"
54
"net/http"
65

76
"github.com/Wong801/gin-api/src/config"
@@ -28,8 +27,10 @@ func InitRoutes() handler {
2827
r.router.Use(csrf.Middleware(csrf.Options{
2928
Secret: config.GetEnv("CSRF_SECRET", "secret"),
3029
ErrorFunc: func(c *gin.Context) {
31-
c.Set("status", http.StatusBadRequest)
32-
c.Set("error", errors.New("CSRF token mismatch"))
30+
c.AbortWithStatusJSON(http.StatusBadRequest, &entity.HttpResponse{
31+
Success: false,
32+
Data: "CSRF token mismatch",
33+
})
3334
},
3435
}))
3536

src/routes/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ func (r handler) addRoot(rg *gin.RouterGroup) {
1010

1111
rg.GET("/stats", rc.GetStats())
1212
rg.GET("/ping", rc.Ping())
13+
rg.GET("/csrf", rc.GetToken())
1314
}

src/routes/user.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func (r handler) addUsers(rg *gin.RouterGroup) {
1717

1818
userRoute.Use(m.Authenticate())
1919

20+
userRoute.POST("/check-login", userController.CheckLogin())
2021
userRoute.POST("/logout", userController.Logout())
2122
userRoute.PUT("/profile", userController.UpdateProfile())
2223
userRoute.PATCH("/change-password", userController.ChangePassword())

src/services/user.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,16 @@ func (us UserService) GetUser() (int, *model.UserBase, error) {
7474
return http.StatusNotFound, nil, err
7575
}
7676

77-
return http.StatusFound, user.GetBase(), nil
77+
return http.StatusOK, user.GetBase(), nil
7878
}
7979

8080
func (us UserService) UpdateUser(id int, u *model.UserBase) (int, *model.UserBase, error) {
81-
var user model.User
8281
db.Open(us.DB)
8382

84-
us.DB.Database.Where("id = ?", id)
85-
86-
if err := us.DB.Database.Save(&user).Error; err != nil {
87-
return http.StatusInternalServerError, nil, err
83+
if err := us.DB.Database.Save(u.GetUser()).Error; err != nil {
84+
return http.StatusBadRequest, nil, err
8885
}
89-
return http.StatusOK, user.GetBase(), nil
86+
return http.StatusOK, u.GetUser().GetBase(), nil
9087
}
9188

9289
func (us UserService) Register(u *model.User) (int, error) {
@@ -109,7 +106,7 @@ func (us UserService) Login(u *model.UserLogin) (int, *entity.Token, error) {
109106
db.Open(us.DB)
110107

111108
if err := us.DB.Database.First(&user, "email = ? OR username = ?", u.Email, u.Username).Error; err != nil {
112-
return http.StatusNotFound, nil, errors.New("user not found")
109+
return http.StatusBadRequest, nil, errors.New("incorrect username or password")
113110
}
114111

115112
correctPassword := checkPasswordHash(u.Password, user.Password)

0 commit comments

Comments
 (0)