From 6d13844d6546a5c29e7e121450af1130eb189e7f Mon Sep 17 00:00:00 2001 From: Grant G Date: Wed, 8 Jan 2025 16:44:02 -0800 Subject: [PATCH] fix: chat infinite loop and missing lines (#392) Fixes #390 --- crates/q_cli/src/cli/chat/mod.rs | 105 ++++++++++++++---- crates/q_cli/src/cli/chat/parse.rs | 31 ++++-- .../src/cli/chat/{terminal.rs => stdio.rs} | 0 3 files changed, 100 insertions(+), 36 deletions(-) rename crates/q_cli/src/cli/chat/{terminal.rs => stdio.rs} (100%) diff --git a/crates/q_cli/src/cli/chat/mod.rs b/crates/q_cli/src/cli/chat/mod.rs index ad27e19694..f66dcf3b1b 100644 --- a/crates/q_cli/src/cli/chat/mod.rs +++ b/crates/q_cli/src/cli/chat/mod.rs @@ -1,7 +1,7 @@ mod api; mod parse; mod prompt; -mod terminal; +mod stdio; use std::io::{ IsTerminal, @@ -16,13 +16,13 @@ use color_eyre::owo_colors::OwoColorize; use crossterm::style::{ Attribute, Color, - Print, }; use crossterm::{ cursor, execute, queue, style, + terminal, }; use eyre::{ Result, @@ -40,7 +40,7 @@ use spinners::{ Spinner, Spinners, }; -use terminal::StdioOutput; +use stdio::StdioOutput; use winnow::Partial; use winnow::stream::Offset; @@ -79,7 +79,8 @@ pub async fn chat(mut input: String) -> Result { } let mut output = StdioOutput::new(is_interactive); - let result = try_chat(&mut output, input, is_interactive).await; + let client = StreamingClient::new().await?; + let result = try_chat(&mut output, input, is_interactive, &client).await; if is_interactive { queue!(output, style::SetAttribute(Attribute::Reset), style::ResetColor).ok(); @@ -89,9 +90,13 @@ pub async fn chat(mut input: String) -> Result { result.map(|_| ExitCode::SUCCESS) } -async fn try_chat(output: &mut W, mut input: String, interactive: bool) -> Result<()> { +async fn try_chat( + output: &mut W, + mut input: String, + interactive: bool, + client: &StreamingClient, +) -> Result<()> { let mut rl = if interactive { Some(rl()?) } else { None }; - let client = StreamingClient::new().await?; let mut rx = None; let mut conversation_id: Option = None; let mut message_id = None; @@ -158,8 +163,8 @@ You can include additional context by adding the following to your prompt: let mut offset = 0; let mut ended = false; - let columns = crossterm::terminal::window_size()?.columns.into(); - let mut state = ParseState::new(columns); + let terminal_width = terminal::window_size().map(|s| s.columns.into()).ok(); + let mut state = ParseState::new(terminal_width); loop { if let Some(response) = rx.recv().await { @@ -229,11 +234,11 @@ You can include additional context by adding the following to your prompt: buf.push('\n'); } - if !buf.is_empty() && interactive { + if !buf.is_empty() && interactive && spinner.is_some() { drop(spinner.take()); queue!( output, - crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine), + terminal::Clear(terminal::ClearType::CurrentLine), cursor::MoveToColumn(0), cursor::Show )?; @@ -259,32 +264,26 @@ You can include additional context by adding the following to your prompt: } if ended { + if let (Some(conversation_id), Some(message_id)) = (&conversation_id, &message_id) { + fig_telemetry::send_chat_added_message(conversation_id.to_owned(), message_id.to_owned()).await; + } + if interactive { - queue!( - output, - style::ResetColor, - style::SetAttribute(Attribute::Reset), - Print("\n") - )?; + queue!(output, style::ResetColor, style::SetAttribute(Attribute::Reset))?; for (i, citation) in &state.citations { queue!( output, + style::Print("\n"), style::SetForegroundColor(Color::Blue), - style::Print(format!("{i} ")), + style::Print(format!("[^{i}]: ")), style::SetForegroundColor(Color::DarkGrey), style::Print(format!("{citation}\n")), style::SetForegroundColor(Color::Reset) )?; } - if !state.citations.is_empty() { - execute!(output, Print("\n"))?; - } - } - - if let (Some(conversation_id), Some(message_id)) = (&conversation_id, &message_id) { - fig_telemetry::send_chat_added_message(conversation_id.to_owned(), message_id.to_owned()).await; + execute!(output, style::Print("\n"))?; } break; @@ -313,6 +312,64 @@ You can include additional context by adding the following to your prompt: }, } } + } else { + break Ok(()); } } } + +#[cfg(test)] +mod test { + use fig_api_client::model::ChatResponseStream; + + use super::*; + + fn mock_client(s: impl IntoIterator) -> StreamingClient { + StreamingClient::mock( + s.into_iter() + .map(|s| ChatResponseStream::AssistantResponseEvent { content: s.into() }) + .collect(), + ) + } + + #[tokio::test] + async fn try_chat_non_interactive() { + let client = mock_client(["Hello,", " World", "!"]); + let mut output = Vec::new(); + try_chat(&mut output, "test".into(), false, &client).await.unwrap(); + + let mut expected = Vec::new(); + execute!( + expected, + style::Print("Hello, World!"), + style::ResetColor, + style::SetAttribute(Attribute::Reset), + style::Print("\n") + ) + .unwrap(); + + assert_eq!(expected, output); + } + + #[tokio::test] + async fn try_chat_non_interactive_citation() { + let client = mock_client(["Citation [[1]](https://aws.com)"]); + let mut output = Vec::new(); + try_chat(&mut output, "test".into(), false, &client).await.unwrap(); + + let mut expected = Vec::new(); + execute!( + expected, + style::Print("Citation "), + style::SetForegroundColor(Color::Blue), + style::Print("[^1]"), + style::ResetColor, + style::ResetColor, + style::SetAttribute(Attribute::Reset), + style::Print("\n") + ) + .unwrap(); + + assert_eq!(expected, output); + } +} diff --git a/crates/q_cli/src/cli/chat/parse.rs b/crates/q_cli/src/cli/chat/parse.rs index 1c7bb82f58..57a803c37a 100644 --- a/crates/q_cli/src/cli/chat/parse.rs +++ b/crates/q_cli/src/cli/chat/parse.rs @@ -54,6 +54,8 @@ const BLOCKQUOTE_COLOR: Color = Color::DarkGrey; const URL_TEXT_COLOR: Color = Color::Blue; const URL_LINK_COLOR: Color = Color::DarkGrey; +const DEFAULT_RULE_WIDTH: usize = 40; + #[derive(Debug, thiserror::Error)] pub enum Error<'a> { #[error(transparent)] @@ -82,7 +84,7 @@ impl<'a> ParserError> for Error<'a> { #[derive(Debug)] pub struct ParseState { - pub terminal_width: usize, + pub terminal_width: Option, pub column: usize, pub in_codeblock: bool, pub bold: bool, @@ -94,7 +96,7 @@ pub struct ParseState { } impl ParseState { - pub fn new(terminal_width: usize) -> Self { + pub fn new(terminal_width: Option) -> Self { Self { terminal_width, column: 0, @@ -270,7 +272,8 @@ fn horizontal_rule<'a, 'b>( state.column = 0; state.set_newline = true; - queue(&mut o, style::Print(format!("{}\n", "━".repeat(state.terminal_width)))) + let rule_width = state.terminal_width.unwrap_or(DEFAULT_RULE_WIDTH); + queue(&mut o, style::Print(format!("{}\n", "━".repeat(rule_width)))) } } @@ -394,7 +397,7 @@ fn citation<'a, 'b>( queue_newline_or_advance(&mut o, state, num.width() + 1)?; queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; - queue(&mut o, style::Print(format!("[{num}]")))?; + queue(&mut o, style::Print(format!("[^{num}]")))?; queue(&mut o, style::ResetColor) } } @@ -499,13 +502,17 @@ fn queue_newline_or_advance<'a, 'b>( state: &'b mut ParseState, width: usize, ) -> Result<(), ErrMode>> { - if state.column > 0 && state.column + width > state.terminal_width { - state.column = width; - queue(&mut o, style::Print('\n'))?; - } else { - state.column += width; + if let Some(terminal_width) = state.terminal_width { + if state.column > 0 && state.column + width > terminal_width { + state.column = width; + queue(&mut o, style::Print('\n'))?; + return Ok(()); + } } + // else + state.column += width; + Ok(()) } @@ -630,7 +637,7 @@ mod tests { input.push(' '); input.push(' '); - let mut state = ParseState::new(256); + let mut state = ParseState::new(Some(80)); let mut presult = vec![]; let mut offset = 0; @@ -686,7 +693,7 @@ mod tests { ]); validate!(citation_1, "[[1]](google.com)", [ style::SetForegroundColor(URL_TEXT_COLOR), - style::Print("[1]"), + style::Print("[^1]"), style::ResetColor, ]); validate!(bold_1, "**hello**", [ @@ -709,7 +716,7 @@ mod tests { validate!(ampersand_1, "&", [style::Print('&')]); validate!(quote_1, """, [style::Print('"')]); validate!(fallback_1, "+ % @ . ? ", [style::Print("+ % @ . ?")]); - validate!(horizontal_rule_1, "---", [style::Print("━".repeat(256))]); + validate!(horizontal_rule_1, "---", [style::Print("━".repeat(80))]); validate!(heading_1, "# Hello World", [ style::SetForegroundColor(HEADING_COLOR), style::SetAttribute(Attribute::Bold), diff --git a/crates/q_cli/src/cli/chat/terminal.rs b/crates/q_cli/src/cli/chat/stdio.rs similarity index 100% rename from crates/q_cli/src/cli/chat/terminal.rs rename to crates/q_cli/src/cli/chat/stdio.rs