Skip to content

Commit 8cd9031

Browse files
committed
Code quality
1 parent 4e74b28 commit 8cd9031

File tree

8 files changed

+210
-276
lines changed

8 files changed

+210
-276
lines changed

src/get_data.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import requests
44
import pandas as pd
55
from path import Path
6-
from src.parameters import *
6+
from parameters import *
77

88

9-
def downloadData(data_path='/input/speech_commands/'):
9+
def downloadData(data_path="/input/speech_commands/"):
1010
"""
1111
Downloads Google Speech Commands dataset (version0.01)
1212
:param data_path: Path to download dataset
@@ -15,10 +15,10 @@ def downloadData(data_path='/input/speech_commands/'):
1515

1616
dataset_path = Path(os.path.abspath(__file__)).parent.parent + data_path
1717

18-
datasets = ['train', 'test']
18+
datasets = ["train", "test"]
1919
urls = [
20-
'http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
21-
'http://download.tensorflow.org/data/speech_commands_test_set_v0.01.tar.gz'
20+
"http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz",
21+
"http://download.tensorflow.org/data/speech_commands_test_set_v0.01.tar.gz",
2222
]
2323

2424
for dataset, url in zip(datasets, urls):
@@ -27,7 +27,7 @@ def downloadData(data_path='/input/speech_commands/'):
2727
# Check if we need to extract the dataset
2828
if not os.path.isdir(dataset_directory):
2929
os.makedirs(dataset_directory)
30-
file_name = dataset_path + dataset + '.tar.gz'
30+
file_name = dataset_path + dataset + ".tar.gz"
3131

3232
# Check if the dataset has been downloaded, else download it
3333
if os.path.isfile(file_name):
@@ -36,7 +36,7 @@ def downloadData(data_path='/input/speech_commands/'):
3636
print("Downloading '{}' into '{}' file".format(url, file_name))
3737

3838
data_request = requests.get(url)
39-
with open(file_name, 'wb') as file:
39+
with open(file_name, "wb") as file:
4040
file.write(data_request.content)
4141

4242
# Extract downloaded file
@@ -54,7 +54,7 @@ def downloadData(data_path='/input/speech_commands/'):
5454
print("Input data setup successful.")
5555

5656

57-
def getDataDict(data_path='/input/speech_commands/'):
57+
def getDataDict(data_path="/input/speech_commands/"):
5858
"""
5959
Creates a dictionary with train, test, validate and test file names and labels.
6060
:param data_path: Path to the downloaded dataset
@@ -64,24 +64,24 @@ def getDataDict(data_path='/input/speech_commands/'):
6464
data_path = Path(os.path.abspath(__file__)).parent.parent + data_path
6565

6666
# Get the validation files
67-
validation_files = open(data_path + 'train/validation_list.txt').read().splitlines()
68-
validation_files = [data_path + 'train/' + file_name for file_name in validation_files]
67+
validation_files = open(data_path + "train/validation_list.txt").read().splitlines()
68+
validation_files = [data_path + "train/" + file_name for file_name in validation_files]
6969

7070
# Get the dev files
71-
dev_files = open(data_path + 'train/testing_list.txt').read().splitlines()
72-
dev_files = [data_path + 'train/' + file_name for file_name in dev_files]
71+
dev_files = open(data_path + "train/testing_list.txt").read().splitlines()
72+
dev_files = [data_path + "train/" + file_name for file_name in dev_files]
7373

7474
# Find train_files as allFiles - {validation_files, dev_files}
7575
all_files = []
76-
for root, dirs, files in os.walk(data_path + 'train/'):
77-
all_files += [root + '/' + file_name for file_name in files if file_name.endswith('.wav')]
76+
for root, dirs, files in os.walk(data_path + "train/"):
77+
all_files += [root + "/" + file_name for file_name in files if file_name.endswith(".wav")]
7878

7979
train_files = list(set(all_files) - set(validation_files) - set(dev_files))
8080

8181
# Get the test files
8282
test_files = list()
83-
for root, dirs, files in os.walk(data_path + 'test/'):
84-
test_files += [root + '/' + file_name for file_name in files if file_name.endswith('.wav')]
83+
for root, dirs, files in os.walk(data_path + "test/"):
84+
test_files += [root + "/" + file_name for file_name in files if file_name.endswith(".wav")]
8585

8686
# Get labels
8787
validation_file_labels = [getLabel(wav) for wav in validation_files]
@@ -90,17 +90,12 @@ def getDataDict(data_path='/input/speech_commands/'):
9090
test_file_labels = [getLabel(wav) for wav in test_files]
9191

9292
# Create dictionaries containing (file, labels)
93-
trainData = {'files': train_files, 'labels': train_file_labels}
94-
valData = {'files': validation_files, 'labels': validation_file_labels}
95-
devData = {'files': dev_files, 'labels': dev_file_labels}
96-
testData = {'files': test_files, 'labels': test_file_labels}
97-
98-
dataDict = {
99-
'train': trainData,
100-
'val': valData,
101-
'dev': devData,
102-
'test': testData
103-
}
93+
trainData = {"files": train_files, "labels": train_file_labels}
94+
valData = {"files": validation_files, "labels": validation_file_labels}
95+
devData = {"files": dev_files, "labels": dev_file_labels}
96+
testData = {"files": test_files, "labels": test_file_labels}
97+
98+
dataDict = {"train": trainData, "val": valData, "dev": devData, "test": testData}
10499

105100
return dataDict
106101

@@ -112,8 +107,8 @@ def getLabel(file_name):
112107
:return: Class label
113108
"""
114109

115-
category = file_name.split('/')[-2]
116-
label = categories.get(category, categories['_background_noise_'])
110+
category = file_name.split("/")[-2]
111+
label = categories.get(category, categories["_background_noise_"])
117112

118113
return label
119114

@@ -127,9 +122,9 @@ def getDataframe(data, include_unknown=False):
127122
"""
128123

129124
df = pd.DataFrame(data)
130-
df['category'] = df.apply(lambda row: inv_categories[row['labels']], axis=1)
125+
df["category"] = df.apply(lambda row: inv_categories[row["labels"]], axis=1)
131126

132127
if not include_unknown:
133-
df = df.loc[df['category'] != '_background_noise_', :]
128+
df = df.loc[df["category"] != "_background_noise_", :]
134129

135130
return df

src/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import os
2-
from src.model_train import model_train, marvin_kws_model
3-
from src.model_test import marvin_model_test
2+
from model_train import model_train, marvin_kws_model
3+
from model_test import marvin_model_test
44

55

66
def main():
7-
trained = os.path.isfile('../models/marvin_kws_svm.pickle') \
8-
and os.path.isfile('../models/marvin_kws_pca.pickle')
7+
trained = os.path.isfile("../models/marvin_kws_svm.pickle") and os.path.isfile("../models/marvin_kws_pca.pickle")
98

109
if not trained:
1110
print("Training model")
@@ -16,5 +15,5 @@ def main():
1615
marvin_model_test()
1716

1817

19-
if __name__ == '__main__':
18+
if __name__ == "__main__":
2019
main()

src/model_test.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import tensorflow as tf
44
from tensorflow.keras.models import Model, load_model
55

6-
from src.parameters import *
7-
from src.utils import OC_Statistics
8-
from src.utils import getDataset
9-
from src.get_data import downloadData, getDataDict, getDataframe
6+
from parameters import *
7+
from utils import OC_Statistics
8+
from utils import getDataset
9+
from get_data import downloadData, getDataDict, getDataframe
1010

1111

1212
def marvin_model_test():
@@ -16,43 +16,38 @@ def marvin_model_test():
1616
"""
1717

1818
# Download data
19-
downloadData(data_path='/input/speech_commands/')
19+
downloadData(data_path="/input/speech_commands/")
2020

2121
# Get dictionary with files and labels
22-
dataDict = getDataDict(data_path='/input/speech_commands/')
22+
dataDict = getDataDict(data_path="/input/speech_commands/")
2323

2424
# Obtain dataframe by merging dev and test dataset
25-
devDF = getDataframe(dataDict['dev'], include_unknown=True)
26-
testDF = getDataframe(dataDict['test'], include_unknown=True)
25+
devDF = getDataframe(dataDict["dev"], include_unknown=True)
26+
testDF = getDataframe(dataDict["test"], include_unknown=True)
2727

2828
evalDF = pd.concat([devDF, testDF], ignore_index=True)
2929

3030
print("Test files: {}".format(evalDF.shape[0]))
3131

3232
# Obtain Marvin - Other separated data
33-
evalDF['class'] = evalDF.apply(lambda row: 1 if row['category'] == 'marvin' else -1, axis=1)
34-
evalDF.drop('category', axis=1)
35-
test_true_labels = evalDF['class'].tolist()
33+
evalDF["class"] = evalDF.apply(lambda row: 1 if row["category"] == "marvin" else -1, axis=1)
34+
evalDF.drop("category", axis=1)
35+
test_true_labels = evalDF["class"].tolist()
3636

37-
eval_data, _ = getDataset(
38-
df=evalDF,
39-
batch_size=BATCH_SIZE,
40-
cache_file='kws_val_cache',
41-
shuffle=False
42-
)
37+
eval_data, _ = getDataset(df=evalDF, batch_size=BATCH_SIZE, cache_file="kws_val_cache", shuffle=False)
4338

4439
# Load trained model
45-
model = load_model('../models/marvin_kws.h5')
40+
model = load_model("../models/marvin_kws.h5")
4641

47-
layer_name = 'features256'
42+
layer_name = "features256"
4843
feature_extractor = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
4944

5045
# Load trained PCA object
51-
with open('../models/marvin_kws_pca.pickle', "rb") as file:
46+
with open("../models/marvin_kws_pca.pickle", "rb") as file:
5247
pca = pickle.load(file)
5348

5449
# Load trained SVM
55-
with open('../models/marvin_kws_svm.pickle', "rb") as file:
50+
with open("../models/marvin_kws_svm.pickle", "rb") as file:
5651
marvin_svm = pickle.load(file)
5752

5853
# Extract the feature embeddings and evaluate using SVM
@@ -61,4 +56,4 @@ def marvin_model_test():
6156
X_test_scaled = pca.transform(X_test)
6257
test_pred_labels = marvin_svm.predict(X_test_scaled)
6358

64-
OC_Statistics(test_pred_labels, test_true_labels, 'marvin_cm_without_noise')
59+
OC_Statistics(test_pred_labels, test_true_labels, "marvin_cm_without_noise")

0 commit comments

Comments
 (0)