diff --git a/web/fiber/middleware/blacklist.go b/web/fiber/middleware/blacklist.go index d8ebbe8..c219557 100644 --- a/web/fiber/middleware/blacklist.go +++ b/web/fiber/middleware/blacklist.go @@ -2,37 +2,31 @@ package middleware import ( "fmt" - "net/http" "github.com/gofiber/fiber/v2" "github.com/fufuok/pkg/common" "github.com/fufuok/pkg/config" - "github.com/fufuok/pkg/logger/sampler" "github.com/fufuok/pkg/web/fiber/proxy" - "github.com/fufuok/pkg/web/fiber/response" ) // CheckBlacklist 接口黑名单检查 func CheckBlacklist(asAPI bool) fiber.Handler { errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName) return func(c *fiber.Ctx) error { - if len(config.Blacklist) > 0 { - clientIP := proxy.GetClientIP(c) - if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). - Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). - Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). - Msg(msg) - if asAPI { - return response.APIException(c, http.StatusForbidden, msg, nil) - } else { - return response.TxtException(c, http.StatusForbidden, msg) - } - } + if BlacklistChecker(c) { + return responseForbidden(c, errMsg, asAPI) } return c.Next() } } + +// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false) +func BlacklistChecker(c *fiber.Ctx) bool { + clientIP := proxy.GetClientIP(c) + if len(config.Blacklist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Blacklist) + return ok + } + return false +} diff --git a/web/fiber/middleware/whitelist.go b/web/fiber/middleware/whitelist.go index 38aaba3..3989075 100644 --- a/web/fiber/middleware/whitelist.go +++ b/web/fiber/middleware/whitelist.go @@ -13,26 +13,62 @@ import ( "github.com/fufuok/pkg/web/fiber/response" ) +type ForbiddenChecker = func(*fiber.Ctx) bool + // CheckWhitelist 接口白名单检查 func CheckWhitelist(asAPI bool) fiber.Handler { errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName) return func(c *fiber.Ctx) error { - if len(config.Whitelist) > 0 { - clientIP := proxy.GetClientIP(c) - if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). - Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). - Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). - Msg(msg) - if asAPI { - return response.APIException(c, http.StatusForbidden, msg, nil) - } else { - return response.TxtException(c, http.StatusForbidden, msg) - } - } + if !WhitelistChecker(c) { + return responseForbidden(c, errMsg, asAPI) + } + return c.Next() + } +} + +// CheckWhitelistOr 校验接口白名单或自定义检查器 +func CheckWhitelistOr(checker ForbiddenChecker, asAPI bool) fiber.Handler { + errMsg := fmt.Sprintf("[ERROR] 禁止来访(%s): ", config.AppName) + return func(c *fiber.Ctx) error { + if !WhitelistChecker(c) && !checker(c) { + return responseForbidden(c, errMsg, asAPI) + } + return c.Next() + } +} + +// CheckWhitelistAnd 同时校验接口白名单和自定义检查器 +func CheckWhitelistAnd(checker ForbiddenChecker, asAPI bool) fiber.Handler { + errMsg := fmt.Sprintf("[ERROR] 禁止访问(%s): ", config.AppName) + return func(c *fiber.Ctx) error { + if !WhitelistChecker(c) || !checker(c) { + return responseForbidden(c, errMsg, asAPI) } return c.Next() } } + +// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true) +func WhitelistChecker(c *fiber.Ctx) bool { + clientIP := proxy.GetClientIP(c) + if len(config.Whitelist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Whitelist) + return ok + } + return true +} + +func responseForbidden(c *fiber.Ctx, msg string, asAPI bool) error { + clientIP := proxy.GetClientIP(c) + msg += clientIP + sampler.Info(). + Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). + Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). + Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). + Msg(msg) + + if asAPI { + return response.APIException(c, http.StatusForbidden, msg, nil) + } + return response.TxtException(c, http.StatusForbidden, msg) +} diff --git a/web/gin/middleware/blacklist.go b/web/gin/middleware/blacklist.go index 94d6c66..4a91525 100644 --- a/web/gin/middleware/blacklist.go +++ b/web/gin/middleware/blacklist.go @@ -2,37 +2,31 @@ package middleware import ( "fmt" - "net/http" "github.com/gin-gonic/gin" "github.com/fufuok/pkg/common" "github.com/fufuok/pkg/config" - "github.com/fufuok/pkg/logger/sampler" - "github.com/fufuok/pkg/web/gin/response" ) // CheckBlacklist 接口黑名单检查 func CheckBlacklist(asAPI bool) gin.HandlerFunc { errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName) return func(c *gin.Context) { - if len(config.Blacklist) > 0 { - clientIP := c.ClientIP() - if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). - Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). - Msg(msg) - if asAPI { - response.APIException(c, http.StatusForbidden, msg, nil) - } else { - response.TxtException(c, http.StatusForbidden, msg) - } - return - } + if BlacklistChecker(c) { + responseForbidden(c, errMsg, asAPI) + return } - c.Next() } } + +// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false) +func BlacklistChecker(c *gin.Context) bool { + clientIP := c.ClientIP() + if len(config.Blacklist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Blacklist) + return ok + } + return false +} diff --git a/web/gin/middleware/whitelist.go b/web/gin/middleware/whitelist.go index 361eebd..3583a17 100644 --- a/web/gin/middleware/whitelist.go +++ b/web/gin/middleware/whitelist.go @@ -12,27 +12,65 @@ import ( "github.com/fufuok/pkg/web/gin/response" ) +type ForbiddenChecker = func(*gin.Context) bool + // CheckWhitelist 接口白名单检查 func CheckWhitelist(asAPI bool) gin.HandlerFunc { errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName) return func(c *gin.Context) { - if len(config.Whitelist) > 0 { - clientIP := c.ClientIP() - if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). - Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). - Msg(msg) - if asAPI { - response.APIException(c, http.StatusForbidden, msg, nil) - } else { - response.TxtException(c, http.StatusForbidden, msg) - } - return - } + if !WhitelistChecker(c) { + responseForbidden(c, errMsg, asAPI) + return } + c.Next() + } +} +// CheckWhitelistOr 校验接口白名单或自定义检查器 +func CheckWhitelistOr(checker ForbiddenChecker, asAPI bool) gin.HandlerFunc { + errMsg := fmt.Sprintf("[ERROR] 禁止来访(%s): ", config.AppName) + return func(c *gin.Context) { + if !WhitelistChecker(c) && !checker(c) { + responseForbidden(c, errMsg, asAPI) + return + } c.Next() } } + +// CheckWhitelistAnd 同时校验接口白名单和自定义检查器 +func CheckWhitelistAnd(checker ForbiddenChecker, asAPI bool) gin.HandlerFunc { + errMsg := fmt.Sprintf("[ERROR] 禁止访问(%s): ", config.AppName) + return func(c *gin.Context) { + if !WhitelistChecker(c) || !checker(c) { + responseForbidden(c, errMsg, asAPI) + return + } + c.Next() + } +} + +// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true) +func WhitelistChecker(c *gin.Context) bool { + clientIP := c.ClientIP() + if len(config.Whitelist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Whitelist) + return ok + } + return true +} + +func responseForbidden(c *gin.Context, msg string, asAPI bool) { + clientIP := c.ClientIP() + msg += clientIP + sampler.Info(). + Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). + Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). + Msg(msg) + + if asAPI { + response.APIException(c, http.StatusForbidden, msg, nil) + } else { + response.TxtException(c, http.StatusForbidden, msg) + } +}