Skip to content

Commit cca1753

Browse files
Merge pull request #255 from paxtonfitzpatrick/master
updates to example data, DataGeometry persistence, tests, requirements + other improvements
2 parents 390c4f0 + 4d1426c commit cca1753

23 files changed

+231
-246
lines changed

.travis.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@ python:
44
- '3.6'
55
- '3.7'
66
- '3.8'
7+
- '3.9'
78
services:
89
- xvfb
910
install:
10-
- wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
11-
- bash miniconda.sh -b -p $HOME/miniconda
12-
- export PATH="$HOME/miniconda/bin:$PATH"
11+
- wget https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-$(uname -m).sh -O mambaforge.sh
12+
- bash mambaforge.sh -b -p $HOME/mambaforge
13+
- export PATH="$HOME/mambaforge/bin:$PATH"
1314
- hash -r
1415
- conda config --set always_yes yes --set changeps1 no
15-
- conda update -q conda
16-
- conda info -a
17-
- conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION pip pytest numpy pandas
18-
scipy matplotlib seaborn scikit-learn numba
16+
- mamba info -a
17+
- mamba create -y -q -n testenv python=$TRAVIS_PYTHON_VERSION pip pytest numpy pandas
18+
scipy matplotlib seaborn scikit-learn umap-learn requests
1919
- source activate testenv
20-
- pip install -r requirements.txt
21-
- python setup.py install
20+
- pip install .
2221
- cp tests/matplotlibrc .
2322
before_script:
23+
- if [[ -d "$HOME/hypertools_data" ]]; then rm -rf "$HOME/hypertools_data"; fi
2424
- if [ ${TRAVIS_PYTHON_VERSION} == '3.6' ]; then pip install flake8; flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics; fi
25-
script: py.test
25+
script: pytest -sv .
2626
notifications:
2727
slack:
2828
rooms:

hypertools/_externals/ppca.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from __future__ import division
2-
from __future__ import print_function
31

4-
from builtins import object
52
import os
63

74
import numpy as np

hypertools/_externals/srm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
# Authors: Po-Hsuan Chen (Princeton Neuroscience Institute) and Javier Turek
3535
# (Intel Labs), 2015
36-
from __future__ import division
37-
3836
import logging
3937

4038
import numpy as np

hypertools/_shared/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@ class HypertoolsBackendError(HypertoolsError):
66
def __init__(self, message):
77
super().__init__(message)
88
self.message = message
9+
10+
11+
class HypertoolsIOError(HypertoolsError, OSError):
12+
def __init__(self, message):
13+
super().__init__(message)
14+
self.message = message

hypertools/_shared/helpers.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,9 @@
55
"""
66

77
##PACKAGES##
8-
from __future__ import division
9-
from __future__ import print_function
108
import functools
119
import sys
1210
import numpy as np
13-
import six
1411
import copy
1512
from scipy.interpolate import PchipInterpolator as pchip
1613
import seaborn as sns
@@ -168,11 +165,10 @@ def get_type(data):
168165
"""
169166
Checks what the data type is and returns it as a string label
170167
"""
171-
import six
172168
from ..datageometry import DataGeometry
173169

174170
if isinstance(data, list):
175-
if isinstance(data[0], (six.string_types, six.text_type, six.binary_type)):
171+
if isinstance(data[0], (str, bytes)):
176172
return 'list_str'
177173
elif isinstance(data[0], (int, float)):
178174
return 'list_num'
@@ -183,13 +179,13 @@ def get_type(data):
183179
'Numpy Array, Pandas DataFrame, String, List of strings'
184180
', List of numbers')
185181
elif isinstance(data, np.ndarray):
186-
if isinstance(data[0][0], (six.string_types, six.text_type, six.binary_type)):
182+
if isinstance(data[0][0], (str, bytes)):
187183
return 'arr_str'
188184
else:
189185
return 'arr_num'
190186
elif isinstance(data, pd.DataFrame):
191187
return 'df'
192-
elif isinstance(data, (six.string_types, six.text_type, six.binary_type)):
188+
elif isinstance(data, (str, bytes)):
193189
return 'str'
194190
elif isinstance(data, DataGeometry):
195191
return 'geo'
@@ -211,19 +207,19 @@ def check_geo(geo):
211207
geo = copy.copy(geo)
212208

213209
def fix_item(item):
214-
if isinstance(item, six.binary_type):
210+
if isinstance(item, bytes):
215211
return item.decode()
216212
return item
217213

218214
def fix_list(lst):
219215
return [fix_item(i) for i in lst]
220-
if isinstance(geo.reduce, six.binary_type):
216+
if isinstance(geo.reduce, bytes):
221217
geo.reduce = geo.reduce.decode()
222218
for key in geo.kwargs.keys():
223219
if geo.kwargs[key] is not None:
224220
if isinstance(geo.kwargs[key], (list, np.ndarray)):
225221
geo.kwargs[key] = fix_list(geo.kwargs[key])
226-
elif isinstance(geo.kwargs[key], six.binary_type):
222+
elif isinstance(geo.kwargs[key], bytes):
227223
geo.kwargs[key] = fix_item(geo.kwargs[key])
228224
return geo
229225

@@ -232,7 +228,6 @@ def get_dtype(data):
232228
"""
233229
Checks what the data type is and returns it as a string label
234230
"""
235-
import six
236231
from ..datageometry import DataGeometry
237232

238233
if isinstance(data, list):
@@ -241,7 +236,7 @@ def get_dtype(data):
241236
return 'arr'
242237
elif isinstance(data, pd.DataFrame):
243238
return 'df'
244-
elif isinstance(data, (six.string_types, six.text_type, six.binary_type)):
239+
elif isinstance(data, (str, bytes)):
245240
return 'str'
246241
elif isinstance(data, DataGeometry):
247242
return 'geo'

hypertools/datageometry.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from __future__ import unicode_literals
2-
from builtins import object
31
import copy
4-
import deepdish as dd
5-
import numpy as np
2+
import pickle
3+
import warnings
4+
5+
import pandas as pd
6+
67
from .tools.normalize import normalize as normalizer
78
from .tools.reduce import reduce as reducer
89
from .tools.align import align as aligner
@@ -188,7 +189,7 @@ def plot(self, data=None, **kwargs):
188189
new_kwargs.update({key : kwargs[key]})
189190
return plotter(d, **new_kwargs)
190191

191-
def save(self, fname, compression='blosc'):
192+
def save(self, fname, compression=None):
192193
"""
193194
Save method for the data geometry object
194195
@@ -198,41 +199,46 @@ def save(self, fname, compression='blosc'):
198199
199200
Parameters
200201
----------
201-
202202
fname : str
203203
A name for the file. If the file extension (.geo) is not specified,
204204
it will be appended.
205-
206-
compression : str
207-
The kind of compression to use. See the deepdish documentation for
208-
options: http://deepdish.readthedocs.io/en/latest/api_io.html#deepdish.io.save
209-
210205
"""
211-
if hasattr(self, 'dtype'):
212-
if 'list' in self.dtype:
213-
data = np.array(self.data)
214-
elif 'df' in self.dtype:
215-
data = {k: np.array(v).astype('str') for k, v in self.data.to_dict('list').items()}
216-
else:
217-
data = self.data
218-
219-
# put geo vars into a dict
220-
geo = {
221-
'data' : data,
222-
'xform_data' : np.array(self.xform_data),
223-
'reduce' : self.reduce,
224-
'align' : self.align,
225-
'normalize' : self.normalize,
226-
'semantic' : self.semantic,
227-
'corpus' : np.array(self.corpus) if isinstance(self.corpus, list) else self.corpus,
228-
'kwargs' : self.kwargs,
229-
'version' : self.version,
230-
'dtype' : self.dtype
231-
}
232-
233-
# if extension wasn't included, add it
234-
if fname[-4:]!='.geo':
235-
fname+='.geo'
236-
237-
# save
238-
dd.io.save(fname, geo, compression=compression)
206+
if compression is not None:
207+
warnings.warn("Hypertools has switched from deepdish to pickle "
208+
"for saving DataGeomtry objects. 'compression' "
209+
"argument has no effect and will be removed in a "
210+
"future version",
211+
FutureWarning)
212+
213+
# automatically add extension if not present
214+
if not fname.endswith('.geo'):
215+
fname += '.geo'
216+
217+
# can't save/restore matplotlib objects across sessions
218+
curr_fig = self.fig
219+
curr_ax = self.ax
220+
curr_line_ani = self.line_ani
221+
222+
curr_data = self.data
223+
# convert pandas DataFrames to dicts of
224+
# {column_name: list(column_values)} to fix I/O compatibility
225+
# issues across certain pandas versions. Expected self.data
226+
# format is restored by hypertools.load
227+
if isinstance(curr_data, pd.DataFrame):
228+
data_out_fmt = curr_data.to_dict('list')
229+
else:
230+
data_out_fmt = curr_data
231+
232+
try:
233+
self.fig = self.ax = self.line_ani = None
234+
self.data = data_out_fmt
235+
# save
236+
with open(fname, 'wb') as f:
237+
pickle.dump(self, f)
238+
finally:
239+
# make sure we don't mutate attribute values whether or not
240+
# save was successful
241+
self.fig = curr_fig
242+
self.ax = curr_ax
243+
self.line_ani = curr_line_ani
244+
self.data = curr_data

hypertools/plot/draw.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#!/usr/bin/env python
22

3-
from __future__ import division
4-
from builtins import str
5-
from builtins import range
63
import matplotlib
74
import matplotlib.pyplot as plt
85
from mpl_toolkits.mplot3d import proj3d
@@ -330,7 +327,7 @@ def update_lines_spin(num, data_lines, lines, cube_scale, rotations=2,
330327
return lines
331328

332329
def dispatch_animate(x, ani_params):
333-
if x[0].shape[1] is 3:
330+
if x[0].shape[1] == 3:
334331
return animate_plot3D(x, **ani_params)
335332

336333
def animate_plot3D(x, tail_duration=2, rotations=2, zoom=1, chemtrails=False,
@@ -380,7 +377,7 @@ def animate_plot3D(x, tail_duration=2, rotations=2, zoom=1, chemtrails=False,
380377
plt.ioff()
381378

382379
if animate in [True, 'parallel', 'spin']:
383-
assert x[0].shape[1] is 3, "Animations are currently only supported for 3d plots."
380+
assert x[0].shape[1] == 3, "Animations are currently only supported for 3d plots."
384381

385382
# animation params
386383
ani_params = dict(tail_duration=tail_duration,
@@ -400,7 +397,7 @@ def animate_plot3D(x, tail_duration=2, rotations=2, zoom=1, chemtrails=False,
400397
fig, ax, data = dispatch_static(x, ax)
401398

402399
# if 3d, plot the cube
403-
if x[0].shape[1] is 3:
400+
if x[0].shape[1] == 3:
404401

405402
# set cube scale
406403
cube_scale = 1
@@ -416,7 +413,7 @@ def animate_plot3D(x, tail_duration=2, rotations=2, zoom=1, chemtrails=False,
416413
# initialize the view
417414
ax.view_init(elev=elev, azim=azim)
418415

419-
elif x[0].shape[1] is 2:
416+
elif x[0].shape[1] == 2:
420417

421418
# plot square
422419
plot_square(ax)

hypertools/plot/plot.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
from __future__ import division
32
import warnings
43
import matplotlib.animation as animation
54
import matplotlib.pyplot as plt
@@ -311,7 +310,7 @@ def plot(x, fmt='-', marker=None, markers=None, linestyle=None, linestyles=None,
311310
if cluster is not None:
312311
if hue is not None:
313312
warnings.warn('cluster overrides hue, ignoring hue.')
314-
if isinstance(cluster, (six.string_types, six.binary_type)):
313+
if isinstance(cluster, (str, bytes)):
315314
model = cluster
316315
params = default_params(model)
317316
elif isinstance(cluster, dict):
@@ -374,7 +373,7 @@ def plot(x, fmt='-', marker=None, markers=None, linestyle=None, linestyles=None,
374373
mpl_kwargs['label'] = legend
375374

376375
# interpolate if its a line plot
377-
if fmt is None or isinstance(fmt, six.string_types):
376+
if fmt is None or isinstance(fmt, str):
378377
if is_line(fmt):
379378
if xform[0].shape[0] > 1:
380379
xform = interp_array_list(xform, interp_val=frame_rate*duration/(xform[0].shape[0] - 1))
@@ -386,8 +385,8 @@ def plot(x, fmt='-', marker=None, markers=None, linestyle=None, linestyles=None,
386385

387386
# handle explore flag
388387
if explore:
389-
assert xform[0].shape[1] is 3, "Explore mode is currently only supported for 3D plots."
390-
mpl_kwargs['picker']=True
388+
assert xform[0].shape[1] == 3, "Explore mode is currently only supported for 3D plots."
389+
mpl_kwargs['picker'] = True
391390

392391
# center
393392
xform = center(xform)

hypertools/tools/align.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/usr/bin/env python
22

3-
from __future__ import division
4-
from builtins import range
53
from .._externals.srm import SRM
64
from .procrustes import procrustes
75
import numpy as np
@@ -82,7 +80,7 @@ def align(data, align='hyper', normalize=None, ndims=None, method=None,
8280
if format_data:
8381
data = formatter(data, ppca=True)
8482

85-
if len(data) is 1:
83+
if len(data) == 1:
8684
warnings.warn('Data in list of length 1 can not be aligned. '
8785
'Skipping the alignment.')
8886

hypertools/tools/cluster.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
import warnings
33
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
4-
import six
54
import numpy as np
65
from .._shared.helpers import *
76
from .format_data import format_data as formatter
@@ -64,7 +63,7 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
6463

6564
if cluster == None:
6665
return x
67-
elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \
66+
elif (isinstance(cluster, str) and cluster=='HDBSCAN') or \
6867
(isinstance(cluster, dict) and cluster['model']=='HDBSCAN'):
6968
if not _has_hdbscan:
7069
raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11')
@@ -76,7 +75,7 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
7675
x = formatter(x, ppca=True)
7776

7877
# if reduce is a string, find the corresponding model
79-
if isinstance(cluster, six.string_types):
78+
if isinstance(cluster, str):
8079
model = models[cluster]
8180
if cluster != 'HDBSCAN':
8281
model_params = {
@@ -86,7 +85,7 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
8685
model_params = {}
8786
# if its a dict, use custom params
8887
elif type(cluster) is dict:
89-
if isinstance(cluster['model'], six.string_types):
88+
if isinstance(cluster['model'], str):
9089
model = models[cluster['model']]
9190
model_params = cluster['params']
9291

0 commit comments

Comments
 (0)