Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg-py/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### New features

* New `QueryChat.app()` method enables quicker/easier chatting with a dataset. (#104)

* Enabled bookmarking by default in both `.app()` and `.server()` methods. In latter case, you'll need to also specify the `bookmark_store` (either in `shiny.App()` or `shiny.express.app_opts()`) for it to take effect. (#104)

* The current SQL query and title can now be programmatically set through the `.sql()` and `.title()` methods of `QueryChat()`. (#98, #101)

* Added a `.generate_greeting()` method to help you create a greeting message for your querychat bot. (#87)
Expand Down
20 changes: 20 additions & 0 deletions pkg-py/src/querychat/_icons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Literal

from shiny import ui

ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"]


def bs_icon(name: ICON_NAMES) -> ui.HTML:
"""Get Bootstrap icon SVG by name."""
if name not in BS_ICONS:
raise ValueError(f"Unknown Bootstrap icon: {name}")
return ui.HTML(BS_ICONS[name])


BS_ICONS = {
"arrow-counterclockwise": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-arrow-counterclockwise" style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img"><path fill-rule="evenodd" d="M8 3a5 5 0 1 1-4.546 2.914.5.5 0 0 0-.908-.417A6 6 0 1 0 8 2v1z"></path><path d="M8 4.466V.534a.25.25 0 0 0-.41-.192L5.23 2.308a.25.25 0 0 0 0 .384l2.36 1.966A.25.25 0 0 0 8 4.466z"></path></svg>',
"funnel-fill": '<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-funnel-fill" viewBox="0 0 16 16"><path d="M1.5 1.5A.5.5 0 0 1 2 1h12a.5.5 0 0 1 .5.5v2a.5.5 0 0 1-.128.334L10 8.692V13.5a.5.5 0 0 1-.342.474l-3 1A.5.5 0 0 1 6 14.5V8.692L1.628 3.834A.5.5 0 0 1 1.5 3.5z"/></svg>',
"terminal-fill": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-terminal-fill " style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img" ><path d="M0 3a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v10a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2V3zm9.5 5.5h-3a.5.5 0 0 0 0 1h3a.5.5 0 0 0 0-1zm-6.354-.354a.5.5 0 1 0 .708.708l2-2a.5.5 0 0 0 0-.708l-2-2a.5.5 0 1 0-.708.708L4.793 6.5 3.146 8.146z"></path></svg>',
"table": '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" class="bi bi-table " style="height:1em;width:1em;fill:currentColor;vertical-align:-0.125em;" aria-hidden="true" role="img" ><path d="M0 2a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v12a2 2 0 0 1-2 2H2a2 2 0 0 1-2-2V2zm15 2h-4v3h4V4zm0 4h-4v3h4V8zm0 4h-4v3h3a1 1 0 0 0 1-1v-2zm-5 3v-3H6v3h4zm-5 0v-3H1v2a1 1 0 0 0 1 1h3zm-4-4h4V8H1v3zm0-4h4V4H1v3zm5-3v3h4V4H6zm4 4H6v3h4V8z"></path></svg>',
}
181 changes: 164 additions & 17 deletions pkg-py/src/querychat/_querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import chatlas
import chevron
import sqlalchemy
from shiny import ui
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from shiny.express._stub_session import ExpressStubSession
from shiny.session import get_current_session
from shinychat import output_markdown_stream

from ._icons import bs_icon
from ._querychat_module import ModServerResult, mod_server, mod_ui
from .datasource import DataFrameSource, DataSource, SQLAlchemySource

Expand Down Expand Up @@ -133,6 +136,99 @@ def __init__(
# Populated when ._server() gets called (in an active session)
self._server_values: ModServerResult | None = None

def app(
self, *, bookmark_store: Literal["url", "server", "disable"] = "url"
) -> App:
"""
Quickly chat with a dataset.

Creates a Shiny app with a chat sidebar and data table view -- providing a
quick-and-easy way to start chatting with your data.

Parameters
----------
bookmark_store
The bookmarking store to use for the Shiny app. Options are:
- `"url"`: Store bookmarks in the URL (default).
- `"server"`: Store bookmarks on the server.
- `"disable"`: Disable bookmarking.

Returns
-------
:
A Shiny App object that can be run with `app.run()` or served with `shiny run`.

"""
enable_bookmarking = bookmark_store != "disable"
table_name = self.data_source.table_name

def app_ui(request):
return ui.page_sidebar(
self.sidebar(),
ui.card(
ui.card_header(
ui.div(
ui.div(
bs_icon("terminal-fill"),
ui.output_text("query_title", inline=True),
class_="d-flex align-items-center gap-2",
),
ui.output_ui("ui_reset", inline=True),
class_="hstack gap-3",
),
),
ui.output_ui("sql_output"),
fill=False,
style="max-height: 33%;",
),
ui.card(
ui.card_header(bs_icon("table"), " Data"),
ui.output_data_frame("dt"),
),
title=ui.span("querychat with ", ui.code(table_name)),
class_="bslib-page-dashboard",
fillable=True,
)

def app_server(input: Inputs, output: Outputs, session: Session):
self._server(enable_bookmarking=enable_bookmarking)

@render.text
def query_title():
return self.title() or "SQL Query"

@render.ui
def ui_reset():
req(self.sql())
return ui.input_action_button(
"reset_query",
"Reset Query",
class_="btn btn-outline-danger btn-sm lh-1 ms-auto",
)

@reactive.effect
@reactive.event(input.reset_query)
def _():
self.sql("")
self.title(None)

@render.data_frame
def dt():
return self.df()

@render.ui
def sql_output():
sql = self.sql() or f"SELECT * FROM {table_name}"
sql_code = f"```sql\n{sql}\n```"
return output_markdown_stream(
"sql_code",
content=sql_code,
auto_scroll=False,
width="100%",
)

return App(app_ui, app_server, bookmark_store=bookmark_store)

def sidebar(
self,
*,
Expand Down Expand Up @@ -183,7 +279,7 @@ def ui(self, **kwargs):
"""
return mod_ui(self.id, **kwargs)

def _server(self):
def _server(self, *, enable_bookmarking: bool = False) -> None:
"""
Initialize the server module.

Expand Down Expand Up @@ -211,6 +307,7 @@ def _server(self):
system_prompt=self.system_prompt,
greeting=self.greeting,
client=self.client,
enable_bookmarking=enable_bookmarking,
)

return
Expand Down Expand Up @@ -434,7 +531,7 @@ def set_client(self, client: str | chatlas.Chat) -> None:


class QueryChat(QueryChatBase):
def server(self):
def server(self, *, enable_bookmarking: bool = False) -> None:
"""
Initialize Shiny server logic.

Expand All @@ -443,29 +540,48 @@ def server(self):
Express mode, you can use `querychat.express.QueryChat` instead
of `querychat.QueryChat`, which calls `.server()` automatically.

Parameters
----------
enable_bookmarking
Whether to enable bookmarking for the querychat module.

Examples
--------
```python
from shiny import App, render, ui
from seaborn import load_dataset
from querychat import QueryChat

qc = QueryChat(my_dataframe, "my_data")
titanic = load_dataset("titanic")

app_ui = ui.page_fluid(
qc.sidebar(),
ui.output_data_frame("data_table"),
)
qc = QueryChat(titanic, "titanic")


def app_ui(request):
return ui.page_sidebar(
qc.sidebar(),
ui.card(
ui.card_header(ui.output_text("title")),
ui.output_data_frame("data_table"),
),
title="Titanic QueryChat App",
fillable=True,
)


def server(input, output, session):
qc.server()
qc.server(enable_bookmarking=True)

@render.data_frame
def data_table():
return qc.df()

@render.text
def title():
return qc.title() or "My Data"

app = App(app_ui, server)

app = App(app_ui, server, bookmark_store="url")
```

Returns
Expand All @@ -474,7 +590,7 @@ def data_table():
None

"""
return self._server()
return self._server(enable_bookmarking=enable_bookmarking)


class QueryChatExpress(QueryChatBase):
Expand All @@ -488,17 +604,33 @@ class QueryChatExpress(QueryChatBase):
Examples
--------
```python
from shiny.express import render, ui
from querychat.express import QueryChat
from seaborn import load_dataset
from shiny.express import app_opts, render, ui

qc = QueryChat(my_dataframe, "my_data")
titanic = load_dataset("titanic")

qc = QueryChat(titanic, "titanic")
qc.sidebar()

with ui.card(fill=True):
with ui.card_header():

@render.text
def title():
return qc.title() or "Titanic Dataset"

@render.data_frame
def data_table():
return qc.df()

@render.data_frame
def data_table():
return qc.df()

ui.page_opts(
title="Titanic QueryChat App",
fillable=True,
)

app_opts(bookmark_store="url")
```

"""
Expand All @@ -514,6 +646,7 @@ def __init__(
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
prompt_template: Optional[str | Path] = None,
enable_bookmarking: Literal["auto", True, False] = "auto",
):
super().__init__(
data_source,
Expand All @@ -525,7 +658,21 @@ def __init__(
extra_instructions=extra_instructions,
prompt_template=prompt_template,
)
self._server()

# If the Express session has a bookmark store set, automatically enable
# querychat's bookmarking
enable: bool
if enable_bookmarking == "auto":
session = get_current_session()
if session and isinstance(session, ExpressStubSession):
store = session.app_opts.get("bookmark_store", "disable")
enable = store != "disable"
else:
enable = False
else:
enable = enable_bookmarking

self._server(enable_bookmarking=enable)


def normalize_data_source(
Expand Down
28 changes: 28 additions & 0 deletions pkg-py/src/querychat/_querychat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import chatlas
import pandas as pd
from shiny import Inputs, Outputs, Session
from shiny.bookmark import BookmarkState, RestoreState

from .datasource import DataSource

Expand Down Expand Up @@ -60,10 +61,12 @@ def mod_server(
system_prompt: str,
greeting: str | None,
client: chatlas.Chat,
enable_bookmarking: bool,
):
# Reactive values to store state
sql = ReactiveString("")
title = ReactiveStringOrNone(None)
has_greeted = reactive.value[bool](False) # noqa: FBT003

# Set up the chat object for this session
chat = copy.deepcopy(client)
Expand Down Expand Up @@ -99,6 +102,9 @@ async def _(user_input: str):

@reactive.effect
async def greet_on_startup():
if has_greeted():
return

if greeting:
await chat_ui.append_message(greeting)
elif greeting is None:
Expand All @@ -108,6 +114,8 @@ async def greet_on_startup():
)
await chat_ui.append_message_stream(stream)

has_greeted.set(True)

# Handle update button clicks
@reactive.effect
@reactive.event(input.chat_update)
Expand All @@ -125,4 +133,24 @@ def _():
if new_title is not None:
title.set(new_title)

if enable_bookmarking:
chat_ui.enable_bookmarking(client)

@session.bookmark.on_bookmark
def _on_bookmark(x: BookmarkState) -> None:
vals = x.values # noqa: PD011
vals["querychat_sql"] = sql.get()
vals["querychat_title"] = title.get()
vals["querychat_has_greeted"] = has_greeted.get()

@session.bookmark.on_restore
def _on_restore(x: RestoreState) -> None:
vals = x.values # noqa: PD011
if "querychat_sql" in vals:
sql.set(vals["querychat_sql"])
if "querychat_title" in vals:
title.set(vals["querychat_title"])
if "querychat_has_greeted" in vals:
has_greeted.set(vals["querychat_has_greeted"])

return ModServerResult(df=filtered_df, sql=sql, title=title, client=chat)
Loading