@@ -52,12 +52,18 @@ func newErrClientCopyFailed(desc string) error {
52
52
return psqlerr .WithSeverity (psqlerr .WithCode (err , codes .Uncategorized ), psqlerr .LevelError )
53
53
}
54
54
55
+ type Session struct {
56
+ * Server
57
+ Statements StatementCache
58
+ Portals PortalCache
59
+ }
60
+
55
61
// consumeCommands consumes incoming commands sent over the Postgres wire connection.
56
62
// Commands consumed from the connection are returned through a go channel.
57
63
// Responses for the given message type are written back to the client.
58
64
// This method keeps consuming messages until the client issues a close message
59
65
// or the connection is terminated.
60
- func (srv * Server ) consumeCommands (ctx context.Context , conn net.Conn , reader * buffer.Reader , writer * buffer.Writer ) error {
66
+ func (srv * Session ) consumeCommands (ctx context.Context , conn net.Conn , reader * buffer.Reader , writer * buffer.Writer ) error {
61
67
srv .logger .Debug ("ready for query... starting to consume commands" )
62
68
63
69
// TODO: Include a value to identify unique connections
@@ -77,7 +83,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
77
83
}
78
84
}
79
85
80
- func (srv * Server ) consumeSingleCommand (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer , conn net.Conn ) error {
86
+ func (srv * Session ) consumeSingleCommand (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer , conn net.Conn ) error {
81
87
t , length , err := reader .ReadTypedMsg ()
82
88
if err == io .EOF {
83
89
return nil
@@ -141,7 +147,7 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc
141
147
// message type and reader buffer containing the actual message. The type
142
148
// indecates a action executed by the client.
143
149
// https://www.postgresql.org/docs/14/protocol-message-formats.html
144
- func (srv * Server ) handleCommand (ctx context.Context , conn net.Conn , t types.ClientMessage , reader * buffer.Reader , writer * buffer.Writer ) error {
150
+ func (srv * Session ) handleCommand (ctx context.Context , conn net.Conn , t types.ClientMessage , reader * buffer.Reader , writer * buffer.Writer ) error {
145
151
ctx , cancel := context .WithCancel (ctx )
146
152
defer cancel ()
147
153
@@ -236,7 +242,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
236
242
}
237
243
}
238
244
239
- func (srv * Server ) handleSimpleQuery (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
245
+ func (srv * Session ) handleSimpleQuery (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
240
246
if srv .parse == nil {
241
247
return ErrorCode (writer , NewErrUnimplementedMessageType (types .ClientSimpleQuery ))
242
248
}
@@ -287,7 +293,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
287
293
return readyForQuery (writer , types .ServerIdle )
288
294
}
289
295
290
- func (srv * Server ) handleParse (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
296
+ func (srv * Session ) handleParse (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
291
297
if srv .parse == nil || srv .Statements == nil {
292
298
return ErrorCode (writer , NewErrUnimplementedMessageType (types .ClientParse ))
293
299
}
@@ -337,7 +343,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
337
343
return writer .End ()
338
344
}
339
345
340
- func (srv * Server ) handleDescribe (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
346
+ func (srv * Session ) handleDescribe (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
341
347
d , err := reader .GetBytes (1 )
342
348
if err != nil {
343
349
return err
@@ -385,7 +391,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
385
391
}
386
392
387
393
// https://www.postgresql.org/docs/15/protocol-message-formats.html
388
- func (srv * Server ) writeParameterDescription (writer * buffer.Writer , parameters []oid.Oid ) error {
394
+ func (srv * Session ) writeParameterDescription (writer * buffer.Writer , parameters []oid.Oid ) error {
389
395
writer .Start (types .ServerParameterDescription )
390
396
writer .AddInt16 (int16 (len (parameters )))
391
397
@@ -400,7 +406,7 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [
400
406
// back to the writer buffer. Information about the returned columns is written
401
407
// to the client.
402
408
// https://www.postgresql.org/docs/15/protocol-message-formats.html
403
- func (srv * Server ) writeColumnDescription (ctx context.Context , writer * buffer.Writer , formats []FormatCode , columns Columns ) error {
409
+ func (srv * Session ) writeColumnDescription (ctx context.Context , writer * buffer.Writer , formats []FormatCode , columns Columns ) error {
404
410
if len (columns ) == 0 {
405
411
writer .Start (types .ServerNoData )
406
412
return writer .End ()
@@ -409,7 +415,7 @@ func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Wr
409
415
return columns .Define (ctx , writer , formats )
410
416
}
411
417
412
- func (srv * Server ) handleBind (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
418
+ func (srv * Session ) handleBind (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
413
419
name , err := reader .GetString ()
414
420
if err != nil {
415
421
return err
@@ -451,7 +457,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
451
457
// readParameters attempts to read all incoming parameters from the given
452
458
// reader. The parameters are parsed and returned.
453
459
// https://www.postgresql.org/docs/14/protocol-message-formats.html
454
- func (srv * Server ) readParameters (ctx context.Context , reader * buffer.Reader ) ([]Parameter , error ) {
460
+ func (srv * Session ) readParameters (ctx context.Context , reader * buffer.Reader ) ([]Parameter , error ) {
455
461
// NOTE: read the total amount of parameter format length that will be send
456
462
// by the client. This can be zero to indicate that there are no parameters
457
463
// or that the parameters all use the default format (text); or one, in
@@ -516,7 +522,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
516
522
return parameters , nil
517
523
}
518
524
519
- func (srv * Server ) readColumnTypes (reader * buffer.Reader ) ([]FormatCode , error ) {
525
+ func (srv * Session ) readColumnTypes (reader * buffer.Reader ) ([]FormatCode , error ) {
520
526
length , err := reader .GetUint16 ()
521
527
if err != nil {
522
528
return nil , err
@@ -537,7 +543,7 @@ func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error)
537
543
return columns , nil
538
544
}
539
545
540
- func (srv * Server ) handleExecute (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
546
+ func (srv * Session ) handleExecute (ctx context.Context , reader * buffer.Reader , writer * buffer.Writer ) error {
541
547
if srv .Statements == nil {
542
548
return ErrorCode (writer , NewErrUnimplementedMessageType (types .ClientExecute ))
543
549
}
@@ -565,7 +571,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
565
571
return nil
566
572
}
567
573
568
- func (srv * Server ) handleConnTerminate (ctx context.Context ) error {
574
+ func (srv * Session ) handleConnTerminate (ctx context.Context ) error {
569
575
if srv .TerminateConn == nil {
570
576
return nil
571
577
}
0 commit comments