Skip to content

Commit

Permalink
Merge pull request #185 from gganapavarapu/master
Browse files Browse the repository at this point in the history
Allowing custom dataset urls in data load helper classes
  • Loading branch information
vijay-arya authored Jul 31, 2023
2 parents 8e26427 + 34a2906 commit 298ce21
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 9 deletions.
11 changes: 9 additions & 2 deletions aix360/datasets/climate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class ClimateDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "climate_data"
Expand All @@ -32,7 +35,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "jena_climate_2009_2016.csv")
)
climate_data_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
climate_data_url = (
url
if url is not None
else "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
)

self.input_length = 500
# download data
Expand Down
11 changes: 9 additions & 2 deletions aix360/datasets/diabetes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ class DiabetesDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "diabetes_data"
Expand All @@ -27,7 +30,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "diabetes.csv")
)
diabetes_url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"
diabetes_url = (
url
if url is not None
else "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"
)

if not os.path.exists(self.data_file):
response = requests.get(diabetes_url)
Expand Down
10 changes: 7 additions & 3 deletions aix360/datasets/ford_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FordDataset:
"""

def __init__(self, category_a: bool = True):
def __init__(self, url: str = None, category_a: bool = True):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "ford_data"
Expand All @@ -41,8 +41,12 @@ def __init__(self, category_a: bool = True):
)

self.category = "A" if category_a else "B"
ford_data_url = "http://timeseriesclassification.com/ClassificationDownloads/Ford{}.zip".format(
self.category
ford_data_url = (
url
if url is not None
else "https://timeseriesclassification.com/aeon-toolkit/Ford{}.zip".format(
self.category
)
)

self.input_length = 500
Expand Down
11 changes: 9 additions & 2 deletions aix360/datasets/sunspots_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class SunspotDataset:
"""

def __init__(self):
def __init__(
self,
url: str = None,
):
self.data_folder = os.path.realpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../data", "sunspots_data"
Expand All @@ -32,7 +35,11 @@ def __init__(self):
self.data_file = os.path.realpath(
os.path.join(self.data_folder, "sunspots.csv")
)
sunspots_url = "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv"
sunspots_url = (
url
if url is not None
else "https://raw.githubusercontent.com/PacktPublishing/Practical-Time-Series-Analysis/master/Data%20Files/monthly-sunspot-number-zurich-17.csv"
)

if not os.path.exists(self.data_file):
response = requests.get(sunspots_url)
Expand Down
4 changes: 4 additions & 0 deletions tests/tslime/test_tslime.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,7 @@ def test_tslime(self):
self.assertIn("surrogate_prediction", explanation)

self.assertEqual(explanation["history_weights"].shape[0], relevant_history)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions tests/tssaliency/test_tssaliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ def test_tssaliency(self):
self.assertIn("base_value_prediction", explanation)

self.assertEqual(explanation["saliency"].shape, test_window.shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 298ce21

Please sign in to comment.