diff --git a/.Rbuildignore b/.Rbuildignore index 65b1be8c..f32e6e3e 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -14,4 +14,5 @@ _cache/ ^[\.]?air\.toml$ ^\.vscode$ ^data-raw$ +^_dev$ ^revdep$ diff --git a/.gitignore b/.gitignore index 0160fa21..b46d1001 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ docs inst/doc /.quarto/ +_dev/ diff --git a/DESCRIPTION b/DESCRIPTION index 28ca9b3b..d838b5cf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: ellmer Title: Chat with Large Language Models -Version: 0.2.1.9000 +Version: 0.2.1.9001 Authors@R: c( person("Hadley", "Wickham", , "hadley@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4757-117X")), @@ -8,6 +8,8 @@ Authors@R: c( person("Aaron", "Jacobs", role = "aut"), person("Garrick", "Aden-Buie", , "garrick@posit.co", role = "aut", comment = c(ORCID = "0000-0002-7111-0077")), + person("Barret", "Schloerke", , "barret@posit.co", role = "aut", + comment = c(ORCID = "0000-0001-9986-114X")), person("Posit Software, PBC", role = c("cph", "fnd"), comment = c(ROR = "03wc8by49")) ) @@ -58,6 +60,7 @@ RoxygenNote: 7.3.2 Collate: 'utils-S7.R' 'types.R' + 'ellmer-package.R' 'tools-def.R' 'content.R' 'provider.R' @@ -70,9 +73,9 @@ Collate: 'content-image.R' 'content-pdf.R' 'turns.R' + 'content-replay.R' 'content-tools.R' 'deprecated.R' - 'ellmer-package.R' 'httr2.R' 'import-standalone-obj-type.R' 'import-standalone-purrr.R' diff --git a/NAMESPACE b/NAMESPACE index 992a498c..6f0c42cc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -56,6 +56,8 @@ export(content_pdf_file) export(content_pdf_url) export(contents_html) export(contents_markdown) +export(contents_record) +export(contents_replay) export(contents_text) export(create_tool_def) export(google_upload) diff --git a/NEWS.md b/NEWS.md index c0dbee9b..6d53d64e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,14 @@ # ellmer (development version) +## New features * `models_github()` lists models for `chat_github()` (#561). * `chat_snowflake()` now works with tool calling (#557, @atheriel). +* Added `contents_record()`, `contents_replay()`, and `contents_replay_class()` + to record and replay `Turn` related information from a `Chat` instance (#502). + For example, these methods can be used for bookmarking within `{shinychat}`. + # ellmer 0.2.1 * When you save a `Chat` object to disk, API keys are automatically redacted. diff --git a/R/content-replay.R b/R/content-replay.R new file mode 100644 index 00000000..881a2859 --- /dev/null +++ b/R/content-replay.R @@ -0,0 +1,231 @@ +#' @include utils-S7.R +#' @include turns.R +#' @include tools-def.R +#' @include content.R +NULL + +#' Save and restore content +#' +#' @description +#' These generic functions can be use to convert [Turn]/[Content] objects +#' into easily serializable representations. +#' +#' * `contents_record()` accept a [Turn] or [Content] and return a simple list. +#' * `contents_replay()` will accept a simple list (from `contents_record()`) +#' and return a [Turn] or [Content] object. +#' +#' @param content A [Turn] or [Content] object to serialize. +#' @param obj A basic list to desierialize. +#' @param chat A [Chat] object to be used for context. +#' @param ... Not used. +#' +#' @examplesIf has_credentials("openai") +#' chat <- chat_openai(model = "gpt-4.1-nano") +#' chat$chat("Where is the capital of France?") +#' +#' # Serialize to a simple list +#' turn_recorded <- contents_record(chat$get_turns(), chat = chat) +#' str(turn_recorded) +#' +#' # Deserialize back to S7 objects +#' turn_replayed <- contents_replay(turn_recorded, chat = chat) +#' turn_replayed +#' @export +#' @rdname contents_record +contents_record <- + new_generic( + "contents_record", + "content", + function(content, ..., chat) { + check_chat(chat, call = caller_env()) + + recorded <- S7::S7_dispatch() + + if (!is_recorded_object(recorded)) { + cli::cli_abort( + "Expected the recorded object to be a list with at least names 'version', 'class', and 'props'." + ) + } + + if ( + !is.character(recorded$class) || + length(recorded$class) != 1 + ) { + cli::cli_abort( + "Expected the recorded object to have a single $class name, containing `::` if the class is from a package." + ) + } + + if (!grepl("ellmer::", recorded$class, fixed = TRUE)) { + cli::cli_abort( + "Only S7 classes from the `ellmer` package are currently supported. Received: {.val {recorded$class}}." + ) + } + + recorded + } + ) + +method(contents_record, S7::S7_object) <- function(content, ..., chat) { + class_name <- class(content)[1] + + # Remove read-only props + cls_props <- S7::S7_class(content)@properties + prop_names <- names(cls_props)[!map_lgl(cls_props, prop_is_read_only)] + + recorded_props <- setNames( + lapply(prop_names, function(prop_name) { + prop_value <- S7::prop(prop_name, object = content) + if (S7_inherits(prop_value)) { + # Recursive record for S7 objects + contents_record(prop_value, chat = chat) + } else if (is_list_of_s7_objects(prop_value)) { + # Make record of each item in list + lapply(prop_value, contents_record, chat = chat) + } else { + prop_value + } + }), + prop_names + ) + + # Remove non-serializable properties + recorded_props <- Filter(function(x) !is.function(x), recorded_props) + + list( + version = 1, + class = class_name, + props = recorded_props + ) +} + + +#' @rdname contents_record +#' @export +# Holy "Holy Trait" dispatching, Batman! +contents_replay <- function(obj, ..., chat) { + check_chat(chat, call = caller_env()) + + # Find any reason to not believe `obj` is a recorded object. + # If not a recorded object, return it as is. + # If it is a recorded s7 object, dispatch on the discovered class. + + if (!is_recorded_object(obj)) { + cli::cli_abort( + "Expected the object to be a list with at least names 'version', 'class', and 'props'." + ) + } + + class_name <- obj$class + if (!(is.character(class_name) && length(class_name) == 1)) { + cli::cli_abort( + "Expected the replay object's `'class'` value to be a single character." + ) + } + + cls_name <- strsplit(class_name, "::")[[1]][2] + if (!grepl("ellmer::", class_name, fixed = TRUE)) { + cli::cli_abort( + "Only S7 classes from the `ellmer` package are currently supported." + ) + } + + cls <- pkg_env("ellmer")[[cls_name]] + + if (is.null(cls)) { + cli::cli_abort("Unable to find the S7 class: {.val {class_name}}.") + } + + if (!S7_inherits(cls)) { + cli::cli_abort( + "The object returned for {.val {class_name}} is not an S7 class." + ) + } + + # Manually retrieve the handler for the class as we dispatch on the class itself, + # not on an instance + # An error will be thrown if a method is not found, + # however we have a fallback for the `S7::S7_object` (the root base class) + handler <- S7::method(contents_replay_class, cls) + handler(cls, obj, chat = chat) +} + +contents_replay_class <- new_generic( + "contents_replay_class", + "cls", + function(cls, obj, ..., chat) { + S7::S7_dispatch() + } +) + + +method(contents_replay_class, S7::S7_object) <- function( + cls, + obj, + ..., + chat +) { + stopifnot(obj$version == 1) + + obj_props <- map(obj$props, function(prop_value) { + if (is_list_of_recorded_objects(prop_value)) { + # If the prop is a list of recorded objects, replay each one + map(prop_value, contents_replay, chat = chat) + } else if (is_recorded_object(prop_value)) { + # If the prop is a recorded object, replay it + contents_replay(prop_value, chat = chat) + } else { + prop_value + } + }) + + class_name <- obj$class[1] + cls_name <- strsplit(class_name, "::")[[1]][2] + # While this seems like a bit of extra work, the tracebacks are accurate + # vs referencing an unrelated parameter name in the traceback + exec(cls_name, !!!obj_props, .env = ns_env("ellmer")) +} + +method(contents_replay_class, ToolDef) <- function( + cls, + obj, + ..., + chat +) { + if (obj$version != 1) { + cli::cli_abort( + "Unsupported version {.val {obj$version}}." + ) + } + + tools <- chat$get_tools() + matched_tool <- tools[[obj$props$name]] + + if (!is.null(matched_tool)) { + return(matched_tool) + } + + # If no tool is found, return placeholder tool containing the metadata + ret <- contents_replay_class( + super(cls, S7::S7_object), + obj, + chat = chat + ) + ret +} + +prop_is_read_only <- function(prop) { + is.function(prop$getter) && !is.function(prop$setter) +} + +is_recorded_object <- function(x) { + is.list(x) && all(c("version", "class", "props") %in% names(x)) +} + +is_list_of_s7_objects <- function(x) { + is.list(x) && all(map_lgl(x, S7_inherits)) +} + +is_list_of_recorded_objects <- function(x) { + is.list(x) && all(map_lgl(x, is_recorded_object)) +} diff --git a/R/tools-def.R b/R/tools-def.R index dd88f02c..5e746ff7 100644 --- a/R/tools-def.R +++ b/R/tools-def.R @@ -1,5 +1,6 @@ #' @include utils-S7.R #' @include types.R +#' @include ellmer-package.R NULL #' Define a tool @@ -263,7 +264,7 @@ tool_reject <- function( ) { check_string(reason) - rlang::abort( + abort( paste("Tool call rejected.", reason), class = "ellmer_tool_reject" ) diff --git a/_pkgdown.yml b/_pkgdown.yml index 58a81321..4a4747d2 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -59,6 +59,7 @@ reference: - title: Utilities contents: - contents_text + - contents_record - params - title: Deprecated functions diff --git a/man/contents_record.Rd b/man/contents_record.Rd new file mode 100644 index 00000000..f8c9d3d5 --- /dev/null +++ b/man/contents_record.Rd @@ -0,0 +1,43 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/content-replay.R +\name{contents_record} +\alias{contents_record} +\alias{contents_replay} +\title{Save and restore content} +\usage{ +contents_record(content, ..., chat) + +contents_replay(obj, ..., chat) +} +\arguments{ +\item{content}{A \link{Turn} or \link{Content} object to serialize.} + +\item{...}{Not used.} + +\item{chat}{A \link{Chat} object to be used for context.} + +\item{obj}{A basic list to desierialize.} +} +\description{ +These generic functions can be use to convert \link{Turn}/\link{Content} objects +into easily serializable representations. +\itemize{ +\item \code{contents_record()} accept a \link{Turn} or \link{Content} and return a simple list. +\item \code{contents_replay()} will accept a simple list (from \code{contents_record()}) +and return a \link{Turn} or \link{Content} object. +} +} +\examples{ +\dontshow{if (has_credentials("openai")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +chat <- chat_openai(model = "gpt-4.1-nano") +chat$chat("Where is the capital of France?") + +# Serialize to a simple list +turn_recorded <- contents_record(chat$get_turns(), chat = chat) +str(turn_recorded) + +# Deserialize back to S7 objects +turn_replayed <- contents_replay(turn_recorded, chat = chat) +turn_replayed +\dontshow{\}) # examplesIf} +} diff --git a/man/ellmer-package.Rd b/man/ellmer-package.Rd index 93f5fc50..f609548d 100644 --- a/man/ellmer-package.Rd +++ b/man/ellmer-package.Rd @@ -27,6 +27,7 @@ Authors: \item Joe Cheng \item Aaron Jacobs \item Garrick Aden-Buie \email{garrick@posit.co} (\href{https://orcid.org/0000-0002-7111-0077}{ORCID}) + \item Barret Schloerke \email{barret@posit.co} (\href{https://orcid.org/0000-0001-9986-114X}{ORCID}) } Other contributors: diff --git a/tests/testthat/_snaps/content-replay.md b/tests/testthat/_snaps/content-replay.md new file mode 100644 index 00000000..4846271c --- /dev/null +++ b/tests/testthat/_snaps/content-replay.md @@ -0,0 +1,25 @@ +# non-ellmer classes are recorded/replayed by default + + Code + contents_record(LocalClass("testname"), chat = chat) + Condition + Error in `contents_record()`: + ! Only S7 classes from the `ellmer` package are currently supported. Received: "LocalClass". + +--- + + Code + contents_replay(list(version = 1, class = "testpkg::LocalClass", props = list( + name = "testname")), chat = chat) + Condition + Error in `contents_replay()`: + ! Only S7 classes from the `ellmer` package are currently supported. + +# unknown classes cause errors + + Code + contents_replay(recorded, chat = chat) + Condition + Error in `contents_replay()`: + ! Unable to find the S7 class: "ellmer::Turn2". + diff --git a/tests/testthat/helper-content-replay.R b/tests/testthat/helper-content-replay.R new file mode 100644 index 00000000..9f7acfb9 --- /dev/null +++ b/tests/testthat/helper-content-replay.R @@ -0,0 +1,43 @@ +test_record_replay <- function( + x, + ..., + chat = chat_openai_test(), + env = rlang::caller_env() +) { + rlang::check_dots_empty() + + shiny__to_json <- rlang::ns_env("shiny")[["toJSON"]] + shiny__safe_from_json <- rlang::ns_env("shiny")[["safeFromJSON"]] + + # Simulate the full bookmarking experience: + # * Record the object to something serializable + # * Serialize the object to JSON via shiny; "bookmark" + # * Unserialize the object from JSON via shiny; "restore" + # * Replay the unserialized object to the original object + # * Check that the replayed object has the same class as the original object + # * Check that the replayed object has the same properties as the original object + + obj <- contents_record(x, chat = chat) + + # obj_packed <- jsonlite:::pack(obj) + + # Work around Shiny's terrible JSON serialization + # Use `as.character()` to remove the JSON class so that it is double serialized :-/ + marshalled <- list( + "my_chat" = as.character(jsonlite::serializeJSON(obj)) + ) + + # Bookmark + serialized <- shiny__to_json(marshalled) + unserialized <- shiny__safe_from_json(serialized) + + # obj_unpacked <- jsonlite:::unpack(unserialized) + unmarshalled <- jsonlite::unserializeJSON(unserialized$my_chat) + + replayed <- contents_replay(unmarshalled, chat = chat, env = env) + + expect_s3_class(replayed, class(x)[1]) + expect_equal(S7::props(replayed), S7::props(x)) + + invisible(replayed) +} diff --git a/tests/testthat/test-content-replay.R b/tests/testthat/test-content-replay.R new file mode 100644 index 00000000..b72a872e --- /dev/null +++ b/tests/testthat/test-content-replay.R @@ -0,0 +1,258 @@ +# ------------------------------------------------------------------------- + +test_that("can round trip of Turn record/replay", { + test_record_replay(Turn("user")) + + test_record_replay(Turn( + "user", + list( + ContentText("hello world"), + ContentText("hello world2") + ) + )) +}) + + +test_that("can round trip of Content record/replay", { + test_record_replay(Content()) +}) + +test_that("can round trip of ContentText record/replay", { + test_record_replay(ContentText("hello world")) +}) + +test_that("can round trip of ContentImageInline record/replay", { + test_record_replay( + ContentImageInline("image/png", "abcd123") + ) +}) + +test_that("can round trip of ContentImageRemote record/replay", { + test_record_replay( + ContentImageRemote("https://example.com/image.jpg", detail = "") + ) +}) + +test_that("can round trip of ContentJson record/replay", { + test_record_replay( + ContentJson(list(a = 1:2, b = "apple")) + ) +}) + +test_that("can round trip of ContentSql record/replay", { + test_record_replay( + ContentSql("SELECT * FROM mtcars") + ) +}) + +test_that("can round trip of ContentSuggestions record/replay", { + test_record_replay( + ContentSuggestions( + c( + "What is the total quantity sold for each product last quarter?", + "What is the average discount percentage for orders from the United States?", + "What is the average price of products in the 'electronics' category?" + ) + ) + ) +}) + +test_that("can round trip of ContentThinking record/replay", { + test_record_replay( + ContentThinking("A **thought**.") + ) +}) + +test_that("can round trip of ContentTool record/replay", { + chat <- chat_openai_test() + tool_rnorm <- tool( + rnorm, + "Drawn numbers from a random normal distribution", + n = type_integer("The number of observations. Must be a positive integer."), + mean = type_number("The mean value of the distribution."), + sd = type_number( + "The standard deviation of the distribution. Must be a non-negative number." + ) + ) + chat$register_tool(tool_rnorm) + + test_record_replay( + ContentToolRequest("ID", "tool_name", list(a = 1:2, b = "apple")), + chat = chat + ) +}) + +test_that("can round trip of ToolDef record/replay", { + chat <- chat_openai_test() + tool_rnorm <- tool( + # Use `rnorm` to avoid loading the package... this causes the name to not be auto found + rnorm, + "Drawn numbers from a random normal distribution", + n = type_integer("The number of observations. Must be a positive integer."), + mean = type_number("The mean value of the distribution."), + sd = type_number( + "The standard deviation of the distribution. Must be a non-negative number." + ) + ) + chat$register_tool(tool_rnorm) + + test_record_replay(tool_rnorm, chat = chat) + + test_record_replay( + ContentToolRequest( + "ID", + "tool_name", + list(a = 1:2, b = "apple"), + tool = tool_rnorm + ), + chat = chat + ) + + recorded_tool <- contents_record(tool_rnorm, chat = chat) + chat_empty <- chat_openai_test() + replayed_tool <- contents_replay(recorded_tool, chat = chat_empty) + + tool_rnorm_empty <- ToolDef( + # rnorm, + name = "rnorm", + description = "Drawn numbers from a random normal distribution", + arguments = type_object( + n = type_integer( + "The number of observations. Must be a positive integer." + ), + mean = type_number("The mean value of the distribution."), + sd = type_number( + "The standard deviation of the distribution. Must be a non-negative number." + ) + ), + ) + + expect_equal( + replayed_tool, + tool_rnorm_empty + ) +}) + +test_that("can round trip of ContentToolResult record/replay", { + test_record_replay( + ContentToolResult( + value = "VALUE", + error = NULL, + extra = list(extra = 1:2, b = "apple"), + request = NULL + ) + ) + + chat <- chat_openai_test() + tool_rnorm <- tool( + stats::rnorm, + "Drawn numbers from a random normal distribution", + n = type_integer("The number of observations. Must be a positive integer."), + mean = type_number("The mean value of the distribution."), + sd = type_number( + "The standard deviation of the distribution. Must be a non-negative number." + ) + ) + chat$register_tool(tool_rnorm) + + replayed <- + test_record_replay( + ContentToolResult( + value = "VALUE", + error = try(stop("boom"), silent = TRUE), + extra = list(extra = 1:2, b = "apple"), + request = ContentToolRequest( + "ID", + "tool_name", + list(a = 1:2, b = "apple"), + tool = tool_rnorm + ) + ), + chat = chat + ) + + tryCatch( + signalCondition(replayed@error), # re-throw error + error = function(e) { + expect_equal( + e$message, + "boom" + ) + } + ) +}) + +test_that("can round trip of ContentUploaded record/replay", { + test_record_replay(ContentUploaded("https://example.com/image.jpg")) +}) + +test_that("can round trip of ContentPDF record/replay", { + test_record_replay(ContentPDF(type = "TYPE", data = "DATA")) +}) + +test_that("non-ellmer classes are not recorded/replayed by default", { + chat <- chat_openai_test() + + LocalClass <- S7::new_class( + "LocalClass", + properties = list( + name = prop_string() + ), + # Make sure to unset the package being used! + # Within testing, it sets the package to "ellmer" + package = NULL + ) + + expect_snapshot( + contents_record(LocalClass("testname"), chat = chat), + error = TRUE + ) + expect_snapshot( + contents_replay( + list( + version = 1, + class = "testpkg::LocalClass", + props = list(name = "testname") + ), + chat = chat + ), + error = TRUE + ) +}) + +test_that("unknown classes cause errors", { + chat <- chat_openai_test() + recorded <- contents_record(Turn("user"), chat = chat) + recorded$class <- "ellmer::Turn2" + + expect_error( + contents_replay(recorded, chat = chat), + "Unable to find the S7 class" + ) + + expect_snapshot(contents_replay(recorded, chat = chat), error = TRUE) +}) + +test_that("replay classes are S7 classes", { + OtherName <- S7::new_class( + "LocalClass", + properties = list( + name = prop_string() + ), + # Make sure to unset the package being used! + # Within testing, it sets the package to "ellmer" + package = NULL + ) + LocalClass <- function(name) { + OtherName(name = name) + } + + chat <- chat_openai_test() + recorded <- contents_record(LocalClass("testname"), chat = chat) + expect_error( + contents_replay(recorded, chat = chat), + "is not an S7 class" + ) + + expect_snapshot(contents_replay(recorded, chat = chat), error = TRUE) +})