Skip to content

Commit

Permalink
🎉 Get the new explain raw data function working
Browse files Browse the repository at this point in the history
  • Loading branch information
larsyencken committed Sep 18, 2024
1 parent 15b474d commit 45483e5
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 32 deletions.
64 changes: 32 additions & 32 deletions apps/wizard/app_pages/datainsight_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from typing import cast
from urllib import parse

import pandas as pd
import requests
import streamlit as st
from owid.catalog.charts import Chart

from apps.utils.gpt import OpenAIWrapper
from apps.wizard.app_pages.insights import (
fetch_chart_data,
fetch_data,
get_grapher_thumbnail,
get_thumbnail_url,
list_charts,
)
from etl.db import get_connection


Expand All @@ -23,52 +31,39 @@ class DataError(Exception):
page_icon="🪄",
)
st.title(":material/lightbulb: Data insighter")
st.markdown(f"Generate data insights from a chart view, using the `{MODEL}` model.")


# FUNCTIONS
def get_thumbnail_url(grapher_url: str) -> str:
"""
Turn https://ourworldindata.org/grapher/life-expectancy?country=~CHN"
Into https://ourworldindata.org/grapher/thumbnail/life-expectancy.png?country=~CHN
"""
assert grapher_url.startswith("https://ourworldindata.org/grapher/")
parts = parse.urlparse(grapher_url)

return f"{parts.scheme}://{parts.netloc}/grapher/thumbnail/{Path(parts.path).name}.png?{parts.query}"


def get_grapher_thumbnail(grapher_url: str) -> str:
url = get_thumbnail_url(grapher_url)
data = requests.get(url).content
return f"data:image/png;base64,{base64.b64encode(data).decode('utf8')}"


def get_trajectory_prompt(base_prompt: str, slug: str) -> str:
chart = Chart(slug)
df = chart.get_data()
st.warning(f"Chart has {len(df)} rows and {len(df.columns)} columns")
if len(df.columns) > 3:
raise DataError("This chart has more than 3 columns, which is not supported.")

(value_col,) = df.columns.difference(["entities", "years"])
df_s = df.round(1).query("years >= 2000").pivot(index="entities", columns="years", values=value_col).to_csv()
date_col = "years" if "years" in df.columns else "dates"

# shrink it
df = df.round(1)
if "years" in df.columns:
st.warning("NOTE: We are only looking at data from the year 2000 onwards")
df = df.query("years >= 2000")

if len(df.columns) == 3:
# shrink more via a pivot
(value_col,) = df.columns.difference(["entities", date_col])
df = df.pivot(index="entities", columns=date_col, values=value_col)

df_s = df.to_csv()

title = chart.config["title"]
subtitle = chart.config["subtitle"]

return f"{base_prompt}\n\n---\n\n## {title}\n\n{subtitle}\n\n{df_s}"


def list_charts(conn) -> list[str]:
with conn.cursor() as cur:
cur.execute("SELECT slug FROM chart_configs WHERE JSON_EXTRACT(config, '$.isPublished')")
return [slug for (slug,) in cur.fetchall()]


(tab1, tab2) = st.tabs(["Insight from chart", "Explain raw data"])

with tab1:
st.markdown(
f"Generate data insights from a chart view, using the `{MODEL}` model. Choose what to describe by selecting the chart and the countries and years you care about, then paste the link in here."
)
# PROMPT
default_prompt = """This is a chart from Our World In Data.
Expand Down Expand Up @@ -163,18 +158,22 @@ def list_charts(conn) -> list[str]:
response = cast(str, st.write_stream(stream))

with tab2:
st.markdown(
f"Generate insights from the raw data underlying a chart, using the `{MODEL}` model. In this case, ChatGPT is looking at all countries and all time periods at once."
)
conn = get_connection()
default_prompt = """This is an indicator published by Our World In Data.
Explain the core insights present in this data, in plain, educational language.
"""
all_charts = list_charts(conn)
slug = st.multiselect(
slugs = st.multiselect(
label="Grapher slug",
options=all_charts,
options=[None] + all_charts,
help="Introduce the URL to a Grapher URL. Query parameters work!",
key="tab2_url",
)
slug = None if len(slugs) == 0 else slugs[0]

with st.expander("Edit the prompt"):
prompt = st.text_area(
Expand All @@ -193,6 +192,7 @@ def list_charts(conn) -> list[str]:
# Opena AI (do first to catch possible errors in ENV)
api = OpenAIWrapper()

df = fetch_chart_data(conn, slug)
prompt_with_data = get_trajectory_prompt(prompt, slug) # type: ignore

# Prepare messages for Insighter
Expand Down
160 changes: 160 additions & 0 deletions apps/wizard/app_pages/insights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import base64
import datetime as dt
import json
from typing import Literal

import pandas as pd
import requests
from dateutil.parser import parse as date_parse

from etl.db import get_connection


def get_thumbnail_url(grapher_url: str) -> str:
"""
Turn https://ourworldindata.org/grapher/life-expectancy?country=~CHN"
Into https://ourworldindata.org/grapher/thumbnail/life-expectancy.png?country=~CHN
"""
assert grapher_url.startswith("https://ourworldindata.org/grapher/")
parts = parse.urlparse(grapher_url)

return f"{parts.scheme}://{parts.netloc}/grapher/thumbnail/{Path(parts.path).name}.png?{parts.query}"


def get_grapher_thumbnail(grapher_url: str) -> str:
url = get_thumbnail_url(grapher_url)
data = requests.get(url).content
return f"data:image/png;base64,{base64.b64encode(data).decode('utf8')}"


def fetch_chart_data(conn, slug: str) -> pd.DataFrame:
# Use the DB for as much as we can, and the API just for data and metadata
config = fetch_config(conn, slug)
return fetch_data(conn, config)


def list_charts(conn) -> list[str]:
with conn.cursor() as cur:
cur.execute(
"""
SELECT DISTINCT slug
FROM chart_configs
WHERE
JSON_EXTRACT(full, '$.isPublished')
AND slug IS NOT NULL
ORDER BY slug
"""
)
return [slug for (slug,) in cur.fetchall()]


def fetch_config(conn, slug: str) -> dict:
with conn.cursor() as cur:
cur.execute(
"""
SELECT full
FROM chart_configs
WHERE
slug = %s
AND JSON_EXTRACT(full, '$.isPublished')
""",
(slug,),
)
config = json.loads(cur.fetchone()[0])
if config is None:
raise ValueError(f"No published chart with slug {slug}")

return config


def fetch_data(conn, config: dict) -> pd.DataFrame:
dimensions = set(d["variableId"] for d in config["dimensions"])
bundle = {dim: _fetch_dimension(dim) for dim in dimensions}
df = _bundle_to_frame(config, bundle)
return df


def _indicator_to_frame(indicator: dict) -> pd.DataFrame:
data = indicator["data"]
metadata = indicator["metadata"]

# getting a data frame is easy
df = pd.DataFrame.from_dict(data)

# turning entity ids into entity names
entities = pd.DataFrame.from_records(metadata["dimensions"]["entities"]["values"])
id_to_name = entities.set_index("id").name.to_dict()
df["entities"] = df.entities.apply(id_to_name.__getitem__)

# make the "values" column more interestingly named
short_name = metadata.get("shortName", f'_{metadata["id"]}')
df = df.rename(columns={"values": short_name})

time_col = _detect_time_col_type(metadata)
if time_col == "dates":
df["years"] = _convert_years_to_dates(metadata, df["years"])

# order the columns better
cols = ["entities", "years"] + sorted(df.columns.difference(["entities", "years"]))
df = df[cols]

return df


def _detect_time_col_type(metadata) -> Literal["dates", "years"]:
if metadata.get("display", {}).get("yearIsDay"):
return "dates"

return "years"


def _convert_years_to_dates(metadata, years):
base_date = date_parse(metadata["display"]["zeroDay"])
return years.apply(lambda y: base_date + dt.timedelta(days=y))


def _fetch_dimension(id: int) -> dict:
data = requests.get(f"https://api.ourworldindata.org/v1/indicators/{id}.data.json").json()
metadata = requests.get(f"https://api.ourworldindata.org/v1/indicators/{id}.metadata.json").json()
return {"data": data, "metadata": metadata}


def _bundle_to_frame(config, bundle) -> pd.DataFrame:
# combine all the indicators into a single data frame and one metadata dict
metadata = {}
df = None
for dim in bundle.values():
to_merge = _indicator_to_frame(dim)
(value_col,) = to_merge.columns.difference(["entities", "years"])
metadata[value_col] = dim["metadata"].copy()

if df is None:
df = to_merge
else:
df = pd.merge(df, to_merge, how="outer", on=["entities", "years"])

assert df is not None

# save some useful metadata onto the frame
assert config
slug = config["slug"]
df.attrs["slug"] = slug
df.attrs["url"] = f"https://ourworldindata.org/grapher/{slug}"
df.attrs["metadata"] = metadata
df.attrs["config"] = config

# if there is only one indicator, we can use the slug as the column name
if len(df.columns) == 3:
assert config
(value_col,) = df.columns.difference(["entities", "years"])
short_name = slug.replace("-", "_")
df = df.rename(columns={value_col: short_name})
df.attrs["metadata"][short_name] = df.attrs["metadata"].pop(value_col)
df.attrs["value_col"] = short_name

# we kept using "years" until now to keep the code paths the same, but they could
# be dates
if df["years"].astype(str).str.match(r"^\d{4}-\d{2}-\d{2}$").all():
df = df.rename(columns={"years": "dates"})

return df

0 comments on commit 45483e5

Please sign in to comment.