Skip to content

Commit

Permalink
Merge pull request #17 from gdmnl/zulun
Browse files Browse the repository at this point in the history
add the degree output
  • Loading branch information
ZulunZhu authored Apr 4, 2024
2 parents e31f417 + 86a6a74 commit 969fd8e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 17 deletions.
49 changes: 44 additions & 5 deletions examples/run_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
from itertools import product
import os
import json
import argparse
import pandas as pd
from trainer import DatasetLoader, ModelLoader
from pathlib import Path
Expand Down Expand Up @@ -72,7 +74,9 @@ def update_results_csv(result_path, model_name, dataset_name, new_result):

print(f"Updated CSV file at {result_path} with new results for model '{model_name}' and dataset '{dataset_name}'.")

def search_hyperparameters(parser, method, model):
def search_hyperparameters(args, method, model):
setattr(args, 'grid_search_flag', True)
setattr(args, 'degree_flag', False)
# Ensure method and model are valid
if method in method_model_configs and model in method_model_configs[method]:
# Get hyperparameters for the specific method and model
Expand All @@ -96,8 +100,30 @@ def search_hyperparameters(parser, method, model):
else:
print("Invalid method or model configuration.")

def get_degree_accuracy(args, method, model):
setattr(args, 'grid_search_flag', False)
setattr(args, 'degree_flag', True)
args.logpath = setup_logpath(
folder_args=(args.data, args.model, args.flag),
quiet=args.quiet)
# Ensure method and model are valid
if method in method_model_configs and model in method_model_configs[method]:
optimal_path = os.path.join(args.logpath, args.conv+ "_optimal/config.json")
hyperparams = method_model_configs[method][model]
with open(optimal_path, 'r') as file:
config_dict = json.load(file)

# Convert the dictionary to a namespace object
for param in hyperparams:
setattr(args, param, int(config_dict[param])) if isinstance(config_dict[param], int) else setattr(args, param, float(config_dict[param]))
# args = argparse.Namespace(**config_dict)
print('args:',args.data)
main(args)


else:
print("Invalid method or model configuration.")

def main(args):
# ========== Run configuration
args.logpath = setup_logpath(
Expand Down Expand Up @@ -129,12 +155,17 @@ def main(args):
logger.log(logging.LRES, f"[res]: {res_logger}")
current_accuracy = float(res_logger.get_str("f1micro_test", 0).split(":")[1])
print("result:", current_accuracy)
f1micro_high = float(res_logger.get_str("f1micro_high", 0).split(":")[1])
f1micro_low = float(res_logger.get_str("f1micro_low", 0).split(":")[1])
print("f1micro_high:", f1micro_high, " f1micro_low", f1micro_low)


res_logger.save()
clear_logger(logger)
global best_accuracy
table_path = Path('../log/sum_table.csv')
if current_accuracy>best_accuracy:
degree_table_path = Path('../log/degree_table.csv')
if current_accuracy>best_accuracy and args.grid_search_flag:
setattr(args, 'optimal_accuracy', current_accuracy)
optimal_path = os.path.join(args.logpath, args.conv+ "_optimal")
if not os.path.exists(optimal_path):
Expand All @@ -143,16 +174,24 @@ def main(args):
best_accuracy = current_accuracy
update_results_csv(table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data, best_accuracy)


if args.degree_flag:
#Just output the degree accuracy using the optimal config
update_results_csv(degree_table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data+'-'+'high', f1micro_high)
update_results_csv(degree_table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data+'-'+'low', f1micro_low)

if __name__ == '__main__':
parser = setup_argparse()
# Experiment-specific arguments
# parser.add_argument()
args = setup_args(parser)

# Example usage
search_hyperparameters(args, args.model, args.conv)
# Example usage for grid search
# search_hyperparameters(args, args.model, args.conv)

# Get the degree accuracy when you have an optimal config
get_degree_accuracy(args, args.model, args.conv)





Expand Down
23 changes: 12 additions & 11 deletions examples/scripts/grid_search_all.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Define the list of datasets
# datasets=("Cora" "CiteSeer" "PubMed")
# datasets=("Cora" "CiteSeer" "PubMed" "Texas" "Squirrel" "Chameleon")

datasets=("texas" "squirrel" "chameleon")

# Loop through each dataset and run the Python script with the parameters
Expand All @@ -8,15 +9,15 @@ for graph in "${datasets[@]}"; do
#fixed-linear
python run_grid_search.py --model IterGNN --conv FixLinSumAdj --data $graph --theta khop
#fixed-Impulse
python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta khop
#fixed-Monomial
python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta mono
#fixed-ppr
python run_grid_search.py --model PostMLP --conv FixSumAdj --data $graph --theta appr
#fixed-heart kernel
python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta hk
#fixed-guassian
python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta gaussian
# python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta khop
# #fixed-Monomial
# python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta mono
# #fixed-ppr
# python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta appr
# #fixed-heart kernel
# python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta hk
# #fixed-guassian
# python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta gaussian

#var-Linear-unfinished
# python run_grid_search.py --model IterGNN --conv VarSumAdj --data $graph --theta khop
Expand All @@ -28,7 +29,7 @@ for graph in "${datasets[@]}"; do
python run_grid_search.py --model PostMLP --conv ChebBase --data $graph
#var-Chebyshev2
python run_grid_search.py --model PostMLP --conv ChebConv2 --data $graph
#var-Bernstein
var-Bernstein
python run_grid_search.py --model PostMLP --conv BernConv --data $graph

#bank-linear
Expand Down
2 changes: 1 addition & 1 deletion examples/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def run(self) -> ResLogger:
res_test = self.test()
res_run.merge(res_test)

self.test_deg()
res_run.merge(self.test_deg())

return self.res_logger.merge(res_run)

Expand Down

0 comments on commit 969fd8e

Please sign in to comment.