Skip to content

Commit f1238e4

Browse files
kyleconroyclaude
andcommitted
feat(mysql): Use forked driver to get prepared statement metadata
Updates the MySQL analyzer to use the sqlc-dev/mysql forked driver which exposes column and parameter metadata from COM_STMT_PREPARE responses. This provides more accurate type information directly from MySQL. The forked driver adds a StmtMetadata interface with ColumnMetadata() and ParamMetadata() methods that return type info including DatabaseTypeName, Nullable, Unsigned, and Length fields. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent e52cd4e commit f1238e4

File tree

3 files changed

+86
-112
lines changed

3 files changed

+86
-112
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,5 @@ require (
6464
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect
6565
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
6666
)
67+
68+
replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
2626
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
2727
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
2828
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
29-
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
30-
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
3129
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
3230
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
3331
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
@@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4
159157
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
160158
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
161159
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
160+
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU=
161+
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE=
162162
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
163163
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
164164
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

internal/engine/dolphin/analyzer/analyze.go

Lines changed: 82 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package analyzer
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"fmt"
78
"hash/fnv"
89
"io"
910
"strings"
1011
"sync"
1112

12-
_ "github.com/go-sql-driver/mysql"
13+
"github.com/go-sql-driver/mysql"
1314

1415
core "github.com/sqlc-dev/sqlc/internal/analysis"
1516
"github.com/sqlc-dev/sqlc/internal/config"
@@ -139,90 +140,102 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
139140
}
140141
}
141142

142-
// Count parameters in the query
143-
paramCount := countParameters(query)
144-
145-
// Try to prepare the statement first to validate syntax
146-
stmt, err := a.conn.PrepareContext(ctx, query)
143+
// Get metadata directly from prepared statement via driver connection
144+
result, err := a.getStatementMetadata(ctx, n, query, ps)
147145
if err != nil {
148-
return nil, a.extractSqlErr(n, err)
146+
return nil, err
149147
}
150-
stmt.Close()
151148

149+
return result, nil
150+
}
151+
152+
// getStatementMetadata uses the MySQL driver's prepared statement metadata API
153+
// to get column and parameter type information without executing the query
154+
func (a *Analyzer) getStatementMetadata(ctx context.Context, n ast.Node, query string, ps *named.ParamSet) (*core.Analysis, error) {
152155
var result core.Analysis
153156

154-
// For SELECT queries, execute with default parameter values to get column metadata
155-
if isSelectQuery(query) {
156-
cols, err := a.getColumnMetadata(ctx, query, paramCount)
157-
if err == nil {
158-
result.Columns = cols
159-
}
160-
// If we fail to get column metadata, fall through to return empty columns
161-
// and let the catalog-based inference handle it
157+
// Get a raw connection to access driver-level prepared statement
158+
conn, err := a.conn.Conn(ctx)
159+
if err != nil {
160+
return nil, a.extractSqlErr(n, fmt.Errorf("failed to get connection: %w", err))
162161
}
162+
defer conn.Close()
163163

164-
// Build parameter info
165-
for i := 1; i <= paramCount; i++ {
166-
name := ""
167-
if ps != nil {
168-
name, _ = ps.NameFor(i)
164+
err = conn.Raw(func(driverConn any) error {
165+
// Get the driver connection that supports PrepareContext
166+
preparer, ok := driverConn.(driver.ConnPrepareContext)
167+
if !ok {
168+
return fmt.Errorf("driver connection does not support PrepareContext")
169169
}
170-
result.Params = append(result.Params, &core.Parameter{
171-
Number: int32(i),
172-
Column: &core.Column{
173-
Name: name,
174-
DataType: "any",
175-
NotNull: false,
176-
},
177-
})
178-
}
179-
180-
return &result, nil
181-
}
182170

183-
// isSelectQuery checks if a query is a SELECT statement
184-
func isSelectQuery(query string) bool {
185-
trimmed := strings.TrimSpace(strings.ToUpper(query))
186-
return strings.HasPrefix(trimmed, "SELECT") ||
187-
strings.HasPrefix(trimmed, "WITH") // CTEs
188-
}
171+
// Prepare the statement - this sends COM_STMT_PREPARE to MySQL
172+
// and receives column and parameter metadata
173+
stmt, err := preparer.PrepareContext(ctx, query)
174+
if err != nil {
175+
return err
176+
}
177+
defer stmt.Close()
178+
179+
// Access the metadata via the StmtMetadata interface from our forked driver
180+
meta, ok := stmt.(mysql.StmtMetadata)
181+
if !ok {
182+
// Fallback: just use param count from NumInput
183+
paramCount := stmt.NumInput()
184+
for i := 1; i <= paramCount; i++ {
185+
name := ""
186+
if ps != nil {
187+
name, _ = ps.NameFor(i)
188+
}
189+
result.Params = append(result.Params, &core.Parameter{
190+
Number: int32(i),
191+
Column: &core.Column{
192+
Name: name,
193+
DataType: "any",
194+
NotNull: false,
195+
},
196+
})
197+
}
198+
return nil
199+
}
189200

190-
// getColumnMetadata executes the query with default values to retrieve column information
191-
func (a *Analyzer) getColumnMetadata(ctx context.Context, query string, paramCount int) ([]*core.Column, error) {
192-
// Generate default parameter values (use 1 for all - works for most types)
193-
args := make([]any, paramCount)
194-
for i := range args {
195-
args[i] = 1
196-
}
201+
// Get column metadata
202+
for _, col := range meta.ColumnMetadata() {
203+
result.Columns = append(result.Columns, &core.Column{
204+
Name: col.Name,
205+
DataType: strings.ToLower(col.DatabaseTypeName),
206+
NotNull: !col.Nullable,
207+
Unsigned: col.Unsigned,
208+
Length: int32(col.Length),
209+
})
210+
}
197211

198-
// Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0
199-
// This ensures we get column metadata without executing the actual query
200-
wrappedQuery := fmt.Sprintf("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0", query)
212+
// Get parameter metadata
213+
paramMeta := meta.ParamMetadata()
214+
for i, param := range paramMeta {
215+
name := ""
216+
if ps != nil {
217+
name, _ = ps.NameFor(i + 1)
218+
}
219+
result.Params = append(result.Params, &core.Parameter{
220+
Number: int32(i + 1),
221+
Column: &core.Column{
222+
Name: name,
223+
DataType: strings.ToLower(param.DatabaseTypeName),
224+
NotNull: !param.Nullable,
225+
Unsigned: param.Unsigned,
226+
Length: int32(param.Length),
227+
},
228+
})
229+
}
201230

202-
rows, err := a.conn.QueryContext(ctx, wrappedQuery, args...)
203-
if err != nil {
204-
// If wrapped query fails, try direct query with LIMIT 0
205-
// Some queries may not support being wrapped (e.g., queries with UNION at the end)
206-
return nil, err
207-
}
208-
defer rows.Close()
231+
return nil
232+
})
209233

210-
colTypes, err := rows.ColumnTypes()
211234
if err != nil {
212-
return nil, err
213-
}
214-
215-
var columns []*core.Column
216-
for _, col := range colTypes {
217-
nullable, _ := col.Nullable()
218-
columns = append(columns, &core.Column{
219-
Name: col.Name(),
220-
DataType: strings.ToLower(col.DatabaseTypeName()),
221-
NotNull: !nullable,
222-
})
235+
return nil, a.extractSqlErr(n, err)
223236
}
224237

225-
return columns, nil
238+
return &result, nil
226239
}
227240

228241
// replaceDatabase replaces the database name in a MySQL DSN
@@ -253,47 +266,6 @@ func replaceDatabase(dsn string, newDB string) string {
253266
return dsn[:slashIdx+1] + newDB + dsn[slashIdx+paramIdx:]
254267
}
255268

256-
// countParameters counts the number of ? placeholders in a query
257-
func countParameters(query string) int {
258-
count := 0
259-
inString := false
260-
stringChar := byte(0)
261-
escaped := false
262-
263-
for i := 0; i < len(query); i++ {
264-
c := query[i]
265-
266-
if escaped {
267-
escaped = false
268-
continue
269-
}
270-
271-
if c == '\\' {
272-
escaped = true
273-
continue
274-
}
275-
276-
if inString {
277-
if c == stringChar {
278-
inString = false
279-
}
280-
continue
281-
}
282-
283-
if c == '\'' || c == '"' || c == '`' {
284-
inString = true
285-
stringChar = c
286-
continue
287-
}
288-
289-
if c == '?' {
290-
count++
291-
}
292-
}
293-
294-
return count
295-
}
296-
297269
func (a *Analyzer) extractSqlErr(n ast.Node, err error) error {
298270
if err == nil {
299271
return nil

0 commit comments

Comments
 (0)