Skip to content

Commit

Permalink
Bug/fix minor bugs (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBavenstrand committed Jun 30, 2023
2 parents 74315f4 + 588a016 commit b1607d2
Show file tree
Hide file tree
Showing 16 changed files with 377 additions and 721 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
include:
- { python: "3.10", os: "ubuntu-latest", session: "pre-commit" }
- { python: "3.10", os: "ubuntu-latest", session: "safety" }
- { python: "3.10", os: "ubuntu-latest", session: "mypy" }
- { python: "3.9", os: "ubuntu-latest", session: "mypy" }
- { python: "3.8", os: "ubuntu-latest", session: "mypy" }
- { python: "3.10", os: "ubuntu-latest", session: "pyright" }
- { python: "3.9", os: "ubuntu-latest", session: "pyright" }
- { python: "3.8", os: "ubuntu-latest", session: "pyright" }
- { python: "3.10", os: "ubuntu-latest", session: "tests" }
- { python: "3.9", os: "ubuntu-latest", session: "tests" }
- { python: "3.8", os: "ubuntu-latest", session: "tests" }
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
.mypy_cache/
/.coverage
/.coverage.*
/.nox/
Expand Down
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"config",
"cache",
"pipeline",
"data cleaning",
"data splitting",
"feature selection",
"semantic versioning",
Expand Down
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ Request features on the [Issue Tracker].

You need Python 3.8+ and the following tools:

- [Poetry]
- [Nox]
- [nox-poetry]
- [poetry](https://python-poetry.org/)
- [nox](https://nox.thea.codes/)
- [nox-poetry](https://nox-poetry.readthedocs.io/)

Install the package with development requirements:

Expand Down
6 changes: 6 additions & 0 deletions mleko/dataset/convert/csv_to_vaex_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import repeat
from pathlib import Path

import pyarrow as pa
import vaex
from pyarrow import csv as arrow_csv
from tqdm.auto import tqdm
Expand All @@ -14,6 +15,7 @@
from mleko.utils.custom_logger import CustomLogger
from mleko.utils.decorators import auto_repr
from mleko.utils.file_helpers import clear_directory
from mleko.utils.vaex_helpers import get_column

from .base_converter import BaseConverter

Expand Down Expand Up @@ -222,6 +224,10 @@ def _convert_csv_file_to_arrow(
),
).drop(drop_columns)

for column_name in df_chunk.get_column_names():
if get_column(df_chunk, column_name).dtype in (pa.date32(), pa.date64()):
df_chunk[column_name] = get_column(df_chunk, column_name).astype("datetime64[s]")

output_path = output_directory / f"df_chunk_{file_path.stem}.{dataframe_suffix}"
df_chunk.export(output_path, chunk_size=100_000, parallel=False)
df_chunk.close()
Expand Down
42 changes: 21 additions & 21 deletions mleko/dataset/ingest/s3_ingester.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
aws_profile_name: str | None = None,
aws_region_name: str = "eu-west-1",
num_workers: int = 64,
manifest_file_name: str = "manifest",
manifest_file_name: str | None = "manifest",
check_s3_timestamps: bool = True,
) -> None:
"""Initializes the S3 bucket client, configures the destination directory, and sets client-related parameters.
Expand Down Expand Up @@ -111,27 +111,27 @@ def fetch_data(self, force_recompute: bool = False) -> list[Path]:
raise Exception(
"Files in S3 are from muliples dates. This might mean the data is corrupted/duplicated."
)

manifest_file_key = next(
entry["Key"]
for entry in resp["Contents"]
if "Key" in entry and entry["Key"].endswith(self._manifest_file_name)
)

if not force_recompute and manifest_file_key:
self._s3_client.download_file(
Bucket=self._s3_bucket_name,
Key=manifest_file_key,
Filename=str(self._destination_directory / self._manifest_file_name),
if self._manifest_file_name is not None:
manifest_file_key = next(
entry["Key"]
for entry in resp["Contents"]
if "Key" in entry and entry["Key"].endswith(self._manifest_file_name)
)
with open(self._destination_directory / self._manifest_file_name) as f:
manifest: dict[str, Any] = json.load(f)
if self._is_local_dataset_fresh(manifest):
logger.info(
"\033[32mCache Hit\033[0m: Local dataset is up to date with S3 bucket contents, "
"skipping download."
)
return self._get_local_filenames(["gz", "csv", "zip"])

if not force_recompute and manifest_file_key:
self._s3_client.download_file(
Bucket=self._s3_bucket_name,
Key=manifest_file_key,
Filename=str(self._destination_directory / self._manifest_file_name),
)
with open(self._destination_directory / self._manifest_file_name) as f:
manifest: dict[str, Any] = json.load(f)
if self._is_local_dataset_fresh(manifest):
logger.info(
"\033[32mCache Hit\033[0m: Local dataset is up to date with S3 bucket contents, "
"skipping download."
)
return self._get_local_filenames(["gz", "csv", "zip"])

if force_recompute:
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion mleko/dataset/split/expression_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def split(
Returns:
A tuple containing the split dataframes.
"""
return self._cached_execute( # type: ignore
return self._cached_execute(
lambda_func=lambda: self._split(dataframe),
cache_keys=[
self._expression,
Expand Down
13 changes: 7 additions & 6 deletions mleko/dataset/split/random_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def split(
Returns:
A tuple containing the split dataframes.
"""
return self._cached_execute( # type: ignore
return self._cached_execute(
lambda_func=lambda: self._split(dataframe),
cache_keys=[
self._idx2_size,
Expand All @@ -118,9 +118,10 @@ def _split(self, dataframe: vaex.DataFrame) -> tuple[vaex.DataFrame, vaex.DataFr
A tuple containing the split dataframes.
"""
index_name = "index"
dataframe[index_name] = vaex.vrange(0, dataframe.shape[0])
index = get_column(dataframe, index_name)
target = get_column(dataframe, self._stratify).to_numpy() if self._stratify else None
df = dataframe.copy()
df[index_name] = vaex.vrange(0, df.shape[0])
index = get_column(df, index_name)
target = get_column(df, self._stratify).to_numpy() if self._stratify else None

if self._shuffle:
logger.info("Shuffling data before splitting.")
Expand All @@ -137,8 +138,8 @@ def _split(self, dataframe: vaex.DataFrame) -> tuple[vaex.DataFrame, vaex.DataFr
stratify=target,
)

df1 = get_filtered_df(dataframe, index.isin(idx1)).extract()
df2 = get_filtered_df(dataframe, index.isin(idx2)).extract()
df1 = get_filtered_df(df, index.isin(idx1)).extract()
df2 = get_filtered_df(df, index.isin(idx2)).extract()
logger.info(f"Split dataframe into two dataframes with shapes {df1.shape} and {df2.shape}.")
df1.delete_virtual_column(index_name)
df2.delete_virtual_column(index_name)
Expand Down
31 changes: 10 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@
from pathlib import Path
from textwrap import dedent

import nox


try:
from nox_poetry import Session, session
except ImportError:
message = f"""\
Nox failed to import the 'nox-poetry' package.
Please install it using the following command:
{sys.executable} -m pip install nox-poetry"""
raise SystemExit(dedent(message)) from None
import nox # type: ignore
from nox_poetry import Session, session # type: ignore


package = "mleko"
Expand All @@ -27,7 +16,7 @@
nox.options.sessions = (
"pre-commit",
"safety",
"mypy",
"pyright",
"tests",
"typeguard",
"docs-build",
Expand Down Expand Up @@ -141,21 +130,21 @@ def safety(session: Session) -> None:


@session(python=python_versions)
def mypy(session: Session) -> None:
"""Type-check using mypy."""
def pyright(session: Session) -> None:
"""Type-check using pyright."""
args = session.posargs or ["mleko", "docs/conf.py"]
session.install(".")
session.install("mypy", "pytest")
session.run("mypy", *args)
session.install("pyright", "pytest")
session.run("pyright", *args)
if not session.posargs:
session.run("mypy", f"--python-executable={sys.executable}", "noxfile.py")
session.run("pyright", f"--pythonpath={sys.executable}", "noxfile.py")


@session(python=python_versions)
def tests(session: Session) -> None:
"""Run the test suite."""
session.install(".")
session.install("coverage[toml]", "pytest", "pytest-mock", "pygments", "moto", "mypy_boto3_s3")
session.install("coverage[toml]", "pytest", "pytest-mock", "pygments", "moto")
try:
session.run(
"coverage", "run", "--parallel", "--concurrency=multiprocessing,thread", "-m", "pytest", *session.posargs
Expand Down Expand Up @@ -183,7 +172,7 @@ def coverage(session: Session) -> None:
def typeguard(session: Session) -> None:
"""Runtime type checking using Typeguard."""
session.install(".")
session.install("pytest", "typeguard", "pygments", "moto", "mypy_boto3_s3")
session.install("pytest", "typeguard", "pygments", "moto")
session.run("pytest", f"--typeguard-packages={package}", *session.posargs)


Expand Down
Loading

0 comments on commit b1607d2

Please sign in to comment.