Skip to content

Commit 59c4c8b

Browse files
committed
fix: isolate statements and portals
1 parent bb11a1b commit 59c4c8b

File tree

6 files changed

+48
-28
lines changed

6 files changed

+48
-28
lines changed

cache.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ type Statement struct {
1515
columns Columns
1616
}
1717

18+
func DefaultStatementCacheFn() StatementCache {
19+
return &DefaultStatementCache{}
20+
}
21+
1822
type DefaultStatementCache struct {
1923
statements map[string]*Statement
2024
mu sync.RWMutex
@@ -63,6 +67,10 @@ type Portal struct {
6367
formats []FormatCode
6468
}
6569

70+
func DefaultPortalCacheFn() PortalCache {
71+
return &DefaultPortalCache{}
72+
}
73+
6674
type DefaultPortalCache struct {
6775
portals map[string]*Portal
6876
mu sync.RWMutex

command.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,18 @@ func newErrClientCopyFailed(desc string) error {
5252
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError)
5353
}
5454

55+
type Session struct {
56+
*Server
57+
Statements StatementCache
58+
Portals PortalCache
59+
}
60+
5561
// consumeCommands consumes incoming commands sent over the Postgres wire connection.
5662
// Commands consumed from the connection are returned through a go channel.
5763
// Responses for the given message type are written back to the client.
5864
// This method keeps consuming messages until the client issues a close message
5965
// or the connection is terminated.
60-
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
66+
func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
6167
srv.logger.Debug("ready for query... starting to consume commands")
6268

6369
// TODO: Include a value to identify unique connections
@@ -77,7 +83,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
7783
}
7884
}
7985

80-
func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
86+
func (srv *Session) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
8187
t, length, err := reader.ReadTypedMsg()
8288
if err == io.EOF {
8389
return nil
@@ -141,7 +147,7 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc
141147
// message type and reader buffer containing the actual message. The type
142148
// indecates a action executed by the client.
143149
// https://www.postgresql.org/docs/14/protocol-message-formats.html
144-
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
150+
func (srv *Session) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
145151
ctx, cancel := context.WithCancel(ctx)
146152
defer cancel()
147153

@@ -236,7 +242,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
236242
}
237243
}
238244

239-
func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
245+
func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
240246
if srv.parse == nil {
241247
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery))
242248
}
@@ -287,7 +293,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
287293
return readyForQuery(writer, types.ServerIdle)
288294
}
289295

290-
func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
296+
func (srv *Session) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
291297
if srv.parse == nil || srv.Statements == nil {
292298
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientParse))
293299
}
@@ -337,7 +343,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
337343
return writer.End()
338344
}
339345

340-
func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
346+
func (srv *Session) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
341347
d, err := reader.GetBytes(1)
342348
if err != nil {
343349
return err
@@ -385,7 +391,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
385391
}
386392

387393
// https://www.postgresql.org/docs/15/protocol-message-formats.html
388-
func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
394+
func (srv *Session) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
389395
writer.Start(types.ServerParameterDescription)
390396
writer.AddInt16(int16(len(parameters)))
391397

@@ -400,7 +406,7 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [
400406
// back to the writer buffer. Information about the returned columns is written
401407
// to the client.
402408
// https://www.postgresql.org/docs/15/protocol-message-formats.html
403-
func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
409+
func (srv *Session) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
404410
if len(columns) == 0 {
405411
writer.Start(types.ServerNoData)
406412
return writer.End()
@@ -409,7 +415,7 @@ func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Wr
409415
return columns.Define(ctx, writer, formats)
410416
}
411417

412-
func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
418+
func (srv *Session) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
413419
name, err := reader.GetString()
414420
if err != nil {
415421
return err
@@ -451,7 +457,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
451457
// readParameters attempts to read all incoming parameters from the given
452458
// reader. The parameters are parsed and returned.
453459
// https://www.postgresql.org/docs/14/protocol-message-formats.html
454-
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
460+
func (srv *Session) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
455461
// NOTE: read the total amount of parameter format length that will be send
456462
// by the client. This can be zero to indicate that there are no parameters
457463
// or that the parameters all use the default format (text); or one, in
@@ -516,7 +522,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
516522
return parameters, nil
517523
}
518524

519-
func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
525+
func (srv *Session) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
520526
length, err := reader.GetUint16()
521527
if err != nil {
522528
return nil, err
@@ -537,7 +543,7 @@ func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error)
537543
return columns, nil
538544
}
539545

540-
func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
546+
func (srv *Session) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
541547
if srv.Statements == nil {
542548
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientExecute))
543549
}
@@ -565,7 +571,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
565571
return nil
566572
}
567573

568-
func (srv *Server) handleConnTerminate(ctx context.Context) error {
574+
func (srv *Session) handleConnTerminate(ctx context.Context) error {
569575
if srv.TerminateConn == nil {
570576
return nil
571577
}

examples/session/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
)
1111

1212
func main() {
13-
srv, err := wire.NewServer(handler, wire.Session(session))
13+
srv, err := wire.NewServer(handler, wire.SessionMiddleware(session))
1414
if err != nil {
1515
panic(err)
1616
}

options.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,18 @@ type OptionFn func(*Server) error
103103

104104
// Statements sets the statement cache used to cache statements for later use. By
105105
// default [DefaultStatementCache] is used.
106-
func Statements(cache StatementCache) OptionFn {
106+
func Statements(handler func() StatementCache) OptionFn {
107107
return func(srv *Server) error {
108-
srv.Statements = cache
108+
srv.Statements = handler
109109
return nil
110110
}
111111
}
112112

113113
// Portals sets the portals cache used to cache statements for later use. By
114114
// default [DefaultPortalCache] is used.
115-
func Portals(cache PortalCache) OptionFn {
115+
func Portals(handler func() PortalCache) OptionFn {
116116
return func(srv *Server) error {
117-
srv.Portals = cache
117+
srv.Portals = handler
118118
return nil
119119
}
120120
}
@@ -199,10 +199,10 @@ func ExtendTypes(fn func(*pgtype.Map)) OptionFn {
199199
}
200200
}
201201

202-
// Session sets the given session handler within the underlying server. The
202+
// SessionMiddleware sets the given session handler within the underlying server. The
203203
// session handler is called when a new connection is opened and authenticated
204204
// allowing for additional metadata to be wrapped around the connection context.
205-
func Session(fn SessionHandler) OptionFn {
205+
func SessionMiddleware(fn SessionHandler) OptionFn {
206206
return func(srv *Server) error {
207207
if srv.Session == nil {
208208
srv.Session = fn

options_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ func TestSessionHandler(t *testing.T) {
5656

5757
tests := map[string]test{
5858
"single": {
59-
Session(func(ctx context.Context) (context.Context, error) {
59+
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
6060
return context.WithValue(ctx, mock, value), nil
6161
}),
6262
},
6363
"nested": {
64-
Session(func(ctx context.Context) (context.Context, error) {
64+
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
6565
return ctx, nil
6666
}),
67-
Session(func(ctx context.Context) (context.Context, error) {
67+
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
6868
return context.WithValue(ctx, mock, value), nil
6969
}),
7070
},

wire.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) {
3535
logger: slog.Default(),
3636
closer: make(chan struct{}),
3737
types: pgtype.NewMap(),
38-
Statements: &DefaultStatementCache{},
39-
Portals: &DefaultPortalCache{},
38+
Statements: DefaultStatementCacheFn,
39+
Portals: DefaultPortalCacheFn,
4040
Session: func(ctx context.Context) (context.Context, error) { return ctx, nil },
4141
}
4242

@@ -62,8 +62,8 @@ type Server struct {
6262
TLSConfig *tls.Config
6363
parse ParseFn
6464
Session SessionHandler
65-
Statements StatementCache
66-
Portals PortalCache
65+
Statements func() StatementCache
66+
Portals func() PortalCache
6767
CloseConn CloseFn
6868
TerminateConn CloseFn
6969
Version string
@@ -162,7 +162,13 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error {
162162
return err
163163
}
164164

165-
return srv.consumeCommands(ctx, conn, reader, writer)
165+
session := &Session{
166+
Server: srv,
167+
Statements: srv.Statements(),
168+
Portals: srv.Portals(),
169+
}
170+
171+
return session.consumeCommands(ctx, conn, reader, writer)
166172
}
167173

168174
// Close gracefully closes the underlaying Postgres server.

0 commit comments

Comments
 (0)