Skip to content

Commit

Permalink
Fix dataset versioning
Browse files Browse the repository at this point in the history
Co-authored-by: Gustavo Viera López <[email protected]>
  • Loading branch information
jmorgadov and gvieralopez committed Jun 7, 2023
1 parent 22c5e72 commit eb36bf6
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
9 changes: 4 additions & 5 deletions pactus/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class Dataset(Data):
Dataset version.
"""

_last_tag = None
_last_tag: Union[str, None] = None

def __init__(
self,
Expand All @@ -277,7 +277,7 @@ def __init__(
self.version = version
self.trajs = trajs
self.labels = labels
super().__init__(trajs, labels)
super().__init__(trajs, labels, dataset_name=name)

def __len__(self):
return len(self.trajs)
Expand All @@ -286,13 +286,12 @@ def __len__(self):
def _from_json(name: str, data: dict) -> Dataset:
assert "trajs" in data, "trajs not found in dataset"
assert "labels" in data, "labels not found in dataset"
assert "version" in data, "version not found in dataset"
trajs = [JSONSerializer.from_json(traj) for traj in data["trajs"]]
return Dataset(
name=name,
trajs=trajs,
labels=data["labels"],
version=data["version"],
version=data.get("version", 0),
)

@staticmethod
Expand All @@ -315,7 +314,7 @@ def _from_url(name: str, force: bool = False) -> Dataset:
if not force and yupi_data_file.exists():
with open(yupi_data_file, "r", encoding="utf-8") as yupi_fd:
data = json.load(yupi_fd)
if "trajs" in data and "labels" in data and "version" in data:
if "trajs" in data and "labels" in data:
return Dataset._from_json(name, data)
logging.warning("Invalid dataset file, downloading again")

Expand Down
2 changes: 1 addition & 1 deletion pactus/models/evaluation_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, evals: List[Evaluation]):

self.evals_by_dataset: Dict[str, List[Evaluation]] = {}
for evaluation in self.evals:
ds_name = evaluation.dataset.name
ds_name = evaluation.dataset_name
self.evals_by_dataset[ds_name] = self.evals_by_dataset.get(ds_name, []) + [
evaluation
]
Expand Down
17 changes: 9 additions & 8 deletions pactus/models/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from sklearn.preprocessing import LabelEncoder
from tensorflow import keras
from tensorflow.keras import layers
from yupi import Trajectory

from pactus import Dataset
Expand Down Expand Up @@ -65,25 +64,27 @@ def _get_model(self, input_shape, n_classes):
max_len, traj_dim = input_shape
model = keras.Sequential()
model.add(
layers.Masking(
mask_value=self.masking_value, input_shape=(max_len, traj_dim)
keras.layers.Masking(
mask_value=self.masking_value,
input_shape=(max_len, traj_dim),
)
)
for units_val in self.units:
model.add(
layers.LSTM(
keras.layers.LSTM(
units_val,
input_shape=(max_len, traj_dim),
return_sequences=True,
)
)
model.add(
layers.Bidirectional(
layers.LSTM(32, input_shape=(max_len, traj_dim)), merge_mode="ave"
keras.layers.Bidirectional(
keras.layers.LSTM(32, input_shape=(max_len, traj_dim)),
merge_mode="ave",
)
)
model.add(layers.Dense(15, activation="relu"))
model.add(layers.Dense(n_classes, activation="softmax"))
model.add(keras.layers.Dense(15, activation="relu"))
model.add(keras.layers.Dense(n_classes, activation="softmax"))
model.compile(**self.compile_args)
return model

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ push = false
"pactus/__init__.py" = [
'__version__ = "{version}"',
]
"docs/source/conf.py" = [
'release = "{version}"',
]

[build-system]
requires = ["setuptools>=61.0.0",
Expand All @@ -78,6 +81,9 @@ build-backend = "setuptools.build_meta"

[tool.mypy]
python_version = "3.8"
exclude = [
"docs",
]

[[tool.mypy.overrides]]
module = [
Expand Down

0 comments on commit eb36bf6

Please sign in to comment.