This repository has been archived by the owner on Mar 5, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathproxy.go
614 lines (554 loc) · 15.6 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
package main
import (
"bytes"
"crypto/tls"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"time"
// debuggin' -- used for runtime profiling/debugging
"net/http"
_ "net/http/pprof"
"github.com/voutilad/bolt-proxy/backend"
"github.com/voutilad/bolt-proxy/bolt"
"github.com/voutilad/bolt-proxy/health"
"github.com/gobwas/ws"
)
const (
// A basic idle timeout duration for now
MAX_IDLE_MINS int = 30
// max bytes to display in logs in debug mode
MAX_BYTES int = 32
)
var (
debug *log.Logger
info *log.Logger
warn *log.Logger
)
// Crude logging routine for helping debug bolt Messages. Tries not to clutter
// output too much due to large messages while trying to deliniate who logged
// the message.
func logMessage(who string, msg *bolt.Message) {
end := MAX_BYTES
suffix := fmt.Sprintf("...+%d bytes", len(msg.Data))
if len(msg.Data) < MAX_BYTES {
end = len(msg.Data)
suffix = ""
}
switch msg.T {
case bolt.HelloMsg:
// make sure we don't print the secrets in a Hello!
debug.Printf("[%s] <%s>: %#v\n\n", who, msg.T, msg.Data[:4])
case bolt.BeginMsg, bolt.FailureMsg:
debug.Printf("[%s] <%s>: %#v\n%s\n", who, msg.T, msg.Data[:end], msg.Data)
default:
debug.Printf("[%s] <%s>: %#v%s, last2:%#v\n", who, msg.T, msg.Data[:end], suffix, msg.Data[len(msg.Data)-2:])
}
}
func logMessages(who string, messages []*bolt.Message) {
for _, msg := range messages {
logMessage(who, msg)
}
}
// Primary Transaction server-side event handler, collecting Messages from
// the backend Bolt server and writing them to the given client.
//
// Since this should be running async to process server Messages as they
// arrive, two channels are provided for signaling:
//
// ack: used for letting this handler to signal that it's completed and
// stopping execution, basically a way to confirm the requested halt
//
// halt: used by an external routine to request this handler to cleanly
// stop execution
//
func handleTx(client, server bolt.BoltConn, ack chan<- bool, halt <-chan bool) {
finished := false
for !finished {
select {
case msg, ok := <-server.R():
if ok {
logMessage("P<-S", msg)
err := client.WriteMessage(msg)
if err != nil {
panic(err)
}
logMessage("C<-P", msg)
// if know the server side is saying goodbye,
// we abort the loop
if msg.T == bolt.GoodbyeMsg {
finished = true
}
} else {
debug.Println("potential server hangup")
finished = true
}
case <-halt:
finished = true
case <-time.After(time.Duration(MAX_IDLE_MINS) * time.Minute):
warn.Println("timeout reading server!")
finished = true
}
}
select {
case ack <- true:
debug.Println("tx handler stop ACK sent")
default:
warn.Println("couldn't put value in ack channel?!")
}
}
// Identify if a new connection is valid Bolt or Bolt-over-Websocket
// connection based on handshakes.
//
// If so, wrap the incoming conn into a BoltConn and pass it off to
// a client handler
func handleClient(conn net.Conn, b *backend.Backend) {
defer func() {
debug.Printf("closing client connection from %s\n",
conn.RemoteAddr())
conn.Close()
}()
// XXX why 1024? I've observed long user-agents that make this
// pass the 512 mark easily, so let's be safe and go a full 1kb
buf := make([]byte, 1024)
n, err := conn.Read(buf[:4])
if err != nil || n != 4 {
warn.Println("bad connection from", conn.RemoteAddr())
return
}
if bytes.Equal(buf[:4], []byte{0x60, 0x60, 0xb0, 0x17}) {
// First case: we have a direct bolt client connection
n, err := conn.Read(buf[:20])
if err != nil {
warn.Println("error peeking at connection from", conn.RemoteAddr())
return
}
// Make sure we try to use the version we're using the best
// version based on the backend server
serverVersion := b.Version().Bytes()
clientVersion, err := bolt.ValidateHandshake(buf[:n], serverVersion)
if err != nil {
warn.Fatal(err)
}
_, err = conn.Write(clientVersion)
if err != nil {
warn.Fatal(err)
}
// regular bolt
handleBoltConn(bolt.NewDirectConn(conn), clientVersion, b)
} else if bytes.Equal(buf[:4], []byte{0x47, 0x45, 0x54, 0x20}) {
// Second case, we have an HTTP connection that might just
// be a WebSocket upgrade OR a health check.
// Read the rest of the request
n, err = conn.Read(buf[4:])
if err != nil {
warn.Printf("failed reading rest of GET request: %s\n", err)
return
}
// Health check, maybe? If so, handle and bail.
if health.IsHealthCheck(buf[:n+4]) {
err = health.HandleHealthCheck(conn, buf[:n+4])
if err != nil {
warn.Println(err)
}
return
}
// Build something implementing the io.ReadWriter interface
// to pass to the upgrader routine
iobuf := bytes.NewBuffer(buf[:n+4])
_, err := ws.Upgrade(iobuf)
if err != nil {
warn.Printf("failed to upgrade websocket client %s: %s\n",
conn.RemoteAddr(), err)
return
}
// Relay the upgrade response
_, err = io.Copy(conn, iobuf)
if err != nil {
warn.Printf("failed to copy upgrade to client %s\n",
conn.RemoteAddr())
return
}
// After upgrade, we should get a WebSocket message with header
header, err := ws.ReadHeader(conn)
if err != nil {
warn.Printf("failed to read ws header from client %s: %s\n",
conn.RemoteAddr(), err)
return
}
n, err := conn.Read(buf[:header.Length])
if err != nil {
warn.Printf("failed to read payload from client %s\n",
conn.RemoteAddr())
return
}
if header.Masked {
ws.Cipher(buf[:n], header.Mask, 0)
}
// We expect we can now do the initial Bolt handshake
magic, handshake := buf[:4], buf[4:20] // blaze it
valid, err := bolt.ValidateMagic(magic)
if !valid {
warn.Fatal(err)
}
// negotiate client & server side bolt versions
serverVersion := b.Version().Bytes()
clientVersion, err := bolt.ValidateHandshake(handshake, serverVersion)
if err != nil {
warn.Fatal(err)
}
// Complete Bolt handshake via WebSocket frame
frame := ws.NewBinaryFrame(clientVersion)
if err = ws.WriteFrame(conn, frame); err != nil {
warn.Fatal(err)
}
// Let there be Bolt-via-WebSockets!
handleBoltConn(bolt.NewWsConn(conn), clientVersion, b)
} else {
// not bolt, not http...something else?
info.Printf("client %s is speaking gibberish: %#v\n",
conn.RemoteAddr(), buf[:4])
}
}
// Primary Transaction client-side event handler, collecting Messages from
// the Bolt client and finding ways to switch them to the proper backend.
//
// The event loop...
//
// TOOD: this logic should be split out between the authentication and the
// event loop. For now, this does both.
func handleBoltConn(client bolt.BoltConn, clientVersion []byte, b *backend.Backend) {
// Intercept HELLO message for authentication and hold onto it
// for use in backend authentication
var hello *bolt.Message
select {
case msg, ok := <-client.R():
if !ok {
warn.Println("failed to read expected Hello from client")
return
}
hello = msg
case <-time.After(30 * time.Second):
warn.Println("timed out waiting for client to auth")
return
}
logMessage("C->P", hello)
if hello.T != bolt.HelloMsg {
debug.Println("expected HelloMsg, got:", hello.T)
return
}
// get backend connection
pool, err := b.Authenticate(hello)
if err != nil {
warn.Println(err)
return
}
// TODO: this seems odd...move parser and version stuff to bolt pkg
v, _ := backend.ParseVersion(clientVersion)
info.Printf("authenticated client %s speaking %s to %d host(s)\n",
client, v, len(pool))
defer func() {
info.Printf("goodbye to client %s\n", client)
}()
// TODO: Replace hardcoded Success message with dynamic one
success := bolt.Message{
T: bolt.SuccessMsg,
Data: []byte{
0x0, 0x2b, 0xb1, 0x70,
0xa2,
0x86, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x8b, 0x4e, 0x65, 0x6f, 0x34, 0x6a, 0x2f, 0x34, 0x2e,
0x32, 0x2e, 0x30,
0x8d, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64,
0x86, 0x62, 0x6f, 0x6c, 0x74, 0x2d, 0x34,
0x00, 0x00}}
logMessage("P->C", &success)
err = client.WriteMessage(&success)
if err != nil {
warn.Fatal(err)
}
// Time to begin the client-side event loop!
startingTx := false
manualTx := false
halt := make(chan bool, 1)
ack := make(chan bool, 1)
var server bolt.BoltConn
for {
var msg *bolt.Message
select {
case m, ok := <-client.R():
if ok {
msg = m
logMessage("C->P", msg)
} else {
debug.Println("potential client hangup")
select {
case halt <- true:
debug.Println("client hungup, asking tx to halt")
default:
warn.Println("failed to send halt message to tx handler")
}
return
}
case <-time.After(time.Duration(MAX_IDLE_MINS) * time.Minute):
warn.Println("client idle timeout")
return
}
if msg == nil {
// happens during websocket timeout?
panic("msg is nil")
}
// Inspect the client's message to discern transaction state
// We need to figure out if a transaction is starting and
// what kind of transaction (manual, auto, etc.) it might be.
switch msg.T {
case bolt.BeginMsg:
startingTx = true
manualTx = true
case bolt.RunMsg:
if !manualTx {
startingTx = true
}
case bolt.CommitMsg, bolt.RollbackMsg:
manualTx = false
startingTx = false
}
// XXX: This is a mess, but if we're starting a new transaction
// we need to find a new connection to switch to
if startingTx {
mode, _ := bolt.ValidateMode(msg.Data)
info, err := b.ClusterInfo()
if err != nil {
warn.Printf("error getting cluster info: %s\n", err)
return
}
db := info.DefaultDb
// get the db name, if any. otherwise, use default
var (
m map[string]interface{}
n int
)
if msg.T == bolt.BeginMsg {
m, _, err = bolt.ParseMap(msg.Data[4:])
if err != nil {
warn.Println(err)
return
}
} else if msg.T == bolt.RunMsg {
pos := 4
// query
_, n, err = bolt.ParseString(msg.Data[pos:])
if err != nil {
warn.Println(err)
return
}
pos = pos + n
// query params
_, n, err = bolt.ParseMap(msg.Data[pos:])
if err != nil {
warn.Println(err)
return
}
pos = pos + n
// metadata..like the db name!
m, _, err = bolt.ParseMap(msg.Data[pos:])
if err != nil {
warn.Println(err)
return
}
} else {
panic("shouldn't be starting a tx without a Begin or Run message")
}
// Extract the db name, if any
val, found := m["db"]
if found {
ok := false
db, ok = val.(string)
if !ok {
panic("db name wasn't a string?!")
}
} else {
debug.Printf("using default db of %s\n", db)
}
// Just choose the first one for now...something simple
rt, err := b.RoutingTable(db)
if err != nil {
warn.Printf("error getting routing table for %s: %s\n", db, err)
return
}
var hosts []string
if mode == bolt.ReadMode {
hosts = rt.Readers
} else {
hosts = rt.Writers
}
if err != nil {
warn.Printf("couldn't find host for '%s' in routing table", db)
}
if len(hosts) < 1 {
warn.Println("empty hosts lists for database", db)
// TODO: return FailureMsg???
return
}
host := hosts[0]
// Are we already using a host? If so try to stop the
// current tx handler before we create a new one
if server != nil {
select {
case halt <- true:
debug.Println("...asking current tx handler to halt")
select {
case <-ack:
debug.Println("tx handler ack'd stop")
case <-time.After(5 * time.Second):
warn.Println("!!! timeout waiting for ack from tx handler")
}
default:
// this shouldn't happen!
panic("couldn't send halt to tx handler!")
}
}
// Grab our host from our local pool
ok := false
server, ok = pool[host]
if !ok {
warn.Println("no established connection for host", host)
return
}
debug.Printf("grabbed conn for %s-access to db %s on host %s\n", mode, db, host)
// TODO: refactor channel handling...probably have handleTx() return new ones
// instead of reusing the same ones. If we don't create new ones, there could
// be lingering halt/ack messages. :-(
halt = make(chan bool, 1)
ack = make(chan bool, 1)
// kick off a new tx handler routine
go handleTx(client, server, ack, halt)
startingTx = false
}
// TODO: this connected/not-connected handling looks messy
if server != nil {
err = server.WriteMessage(msg)
if err != nil {
// TODO: figure out best way to handle failed writes
panic(err)
}
logMessage("P->S", msg)
} else {
// we have no connection since there's no tx...
// handle only specific, simple messages
switch msg.T {
case bolt.ResetMsg:
// XXX: Neo4j Desktop does this when defining a
// remote dbms connection.
// simply send empty success message
client.WriteMessage(&bolt.Message{
T: bolt.SuccessMsg,
Data: []byte{
0x00, 0x03,
0xb1, 0x70,
0xa0,
0x00, 0x00,
},
})
case bolt.GoodbyeMsg:
// bye!
return
}
}
}
}
const (
DEFAULT_BIND string = "localhost:8888"
DEFAULT_URI string = "bolt://localhost:7687"
DEFAULT_USER string = "neo4j"
)
func main() {
var (
debugMode bool
bindOn string
proxyTo string
username, password string
certFile, keyFile string
)
bindOn, found := os.LookupEnv("BOLT_PROXY_BIND")
if !found {
bindOn = DEFAULT_BIND
}
proxyTo, found = os.LookupEnv("BOLT_PROXY_URI")
if !found {
proxyTo = DEFAULT_URI
}
username, found = os.LookupEnv("BOLT_PROXY_USER")
if !found {
username = DEFAULT_USER
}
_, debugMode = os.LookupEnv("BOLT_PROXY_DEBUG")
password = os.Getenv("BOLT_PROXY_PASSWORD")
certFile = os.Getenv("BOLT_PROXY_CERT")
keyFile = os.Getenv("BOLT_PROXY_KEY")
// to keep it easy, let the defaults be populated by the env vars
flag.StringVar(&bindOn, "bind", bindOn, "host:port to bind to")
flag.StringVar(&proxyTo, "uri", proxyTo, "bolt uri for remote Neo4j")
flag.StringVar(&username, "user", username, "Neo4j username")
flag.StringVar(&password, "pass", password, "Neo4j password")
flag.StringVar(&certFile, "cert", certFile, "x509 certificate")
flag.StringVar(&keyFile, "key", keyFile, "x509 private key")
flag.BoolVar(&debugMode, "debug", debugMode, "enable debug logging")
flag.Parse()
// We log to stdout because our parents raised us right
info = log.New(os.Stdout, "INFO ", log.Ldate|log.Ltime|log.Lmsgprefix)
if debugMode {
debug = log.New(os.Stdout, "DEBUG ", log.Ldate|log.Ltime|log.Lmsgprefix)
} else {
debug = log.New(ioutil.Discard, "DEBUG ", 0)
}
warn = log.New(os.Stderr, "WARN ", log.Ldate|log.Ltime|log.Lmsgprefix)
// ---------- pprof debugger
go func() {
info.Println(http.ListenAndServe("localhost:6060", nil))
}()
// ---------- BACK END
info.Println("starting bolt-proxy backend")
backend, err := backend.NewBackend(debug, username, password, proxyTo)
if err != nil {
warn.Fatal(err)
}
info.Println("connected to backend", proxyTo)
info.Printf("found backend version %s\n", backend.Version())
// ---------- FRONT END
info.Println("starting bolt-proxy frontend")
var listener net.Listener
if certFile == "" || keyFile == "" {
// non-tls
listener, err = net.Listen("tcp", bindOn)
if err != nil {
warn.Fatal(err)
}
info.Printf("listening on %s\n", bindOn)
} else {
// tls
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
warn.Fatal(err)
}
config := &tls.Config{Certificates: []tls.Certificate{cert}}
listener, err = tls.Listen("tcp", bindOn, config)
if err != nil {
warn.Fatal(err)
}
info.Printf("listening for TLS connections on %s\n", bindOn)
}
// ---------- Event Loop
for {
conn, err := listener.Accept()
if err != nil {
warn.Printf("error: %v\n", err)
} else {
go handleClient(conn, backend)
}
}
}