Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Nov 5, 2024
1 parent b21d8fd commit c8c2a20
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def run(dest_dir: str) -> None:
log.info("cherry_blossom.start")

# read dataset from meadow
# Read dataset from meadow.
ds_meadow = paths.load_dataset("cherry_blossom")
tb = ds_meadow.read_table("cherry_blossom")

Expand Down
39 changes: 20 additions & 19 deletions lib/repack/owid/repack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import pyarrow


def repack_frame(
Expand Down Expand Up @@ -59,10 +60,10 @@ def repack_frame(


def repack_series(s: pd.Series) -> pd.Series:
if s.dtype.name in ("Int64", "int64", "UInt64", "uint64"):
if s.dtype.name.replace("[pyarrow]", "") in ("Int64", "int64", "UInt64", "uint64"):
return shrink_integer(s)

if s.dtype.name in ("object", "string", "float64", "Float64"):
if s.dtype.name.replace("[pyarrow]", "") in ("object", "string", "float64", "Float64"):
for strategy in [to_int, to_float, to_category]:
try:
return strategy(s)
Expand All @@ -74,9 +75,9 @@ def repack_series(s: pd.Series) -> pd.Series:

def to_int(s: pd.Series) -> pd.Series:
# values could be integers or strings
v = s.astype("float64").astype("Int64")
v = s.astype("float64").astype("int64[pyarrow]")

if not series_eq(v, s, cast=float):
if not series_eq(v, s, cast="float64[pyarrow]"):
raise ValueError()

# it's an integer, now pack it smaller
Expand All @@ -85,26 +86,25 @@ def to_int(s: pd.Series) -> pd.Series:

def shrink_integer(s: pd.Series) -> pd.Series:
"""
Take an Int64 series and make it as small as possible.
Take an int64[pyarrow] series and make it as small as possible.
"""
assert s.dtype.name in ("Int64", "int64", "UInt64", "uint64")
assert s.dtype.name.replace("[pyarrow]", "") in ("Int64", "int64", "UInt64", "uint64"), s.dtype

if s.isnull().all():
# shrink all NaNs to Int8
return s.astype("Int8")
elif s.isnull().any():
if s.min() < 0:
series = ["Int32", "Int16", "Int8"]
else:
series = ["UInt32", "UInt16", "UInt8"]
return s.astype("int8[pyarrow]")
else:
if s.min() < 0:
series = ["int32", "int16", "int8"]
series = ["int32[pyarrow]", "int16[pyarrow]", "int8[pyarrow]"]
else:
series = ["uint32", "uint16", "uint8"]
series = ["uint32[pyarrow]", "uint16[pyarrow]", "uint8[pyarrow]"]

for dtype in series:
v = s.astype(dtype)
try:
v = s.astype(dtype)
except pyarrow.lib.ArrowInvalid:
break

if not (v == s).all():
break

Expand All @@ -114,11 +114,11 @@ def shrink_integer(s: pd.Series) -> pd.Series:


def to_float(s: pd.Series) -> pd.Series:
options = ["float32", "float64"]
options = ["float32[pyarrow]", "float64[pyarrow]"]
for dtype in options:
v = s.astype(dtype)

if series_eq(s, v, float):
if series_eq(s, v, "float64[pyarrow]"):
return v

raise ValueError()
Expand All @@ -145,9 +145,10 @@ def series_eq(lhs: pd.Series, rhs: pd.Series, cast: Any, rtol: float = 1e-5, ato
return False

# improve performance by calling native astype method
if cast == float:
func = lambda s: s.astype(float) # noqa: E731
if cast == "float64[pyarrow]":
func = lambda s: s.astype(cast) # noqa: E731
else:
raise NotImplementedError()
# NOTE: this would be extremely slow in practice
func = lambda s: s.apply(cast) # noqa: E731

Expand Down

0 comments on commit c8c2a20

Please sign in to comment.