diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 6567c0008b..c27813648e 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -3098,3 +3098,158 @@ func (suite *BulkIngestTests) TestBulkIngestViaExecuteQuery() { suite.Require().Len(ingestedData, 1) suite.Equal(int64(2), ingestedData[0].NumRows()) } + +// ---- IsUpdate Prepared Statement Tests -------------------- + +type IsUpdateTestServer struct { + flightsql.BaseServer + + getFlightInfoPreparedCalled bool + doPutPreparedStatementUpdCalled bool + doGetPreparedStatementCalled bool + doGetPreparedStatementHandle []byte +} + +func (srv *IsUpdateTestServer) CreatePreparedStatement(_ context.Context, req flightsql.ActionCreatePreparedStatementRequest) (flightsql.ActionCreatePreparedStatementResult, error) { + result := flightsql.ActionCreatePreparedStatementResult{ + Handle: []byte(req.GetQuery()), + } + switch req.GetQuery() { + case "UPDATE t SET x = 1": + t := true + result.IsUpdate = &t + case "SELECT 1": + f := false + result.IsUpdate = &f + } + // default: IsUpdate remains nil + return result, nil +} + +func (srv *IsUpdateTestServer) ClosePreparedStatement(_ context.Context, _ flightsql.ActionClosePreparedStatementRequest) error { + return nil +} + +func (srv *IsUpdateTestServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + srv.getFlightInfoPreparedCalled = true + + query := &flightproto.CommandPreparedStatementQuery{ + PreparedStatementHandle: cmd.GetPreparedStatementHandle(), + } + var ticket anypb.Any + if err := ticket.MarshalFrom(query); err != nil { + return nil, err + } + tkt, err := proto.Marshal(&ticket) + if err != nil { + return nil, err + } + return &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{ + {Ticket: &flight.Ticket{Ticket: tkt}}, + }, + TotalRecords: 0, + TotalBytes: -1, + }, nil +} + +func (srv *IsUpdateTestServer) DoGetPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) { + srv.doGetPreparedStatementCalled = true + srv.doGetPreparedStatementHandle = cmd.GetPreparedStatementHandle() + sc := arrow.NewSchema([]arrow.Field{}, nil) + ch := make(chan flight.StreamChunk) + close(ch) + return sc, ch, nil +} + +func (srv *IsUpdateTestServer) DoPutPreparedStatementUpdate(_ context.Context, _ flightsql.PreparedStatementUpdate, _ flight.MessageReader) (int64, error) { + srv.doPutPreparedStatementUpdCalled = true + return 42, nil +} + +type IsUpdateTests struct { + ServerBasedTests + + srv *IsUpdateTestServer +} + +func (suite *IsUpdateTests) SetupSuite() { + suite.srv = &IsUpdateTestServer{} + suite.srv.Alloc = memory.DefaultAllocator + suite.DoSetupSuite(suite.srv, nil, nil) +} + +func (suite *IsUpdateTests) SetupTest() { + suite.srv.getFlightInfoPreparedCalled = false + suite.srv.doPutPreparedStatementUpdCalled = false + suite.srv.doGetPreparedStatementCalled = false + suite.srv.doGetPreparedStatementHandle = nil + suite.ServerBasedTests.SetupTest() +} + +// TestIsUpdateNil verifies that when IsUpdate is nil (not set by server), +// the driver calls GetFlightInfoPreparedStatement (Execute path). +func (suite *IsUpdateTests) TestIsUpdateNil() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + suite.Require().NoError(stmt.SetSqlQuery("SELECT *")) + suite.Require().NoError(stmt.Prepare(context.Background())) + + rdr, _, err := stmt.ExecuteQuery(context.Background()) + suite.Require().NoError(err) + defer rdr.Release() + + suite.True(suite.srv.getFlightInfoPreparedCalled, "Expected GetFlightInfoPreparedStatement to be called when IsUpdate is nil") + suite.False(suite.srv.doPutPreparedStatementUpdCalled, "Expected DoPutPreparedStatementUpdate NOT to be called when IsUpdate is nil") + suite.True(suite.srv.doGetPreparedStatementCalled, "Expected DoGetPreparedStatement to be called when IsUpdate is nil") + suite.Equal([]byte("SELECT *"), suite.srv.doGetPreparedStatementHandle, "Expected DoGetPreparedStatement to receive the correct prepared statement handle") +} + +// TestIsUpdateFalse verifies that when IsUpdate is explicitly false, +// the driver calls GetFlightInfoPreparedStatement (Execute path). +func (suite *IsUpdateTests) TestIsUpdateFalse() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + suite.Require().NoError(stmt.SetSqlQuery("SELECT 1")) + suite.Require().NoError(stmt.Prepare(context.Background())) + + rdr, _, err := stmt.ExecuteQuery(context.Background()) + suite.Require().NoError(err) + defer rdr.Release() + + suite.True(suite.srv.getFlightInfoPreparedCalled, "Expected GetFlightInfoPreparedStatement to be called when IsUpdate is false") + suite.False(suite.srv.doPutPreparedStatementUpdCalled, "Expected DoPutPreparedStatementUpdate NOT to be called when IsUpdate is false") + suite.True(suite.srv.doGetPreparedStatementCalled, "Expected DoGetPreparedStatement to be called when IsUpdate is false") + suite.Equal([]byte("SELECT 1"), suite.srv.doGetPreparedStatementHandle, "Expected DoGetPreparedStatement to receive the correct prepared statement handle") +} + +// TestIsUpdateTrue verifies that when IsUpdate is explicitly true, +// the driver calls DoPutPreparedStatementUpdate (ExecuteUpdate path) +// instead of GetFlightInfoPreparedStatement, even when ExecuteQuery is called. +func (suite *IsUpdateTests) TestIsUpdateTrue() { + stmt, err := suite.cnxn.NewStatement() + suite.Require().NoError(err) + defer validation.CheckedClose(suite.T(), stmt) + + suite.Require().NoError(stmt.SetSqlQuery("UPDATE t SET x = 1")) + suite.Require().NoError(stmt.Prepare(context.Background())) + + rdr, nrec, err := stmt.ExecuteQuery(context.Background()) + suite.Require().NoError(err) + defer rdr.Release() + + suite.EqualValues(42, nrec, "Expected ExecuteQuery to return the number of rows affected by DoPutPreparedStatementUpdate") + suite.False(rdr.Next(), "Expected empty record reader when IsUpdate is true") + suite.False(suite.srv.getFlightInfoPreparedCalled, "Expected GetFlightInfoPreparedStatement NOT to be called when IsUpdate is true") + suite.True(suite.srv.doPutPreparedStatementUpdCalled, "Expected DoPutPreparedStatementUpdate to be called when IsUpdate is true") + suite.False(suite.srv.doGetPreparedStatementCalled, "Expected DoGetPreparedStatement NOT to be called when IsUpdate is true") +} + +func TestIsUpdate(t *testing.T) { + suite.Run(t, &IsUpdateTests{}) +} diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index 85f750a0c9..3e12bf52fd 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -514,6 +514,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n var header, trailer metadata.MD opts := append([]grpc.CallOption{}, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts) if s.prepared != nil { + if isUpdate := s.prepared.IsUpdate(); isUpdate != nil && *isUpdate { + nrec, err = s.prepared.ExecuteUpdate(ctx, opts...) + if err != nil { + return nil, -1, adbcFromFlightStatusWithDetails(err, header, trailer, "ExecuteUpdate") + } + rdr, err = array.NewRecordReader(arrow.NewSchema(nil, nil), nil) + return + } info, err = s.prepared.Execute(ctx, opts...) } else { info, err = s.query.execute(ctx, s.cnxn, opts...)