|
| 1 | +#' Generate embeddings using Snowflake |
| 2 | +#' |
| 3 | +#' Uses the [Cortex API `EMBED`](https://docs.snowflake.com/en/release-notes/2025/other/2025-04-14-cortex-offers-embed-rest-api) |
| 4 | +#' functions to generate embeddings. |
| 5 | +#' |
| 6 | +#' See [complete documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-rest-api#label-cortex-llm-embed-function). |
| 7 | +#' |
| 8 | +#' @section Authentication: |
| 9 | +#' |
| 10 | +#' - a *Programmatic Access Token* (PAT) defined via the SNOWFLAKE_PAT environment variable. |
| 11 | +#' - A static OAuth token defined via the SNOWFLAKE_TOKEN environment variable. |
| 12 | +#' - Key-pair authentication credentials defined via the SNOWFLAKE_USER and SNOWFLAKE_PRIVATE_KEY (which can be a PEM-encoded private key or a path to one) environment variables. |
| 13 | +#' - Posit Workbench-managed Snowflake credentials for the corresponding account. |
| 14 | +#' - Viewer-based credentials on Posit Connect. Requires the connectcreds package. |
| 15 | +#' |
| 16 | +#' @inheritParams embed_ollama |
| 17 | +#' @inheritParams ellmer::chat_snowflake |
| 18 | +#' @export |
| 19 | +embed_snowflake <- function( |
| 20 | + x, |
| 21 | + account = snowflake_account(), |
| 22 | + credentials = NULL, |
| 23 | + model = "snowflake-arctic-embed-m-v1.5", |
| 24 | + api_args = list(), |
| 25 | + batch_size = 512L |
| 26 | +) { |
| 27 | + if (is.data.frame(x)) { |
| 28 | + x[["embedding"]] <- Recall( |
| 29 | + x[["text"]], |
| 30 | + account = account, |
| 31 | + credentials = credentials, |
| 32 | + model = model, |
| 33 | + api_args = api_args, |
| 34 | + batch_size = batch_size |
| 35 | + ) |
| 36 | + return(x) |
| 37 | + } |
| 38 | + |
| 39 | + text <- x |
| 40 | + check_character(text) |
| 41 | + if (!length(text)) { |
| 42 | + # ideally we'd return a 0-row matrix, but currently the correct |
| 43 | + # embedding_size is not convenient to access in this context |
| 44 | + return(NULL) |
| 45 | + } |
| 46 | + check_string(model, allow_empty = FALSE) |
| 47 | + if (!is.list(api_args)) { |
| 48 | + cli::cli_abort("`api_args` must be a list.") |
| 49 | + } |
| 50 | + |
| 51 | + auth_headers <- function() { |
| 52 | + if (is.null(credentials)) { |
| 53 | + default_snowflake_credentials(account)() |
| 54 | + } else if (is.function(credentials)) { |
| 55 | + credentials() |
| 56 | + } else { |
| 57 | + credentials |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + headers <- auth_headers() |
| 62 | + |
| 63 | + base_req <- account |> |
| 64 | + snowflake_url() |> |
| 65 | + httr2::request() |> |
| 66 | + embed_req_retry() |> |
| 67 | + httr2::req_url_path_append("api/v2/cortex/inference:embed") |> |
| 68 | + httr2::req_headers_redacted(!!!headers) |> |
| 69 | + httr2::req_headers( |
| 70 | + "Content-Type" = "application/json", |
| 71 | + "Accept" = "application/json" |
| 72 | + ) |> |
| 73 | + httr2::req_user_agent(ragnar_user_agent()) |> |
| 74 | + httr2::req_error(body = function(resp) { |
| 75 | + tryCatch({ |
| 76 | + json <- httr2::resp_body_json(resp, check_type = FALSE) |
| 77 | + json$message |
| 78 | + }, error = function(e) { |
| 79 | + "Unknown error" |
| 80 | + }) |
| 81 | + }) |
| 82 | + |
| 83 | + out <- vector("list", length(text)) |
| 84 | + base_body <- rlang::list2(model = model, !!!api_args) |
| 85 | + |
| 86 | + for (indices in chunk_list(seq_along(text), batch_size)) { |
| 87 | + body <- base_body |
| 88 | + body$text <- as.list(text[indices]) |
| 89 | + |
| 90 | + resp <- base_req |> |
| 91 | + httr2::req_body_json(body) |> |
| 92 | + httr2::req_perform() |> |
| 93 | + httr2::resp_body_json() |
| 94 | + |
| 95 | + out[indices] <- lapply(resp$data, \(x) x$embedding) |
| 96 | + } |
| 97 | + |
| 98 | + matrix(unlist(out), nrow = length(text), byrow = TRUE) |
| 99 | +} |
| 100 | + |
| 101 | +# Snowflake utilities copied from ellmer ------------ |
| 102 | +# Handling Snowflake credentials and authentication in workbench + multiple other scnearios. |
| 103 | + |
| 104 | +snowflake_account <- function() { |
| 105 | + val <- Sys.getenv("SNOWFLAKE_ACCOUNT") |
| 106 | + if (!identical(val, "")) { |
| 107 | + val |
| 108 | + } else { |
| 109 | + cli::cli_abort("SNOWFLAKE_ACCOUNT environment variable is not set.") |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +snowflake_url <- function(account) { |
| 114 | + paste0("https://", account, ".snowflakecomputing.com") |
| 115 | +} |
| 116 | + |
| 117 | +default_snowflake_credentials <- function(account = snowflake_account()) { |
| 118 | + # Detect viewer-based credentials from Posit Connect. |
| 119 | + url <- snowflake_url(account) |
| 120 | + if (is_installed("connectcreds") && connectcreds::has_viewer_token(url)) { |
| 121 | + return(function() { |
| 122 | + token <- connectcreds::connect_viewer_token(url) |
| 123 | + list( |
| 124 | + Authorization = paste("Bearer", token$access_token), |
| 125 | + `X-Snowflake-Authorization-Token-Type` = "OAUTH" |
| 126 | + ) |
| 127 | + }) |
| 128 | + } |
| 129 | + |
| 130 | + token <- Sys.getenv("SNOWFLAKE_TOKEN") |
| 131 | + if (nchar(token) != 0) { |
| 132 | + return(function() { |
| 133 | + list( |
| 134 | + Authorization = paste("Bearer", token), |
| 135 | + # See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-oauth |
| 136 | + `X-Snowflake-Authorization-Token-Type` = "OAUTH" |
| 137 | + ) |
| 138 | + }) |
| 139 | + } |
| 140 | + |
| 141 | + token <- Sys.getenv("SNOWFLAKE_PAT") |
| 142 | + if (nchar(token) != 0) { |
| 143 | + return(function() { |
| 144 | + list( |
| 145 | + Authorization = paste("Bearer", token), |
| 146 | + # See https://docs.snowflake.com/en/user-guide/programmatic-access-tokens |
| 147 | + `X-Snowflake-Authorization-Token-Type` = "PROGRAMMATIC_ACCESS_TOKEN" |
| 148 | + ) |
| 149 | + }) |
| 150 | + } |
| 151 | + |
| 152 | + # Support for Snowflake key-pair authentication. |
| 153 | + # See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#generate-a-jwt-token |
| 154 | + user <- Sys.getenv("SNOWFLAKE_USER") |
| 155 | + private_key <- Sys.getenv("SNOWFLAKE_PRIVATE_KEY") |
| 156 | + if (nchar(user) != 0 && nchar(private_key) != 0) { |
| 157 | + check_installed(c("jose", "openssl"), "for key-pair authentication") |
| 158 | + key <- openssl::read_key(private_key) |
| 159 | + return(function() { |
| 160 | + token <- snowflake_keypair_token(account, user, key) |
| 161 | + list( |
| 162 | + Authorization = paste("Bearer", token), |
| 163 | + `X-Snowflake-Authorization-Token-Type` = "KEYPAIR_JWT" |
| 164 | + ) |
| 165 | + }) |
| 166 | + } |
| 167 | + |
| 168 | + # Check for Workbench-managed credentials. |
| 169 | + sf_home <- Sys.getenv("SNOWFLAKE_HOME") |
| 170 | + if (grepl("posit-workbench", sf_home, fixed = TRUE)) { |
| 171 | + token <- workbench_snowflake_token(account, sf_home) |
| 172 | + if (!is.null(token)) { |
| 173 | + return(function() { |
| 174 | + # Ensure we get an up-to-date token. |
| 175 | + token <- workbench_snowflake_token(account, sf_home) |
| 176 | + list( |
| 177 | + Authorization = paste("Bearer", token), |
| 178 | + `X-Snowflake-Authorization-Token-Type` = "OAUTH" |
| 179 | + ) |
| 180 | + }) |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + cli::cli_abort("No Snowflake credentials are available.") |
| 185 | +} |
| 186 | + |
| 187 | +snowflake_keypair_token <- function( |
| 188 | + account, |
| 189 | + user, |
| 190 | + key, |
| 191 | + cache = snowflake_keypair_cache(account, key), |
| 192 | + lifetime = 600L, |
| 193 | + reauth = FALSE |
| 194 | +) { |
| 195 | + # Producing a signed JWT is a fairly expensive operation (in the order of |
| 196 | + # ~10ms), but adding a cache speeds this up approximately 500x. |
| 197 | + creds <- cache$get() |
| 198 | + if (reauth || is.null(creds) || creds$expiry < Sys.time()) { |
| 199 | + cache$clear() |
| 200 | + expiry <- Sys.time() + lifetime |
| 201 | + # We can't use openssl::fingerprint() here because it uses a different |
| 202 | + # algorithm. |
| 203 | + fp <- openssl::base64_encode( |
| 204 | + openssl::sha256(openssl::write_der(key$pubkey)) |
| 205 | + ) |
| 206 | + if (grepl(".+\\.privatelink$", account)) { |
| 207 | + # account identifier is everything up to the first period |
| 208 | + account <- gsub("^([^.]*).+", "\\1", account) |
| 209 | + } |
| 210 | + sub <- toupper(paste0(account, ".", user)) |
| 211 | + iss <- paste0(sub, ".SHA256:", fp) |
| 212 | + # Note: Snowflake employs a malformed issuer claim, so we have to inject it |
| 213 | + # manually after jose's validation phase. |
| 214 | + claim <- jose::jwt_claim("dummy", sub, exp = as.integer(expiry)) |
| 215 | + claim$iss <- iss |
| 216 | + creds <- list(expiry = expiry, token = jose::jwt_encode_sig(claim, key)) |
| 217 | + cache$set(creds) |
| 218 | + } |
| 219 | + creds$token |
| 220 | +} |
| 221 | + |
| 222 | +snowflake_keypair_cache <- function(account, key) { |
| 223 | + credentials_cache(key = hash(c("sf", account, openssl::fingerprint(key)))) |
| 224 | +} |
| 225 | + |
| 226 | +snowflake_credentials_exist <- function(...) { |
| 227 | + tryCatch( |
| 228 | + is_list(default_snowflake_credentials(...)), |
| 229 | + error = function(e) FALSE |
| 230 | + ) |
| 231 | +} |
| 232 | + |
| 233 | +# Reads Posit Workbench-managed Snowflake credentials from a |
| 234 | +# $SNOWFLAKE_HOME/connections.toml file, as used by the Snowflake Connector for |
| 235 | +# Python implementation. The file will look as follows: |
| 236 | +# |
| 237 | +# [workbench] |
| 238 | +# account = "account-id" |
| 239 | +# token = "token" |
| 240 | +# authenticator = "oauth" |
| 241 | +workbench_snowflake_token <- function(account, sf_home) { |
| 242 | + cfg <- readLines(file.path(sf_home, "connections.toml")) |
| 243 | + # We don't attempt a full parse of the TOML syntax, instead relying on the |
| 244 | + # fact that this file will always contain only one section. |
| 245 | + if (!any(grepl(account, cfg, fixed = TRUE))) { |
| 246 | + # The configuration doesn't actually apply to this account. |
| 247 | + return(NULL) |
| 248 | + } |
| 249 | + line <- grepl("token = ", cfg, fixed = TRUE) |
| 250 | + token <- gsub("token = ", "", cfg[line]) |
| 251 | + if (nchar(token) == 0) { |
| 252 | + return(NULL) |
| 253 | + } |
| 254 | + # Drop enclosing quotes. |
| 255 | + gsub("\"", "", token) |
| 256 | +} |
0 commit comments