diff --git a/api/routers/v1/auth.go b/api/routers/v1/auth/auth.go similarity index 100% rename from api/routers/v1/auth.go rename to api/routers/v1/auth/auth.go diff --git a/api/routers/v1/plugin.go b/api/routers/v1/plugin/plugin.go similarity index 71% rename from api/routers/v1/plugin.go rename to api/routers/v1/plugin/plugin.go index e2fc8b1..03baaf9 100644 --- a/api/routers/v1/plugin.go +++ b/api/routers/v1/plugin/plugin.go @@ -1,6 +1,7 @@ package v1 import ( + "bufio" "github.com/astaxie/beego/validation" "github.com/gin-gonic/gin" "github.com/jweny/pocassist/api/msg" @@ -10,6 +11,13 @@ import ( "github.com/jweny/pocassist/poc/rule" "github.com/unknwon/com" "gorm.io/datatypes" + "log" +) + +const ( + TargetUrl = "url" + TargetUrlFile = "file" + TargetUrlRaw = "raw" ) type PluginSerializer struct { @@ -23,13 +31,21 @@ type PluginSerializer struct { Description int `gorm:"column:description" json:"description"` } -type RunPluginSerializer struct { - // 运行 - Target string `json:"target"` +type RunSinglePluginSerializer struct { + // 运行单个 + Target string `json:"target"` Affects string `gorm:"column:affects" json:"affects"` JsonPoc datatypes.JSON `gorm:"column:json_poc" json:"json_poc"` } +type RunPluginsSerializer struct { + // 批量运行 + Target string `json:"target"` + TargetType string `json:"target_type"` + RunType string `json:"run_type"` + VulIdList []string `json:"vul_id_list"` +} + //获取单个plugin func GetPlugin(c *gin.Context) { id := com.StrTo(c.Param("id")).MustInt() @@ -137,11 +153,12 @@ func UpdatePlugin(c *gin.Context) { valid.Required(plugin.Affects, "Affects").Message("Affects不能为空") if ! valid.HasErrors() { - if db.ExistPluginByID(plugin.Id){ + if db.ExistPluginByVulId(plugin.VulId){ + c.JSON(msg.ErrResp("漏洞编号已存在")) + return + } else { db.EditPlugin(plugin.Id, plugin) c.JSON(msg.SuccessResp(plugin)) - } else { - c.JSON(msg.ErrResp("record not found")) return } } else { @@ -172,9 +189,9 @@ func DeletePlugin(c *gin.Context) { } } -//运行 +//运行单个plugin 不是从数据库提取数据,表单传数据 func RunPlugin(c *gin.Context) { - run := RunPluginSerializer{} + run := RunSinglePluginSerializer{} err := c.BindJSON(&run) if err != nil { c.JSON(msg.ErrResp("参数校验不通过")) @@ -195,7 +212,7 @@ func RunPlugin(c *gin.Context) { Affects: run.Affects, JsonPoc: poc, } - item := &rule.ScanItem{Req: oreq, Vul: ¤tPlugin} + item := &rule.ScanItem{Req: oreq, Plugin: ¤tPlugin} result, err := rule.RunPoc(item) if err != nil { c.JSON(msg.ErrResp("规则运行失败:" + err.Error())) @@ -210,6 +227,62 @@ func RunPlugin(c *gin.Context) { } } +//批量运行plugin 从数据库提取数据,表单传数据 +//前端向后端传 "vul_id_list":["poc_db_1","poc_db_2"] +func RunPlugins(c *gin.Context) { + runs := RunPluginsSerializer{} + err := c.BindJSON(&runs) + if err != nil { + c.JSON(msg.ErrResp("参数校验不通过")) + return + } + plugins, err := rule.LoadDbPlugin(runs.RunType, runs.VulIdList) + + switch runs.TargetType { + case TargetUrl: + url := runs.TargetType + oreq, err := util.GenOriginalReq(url) + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + c.JSON(msg.ErrResp("原始请求生成失败")) + return + } + rule.RunPlugins(oreq, plugins) + case TargetUrlFile: + //获取文件 + file, header, err := c.Request.FormFile("file") + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + c.JSON(msg.ErrResp("url文件上传失败")) + return + } + log.Print(header.Filename) + //content, err := ioutil.ReadAll(file) + var targets []string + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + val := scanner.Text() + if val == "" { + continue + } + targets = append(targets, val) + } + + for _, url := range targets { + oreq, err := util.GenOriginalReq(url) + if err != nil { + logging.GlobalLogger.Error("[original request gen err ]", err) + } + logging.GlobalLogger.Info("[start check url ]", url) + rule.RunPlugins(oreq, plugins) + } + case TargetUrlRaw: + //请求报文 + } +} + + diff --git a/api/routers/v1/vulnerabilities.go b/api/routers/v1/vulnerability/vulnerability.go similarity index 100% rename from api/routers/v1/vulnerabilities.go rename to api/routers/v1/vulnerability/vulnerability.go diff --git a/cmd/run.go b/cmd/run.go index 43000dd..fac9fce 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -8,6 +8,7 @@ import ( "github.com/jweny/pocassist/pkg/util" "github.com/jweny/pocassist/poc/rule" "github.com/urfave/cli/v2" + "log" "os" "sort" ) @@ -18,15 +19,13 @@ var ( rawFile string loadPoc string condition string - debug bool - dbname string ) func InitAll() { // config 必须最先加载 conf2.Setup() - logging.Setup(debug) - db.Setup(dbname) + logging.Setup() + db.Setup() routers.Setup() util.Setup() rule.Setup() @@ -37,21 +36,6 @@ func RunApp() { app.Name = "pocassist" app.Usage = "New POC Framework Without Writing Code" app.Version = "0.3.0" - // 全局flag - app.Flags = []cli.Flag{ - &cli.BoolFlag{ - Name: "debug", - Aliases: []string{"d"}, - Destination: &debug, - Value: false, - Usage: "enable debug log"}, - &cli.StringFlag{ - Name: "database", - Aliases: []string{"b"}, - Destination: &dbname, - Value: "sqlite", - Usage: "kind of database, default: sqlite"}, - } // 子命令 app.Commands = []*cli.Command{ @@ -64,7 +48,7 @@ func RunApp() { err := app.Run(os.Args) if err != nil { - logging.GlobalLogger.Error("[app run err ]", err) + log.Fatalf("cli.RunApp err: %v", err) return } } diff --git a/config.yaml b/config.yaml.example similarity index 100% rename from config.yaml rename to config.yaml.example diff --git a/pkg/conf/config.go b/pkg/conf/config.go index 4f1a6cd..9cf844d 100644 --- a/pkg/conf/config.go +++ b/pkg/conf/config.go @@ -1,6 +1,8 @@ package conf import ( + "bytes" + "github.com/fsnotify/fsnotify" "github.com/spf13/viper" "log" "os" @@ -64,24 +66,59 @@ var GlobalConfig *Config // 加载配置 func Setup() { - // 加载config - var err error dir, err := filepath.Abs(filepath.Dir(os.Args[0])) if err != nil { - log.Fatalf("config.Setup, fail to get current path: %v", err) + log.Fatalf("conf.Setup, fail to get current path: %v", err) } + // 配置文件路径 当前文件夹 + config.yaml configFile := path.Join(dir, "config.yaml") - viper.SetConfigFile(configFile) + + // 检测配置文件是否存在 + _ , err = os.Lstat(configFile) + if err != nil { + // 没有,生成默认yaml + WriteYamlConfig(configFile) + } + // watch配置 + ReadYamlConfig(configFile) + +} + +func ReadYamlConfig(configFile string) { + // 加载config viper.SetConfigType("yaml") + viper.SetConfigFile(configFile) - err = viper.ReadInConfig() + err := viper.ReadInConfig() if err != nil { - log.Fatalf("config.Setup, fail to read 'config.yaml': %v", err) + log.Fatalf("conf.Setup, fail to read 'config.yaml': %v", err) } err = viper.Unmarshal(&GlobalConfig) if err != nil { - log.Fatalf("config.Setup, fail to parse 'config.yaml': %v", err) + log.Fatalf("conf.Setup, fail to parse 'config.yaml', check format: %v", err) + } + err = verifiyConfig() + if err != nil { + log.Fatalf("conf.Setup, fail to verify 'config.yaml', check format: %v", err) } + // watch 监控配置文件变化 + viper.WatchConfig() + viper.OnConfigChange(func(e fsnotify.Event) { + // 配置文件发生变更之后会调用的回调函数 + log.Println("Config file changed:", e.Name) + }) } - +func WriteYamlConfig(configFile string) { + // 生成默认config + viper.SetConfigType("yaml") + err := viper.ReadConfig(bytes.NewBuffer(defaultYamlByte)) + if err != nil { + log.Fatalf("conf.Setup, fail to read default config bytes: %v", err) + } + // 写文件 + err = viper.SafeWriteConfigAs(configFile) + if err != nil { + log.Fatalf("conf.Setup, fail to write 'config.yaml': %v", err) + } +} \ No newline at end of file diff --git a/pkg/conf/default.go b/pkg/conf/default.go new file mode 100644 index 0000000..3bbc577 --- /dev/null +++ b/pkg/conf/default.go @@ -0,0 +1,84 @@ +package conf + +import ( + "encoding/json" + "errors" +) + +var defaultYamlByte = []byte(` +# webserver配置 +serverConfig: + # 配置jwt秘钥 + jwt_secret: "pocassist" + # gin的运行模式 "release" 或者 "debug" + run_mode: "release" + # 运行日志的文件名,日志将保存在二进制所在目录 + log_name : "debug.log" + +# HTTP配置 +httpConfig: + # 扫描时使用的代理:格式为 IP:PORT,example: 如 burpsuite,可填写 127.0.0.1:8080 + proxy: "" + # 读取 http 响应超时时间,不建议设置太小,否则可能影响到盲注的判断 + http_timeout: 10 + # 建立 tcp 连接的超时时间 + dail_timeout: 5 + # udp 超时时间 + udp_timeout: 5 + # 每秒最大请求数 + max_qps: 100 + # 单个请求最大允许的跳转次数 + max_redirect: 5 + headers: + # 默认 UA + user_agent: "Mozilla/5.0 (Windows NT 10.0; rv:78.0) Gecko/20100101 Firefox/78.0" + +# 数据库配置 +dbConfig: + # sqlite配置:sqlite数据库文件的路径 + sqlite : "pocassist.db" + # mysql配置 + mysql: + host: "127.0.0.1" + password: "" + port: "3306" + user: "root" + database: "pocassist" + # 数据库连接超时时间 + timeout: "3s" + +# 插件配置 +pluginsConfig: + # 并发量:同时运行的插件数量 + parallel: 8 + +# 反连平台配置: 目前使用 ceye.io +reverse: + api_key: "" + domain: "" +`) + +var runMode = []string{"debug","release"} + +func ArrayToString (array []string) string { + str, _ := json.Marshal(array) + return string(str) +} + +func StrInArray (str string, array []string) error { + for _, element := range array{ + if str == element{ + return nil + } + } + return errors.New(str + "must in" + ArrayToString(array)) +} + +func verifiyConfig() error { + var err error + err = StrInArray(GlobalConfig.ServerConfig.RunMode, runMode) + if err != nil { + return err + } + return nil +} diff --git a/pkg/db/conn.go b/pkg/db/conn.go index 8f1bbfc..364a4ed 100644 --- a/pkg/db/conn.go +++ b/pkg/db/conn.go @@ -12,13 +12,11 @@ import ( var GlobalDB *gorm.DB -func Setup(dbname string) { - if dbname != "mysql" && dbname != "sqlite" { - log.Fatalf("db.Setup err: unsupported database kind. only 'sqlite' or 'mysql'") - } +func Setup() { var err error dbConfig := conf.GlobalConfig.DbConfig - if dbname == "mysql" { + + if conf.GlobalConfig.DbConfig.Sqlite == "" { // 配置mysql数据源 if dbConfig.Mysql.User == "" || dbConfig.Mysql.Password == "" || @@ -39,8 +37,7 @@ func Setup(dbname string) { if err != nil { log.Fatalf("db.Setup err: %v", err) } - } - if dbname == "sqlite" { + } else { // 配置sqlite数据源 if dbConfig.Sqlite == "" { log.Fatalf("db.Setup err: config.yaml sqlite config not set") @@ -50,6 +47,7 @@ func Setup(dbname string) { log.Fatalf("db.Setup err: %v", err) } } + if GlobalDB == nil { log.Fatalf("db.Setup err: db connect failed") } diff --git a/pkg/logging/log.go b/pkg/logging/log.go index 1e1c326..8865363 100644 --- a/pkg/logging/log.go +++ b/pkg/logging/log.go @@ -14,7 +14,13 @@ import ( var GlobalLogger *logrus.Logger -func Setup(debug bool) { +func Setup() { + var debug = false + + if conf2.GlobalConfig.ServerConfig.RunMode == "debug" { + debug = true + } + logName := conf2.GlobalConfig.ServerConfig.LogName if logName == "" { logName = "debug.log" diff --git a/poc/rule/controller.go b/poc/rule/controller.go index 18ecfab..d27d38b 100644 --- a/poc/rule/controller.go +++ b/poc/rule/controller.go @@ -27,7 +27,7 @@ var ControllerPool = sync.Pool{ } type PocController struct { - vulId string + pluginId string originalReq *http.Request // 原始请求 --> 初始条件 poc *Poc // 加载的poc --> 初始条件 NewReq *proto.Request // 生成的新请求 @@ -73,7 +73,7 @@ func (controller *PocController) Reset() { controller.poc = nil controller.celEnv = nil controller.NewReq = nil - controller.vulId = "" + controller.pluginId = "" return } diff --git a/poc/rule/handle.go b/poc/rule/handle.go index 763b332..cef78d5 100644 --- a/poc/rule/handle.go +++ b/poc/rule/handle.go @@ -30,8 +30,8 @@ func LimitWait() { // 限制并发 type ScanItem struct { - Req *http.Request // 原始请求 - Vul *Plugin // vul from db + Req *http.Request // 原始请求 + Plugin *Plugin // 检测插件 } var Handles map[string][]HandlerFunc @@ -48,15 +48,15 @@ func ExecExpressionHandle(controller *PocController) error { return err } if result { - logging.GlobalLogger.Info("[=== find vul===]\n", - " [vul_id] ", controller.vulId, - " [vul_name] ", controller.poc.Name) + logging.GlobalLogger.Info("[=== find vulnerability===]\n", + " [plugin_id] ", controller.pluginId, + " [plugin_name] ", controller.poc.Name) controller.Abort() } logging.GlobalLogger.Info("[=== not vul===]\n", - " [vul_id] ", controller.vulId, - " [vul_name] ", controller.poc.Name) + " [plugin_id] ", controller.pluginId, + " [plugin_name] ", controller.poc.Name) return nil } @@ -90,13 +90,13 @@ func ExecScriptHandle(controller *PocController) error { result, err := scanFunc(args) if err != nil { - logging.GlobalLogger.Error("[script scan failed ]", controller.vulId, " err:", err) + logging.GlobalLogger.Error("[script scan failed ]", controller.pluginId, " err:", err) return err } logging.GlobalLogger.Info("[script scan finished ]", - " [vul_id] ", controller.vulId, + " [plugin_id] ", controller.pluginId, " [script_func] ", scanFunc, - " [vul_result] ", result) + " [result] ", result) return nil } diff --git a/poc/rule/model.go b/poc/rule/model.go index 50e90ea..183b7bc 100644 --- a/poc/rule/model.go +++ b/poc/rule/model.go @@ -38,8 +38,8 @@ type Poc struct { } type Plugin struct { - VulId string `gorm:"column:vul_id"` // 漏洞编号 - Affects string `gorm:"column:affects"` // 影响类型 dir/server/param/url/content - JsonPoc *Poc `gorm:"column:json_poc"` // json规则 - Enable bool `gorm:"column:enable"` // 是否启用 + VulId string `gorm:"column:vul_id"` // 漏洞编号 + Affects string `gorm:"column:affects"` // 影响类型 dir/server/param/url/content + JsonPoc *Poc `gorm:"column:json_poc"` // json规则 + Enable bool `gorm:"column:enable"` // 是否启用 } \ No newline at end of file diff --git a/poc/rule/plugin.go b/poc/rule/plugin.go index f3da933..9064581 100644 --- a/poc/rule/plugin.go +++ b/poc/rule/plugin.go @@ -2,6 +2,7 @@ package rule import ( "encoding/json" + "github.com/jweny/pocassist/pkg/conf" "github.com/jweny/pocassist/pkg/db" "github.com/jweny/pocassist/pkg/logging" "github.com/panjf2000/ants/v2" @@ -114,9 +115,9 @@ func LoadPlugins(loadType string, conditions string) ([]Plugin, error) { func RunPlugins(oreq *http.Request, rules []Plugin){ // 并发限制 var wg sync.WaitGroup - //parallel := conf.GlobalConfig.PluginsConfig.Parallel + parallel := conf.GlobalConfig.PluginsConfig.Parallel - p, _ := ants.NewPoolWithFunc(10, func(item interface{}) { + p, _ := ants.NewPoolWithFunc(parallel, func(item interface{}) { RunPoc(item) wg.Done() }) diff --git a/poc/rule/poc.go b/poc/rule/poc.go index 3aef9cc..e0a8819 100644 --- a/poc/rule/poc.go +++ b/poc/rule/poc.go @@ -12,27 +12,27 @@ import ( // 执行单个poc func RunPoc(inter interface{}) (*util.ScanResult, error) { - item := inter.(*ScanItem) - originalReq := item.Req - vul := item.Vul + scanItem := inter.(*ScanItem) + originalReq := scanItem.Req + plugin := scanItem.Plugin - if originalReq == nil || vul == nil { - return nil, errors.New("no request or no vul") + if originalReq == nil || plugin == nil { + return nil, errors.New("no request or no plugin") } var data []byte if originalReq.Body != nil && originalReq.Body != http.NoBody { data, err := ioutil.ReadAll(originalReq.Body) if err != nil { - logging.GlobalLogger.Error("[plugin originalReq data read err ]", vul.VulId) + logging.GlobalLogger.Error("[plugin originalReq data read err ]", plugin.VulId) return nil, err } originalReq.Body = ioutil.NopCloser(bytes.NewBuffer(data)) } - handles := getHandles(vul.Affects) - logging.GlobalLogger.Debug("[plugin running ]" , vul.VulId, " [affects] ", vul.Affects, " [name] ", vul.JsonPoc.Name) + handles := getHandles(plugin.Affects) + logging.GlobalLogger.Debug("[plugin running ]" , plugin.VulId, " [affects] ", plugin.Affects, " [name] ", plugin.JsonPoc.Name) // 影响为参数类型 - if vul.Affects == AffectReplaceParameter || vul.Affects == AffectAppendParameter { + if plugin.Affects == AffectReplaceParameter || plugin.Affects == AffectAppendParameter { var originalGetParamFields url.Values var replaceHandler ReplaceHandler var err error @@ -52,29 +52,29 @@ func RunPoc(inter interface{}) (*util.ScanResult, error) { replaceHandler = &ReplacePost{} } - env, err := GenCelEnv(vul.JsonPoc) + env, err := GenCelEnv(plugin.JsonPoc) if err != nil { - logging.GlobalLogger.Error("[plugin cel env gen err ]" , vul.VulId) + logging.GlobalLogger.Error("[plugin cel env gen err ]" , plugin.VulId) return nil, err } newReq, err := InitNewReq(originalReq) if err != nil { - logging.GlobalLogger.Error("[plugin new request init err ]" , vul.VulId) + logging.GlobalLogger.Error("[plugin new request init err ]" , plugin.VulId) return nil, err } - varMap, err := ParsePocSet(vul.JsonPoc, env, newReq) + varMap, err := ParsePocSet(plugin.JsonPoc, env, newReq) if err != nil { util.RequestPut(newReq) - logging.GlobalLogger.Error("[plugin poc set parse err ]", vul.VulId, err) + logging.GlobalLogger.Error("[plugin poc set parse err ]", plugin.VulId, err) return nil, err } for field := range originalGetParamFields { - for _, value := range vul.JsonPoc.Params { + for _, value := range plugin.JsonPoc.Params { // 限速 LimitWait() logging.GlobalLogger.Debug("[current param]", value) - controller := InitPocController(originalReq, vul.JsonPoc, vul.Affects, data) + controller := InitPocController(originalReq, plugin.JsonPoc, plugin.Affects, data) controller.celEnv = env controller.varMap = varMap controller.Handles = handles @@ -90,8 +90,8 @@ func RunPoc(inter interface{}) (*util.ScanResult, error) { controller.Reset() util.RequestPut(newReq) logging.GlobalLogger.Info("[plugin result ]\n", - " [vul_id] ", vul.VulId, - " [vul_name] ", vul.JsonPoc.Name, + " [plugin_id] ", plugin.VulId, + " [plugin_name] ", plugin.JsonPoc.Name, " [param] ", value) return util.VulnerableHttpResult(controller.originalReq.URL.String(),"", controller.respList), nil @@ -103,29 +103,29 @@ func RunPoc(inter interface{}) (*util.ScanResult, error) { } else { // 其他类型 - env, err := GenCelEnv(vul.JsonPoc) + env, err := GenCelEnv(plugin.JsonPoc) if err != nil { - logging.GlobalLogger.Error("[plugin cel env gen err ]" , vul.VulId) + logging.GlobalLogger.Error("[plugin cel env gen err ]" , plugin.VulId) return nil, err } newReq, err := InitNewReq(originalReq) if err != nil { - logging.GlobalLogger.Error("[plugin new request init err ]" , vul.VulId) + logging.GlobalLogger.Error("[plugin new request init err ]" , plugin.VulId) return nil, err } - varMap, err := ParsePocSet(vul.JsonPoc, env, newReq) + varMap, err := ParsePocSet(plugin.JsonPoc, env, newReq) if err != nil { util.RequestPut(newReq) - logging.GlobalLogger.Error("plugin poc set parse err ]", vul.VulId, err) + logging.GlobalLogger.Error("plugin poc set parse err ]", plugin.VulId, err) return nil, err } // 限速 LimitWait() - controller := InitPocController(originalReq, vul.JsonPoc, vul.Affects, data) + controller := InitPocController(originalReq, plugin.JsonPoc, plugin.Affects, data) controller.celEnv = env controller.varMap = varMap controller.Handles = handles - controller.vulId = vul.VulId + controller.pluginId = plugin.VulId err = controller.Next() if err != nil { @@ -135,9 +135,9 @@ func RunPoc(inter interface{}) (*util.ScanResult, error) { if controller.IsAborted() { result := util.VulnerableHttpResult(controller.originalReq.URL.String(),"", controller.respList) logging.GlobalLogger.Info("[plugin scan result ]\n", - " [vul_id] ", vul.VulId, - " [vul_name] ", vul.JsonPoc.Name, - " [vul_result] ", result) + " [plugin_id] ", plugin.VulId, + " [plugin_name] ", plugin.JsonPoc.Name, + " [plugin_result] ", result) return result, nil } controller.Reset()