Skip to content

Commit 5d85a4d

Browse files
Add snowflake support (#148)
* add snowflake support * two tweaks for correct results with PAT * actually keep the static oauth and add snowflake_pat support * fix recall * Better surface error messages * update docs * Add NEWS --------- Co-authored-by: Tomasz Kalinowski <[email protected]> Co-authored-by: Tomasz Kalinowski <[email protected]>
1 parent ea011f0 commit 5d85a4d

File tree

6 files changed

+364
-1
lines changed

6 files changed

+364
-1
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ Suggests:
6060
shiny,
6161
stringr,
6262
testthat (>= 3.0.0),
63-
tibble
63+
tibble,
64+
jose,
65+
openssl
6466
VignetteBuilder:
6567
knitr
6668
Config/Needs/website: tidyverse/tidytemplate, rmarkdown

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export(embed_google_vertex)
1212
export(embed_lm_studio)
1313
export(embed_ollama)
1414
export(embed_openai)
15+
export(embed_snowflake)
1516
export(markdown_chunk)
1617
export(markdown_frame)
1718
export(markdown_segment)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
- New `embed_azure_openai()` helper for generating embeddings from
44
Azure AI Foundry (#144).
55

6+
- New `embed_snowflake()` helper for generating embeddings with the
7+
Snowflake Cortex Embedding API (#148).
8+
69
- `ragnar_retrieve()` (and the corresponding ellmer retrieve tool) now
710
accept a vector of queries (#150).
811

R/embed-snowflake.R

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#' Generate embeddings using Snowflake
2+
#'
3+
#' Uses the [Cortex API `EMBED`](https://docs.snowflake.com/en/release-notes/2025/other/2025-04-14-cortex-offers-embed-rest-api)
4+
#' functions to generate embeddings.
5+
#'
6+
#' See [complete documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-rest-api#label-cortex-llm-embed-function).
7+
#'
8+
#' @section Authentication:
9+
#'
10+
#' - a *Programmatic Access Token* (PAT) defined via the SNOWFLAKE_PAT environment variable.
11+
#' - A static OAuth token defined via the SNOWFLAKE_TOKEN environment variable.
12+
#' - Key-pair authentication credentials defined via the SNOWFLAKE_USER and SNOWFLAKE_PRIVATE_KEY (which can be a PEM-encoded private key or a path to one) environment variables.
13+
#' - Posit Workbench-managed Snowflake credentials for the corresponding account.
14+
#' - Viewer-based credentials on Posit Connect. Requires the connectcreds package.
15+
#'
16+
#' @inheritParams embed_ollama
17+
#' @inheritParams ellmer::chat_snowflake
18+
#' @export
19+
embed_snowflake <- function(
20+
x,
21+
account = snowflake_account(),
22+
credentials = NULL,
23+
model = "snowflake-arctic-embed-m-v1.5",
24+
api_args = list(),
25+
batch_size = 512L
26+
) {
27+
if (is.data.frame(x)) {
28+
x[["embedding"]] <- Recall(
29+
x[["text"]],
30+
account = account,
31+
credentials = credentials,
32+
model = model,
33+
api_args = api_args,
34+
batch_size = batch_size
35+
)
36+
return(x)
37+
}
38+
39+
text <- x
40+
check_character(text)
41+
if (!length(text)) {
42+
# ideally we'd return a 0-row matrix, but currently the correct
43+
# embedding_size is not convenient to access in this context
44+
return(NULL)
45+
}
46+
check_string(model, allow_empty = FALSE)
47+
if (!is.list(api_args)) {
48+
cli::cli_abort("`api_args` must be a list.")
49+
}
50+
51+
auth_headers <- function() {
52+
if (is.null(credentials)) {
53+
default_snowflake_credentials(account)()
54+
} else if (is.function(credentials)) {
55+
credentials()
56+
} else {
57+
credentials
58+
}
59+
}
60+
61+
headers <- auth_headers()
62+
63+
base_req <- account |>
64+
snowflake_url() |>
65+
httr2::request() |>
66+
embed_req_retry() |>
67+
httr2::req_url_path_append("api/v2/cortex/inference:embed") |>
68+
httr2::req_headers_redacted(!!!headers) |>
69+
httr2::req_headers(
70+
"Content-Type" = "application/json",
71+
"Accept" = "application/json"
72+
) |>
73+
httr2::req_user_agent(ragnar_user_agent()) |>
74+
httr2::req_error(body = function(resp) {
75+
tryCatch({
76+
json <- httr2::resp_body_json(resp, check_type = FALSE)
77+
json$message
78+
}, error = function(e) {
79+
"Unknown error"
80+
})
81+
})
82+
83+
out <- vector("list", length(text))
84+
base_body <- rlang::list2(model = model, !!!api_args)
85+
86+
for (indices in chunk_list(seq_along(text), batch_size)) {
87+
body <- base_body
88+
body$text <- as.list(text[indices])
89+
90+
resp <- base_req |>
91+
httr2::req_body_json(body) |>
92+
httr2::req_perform() |>
93+
httr2::resp_body_json()
94+
95+
out[indices] <- lapply(resp$data, \(x) x$embedding)
96+
}
97+
98+
matrix(unlist(out), nrow = length(text), byrow = TRUE)
99+
}
100+
101+
# Snowflake utilities copied from ellmer ------------
102+
# Handling Snowflake credentials and authentication in workbench + multiple other scnearios.
103+
104+
snowflake_account <- function() {
105+
val <- Sys.getenv("SNOWFLAKE_ACCOUNT")
106+
if (!identical(val, "")) {
107+
val
108+
} else {
109+
cli::cli_abort("SNOWFLAKE_ACCOUNT environment variable is not set.")
110+
}
111+
}
112+
113+
snowflake_url <- function(account) {
114+
paste0("https://", account, ".snowflakecomputing.com")
115+
}
116+
117+
default_snowflake_credentials <- function(account = snowflake_account()) {
118+
# Detect viewer-based credentials from Posit Connect.
119+
url <- snowflake_url(account)
120+
if (is_installed("connectcreds") && connectcreds::has_viewer_token(url)) {
121+
return(function() {
122+
token <- connectcreds::connect_viewer_token(url)
123+
list(
124+
Authorization = paste("Bearer", token$access_token),
125+
`X-Snowflake-Authorization-Token-Type` = "OAUTH"
126+
)
127+
})
128+
}
129+
130+
token <- Sys.getenv("SNOWFLAKE_TOKEN")
131+
if (nchar(token) != 0) {
132+
return(function() {
133+
list(
134+
Authorization = paste("Bearer", token),
135+
# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-oauth
136+
`X-Snowflake-Authorization-Token-Type` = "OAUTH"
137+
)
138+
})
139+
}
140+
141+
token <- Sys.getenv("SNOWFLAKE_PAT")
142+
if (nchar(token) != 0) {
143+
return(function() {
144+
list(
145+
Authorization = paste("Bearer", token),
146+
# See https://docs.snowflake.com/en/user-guide/programmatic-access-tokens
147+
`X-Snowflake-Authorization-Token-Type` = "PROGRAMMATIC_ACCESS_TOKEN"
148+
)
149+
})
150+
}
151+
152+
# Support for Snowflake key-pair authentication.
153+
# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#generate-a-jwt-token
154+
user <- Sys.getenv("SNOWFLAKE_USER")
155+
private_key <- Sys.getenv("SNOWFLAKE_PRIVATE_KEY")
156+
if (nchar(user) != 0 && nchar(private_key) != 0) {
157+
check_installed(c("jose", "openssl"), "for key-pair authentication")
158+
key <- openssl::read_key(private_key)
159+
return(function() {
160+
token <- snowflake_keypair_token(account, user, key)
161+
list(
162+
Authorization = paste("Bearer", token),
163+
`X-Snowflake-Authorization-Token-Type` = "KEYPAIR_JWT"
164+
)
165+
})
166+
}
167+
168+
# Check for Workbench-managed credentials.
169+
sf_home <- Sys.getenv("SNOWFLAKE_HOME")
170+
if (grepl("posit-workbench", sf_home, fixed = TRUE)) {
171+
token <- workbench_snowflake_token(account, sf_home)
172+
if (!is.null(token)) {
173+
return(function() {
174+
# Ensure we get an up-to-date token.
175+
token <- workbench_snowflake_token(account, sf_home)
176+
list(
177+
Authorization = paste("Bearer", token),
178+
`X-Snowflake-Authorization-Token-Type` = "OAUTH"
179+
)
180+
})
181+
}
182+
}
183+
184+
cli::cli_abort("No Snowflake credentials are available.")
185+
}
186+
187+
snowflake_keypair_token <- function(
188+
account,
189+
user,
190+
key,
191+
cache = snowflake_keypair_cache(account, key),
192+
lifetime = 600L,
193+
reauth = FALSE
194+
) {
195+
# Producing a signed JWT is a fairly expensive operation (in the order of
196+
# ~10ms), but adding a cache speeds this up approximately 500x.
197+
creds <- cache$get()
198+
if (reauth || is.null(creds) || creds$expiry < Sys.time()) {
199+
cache$clear()
200+
expiry <- Sys.time() + lifetime
201+
# We can't use openssl::fingerprint() here because it uses a different
202+
# algorithm.
203+
fp <- openssl::base64_encode(
204+
openssl::sha256(openssl::write_der(key$pubkey))
205+
)
206+
if (grepl(".+\\.privatelink$", account)) {
207+
# account identifier is everything up to the first period
208+
account <- gsub("^([^.]*).+", "\\1", account)
209+
}
210+
sub <- toupper(paste0(account, ".", user))
211+
iss <- paste0(sub, ".SHA256:", fp)
212+
# Note: Snowflake employs a malformed issuer claim, so we have to inject it
213+
# manually after jose's validation phase.
214+
claim <- jose::jwt_claim("dummy", sub, exp = as.integer(expiry))
215+
claim$iss <- iss
216+
creds <- list(expiry = expiry, token = jose::jwt_encode_sig(claim, key))
217+
cache$set(creds)
218+
}
219+
creds$token
220+
}
221+
222+
snowflake_keypair_cache <- function(account, key) {
223+
credentials_cache(key = hash(c("sf", account, openssl::fingerprint(key))))
224+
}
225+
226+
snowflake_credentials_exist <- function(...) {
227+
tryCatch(
228+
is_list(default_snowflake_credentials(...)),
229+
error = function(e) FALSE
230+
)
231+
}
232+
233+
# Reads Posit Workbench-managed Snowflake credentials from a
234+
# $SNOWFLAKE_HOME/connections.toml file, as used by the Snowflake Connector for
235+
# Python implementation. The file will look as follows:
236+
#
237+
# [workbench]
238+
# account = "account-id"
239+
# token = "token"
240+
# authenticator = "oauth"
241+
workbench_snowflake_token <- function(account, sf_home) {
242+
cfg <- readLines(file.path(sf_home, "connections.toml"))
243+
# We don't attempt a full parse of the TOML syntax, instead relying on the
244+
# fact that this file will always contain only one section.
245+
if (!any(grepl(account, cfg, fixed = TRUE))) {
246+
# The configuration doesn't actually apply to this account.
247+
return(NULL)
248+
}
249+
line <- grepl("token = ", cfg, fixed = TRUE)
250+
token <- gsub("token = ", "", cfg[line])
251+
if (nchar(token) == 0) {
252+
return(NULL)
253+
}
254+
# Drop enclosing quotes.
255+
gsub("\"", "", token)
256+
}

man/embed_snowflake.Rd

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
test_that("snowflake embeddings works", {
2+
account <- Sys.getenv("SNOWFLAKE_ACCOUNT")
3+
testthat::skip_if(account == "")
4+
testthat::skip_if(!snowflake_credentials_exist(account))
5+
6+
model <- "snowflake-arctic-embed-m-v1.5"
7+
text <- c("hello world", "another hello world")
8+
9+
embs_single_1 <- embed_snowflake(
10+
text[1],
11+
account = account,
12+
model = model
13+
)
14+
15+
embs_single_2 <- embed_snowflake(
16+
text[2],
17+
account = account,
18+
model = model
19+
)
20+
21+
embs_batch <- embed_snowflake(
22+
text,
23+
account = account,
24+
model = model,
25+
batch_size = 1L
26+
)
27+
28+
expect_equal(embs_single_1[1, ], embs_batch[1, ])
29+
expect_equal(embs_single_2[1, ], embs_batch[2, ])
30+
})
31+
32+
test_that("embed_snowflake returns NULL for empty inputs", {
33+
expect_null(
34+
embed_snowflake(
35+
character(),
36+
account = "placeholder",
37+
credentials = function() list()
38+
)
39+
)
40+
})

0 commit comments

Comments
 (0)