Skip to content

Commit 8f8243c

Browse files
committed
feat: implement BEGIN, COMMIT, ROLLBACK
1 parent 1cda2e6 commit 8f8243c

26 files changed

+881
-245
lines changed

.mockery.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ packages:
1212
github.com/RichardKnop/minisql/internal/core/minisql:
1313
interfaces:
1414
Parser:
15+
PageSaver:
1516
TxPager:

cmd/minisql/main.go

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ type Server struct {
4141
quit chan struct{}
4242
wg sync.WaitGroup
4343
logger *zap.Logger
44+
45+
// Add connection tracking
46+
connections map[minisql.ConnectionID]*minisql.Connection
47+
nextConnID minisql.ConnectionID
48+
connMu sync.RWMutex
4449
}
4550

4651
func main() {
@@ -87,22 +92,23 @@ func main() {
8792
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
8893

8994
aServer := &Server{
90-
database: aDatabase,
91-
quit: make(chan struct{}),
92-
logger: logger,
95+
database: aDatabase,
96+
quit: make(chan struct{}),
97+
connections: make(map[minisql.ConnectionID]*minisql.Connection),
98+
logger: logger,
9399
}
94100

95101
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", portFlag))
96102
if err != nil {
97103
panic(err)
98104
}
99105
defer listener.Close()
100-
logger.Info("Listening on port", zap.Int("port", portFlag))
106+
logger.Info("listening on port", zap.Int("port", portFlag))
101107

102108
aServer.listener = listener
103109
aServer.wg.Add(1)
104110

105-
go aServer.serve()
111+
go aServer.serve(ctx)
106112

107113
<-sigChan
108114

@@ -115,7 +121,7 @@ func main() {
115121
os.Exit(0)
116122
}
117123

118-
func (s *Server) serve() {
124+
func (s *Server) serve(ctx context.Context) {
119125
defer s.wg.Done()
120126

121127
for {
@@ -129,10 +135,28 @@ func (s *Server) serve() {
129135
}
130136
} else {
131137
s.wg.Add(1)
132-
go func() {
138+
go func(tcpConn net.Conn) {
133139
defer s.wg.Done()
134-
s.handleConnection(context.Background(), conn)
135-
}()
140+
141+
// Create connection context
142+
s.connMu.Lock()
143+
s.nextConnID++
144+
aConnection := s.database.NewConnection(s.nextConnID, tcpConn)
145+
s.connections[aConnection.ID] = aConnection
146+
s.connMu.Unlock()
147+
148+
s.logger.Debug("new connection", zap.String("id", fmt.Sprint(aConnection.ID)))
149+
150+
// Handle connection messages
151+
s.handleConnection(ctx, aConnection)
152+
153+
// Cleanup on disconnect
154+
s.connMu.Lock()
155+
delete(s.connections, aConnection.ID)
156+
s.connMu.Unlock()
157+
158+
s.logger.Debug("connection closed", zap.String("id", fmt.Sprint(aConnection.ID)))
159+
}(conn)
136160
}
137161
}
138162
}
@@ -143,8 +167,9 @@ func (s *Server) stop() {
143167
s.wg.Wait()
144168
}
145169

146-
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
170+
func (s *Server) handleConnection(ctx context.Context, conn *minisql.Connection) {
147171
defer conn.Close()
172+
defer conn.Cleanup(ctx)
148173

149174
buf := make([]byte, 2048)
150175

@@ -154,13 +179,13 @@ ReadLoop:
154179
case <-s.quit:
155180
return
156181
default:
157-
conn.SetDeadline(time.Now().Add(200 * time.Millisecond))
158-
n, err := conn.Read(buf)
182+
conn.TcpConn().SetDeadline(time.Now().Add(200 * time.Millisecond))
183+
n, err := conn.TcpConn().Read(buf)
159184
if err != nil {
160185
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
161186
continue ReadLoop
162187
} else if err != io.EOF {
163-
log.Println("read error", err)
188+
s.logger.Error("read error", zap.Error(err))
164189
return
165190
}
166191
}
@@ -169,14 +194,14 @@ ReadLoop:
169194
}
170195

171196
if err := s.handleMessage(ctx, conn, buf[:n]); err != nil {
172-
log.Println("Error:", err)
197+
s.logger.Error("error handling message", zap.Error(err))
173198
return
174199
}
175200
}
176201
}
177202
}
178203

179-
func (s *Server) handleMessage(ctx context.Context, conn net.Conn, msg []byte) error {
204+
func (s *Server) handleMessage(ctx context.Context, conn *minisql.Connection, msg []byte) error {
180205
s.logger.Debug("Received message", zap.String("message", string(msg)))
181206

182207
var req protocol.Request
@@ -211,7 +236,7 @@ func (s *Server) handleMessage(ctx context.Context, conn net.Conn, msg []byte) e
211236
return nil
212237
}
213238

214-
func (s *Server) handleSQL(ctx context.Context, conn net.Conn, sql string) error {
239+
func (s *Server) handleSQL(ctx context.Context, conn *minisql.Connection, sql string) error {
215240
stmts, err := s.database.PrepareStatements(ctx, sql)
216241
if err != nil {
217242
return s.sendResponse(conn, protocol.Response{
@@ -221,13 +246,16 @@ func (s *Server) handleSQL(ctx context.Context, conn net.Conn, sql string) error
221246
}
222247

223248
for _, stmt := range stmts {
224-
results, err := s.database.ExecuteInTransaction(ctx, stmt)
249+
results, err := conn.ExecuteStatements(ctx, stmt)
225250
if err != nil {
226251
return s.sendResponse(conn, protocol.Response{
227252
Success: false,
228253
Error: err.Error(),
229254
})
230255
}
256+
if len(results) == 0 {
257+
continue
258+
}
231259
aResult := results[0]
232260

233261
aResponse := protocol.Response{
@@ -271,16 +299,16 @@ func (s *Server) handleSQL(ctx context.Context, conn net.Conn, sql string) error
271299
return nil
272300
}
273301

274-
func (s *Server) sendResponse(conn net.Conn, resp protocol.Response) error {
302+
func (s *Server) sendResponse(conn *minisql.Connection, resp protocol.Response) error {
275303
jsonData, err := json.Marshal(resp)
276304
if err != nil {
277305
return fmt.Errorf("error marshalling response: %v", err)
278306
}
279-
_, err = conn.Write(jsonData)
307+
_, err = conn.TcpConn().Write(jsonData)
280308
if err != nil {
281309
return fmt.Errorf("error writing response: %v", err)
282310
}
283-
_, err = conn.Write([]byte("\n"))
311+
_, err = conn.TcpConn().Write([]byte("\n"))
284312
if err != nil {
285313
return fmt.Errorf("error writing response newline: %v", err)
286314
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package minisql
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"sync"
8+
9+
"go.uber.org/zap"
10+
)
11+
12+
type ConnectionID uint64
13+
14+
type Connection struct {
15+
ID ConnectionID
16+
tcpConn net.Conn
17+
db *Database
18+
transaction *Transaction
19+
// TODO - do we need a lock here?
20+
// Within a single connection, operations are sequential, but multiple goroutines
21+
// could be accessing the connection concurrently. Let's revisit this later.
22+
mu sync.RWMutex
23+
}
24+
25+
func (d *Database) NewConnection(id ConnectionID, tcpConn net.Conn) *Connection {
26+
return &Connection{
27+
ID: id,
28+
db: d,
29+
tcpConn: tcpConn,
30+
}
31+
}
32+
33+
func (c *Connection) Close() error {
34+
return c.tcpConn.Close()
35+
}
36+
37+
func (c *Connection) SetTransaction(tx *Transaction) {
38+
c.mu.Lock()
39+
defer c.mu.Unlock()
40+
c.transaction = tx
41+
}
42+
43+
func (c *Connection) TcpConn() net.Conn {
44+
return c.tcpConn
45+
}
46+
47+
func (c *Connection) HasActiveTransaction() bool {
48+
c.mu.RLock()
49+
defer c.mu.RUnlock()
50+
return c.transaction != nil
51+
}
52+
53+
func (c *Connection) TransactionContext(ctx context.Context) context.Context {
54+
c.mu.RLock()
55+
defer c.mu.RUnlock()
56+
if c.transaction != nil {
57+
return WithTransaction(ctx, c.transaction)
58+
}
59+
return ctx
60+
}
61+
62+
// Clean up any active transaction on disconnect
63+
func (c *Connection) Cleanup(ctx context.Context) {
64+
if !c.HasActiveTransaction() {
65+
return
66+
}
67+
68+
c.db.logger.Warn("connection closed with active transaction, rolling back",
69+
zap.Uint64("id", uint64(c.ID)))
70+
71+
c.mu.Lock()
72+
defer c.mu.Unlock()
73+
74+
c.db.txManager.RollbackTransaction(ctx, c.transaction)
75+
c.transaction = nil
76+
}
77+
78+
func (c *Connection) ExecuteStatements(ctx context.Context, statements ...Statement) ([]StatementResult, error) {
79+
var results []StatementResult
80+
81+
for _, stmt := range statements {
82+
if c.HasActiveTransaction() {
83+
switch stmt.Kind {
84+
case BeginTransaction:
85+
return results, fmt.Errorf("transaction already active on this connection")
86+
case CommitTransaction:
87+
if err := c.db.txManager.CommitTransaction(ctx, c.transaction, c.db.saver); err != nil {
88+
return results, err
89+
}
90+
c.SetTransaction(nil)
91+
92+
results = append(results, StatementResult{})
93+
case RollbackTransaction:
94+
c.db.txManager.RollbackTransaction(ctx, c.transaction)
95+
c.SetTransaction(nil)
96+
results = append(results, StatementResult{})
97+
default:
98+
aResult, err := c.db.executeStatement(c.TransactionContext(ctx), stmt)
99+
if err != nil {
100+
return nil, err
101+
}
102+
results = append(results, aResult)
103+
}
104+
} else {
105+
// BEGIN, COMMIT, ROLLBACK outside transaction
106+
switch stmt.Kind {
107+
case BeginTransaction, CommitTransaction, RollbackTransaction:
108+
if stmt.Kind != BeginTransaction {
109+
return results, fmt.Errorf("no active transaction on this connection")
110+
}
111+
c.SetTransaction(c.db.txManager.BeginTransaction(ctx))
112+
results = append(results, StatementResult{})
113+
continue
114+
}
115+
// Everything wrap in a single statement transaction
116+
if err := c.db.txManager.ExecuteInTransaction(ctx, func(ctx context.Context) error {
117+
aResult, err := c.db.executeStatement(ctx, stmt)
118+
if err != nil {
119+
return err
120+
}
121+
results = append(results, aResult)
122+
return nil
123+
}, c.db.saver); err != nil {
124+
return results, err
125+
}
126+
}
127+
}
128+
129+
return results, nil
130+
}

0 commit comments

Comments
 (0)