Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/query timeout option #517

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
12 changes: 6 additions & 6 deletions R/Connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ setMethod(
#' @export
setMethod(
"dbSendQuery", c("OdbcConnection", "character"),
function(conn, statement, params = NULL, ..., immediate = FALSE) {
res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate)
function(conn, statement, params = NULL, ..., immediate = FALSE, query_timeout = 0) {
res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate, query_timeout = query_timeout)
res
})

Expand All @@ -222,8 +222,8 @@ setMethod(
#' @export
setMethod(
"dbSendStatement", c("OdbcConnection", "character"),
function(conn, statement, params = NULL, ..., immediate = FALSE) {
res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate)
function(conn, statement, params = NULL, ..., immediate = FALSE, query_timeout = 0) {
res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate, query_timeout = query_timeout)
res
})

Expand Down Expand Up @@ -351,8 +351,8 @@ setMethod(
#' @inheritParams DBI::dbFetch
#' @export
setMethod("dbGetQuery", signature("OdbcConnection", "character"),
function(conn, statement, n = -1, params = NULL, ...) {
rs <- dbSendQuery(conn, statement, params = params, ...)
function(conn, statement, n = -1, params = NULL, query_timeout = 0, ...) {
rs <- dbSendQuery(conn, statement, params = params, query_timeout = query_timeout, ...)
on.exit(dbClearResult(rs))

df <- dbFetch(rs, n = n, ...)
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ result_completed <- function(r) {
.Call(`_odbc_result_completed`, r)
}

new_result <- function(p, sql, immediate) {
.Call(`_odbc_new_result`, p, sql, immediate)
new_result <- function(p, sql, immediate, query_timeout) {
.Call(`_odbc_new_result`, p, sql, immediate, query_timeout)
}

result_fetch <- function(r, n_max = -1L) {
Expand Down
4 changes: 2 additions & 2 deletions R/Result.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ NULL
#' @docType methods
NULL

OdbcResult <- function(connection, statement, params = NULL, immediate = FALSE) {
OdbcResult <- function(connection, statement, params = NULL, immediate = FALSE, query_timeout = 0L) {
shrektan marked this conversation as resolved.
Show resolved Hide resolved
if (nzchar(connection@encoding)) {
statement <- enc2iconv(statement, connection@encoding)
}
ptr <- new_result(connection@ptr, statement, immediate)
ptr <- new_result(connection@ptr, statement, immediate, query_timeout)
res <- new("OdbcResult", connection = connection, statement = statement, ptr = ptr)

if (!is.null(params)) {
Expand Down
16 changes: 16 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,22 @@ first_100 <- dbFetch(result, n = 100)
rest <- dbFetch(result)
```

`dbGetQuery()`, `dbSendStatement()` and `dbSendQuery()` provide a `query_timeout` option (internally sets `SQL_ATTR_QUERY_TIMEOUT`).

`query_timeout`: The number in seconds before query timeout (can be used to stop long running queries). Default is 0 indicating no timeout.

``` r
# long running query stops after 60 seconds
# provided statement is a placeholder

dbGetQuery(con, "SELECT flight, tailnum, origin FROM flights ORDER BY origin", query_timeout = 60)

dbSendStatement(con, "SELECT flight, tailnum, origin FROM flights ORDER BY origin", query_timeout = 60)

dbSendQuery(con, "SELECT flight, tailnum, origin FROM flights ORDER BY origin", query_timeout = 60)

```

## Benchmarks

The *odbc* package is often much faster than the existing
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ BEGIN_RCPP
END_RCPP
}
// new_result
result_ptr new_result(connection_ptr const& p, std::string const& sql, const bool immediate);
RcppExport SEXP _odbc_new_result(SEXP pSEXP, SEXP sqlSEXP, SEXP immediateSEXP) {
result_ptr new_result(connection_ptr const& p, std::string const& sql, const bool immediate, long query_timeout);
RcppExport SEXP _odbc_new_result(SEXP pSEXP, SEXP sqlSEXP, SEXP immediateSEXP, SEXP query_timeoutSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< connection_ptr const& >::type p(pSEXP);
Rcpp::traits::input_parameter< std::string const& >::type sql(sqlSEXP);
Rcpp::traits::input_parameter< const bool >::type immediate(immediateSEXP);
rcpp_result_gen = Rcpp::wrap(new_result(p, sql, immediate));
Rcpp::traits::input_parameter< long >::type query_timeout(query_timeoutSEXP);
rcpp_result_gen = Rcpp::wrap(new_result(p, sql, immediate, query_timeout));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -341,7 +342,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_odbc_result_release", (DL_FUNC) &_odbc_result_release, 1},
{"_odbc_result_active", (DL_FUNC) &_odbc_result_active, 1},
{"_odbc_result_completed", (DL_FUNC) &_odbc_result_completed, 1},
{"_odbc_new_result", (DL_FUNC) &_odbc_new_result, 3},
{"_odbc_new_result", (DL_FUNC) &_odbc_new_result, 4},
{"_odbc_result_fetch", (DL_FUNC) &_odbc_result_fetch, 2},
{"_odbc_result_column_info", (DL_FUNC) &_odbc_result_column_info, 1},
{"_odbc_result_bind", (DL_FUNC) &_odbc_result_bind, 3},
Expand Down
13 changes: 7 additions & 6 deletions src/odbc_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
namespace odbc {

odbc_result::odbc_result(
std::shared_ptr<odbc_connection> c, std::string sql, bool immediate)
std::shared_ptr<odbc_connection> c, std::string sql, bool immediate, long query_timeout)
: c_(c),
sql_(sql),
rows_fetched_(0),
num_columns_(0),
complete_(0),
bound_(false),
output_encoder_(Iconv(c_->encoding(), "UTF-8")) {
output_encoder_(Iconv(c_->encoding(), "UTF-8")),
query_timeout_(query_timeout){
shrektan marked this conversation as resolved.
Show resolved Hide resolved

if (immediate) {
s_ = std::make_shared<nanodbc::statement>();
bound_ = true;
r_ = std::make_shared<nanodbc::result>(
s_->execute_direct(*c_->connection(), sql_));
s_->execute_direct(*c_->connection(), sql_, query_timeout_));
num_columns_ = r_->columns();
c_->set_current_result(this);
} else {
Expand All @@ -42,12 +43,12 @@ std::shared_ptr<nanodbc::result> odbc_result::result() const {
return std::shared_ptr<nanodbc::result>(r_);
}
void odbc_result::prepare() {
s_ = std::make_shared<nanodbc::statement>(*c_->connection(), sql_);
s_ = std::make_shared<nanodbc::statement>(*c_->connection(), sql_, query_timeout_);
}
void odbc_result::execute() {
if (!r_) {
try {
r_ = std::make_shared<nanodbc::result>(s_->execute());
r_ = std::make_shared<nanodbc::result>(s_->execute(1L, query_timeout_));
num_columns_ = r_->columns();
} catch (const nanodbc::database_error& e) {
c_->set_current_result(nullptr);
Expand Down Expand Up @@ -151,7 +152,7 @@ void odbc_result::bind_list(
for (short col = 0; col < ncols; ++col) {
bind_columns(*s_, types[col], x, col, start, size);
}
r_ = std::make_shared<nanodbc::result>(nanodbc::execute(*s_, size));
r_ = std::make_shared<nanodbc::result>(s_->execute(size, query_timeout_));
num_columns_ = r_->columns();
start += batch_rows;

Expand Down
3 changes: 2 additions & 1 deletion src/odbc_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class odbc_error : public Rcpp::exception {
class odbc_result {
public:
odbc_result(
std::shared_ptr<odbc_connection> c, std::string sql, bool immediate);
std::shared_ptr<odbc_connection> c, std::string sql, bool immediate, long query_timeout);
std::shared_ptr<odbc_connection> connection() const;
std::shared_ptr<nanodbc::statement> statement() const;
std::shared_ptr<nanodbc::result> result() const;
Expand Down Expand Up @@ -62,6 +62,7 @@ class odbc_result {
bool complete_;
bool bound_;
Iconv output_encoder_;
long query_timeout_;

std::map<short, std::vector<std::string>> strings_;
std::map<short, std::vector<std::vector<uint8_t>>> raws_;
Expand Down
4 changes: 2 additions & 2 deletions src/result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ bool result_completed(result_ptr const& r) { return r->complete(); }

// [[Rcpp::export]]
result_ptr new_result(
connection_ptr const& p, std::string const& sql, const bool immediate) {
return result_ptr(new odbc::odbc_result(*p, sql, immediate));
connection_ptr const& p, std::string const& sql, const bool immediate, long query_timeout) {
return result_ptr(new odbc::odbc_result(*p, sql, immediate, query_timeout));
}

// [[Rcpp::export]]
Expand Down
78 changes: 78 additions & 0 deletions tests/testthat/test-SQLServer.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,82 @@ test_that("SQLServer", {
res <- dbGetQuery(con, "SELECT CAST(? AS date)", params = as.Date("2019-01-01"))
expect_equal(res[[1]], as.Date("2019-01-01"))
})

test_that("query_timeout - dbGetQuery - short query", {
con <- DBItest:::connect(DBItest:::get_default_context())
res <- dbGetQuery(con, "WaitFor Delay '00:00:02'; SELECT 'HELLO' as world", query_timeout = 3)
expect_equal(res[[1]], "HELLO")
})

test_that("query_timeout - dbGetQuery - short query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
res <- dbGetQuery(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 3)
expect_equal(res[[1]], "HELLO")
})

test_that("query_timeout - dbGetQuery - long running query", {
con <- DBItest:::connect(DBItest:::get_default_context())
expect_error(dbGetQuery(con, "WaitFor Delay '00:00:02'; SELECT 'HELLO' as world", query_timeout = 1), "timeout expired")
})

test_that("query_timeout - dbGetQuery - long running query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
expect_error(dbGetQuery(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 1), "timeout expired")
})

test_that("query_timeout - dbSendQuery - short query", {
con <- DBItest:::connect(DBItest:::get_default_context())
res <- dbSendQuery(con, "WaitFor Delay '00:00:02'; SELECT 'HELLO' as world", query_timeout = 3)
result <- dbFetch(res)
expect_equal(result[[1]], "HELLO")
})

test_that("query_timeout - dbSendQuery - short query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
res <- dbSendQuery(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 3)
result <- dbFetch(res)
expect_equal(result[[1]], "HELLO")
})

test_that("query_timeout - dbSendQuery - long running query", {
con <- DBItest:::connect(DBItest:::get_default_context())
expect_error(dbSendQuery(con, "WaitFor Delay '00:00:02'; SELECT 'HELLO' as world", query_timeout = 1), "timeout expired")
})

test_that("query_timeout - dbSendQuery - long running query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
expect_error(dbSendQuery(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 1), "timeout expired")
})

test_that("query_timeout - dbSendStatement - short query", {
con <- DBItest:::connect(DBItest:::get_default_context())
res <- dbSendStatement(con, "WaitFor Delay '00:00:02';", query_timeout = 3)
result <- dbGetRowsAffected(res)
# if the test reaches this line of code, the query is not stopped (as expected)
expect_equal(result, 0)
})

test_that("query_timeout - dbSendStatement - short query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
res <- dbSendStatement(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 3)
result <- dbGetRowsAffected(res)
# if the test reaches this line of code, the query is not stopped (as expected)
expect_equal(result, 0)
})

test_that("query_timeout - dbSendStatement - long running query", {
con <- DBItest:::connect(DBItest:::get_default_context())
expect_error(dbSendStatement(con, "WaitFor Delay '00:00:02';", query_timeout = 1), "timeout expired")
})

test_that("query_timeout - dbSendStatement - long running query with parameters", {
con <- DBItest:::connect(DBItest:::get_default_context())
par <- data.frame('HELLO')
expect_error(dbSendStatement(con, "WaitFor Delay '00:00:02'; SELECT ? as world", params = par, query_timeout = 1), "timeout expired")
})
})