@@ -41,6 +41,11 @@ type Server struct {
4141 quit chan struct {}
4242 wg sync.WaitGroup
4343 logger * zap.Logger
44+
45+ // Add connection tracking
46+ connections map [minisql.ConnectionID ]* minisql.Connection
47+ nextConnID minisql.ConnectionID
48+ connMu sync.RWMutex
4449}
4550
4651func main () {
@@ -87,22 +92,23 @@ func main() {
8792 signal .Notify (sigChan , syscall .SIGINT , syscall .SIGTERM )
8893
8994 aServer := & Server {
90- database : aDatabase ,
91- quit : make (chan struct {}),
92- logger : logger ,
95+ database : aDatabase ,
96+ quit : make (chan struct {}),
97+ connections : make (map [minisql.ConnectionID ]* minisql.Connection ),
98+ logger : logger ,
9399 }
94100
95101 listener , err := net .Listen ("tcp" , fmt .Sprintf (":%d" , portFlag ))
96102 if err != nil {
97103 panic (err )
98104 }
99105 defer listener .Close ()
100- logger .Info ("Listening on port" , zap .Int ("port" , portFlag ))
106+ logger .Info ("listening on port" , zap .Int ("port" , portFlag ))
101107
102108 aServer .listener = listener
103109 aServer .wg .Add (1 )
104110
105- go aServer .serve ()
111+ go aServer .serve (ctx )
106112
107113 <- sigChan
108114
@@ -115,7 +121,7 @@ func main() {
115121 os .Exit (0 )
116122}
117123
118- func (s * Server ) serve () {
124+ func (s * Server ) serve (ctx context. Context ) {
119125 defer s .wg .Done ()
120126
121127 for {
@@ -129,10 +135,28 @@ func (s *Server) serve() {
129135 }
130136 } else {
131137 s .wg .Add (1 )
132- go func () {
138+ go func (tcpConn net. Conn ) {
133139 defer s .wg .Done ()
134- s .handleConnection (context .Background (), conn )
135- }()
140+
141+ // Create connection context
142+ s .connMu .Lock ()
143+ s .nextConnID ++
144+ aConnection := s .database .NewConnection (s .nextConnID , tcpConn )
145+ s .connections [aConnection .ID ] = aConnection
146+ s .connMu .Unlock ()
147+
148+ s .logger .Debug ("new connection" , zap .String ("id" , fmt .Sprint (aConnection .ID )))
149+
150+ // Handle connection messages
151+ s .handleConnection (ctx , aConnection )
152+
153+ // Cleanup on disconnect
154+ s .connMu .Lock ()
155+ delete (s .connections , aConnection .ID )
156+ s .connMu .Unlock ()
157+
158+ s .logger .Debug ("connection closed" , zap .String ("id" , fmt .Sprint (aConnection .ID )))
159+ }(conn )
136160 }
137161 }
138162}
@@ -143,8 +167,9 @@ func (s *Server) stop() {
143167 s .wg .Wait ()
144168}
145169
146- func (s * Server ) handleConnection (ctx context.Context , conn net. Conn ) {
170+ func (s * Server ) handleConnection (ctx context.Context , conn * minisql. Connection ) {
147171 defer conn .Close ()
172+ defer conn .Cleanup (ctx )
148173
149174 buf := make ([]byte , 2048 )
150175
@@ -154,13 +179,13 @@ ReadLoop:
154179 case <- s .quit :
155180 return
156181 default :
157- conn .SetDeadline (time .Now ().Add (200 * time .Millisecond ))
158- n , err := conn .Read (buf )
182+ conn .TcpConn (). SetDeadline (time .Now ().Add (200 * time .Millisecond ))
183+ n , err := conn .TcpConn (). Read (buf )
159184 if err != nil {
160185 if opErr , ok := err .(* net.OpError ); ok && opErr .Timeout () {
161186 continue ReadLoop
162187 } else if err != io .EOF {
163- log . Println ("read error" , err )
188+ s . logger . Error ("read error" , zap . Error ( err ) )
164189 return
165190 }
166191 }
@@ -169,14 +194,14 @@ ReadLoop:
169194 }
170195
171196 if err := s .handleMessage (ctx , conn , buf [:n ]); err != nil {
172- log . Println ( "Error: " , err )
197+ s . logger . Error ( "error handling message " , zap . Error ( err ) )
173198 return
174199 }
175200 }
176201 }
177202}
178203
179- func (s * Server ) handleMessage (ctx context.Context , conn net. Conn , msg []byte ) error {
204+ func (s * Server ) handleMessage (ctx context.Context , conn * minisql. Connection , msg []byte ) error {
180205 s .logger .Debug ("Received message" , zap .String ("message" , string (msg )))
181206
182207 var req protocol.Request
@@ -211,7 +236,7 @@ func (s *Server) handleMessage(ctx context.Context, conn net.Conn, msg []byte) e
211236 return nil
212237}
213238
214- func (s * Server ) handleSQL (ctx context.Context , conn net. Conn , sql string ) error {
239+ func (s * Server ) handleSQL (ctx context.Context , conn * minisql. Connection , sql string ) error {
215240 stmts , err := s .database .PrepareStatements (ctx , sql )
216241 if err != nil {
217242 return s .sendResponse (conn , protocol.Response {
@@ -221,13 +246,16 @@ func (s *Server) handleSQL(ctx context.Context, conn net.Conn, sql string) error
221246 }
222247
223248 for _ , stmt := range stmts {
224- results , err := s . database . ExecuteInTransaction (ctx , stmt )
249+ results , err := conn . ExecuteStatements (ctx , stmt )
225250 if err != nil {
226251 return s .sendResponse (conn , protocol.Response {
227252 Success : false ,
228253 Error : err .Error (),
229254 })
230255 }
256+ if len (results ) == 0 {
257+ continue
258+ }
231259 aResult := results [0 ]
232260
233261 aResponse := protocol.Response {
@@ -271,16 +299,16 @@ func (s *Server) handleSQL(ctx context.Context, conn net.Conn, sql string) error
271299 return nil
272300}
273301
274- func (s * Server ) sendResponse (conn net. Conn , resp protocol.Response ) error {
302+ func (s * Server ) sendResponse (conn * minisql. Connection , resp protocol.Response ) error {
275303 jsonData , err := json .Marshal (resp )
276304 if err != nil {
277305 return fmt .Errorf ("error marshalling response: %v" , err )
278306 }
279- _ , err = conn .Write (jsonData )
307+ _ , err = conn .TcpConn (). Write (jsonData )
280308 if err != nil {
281309 return fmt .Errorf ("error writing response: %v" , err )
282310 }
283- _ , err = conn .Write ([]byte ("\n " ))
311+ _ , err = conn .TcpConn (). Write ([]byte ("\n " ))
284312 if err != nil {
285313 return fmt .Errorf ("error writing response newline: %v" , err )
286314 }
0 commit comments