diff --git a/api/routers/route.go b/api/routers/route.go index 8488a75..fe4f340 100644 --- a/api/routers/route.go +++ b/api/routers/route.go @@ -3,14 +3,16 @@ package routers import ( "github.com/gin-gonic/gin" "github.com/jweny/pocassist/api/middleware/jwt" - "github.com/jweny/pocassist/api/routers/v1" + "github.com/jweny/pocassist/api/routers/v1/auth" + "github.com/jweny/pocassist/api/routers/v1/plugin" + "github.com/jweny/pocassist/api/routers/v1/vulnerability" + "github.com/jweny/pocassist/api/routers/v1/webapp" "github.com/jweny/pocassist/pkg/conf" "net/http" ) func Setup() { gin.SetMode(conf.GlobalConfig.ServerConfig.RunMode) - } @@ -24,58 +26,64 @@ func InitRouter(port string) { }) // api - router.POST("/api/v1/user/login", v1.GetAuth) + router.POST("/api/v1/user/login", auth.Login) pluginRoutes := router.Group("/api/v1/poc") pluginRoutes.Use(jwt.JWT()) { // all - pluginRoutes.GET("/", v1.GetPlugins) + pluginRoutes.GET("/", plugin.Get) // 增 - pluginRoutes.POST("/", v1.CreatePlugin) + pluginRoutes.POST("/", plugin.Add) // 改 - pluginRoutes.PUT("/:id/", v1.UpdatePlugin) + pluginRoutes.PUT("/:id/", plugin.Update) // 详情 - pluginRoutes.GET("/:id/", v1.GetPlugin) + pluginRoutes.GET("/:id/", plugin.Detail) // 删 - pluginRoutes.DELETE("/:id/", v1.DeletePlugin) - // 运行 - pluginRoutes.POST("/run/", v1.RunPlugin) + pluginRoutes.DELETE("/:id/", plugin.Delete) + // 测试单个poc + pluginRoutes.POST("/run/", plugin.Test) + //// 批量测试poc + //pluginRoutes.POST("/runs", plugin.RunPlugins) } vulRoutes := router.Group("/api/v1/vul") vulRoutes.Use(jwt.JWT()) { // basic - vulRoutes.GET("/basic/", v1.GetBasic) + vulRoutes.GET("/basic/", vulnerability.Basic) // all - vulRoutes.GET("/", v1.GetVuls) + vulRoutes.GET("/", vulnerability.Get) // 增 - vulRoutes.POST("/", v1.CreateVul) + vulRoutes.POST("/", vulnerability.Create) // 改 - vulRoutes.PUT("/:id/", v1.UpdateVul) + vulRoutes.PUT("/:id/", vulnerability.Update) // 详情 - vulRoutes.GET("/:id/", v1.GetVul) + vulRoutes.GET("/:id/", vulnerability.Detail) // 删 - vulRoutes.DELETE("/:id/", v1.DeleteVul) + vulRoutes.DELETE("/:id/", vulnerability.Delete) } appRoutes := router.Group("/api/v1/product") appRoutes.Use(jwt.JWT()) { // all - appRoutes.GET("/", v1.GetWebApps) + appRoutes.GET("/", webapp.Get) // 增 - appRoutes.POST("/", v1.CreateWebApp) + appRoutes.POST("/", webapp.Create) } - - userRoutes := router.Group("/api/v1/user") userRoutes.Use(jwt.JWT()) { - userRoutes.POST("/self/resetpwd/", v1.SelfResetPassword) - userRoutes.GET("/info", v1.SelfGetInfo) - userRoutes.GET("/logout", v1.SelfLogout) + userRoutes.POST("/self/resetpwd/", auth.Reset) + userRoutes.GET("/info", auth.Self) + userRoutes.GET("/logout", auth.Logout) + } + + // todo scan add jwt + scanRoutes := router.Group("/api/vi/scan") + { + scanRoutes.POST("") } router.Run(":" + port) diff --git a/api/routers/v1/auth/auth.go b/api/routers/v1/auth/auth.go index c044fda..efb466a 100644 --- a/api/routers/v1/auth/auth.go +++ b/api/routers/v1/auth/auth.go @@ -1,4 +1,4 @@ -package v1 +package auth import ( "github.com/gin-gonic/gin" @@ -17,7 +17,7 @@ type ResetPwd struct { NewPassword string `json:"newpassword"` } -func GetAuth(c *gin.Context) { +func Login(c *gin.Context) { login := auth{} err := c.BindJSON(&login) if err != nil { @@ -48,7 +48,7 @@ func GetAuth(c *gin.Context) { } } -func SelfResetPassword(c *gin.Context) { +func Reset(c *gin.Context) { resetPwd := ResetPwd{} err := c.BindJSON(&resetPwd) if err != nil { @@ -71,7 +71,7 @@ func SelfResetPassword(c *gin.Context) { } } -func SelfGetInfo(c *gin.Context) { +func Self(c *gin.Context) { token := c.Request.Header.Get("Authorization") claims, err := util.ParseToken(token) if err != nil || claims == nil { @@ -84,7 +84,7 @@ func SelfGetInfo(c *gin.Context) { return } -func SelfLogout(c *gin.Context) { +func Logout(c *gin.Context) { // 后端伪登出 todo:优化jwt c.JSON(msg.SuccessResp("登出成功")) return diff --git a/api/routers/v1/plugin/plugin.go b/api/routers/v1/plugin/plugin.go index 03baaf9..75459be 100644 --- a/api/routers/v1/plugin/plugin.go +++ b/api/routers/v1/plugin/plugin.go @@ -1,7 +1,6 @@ -package v1 +package plugin import ( - "bufio" "github.com/astaxie/beego/validation" "github.com/gin-gonic/gin" "github.com/jweny/pocassist/api/msg" @@ -11,16 +10,9 @@ import ( "github.com/jweny/pocassist/poc/rule" "github.com/unknwon/com" "gorm.io/datatypes" - "log" ) -const ( - TargetUrl = "url" - TargetUrlFile = "file" - TargetUrlRaw = "raw" -) - -type PluginSerializer struct { +type Serializer struct { // 返回给前端的字段 DespName string `json:"desp_name"` Id int `gorm:"primary_key" json:"id"` @@ -31,23 +23,15 @@ type PluginSerializer struct { Description int `gorm:"column:description" json:"description"` } -type RunSinglePluginSerializer struct { +type RunSerializer struct { // 运行单个 Target string `json:"target"` Affects string `gorm:"column:affects" json:"affects"` JsonPoc datatypes.JSON `gorm:"column:json_poc" json:"json_poc"` } -type RunPluginsSerializer struct { - // 批量运行 - Target string `json:"target"` - TargetType string `json:"target_type"` - RunType string `json:"run_type"` - VulIdList []string `json:"vul_id_list"` -} - //获取单个plugin -func GetPlugin(c *gin.Context) { +func Detail(c *gin.Context) { id := com.StrTo(c.Param("id")).MustInt() var data interface {} valid := validation.Validation{} @@ -68,7 +52,7 @@ func GetPlugin(c *gin.Context) { } //获取多个pluign -func GetPlugins(c *gin.Context) { +func Get(c *gin.Context) { data := make(map[string]interface{}) field := db.PluginSearchField{Search: "", EnableField:-1, AffectsField:"",} valid := validation.Validation{} @@ -91,7 +75,7 @@ func GetPlugins(c *gin.Context) { if ! valid.HasErrors() { plugins := db.GetPlugins(page, pageSize, &field) - var pluginRespData []PluginSerializer + var pluginRespData []Serializer for _, plugin := range plugins { var despName string @@ -101,7 +85,7 @@ func GetPlugins(c *gin.Context) { despName = "" } - pluginRespData = append(pluginRespData, PluginSerializer{ + pluginRespData = append(pluginRespData, Serializer{ DespName: despName, Id: plugin.Id, VulId: plugin.VulId, @@ -123,7 +107,7 @@ func GetPlugins(c *gin.Context) { } //新增 -func CreatePlugin(c *gin.Context) { +func Add(c *gin.Context) { plugin := db.Plugin{} err := c.BindJSON(&plugin) if err != nil { @@ -141,7 +125,7 @@ func CreatePlugin(c *gin.Context) { } //修改 -func UpdatePlugin(c *gin.Context) { +func Update(c *gin.Context) { plugin := db.Plugin{} err := c.BindJSON(&plugin) if err != nil { @@ -168,7 +152,7 @@ func UpdatePlugin(c *gin.Context) { } //删除 -func DeletePlugin(c *gin.Context) { +func Delete(c *gin.Context) { id := com.StrTo(c.Param("id")).MustInt() valid := validation.Validation{} @@ -190,8 +174,8 @@ func DeletePlugin(c *gin.Context) { } //运行单个plugin 不是从数据库提取数据,表单传数据 -func RunPlugin(c *gin.Context) { - run := RunSinglePluginSerializer{} +func Test(c *gin.Context) { + run := RunSerializer{} err := c.BindJSON(&run) if err != nil { c.JSON(msg.ErrResp("参数校验不通过")) @@ -225,64 +209,4 @@ func RunPlugin(c *gin.Context) { c.JSON(msg.ErrResp("检测目标、规则类型均不可为空")) return } -} - -//批量运行plugin 从数据库提取数据,表单传数据 -//前端向后端传 "vul_id_list":["poc_db_1","poc_db_2"] -func RunPlugins(c *gin.Context) { - runs := RunPluginsSerializer{} - err := c.BindJSON(&runs) - if err != nil { - c.JSON(msg.ErrResp("参数校验不通过")) - return - } - plugins, err := rule.LoadDbPlugin(runs.RunType, runs.VulIdList) - - switch runs.TargetType { - case TargetUrl: - url := runs.TargetType - oreq, err := util.GenOriginalReq(url) - if err != nil { - logging.GlobalLogger.Error("[original request gen err ]", err) - c.JSON(msg.ErrResp("原始请求生成失败")) - return - } - rule.RunPlugins(oreq, plugins) - case TargetUrlFile: - //获取文件 - file, header, err := c.Request.FormFile("file") - if err != nil { - logging.GlobalLogger.Error("[original request gen err ]", err) - c.JSON(msg.ErrResp("url文件上传失败")) - return - } - log.Print(header.Filename) - //content, err := ioutil.ReadAll(file) - var targets []string - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - val := scanner.Text() - if val == "" { - continue - } - targets = append(targets, val) - } - - for _, url := range targets { - oreq, err := util.GenOriginalReq(url) - if err != nil { - logging.GlobalLogger.Error("[original request gen err ]", err) - } - logging.GlobalLogger.Info("[start check url ]", url) - rule.RunPlugins(oreq, plugins) - } - case TargetUrlRaw: - //请求报文 - } -} - - - - - +} \ No newline at end of file diff --git a/api/routers/v1/scan/scan.go b/api/routers/v1/scan/scan.go new file mode 100644 index 0000000..bae03ed --- /dev/null +++ b/api/routers/v1/scan/scan.go @@ -0,0 +1,77 @@ +package scan + +import ( + "bufio" + "github.com/gin-gonic/gin" + "github.com/jweny/pocassist/api/msg" + "github.com/jweny/pocassist/pkg/logging" + "github.com/jweny/pocassist/pkg/util" + "github.com/jweny/pocassist/poc/rule" + "log" +) + +type UrlItem struct { + // 批量运行 + Target string `json:"target"` + LoadType string `json:"run_type"` // multi or all + VulIdList []string `json:"vul_id_list"` //前端向后端传 "vul_id_list":["poc_db_1","poc_db_2"] +} + +// 单个url +func Url(c *gin.Context) { + item := UrlItem{} + err := c.BindJSON(&item) + if err != nil { + c.JSON(msg.ErrResp("参数校验不通过")) + return + } + plugins, err := rule.LoadDbPlugin(item.LoadType, item.VulIdList) + oreq, err := util.GenOriginalReq(item.Target) + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + c.JSON(msg.ErrResp("原始请求生成失败")) + return + } + // todo 加载config + ch := make(chan util.ScanResult, 100) + rule.RunPlugins(oreq, plugins) +} + +// 加载文件 批量扫描 +func File(c *gin.Context) { + //获取文件 + file, header, err := c.Request.FormFile("file") + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + c.JSON(msg.ErrResp("url文件上传失败")) + return + } + log.Print(header.Filename) + var targets []string + scanner := bufio.NewScanner(file) + for scanner.Scan() { + val := scanner.Text() + if val == "" { + continue + } + targets = append(targets, val) + } + + item := UrlItem{} + err = c.BindJSON(&item) + if err != nil { + c.JSON(msg.ErrResp("参数校验不通过")) + return + } + + for _, url := range targets { + plugins, err := rule.LoadDbPlugin(item.LoadType, item.VulIdList) + oreq, err := util.GenOriginalReq(url) + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + c.JSON(msg.ErrResp("原始请求生成失败")) + return + } + rule.RunPlugins(oreq, plugins) + } +} diff --git a/api/routers/v1/vulnerability/vulnerability.go b/api/routers/v1/vulnerability/vulnerability.go index 358a197..4898b09 100644 --- a/api/routers/v1/vulnerability/vulnerability.go +++ b/api/routers/v1/vulnerability/vulnerability.go @@ -1,4 +1,4 @@ -package v1 +package vulnerability import ( "github.com/astaxie/beego/validation" @@ -25,41 +25,8 @@ type VulSerializer struct { Webapp int `gorm:"column:webapp" json:"webapp"` } -//获取 webapp -func GetWebApps(c *gin.Context) { - data := make(map[string]interface{}) - // 分页 - page, _ := com.StrTo(c.Query("page")).Int() - pageSize, _ := com.StrTo(c.Query("pagesize")).Int() - - apps := db.GetWebApps(page, pageSize) - data["data"] = apps - total := db.GetWebAppsTotal() - data["total"] = total - c.JSON(msg.SuccessResp(data)) - return -} - -//新增 -func CreateWebApp(c *gin.Context) { - app := db.Webapp{} - err := c.BindJSON(&app) - if err != nil { - c.JSON(msg.ErrResp("参数校验不通过")) - return - } - if db.ExistWebappByName(app.Name){ - c.JSON(msg.ErrResp("漏洞名称已存在")) - return - } else { - db.AddWebapp(app) - c.JSON(msg.SuccessResp(app)) - return - } -} - //获取单个描述 -func GetVul(c *gin.Context) { +func Detail(c *gin.Context) { id := com.StrTo(c.Param("id")).MustInt() var data interface {} valid := validation.Validation{} @@ -80,7 +47,7 @@ func GetVul(c *gin.Context) { } //获取多个描述 -func GetVuls(c *gin.Context) { +func Get(c *gin.Context) { data := make(map[string]interface{}) field := db.VulnerabilitySearchField{ Search:"", @@ -139,7 +106,7 @@ func GetVuls(c *gin.Context) { } //新增 -func CreateVul(c *gin.Context) { +func Create(c *gin.Context) { vul := db.Vulnerability{} err := c.BindJSON(&vul) if err != nil { @@ -157,7 +124,7 @@ func CreateVul(c *gin.Context) { } //修改 -func UpdateVul(c *gin.Context) { +func Update(c *gin.Context) { vul := db.Vulnerability{} err := c.BindJSON(&vul) if err != nil { @@ -185,7 +152,7 @@ func UpdateVul(c *gin.Context) { } //删除 -func DeleteVul(c *gin.Context) { +func Delete(c *gin.Context) { id := com.StrTo(c.Param("id")).MustInt() valid := validation.Validation{} @@ -206,28 +173,28 @@ func DeleteVul(c *gin.Context) { } } -type Basic struct { +type BasicObj struct { Name string `json:"name"` Label string `json:"label"` } // 前端需要的基础信息 -func GetBasic(c *gin.Context) { - var LanguageChoice []Basic +func Basic(c *gin.Context) { + var LanguageChoice []BasicObj for _, v := range []string{"Any","ASP","JAVA","Python","NodeJS","PHP","Ruby","ASPX"} { - LanguageChoice = append(LanguageChoice, Basic{Name: v, Label:v}) + LanguageChoice = append(LanguageChoice, BasicObj{Name: v, Label:v}) } - var AffectChoice []Basic + var AffectChoice []BasicObj for _, v := range []string{"server","text","directory","url","appendparam","replaceparam","script"} { - AffectChoice = append(AffectChoice, Basic{Name: v, Label:v}) + AffectChoice = append(AffectChoice, BasicObj{Name: v, Label:v}) } - var LevelChoice []Basic + var LevelChoice []BasicObj for _, v := range []string{"high","middle","low","info",} { - LevelChoice = append(LevelChoice, Basic{Name: v, Label:v}) + LevelChoice = append(LevelChoice, BasicObj{Name: v, Label:v}) } - var TypeChoice []Basic + var TypeChoice []BasicObj for _, v := range []string{"SQL 注入","命令执行","信息泄漏","其他类型","发现备份文件","未知","目录穿越","未授权","ShellCode","任意文件下载","任意文件读取","反序列化","任意文件写入","弱口令","权限提升","目录遍历","JAVA反序列化","代码执行","嵌入恶意代码","拒绝服务","文件上传","远程文件包含","跨站请求伪造","跨站脚本XSS","XPath注入","缓冲区溢出","XML注入","服务器端请求伪造","Cookie验证错误","解析错误","本地文件包含","配置错误"} { - TypeChoice = append(TypeChoice, Basic{Name: v, Label:v}) + TypeChoice = append(TypeChoice, BasicObj{Name: v, Label:v}) } data := make(map[string]interface{}) diff --git a/api/routers/v1/webapp/webapp.go b/api/routers/v1/webapp/webapp.go new file mode 100644 index 0000000..337162f --- /dev/null +++ b/api/routers/v1/webapp/webapp.go @@ -0,0 +1,41 @@ +package webapp + +import ( + "github.com/gin-gonic/gin" + "github.com/jweny/pocassist/api/msg" + "github.com/jweny/pocassist/pkg/db" + "github.com/unknwon/com" +) + +//获取 webapp +func Get(c *gin.Context) { + data := make(map[string]interface{}) + // 分页 + page, _ := com.StrTo(c.Query("page")).Int() + pageSize, _ := com.StrTo(c.Query("pagesize")).Int() + + apps := db.GetWebApps(page, pageSize) + data["data"] = apps + total := db.GetWebAppsTotal() + data["total"] = total + c.JSON(msg.SuccessResp(data)) + return +} + +//新增 +func Create(c *gin.Context) { + app := db.Webapp{} + err := c.BindJSON(&app) + if err != nil { + c.JSON(msg.ErrResp("参数校验不通过")) + return + } + if db.ExistWebappByName(app.Name){ + c.JSON(msg.ErrResp("漏洞名称已存在")) + return + } else { + db.AddWebapp(app) + c.JSON(msg.SuccessResp(app)) + return + } +} \ No newline at end of file diff --git a/cmd/run.go b/cmd/run.go index fac9fce..aecf8d3 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -1,15 +1,20 @@ package cmd import ( + "fmt" + "github.com/fsnotify/fsnotify" "github.com/jweny/pocassist/api/routers" conf2 "github.com/jweny/pocassist/pkg/conf" "github.com/jweny/pocassist/pkg/db" "github.com/jweny/pocassist/pkg/logging" "github.com/jweny/pocassist/pkg/util" "github.com/jweny/pocassist/poc/rule" + "github.com/spf13/viper" "github.com/urfave/cli/v2" "log" "os" + "path" + "path/filepath" "sort" ) @@ -21,6 +26,18 @@ var ( condition string ) +func init() { + welcome := ` + _ _ + _ __ ___ ___ __ _ ___ ___(_)___| |_ +| '_ \ / _ \ / __/ _' / __/ __| / __| __| +| |_) | (_) | (_| (_| \__ \__ \ \__ \ |_ +| .__/ \___/ \___\__,_|___/___/_|___/\__| +|_| +` + fmt.Println(welcome) +} + func InitAll() { // config 必须最先加载 conf2.Setup() @@ -31,6 +48,25 @@ func InitAll() { rule.Setup() } +// 使用viper 对配置热加载 +func HotConf() { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + log.Fatalf("conf.Setup, fail to get current path: %v", err) + } + // 配置文件路径 当前文件夹 + config.yaml + configFile := path.Join(dir, "config.yaml") + viper.SetConfigType("yaml") + viper.SetConfigFile(configFile) + // watch 监控配置文件变化 + viper.WatchConfig() + viper.OnConfigChange(func(e fsnotify.Event) { + // 配置文件发生变更之后会调用的回调函数 + log.Println("Config file changed:", e.Name) + InitAll() + }) +} + func RunApp() { app := cli.NewApp() app.Name = "pocassist" diff --git a/cmd/server.go b/cmd/server.go index c38f63d..43a4b5e 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -24,6 +24,7 @@ var subCommandServer = cli.Command{ func RunServer(c *cli.Context) error { InitAll() + HotConf() port := c.String("port") routers.InitRouter(port) return nil diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..e262a6d --- /dev/null +++ b/config.yaml @@ -0,0 +1,50 @@ +# webserver配置 +serverConfig: + # 配置jwt秘钥 + jwt_secret: "pocassist" + # gin的运行模式 "release" 或者 "debug" + run_mode: "release" + # 运行日志的文件名,日志将保存在二进制所在目录 + log_name : "debug.log" + +# HTTP配置 +httpConfig: + # 扫描时使用的代理:格式为 IP:PORT,example: 如 burpsuite,可填写 127.0.0.1:8080 + proxy: "" + # 读取 http 响应超时时间,不建议设置太小,否则可能影响到盲注的判断 + http_timeout: 10 + # 建立 tcp 连接的超时时间 + dail_timeout: 5 + # udp 超时时间 + udp_timeout: 5 + # 每秒最大请求数 + max_qps: 100 + # 单个请求最大允许的跳转次数 + max_redirect: 5 + headers: + # 默认 UA + user_agent: "Mozilla/5.0 (Windows NT 10.0; rv:78.0) Gecko/20100101 Firefox/78.0" + +# 数据库配置 +dbConfig: + # sqlite配置:sqlite数据库文件的路径 + sqlite : "pocassist.db" + # mysql配置 + mysql: + host: "127.0.0.1" + password: "" + port: "3306" + user: "root" + database: "pocassist" + # 数据库连接超时时间 + timeout: "3s" + +# 插件配置 +pluginsConfig: + # 并发量:同时运行的插件数量 + parallel: 8 + +# 反连平台配置: 目前使用 ceye.io +reverse: + api_key: "" + domain: "" diff --git a/pkg/conf/config.go b/pkg/conf/config.go index 9cf844d..0e1403b 100644 --- a/pkg/conf/config.go +++ b/pkg/conf/config.go @@ -2,7 +2,6 @@ package conf import ( "bytes" - "github.com/fsnotify/fsnotify" "github.com/spf13/viper" "log" "os" @@ -79,7 +78,6 @@ func Setup() { // 没有,生成默认yaml WriteYamlConfig(configFile) } - // watch配置 ReadYamlConfig(configFile) } @@ -101,12 +99,6 @@ func ReadYamlConfig(configFile string) { if err != nil { log.Fatalf("conf.Setup, fail to verify 'config.yaml', check format: %v", err) } - // watch 监控配置文件变化 - viper.WatchConfig() - viper.OnConfigChange(func(e fsnotify.Event) { - // 配置文件发生变更之后会调用的回调函数 - log.Println("Config file changed:", e.Name) - }) } func WriteYamlConfig(configFile string) { diff --git a/pkg/db/auth.go b/pkg/db/auth.go index 3b2934e..1f13854 100644 --- a/pkg/db/auth.go +++ b/pkg/db/auth.go @@ -1,5 +1,7 @@ package db +// auths 表 + type Auth struct { Id int `gorm:"primary_key" json:"id"` Username string `json:"username"` diff --git a/pkg/db/conn.go b/pkg/db/conn.go index 364a4ed..0052788 100644 --- a/pkg/db/conn.go +++ b/pkg/db/conn.go @@ -15,7 +15,6 @@ var GlobalDB *gorm.DB func Setup() { var err error dbConfig := conf.GlobalConfig.DbConfig - if conf.GlobalConfig.DbConfig.Sqlite == "" { // 配置mysql数据源 if dbConfig.Mysql.User == "" || @@ -57,6 +56,10 @@ func Setup() { if err != nil { log.Fatalf("db.Setup err: %v", err) } - GlobalDB.Logger = logger.Default.LogMode(logger.Silent) + + if conf.GlobalConfig.ServerConfig.RunMode == "release" { + // release下 + GlobalDB.Logger = logger.Default.LogMode(logger.Silent) + } } diff --git a/pkg/db/plugin.go b/pkg/db/plugin.go index 794f004..2f1e850 100644 --- a/pkg/db/plugin.go +++ b/pkg/db/plugin.go @@ -4,7 +4,8 @@ import ( "gorm.io/datatypes" ) -// 数据库 plugins 表 +// plugins 表 + type Plugin struct { Id int `gorm:"primary_key" json:"id"` VulId string `gorm:"column:vul_id" json:"vul_id"` diff --git a/pkg/db/scan.go b/pkg/db/scan.go new file mode 100644 index 0000000..091e4a0 --- /dev/null +++ b/pkg/db/scan.go @@ -0,0 +1,56 @@ +package db + +import ( + "gorm.io/datatypes" + "gorm.io/gorm" +) + +// task表 + +type Task struct { + gorm.Model + Id int `gorm:"primary_key" json:"id"` + Remarks string `gorm:"column:remarks" json:"remarks"` + Target string `gorm:"type:longtext" json:"Target"` + Operator string `gorm:"type:string" json:"operator"` +} + +// result表 + +type Result struct { + gorm.Model + Id int `gorm:"primary_key" json:"id"` + Detail datatypes.JSON `gorm:"column:detail" json:"detail"` + TaskId Task `gorm:"foreignkey:Desc"` +} + +type ResultSearchField struct { + Search string +} + +//func GetResultTotal(field *ResultSearchField) (total int64){ +// db := GlobalDB.Model(&Result{}) +// if field.Search != ""{ +// db = db.Where( +// GlobalDB.Where("remarks like ?", "%"+field.Search+"%"). +// Or("Target like ?", "%"+field.Search+"%")) +// } +// db.Count(&total) +// return +//} +// +//func GetResult(page int, pageSize int, field *TaskSearchField) (tasks []Task) { +// +// db := GlobalDB.Model(&Task{}) +// +// if field.Search != ""{ +// db = db.Where( +// GlobalDB.Where("remarks like ?", "%"+field.Search+"%"). +// Or("Target like ?", "%"+field.Search+"%")) +// } +// // 分页 +// if page > 0 && pageSize > 0 { +// db = db.Offset((page - 1) * pageSize).Limit(pageSize).Find(&tasks) +// } +// return +//} diff --git a/pkg/db/vulnerability.go b/pkg/db/vulnerability.go index 4e656e8..75d9dba 100644 --- a/pkg/db/vulnerability.go +++ b/pkg/db/vulnerability.go @@ -1,11 +1,6 @@ package db -type Webapp struct { - Id int `gorm:"primary_key" json:"id"` - Name string `gorm:"column:name" json:"name"` - Provider string `gorm:"column:provider" json:"provider"` - Remarks string `gorm:"column:remarks" json:"remarks"` -} +// Vulnerabilities表 type Vulnerability struct { Id int `gorm:"primary_key" json:"id"` @@ -116,20 +111,6 @@ func GetWebAppsTotal() (total int64) { return } -func GetWebApps(page int, pageSize int) (apps []Webapp) { - // 分页 - db := GlobalDB.Model(&Webapp{}) - if page > 0 && pageSize > 0 { - db = db.Offset((page - 1) * pageSize).Limit(pageSize).Find(&apps) - } - return -} - -func AddWebapp(app Webapp) bool { - GlobalDB.Create(&app) - return true -} - func ExistWebappByName(name string) bool { var app Webapp GlobalDB.Model(&Webapp{}).Where("name = ?", name).First(&app) diff --git a/pkg/db/webapp.go b/pkg/db/webapp.go new file mode 100644 index 0000000..f200d15 --- /dev/null +++ b/pkg/db/webapp.go @@ -0,0 +1,22 @@ +package db + +type Webapp struct { + Id int `gorm:"primary_key" json:"id"` + Name string `gorm:"column:name" json:"name"` + Provider string `gorm:"column:provider" json:"provider"` + Remarks string `gorm:"column:remarks" json:"remarks"` +} + +func GetWebApps(page int, pageSize int) (apps []Webapp) { + // 分页 + db := GlobalDB.Model(&Webapp{}) + if page > 0 && pageSize > 0 { + db = db.Offset((page - 1) * pageSize).Limit(pageSize).Find(&apps) + } + return +} + +func AddWebapp(app Webapp) bool { + GlobalDB.Create(&app) + return true +} diff --git a/pkg/util/request.go b/pkg/util/request.go index 3e22429..d233141 100644 --- a/pkg/util/request.go +++ b/pkg/util/request.go @@ -14,7 +14,9 @@ import ( ) type clientDoer interface { + // 不跟随重定向 DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, t time.Duration) error + // 跟随重定向 DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirectsCount int) error } @@ -174,11 +176,13 @@ func DoFasthttpRequest(req *fasthttp.Request, redirect bool) (*proto.Response, e var err error if redirect { + // 跟随重定向 最大跳转数从conf中加载 + maxRedirects := conf.GlobalConfig.HttpConfig.MaxRedirect + err = fasthttpClient.DoRedirects(req, resp, maxRedirects) + } else { + // 不跟随重定向 timeout := conf.GlobalConfig.HttpConfig.HttpTimeout err = fasthttpClient.DoTimeout(req, resp, time.Duration(timeout)*time.Second) - } else { - // 不接受跳转 - err = fasthttpClient.DoRedirects(req, resp, 0) } if err != nil { return nil, err diff --git a/pkg/util/util.go b/pkg/util/util.go index 2e774d6..2a8b442 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -9,7 +9,7 @@ import ( ) func Setup() { - // fast http client 初始化 + // fasthttp client 初始化 DownProxy := conf2.GlobalConfig.HttpConfig.Proxy client := &fasthttp.Client{ // If InsecureSkipVerify is true, TLS accepts any certificate diff --git a/poc/rule/plugin.go b/poc/rule/plugin.go index 9064581..bcb4aa1 100644 --- a/poc/rule/plugin.go +++ b/poc/rule/plugin.go @@ -34,12 +34,55 @@ func SplitToArray(conditions string) []string { } // 从数据库 中加载 poc +func LoadDbPlugin(lodeType string, array []string) ([]Plugin, error) { + // 数据库数据 + var dbPluginList []db.Plugin + // plugin对象 + var plugins []Plugin + switch lodeType { + case LoadMulti: + // 多个 + tx := db.GlobalDB.Where("vul_id IN ? AND enable = ?", array, 1).Find(&dbPluginList) + if tx.Error != nil { + logging.GlobalLogger.Error("[db select err ]", tx.Error) + return nil, tx.Error + } + default: + // 默认执行全部启用规则 + tx := db.GlobalDB.Where("enable = ?", 1).Find(&dbPluginList) + if tx.Error != nil { + logging.GlobalLogger.Error("[db select err ]", tx.Error) + return nil, tx.Error + } + } + + logging.GlobalLogger.Info("[dbPluginList load number ]", len(dbPluginList)) + + for _, v := range dbPluginList { + poc, err := ParseJsonPoc(v.JsonPoc) + if err != nil { + logging.GlobalLogger.Error("[plugins plugin load err ]", v.VulId) + continue + } + plugin := Plugin{ + VulId: v.VulId, + Affects: v.Affects, + JsonPoc: poc, + Enable: v.Enable, + } + plugins = append(plugins, plugin) + } + return plugins, nil + +} + +// 从数据库 中加载 poc +// todo delete func LoadDbPlugins(loadType string, conditions string) ([]db.Plugin, error) { var plugin db.Plugin var plugins []db.Plugin logging.GlobalLogger.Debug("[loadPoc type ]", loadType) logging.GlobalLogger.Debug("[conditions is ]", conditions) - // todo 命令行里传json_str过来 switch loadType { case LoadSingle: // 漏洞编号 @@ -66,8 +109,8 @@ func LoadDbPlugins(loadType string, conditions string) ([]db.Plugin, error) { } case LoadMulti: - plugins := SplitToArray(conditions) - tx := db.GlobalDB.Where("vul_id IN ? AND enable = ?", plugins, 1).Find(&plugins) + vulList := SplitToArray(conditions) + tx := db.GlobalDB.Where("vul_id IN ? AND enable = ?", vulList, 1).Find(&plugins) if tx.Error != nil { logging.GlobalLogger.Error("[db select err ]", tx.Error) return nil, tx.Error @@ -112,7 +155,7 @@ func LoadPlugins(loadType string, conditions string) ([]Plugin, error) { } // 批量执行plugin -func RunPlugins(oreq *http.Request, rules []Plugin){ +func RunPlugins(oreq *http.Request, plugins []Plugin){ // 并发限制 var wg sync.WaitGroup parallel := conf.GlobalConfig.PluginsConfig.Parallel @@ -123,8 +166,8 @@ func RunPlugins(oreq *http.Request, rules []Plugin){ }) defer p.Release() - for i := range rules { - item := &ScanItem{oreq, &rules[i]} + for i := range plugins { + item := &ScanItem{oreq, &plugins[i]} wg.Add(1) p.Invoke(item) }