Skip to content

Commit

Permalink
🎉 Switch to pyarrow dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Nov 5, 2024
1 parent 07061d1 commit 9ec8b6c
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 30 deletions.
4 changes: 2 additions & 2 deletions etl/data_helpers/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def _add_population_to_dataframe(

# Load population data.
if ds_population is not None:
population = ds_population.read_table("population")
population = ds_population.read_table("population", safe_types=False)
else:
population = _load_population()
population = population.rename(
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def make_table_population_daily(ds_population: Dataset, year_min: int, year_max:
Uses linear interpolation.
"""
# Load population table
population = ds_population.read_table("population")
population = ds_population.read_table("population", safe_types=False)
# Filter only years of interest
population = population[(population["year"] >= year_min) & (population["year"] <= year_max)]
# Create date column
Expand Down
2 changes: 1 addition & 1 deletion etl/data_helpers/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def add_population(
log.warning(f"Dataset {ds_un_wpp_path} is silently being loaded.")
# Load granular population dataset
ds_un_wpp = Dataset(ds_un_wpp_path)
pop = ds_un_wpp.read_table("population_granular") # type: ignore
pop = ds_un_wpp.read_table("population_granular", safe_types=False) # type: ignore
# Keep only variant='medium'
pop = pop[pop["variant"] == "medium"].drop(columns=["variant"])
# Keep only metric='population'
Expand Down
2 changes: 1 addition & 1 deletion etl/grapher_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def variable_data_table_from_catalog(
tbs = []
for (ds_path, table_name), variables in to_read.items():
try:
tb = Dataset(DATA_DIR / ds_path).read_table(table_name)
tb = Dataset(DATA_DIR / ds_path).read_table(table_name, safe_types=False)
except FileNotFoundError as e:
raise FileNotFoundError(f"Dataset {ds_path} not found in local catalog.") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(dest_dir: str) -> None:

# read dataset from meadow
ds_meadow = paths.load_dataset("cherry_blossom")
tb = ds_meadow["cherry_blossom"].reset_index()
tb = ds_meadow.read_table("cherry_blossom")

# Calculate a 20,40 and 50 year average
tb = calculate_multiple_year_average(tb)
Expand Down
18 changes: 10 additions & 8 deletions etl/steps/data/garden/covid/latest/cases_deaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def discard_rows(tb: Table):
print("Discarding rows…")
# For all rows where new_cases or new_deaths is negative, we keep the cumulative value but set
# the daily change to NA. This also sets the 7-day rolling average to NA for the next 7 days.
tb.loc[tb["new_cases"] < 0, "new_cases"] = np.nan
tb.loc[tb["new_deaths"] < 0, "new_deaths"] = np.nan
tb.loc[tb["new_cases"] < 0, "new_cases"] = pd.NA
tb.loc[tb["new_deaths"] < 0, "new_deaths"] = pd.NA

# Custom data corrections
for ldc in LARGE_DATA_CORRECTIONS:
tb.loc[(tb["country"] == ldc[0]) & (tb["date"].astype(str) == ldc[1]), f"new_{ldc[2]}"] = np.nan
tb.loc[(tb["country"] == ldc[0]) & (tb["date"].astype(str) == ldc[1]), f"new_{ldc[2]}"] = pd.NA

for ldc in LARGE_DATA_CORRECTIONS_SINCE:
tb.loc[(tb["country"] == ldc[0]) & (tb["date"].astype(str) >= ldc[1]), f"new_{ldc[2]}"] = np.nan
tb.loc[(tb["country"] == ldc[0]) & (tb["date"].astype(str) >= ldc[1]), f"new_{ldc[2]}"] = pd.NA

# Sort (legacy)
tb = tb.sort_values(["country", "date"])
Expand Down Expand Up @@ -216,8 +216,8 @@ def add_period_aggregates(tb: Table, prefix: str, periods: int):
)

# Set NaNs where the original data was NaN
tb.loc[tb["new_cases"].isnull(), cases_colname] = np.nan
tb.loc[tb["new_deaths"].isnull(), deaths_colname] = np.nan
tb.loc[tb["new_cases"].isnull(), cases_colname] = pd.NA
tb.loc[tb["new_deaths"].isnull(), deaths_colname] = pd.NA

return tb

Expand Down Expand Up @@ -247,7 +247,7 @@ def add_doubling_days(tb: Table) -> Table:
for col, spec in DOUBLING_DAYS_SPEC.items():
value_col = spec["value_col"]
periods = spec["periods"]
tb.loc[tb[value_col] == 0, value_col] = np.nan
tb.loc[tb[value_col] == 0, value_col] = pd.NA
tb[col] = (
tb.groupby("country", as_index=False)[value_col]
.pct_change(periods=periods, fill_method=None)
Expand Down Expand Up @@ -336,6 +336,8 @@ def _apply_row_cfr_100(row):
return pd.NA

tb["cfr"] = 100 * tb["total_deaths"] / tb["total_cases"]
# 0/0 returns np.nan and not pd.NA which would be more natural for Float64
tb["cfr"] = tb["cfr"].mask(np.isnan(tb["cfr"]), pd.NA)
tb["cfr_100_cases"] = tb.apply(_apply_row_cfr_100, axis=1)
tb["cfr_100_cases"] = tb["cfr_100_cases"].copy_metadata(tb["cfr"])

Expand All @@ -348,7 +350,7 @@ def _apply_row_cfr_100(row):
tb.loc[
(tb["cfr_short_term"] < 0) | (tb["cfr_short_term"] > 10) | (tb["date"].astype(str) < "2020-09-01"),
"cfr_short_term",
] = np.nan
] = pd.NA

# Replace inf
cols = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def run(dest_dir: str) -> None:
tb = pr.merge(tb, tb_strategy, on=["country", "year"], how="outer", short_name=paths.short_name)

# Fill nan in type_of_strategy with Not applicable
tb["type_of_strategy"] = tb["type_of_strategy"].astype(str)
tb.loc[tb["type_of_strategy"] == "nan", "type_of_strategy"] = "Not applicable"
tb["type_of_strategy"] = tb["type_of_strategy"].astype("string").fillna("Not applicable")

tb = tb.format(["country", "year"])

Expand Down
2 changes: 1 addition & 1 deletion etl/steps/data/grapher/un/2024-08-27/un_sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run(dest_dir: str) -> None:

log.info("un_sdg.process", table_name=var)

tb = ds_garden.read_table(var)
tb = ds_garden.read_table(var, safe_types=False)

tb = create_table(tb)

Expand Down
13 changes: 9 additions & 4 deletions lib/catalog/owid/catalog/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from os import environ
from os.path import join
from pathlib import Path
from typing import Any, Dict, Iterator, List, Literal, Optional, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast

import numpy as np
import pandas as pd
import yaml
from _hashlib import HASH

from owid.repack import to_safe_types

from . import tables, utils
from .meta import SOURCE_EXISTS_OPTIONS, DatasetMeta, TableMeta
from .processing_log import disable_processing_log
Expand Down Expand Up @@ -153,27 +155,30 @@ def add(
table_filename = join(self.path, table.metadata.checked_name + f".{format}")
table.to(table_filename, repack=repack)

def read_table(self, name: str, reset_index: bool = True) -> tables.Table:
def read_table(self, name: str, reset_index: bool = True, safe_types: bool = True) -> tables.Table:
"""Read dataset's table from disk. Alternative to ds[table_name], but
with more options to optimize the reading.
:param reset_index: If true, don't set primary keys of the table. This can make loading
large datasets with multi-indexes much faster.
:param safe_types: If true, convert numeric columns to Float64 and Int64 and categorical
columns to string[python]. This can significantly increase memory usage.
"""
stem = self.path / Path(name)

for format in SUPPORTED_FORMATS:
path = stem.with_suffix(f".{format}")
if path.exists():
t = tables.Table.read(path, primary_key=[] if reset_index else None)
# dataset metadata might have been updated, refresh it
t.metadata.dataset = self.metadata
if safe_types:
t = cast(tables.Table, to_safe_types(t))
return t

raise KeyError(f"Table `{name}` not found, available tables: {', '.join(self.table_names)}")

def __getitem__(self, name: str) -> tables.Table:
return self.read_table(name, reset_index=False)
return self.read_table(name, reset_index=False, safe_types=False)

def __contains__(self, name: str) -> bool:
return any((Path(self.path) / name).with_suffix(f".{format}").exists() for format in SUPPORTED_FORMATS)
Expand Down
3 changes: 2 additions & 1 deletion lib/catalog/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ dependencies = [
"Unidecode>=1.3.4",
"PyYAML>=6.0.1",
"structlog>=21.5.0",
"owid-repack>=0.1.1",
"dynamic-yaml>=1.3.5",
"mistune>=3.0.1",
"dataclasses-json==0.5.8",
"rdata==0.9",
"owid-datautils",
"owid-repack",
]

[tool.uv]
Expand All @@ -38,6 +38,7 @@ dev-dependencies = [

[tool.uv.sources]
owid-datautils = { path = "../datautils", editable = true }
owid-repack = { path = "../repack", editable = true }

[tool.ruff]
extend = "../../pyproject.toml"
Expand Down
23 changes: 16 additions & 7 deletions lib/catalog/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions lib/repack/owid/repack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,28 @@ def series_eq(lhs: pd.Series, rhs: pd.Series, cast: Any, rtol: float = 1e-5, ato
func = lambda s: s.apply(cast) # noqa: E731

return np.allclose(func(lhs), func(rhs), rtol=rtol, atol=atol, equal_nan=True)


def _safe_dtype(dtype: Any) -> str:
"""Determine the appropriate dtype string based on pandas dtype."""
if pd.api.types.is_integer_dtype(dtype):
return "Int64"
elif pd.api.types.is_float_dtype(dtype):
return "Float64"
elif isinstance(dtype, pd.CategoricalDtype):
return "string[python]"
else:
return dtype


def to_safe_types(t: pd.DataFrame) -> pd.DataFrame:
"""Convert numeric columns to Float64 and Int64 and categorical
columns to string[python]. This can significantly increase memory usage."""
t = t.astype({col: _safe_dtype(t[col].dtype) for col in t.columns})

if isinstance(t.index, pd.MultiIndex):
t.index = t.index.set_levels([level.astype(_safe_dtype(level.dtype)) for level in t.index.levels])
else:
t.index = t.index.astype(_safe_dtype(t.index.dtype))

return t
70 changes: 70 additions & 0 deletions lib/repack/tests/test_repack.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,73 @@ def test_repack_string_type():

v = repack.repack_series(s)
assert v.dtype == "category"


def test_to_safe_types():
# Create a DataFrame with various dtypes
df = pd.DataFrame(
{
"int_col": [1, 2, 3],
"float_col": [1.1, 2.2, 3.3],
"cat_col": pd.Categorical(["a", "b", "c"]),
"object_col": ["x", "y", "z"],
}
)

# Set an index with integer dtype
df.set_index("int_col", inplace=True)

# Apply the to_safe_types function
df_safe = repack.to_safe_types(df)

# Check that the dtypes have been converted appropriately
assert df_safe.index.dtype == "Int64"
assert df_safe["float_col"].dtype == "Float64"
assert df_safe["cat_col"].dtype == "string[python]"
# 'object_col' should remain unchanged
assert df_safe["object_col"].dtype == "object"


def test_to_safe_types_multiindex():
# Create a DataFrame with MultiIndex
df = pd.DataFrame(
{
"int_col": [1, 2, 3],
"cat_col": pd.Categorical(["a", "b", "c"]),
"float_col": [1.1, 2.2, 3.3],
}
)
df.set_index(["int_col", "cat_col"], inplace=True)

# Apply the to_safe_types function
df_safe = repack.to_safe_types(df)

# Check index levels
assert df_safe.index.levels[0].dtype == "Int64" # type: ignore
assert df_safe.index.levels[1].dtype == "string[python]" # type: ignore
# Check column dtype
assert df_safe["float_col"].dtype == "Float64"


def test_to_safe_types_with_nan():
# Create a DataFrame with NaN values
df = pd.DataFrame(
{
"int_col": [1, 2, 3],
"float_col": [1.1, np.nan, 3.3],
"cat_col": pd.Categorical(["a", None, "c"]),
}
)
df.set_index("float_col", inplace=True)

# Apply the to_safe_types function
df_safe = repack.to_safe_types(df)

# Check that NaN values are handled correctly
assert df_safe.index.dtype == "Float64"
assert df_safe["int_col"].dtype == "Int64"
assert df_safe["cat_col"].dtype == "string[python]"

# Ensure that the NA value in 'cat_col' remains pd.NA and not the string "NA"
assert pd.isna(df_safe["cat_col"].iloc[1])
assert df_safe["cat_col"].iloc[1] is pd.NA
2 changes: 1 addition & 1 deletion lib/repack/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9ec8b6c

Please sign in to comment.