diff --git a/notebooks/01 Feature Extraction and Selection.ipynb b/notebooks/01 Feature Extraction and Selection.ipynb
index 129d5cab..f74e06f9 100644
--- a/notebooks/01 Feature Extraction and Selection.ipynb
+++ b/notebooks/01 Feature Extraction and Selection.ipynb
@@ -157,7 +157,18 @@
"source": [
"Using the hypothesis tests implemented in `tsfresh` (see [here](https://tsfresh.readthedocs.io/en/latest/text/feature_filtering.html) for more information) it is now possible to select only the relevant features out of this large dataset.\n",
"\n",
- "`tsfresh` will do a hypothesis test for each of the features to check, if it is relevant for your given target."
+ "`tsfresh` will do a hypothesis test for each of the features to check, if it is relevant for your given target.\n",
+ "\n",
+ "To not leak information between the train and the test set, we will only perform the selection on the train set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_full_train, X_full_test, y_train, y_test = train_test_split(X, y, test_size=.4, random_state=42)"
]
},
{
@@ -166,7 +177,7 @@
"metadata": {},
"outputs": [],
"source": [
- "X_filtered = select_features(X, y)"
+ "X_filtered_train = select_features(X_full_train, y_train)"
]
},
{
@@ -177,7 +188,7 @@
},
"outputs": [],
"source": [
- "X_filtered.head()"
+ "X_filtered_train.head()"
]
},
{
@@ -186,7 +197,7 @@
"source": [
"
\n",
"\n",
- "Currently, 669 non-NaN features survive the feature selection given this target.\n",
+ "Currently, 423 non-NaN features survive the feature selection given this target.\n",
"Again, this number will vary depending on your data, your target and the `tsfresh` version.\n",
" \n",
"
"
@@ -214,8 +225,7 @@
"metadata": {},
"outputs": [],
"source": [
- "X_full_train, X_full_test, y_train, y_test = train_test_split(X, y, test_size=.4)\n",
- "X_filtered_train, X_filtered_test = X_full_train[X_filtered.columns], X_full_test[X_filtered.columns]"
+ "X_filtered_train, X_filtered_test = X_full_train[X_filtered_train.columns], X_full_test[X_filtered_train.columns]"
]
},
{
@@ -246,7 +256,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Compared to using all features (`classifier_full`), using only the relevant features (`classifier_filtered`) achieves better classification performance with less data."
+ "Compared to using all features (`classifier_full`), using only the relevant features (`classifier_filtered`) achieves similar or better classification performance with much less data."
]
},
{
@@ -284,17 +294,8 @@
"metadata": {},
"outputs": [],
"source": [
- "X_filtered_2 = extract_relevant_features(df, y, column_id='id', column_sort='time',\n",
- " default_fc_parameters=extraction_settings)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "(X_filtered.columns == X_filtered_2.columns).all()"
+ "extract_relevant_features(df, y, column_id='id', column_sort='time',\n",
+ " default_fc_parameters=extraction_settings)"
]
}
],
@@ -314,7 +315,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.2"
+ "version": "3.10.13"
}
},
"nbformat": 4,