Skip to content

Commit

Permalink
fix: chat infinite loop and missing lines (#392)
Browse files Browse the repository at this point in the history
Fixes #390
  • Loading branch information
grant0417 authored Jan 9, 2025
1 parent dc3af8c commit 6d13844
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 36 deletions.
105 changes: 81 additions & 24 deletions crates/q_cli/src/cli/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod api;
mod parse;
mod prompt;
mod terminal;
mod stdio;

use std::io::{
IsTerminal,
Expand All @@ -16,13 +16,13 @@ use color_eyre::owo_colors::OwoColorize;
use crossterm::style::{
Attribute,
Color,
Print,
};
use crossterm::{
cursor,
execute,
queue,
style,
terminal,
};
use eyre::{
Result,
Expand All @@ -40,7 +40,7 @@ use spinners::{
Spinner,
Spinners,
};
use terminal::StdioOutput;
use stdio::StdioOutput;
use winnow::Partial;
use winnow::stream::Offset;

Expand Down Expand Up @@ -79,7 +79,8 @@ pub async fn chat(mut input: String) -> Result<ExitCode> {
}

let mut output = StdioOutput::new(is_interactive);
let result = try_chat(&mut output, input, is_interactive).await;
let client = StreamingClient::new().await?;
let result = try_chat(&mut output, input, is_interactive, &client).await;

if is_interactive {
queue!(output, style::SetAttribute(Attribute::Reset), style::ResetColor).ok();
Expand All @@ -89,9 +90,13 @@ pub async fn chat(mut input: String) -> Result<ExitCode> {
result.map(|_| ExitCode::SUCCESS)
}

async fn try_chat<W: Write>(output: &mut W, mut input: String, interactive: bool) -> Result<()> {
async fn try_chat<W: Write>(
output: &mut W,
mut input: String,
interactive: bool,
client: &StreamingClient,
) -> Result<()> {
let mut rl = if interactive { Some(rl()?) } else { None };
let client = StreamingClient::new().await?;
let mut rx = None;
let mut conversation_id: Option<String> = None;
let mut message_id = None;
Expand Down Expand Up @@ -158,8 +163,8 @@ You can include additional context by adding the following to your prompt:
let mut offset = 0;
let mut ended = false;

let columns = crossterm::terminal::window_size()?.columns.into();
let mut state = ParseState::new(columns);
let terminal_width = terminal::window_size().map(|s| s.columns.into()).ok();
let mut state = ParseState::new(terminal_width);

loop {
if let Some(response) = rx.recv().await {
Expand Down Expand Up @@ -229,11 +234,11 @@ You can include additional context by adding the following to your prompt:
buf.push('\n');
}

if !buf.is_empty() && interactive {
if !buf.is_empty() && interactive && spinner.is_some() {
drop(spinner.take());
queue!(
output,
crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine),
terminal::Clear(terminal::ClearType::CurrentLine),
cursor::MoveToColumn(0),
cursor::Show
)?;
Expand All @@ -259,32 +264,26 @@ You can include additional context by adding the following to your prompt:
}

if ended {
if let (Some(conversation_id), Some(message_id)) = (&conversation_id, &message_id) {
fig_telemetry::send_chat_added_message(conversation_id.to_owned(), message_id.to_owned()).await;
}

if interactive {
queue!(
output,
style::ResetColor,
style::SetAttribute(Attribute::Reset),
Print("\n")
)?;
queue!(output, style::ResetColor, style::SetAttribute(Attribute::Reset))?;

for (i, citation) in &state.citations {
queue!(
output,
style::Print("\n"),
style::SetForegroundColor(Color::Blue),
style::Print(format!("{i} ")),
style::Print(format!("[^{i}]: ")),
style::SetForegroundColor(Color::DarkGrey),
style::Print(format!("{citation}\n")),
style::SetForegroundColor(Color::Reset)
)?;
}

if !state.citations.is_empty() {
execute!(output, Print("\n"))?;
}
}

if let (Some(conversation_id), Some(message_id)) = (&conversation_id, &message_id) {
fig_telemetry::send_chat_added_message(conversation_id.to_owned(), message_id.to_owned()).await;
execute!(output, style::Print("\n"))?;
}

break;
Expand Down Expand Up @@ -313,6 +312,64 @@ You can include additional context by adding the following to your prompt:
},
}
}
} else {
break Ok(());
}
}
}

#[cfg(test)]
mod test {
use fig_api_client::model::ChatResponseStream;

use super::*;

fn mock_client(s: impl IntoIterator<Item = &'static str>) -> StreamingClient {
StreamingClient::mock(
s.into_iter()
.map(|s| ChatResponseStream::AssistantResponseEvent { content: s.into() })
.collect(),
)
}

#[tokio::test]
async fn try_chat_non_interactive() {
let client = mock_client(["Hello,", " World", "!"]);
let mut output = Vec::new();
try_chat(&mut output, "test".into(), false, &client).await.unwrap();

let mut expected = Vec::new();
execute!(
expected,
style::Print("Hello, World!"),
style::ResetColor,
style::SetAttribute(Attribute::Reset),
style::Print("\n")
)
.unwrap();

assert_eq!(expected, output);
}

#[tokio::test]
async fn try_chat_non_interactive_citation() {
let client = mock_client(["Citation [[1]](https://aws.com)"]);
let mut output = Vec::new();
try_chat(&mut output, "test".into(), false, &client).await.unwrap();

let mut expected = Vec::new();
execute!(
expected,
style::Print("Citation "),
style::SetForegroundColor(Color::Blue),
style::Print("[^1]"),
style::ResetColor,
style::ResetColor,
style::SetAttribute(Attribute::Reset),
style::Print("\n")
)
.unwrap();

assert_eq!(expected, output);
}
}
31 changes: 19 additions & 12 deletions crates/q_cli/src/cli/chat/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ const BLOCKQUOTE_COLOR: Color = Color::DarkGrey;
const URL_TEXT_COLOR: Color = Color::Blue;
const URL_LINK_COLOR: Color = Color::DarkGrey;

const DEFAULT_RULE_WIDTH: usize = 40;

#[derive(Debug, thiserror::Error)]
pub enum Error<'a> {
#[error(transparent)]
Expand Down Expand Up @@ -82,7 +84,7 @@ impl<'a> ParserError<Partial<&'a str>> for Error<'a> {

#[derive(Debug)]
pub struct ParseState {
pub terminal_width: usize,
pub terminal_width: Option<usize>,
pub column: usize,
pub in_codeblock: bool,
pub bold: bool,
Expand All @@ -94,7 +96,7 @@ pub struct ParseState {
}

impl ParseState {
pub fn new(terminal_width: usize) -> Self {
pub fn new(terminal_width: Option<usize>) -> Self {
Self {
terminal_width,
column: 0,
Expand Down Expand Up @@ -270,7 +272,8 @@ fn horizontal_rule<'a, 'b>(
state.column = 0;
state.set_newline = true;

queue(&mut o, style::Print(format!("{}\n", "━".repeat(state.terminal_width))))
let rule_width = state.terminal_width.unwrap_or(DEFAULT_RULE_WIDTH);
queue(&mut o, style::Print(format!("{}\n", "━".repeat(rule_width))))
}
}

Expand Down Expand Up @@ -394,7 +397,7 @@ fn citation<'a, 'b>(

queue_newline_or_advance(&mut o, state, num.width() + 1)?;
queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?;
queue(&mut o, style::Print(format!("[{num}]")))?;
queue(&mut o, style::Print(format!("[^{num}]")))?;
queue(&mut o, style::ResetColor)
}
}
Expand Down Expand Up @@ -499,13 +502,17 @@ fn queue_newline_or_advance<'a, 'b>(
state: &'b mut ParseState,
width: usize,
) -> Result<(), ErrMode<Error<'a>>> {
if state.column > 0 && state.column + width > state.terminal_width {
state.column = width;
queue(&mut o, style::Print('\n'))?;
} else {
state.column += width;
if let Some(terminal_width) = state.terminal_width {
if state.column > 0 && state.column + width > terminal_width {
state.column = width;
queue(&mut o, style::Print('\n'))?;
return Ok(());
}
}

// else
state.column += width;

Ok(())
}

Expand Down Expand Up @@ -630,7 +637,7 @@ mod tests {
input.push(' ');
input.push(' ');

let mut state = ParseState::new(256);
let mut state = ParseState::new(Some(80));
let mut presult = vec![];
let mut offset = 0;

Expand Down Expand Up @@ -686,7 +693,7 @@ mod tests {
]);
validate!(citation_1, "[[1]](google.com)", [
style::SetForegroundColor(URL_TEXT_COLOR),
style::Print("[1]"),
style::Print("[^1]"),
style::ResetColor,
]);
validate!(bold_1, "**hello**", [
Expand All @@ -709,7 +716,7 @@ mod tests {
validate!(ampersand_1, "&amp;", [style::Print('&')]);
validate!(quote_1, "&quot;", [style::Print('"')]);
validate!(fallback_1, "+ % @ . ? ", [style::Print("+ % @ . ?")]);
validate!(horizontal_rule_1, "---", [style::Print("━".repeat(256))]);
validate!(horizontal_rule_1, "---", [style::Print("━".repeat(80))]);
validate!(heading_1, "# Hello World", [
style::SetForegroundColor(HEADING_COLOR),
style::SetAttribute(Attribute::Bold),
Expand Down
File renamed without changes.

0 comments on commit 6d13844

Please sign in to comment.