Skip to content

Commit

Permalink
Merge pull request #216 from actiontech/fix-issue2309-1
Browse files Browse the repository at this point in the history
Fix issue2309 1
  • Loading branch information
LordofAvernus authored Mar 22, 2024
2 parents 2cf3439 + 0a7d553 commit e86623c
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 605 deletions.
1 change: 0 additions & 1 deletion internal/apiserver/cmd/server/dms.pid

This file was deleted.

601 changes: 0 additions & 601 deletions internal/apiserver/cmd/server/logs/dms.log

This file was deleted.

12 changes: 11 additions & 1 deletion internal/apiserver/service/dms_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,17 @@ func (a *DMSController) GenAccessToken(c echo.Context) error {
if nil != err {
return NewErrResp(c, err, apiError.BadRequestErr)
}
reply := &dmsV1.GenAccessTokenReply{}

// get current user id
currentUid, err := jwt.GetUserUidStrFromContext(c)
if err != nil {
return NewErrResp(c, err, apiError.DMSServiceErr)
}

reply, err := a.DMS.GenAccessToken(c.Request().Context(), currentUid, req)
if nil != err {
return NewErrResp(c, err, apiError.DMSServiceErr)
}
return NewOkRespWithReply(c, reply)
}

Expand Down
5 changes: 3 additions & 2 deletions internal/apiserver/service/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ func (s *APIServer) installMiddleware() error {
if strings.HasSuffix(c.Request().RequestURI, dmsV1.SessionRouterGroup) ||
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/oauth2" /* TODO 使用统一方法skip */) ||
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/personalization/logo") ||
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/configurations/license" /* TODO 使用统一方法skip */) ||
!strings.HasPrefix(c.Request().RequestURI, dmsV1.CurrentGroupVersion) {
strings.HasPrefix(c.Request().RequestURI, "/v1/dms/configurations/license" /* TODO 使用统一方法skip */) {
logger.Debugf("skipper url jwt check: %v", c.Request().RequestURI)
return true
}
Expand All @@ -245,6 +244,8 @@ func (s *APIServer) installMiddleware() error {

s.echo.Use(dmsMiddleware.LicenseAdapter(s.DMSController.DMS.LicenseUsecase))

s.echo.Use(s.DMSController.DMS.AuthAccessTokenUseCase.CheckLatestAccessToken())

s.echo.Use(middleware.ProxyWithConfig(middleware.ProxyConfig{
Skipper: s.DMSController.DMS.DmsProxyUsecase.GetEchoProxySkipper(),
Balancer: s.DMSController.DMS.DmsProxyUsecase.GetEchoProxyBalancer(),
Expand Down
67 changes: 67 additions & 0 deletions internal/dms/biz/access_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package biz

import (
"fmt"
"net/http"

jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
)

const AccessTokenLogin = "access_token_login"

type AuthAccessTokenUsecase struct {
userUsecase *UserUsecase
log *utilLog.Helper
}

func NewAuthAccessTokenUsecase(log utilLog.Logger, usecase *UserUsecase) *AuthAccessTokenUsecase {
au := &AuthAccessTokenUsecase{
userUsecase: usecase,
log: utilLog.NewHelper(log, utilLog.WithMessageKey("biz.accesstoken")),
}
return au
}

func (au *AuthAccessTokenUsecase) CheckLatestAccessToken() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := c.Get("user")
// 获取token为空,代表该请求不需要校验token,例如:/v1/dms/oauth2
if user == nil {
return next(c)
}
token, ok := user.(*jwt.Token)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "failed to convert user from jwt token")
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "failed to convert token claims to jwt")
}

// 如果不存在JWTLoginType字段,代表是账号密码登录获取的token或者是扫描任务的凭证,不进行校验
loginType, ok := claims[jwtPkg.JWTLoginType]
if !ok {
return next(c)
}
if loginType != AccessTokenLogin {
return echo.NewHTTPError(http.StatusUnauthorized, "access token login type is error")
}
uidStr := fmt.Sprintf("%v", claims[jwtPkg.JWTUserId])
accessTokenInfo, err := au.userUsecase.repo.GetAccessTokenByUser(c.Request().Context(), uidStr)
if err != nil {
return err
}

if accessTokenInfo.Token != token.Raw {
return echo.NewHTTPError(http.StatusUnauthorized, "access token is not latest")
}

return next(c)
}
}
}
32 changes: 32 additions & 0 deletions internal/dms/biz/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"strconv"
"time"

pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant"
Expand Down Expand Up @@ -83,6 +84,13 @@ type User struct {
Deleted bool
}

type AccessTokenInfo struct {
UID string
UserID uint
Token string
ExpiredTime time.Time
}

func initUsers() []*User {
return []*User{
{
Expand Down Expand Up @@ -162,6 +170,8 @@ type UserRepo interface {
GetUserGroupsByUser(ctx context.Context, userUid string) ([]*UserGroup, error)
GetOpPermissionsByUser(ctx context.Context, userUid string) ([]*OpPermission, error)
GetUserByThirdPartyUserID(ctx context.Context, thirdPartyUserUID string) (*User, error)
SaveAccessToken(ctx context.Context, accessTokenInfo *AccessTokenInfo) error
GetAccessTokenByUser(ctx context.Context, UserUid string) (*AccessTokenInfo, error)
}

type UserUsecase struct {
Expand Down Expand Up @@ -769,3 +779,25 @@ func (d *UserUsecase) GetBizUserWithNameByUids(ctx context.Context, uids []strin
}
return ret
}

func (d *UserUsecase) SaveAccessToken(ctx context.Context, userId string, token string, expiredTime time.Time) error {
userIdInt, err := strconv.Atoi(userId)
if err != nil {
return err
}
uid, err := pkgRand.GenStrUid()
if err != nil {
return err
}

tokenInfo := &AccessTokenInfo{UID: uid, UserID: uint(userIdInt), Token: token, ExpiredTime: expiredTime}
return d.repo.SaveAccessToken(ctx, tokenInfo)
}

func (d *UserUsecase) GetAccessTokenByUser(ctx context.Context, UserUid string) (*AccessTokenInfo, error) {
accessTokenInfo, err := d.repo.GetAccessTokenByUser(ctx, UserUid)
if err != nil {
return nil, err
}
return accessTokenInfo, nil
}
3 changes: 3 additions & 0 deletions internal/dms/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type DMSService struct {
ClusterUsecase *biz.ClusterUsecase
DataExportWorkflowUsecase *biz.DataExportWorkflowUsecase
DataMaskingUsecase *biz.DataMaskingUsecase
AuthAccessTokenUseCase *biz.AuthAccessTokenUsecase
log *utilLog.Helper
shutdownCallback func() error
}
Expand Down Expand Up @@ -111,6 +112,7 @@ func NewAndInitDMSService(logger utilLog.Logger, opts *conf.DMSOptions) (*DMSSer
workflowRepo := storage.NewWorkflowRepo(logger, st)
DataExportWorkflowUsecase := biz.NewDataExportWorkflowUsecase(logger, tx, workflowRepo, dataExportTaskRepo, dbServiceRepo, opPermissionVerifyUsecase, projectUsecase, dmsProxyTargetRepo, clusterUsecase, webhookConfigurationUsecase, userUsecase, fmt.Sprintf("%s:%d", opts.ReportHost, opts.APIServiceOpts.Port))
dataMasking, err := maskingBiz.NewDataMaskingUseCase(logger)
authAccessTokenUsecase := biz.NewAuthAccessTokenUsecase(logger, userUsecase)
if err != nil {
return nil, fmt.Errorf("failed to new data masking use case: %v", err)
}
Expand Down Expand Up @@ -147,6 +149,7 @@ func NewAndInitDMSService(logger utilLog.Logger, opts *conf.DMSOptions) (*DMSSer
ClusterUsecase: clusterUsecase,
DataExportWorkflowUsecase: DataExportWorkflowUsecase,
DataMaskingUsecase: dataMaskingUsecase,
AuthAccessTokenUseCase: authAccessTokenUsecase,
log: utilLog.NewHelper(logger, utilLog.WithMessageKey("dms.service")),
shutdownCallback: func() error {
if err := st.Close(); nil != err {
Expand Down
42 changes: 42 additions & 0 deletions internal/dms/service/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"

dmsV1 "github.com/actiontech/dms/api/dms/service/v1"
"github.com/actiontech/dms/internal/dms/biz"
pkgConst "github.com/actiontech/dms/internal/dms/pkg/constant"

dmsCommonV1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1"
jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
"github.com/golang-jwt/jwt/v4"
)

func (d *DMSService) VerifyUserLogin(ctx context.Context, req *dmsV1.VerifyUserLoginReq) (reply *dmsV1.VerifyUserLoginReply, err error) {
Expand Down Expand Up @@ -505,6 +509,19 @@ func (d *DMSService) GetUser(ctx context.Context, req *dmsCommonV1.GetUserReq) (
}
dmsCommonUser.UserBindProjects = userBindProjects

// 获取用户access token
tokenInfo, err := d.UserUsecase.GetAccessTokenByUser(ctx, u.UID)
if err != nil {
return nil, fmt.Errorf("failed to get user access token: %v", err)
}
accessToken := dmsCommonV1.AccessTokenInfo{}
accessToken.AccessToken = tokenInfo.Token
accessToken.ExpiredTime = tokenInfo.ExpiredTime.Format("2006-01-02T15:04:05-07:00")
if tokenInfo.ExpiredTime.Before(time.Now()) {
accessToken.IsExpired = true
}
dmsCommonUser.AccessTokenInfo = accessToken

reply = &dmsCommonV1.GetUserReply{
Data: dmsCommonUser,
}
Expand All @@ -513,6 +530,31 @@ func (d *DMSService) GetUser(ctx context.Context, req *dmsCommonV1.GetUserReq) (
return reply, nil
}

func (d *DMSService) GenAccessToken(ctx context.Context, currentUserUid string, req *dmsCommonV1.GenAccessToken) (reply *dmsCommonV1.GenAccessTokenReply, err error) {
days, err := strconv.ParseUint(req.ExpirationDays, 10, 64)
if err != nil {
return nil, err
}

expiredTime := time.Now().Add(time.Duration(days) * 24 * time.Hour)
token, err := jwtPkg.GenJwtTokenWithExpirationTime(jwt.NewNumericDate(expiredTime), jwtPkg.WithUserId(currentUserUid), jwtPkg.WithAccessTokenMark(biz.AccessTokenLogin))
if err != nil {
return nil, fmt.Errorf("gen access token failed: %v", err)
}
if err := d.UserUsecase.SaveAccessToken(ctx, currentUserUid, token, expiredTime); err != nil {
return nil, fmt.Errorf("save access token failed: %v", err)
}

reply = &dmsCommonV1.GenAccessTokenReply{
Data: &dmsCommonV1.AccessTokenInfo{
AccessToken: token,
ExpiredTime: expiredTime.Format("2006-01-02T15:04:05-07:00"),
},
}

return reply, nil
}

func convertBizOpPermission(opPermissionUid string) (apiOpPermissionTyp dmsCommonV1.OpPermissionType, err error) {
switch opPermissionUid {
case pkgConst.UIDOfOpPermissionCreateWorkflow:
Expand Down
10 changes: 10 additions & 0 deletions internal/dms/storage/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ var AutoMigrateList = []interface{}{
WorkflowStep{},
DataExportTask{},
DataExportTaskRecord{},
UserAccessToken{},
}

type Model struct {
Expand Down Expand Up @@ -140,6 +141,15 @@ type OpPermission struct {
RangeType string `json:"range_type" gorm:"size:255;column:range_type"`
}

type UserAccessToken struct {
Model
Token string `json:"token" gorm:"size:255"`
ExpiredTime time.Time `json:"expired_time" example:"2018-10-21T16:40:23+08:00"`
UserID uint `json:"user_id" gorm:"size:32;index:user_id,unique"`

User *User `json:"user" gorm:"foreignkey:user_id"`
}

type DMSConfig struct {
Model
NeedInitOpPermissions bool `json:"need_init_op_permissions" gorm:"column:need_init_op_permissions"`
Expand Down
42 changes: 42 additions & 0 deletions internal/dms/storage/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package storage

import (
"context"
"errors"
"fmt"

"github.com/actiontech/dms/internal/dms/biz"
Expand All @@ -12,6 +13,7 @@ import (
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"

"gorm.io/gorm"
"gorm.io/gorm/clause"
)

var _ biz.UserRepo = (*UserRepo)(nil)
Expand Down Expand Up @@ -331,3 +333,43 @@ func (d *UserRepo) GetUserByThirdPartyUserID(ctx context.Context, thirdPartyUser
}
return ret, nil
}

func (d *UserRepo) SaveAccessToken(ctx context.Context, tokenInfo *biz.AccessTokenInfo) error {
userAccessToekn := &model.UserAccessToken{
Model: model.Model{
UID: tokenInfo.UID,
},
UserID: tokenInfo.UserID,
Token: tokenInfo.Token,
ExpiredTime: tokenInfo.ExpiredTime,
}

tx := d.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}},
DoUpdates: clause.Assignments(map[string]interface{}{"token": tokenInfo.Token, "expired_time": tokenInfo.ExpiredTime}),
}).Create(userAccessToekn)

if tx.Error != nil {
return fmt.Errorf("failed to save access token: %v", tx.Error)
}

return nil
}

func (d *UserRepo) GetAccessTokenByUser(ctx context.Context, userUid string) (*biz.AccessTokenInfo, error) {
var userToken *model.UserAccessToken
if err := transaction(d.log, ctx, d.db, func(tx *gorm.DB) error {
if err := tx.First(&userToken, "user_id = ?", userUid).Error; err != nil {
// 未找到记录返回空,不影响获取用户信息的功能
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("failed to get user access token: %v", err)
}
return nil
}); err != nil {
return nil, err
}

return &biz.AccessTokenInfo{Token: userToken.Token, ExpiredTime: userToken.ExpiredTime}, nil
}
20 changes: 20 additions & 0 deletions pkg/dms-common/api/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
JWTUsername = "name"
JWTExpiredTime = "exp"
JWTAuditPlanName = "apn"
JWTLoginType = "loginType"
)

func GenJwtToken(customClaims ...CustomClaimFunc) (tokenStr string, err error) {
Expand All @@ -31,6 +32,19 @@ func GenJwtToken(customClaims ...CustomClaimFunc) (tokenStr string, err error) {
JWTExpiredTime: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
}

return genJwtToken(mapClaims, customClaims...)
}

func GenJwtTokenWithExpirationTime(expiredTime *jwt.NumericDate, customClaims ...CustomClaimFunc) (tokenStr string, err error) {
var mapClaims = jwt.MapClaims{
"iss": "actiontech dms",
JWTExpiredTime: expiredTime,
}

return genJwtToken(mapClaims, customClaims...)
}

func genJwtToken(mapClaims jwt.MapClaims, customClaims ...CustomClaimFunc) (tokenStr string, err error) {
for _, claimFunc := range customClaims {
claimFunc(mapClaims)
}
Expand Down Expand Up @@ -69,6 +83,12 @@ func WithExpiredTime(duration time.Duration) CustomClaimFunc {
}
}

func WithAccessTokenMark(loginType string) CustomClaimFunc {
return func(claims jwt.MapClaims) {
claims[JWTLoginType] = loginType
}
}

func ParseUidFromJwtTokenStr(tokenStr string) (uid string, err error) {
token, err := parseJwtTokenStr(tokenStr)
if err != nil {
Expand Down

0 comments on commit e86623c

Please sign in to comment.