Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Databricks] Supporting OAuth & Serverless compute #127

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
* No longer install 'rpy2' by default. It will prompt user for installation
the first time `spark_apply()` is called (#125)

* Adding support for Databricks serverless interactive compute (#127)

* Extended authentication method support for Databricks by deferring to SDK
(#127)

# pysparklyr 0.1.5

### Improvements
Expand Down
91 changes: 46 additions & 45 deletions R/databricks-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ databricks_host <- function(host = NULL, fail = TRUE) {
}

databricks_token <- function(token = NULL, fail = FALSE) {
# if token provided, return
# otherwise, search for token:
# DATABRICKS_TOKEN > CONNECT_DATABRICKS_TOKEN > .rs.api.getDatabricksToken

if (!is.null(token)) {
return(set_names(token, "argument"))
}
Expand All @@ -53,7 +57,6 @@ databricks_token <- function(token = NULL, fail = FALSE) {
paste0(
"No authentication token was identified: \n",
" - No 'DATABRICKS_TOKEN' environment variable found \n",
" - No Databricks OAuth token found \n",
" - Not passed as a function argument"
),
"Please add your Token to 'DATABRICKS_TOKEN' inside your .Renviron file."
Expand All @@ -66,15 +69,13 @@ databricks_token <- function(token = NULL, fail = FALSE) {
}

databricks_dbr_version_name <- function(cluster_id,
host = NULL,
token = NULL,
client,
silent = FALSE) {
bullets <- NULL
version <- NULL
cluster_info <- databricks_dbr_info(
cluster_id = cluster_id,
host = host,
token = token,
client = client,
silent = silent
)
cluster_name <- substr(cluster_info$cluster_name, 1, 100)
Expand All @@ -96,8 +97,7 @@ databricks_extract_version <- function(x) {
}

databricks_dbr_info <- function(cluster_id,
host = NULL,
token = NULL,
client,
silent = FALSE) {
cli_div(theme = cli_colors())

Expand All @@ -109,10 +109,10 @@ databricks_dbr_info <- function(cluster_id,
)
}

out <- databricks_cluster_get(cluster_id, host, token)
out <- databricks_cluster_get(cluster_id, client)
if (inherits(out, "try-error")) {
sanitized <- sanitize_host(host, silent)
out <- databricks_cluster_get(cluster_id, sanitized, token)
# sanitized <- sanitize_host(host, silent)
out <- databricks_cluster_get(cluster_id, client)
}

if (inherits(out, "try-error")) {
Expand Down Expand Up @@ -159,30 +159,17 @@ databricks_dbr_info <- function(cluster_id,
out
}

databricks_dbr_version <- function(cluster_id,
host = NULL,
token = NULL) {
databricks_dbr_version <- function(cluster_id, client) {
vn <- databricks_dbr_version_name(
cluster_id = cluster_id,
host = host,
token = token
client = client
)
vn$version
}

databricks_cluster_get <- function(cluster_id,
host = NULL,
token = NULL) {
databricks_cluster_get <- function(cluster_id, client) {
try(
paste0(
host,
"/api/2.0/clusters/get"
) %>%
request() %>%
req_auth_bearer_token(token) %>%
req_body_json(list(cluster_id = cluster_id)) %>%
req_perform() %>%
resp_body_json(),
client$clusters$get(cluster_id = cluster_id)$as_dict(),
silent = TRUE
)
}
Expand Down Expand Up @@ -227,25 +214,39 @@ databricks_dbr_error <- function(error) {
)
}

sanitize_host <- function(url, silent = FALSE) {
parsed_url <- url_parse(url)
new_url <- url_parse("http://localhost")
if (is.null(parsed_url$scheme)) {
new_url$scheme <- "https"
if (!is.null(parsed_url$path) && is.null(parsed_url$hostname)) {
new_url$hostname <- parsed_url$path
}
} else {
new_url$scheme <- parsed_url$scheme
new_url$hostname <- parsed_url$hostname
# from httr2
is_hosted_session <- function () {
if (nzchar(Sys.getenv("COLAB_RELEASE_TAG"))) {
return(TRUE)
}
ret <- url_build(new_url)
if (ret != url && !silent) {
cli_div(theme = cli_colors())
cli_alert_warning(
"{.header Changing host URL to:} {.emph {ret}}"
Sys.getenv("RSTUDIO_PROGRAM_MODE") == "server" &&
!grepl("localhost", Sys.getenv("RSTUDIO_HTTP_REFERER"), fixed = TRUE)
}

databricks_desktop_login <- function(host = NULL, profile = NULL) {

# host takes priority over profile
if (!is.null(host)) {
method <- "--host"
value <- host
} else if (!is.null(profile)) {
method <- "--profile"
value <- profile
} else {
# todo rlang error?
stop("must specifiy `host` or `profile`, neither were set")
}

cli_path <- Sys.getenv("DATABRICKS_CLI_PATH", "databricks")
if (!is_hosted_session() && nchar(Sys.which(cli_path)) != 0) {
# When on desktop, try using the Databricks CLI for auth.
output <- suppressWarnings(
system2(
cli_path,
c("auth", "login", method, value),
stdout = TRUE,
stderr = TRUE
)
)
cli_end()
}
ret
}
1 change: 1 addition & 0 deletions R/deploy.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ deploy_databricks <- function(

cluster_id <- cluster_id %||% Sys.getenv("DATABRICKS_CLUSTER_ID")

# TODO: this needs to be adjusted to use client, might need to refactor?
if (is.null(version) && !is.null(cluster_id)) {
version <- databricks_dbr_version(
cluster_id = cluster_id,
Expand Down
3 changes: 2 additions & 1 deletion R/python-install.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ install_environment <- function(
"PyArrow",
"grpcio",
"google-api-python-client",
"grpcio_status"
"grpcio_status",
"databricks-sdk"
)

if (add_torch && install_ml) {
Expand Down
104 changes: 71 additions & 33 deletions R/sparklyr-spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,19 @@ spark_connect_method.spark_method_databricks_connect <- function(
...) {
args <- list(...)
cluster_id <- args$cluster_id
serverless <- args$serverless %||% FALSE
profile <- args$profile %||% NULL
token <- args$token
envname <- args$envname
host_sanitize <- args$host_sanitize %||% TRUE
silent <- args$silent %||% FALSE

method <- method[[1]]

token <- databricks_token(token, fail = FALSE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on your comment on line 137, I think we should remove this line. And have token only populated when the user passes it as an argument in the spark_connect() call

Copy link
Author

@zacdav-db zacdav-db Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking for leaving this was that users explicitly setting the DATABRICKS_TOKEN and DATABRICKS_HOST vars should have those respected as it were set explicitly. The Databricks Python SDK won't detect those when its done from R.

databricks_token function also looks for CONNECT_DATABRICKS_TOKEN so i think its probably important to leave that intact?

I was expecting hierarchy to be:

  1. Explicit token
  2. DATABRICKS_TOKEN
  3. CONNECT_DATABRICKS_TOKEN
  4. .rs.api.getDatabricksToken(host)
  5. Python SDK explicit setting of profile
  6. Python SDK detection of DEFAULT profile

Where 1-4 are handled by databricks_token

cluster_id <- cluster_id %||% Sys.getenv("DATABRICKS_CLUSTER_ID")
master <- databricks_host(master, fail = FALSE)
if (host_sanitize && master != "") {
master <- sanitize_host(master, silent)
}

cluster_info <- NULL
if (cluster_id != "" && master != "" && token != "") {
cluster_info <- databricks_dbr_version_name(
cluster_id = cluster_id,
host = master,
token = token,
silent = silent
)
if (is.null(version)) {
version <- cluster_info$version
}
}

# load python env
envname <- use_envname(
backend = "databricks",
version = version,
Expand All @@ -102,34 +89,80 @@ spark_connect_method.spark_method_databricks_connect <- function(
return(invisible)
}

db <- import_check("databricks.connect", envname, silent)
# load python libs
dbc <- import_check("databricks.connect", envname, silent)
db_sdk <- import_check("databricks.sdk", envname, silent = TRUE)

# SDK behaviour
# https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#default-authentication-flow

conf_args <- list()

# the profile as specified - which has a default of 'DEFAULT'
# otherwise, if a token is found, propagate to SDK config

# TODO: emit messages about connection here?
# specific vars taken priority, profile only works when no env vars are set
if (token != "" && master != "") {
conf_args$host <- master
conf_args$token <- token
conf_args$auth_type <- "pat"
databricks_desktop_login(host = master)
} else if (!is.null(profile)) {
conf_args$profile <- profile
databricks_desktop_login(profile = profile)
}

# serverless config related settings
if (serverless) {
conf_args$serverless_compute_id <- "auto"
} else {
conf_args$cluster_id <- cluster_id
}

sdk_config <- db_sdk$core$Config(!!!conf_args)

# create workspace client
sdk_client <- db_sdk$WorkspaceClient(config = sdk_config)

# if serverless is TRUE, cluster_id is overruled (set to NULL)
cluster_info <- NULL
if (!serverless) {
if (cluster_id != "" && master != "" && token != "") {
cluster_info <- databricks_dbr_version_name(
cluster_id = cluster_id,
client = sdk_client,
silent = silent
)
if (is.null(version)) {
version <- cluster_info$version
}
}
} else {
cluster_id <- NULL
}

if (!is.null(cluster_info)) {
msg <- "{.header Connecting to} {.emph '{cluster_info$name}'}"
msg_done <- "{.header Connected to:} {.emph '{cluster_info$name}'}"
master_label <- glue("{cluster_info$name} ({cluster_id})")
} else {
} else if (!serverless) {
msg <- "{.header Connecting to} {.emph '{cluster_id}'}"
msg_done <- "{.header Connected to:} '{.emph '{cluster_id}'}'"
master_label <- glue("Databricks Connect - Cluster: {cluster_id}")
} else if (serverless) {
msg <- "{.header Connecting to} {.emph serverless}"
msg_done <- "{.header Connected to:} '{.emph serverless}'"
master_label <- glue("Databricks Connect - Cluster: serverless")
}

if (!silent) {
cli_div(theme = cli_colors())
cli_progress_step(msg, msg_done)
}

remote_args <- list()
if (master != "") remote_args$host <- master
if (token != "") remote_args$token <- token
if (cluster_id != "") remote_args$cluster_id <- cluster_id

databricks_session <- function(...) {
user_agent <- build_user_agent()
db$DatabricksSession$builder$remote(...)$userAgent(user_agent)
}

conn <- exec(databricks_session, !!!remote_args)
user_agent <- build_user_agent()
conn <- dbc$DatabricksSession$builder$sdkConfig(sdk_config)$userAgent(user_agent)

if (!silent) {
cli_progress_done()
Expand All @@ -141,6 +174,7 @@ spark_connect_method.spark_method_databricks_connect <- function(
master_label = master_label,
con_class = "connect_databricks",
cluster_id = cluster_id,
serverless = serverless,
method = method,
config = config
)
Expand All @@ -151,6 +185,7 @@ initialize_connection <- function(
master_label,
con_class,
cluster_id = NULL,
serverless = NULL,
method = NULL,
config = NULL) {
warnings <- import("warnings")
Expand All @@ -173,12 +208,15 @@ initialize_connection <- function(
"ignore",
message = "Index.format is deprecated and will be removed in a future version"
)

session <- conn$getOrCreate()
get_version <- try(session$version, silent = TRUE)
if (inherits(get_version, "try-error")) databricks_dbr_error(get_version)
session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L)
session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true")
session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
if (!serverless) {
session$conf$set("spark.sql.session.localRelationCacheThreshold", 1048576L)
session$conf$set("spark.sql.execution.arrow.pyspark.enabled", "true")
session$conf$set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
}

# do we need this `spark_context` object?
spark_context <- list(spark_context = session)
Expand Down