diff --git a/examples/example_04.py b/examples/example_04.py index 2c0c6d1..28904d7 100644 --- a/examples/example_04.py +++ b/examples/example_04.py @@ -14,23 +14,23 @@ Dataset.uci_characters(), Dataset.uci_movement_libras(), ] - + +# Select the desired features to be extracted from the trajectories featurizer = featurizers.UniversalFeaturizer() for dataset in datasets: print(f"\nDataset: {dataset.name}\n") - # Split the dataset into train and test + # Split the dataset into train and test and filter out short trajectories train, test = dataset.filter( - lambda traj, _: len(traj) >= 5 and traj.r.delta.norm.sum() > 0 + lambda traj, label: len(traj) >= 5 + and traj.r.delta.norm.sum() > 0 + and dataset.label_counts[label] > 5 ).split( train_size=0.7, random_state=SEED, ) - # Select the desired features to be extracted from the trajectories - featurizer = featurizers.UniversalFeaturizer() - # Define the model model = XGBoostModel(featurizer=featurizer)