Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert directly to spark dataframe from download #113

Merged
merged 2 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Callable, List, Optional, Tuple, Dict
from typing import Callable, List, Optional, Tuple, Union, Any
from cleanlab_studio.errors import APIError

import requests
Expand All @@ -9,6 +9,13 @@
import numpy as np
import numpy.typing as npt

try:
import pyspark.sql

pyspark_exists = True
except ImportError:
pyspark_exists = False

from cleanlab_studio.internal.types import JSONDict
from cleanlab_studio.version import __version__

Expand Down Expand Up @@ -155,25 +162,44 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str:
return label_column


def download_cleanlab_columns(api_key: str, cleanset_id: str, all: bool = False) -> pd.DataFrame:
def download_cleanlab_columns(
api_key: str,
cleanset_id: str,
all: bool = True,
to_spark: bool = False,
) -> Any:
"""
Download all rows from specified Cleanlab columns

:param api_key:
:param cleanset_id:
:param all: whether to download all Cleanlab columns or just the clean_label column
:return: return (rows, id_column)
:return: return a dataframe, either pandas or spark. Type is Any because don't want to require spark installed
"""
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/columns?all={all}",
cli_base_url + f"/cleansets/{cleanset_id}/columns",
params=dict(to_spark=to_spark, all=all),
headers=_construct_headers(api_key),
)
handle_api_error(res)
cleanset_json: str = res.json()["cleanset_json"]
cleanset_df: pd.DataFrame = pd.read_json(cleanset_json, orient="table")
id_col = get_id_column(api_key, cleanset_id)
cleanset_df.rename(columns={"id": id_col}, inplace=True)
return cleanset_df
cleanset_json: str = res.json()["cleanset_json"]
if to_spark:
if not pyspark_exists:
raise ImportError(
"pyspark is not installed. Please install pyspark to download cleanlab columns as a pyspark DataFrame."
)
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize([cleanset_json])
cleanset_pyspark: pyspark.sql.DataFrame = spark.read.json(rdd)
cleanset_pyspark = cleanset_pyspark.withColumnRenamed("id", id_col)
return cleanset_pyspark

cleanset_pd: pd.DataFrame = pd.read_json(cleanset_json, orient="table")
cleanset_pd.rename(columns={"id": id_col}, inplace=True)
return cleanset_pd


def download_numpy(api_key: str, cleanset_id: str, name: str) -> npt.NDArray[np.float_]:
Expand Down
50 changes: 29 additions & 21 deletions cleanlab_studio/studio/studio.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
from typing import Any, List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

import numpy as np
import numpy.typing as npt
import pandas as pd

try:
import pyspark.sql

pyspark_exists = True
except ImportError:
pyspark_exists = False

from . import clean, upload
from cleanlab_studio.internal.api import api
from cleanlab_studio.internal.util import init_dataset_source, check_none, check_not_none
from cleanlab_studio.internal.settings import CleanlabSettings
from cleanlab_studio.internal.types import FieldSchemaDict

pyspark_exists = api.pyspark_exists
if pyspark_exists:
import pyspark.sql


class Studio:
_api_key: str
Expand Down Expand Up @@ -61,36 +58,46 @@ def download_cleanlab_columns(
self,
cleanset_id: str,
include_action: bool = False,
) -> pd.DataFrame:
rows_df: pd.DataFrame = api.download_cleanlab_columns(self._api_key, cleanset_id, all=True)
to_spark: bool = False,
) -> Any:
"""
Returns either a pandas or pyspark DataFrame
Type Any because don't want to rely on pyspark being installed
"""
rows_df = api.download_cleanlab_columns(
self._api_key, cleanset_id, all=True, to_spark=to_spark
)
if not include_action:
rows_df.drop("action", inplace=True, axis=1)
if to_spark:
rows_df = rows_df.drop("action")
else:
rows_df.drop("action", inplace=True, axis=1)
return rows_df

def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool = False) -> Any:
project_id = api.get_project_of_cleanset(self._api_key, cleanset_id)
label_column = api.get_label_column_of_project(self._api_key, project_id)
id_col = api.get_id_column(self._api_key, cleanset_id)
cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True)
if pyspark_exists and isinstance(dataset, pyspark.sql.DataFrame):
from pyspark.sql.functions import udf

spark = dataset.sparkSession
cl_cols_df = spark.createDataFrame(cl_cols)
corrected_ds = dataset.alias("corrected_ds")
if id_col not in corrected_ds.columns:
cl_cols = self.download_cleanlab_columns(
cleanset_id, include_action=True, to_spark=True
)
corrected_ds_spark = dataset.alias("corrected_ds")
if id_col not in corrected_ds_spark.columns:
from pyspark.sql.functions import (
row_number,
monotonically_increasing_id,
)
from pyspark.sql.window import Window

corrected_ds = corrected_ds.withColumn(
corrected_ds_spark = corrected_ds_spark.withColumn(
id_col,
row_number().over(Window.orderBy(monotonically_increasing_id())) - 1,
)
both = cl_cols_df.select([id_col, "action", "clean_label"]).join(
corrected_ds.select([id_col, label_column]),
both = cl_cols.select([id_col, "action", "clean_label"]).join(
corrected_ds_spark.select([id_col, label_column]),
on=id_col,
how="left",
)
Expand All @@ -107,12 +114,13 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool
[id_col, "action", "__cleanlab_final_label"]
).withColumnRenamed("__cleanlab_final_label", label_column)
return (
corrected_ds.drop(label_column)
corrected_ds_spark.drop(label_column)
.join(new_labels, on=id_col, how="right")
.where(new_labels["action"] != "exclude")
.drop("action")
)
elif isinstance(dataset, pd.DataFrame):
cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True)
joined_ds: pd.DataFrame
if id_col in dataset.columns:
joined_ds = dataset.join(cl_cols.set_index(id_col), on=id_col)
Expand All @@ -123,7 +131,7 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool
dataset[label_column].to_numpy(),
)

corrected_ds = dataset.copy()
corrected_ds: pd.DataFrame = dataset.copy()
corrected_ds[label_column] = joined_ds["__cleanlab_final_label"]
if not keep_excluded:
corrected_ds = corrected_ds.loc[(joined_ds["action"] != "exclude").fillna(True)]
Expand Down