Skip to content

Commit

Permalink
Merge pull request #23 from ktonal/develop
Browse files Browse the repository at this point in the history
v0.2.1
  • Loading branch information
antoinedaurat authored Jun 14, 2021
2 parents d7a373b + 2b1a6e1 commit 8285dbb
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 251 deletions.
2 changes: 1 addition & 1 deletion mimikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.0'
__version__ = '0.2.1'

from . import audios
from . import connectors
Expand Down
3 changes: 0 additions & 3 deletions mimikit/_make_notebook.py

This file was deleted.

49 changes: 29 additions & 20 deletions mimikit/data/create.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import h5py
import pandas as pd
from multiprocessing import cpu_count, Pool
from concurrent.futures import ThreadPoolExecutor
import os
import warnings

Expand Down Expand Up @@ -59,12 +60,13 @@ def file_to_h5(abs_path, extract_func=None, output_path=None, mode="w"):
created_file, infos : str, dict
the name of the created .h5 file and a ``dict`` with the keys ``"dtype"`` and ``"shape"``
"""
print("making .h5 for %s" % abs_path)
# print("making .h5 for %s" % abs_path)
if output_path is None:
output_path = os.path.splitext(abs_path)[0] + ".h5"
else:
if output_path[-3:] != ".h5":
output_path += ".h5"
# print("!!!", abs_path)
rv = extract_func(abs_path)
info = {}
if os.path.exists(output_path) and mode == 'w':
Expand All @@ -84,9 +86,11 @@ def file_to_h5(abs_path, extract_func=None, output_path=None, mode="w"):
f.close()
pd.DataFrame(regions).to_hdf(output_path, name + "_regions", "r+")
f = h5py.File(output_path, "r+")
del data
f.attrs["features"] = list(rv.keys())
f.flush()
f.close()
del rv
return output_path, info


Expand All @@ -95,7 +99,8 @@ def file_to_h5(abs_path, extract_func=None, output_path=None, mode="w"):
def _make_db_for_each_file(file_walker,
extract_func=None,
destination="",
n_cores=cpu_count()):
n_cores=cpu_count(),
method='mp'):
"""
apply ``extract_func`` to the files found by ``file_walker``
Expand All @@ -120,8 +125,12 @@ def _make_db_for_each_file(file_walker,
args = [(file, extract_func, os.path.join(destination, file.strip('.').strip('/')))
for n, file in enumerate(file_walker)]
if len(args) > 1:
with Pool(min(n_cores, len(args))) as p:
tmp_dbs_infos = p.starmap(file_to_h5, args)
if method != 'future':
with Pool(min(n_cores, len(args))) as p:
tmp_dbs_infos = p.starmap(file_to_h5, args)
else:
with ThreadPoolExecutor(max_workers=len(args)) as executor:
tmp_dbs_infos = [info for info in executor.map(file_to_h5, *zip(*args))]
else:
tmp_dbs_infos = [file_to_h5(*arg) for arg in args]
return tmp_dbs_infos
Expand Down Expand Up @@ -156,7 +165,7 @@ def _aggregate_db_infos(infos):
f" feature {str(f)} returned different shapes")
# collect the regions for the files (this is different from the possible segmentation regions!)
regions = Regions.from_duration([s[0] for s in shapes])
ds_shape = (regions.last_stop, *dims)
ds_shape = (regions.stop.values[-1], *dims)
regions.index = paths
ds_definitions[f] = {"shape": ds_shape, "dtype": dtype, "files_regions": regions}
return ds_definitions
Expand Down Expand Up @@ -202,21 +211,21 @@ def _aggregate_dbs(target, tmp_dbs_infos, mode="w"):

# copy the data
intra_regions = {}
for source, key, indices in args:
with h5py.File(source, "r") as src:
data = src[key][()]
attrs = {k: v for k, v in src[key].attrs.items()}
if key + "_regions" in src:
if key not in intra_regions:
intra_regions[key] = []
# concat the regions :
regions = pd.read_hdf(source, key=key + "_regions", mode="r")
regions.loc[:, ("start", "stop")] += indices[0].start
regions = regions.reset_index(drop=False)
# remove ".tmp_" from the source's name
regions.loc[:, "name"] = "".join(source.split(".tmp_"))
intra_regions[key] += [regions]
with h5py.File(target, "r+") as trgt:
with h5py.File(target, "r+") as trgt:
for source, key, indices in args:
with h5py.File(source, "r") as src:
data = src[key][()]
attrs = {k: v for k, v in src[key].attrs.items()}
if key + "_regions" in src:
if key not in intra_regions:
intra_regions[key] = []
# concat the regions :
regions = pd.read_hdf(source, key=key + "_regions", mode="r")
regions.loc[:, ("start", "stop")] += indices[0].start
regions = regions.reset_index(drop=False)
# remove ".tmp_" from the source's name
regions.loc[:, "name"] = "".join(source.split(".tmp_"))
intra_regions[key] += [regions]
trgt[key][indices] = data
trgt[key].attrs.update(attrs)
# "<feature>_regions" will store the segmenting slices for this feature
Expand Down
16 changes: 11 additions & 5 deletions mimikit/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,27 @@ def __init__(self, h5_file: str, ds_name: str, keep_open=False):
has_regions = self.name + "_regions" in f
self.files = Regions(pd.read_hdf(h5_file, self.name + "_files", mode="r")) if has_files else None
self.regions = Regions(pd.read_hdf(h5_file, self.name + "_regions", mode="r")) if has_regions else None
self._f = h5py.File(h5_file, "r+") if keep_open else None
# handle to the file when keeping open. To support torch's Dataloader, we have to open the file by the
# first getitem request
self._f = None
self.keep_open = keep_open

def __len__(self):
return self.N

def __getitem__(self, item):
if self._f is not None:
if self.keep_open:
if self._f is None:
self._f = h5py.File(self.h5_file, "r+")
return self._f[self.name][item]
with h5py.File(self.h5_file, "r") as f:
rv = f[self.name][item]
return rv

def __setitem__(self, item, value):
if self._f is not None:
if self.keep_open:
if self._f is None:
self._f = h5py.File(self.h5_file, "r+")
self._f[self.name][item] = value
with h5py.File(self.h5_file, "r+") as f:
f[self.name][item] = value
Expand Down Expand Up @@ -173,7 +180,6 @@ def _load(cls, path, schema={}):
"""
default extract_func for Database.build. Roughly equivalent to :
``{feat_name: feat.load(path) for feat_name, feat in features_dict.items()}``
"""
out = {}
for f_name, f in schema.items():
Expand Down Expand Up @@ -226,7 +232,7 @@ def create(cls, db_name, sources=tuple(), schema={}):
for f_name, f in schema.items():
# let features the chance to update them selves confronted to their whole set
if getattr(type(f), "post_create", Feature.post_create) != Feature.post_create:
rv = f.post_create(db, f_name)
rv = f.after_create(db, f_name)
if isinstance(rv, np.ndarray):
rv = (rv, getattr(db, f_name).regions)
elif isinstance(rv, Regions):
Expand Down
93 changes: 3 additions & 90 deletions mimikit/data/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@ def ssts(item, axis=0):

class Regions(pd.DataFrame):
"""
subclass of ``pandas.DataFrame`` to store a slice structure. This is handy when concatenating arrays
Instances of ``Regions`` created by the constructors mentioned below will automatically have
three columns : "start", "stop", "duration".
:no-inherited-members:
subclass of ``pandas.DataFrame`` to construct and store sequences of slices.
"""

@property
Expand All @@ -31,70 +26,27 @@ def _validate(obj):
if 'start' not in obj.columns or "stop" not in obj.columns:
raise ValueError("Must have 'start' and 'stop' columns.")

# PROPERTIES

@property
def first_start(self):
return self.start.min()

@property
def last_stop(self):
return self.stop.max()

@property
def span(self):
return self.last_stop - self.first_start

@property
def cumdur(self):
return np.cumsum(self["duration"].values)

@property
def all_indices(self):
return np.array([i for ev in self.events for i in range(ev.start, ev.stop)])

def slices(self, time_axis=0):
"""
This is the back-end core of a score. This method efficiently returns the indexing objects
necessary to communicate with n-dimensional data.
@return: an array of slices where slice_i corresponds to the row/Event_i in the DataFrame
"""
return ssts(self, time_axis)

@property
def durations_(self):
return self.stop.values - self.start.values

@property
def events(self):
return self.itertuples(name="Event", index=True)

# UPDATING METHOD

def make_contiguous(self):
self.reset_index(drop=True, inplace=True)
cumdur = self.cumdur
cumdur = np.cumsum(self["duration"].values)
self.loc[:, "start"] = np.r_[0, cumdur[:-1]] if cumdur[0] != 0 else cumdur[:-1]
self.loc[:, "stop"] = cumdur
return self

# Sanity Checks

def is_consistent(self):
return (self["duration"].values == (self.stop.values - self.start.values)).all()

def is_contiguous(self):
return (self.start.values[1:] == self.stop.values[:-1]).all()

@staticmethod
def from_start_stop(starts, stops, durations):
return Regions((starts, stops, durations), index=["start", "stop", "duration"]).T

@staticmethod
def from_stop(stop):
"""
integers in `stops` correspond to the prev[stop] and next[start] values.
integers in `stop` correspond to the prev[stop] and next[start] values.
`stops` must contain the last index ! and it can begin with 0, but doesn't have to...
"""
stop = np.asarray(stop)
Expand All @@ -117,45 +69,6 @@ def from_data(sequence, time_axis=1):
duration = np.array([x.shape[time_axis] for x in sequence])
return Regions.from_duration(duration)

@staticmethod
def from_frame_definition(total_duration, frame_length, stride=1, butlasts=0):
starts = np.arange(total_duration - frame_length - butlasts + 1, step=stride)
durations = frame_length + np.zeros_like(starts, dtype=np.int)
stops = starts + durations
return Regions.from_start_stop(starts, stops, durations)

def to_labels(self):
return np.hstack([np.ones((tp.duration,), dtype=np.int) * tp.Index
for tp in self.itertuples()])


@pd.api.extensions.register_dataframe_accessor("soft_q")
class SoftQueryAccessor:
def __init__(self, pandas_obj):
self._df = pandas_obj

def or_(self, **kwargs):
series = False
for col_name, func in kwargs.items():
series = series | func(self._df[col_name])
return series

def and_(self, **kwargs):
series = True
for col_name, func in kwargs.items():
series = series & func(self._df[col_name])
return series


@pd.api.extensions.register_dataframe_accessor("hard_q")
class HardQueryAccessor:
def __init__(self, pandas_obj):
self._df = pandas_obj

def or_(self, **kwargs):
series = self._df.soft_q.or_(**kwargs)
return self._df[series]

def and_(self, **kwargs):
series = self._df.soft_q.and_(**kwargs)
return self._df[series]
Loading

0 comments on commit 8285dbb

Please sign in to comment.