From d08c65ef7f30b3efa90719d2bfbc565644189d63 Mon Sep 17 00:00:00 2001 From: Fufu Date: Thu, 7 Mar 2024 10:49:26 +0800 Subject: [PATCH] refactor: blacklist and whitelist middleware --- common/helper.go | 38 +++++++++++++++++ common/helper_test.go | 17 ++++++++ web/fiber/middleware/blacklist.go | 30 ++++++-------- web/fiber/middleware/whitelist.go | 66 +++++++++++++++++++++++------- web/gin/middleware/blacklist.go | 32 ++++++--------- web/gin/middleware/whitelist.go | 68 ++++++++++++++++++++++++------- 6 files changed, 184 insertions(+), 67 deletions(-) create mode 100644 common/helper_test.go diff --git a/common/helper.go b/common/helper.go index 6db37f4..358e795 100644 --- a/common/helper.go +++ b/common/helper.go @@ -2,6 +2,9 @@ package common import ( "net" + "strconv" + + "github.com/fufuok/utils/xhash" ) // LookupIPNetsString 从 IP 段集合中查询并返回对应数值 @@ -23,3 +26,38 @@ func LookupIPNets(ip net.IP, ipNets map[*net.IPNet]int64) (int64, bool) { } return 0, false } + +// GenSign 使用时间戳和密钥生成简单签名字符串 +// 算法: md5(ts+key) +// 结果: ts+sign +func GenSign(ts int64, key string) string { + tss := strconv.FormatInt(ts, 10) + return GenSignString(tss, key) +} + +// GenSignString 字符串类型的时间戳生成签名 +func GenSignString(ts, key string) string { + if len(ts) != 10 || key == "" { + return "" + } + sign := xhash.MD5Hex(ts + key) + return ts + sign +} + +// VerifySign 校验签名 +func VerifySign(key, sign string) bool { + if key == "" || len(sign) != 42 { + return false + } + return sign == GenSignString(sign[:10], key) +} + +// VerifySignTTL 校验签名及签名有效期(当前时间 **秒 范围内有效) +func VerifySignTTL(key, sign string, second int64) bool { + if ok := VerifySign(key, sign); !ok { + return false + } + ts, _ := strconv.ParseInt(sign[:10], 10, 64) + now := GTimestamp() + return ts >= now-second && ts <= now+second +} diff --git a/common/helper_test.go b/common/helper_test.go new file mode 100644 index 0000000..04368a2 --- /dev/null +++ b/common/helper_test.go @@ -0,0 +1,17 @@ +package common + +import ( + "testing" + "time" + + "github.com/fufuok/utils" + "github.com/fufuok/utils/assert" +) + +func TestGenSign(t *testing.T) { + key := utils.RandString(18) + ts := time.Now().Unix() + sign := GenSign(ts, key) + t.Log("sign:", sign) + assert.True(t, VerifySignTTL(key, sign, 1)) +} diff --git a/web/fiber/middleware/blacklist.go b/web/fiber/middleware/blacklist.go index d8ebbe8..c219557 100644 --- a/web/fiber/middleware/blacklist.go +++ b/web/fiber/middleware/blacklist.go @@ -2,37 +2,31 @@ package middleware import ( "fmt" - "net/http" "github.com/gofiber/fiber/v2" "github.com/fufuok/pkg/common" "github.com/fufuok/pkg/config" - "github.com/fufuok/pkg/logger/sampler" "github.com/fufuok/pkg/web/fiber/proxy" - "github.com/fufuok/pkg/web/fiber/response" ) // CheckBlacklist 接口黑名单检查 func CheckBlacklist(asAPI bool) fiber.Handler { errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName) return func(c *fiber.Ctx) error { - if len(config.Blacklist) > 0 { - clientIP := proxy.GetClientIP(c) - if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). - Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). - Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). - Msg(msg) - if asAPI { - return response.APIException(c, http.StatusForbidden, msg, nil) - } else { - return response.TxtException(c, http.StatusForbidden, msg) - } - } + if BlacklistChecker(c) { + return responseForbidden(c, errMsg, asAPI) } return c.Next() } } + +// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false) +func BlacklistChecker(c *fiber.Ctx) bool { + clientIP := proxy.GetClientIP(c) + if len(config.Blacklist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Blacklist) + return ok + } + return false +} diff --git a/web/fiber/middleware/whitelist.go b/web/fiber/middleware/whitelist.go index 38aaba3..95d6fef 100644 --- a/web/fiber/middleware/whitelist.go +++ b/web/fiber/middleware/whitelist.go @@ -13,26 +13,62 @@ import ( "github.com/fufuok/pkg/web/fiber/response" ) +type SignChecker = func(*fiber.Ctx) bool + // CheckWhitelist 接口白名单检查 func CheckWhitelist(asAPI bool) fiber.Handler { errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName) return func(c *fiber.Ctx) error { - if len(config.Whitelist) > 0 { - clientIP := proxy.GetClientIP(c) - if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). - Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). - Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). - Msg(msg) - if asAPI { - return response.APIException(c, http.StatusForbidden, msg, nil) - } else { - return response.TxtException(c, http.StatusForbidden, msg) - } - } + if !WhitelistChecker(c) { + return responseForbidden(c, errMsg, asAPI) + } + return c.Next() + } +} + +// CheckWhitelistOrSign 检查接口白名单或签名 +func CheckWhitelistOrSign(signChecker SignChecker, asAPI bool) fiber.Handler { + errMsg := fmt.Sprintf("[ERROR] 无效签名或非法来访(%s): ", config.AppName) + return func(c *fiber.Ctx) error { + if !WhitelistChecker(c) && !signChecker(c) { + return responseForbidden(c, errMsg, asAPI) + } + return c.Next() + } +} + +// CheckWhitelistAndSign 同时检查接口白名单和签名 +func CheckWhitelistAndSign(signChecker SignChecker, asAPI bool) fiber.Handler { + errMsg := fmt.Sprintf("[ERROR] 无效签名或非法来访(%s): ", config.AppName) + return func(c *fiber.Ctx) error { + if !WhitelistChecker(c) || !signChecker(c) { + return responseForbidden(c, errMsg, asAPI) } return c.Next() } } + +// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true) +func WhitelistChecker(c *fiber.Ctx) bool { + clientIP := proxy.GetClientIP(c) + if len(config.Whitelist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Whitelist) + return ok + } + return true +} + +func responseForbidden(c *fiber.Ctx, msg string, asAPI bool) error { + clientIP := proxy.GetClientIP(c) + msg += clientIP + sampler.Info(). + Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)). + Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)). + Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP). + Msg(msg) + + if asAPI { + return response.APIException(c, http.StatusForbidden, msg, nil) + } + return response.TxtException(c, http.StatusForbidden, msg) +} diff --git a/web/gin/middleware/blacklist.go b/web/gin/middleware/blacklist.go index 94d6c66..4a91525 100644 --- a/web/gin/middleware/blacklist.go +++ b/web/gin/middleware/blacklist.go @@ -2,37 +2,31 @@ package middleware import ( "fmt" - "net/http" "github.com/gin-gonic/gin" "github.com/fufuok/pkg/common" "github.com/fufuok/pkg/config" - "github.com/fufuok/pkg/logger/sampler" - "github.com/fufuok/pkg/web/gin/response" ) // CheckBlacklist 接口黑名单检查 func CheckBlacklist(asAPI bool) gin.HandlerFunc { errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName) return func(c *gin.Context) { - if len(config.Blacklist) > 0 { - clientIP := c.ClientIP() - if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). - Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). - Msg(msg) - if asAPI { - response.APIException(c, http.StatusForbidden, msg, nil) - } else { - response.TxtException(c, http.StatusForbidden, msg) - } - return - } + if BlacklistChecker(c) { + responseForbidden(c, errMsg, asAPI) + return } - c.Next() } } + +// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false) +func BlacklistChecker(c *gin.Context) bool { + clientIP := c.ClientIP() + if len(config.Blacklist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Blacklist) + return ok + } + return false +} diff --git a/web/gin/middleware/whitelist.go b/web/gin/middleware/whitelist.go index 361eebd..14c295f 100644 --- a/web/gin/middleware/whitelist.go +++ b/web/gin/middleware/whitelist.go @@ -12,27 +12,65 @@ import ( "github.com/fufuok/pkg/web/gin/response" ) +type SignChecker = func(*gin.Context) bool + // CheckWhitelist 接口白名单检查 func CheckWhitelist(asAPI bool) gin.HandlerFunc { errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName) return func(c *gin.Context) { - if len(config.Whitelist) > 0 { - clientIP := c.ClientIP() - if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok { - msg := errMsg + clientIP - sampler.Info(). - Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). - Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). - Msg(msg) - if asAPI { - response.APIException(c, http.StatusForbidden, msg, nil) - } else { - response.TxtException(c, http.StatusForbidden, msg) - } - return - } + if !WhitelistChecker(c) { + responseForbidden(c, errMsg, asAPI) + return } + c.Next() + } +} +// CheckWhitelistOrSign 检查接口白名单或签名 +func CheckWhitelistOrSign(signChecker SignChecker, asAPI bool) gin.HandlerFunc { + errMsg := fmt.Sprintf("[ERROR] 无效令牌或非法来访(%s): ", config.AppName) + return func(c *gin.Context) { + if !WhitelistChecker(c) && !signChecker(c) { + responseForbidden(c, errMsg, asAPI) + return + } c.Next() } } + +// CheckWhitelistAndSign 同时检查接口白名单和签名 +func CheckWhitelistAndSign(signChecker SignChecker, asAPI bool) gin.HandlerFunc { + errMsg := fmt.Sprintf("[ERROR] 无效令牌或非法来访(%s): ", config.AppName) + return func(c *gin.Context) { + if !WhitelistChecker(c) || !signChecker(c) { + responseForbidden(c, errMsg, asAPI) + return + } + c.Next() + } +} + +// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true) +func WhitelistChecker(c *gin.Context) bool { + clientIP := c.ClientIP() + if len(config.Whitelist) > 0 { + _, ok := common.LookupIPNetsString(clientIP, config.Whitelist) + return ok + } + return true +} + +func responseForbidden(c *gin.Context, msg string, asAPI bool) { + clientIP := c.ClientIP() + msg += clientIP + sampler.Info(). + Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")). + Str("method", c.Request.Method).Str("uri", c.Request.RequestURI). + Msg(msg) + + if asAPI { + response.APIException(c, http.StatusForbidden, msg, nil) + } else { + response.TxtException(c, http.StatusForbidden, msg) + } +}