Skip to content

Commit

Permalink
Add Flight SQL support for prepared statement APIs without parameter …
Browse files Browse the repository at this point in the history
…binding
  • Loading branch information
geoffxy committed Jan 31, 2025
1 parent 01401a3 commit 8ce9613
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 29 deletions.
142 changes: 113 additions & 29 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,120 @@ BradFlightSqlServer::GetFlightInfoStatement(
const ServerCallContext& context, const StatementQuery& command,
const FlightDescriptor& descriptor) {
const std::string& query = command.query;
const std::string& transaction_id = command.transaction_id;
return GetFlightInfoImpl(query, transaction_id, descriptor);
}

arrow::Result<std::unique_ptr<FlightDataStream>>
BradFlightSqlServer::DoGetStatement(const ServerCallContext& context,
const StatementQueryTicket& command) {
ARROW_ASSIGN_OR_RAISE(auto pair,
DecodeTransactionQuery(command.statement_handle));
const std::string& autoincrement_id = pair.first;
const std::string transaction_id = pair.second;

const std::string& autoincrement_id = std::to_string(++autoincrement_id_);
const std::string& query_ticket =
GetQueryTicket(autoincrement_id, command.transaction_id);
const std::string& query_ticket = transaction_id + ':' + autoincrement_id;

std::shared_ptr<BradStatement> result;
const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) {
result = qr;
return true;
});

if (!found) {
return arrow::Status::Invalid("Invalid ticket.");
}

std::shared_ptr<BradStatementBatchReader> reader;
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result));

return std::make_unique<RecordBatchStream>(reader);
}

arrow::Result<arrow::flight::sql::ActionCreatePreparedStatementResult>
BradFlightSqlServer::CreatePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionCreatePreparedStatementRequest& request) {
const auto id = std::to_string(++autoincrement_id_);
const PreparedStatementContext statement_context{request.query,
request.transaction_id};
prepared_statements_.insert(id, statement_context);
// std::cerr << "Registered prepared statement " << id << " " << request.query
// << std::endl;
return arrow::flight::sql::ActionCreatePreparedStatementResult{nullptr,
nullptr, id};
}

arrow::Status BradFlightSqlServer::ClosePreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::ActionClosePreparedStatementRequest& request) {
// std::cerr << "ClosePreparedStatement called "
// << request.prepared_statement_handle << std::endl;
const bool erased =
prepared_statements_.erase(request.prepared_statement_handle);
if (!erased) {
return arrow::Status::Invalid("Invalid prepared statement handle.");
}
return arrow::Status();
}

arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
BradFlightSqlServer::GetFlightInfoPreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
const arrow::flight::FlightDescriptor& descriptor) {
// std::cerr << "GetFlightInfoPreparedStatement called "
// << command.prepared_statement_handle << std::endl;
const PreparedStatementContext* statement_ctx = nullptr;
prepared_statements_.find_fn(
command.prepared_statement_handle,
[&statement_ctx](const auto& ps_ctx) { statement_ctx = &ps_ctx; });
if (statement_ctx == nullptr) {
return arrow::Status::Invalid("Invalid prepared statement handle.");
}

const std::string& query = statement_ctx->query;
const std::string& transaction_id = statement_ctx->transaction_id;
return GetFlightInfoImpl(query, transaction_id, descriptor);
}

// Currently unimplemented.

arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>
BradFlightSqlServer::DoGetPreparedStatement(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command) {
std::cerr << "DoGetPreparedStatement called "
<< command.prepared_statement_handle << std::endl;
return arrow::Result<std::unique_ptr<arrow::flight::FlightDataStream>>();
}

arrow::Status BradFlightSqlServer::DoPutPreparedStatementQuery(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer) {
std::cerr << "DoPutPreparedStatementQuery called "
<< command.prepared_statement_handle << std::endl;
return arrow::Status();
}

arrow::Result<int64_t> BradFlightSqlServer::DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) {
std::cerr << "DoPutPreparedStatementUpdate called "
<< command.prepared_statement_handle << std::endl;
return arrow::Result<int64_t>();
}

arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>>
BradFlightSqlServer::GetFlightInfoImpl(const std::string& query,
const std::string& transaction_id,
const FlightDescriptor& descriptor) {
const std::string autoincrement_id = std::to_string(++autoincrement_id_);
const std::string query_ticket =
GetQueryTicket(autoincrement_id, transaction_id);
ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query_ticket));

std::shared_ptr<arrow::Schema> result_schema;
Expand Down Expand Up @@ -240,30 +350,4 @@ BradFlightSqlServer::GetFlightInfoStatement(
return std::make_unique<FlightInfo>(result);
}

arrow::Result<std::unique_ptr<FlightDataStream>>
BradFlightSqlServer::DoGetStatement(const ServerCallContext& context,
const StatementQueryTicket& command) {
ARROW_ASSIGN_OR_RAISE(auto pair,
DecodeTransactionQuery(command.statement_handle));
const std::string& autoincrement_id = pair.first;
const std::string transaction_id = pair.second;

const std::string& query_ticket = transaction_id + ':' + autoincrement_id;

std::shared_ptr<BradStatement> result;
const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) {
result = qr;
return true;
});

if (!found) {
return arrow::Status::Invalid("Invalid ticket.");
}

std::shared_ptr<BradStatementBatchReader> reader;
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result));

return std::make_unique<RecordBatchStream>(reader);
}

} // namespace brad
26 changes: 26 additions & 0 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,37 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command) override;

// Currently unimplemented.

// Bind params.
arrow::Status DoPutPreparedStatementQuery(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementQuery& command,
arrow::flight::FlightMessageReader* reader,
arrow::flight::FlightMetadataWriter* writer) override;

// Update the prepared statement.
arrow::Result<int64_t> DoPutPreparedStatementUpdate(
const arrow::flight::ServerCallContext& context,
const arrow::flight::sql::PreparedStatementUpdate& command,
arrow::flight::FlightMessageReader* reader) override;

private:
arrow::Result<std::unique_ptr<arrow::flight::FlightInfo>> GetFlightInfoImpl(
const std::string& query, const std::string& transaction_id,
const arrow::flight::FlightDescriptor& descriptor);

struct PreparedStatementContext {
std::string query;
std::string transaction_id;
};

PythonRunQueryFn handle_query_;

libcuckoo::cuckoohash_map<std::string, std::shared_ptr<BradStatement>>
query_data_;
libcuckoo::cuckoohash_map<std::string, PreparedStatementContext>
prepared_statements_;

std::atomic<uint64_t> autoincrement_id_;
};
Expand Down

0 comments on commit 8ce9613

Please sign in to comment.