diff --git a/app.go b/app.go index dfadcac1272..fea370202e6 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,8 @@ import ( "errors" "fmt" "io" + "maps" + "mime" "net" "net/http" "net/http/httputil" @@ -706,6 +708,200 @@ func (app *App) Name(name string) Router { return app } +// Summary assigns a short summary to the most recently added route. +func (app *App) Summary(sum string) Router { + app.mutex.Lock() + app.latestRoute.Summary = sum + app.mutex.Unlock() + return app +} + +// Description assigns a description to the most recently added route. +func (app *App) Description(desc string) Router { + app.mutex.Lock() + app.latestRoute.Description = desc + app.mutex.Unlock() + return app +} + +// Consumes assigns a request media type to the most recently added route. +func (app *App) Consumes(typ string) Router { + if typ != "" { + if _, _, err := mime.ParseMediaType(typ); err != nil || !strings.Contains(typ, "/") { + panic("invalid media type: " + typ) + } + } + app.mutex.Lock() + app.latestRoute.Consumes = typ + app.mutex.Unlock() + return app +} + +// Produces assigns a response media type to the most recently added route. +func (app *App) Produces(typ string) Router { + if typ != "" { + if _, _, err := mime.ParseMediaType(typ); err != nil || !strings.Contains(typ, "/") { + panic("invalid media type: " + typ) + } + } + app.mutex.Lock() + app.latestRoute.Produces = typ + app.mutex.Unlock() + return app +} + +// RequestBody documents the request payload for the most recently added route. +func (app *App) RequestBody(description string, required bool, mediaTypes ...string) Router { + sanitized := sanitizeRequiredMediaTypes(mediaTypes) + + app.mutex.Lock() + app.latestRoute.RequestBody = &RouteRequestBody{ + Description: description, + Required: required, + MediaTypes: append([]string(nil), sanitized...), + } + if len(sanitized) > 0 { + app.latestRoute.Consumes = sanitized[0] + } + app.mutex.Unlock() + + return app +} + +// Parameter documents an input parameter for the most recently added route. +func (app *App) Parameter(name, in string, required bool, schema map[string]any, description string) Router { + if strings.TrimSpace(name) == "" { + panic("parameter name is required") + } + + location := strings.ToLower(strings.TrimSpace(in)) + switch location { + case "path", "query", "header", "cookie": + default: + panic("invalid parameter location: " + in) + } + + if schema == nil { + schema = map[string]any{"type": "string"} + } + + schemaCopy := make(map[string]any, len(schema)) + maps.Copy(schemaCopy, schema) + if _, ok := schemaCopy["type"]; !ok { + schemaCopy["type"] = "string" + } + + if location == "path" { + required = true + } + + param := RouteParameter{ + Name: name, + In: location, + Required: required, + Description: description, + Schema: schemaCopy, + } + + app.mutex.Lock() + app.latestRoute.Parameters = append(app.latestRoute.Parameters, param) + app.mutex.Unlock() + + return app +} + +// Response documents an HTTP response for the most recently added route. +func (app *App) Response(status int, description string, mediaTypes ...string) Router { + if status != 0 && (status < 100 || status > 599) { + panic("invalid status code") + } + + sanitized := sanitizeMediaTypes(mediaTypes) + + if description == "" { + if status == 0 { + description = "Default response" + } else if text := http.StatusText(status); text != "" { + description = text + } else { + description = "Status " + strconv.Itoa(status) + } + } + + key := "default" + if status > 0 { + key = strconv.Itoa(status) + } + + resp := RouteResponse{Description: description} + if len(sanitized) > 0 { + resp.MediaTypes = append([]string(nil), sanitized...) + } + + app.mutex.Lock() + if app.latestRoute.Responses == nil { + app.latestRoute.Responses = make(map[string]RouteResponse) + } + app.latestRoute.Responses[key] = resp + if status == StatusOK && len(resp.MediaTypes) > 0 { + app.latestRoute.Produces = resp.MediaTypes[0] + } + app.mutex.Unlock() + + return app +} + +func sanitizeMediaTypes(mediaTypes []string) []string { + if len(mediaTypes) == 0 { + return nil + } + + seen := make(map[string]struct{}, len(mediaTypes)) + sanitized := make([]string, 0, len(mediaTypes)) + for _, typ := range mediaTypes { + trimmed := strings.TrimSpace(typ) + if trimmed == "" { + continue + } + if _, _, err := mime.ParseMediaType(trimmed); err != nil || !strings.Contains(trimmed, "/") { + panic("invalid media type: " + trimmed) + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + sanitized = append(sanitized, trimmed) + } + if len(sanitized) == 0 { + return nil + } + return sanitized +} + +func sanitizeRequiredMediaTypes(mediaTypes []string) []string { + sanitized := sanitizeMediaTypes(mediaTypes) + if len(sanitized) == 0 { + panic("at least one media type must be provided") + } + return sanitized +} + +// Tags assigns tags to the most recently added route. +func (app *App) Tags(tags ...string) Router { + app.mutex.Lock() + app.latestRoute.Tags = tags + app.mutex.Unlock() + return app +} + +// Deprecated marks the most recently added route as deprecated. +func (app *App) Deprecated() Router { + app.mutex.Lock() + app.latestRoute.Deprecated = true + app.mutex.Unlock() + return app +} + // GetRoute Get route by name func (app *App) GetRoute(name string) Route { for _, routes := range app.stack { diff --git a/docs/middleware/openapi.md b/docs/middleware/openapi.md new file mode 100644 index 00000000000..f34771c3ca6 --- /dev/null +++ b/docs/middleware/openapi.md @@ -0,0 +1,142 @@ +--- +id: openapi +--- + +# OpenAPI + +OpenAPI middleware for [Fiber](https://github.com/gofiber/fiber) that generates an OpenAPI specification based on the routes registered in your application. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +``` + +## Examples + +Import the middleware package that is part of the Fiber web framework + +```go +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/openapi" +) +``` + +After you initiate your Fiber app, you can use the following possibilities: + +```go +// Initialize default config. Register the middleware *after* all routes +// so that the spec includes every handler. +app.Use(openapi.New()) + +// Or extend your config for customization +app.Use(openapi.New(openapi.Config{ + Title: "My API", + Version: "1.0.0", + ServerURL: "https://example.com", +})) + +// Customize metadata for specific operations +app.Use(openapi.New(openapi.Config{ + Operations: map[string]openapi.Operation{ + "GET /users": { + Summary: "List users", + Description: "Returns all users", + Produces: fiber.MIMEApplicationJSON, + }, + }, +})) + +// Routes may optionally document themselves using Summary, Description, +// RequestBody, Parameter, Response, Tags, Deprecated, Produces and Consumes. +app.Post("/users", createUser). + Summary("Create user"). + Description("Creates a new user"). + RequestBody("User payload", true, fiber.MIMEApplicationJSON). + Parameter("trace-id", "header", true, nil, "Tracing identifier"). + Response(fiber.StatusCreated, "Created", fiber.MIMEApplicationJSON). + Tags("users", "admin"). + Produces(fiber.MIMEApplicationJSON) + +// If not specified, routes default to an empty summary and description, no tags, +// not deprecated, and a "text/plain" request and response media type. +// Consumes and Produces will panic if provided an invalid media type. +``` + +Each documented route automatically includes a `200` response with the description `OK` to satisfy the minimum OpenAPI requirements. Additional responses can be declared via the `Response` helper or the middleware configuration. + +`CONNECT` routes are ignored because the OpenAPI specification does not define a `connect` operation. + +## Config + +| Property | Type | Description | Default | +|:------------|:------------------------|:----------------------------------------------------------------|:------------------:| +| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | +| Title | `string` | Title is the title for the generated OpenAPI specification. | `"Fiber API"` | +| Version | `string` | Version is the version for the generated OpenAPI specification. | `"1.0.0"` | +| Description | `string` | Description is the description for the generated specification. | `""` | +| ServerURL | `string` | ServerURL is the server URL used in the generated specification.| `""` | +| Path | `string` | Path is the route where the specification will be served. | `"/openapi.json"` | +| Operations | `map[string]Operation` | Per-route metadata keyed by `METHOD /path`. | `nil` | + +When the middleware is attached to a group or mounted under a prefixed `Use`, the configured `Path` is resolved relative to that +prefix. For example, `app.Group("/v1").Use(openapi.New())` serves the specification at `/v1/openapi.json`, while a global `app.U +se(openapi.New())` only intercepts `/openapi.json` and will not affect other endpoints ending in `openapi.json`. + +## Default Config + +```go +var ConfigDefault = Config{ + Next: nil, + Operations: nil, + Title: "Fiber API", + Version: "1.0.0", + Description: "", + ServerURL: "", + Path: "/openapi.json", +} +``` + +### Operation + +```go +type Operation struct { + RequestBody *RequestBody + Responses map[string]Response + Parameters []Parameter + Tags []string + + ID string + Summary string + Description string + Consumes string + Produces string + Deprecated bool +} + +type Parameter struct { + Schema map[string]any + Name string + In string + Description string + Required bool +} + +type Media struct { + Schema map[string]any +} + +type Response struct { + Content map[string]Media + Description string +} + +type RequestBody struct { + Content map[string]Media + Description string + Required bool +} +``` + +Refer to the type definitions above when customizing OpenAPI operations in your configuration. diff --git a/docs/whats_new.md b/docs/whats_new.md index 581f1a59234..c9251ff52ac 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1288,6 +1288,10 @@ Deprecated fields `Duration`, `Store`, and `Key` have been removed in v3. Use `E Monitor middleware is migrated to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor) with [PR #1172](https://github.com/gofiber/contrib/pull/1172). +### OpenAPI + +Introduces an `openapi` middleware that inspects registered routes and serves a generated OpenAPI 3.0 specification. Each operation includes a summary and default `200` response. Routes may attach descriptions, parameters, request bodies, and custom responses—alongside request/response media types—directly or configure them globally. + ### Proxy The proxy middleware has been updated to improve consistency with Go naming conventions. The `TlsConfig` field in the configuration struct has been renamed to `TLSConfig`. Additionally, the `WithTlsConfig` method has been removed; you should now configure TLS directly via the `TLSConfig` property within the `Config` struct. diff --git a/group.go b/group.go index f85674bfb8c..6bb1f172e47 100644 --- a/group.go +++ b/group.go @@ -45,6 +45,60 @@ func (grp *Group) Name(name string) Router { return grp } +// Summary assigns a short summary to the most recently added route in the group. +func (grp *Group) Summary(sum string) Router { + grp.app.Summary(sum) + return grp +} + +// Description assigns a description to the most recently added route in the group. +func (grp *Group) Description(desc string) Router { + grp.app.Description(desc) + return grp +} + +// Consumes assigns a request media type to the most recently added route in the group. +func (grp *Group) Consumes(typ string) Router { + grp.app.Consumes(typ) + return grp +} + +// Produces assigns a response media type to the most recently added route in the group. +func (grp *Group) Produces(typ string) Router { + grp.app.Produces(typ) + return grp +} + +// RequestBody documents the request payload for the most recently added route in the group. +func (grp *Group) RequestBody(description string, required bool, mediaTypes ...string) Router { + grp.app.RequestBody(description, required, mediaTypes...) + return grp +} + +// Parameter documents an input parameter for the most recently added route in the group. +func (grp *Group) Parameter(name, in string, required bool, schema map[string]any, description string) Router { + grp.app.Parameter(name, in, required, schema, description) + return grp +} + +// Response documents an HTTP response for the most recently added route in the group. +func (grp *Group) Response(status int, description string, mediaTypes ...string) Router { + grp.app.Response(status, description, mediaTypes...) + return grp +} + +// Tags assigns tags to the most recently added route in the group. +func (grp *Group) Tags(tags ...string) Router { + grp.app.Tags(tags...) + return grp +} + +// Deprecated marks the most recently added route in the group as deprecated. +func (grp *Group) Deprecated() Router { + grp.app.Deprecated() + return grp +} + // Use registers a middleware route that will match requests // with the provided prefix (which is optional and defaults to "/"). // Also, you can pass another app instance as a sub-router along a routing path. diff --git a/group_test.go b/group_test.go new file mode 100644 index 00000000000..fd6a63bf286 --- /dev/null +++ b/group_test.go @@ -0,0 +1,99 @@ +package fiber + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Group_OpenAPI_Helpers(t *testing.T) { + t.Parallel() + + t.Run("Summary", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Summary("sum") + route := app.stack[app.methodInt(MethodGet)][0] + require.Equal(t, "sum", route.Summary) + }) + + t.Run("Description", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Description("desc") + route := app.stack[app.methodInt(MethodGet)][0] + require.Equal(t, "desc", route.Description) + }) + + t.Run("Consumes", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Consumes(MIMEApplicationJSON) + route := app.stack[app.methodInt(MethodGet)][0] + //nolint:testifylint // MIMEApplicationJSON is a plain string, JSONEq not required + require.Equal(t, MIMEApplicationJSON, route.Consumes) + }) + + t.Run("Produces", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Produces(MIMEApplicationXML) + route := app.stack[app.methodInt(MethodGet)][0] + //nolint:testifylint // MIMEApplicationXML is a plain string, JSONEq not required + require.Equal(t, MIMEApplicationXML, route.Produces) + }) + + t.Run("RequestBody", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Post("/users", testEmptyHandler).RequestBody("User", true, MIMEApplicationJSON) + route := app.stack[app.methodInt(MethodPost)][0] + require.NotNil(t, route.RequestBody) + require.Equal(t, []string{MIMEApplicationJSON}, route.RequestBody.MediaTypes) + }) + + t.Run("Parameter", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users/:id", testEmptyHandler).Parameter("id", "path", false, map[string]any{"type": "integer"}, "identifier") + route := app.stack[app.methodInt(MethodGet)][0] + require.Len(t, route.Parameters, 1) + require.Equal(t, "id", route.Parameters[0].Name) + require.True(t, route.Parameters[0].Required) + require.Equal(t, "integer", route.Parameters[0].Schema["type"]) + }) + + t.Run("Response", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Response(StatusCreated, "Created", MIMEApplicationJSON) + route := app.stack[app.methodInt(MethodGet)][0] + require.Contains(t, route.Responses, "201") + require.Equal(t, []string{MIMEApplicationJSON}, route.Responses["201"].MediaTypes) + }) + + t.Run("Tags", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Tags("foo", "bar") + route := app.stack[app.methodInt(MethodGet)][0] + require.Equal(t, []string{"foo", "bar"}, route.Tags) + }) + + t.Run("Deprecated", func(t *testing.T) { + t.Parallel() + app := New() + grp := app.Group("/api") + grp.Get("/users", testEmptyHandler).Deprecated() + route := app.stack[app.methodInt(MethodGet)][0] + require.True(t, route.Deprecated) + }) +} diff --git a/middleware/openapi/config.go b/middleware/openapi/config.go new file mode 100644 index 00000000000..e790c8396df --- /dev/null +++ b/middleware/openapi/config.go @@ -0,0 +1,130 @@ +package openapi + +import ( + "github.com/gofiber/fiber/v3" +) + +// Config defines the config for middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // Operations allows providing per-route metadata keyed by + // "METHOD /path" (e.g. "GET /users"). + // + // Optional. Default: nil + Operations map[string]Operation + + // Title is the title for the generated OpenAPI specification. + // + // Optional. Default: "Fiber API" + Title string + + // Version is the version for the generated OpenAPI specification. + // + // Optional. Default: "1.0.0" + Version string + + // Description is the description for the generated OpenAPI specification. + // + // Optional. Default: "" + Description string + + // ServerURL is the server URL used in the generated specification. + // + // Optional. Default: "" + ServerURL string + + // Path is the route where the specification will be served. + // + // Optional. Default: "/openapi.json" + Path string +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + Next: nil, + Operations: nil, + Title: "Fiber API", + Version: "1.0.0", + Description: "", + ServerURL: "", + Path: "/openapi.json", +} + +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + cfg := config[0] + + if cfg.Next == nil { + cfg.Next = ConfigDefault.Next + } + if cfg.Title == "" { + cfg.Title = ConfigDefault.Title + } + if cfg.Version == "" { + cfg.Version = ConfigDefault.Version + } + if cfg.Description == "" { + cfg.Description = ConfigDefault.Description + } + if cfg.ServerURL == "" { + cfg.ServerURL = ConfigDefault.ServerURL + } + if cfg.Path == "" { + cfg.Path = ConfigDefault.Path + } + if cfg.Operations == nil { + cfg.Operations = ConfigDefault.Operations + } + + return cfg +} + +// Operation configures metadata for a single route in the generated spec. +type Operation struct { + RequestBody *RequestBody + Responses map[string]Response + Parameters []Parameter + Tags []string + + ID string + Summary string + Description string + Consumes string + Produces string + Deprecated bool +} + +// Parameter describes a single OpenAPI parameter. +type Parameter struct { + Schema map[string]any + + Name string + In string + Description string + Required bool +} + +// Media describes the schema payload for a request or response media type. +type Media struct { + Schema map[string]any +} + +// Response describes an OpenAPI response object. +type Response struct { + Content map[string]Media + Description string +} + +// RequestBody describes the request body configuration for an operation. +type RequestBody struct { + Content map[string]Media + Description string + Required bool +} diff --git a/middleware/openapi/openapi.go b/middleware/openapi/openapi.go new file mode 100644 index 00000000000..388b53763d3 --- /dev/null +++ b/middleware/openapi/openapi.go @@ -0,0 +1,438 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "maps" + "strings" + "sync" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" +) + +// New creates a new middleware handler that serves the generated OpenAPI specification. +func New(config ...Config) fiber.Handler { + cfg := configDefault(config...) + + var ( + data []byte + once sync.Once + genErr error + ) + + return func(c fiber.Ctx) error { + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + targetPath := resolvedSpecPath(c, cfg.Path) + if c.Path() != targetPath { + return c.Next() + } + + once.Do(func() { + spec := generateSpec(c.App(), cfg) + data, genErr = json.Marshal(spec) + if genErr != nil { + genErr = fmt.Errorf("openapi: marshal spec: %w", genErr) + } + }) + if genErr != nil { + return genErr + } + c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8) + return c.Status(fiber.StatusOK).Send(data) + } +} + +func resolvedSpecPath(c fiber.Ctx, cfgPath string) string { + path := cfgPath + if path == "" { + path = ConfigDefault.Path + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + route := c.Route() + if route == nil { + return path + } + + prefix := route.Path + if idx := strings.Index(prefix, "*"); idx >= 0 { + prefix = prefix[:idx] + } + if prefix == "/" || prefix == "" { + return path + } + if strings.HasSuffix(prefix, "/") { + prefix = strings.TrimSuffix(prefix, "/") + } + if prefix == "" { + return path + } + + return prefix + path +} + +type openAPISpec struct { + Paths map[string]map[string]operation `json:"paths"` + Servers []openAPIServer `json:"servers,omitempty"` + Info openAPIInfo `json:"info"` + OpenAPI string `json:"openapi"` +} + +type openAPIInfo struct { + Title string `json:"title"` + Version string `json:"version"` + Description string `json:"description,omitempty"` +} + +type openAPIServer struct { + URL string `json:"url"` +} + +type operation struct { + Responses map[string]response `json:"responses"` + RequestBody *requestBody `json:"requestBody,omitempty"` //nolint:tagliatelle + Parameters []parameter `json:"parameters,omitempty"` + Tags []string `json:"tags,omitempty"` + + OperationID string `json:"operationId,omitempty"` //nolint:tagliatelle + Summary string `json:"summary"` + Description string `json:"description"` + Deprecated bool `json:"deprecated,omitempty"` +} + +type response struct { + Content map[string]map[string]any `json:"content,omitempty"` + Description string `json:"description"` +} + +type parameter struct { + Schema map[string]any `json:"schema,omitempty"` + Description string `json:"description,omitempty"` + Name string `json:"name"` + In string `json:"in"` + Required bool `json:"required"` +} + +type requestBody struct { + Content map[string]map[string]any `json:"content"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +func generateSpec(app *fiber.App, cfg Config) openAPISpec { + paths := make(map[string]map[string]operation) + stack := app.Stack() + + for _, routes := range stack { + for _, r := range routes { + if r.Method == fiber.MethodConnect { + continue + } + + path := r.Path + params := make([]parameter, 0, len(r.Params)) + paramIndex := make(map[string]int, len(r.Params)) + if len(r.Params) > 0 { + for _, p := range r.Params { + path = strings.Replace(path, ":"+p, "{"+p+"}", 1) + param := parameter{ + Name: p, + In: "path", + Required: true, + Schema: map[string]any{"type": "string"}, + } + params = append(params, param) + paramIndex[param.In+":"+param.Name] = len(params) - 1 + } + } + + methodLower := utils.ToLower(r.Method) + if paths[path] == nil { + paths[path] = make(map[string]operation) + } + + key := r.Method + " " + r.Path + meta := cfg.Operations[key] + + params = mergeRouteParameters(params, paramIndex, r.Parameters) + params = mergeConfigParameters(params, paramIndex, meta.Parameters) + + summary := meta.Summary + if summary == "" { + summary = r.Summary + } + if summary == "" { + summary = r.Method + " " + r.Path + } + description := meta.Description + if description == "" { + description = r.Description + } + + respType := meta.Produces + if respType == "" { + respType = r.Produces + } + + responses := mergeResponses(r.Responses, meta.Responses) + if responses == nil { + responses = make(map[string]response) + } + defaultResp, exists := responses["200"] + if defaultResp.Description == "" { + defaultResp.Description = "OK" + } + if !exists && respType != "" { + defaultResp.Content = map[string]map[string]any{ + respType: {}, + } + } + responses["200"] = defaultResp + + reqBody := buildRequestBody(r.RequestBody, meta.RequestBody) + if reqBody == nil { + reqType := meta.Consumes + if reqType == "" { + reqType = r.Consumes + } + if shouldIncludeRequestBody(reqType, meta, r) { + reqBody = &requestBody{Content: map[string]map[string]any{reqType: {}}} + } + } + + opID := meta.ID + if opID == "" { + opID = r.Name + } + + tags := meta.Tags + if len(tags) == 0 { + tags = r.Tags + } + + deprecated := meta.Deprecated || r.Deprecated + + paths[path][methodLower] = operation{ + OperationID: opID, + Summary: summary, + Description: description, + Tags: tags, + Deprecated: deprecated, + Parameters: params, + RequestBody: reqBody, + Responses: responses, + } + } + } + + spec := openAPISpec{ + OpenAPI: "3.0.0", + Info: openAPIInfo{ + Title: cfg.Title, + Version: cfg.Version, + Description: cfg.Description, + }, + Paths: paths, + } + if cfg.ServerURL != "" { + spec.Servers = []openAPIServer{{URL: cfg.ServerURL}} + } + return spec +} + +func mergeRouteParameters(params []parameter, index map[string]int, extras []fiber.RouteParameter) []parameter { + if len(extras) == 0 { + return params + } + for _, extra := range extras { + if strings.TrimSpace(extra.Name) == "" { + continue + } + location := strings.ToLower(strings.TrimSpace(extra.In)) + if location == "" { + location = "query" + } + param := parameter{ + Name: extra.Name, + In: location, + Description: extra.Description, + Required: extra.Required, + Schema: copyAnyMap(extra.Schema), + } + if param.Schema == nil { + param.Schema = map[string]any{"type": "string"} + } + if param.In == "path" { + param.Required = true + } + params = appendOrReplaceParameter(params, index, param) + } + return params +} + +func mergeConfigParameters(params []parameter, index map[string]int, extras []Parameter) []parameter { + if len(extras) == 0 { + return params + } + for _, extra := range extras { + if strings.TrimSpace(extra.Name) == "" { + continue + } + location := strings.ToLower(strings.TrimSpace(extra.In)) + if location == "" { + location = "query" + } + param := parameter{ + Name: extra.Name, + In: location, + Description: extra.Description, + Required: extra.Required, + Schema: copyAnyMap(extra.Schema), + } + if param.Schema == nil { + param.Schema = map[string]any{"type": "string"} + } + if param.In == "path" { + param.Required = true + } + params = appendOrReplaceParameter(params, index, param) + } + return params +} + +func appendOrReplaceParameter(params []parameter, index map[string]int, p parameter) []parameter { + if p.Name == "" || p.In == "" { + return params + } + key := p.In + ":" + p.Name + if idx, ok := index[key]; ok { + params[idx] = p + return params + } + index[key] = len(params) + return append(params, p) +} + +func copyAnyMap(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + dst := make(map[string]any, len(src)) + maps.Copy(dst, src) + return dst +} + +func mergeResponses(routeResponses map[string]fiber.RouteResponse, cfgResponses map[string]Response) map[string]response { + var merged map[string]response + if len(routeResponses) > 0 { + merged = make(map[string]response, len(routeResponses)) + for code, resp := range routeResponses { + merged[code] = response{ + Description: resp.Description, + Content: mediaTypesToContent(resp.MediaTypes), + } + } + } + if len(cfgResponses) > 0 { + if merged == nil { + merged = make(map[string]response, len(cfgResponses)) + } + for code, resp := range cfgResponses { + merged[code] = response{ + Description: resp.Description, + Content: convertMediaContent(resp.Content), + } + } + } + return merged +} + +func convertMediaContent(content map[string]Media) map[string]map[string]any { + if len(content) == 0 { + return nil + } + converted := make(map[string]map[string]any, len(content)) + for mediaType, media := range content { + entry := map[string]any{} + if schema := copyAnyMap(media.Schema); len(schema) > 0 { + entry["schema"] = schema + } + converted[mediaType] = entry + } + return converted +} + +func mediaTypesToContent(mediaTypes []string) map[string]map[string]any { + if len(mediaTypes) == 0 { + return nil + } + content := make(map[string]map[string]any, len(mediaTypes)) + for _, mediaType := range mediaTypes { + if mediaType == "" { + continue + } + content[mediaType] = map[string]any{} + } + if len(content) == 0 { + return nil + } + return content +} + +func buildRequestBody(routeBody *fiber.RouteRequestBody, cfgBody *RequestBody) *requestBody { + var merged *requestBody + if routeBody != nil { + merged = &requestBody{ + Description: routeBody.Description, + Required: routeBody.Required, + Content: mediaTypesToContent(routeBody.MediaTypes), + } + } + if cfgBody != nil { + cfgReq := &requestBody{ + Description: cfgBody.Description, + Required: cfgBody.Required, + Content: convertMediaContent(cfgBody.Content), + } + if merged == nil { + merged = cfgReq + } else { + if cfgReq.Description != "" { + merged.Description = cfgReq.Description + } + merged.Required = cfgReq.Required + if len(cfgReq.Content) > 0 { + if merged.Content == nil { + merged.Content = cfgReq.Content + } else { + maps.Copy(merged.Content, cfgReq.Content) + } + } + } + } + return merged +} + +func shouldIncludeRequestBody(reqType string, meta Operation, route *fiber.Route) bool { + if reqType == "" || route == nil { + return false + } + if meta.Consumes != "" { + return true + } + if route.Consumes != fiber.MIMETextPlain { + return true + } + switch route.Method { + case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: + return false + default: + return true + } +} diff --git a/middleware/openapi/openapi_test.go b/middleware/openapi/openapi_test.go new file mode 100644 index 00000000000..32eefc6cef8 --- /dev/null +++ b/middleware/openapi/openapi_test.go @@ -0,0 +1,631 @@ +package openapi + +import ( + "encoding/json" + "io" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_OpenAPI_Generate(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + app.Post("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec struct { + Paths map[string]map[string]any `json:"paths"` + } + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Contains(t, spec.Paths, "/users") + operations := spec.Paths["/users"] + require.Contains(t, operations, "get") + require.Contains(t, operations, "post") + getOp := requireMap(t, operations["get"]) + require.Contains(t, getOp, "responses") + responses := requireMap(t, getOp["responses"]) + require.Contains(t, responses, "200") +} + +func Test_OpenAPI_JSONEquality(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }). + Name("listUsers").Produces(fiber.MIMEApplicationJSON) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + rootOps := map[string]operation{} + for _, m := range app.Config().RequestMethods { + if m == fiber.MethodConnect { + continue + } + lower := strings.ToLower(m) + upper := strings.ToUpper(m) + op := operation{ + Summary: upper + " /", + Description: "", + Responses: map[string]response{ + "200": {Description: "OK", Content: map[string]map[string]any{fiber.MIMETextPlain: {}}}, + }, + } + switch m { + case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: + default: + op.RequestBody = &requestBody{Content: map[string]map[string]any{fiber.MIMETextPlain: {}}} + } + rootOps[lower] = op + } + expected := openAPISpec{ + OpenAPI: "3.0.0", + Info: openAPIInfo{Title: "Fiber API", Version: "1.0.0"}, + Paths: map[string]map[string]operation{ + "/": rootOps, + "/users": { + "get": { + OperationID: "listUsers", + Summary: "GET /users", + Description: "", + Responses: map[string]response{ + "200": {Description: "OK", Content: map[string]map[string]any{fiber.MIMEApplicationJSON: {}}}, + }, + }, + }, + }, + } + exp, err := json.Marshal(expected) + require.NoError(t, err) + require.JSONEq(t, string(exp), string(body)) +} + +func Test_OpenAPI_RawJSON(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }). + Name("listUsers").Produces(fiber.MIMEApplicationJSON) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + rootOps := map[string]operation{} + for _, m := range app.Config().RequestMethods { + if m == fiber.MethodConnect { + continue + } + lower := strings.ToLower(m) + upper := strings.ToUpper(m) + op := operation{ + Summary: upper + " /", + Description: "", + Responses: map[string]response{ + "200": {Description: "OK", Content: map[string]map[string]any{fiber.MIMETextPlain: {}}}, + }, + } + switch m { + case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: + default: + op.RequestBody = &requestBody{Content: map[string]map[string]any{fiber.MIMETextPlain: {}}} + } + rootOps[lower] = op + } + expected := openAPISpec{ + OpenAPI: "3.0.0", + Info: openAPIInfo{Title: "Fiber API", Version: "1.0.0"}, + Paths: map[string]map[string]operation{ + "/": rootOps, + "/users": { + "get": { + OperationID: "listUsers", + Summary: "GET /users", + Description: "", + Responses: map[string]response{ + "200": {Description: "OK", Content: map[string]map[string]any{fiber.MIMEApplicationJSON: {}}}, + }, + }, + }, + }, + } + exp, err := json.Marshal(expected) + require.NoError(t, err) + require.JSONEq(t, string(exp), string(body)) +} + +func Test_OpenAPI_RawJSONFile(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }). + Name("listUsers").Produces(fiber.MIMEApplicationJSON) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + expected, err := os.ReadFile("testdata/openapi.json") + require.NoError(t, err) + + require.JSONEq(t, string(expected), string(body)) +} + +func Test_OpenAPI_OperationConfig(t *testing.T) { + app := fiber.New() + app.Get("/users", func(c fiber.Ctx) error { return c.JSON(fiber.Map{"hello": "world"}) }) + + app.Use(New(Config{ + Operations: map[string]Operation{ + "GET /users": { + ID: "listUsersCustom", + Summary: "List users", + Description: "Returns all users", + Tags: []string{"users"}, + Deprecated: true, + Consumes: fiber.MIMEApplicationJSON, + Produces: fiber.MIMEApplicationJSON, + Parameters: []Parameter{{ + Name: "limit", + In: "query", + Required: true, + Description: "Maximum items", + Schema: map[string]any{"type": "integer"}, + }}, + RequestBody: &RequestBody{ + Description: "Custom payload", + Required: true, + Content: map[string]Media{ + fiber.MIMEApplicationJSON: {Schema: map[string]any{"type": "object"}}, + }, + }, + Responses: map[string]Response{ + "201": {Description: "Created", Content: map[string]Media{ + fiber.MIMEApplicationJSON: {Schema: map[string]any{"type": "object"}}, + }}, + }, + }, + }, + })) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + + op := spec.Paths["/users"]["get"] + require.Equal(t, "listUsersCustom", op.OperationID) + require.Equal(t, "List users", op.Summary) + require.Equal(t, "Returns all users", op.Description) + require.ElementsMatch(t, []string{"users"}, op.Tags) + require.True(t, op.Deprecated) + require.Contains(t, op.Responses["200"].Content, fiber.MIMEApplicationJSON) + require.Contains(t, op.Responses, "201") + require.Contains(t, op.Responses["201"].Content, fiber.MIMEApplicationJSON) + require.NotNil(t, op.RequestBody) + require.Equal(t, "Custom payload", op.RequestBody.Description) + require.Contains(t, op.RequestBody.Content, fiber.MIMEApplicationJSON) + require.True(t, op.RequestBody.Required) + require.Len(t, op.Parameters, 1) + require.Equal(t, "limit", op.Parameters[0].Name) + require.Equal(t, "integer", op.Parameters[0].Schema["type"]) +} + +func Test_OpenAPI_RouteMetadata(t *testing.T) { + app := fiber.New() + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }). + Summary("List users").Description("User list").Produces(fiber.MIMEApplicationJSON). + Parameter("trace-id", "header", true, nil, "Tracing identifier"). + Tags("users", "read").Deprecated() + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + + op := spec.Paths["/users"]["get"] + require.Equal(t, "List users", op.Summary) + require.Equal(t, "User list", op.Description) + require.Contains(t, op.Responses["200"].Content, fiber.MIMEApplicationJSON) + require.ElementsMatch(t, []string{"users", "read"}, op.Tags) + require.True(t, op.Deprecated) + require.Len(t, op.Parameters, 1) + require.Equal(t, "trace-id", op.Parameters[0].Name) + require.Equal(t, "header", op.Parameters[0].In) + require.Equal(t, "Tracing identifier", op.Parameters[0].Description) +} + +func Test_OpenAPI_RouteRequestBodyAndResponses(t *testing.T) { + app := fiber.New() + + app.Post("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }). + RequestBody("Create user", true, fiber.MIMEApplicationJSON). + Response(fiber.StatusCreated, "Created", fiber.MIMEApplicationJSON) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + + op := spec.Paths["/users"]["post"] + require.NotNil(t, op.RequestBody) + require.Equal(t, "Create user", op.RequestBody.Description) + require.True(t, op.RequestBody.Required) + require.Contains(t, op.RequestBody.Content, fiber.MIMEApplicationJSON) + require.Contains(t, op.Responses, "201") + require.Equal(t, "Created", op.Responses["201"].Description) + require.Contains(t, op.Responses["201"].Content, fiber.MIMEApplicationJSON) + require.Contains(t, op.Responses, "200") + require.Equal(t, "OK", op.Responses["200"].Description) +} + +// getPaths is a helper that mounts the middleware, performs the request and +// decodes the resulting OpenAPI specification paths. +func getPaths(t *testing.T, app *fiber.App) map[string]map[string]any { + t.Helper() + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec struct { + Paths map[string]map[string]any `json:"paths"` + } + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + return spec.Paths +} + +func Test_OpenAPI_Methods(t *testing.T) { + handler := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) } + + tests := []struct { + register func(*fiber.App) + method string + }{ + {func(a *fiber.App) { a.Get("/method", handler) }, fiber.MethodGet}, + {func(a *fiber.App) { a.Post("/method", handler) }, fiber.MethodPost}, + {func(a *fiber.App) { a.Put("/method", handler) }, fiber.MethodPut}, + {func(a *fiber.App) { a.Patch("/method", handler) }, fiber.MethodPatch}, + {func(a *fiber.App) { a.Delete("/method", handler) }, fiber.MethodDelete}, + {func(a *fiber.App) { a.Head("/method", handler) }, fiber.MethodHead}, + {func(a *fiber.App) { a.Options("/method", handler) }, fiber.MethodOptions}, + {func(a *fiber.App) { a.Trace("/method", handler) }, fiber.MethodTrace}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + app := fiber.New() + tt.register(app) + + paths := getPaths(t, app) + require.Contains(t, paths, "/method") + ops := paths["/method"] + require.Contains(t, ops, strings.ToLower(tt.method)) + }) + } +} + +func Test_OpenAPI_DifferentHandlers(t *testing.T) { + app := fiber.New() + + app.Get("/string", func(c fiber.Ctx) error { return c.SendString("a") }) + app.Get("/json", func(c fiber.Ctx) error { return c.JSON(fiber.Map{"hello": "world"}) }) + + paths := getPaths(t, app) + + require.Contains(t, paths, "/string") + require.Contains(t, paths["/string"], "get") + require.Contains(t, paths, "/json") + require.Contains(t, paths["/json"], "get") +} + +func Test_OpenAPI_Params(t *testing.T) { + app := fiber.New() + + app.Get("/users/:id", func(c fiber.Ctx) error { return c.SendString(c.Params("id")) }). + Parameter("id", "path", true, map[string]any{"type": "integer"}, "identifier") + + paths := getPaths(t, app) + require.Contains(t, paths, "/users/{id}") + require.Contains(t, paths["/users/{id}"], "get") + op := requireMap(t, paths["/users/{id}"]["get"]) + params := requireSlice(t, op["parameters"]) + require.Len(t, params, 1) + p0 := requireMap(t, params[0]) + require.Equal(t, "id", p0["name"]) + require.Equal(t, "path", p0["in"]) + require.Equal(t, "identifier", p0["description"]) + schema := requireMap(t, p0["schema"]) + require.Equal(t, "integer", schema["type"]) +} + +func Test_OpenAPI_Groups(t *testing.T) { + app := fiber.New() + + api := app.Group("/api") + api.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + api.Post("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }) + + paths := getPaths(t, app) + + require.Contains(t, paths, "/api/users") + ops := paths["/api/users"] + require.Contains(t, ops, "get") + require.Contains(t, ops, "post") +} + +func Test_OpenAPI_Groups_Metadata(t *testing.T) { + app := fiber.New() + + api := app.Group("/api") + api.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }). + Summary("List users").Description("Group users").Produces(fiber.MIMEApplicationJSON). + Tags("users").Deprecated() + + paths := getPaths(t, app) + + require.Contains(t, paths, "/api/users") + op := requireMap(t, paths["/api/users"]["get"]) + require.Equal(t, "List users", op["summary"]) + require.Equal(t, "Group users", op["description"]) + require.ElementsMatch(t, []any{"users"}, requireSlice(t, op["tags"])) + require.Equal(t, true, op["deprecated"]) + resp := requireMap(t, op["responses"]) + cont := requireMap(t, requireMap(t, resp["200"])["content"]) + require.Contains(t, cont, fiber.MIMEApplicationJSON) +} + +func Test_OpenAPI_NoRoutes(t *testing.T) { + app := fiber.New() + + paths := getPaths(t, app) + + require.Len(t, paths, 1) + require.Contains(t, paths, "/") +} + +func Test_OpenAPI_RootOnly(t *testing.T) { + app := fiber.New() + + app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + paths := getPaths(t, app) + + require.Contains(t, paths, "/") + require.Contains(t, paths["/"], "get") +} + +func Test_OpenAPI_GroupMiddleware(t *testing.T) { + app := fiber.New() + + api := app.Group("/api/v2") + api.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + api.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/api/v2/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Contains(t, spec.Paths, "/api/v2/users") +} + +func Test_OpenAPI_DoesNotInterceptSimilarPaths(t *testing.T) { + app := fiber.New() + + app.Use(New()) + app.Get("/reports/openapi.json", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusAccepted) }) + + req := httptest.NewRequest(fiber.MethodGet, "/reports/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusAccepted, resp.StatusCode) +} + +func Test_OpenAPI_RootAndGroupSpecs(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{Title: "root"})) + + v1 := app.Group("/v1") + v1.Use(New(Config{Title: "group"})) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Equal(t, "root", spec.Info.Title) + + req = httptest.NewRequest(fiber.MethodGet, "/v1/openapi.json", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Equal(t, "group", spec.Info.Title) +} + +func Test_OpenAPI_ConfigValues(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + cfg := Config{ + Title: "Custom API", + Version: "2.1.0", + Description: "My description", + ServerURL: "https://example.com", + Path: "/spec.json", + } + app.Use(New(cfg)) + + req := httptest.NewRequest(fiber.MethodGet, "/spec.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Equal(t, cfg.Title, spec.Info.Title) + require.Equal(t, cfg.Version, spec.Info.Version) + require.Equal(t, cfg.Description, spec.Info.Description) + require.Len(t, spec.Servers, 1) + require.Equal(t, cfg.ServerURL, spec.Servers[0].URL) +} + +func Test_OpenAPI_Next(t *testing.T) { + app := fiber.New() + + app.Use(New(Config{Next: func(fiber.Ctx) bool { return true }})) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusNotFound, resp.StatusCode) +} + +func Test_OpenAPI_ConnectIgnored(t *testing.T) { + app := fiber.New() + + app.Connect("/conn", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + paths := getPaths(t, app) + require.NotContains(t, paths, "/conn") +} + +func Test_OpenAPI_MultipleParams(t *testing.T) { + app := fiber.New() + + app.Get("/users/:uid/books/:bid", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + paths := getPaths(t, app) + require.Contains(t, paths, "/users/{uid}/books/{bid}") + op := requireMap(t, paths["/users/{uid}/books/{bid}"]["get"]) + params := requireSlice(t, op["parameters"]) + require.Len(t, params, 2) + p0 := requireMap(t, params[0]) + p1 := requireMap(t, params[1]) + require.Equal(t, "uid", p0["name"]) + require.Equal(t, "path", p0["in"]) + require.Equal(t, "bid", p1["name"]) + require.Equal(t, "path", p1["in"]) +} + +func Test_OpenAPI_ConsumesProduces(t *testing.T) { + app := fiber.New() + + app.Post("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }). + Consumes(fiber.MIMEApplicationJSON). + Produces(fiber.MIMEApplicationXML) + + paths := getPaths(t, app) + + op := requireMap(t, paths["/users"]["post"]) + rb := requireMap(t, op["requestBody"]) + reqContent := requireMap(t, rb["content"]) + require.Contains(t, reqContent, fiber.MIMEApplicationJSON) + + resp := requireMap(t, requireMap(t, op["responses"])["200"]) + cont := requireMap(t, resp["content"]) + require.Contains(t, cont, fiber.MIMEApplicationXML) +} + +func Test_OpenAPI_NoRequestBodyForGET(t *testing.T) { + app := fiber.New() + + app.Get("/users", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + paths := getPaths(t, app) + op := requireMap(t, paths["/users"]["get"]) + require.NotContains(t, op, "requestBody") +} + +func Test_OpenAPI_Cache(t *testing.T) { + app := fiber.New() + + app.Get("/first", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + app.Use(New()) + + req := httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + var spec openAPISpec + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.Contains(t, spec.Paths, "/first") + + app.Get("/second", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) + + req = httptest.NewRequest(fiber.MethodGet, "/openapi.json", nil) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + require.NoError(t, json.NewDecoder(resp.Body).Decode(&spec)) + require.NotContains(t, spec.Paths, "/second") +} + +func requireMap(t *testing.T, value any) map[string]any { + t.Helper() + m, ok := value.(map[string]any) + require.True(t, ok) + return m +} + +func requireSlice(t *testing.T, value any) []any { + t.Helper() + s, ok := value.([]any) + require.True(t, ok) + return s +} diff --git a/middleware/openapi/testdata/openapi.json b/middleware/openapi/testdata/openapi.json new file mode 100644 index 00000000000..6a549b76ea5 --- /dev/null +++ b/middleware/openapi/testdata/openapi.json @@ -0,0 +1 @@ +{"openapi":"3.0.0","info":{"title":"Fiber API","version":"1.0.0"},"paths":{"/":{"delete":{"summary":"DELETE /","description":"","requestBody":{"content":{"text/plain":{}}},"responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"get":{"summary":"GET /","description":"","responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"head":{"summary":"HEAD /","description":"","responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"options":{"summary":"OPTIONS /","description":"","responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"patch":{"summary":"PATCH /","description":"","requestBody":{"content":{"text/plain":{}}},"responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"post":{"summary":"POST /","description":"","requestBody":{"content":{"text/plain":{}}},"responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"put":{"summary":"PUT /","description":"","requestBody":{"content":{"text/plain":{}}},"responses":{"200":{"description":"OK","content":{"text/plain":{}}}}},"trace":{"summary":"TRACE /","description":"","responses":{"200":{"description":"OK","content":{"text/plain":{}}}}}},"/users":{"get":{"operationId":"listUsers","summary":"GET /users","description":"","responses":{"200":{"description":"OK","content":{"application/json":{}}}}}}}} \ No newline at end of file diff --git a/router.go b/router.go index 6ef36ef093e..4776e7ae6e8 100644 --- a/router.go +++ b/router.go @@ -7,6 +7,7 @@ package fiber import ( "bytes" "fmt" + "maps" "slices" "sync/atomic" @@ -36,6 +37,30 @@ type Router interface { Route(path string) Register Name(name string) Router + // Summary sets a short summary for the most recently registered route. + Summary(sum string) Router + // Description sets a human-readable description for the most recently + // registered route. + Description(desc string) Router + // Consumes sets the request media type for the most recently + // registered route. + Consumes(typ string) Router + // Produces sets the response media type for the most recently + // registered route. + Produces(typ string) Router + // RequestBody documents the request body for the most recently + // registered route. + RequestBody(description string, required bool, mediaTypes ...string) Router + // Parameter documents an input parameter for the most recently + // registered route. + Parameter(name, in string, required bool, schema map[string]any, description string) Router + // Response documents an HTTP response for the most recently + // registered route. + Response(status int, description string, mediaTypes ...string) Router + // Tags sets the tags for the most recently registered route. + Tags(tags ...string) Router + // Deprecated marks the most recently registered route as deprecated. + Deprecated() Router } // Route is a struct that holds all metadata for each registered handler. @@ -43,16 +68,27 @@ type Route struct { // ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ### group *Group // Group instance. used for routes in groups + routeParser routeParser // Parameter parser + + Handlers []Handler `json:"-"` // Ctx handlers + Parameters []RouteParameter `json:"parameters"` + Responses map[string]RouteResponse `json:"responses"` + RequestBody *RouteRequestBody `json:"requestBody"` //nolint:tagliatelle + Tags []string `json:"tags"` + Params []string `json:"params"` // Case-sensitive param keys + path string // Prettified path // Public fields Method string `json:"method"` // HTTP method Name string `json:"name"` // Route's name //nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine - Path string `json:"path"` // Original registered route path - Params []string `json:"params"` // Case-sensitive param keys - Handlers []Handler `json:"-"` // Ctx handlers - routeParser routeParser // Parameter parser + Path string `json:"path"` // Original registered route path + Summary string `json:"summary"` + Description string `json:"description"` + Consumes string `json:"consumes"` + Produces string `json:"produces"` + Deprecated bool `json:"deprecated"` // Data for routing use bool // USE matches path prefixes mount bool // Indicated a mounted app on a specific route @@ -60,6 +96,28 @@ type Route struct { root bool // Path equals '/' } +// RouteParameter describes an input captured by a route. +type RouteParameter struct { + Schema map[string]any `json:"schema"` + Description string `json:"description"` + Name string `json:"name"` + In string `json:"in"` + Required bool `json:"required"` +} + +// RouteResponse describes a response emitted by a route. +type RouteResponse struct { + MediaTypes []string `json:"mediaTypes"` //nolint:tagliatelle + Description string `json:"description"` +} + +// RouteRequestBody describes the request payload accepted by a route. +type RouteRequestBody struct { + MediaTypes []string `json:"mediaTypes"` //nolint:tagliatelle + Description string `json:"description"` + Required bool `json:"required"` +} + func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool { // root detectionPath check if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' { @@ -373,12 +431,71 @@ func (*App) copyRoute(route *Route) *Route { routeParser: route.routeParser, // Public data - Path: route.Path, - Params: route.Params, - Name: route.Name, - Method: route.Method, - Handlers: route.Handlers, + Path: route.Path, + Params: route.Params, + Name: route.Name, + Method: route.Method, + Handlers: route.Handlers, + Summary: route.Summary, + Description: route.Description, + Consumes: route.Consumes, + Produces: route.Produces, + RequestBody: cloneRouteRequestBody(route.RequestBody), + Parameters: cloneRouteParameters(route.Parameters), + Responses: cloneRouteResponses(route.Responses), + Tags: route.Tags, + Deprecated: route.Deprecated, + } +} + +func cloneRouteRequestBody(body *RouteRequestBody) *RouteRequestBody { + if body == nil { + return nil + } + clone := &RouteRequestBody{ + Description: body.Description, + Required: body.Required, + } + if len(body.MediaTypes) > 0 { + clone.MediaTypes = append([]string(nil), body.MediaTypes...) + } + return clone +} + +func cloneRouteParameters(params []RouteParameter) []RouteParameter { + if len(params) == 0 { + return nil + } + cloned := make([]RouteParameter, len(params)) + for i, p := range params { + cloned[i] = RouteParameter{ + Name: p.Name, + In: p.In, + Required: p.Required, + Description: p.Description, + } + if len(p.Schema) > 0 { + schemaCopy := make(map[string]any, len(p.Schema)) + maps.Copy(schemaCopy, p.Schema) + cloned[i].Schema = schemaCopy + } + } + return cloned +} + +func cloneRouteResponses(responses map[string]RouteResponse) map[string]RouteResponse { + if len(responses) == 0 { + return nil + } + cloned := make(map[string]RouteResponse, len(responses)) + for code, resp := range responses { + copyResp := RouteResponse{Description: resp.Description} + if len(resp.MediaTypes) > 0 { + copyResp.MediaTypes = append([]string(nil), resp.MediaTypes...) + } + cloned[code] = copyResp } + return cloned } func (app *App) normalizePath(path string) string { @@ -521,9 +638,13 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler Params: parsedRaw.params, group: group, - Path: pathRaw, - Method: method, - Handlers: handlers, + Path: pathRaw, + Method: method, + Handlers: handlers, + Summary: "", + Description: "", + Consumes: MIMETextPlain, + Produces: MIMETextPlain, } // Increment global handler count diff --git a/router_test.go b/router_test.go index 74591563ddd..58a0a1b4b28 100644 --- a/router_test.go +++ b/router_test.go @@ -1421,3 +1421,134 @@ func Test_AddRoute_MergeHandlers(t *testing.T) { require.Len(t, app.stack[app.methodInt(MethodGet)], 1) require.Len(t, app.stack[app.methodInt(MethodGet)][0].Handlers, 2) } + +func Test_Route_InvalidMediaType(t *testing.T) { + t.Run("produces", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Get("/", testEmptyHandler).Produces("invalid") + }) + }) + t.Run("consumes", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Get("/", testEmptyHandler).Consumes("invalid") + }) + }) + t.Run("request body", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Post("/", testEmptyHandler).RequestBody("payload", true, "invalid") + }) + }) + t.Run("request body missing type", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Post("/", testEmptyHandler).RequestBody("payload", true) + }) + }) + t.Run("response", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Get("/", testEmptyHandler).Response(StatusOK, "", "invalid") + }) + }) + t.Run("parameter", func(t *testing.T) { + app := New() + require.Panics(t, func() { + app.Get("/", testEmptyHandler).Parameter("foo", "body", true, nil, "") + }) + }) +} + +func Test_App_Produces(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", testEmptyHandler).Produces(MIMEApplicationJSON) + route := app.stack[app.methodInt(MethodGet)][0] + //nolint:testifylint // MIMEApplicationJSON is a plain string, JSONEq not required + require.Equal(t, MIMEApplicationJSON, route.Produces) +} + +func Test_App_RequestBody(t *testing.T) { + t.Parallel() + app := New() + app.Post("/users", testEmptyHandler). + RequestBody("User payload", true, MIMEApplicationJSON, MIMEApplicationXML) + + route := app.stack[app.methodInt(MethodPost)][0] + require.NotNil(t, route.RequestBody) + require.Equal(t, "User payload", route.RequestBody.Description) + require.True(t, route.RequestBody.Required) + require.Equal(t, []string{MIMEApplicationJSON, MIMEApplicationXML}, route.RequestBody.MediaTypes) + //nolint:testifylint // MIMEApplicationJSON is a plain string, JSONEq not required + require.Equal(t, MIMEApplicationJSON, route.Consumes) +} + +func Test_App_Parameter(t *testing.T) { + t.Parallel() + app := New() + app.Get("/:id", testEmptyHandler). + Parameter("id", "path", false, map[string]any{"type": "integer"}, "identifier"). + Parameter("filter", "query", true, nil, "Filter results") + + route := app.stack[app.methodInt(MethodGet)][0] + require.Len(t, route.Parameters, 2) + + pathParam := route.Parameters[0] + require.Equal(t, "id", pathParam.Name) + require.Equal(t, "path", pathParam.In) + require.True(t, pathParam.Required) + require.Equal(t, "integer", pathParam.Schema["type"]) + require.Equal(t, "identifier", pathParam.Description) + + queryParam := route.Parameters[1] + require.Equal(t, "filter", queryParam.Name) + require.Equal(t, "query", queryParam.In) + require.True(t, queryParam.Required) + require.Equal(t, "string", queryParam.Schema["type"]) + require.Equal(t, "Filter results", queryParam.Description) +} + +func Test_App_Response(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", testEmptyHandler). + Response(StatusOK, "OK", MIMEApplicationJSON). + Response(StatusCreated, "Created", MIMEApplicationJSON). + Response(0, "Default fallback") + + route := app.stack[app.methodInt(MethodGet)][0] + //nolint:testifylint // MIMEApplicationJSON is a plain string, JSONEq not required + require.Equal(t, MIMEApplicationJSON, route.Produces) + require.Len(t, route.Responses, 3) + + okResp, ok := route.Responses["200"] + require.True(t, ok) + require.Equal(t, "OK", okResp.Description) + require.Equal(t, []string{MIMEApplicationJSON}, okResp.MediaTypes) + + created, ok := route.Responses["201"] + require.True(t, ok) + require.Equal(t, "Created", created.Description) + + defResp, ok := route.Responses["default"] + require.True(t, ok) + require.Equal(t, "Default fallback", defResp.Description) +} + +func Test_App_Response_InvalidStatus(t *testing.T) { + t.Parallel() + app := New() + require.Panics(t, func() { + app.Get("/", testEmptyHandler).Response(42, "invalid") + }) +} + +func Test_App_Deprecated(t *testing.T) { + t.Parallel() + app := New() + app.Get("/", testEmptyHandler).Deprecated() + route := app.stack[app.methodInt(MethodGet)][0] + require.True(t, route.Deprecated) +}