@@ -6,14 +6,15 @@ import (
6
6
"bytes"
7
7
"crypto/sha1"
8
8
"encoding/base64"
9
+ "errors"
10
+ "fmt"
9
11
"io"
10
12
"net/http"
11
13
"net/textproto"
12
14
"net/url"
15
+ "strconv"
13
16
"strings"
14
17
15
- "golang.org/x/xerrors"
16
-
17
18
"nhooyr.io/websocket/internal/errd"
18
19
)
19
20
@@ -85,7 +86,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
85
86
86
87
hj , ok := w .(http.Hijacker )
87
88
if ! ok {
88
- err = xerrors .New ("http.ResponseWriter does not implement http.Hijacker" )
89
+ err = errors .New ("http.ResponseWriter does not implement http.Hijacker" )
89
90
http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
90
91
return nil , err
91
92
}
@@ -110,7 +111,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
110
111
111
112
netConn , brw , err := hj .Hijack ()
112
113
if err != nil {
113
- err = xerrors .Errorf ("failed to hijack connection: %w" , err )
114
+ err = fmt .Errorf ("failed to hijack connection: %w" , err )
114
115
http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
115
116
return nil , err
116
117
}
@@ -133,32 +134,32 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
133
134
134
135
func verifyClientRequest (w http.ResponseWriter , r * http.Request ) (errCode int , _ error ) {
135
136
if ! r .ProtoAtLeast (1 , 1 ) {
136
- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
137
+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
137
138
}
138
139
139
140
if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
140
141
w .Header ().Set ("Connection" , "Upgrade" )
141
142
w .Header ().Set ("Upgrade" , "websocket" )
142
- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
143
+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
143
144
}
144
145
145
146
if ! headerContainsToken (r .Header , "Upgrade" , "websocket" ) {
146
147
w .Header ().Set ("Connection" , "Upgrade" )
147
148
w .Header ().Set ("Upgrade" , "websocket" )
148
- return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
149
+ return http .StatusUpgradeRequired , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
149
150
}
150
151
151
152
if r .Method != "GET" {
152
- return http .StatusMethodNotAllowed , xerrors .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
153
+ return http .StatusMethodNotAllowed , fmt .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
153
154
}
154
155
155
156
if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
156
157
w .Header ().Set ("Sec-WebSocket-Version" , "13" )
157
- return http .StatusBadRequest , xerrors .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
158
+ return http .StatusBadRequest , fmt .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
158
159
}
159
160
160
161
if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
161
- return http .StatusBadRequest , xerrors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
162
+ return http .StatusBadRequest , errors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
162
163
}
163
164
164
165
return 0 , nil
@@ -169,10 +170,10 @@ func authenticateOrigin(r *http.Request) error {
169
170
if origin != "" {
170
171
u , err := url .Parse (origin )
171
172
if err != nil {
172
- return xerrors .Errorf ("failed to parse Origin header %q: %w" , origin , err )
173
+ return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
173
174
}
174
175
if ! strings .EqualFold (u .Host , r .Host ) {
175
- return xerrors .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
176
+ return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
176
177
}
177
178
}
178
179
return nil
@@ -208,6 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
208
209
209
210
func acceptDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
210
211
copts := mode .opts ()
212
+ copts .serverMaxWindowBits = 8
211
213
212
214
for _ , p := range ext .params {
213
215
switch p {
@@ -219,11 +221,31 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
219
221
continue
220
222
}
221
223
222
- if strings .HasPrefix (p , "client_max_window_bits" ) || strings .HasPrefix (p , "server_max_window_bits" ) {
224
+ if strings .HasPrefix (p , "client_max_window_bits" ) {
225
+ continue
226
+
227
+ // bits, ok := parseExtensionParameter(p, 15)
228
+ // if !ok || bits < 8 || bits > 16 {
229
+ // err := fmt.Errorf("invalid client_max_window_bits: %q", p)
230
+ // http.Error(w, err.Error(), http.StatusBadRequest)
231
+ // return nil, err
232
+ // }
233
+ // copts.clientMaxWindowBits = bits
234
+ // continue
235
+ }
236
+
237
+ if false && strings .HasPrefix (p , "server_max_window_bits" ) {
238
+ // We always send back 8 but make sure to validate.
239
+ bits , ok := parseExtensionParameter (p , 0 )
240
+ if ! ok || bits < 8 || bits > 16 {
241
+ err := fmt .Errorf ("invalid server_max_window_bits: %q" , p )
242
+ http .Error (w , err .Error (), http .StatusBadRequest )
243
+ return nil , err
244
+ }
223
245
continue
224
246
}
225
247
226
- err := xerrors .Errorf ("unsupported permessage-deflate parameter: %q" , p )
248
+ err := fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
227
249
http .Error (w , err .Error (), http .StatusBadRequest )
228
250
return nil , err
229
251
}
@@ -233,6 +255,21 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
233
255
return copts , nil
234
256
}
235
257
258
+ // parseExtensionParameter parses the value in the extension parameter p.
259
+ // It falls back to defaultVal if there is no value.
260
+ // If defaultVal == 0, then ok == false if there is no value.
261
+ func parseExtensionParameter (p string , defaultVal int ) (int , bool ) {
262
+ ps := strings .Split (p , "=" )
263
+ if len (ps ) == 1 {
264
+ if defaultVal > 0 {
265
+ return defaultVal , true
266
+ }
267
+ return 0 , false
268
+ }
269
+ i , e := strconv .Atoi (strings .Trim (ps [1 ], `"` ))
270
+ return i , e == nil
271
+ }
272
+
236
273
func acceptWebkitDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
237
274
copts := mode .opts ()
238
275
// The peer must explicitly request it.
@@ -253,7 +290,7 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
253
290
//
254
291
// Either way, we're only implementing this for webkit which never sends the max_window_bits
255
292
// parameter so we don't need to worry about it.
256
- err := xerrors .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
293
+ err := fmt .Errorf ("unsupported x-webkit-deflate-frame parameter: %q" , p )
257
294
http .Error (w , err .Error (), http .StatusBadRequest )
258
295
return nil , err
259
296
}
0 commit comments