-
Notifications
You must be signed in to change notification settings - Fork 2
/
proxy.go
129 lines (109 loc) · 2.43 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package parapet
import (
"net"
"net/http"
"strings"
)
// TrustCIDRs trusts given CIDR list
func TrustCIDRs(s []string) Conditional {
trust := parseCIDRs(s)
if len(trust) == 0 {
return func(r *http.Request) bool {
return false
}
}
return func(r *http.Request) bool {
remoteIP := net.ParseIP(parseHost(r.RemoteAddr))
if remoteIP == nil {
return false
}
for _, p := range trust {
if p.Contains(remoteIP) {
return true
}
}
return false
}
}
// Trusted trusts all remotes
func Trusted() Conditional {
return func(r *http.Request) bool {
return true
}
}
const (
headerXForwardedFor = "X-Forwarded-For"
headerXForwardedProto = "X-Forwarded-Proto"
headerXRealIP = "X-Real-Ip"
)
//nolint:govet
type proxy struct {
Trust func(r *http.Request) bool
ComputeFullForwardedFor bool
Handler http.Handler
}
func (m *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if m.Trust == nil {
m.distrust(w, r)
return
}
if m.Trust(r) {
m.trust(w, r)
return
}
m.distrust(w, r)
}
func (m *proxy) trust(w http.ResponseWriter, r *http.Request) {
// TODO: handle compute full forwarded for from server
if m.ComputeFullForwardedFor {
remoteIP := parseHost(r.RemoteAddr)
if p := r.Header.Get(headerXForwardedFor); p == "" {
r.Header.Set(headerXForwardedFor, remoteIP)
} else {
r.Header.Set(headerXForwardedFor, p+", "+remoteIP)
}
}
if r.Header.Get(headerXRealIP) == "" {
r.Header.Set(headerXRealIP, firstHost(r.Header.Get(headerXForwardedFor)))
}
if r.Header.Get(headerXForwardedProto) == "" {
if r.TLS == nil {
r.Header.Set(headerXForwardedProto, "http")
} else {
r.Header.Set(headerXForwardedProto, "https")
}
}
m.Handler.ServeHTTP(w, r)
}
func (m *proxy) distrust(w http.ResponseWriter, r *http.Request) {
remoteIP := parseHost(r.RemoteAddr)
r.Header.Set(headerXForwardedFor, remoteIP)
r.Header.Set(headerXRealIP, remoteIP)
if r.TLS == nil {
r.Header.Set(headerXForwardedProto, "http")
} else {
r.Header.Set(headerXForwardedProto, "https")
}
m.Handler.ServeHTTP(w, r)
}
func parseHost(s string) string {
host, _, _ := net.SplitHostPort(s)
return host
}
func firstHost(s string) string {
i := strings.Index(s, ",")
if i < 0 {
return s
}
return s[:i]
}
func parseCIDRs(xs []string) []*net.IPNet {
var rs []*net.IPNet
for _, x := range xs {
_, n, _ := net.ParseCIDR(x)
if n != nil {
rs = append(rs, n)
}
}
return rs
}