forked from MingjunZhong/seq2point-nilm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_main.py
34 lines (25 loc) · 2.21 KB
/
test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
from remove_space import remove_space
from seq2point_test import Tester
# Allows a model to be tested from the terminal.
# You need to input your test data directory
test_directory="/Users/NunoAlberto/Desktop/Computer Science/3rd Year/tb2/Individual Project/seq2point-nilm/reddMicrowave/microwave_test_.csv"
parser = argparse.ArgumentParser(description="Train a pruned neural network for energy disaggregation. ")
parser.add_argument("--appliance_name", type=remove_space, default="kettle", help="The name of the appliance to perform disaggregation with. Default is kettle. Available are: kettle, fridge, dishwasher, microwave. ")
parser.add_argument("--batch_size", type=int, default="1000", help="The batch size to use when training the network. Default is 1000. ")
parser.add_argument("--crop", type=int, default="10000", help="The number of rows of the dataset to take training data from. Default is 10000. ")
parser.add_argument("--algorithm", type=remove_space, default="seq2point", help="The pruning algorithm of the model to test. Default is none. ")
parser.add_argument("--network_type", type=remove_space, default="", help="The seq2point architecture to use. Only use if you do not want to use the standard architecture. Available are: default, dropout, reduced, and reduced_dropout. ")
parser.add_argument("--input_window_length", type=int, default="599", help="Number of input data points to network. Default is 599. ")
parser.add_argument("--test_directory", type=str, default=test_directory, help="The dir for training data. ")
arguments = parser.parse_args()
# You need to provide the trained model
saved_model_dir = "saved_models/" + arguments.appliance_name + "_" + arguments.algorithm + "_model.h5"
# The logs including results will be recorded to this log file
log_file_dir = "saved_models/" + arguments.appliance_name + "_" + arguments.algorithm + "_" + arguments.network_type + ".log"
tester = Tester(arguments.appliance_name, arguments.algorithm, arguments.crop,
arguments.batch_size, arguments.network_type,
arguments.test_directory, saved_model_dir, log_file_dir,
arguments.input_window_length
)
tester.test_model()