Skip to content

Commit 9edf9ea

Browse files
committed
refactor(api): extract middleware logic into a separate package
1 parent 26ed81e commit 9edf9ea

File tree

5 files changed

+152
-123
lines changed

5 files changed

+152
-123
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package middleware
2+
3+
import (
4+
"errors"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/gin-gonic/gin"
9+
"go.opentelemetry.io/otel/attribute"
10+
"go.opentelemetry.io/otel/trace"
11+
"gorm.io/gorm"
12+
13+
"github.com/memodb-io/Acontext/internal/config"
14+
"github.com/memodb-io/Acontext/internal/modules/model"
15+
"github.com/memodb-io/Acontext/internal/modules/serializer"
16+
"github.com/memodb-io/Acontext/internal/pkg/utils/secrets"
17+
"github.com/memodb-io/Acontext/internal/pkg/utils/tokens"
18+
)
19+
20+
// ProjectAuth returns a middleware that authenticates requests using project bearer tokens.
21+
// It validates the token, looks up the project in the database, and sets the project in the context.
22+
// It also sets the project_id attribute on the current span for telemetry filtering.
23+
func ProjectAuth(cfg *config.Config, db *gorm.DB) gin.HandlerFunc {
24+
return func(c *gin.Context) {
25+
auth := c.GetHeader("Authorization")
26+
if !strings.HasPrefix(auth, "Bearer ") {
27+
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
28+
return
29+
}
30+
raw := strings.TrimPrefix(auth, "Bearer ")
31+
32+
secret, ok := tokens.ParseToken(raw, cfg.Root.ProjectBearerTokenPrefix)
33+
if !ok {
34+
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
35+
return
36+
}
37+
38+
lookup := tokens.HMAC256Hex(cfg.Root.SecretPepper, secret)
39+
40+
var project model.Project
41+
if err := db.WithContext(c.Request.Context()).Where(&model.Project{SecretKeyHMAC: lookup}).First(&project).Error; err != nil {
42+
if errors.Is(err, gorm.ErrRecordNotFound) {
43+
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
44+
return
45+
}
46+
c.AbortWithStatusJSON(http.StatusInternalServerError, serializer.DBErr("", err))
47+
return
48+
}
49+
50+
pass, err := secrets.VerifySecret(secret, cfg.Root.SecretPepper, project.SecretKeyHashPHC)
51+
if err != nil || !pass {
52+
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
53+
return
54+
}
55+
56+
// Set project_id attribute on the current span for telemetry filtering
57+
span := trace.SpanFromContext(c.Request.Context())
58+
if span.SpanContext().IsValid() {
59+
span.SetAttributes(attribute.String("project_id", project.ID.String()))
60+
}
61+
62+
c.Set("project", &project)
63+
c.Next()
64+
}
65+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package middleware
2+
3+
import (
4+
"strings"
5+
"time"
6+
7+
"github.com/gin-gonic/gin"
8+
"go.uber.org/zap"
9+
)
10+
11+
// ZapLogger returns a middleware that logs HTTP requests using zap logger.
12+
// It logs API paths (/api/*) at info level and other paths at debug level.
13+
func ZapLogger(log *zap.Logger) gin.HandlerFunc {
14+
return func(c *gin.Context) {
15+
start := time.Now()
16+
c.Next()
17+
dur := time.Since(start)
18+
19+
// Use debug level for all paths except /api/*
20+
path := c.Request.URL.Path
21+
isAPIPath := strings.HasPrefix(path, "/api/")
22+
23+
if isAPIPath {
24+
log.Sugar().Infow("HTTP",
25+
"method", c.Request.Method,
26+
"path", path,
27+
"status", c.Writer.Status(),
28+
"latency", dur.String(),
29+
"clientIP", c.ClientIP(),
30+
)
31+
} else {
32+
log.Sugar().Debugw("HTTP",
33+
"method", c.Request.Method,
34+
"path", path,
35+
"status", c.Writer.Status(),
36+
"latency", dur.String(),
37+
"clientIP", c.ClientIP(),
38+
)
39+
}
40+
}
41+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package middleware
2+
3+
import (
4+
"strings"
5+
6+
"github.com/gin-gonic/gin"
7+
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
8+
"go.opentelemetry.io/otel/trace"
9+
)
10+
11+
// OtelTracing returns a middleware for OpenTelemetry instrumentation.
12+
// It only traces requests that match /api/ paths to reduce overhead.
13+
func OtelTracing(serviceName string) gin.HandlerFunc {
14+
otelMiddleware := otelgin.Middleware(serviceName)
15+
16+
return func(c *gin.Context) {
17+
// Only instrument requests that start with /api/
18+
path := c.Request.URL.Path
19+
if strings.HasPrefix(path, "/api/") {
20+
otelMiddleware(c)
21+
} else {
22+
// Skip OpenTelemetry instrumentation for non-API paths
23+
c.Next()
24+
}
25+
}
26+
}
27+
28+
// TraceID returns a middleware that adds trace ID to response headers.
29+
// This is useful for correlating logs and traces in distributed systems.
30+
func TraceID() gin.HandlerFunc {
31+
return func(c *gin.Context) {
32+
// Get current span from context
33+
span := trace.SpanFromContext(c.Request.Context())
34+
if span.SpanContext().IsValid() {
35+
// Add trace ID to response header
36+
traceID := span.SpanContext().TraceID().String()
37+
c.Header("X-Trace-Id", traceID)
38+
}
39+
c.Next()
40+
}
41+
}

src/server/api/go/internal/router/router.go

Lines changed: 5 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,21 @@
11
package router
22

33
import (
4-
"errors"
54
"net/http"
6-
"strings"
7-
"time"
85

96
"github.com/gin-gonic/gin"
10-
"go.opentelemetry.io/otel/attribute"
11-
"go.opentelemetry.io/otel/trace"
127
"go.uber.org/zap"
138
"gorm.io/gorm"
149

1510
_ "github.com/memodb-io/Acontext/docs"
1611
"github.com/memodb-io/Acontext/internal/config"
12+
"github.com/memodb-io/Acontext/internal/middleware"
1713
"github.com/memodb-io/Acontext/internal/modules/handler"
18-
"github.com/memodb-io/Acontext/internal/modules/model"
1914
"github.com/memodb-io/Acontext/internal/modules/serializer"
20-
"github.com/memodb-io/Acontext/internal/pkg/utils/secrets"
21-
"github.com/memodb-io/Acontext/internal/pkg/utils/tokens"
22-
"github.com/memodb-io/Acontext/internal/telemetry"
2315
swaggerFiles "github.com/swaggo/files"
2416
ginSwagger "github.com/swaggo/gin-swagger"
2517
)
2618

27-
// zapLoggerMiddleware
28-
func zapLoggerMiddleware(log *zap.Logger) gin.HandlerFunc {
29-
return func(c *gin.Context) {
30-
start := time.Now()
31-
c.Next()
32-
dur := time.Since(start)
33-
34-
// Use debug level for all paths except /api/*
35-
path := c.Request.URL.Path
36-
isAPIPath := strings.HasPrefix(path, "/api/")
37-
38-
if isAPIPath {
39-
log.Sugar().Infow("HTTP",
40-
"method", c.Request.Method,
41-
"path", path,
42-
"status", c.Writer.Status(),
43-
"latency", dur.String(),
44-
"clientIP", c.ClientIP(),
45-
)
46-
} else {
47-
log.Sugar().Debugw("HTTP",
48-
"method", c.Request.Method,
49-
"path", path,
50-
"status", c.Writer.Status(),
51-
"latency", dur.String(),
52-
"clientIP", c.ClientIP(),
53-
)
54-
}
55-
}
56-
}
57-
58-
// projectAuthMiddleware
59-
func projectAuthMiddleware(cfg *config.Config, db *gorm.DB) gin.HandlerFunc {
60-
return func(c *gin.Context) {
61-
auth := c.GetHeader("Authorization")
62-
if !strings.HasPrefix(auth, "Bearer ") {
63-
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
64-
return
65-
}
66-
raw := strings.TrimPrefix(auth, "Bearer ")
67-
68-
secret, ok := tokens.ParseToken(raw, cfg.Root.ProjectBearerTokenPrefix)
69-
if !ok {
70-
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
71-
return
72-
}
73-
74-
lookup := tokens.HMAC256Hex(cfg.Root.SecretPepper, secret)
75-
76-
var project model.Project
77-
if err := db.WithContext(c.Request.Context()).Where(&model.Project{SecretKeyHMAC: lookup}).First(&project).Error; err != nil {
78-
if errors.Is(err, gorm.ErrRecordNotFound) {
79-
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
80-
return
81-
}
82-
c.AbortWithStatusJSON(http.StatusInternalServerError, serializer.DBErr("", err))
83-
return
84-
}
85-
86-
pass, err := secrets.VerifySecret(secret, cfg.Root.SecretPepper, project.SecretKeyHashPHC)
87-
if err != nil || !pass {
88-
c.AbortWithStatusJSON(http.StatusUnauthorized, serializer.AuthErr("Unauthorized"))
89-
return
90-
}
91-
92-
// Set project_id attribute on the current span for telemetry filtering
93-
span := trace.SpanFromContext(c.Request.Context())
94-
if span.SpanContext().IsValid() {
95-
span.SetAttributes(attribute.String("project_id", project.ID.String()))
96-
}
97-
98-
c.Set("project", &project)
99-
c.Next()
100-
}
101-
}
102-
10319
type RouterDeps struct {
10420
Config *config.Config
10521
DB *gorm.DB
@@ -122,12 +38,12 @@ func NewRouter(d RouterDeps) *gin.Engine {
12238

12339
// Add OpenTelemetry middleware if enabled (using configuration system)
12440
if d.Config.Telemetry.Enabled && d.Config.Telemetry.OtlpEndpoint != "" {
125-
r.Use(telemetry.GinMiddleware(d.Config.App.Name))
41+
r.Use(middleware.OtelTracing(d.Config.App.Name))
12642
// Add trace ID to response header
127-
r.Use(telemetry.TraceIDMiddleware())
43+
r.Use(middleware.TraceID())
12844
}
12945

130-
r.Use(zapLoggerMiddleware(d.Log))
46+
r.Use(middleware.ZapLogger(d.Log))
13147

13248
// health
13349
r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, serializer.Response{Msg: "ok"}) })
@@ -140,7 +56,7 @@ func NewRouter(d RouterDeps) *gin.Engine {
14056

14157
v1 := r.Group("/api/v1")
14258
{
143-
v1.Use(projectAuthMiddleware(d.Config, d.DB))
59+
v1.Use(middleware.ProjectAuth(d.Config, d.DB))
14460

14561
// ping endpoint
14662
v1.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, serializer.Response{Msg: "pong"}) })

src/server/api/go/internal/telemetry/otel.go

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,14 @@ import (
66
"strings"
77
"time"
88

9-
"github.com/gin-gonic/gin"
109
"github.com/memodb-io/Acontext/internal/config"
11-
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
1210
"go.opentelemetry.io/otel"
1311
"go.opentelemetry.io/otel/attribute"
1412
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
1513
"go.opentelemetry.io/otel/propagation"
1614
"go.opentelemetry.io/otel/sdk/resource"
1715
sdktrace "go.opentelemetry.io/otel/sdk/trace"
1816
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
19-
"go.opentelemetry.io/otel/trace"
2017
)
2118

2219
var (
@@ -105,34 +102,3 @@ func Shutdown(ctx context.Context) error {
105102
}
106103
return nil
107104
}
108-
109-
// GinMiddleware returns Gin middleware for OpenTelemetry instrumentation
110-
// Only traces requests that match /api/ paths
111-
func GinMiddleware(serviceName string) gin.HandlerFunc {
112-
otelMiddleware := otelgin.Middleware(serviceName)
113-
114-
return func(c *gin.Context) {
115-
// Only instrument requests that start with /api/
116-
path := c.Request.URL.Path
117-
if strings.HasPrefix(path, "/api/") {
118-
otelMiddleware(c)
119-
} else {
120-
// Skip OpenTelemetry instrumentation for non-API paths
121-
c.Next()
122-
}
123-
}
124-
}
125-
126-
// TraceIDMiddleware returns a Gin middleware that adds trace ID to response headers
127-
func TraceIDMiddleware() gin.HandlerFunc {
128-
return func(c *gin.Context) {
129-
// Get current span from context
130-
span := trace.SpanFromContext(c.Request.Context())
131-
if span.SpanContext().IsValid() {
132-
// Add trace ID to response header
133-
traceID := span.SpanContext().TraceID().String()
134-
c.Header("X-Trace-Id", traceID)
135-
}
136-
c.Next()
137-
}
138-
}

0 commit comments

Comments
 (0)