Skip to content

Commit

Permalink
Update example 4
Browse files Browse the repository at this point in the history
Co-authored-by: Gustavo Viera López <[email protected]>
  • Loading branch information
jmorgadov and gvieralopez committed Jun 8, 2023
1 parent 6372cd1 commit 2041f45
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 28 deletions.
24 changes: 3 additions & 21 deletions examples/example_04.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import Tuple


from pactus import Dataset, featurizers
from pactus.dataset.dataset import Data
from pactus.models import XGBoostModel

SEED = 0 # Random seed for reproducibility
Expand All @@ -19,29 +15,15 @@
Dataset.uci_movement_libras(),
]

def dataset_splitter(ds: Data) -> Tuple[Data, Data]:
if ds.dataset_name == "geolife":
use_classes = {"car", "taxi-bus", "walk", "bike", "subway", "train"}
return (
ds.filter(lambda traj, _: len(traj) > 10 and traj.dt < 8)
.map(lambda _, lbl: (_, "taxi-bus" if lbl in ("bus", "taxi") else lbl))
.filter(lambda _, lbl: lbl in use_classes)
.split(train_size=0.7, random_state=SEED)
)
if ds.dataset_name == "mnist_stroke":
ds = ds.take(10_000)
return ds.filter(
lambda traj, _: len(traj) >= 5 and traj.r.delta.norm.sum() > 0
).split(train_size=0.7, random_state=SEED)


featurizer = featurizers.UniversalFeaturizer()

for dataset in datasets:
print(f"\nDataset: {dataset.name}\n")

# Split the dataset into train and test
train, test = dataset_splitter(dataset)
train, test = dataset.filter(
lambda traj, _: len(traj) >= 5 and traj.r.delta.norm.sum() > 0
).split(train_size=0.7, random_state=SEED)

# Select the desired features to be extracted from the trajectories
featurizer = featurizers.UniversalFeaturizer()
Expand Down
14 changes: 7 additions & 7 deletions pactus/models/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.mask_value = mask_value
self.encoder: Union[LabelEncoder, None] = None
self.labels: Union[List[Any], None] = None
self.dataset: Union[Dataset, None] = None
self.original_data: Union[Data, None] = None
self.set_summary(
head_size=self.head_size,
num_heads=self.num_heads,
Expand All @@ -77,7 +77,7 @@ def __init__(
def train(
self,
data: Data,
dataset: Dataset,
original_data: Data,
cross_validation: int = 0,
epochs: int = 10,
validation_split: float = 0.2,
Expand All @@ -93,7 +93,7 @@ def train(
)
self.encoder = None
self.labels = data.labels
self.dataset = dataset
self.original_data = original_data
x_train, y_train = self._get_input_data(data)
n_classes = len(data.classes)
input_shape = x_train.shape[1:]
Expand Down Expand Up @@ -170,15 +170,15 @@ def _get_model(
) -> keras.Model:
model = build_model(
n_classes,
input_shape,
input_shape=input_shape,
head_size=self.head_size,
num_heads=self.num_heads,
ff_dim=self.ff_dim,
num_transformer_blocks=self.num_transformer_blocks,
mlp_units=self.mlp_units,
mlp_dropout=self.mlp_dropout,
dropout=self.dropout,
mask=mask,
# mask=mask, # FIXME: using mask is causing input shapes issues
)
model.compile(
loss=self.loss,
Expand Down Expand Up @@ -212,10 +212,10 @@ def _encode_labels(self, data: Data) -> np.ndarray:

def _extract_raw_data(self, data: Data) -> np.ndarray:
"""Extracts the raw data from the yupi trajectories"""
assert self.dataset is not None, "Dataset must be set"
assert self.original_data is not None, "Original data must be set"

trajs = data.trajs
max_len = np.max([len(traj) for traj in self.dataset.trajs])
max_len = np.max([len(traj) for traj in self.original_data.trajs])
if self.max_traj_len > 0:
max_len = self.max_traj_len
raw_data = [np.hstack((traj.r, np.reshape(traj.t, (-1, 1)))) for traj in trajs]
Expand Down

0 comments on commit 2041f45

Please sign in to comment.