From 86a6a74ef19b27bb953cd045ea1dfef99537e023 Mon Sep 17 00:00:00 2001 From: ZulunZhu <812135011@qq.com> Date: Thu, 4 Apr 2024 15:50:08 +0800 Subject: [PATCH] add the degree output --- examples/run_grid_search.py | 49 ++++++++++++++++++++++++++--- examples/scripts/grid_search_all.sh | 23 +++++++------- examples/trainer/fullbatch.py | 2 +- 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/examples/run_grid_search.py b/examples/run_grid_search.py index a4518ff..4f993e1 100644 --- a/examples/run_grid_search.py +++ b/examples/run_grid_search.py @@ -4,6 +4,8 @@ import logging from itertools import product import os +import json +import argparse import pandas as pd from trainer import DatasetLoader, ModelLoader from pathlib import Path @@ -72,7 +74,9 @@ def update_results_csv(result_path, model_name, dataset_name, new_result): print(f"Updated CSV file at {result_path} with new results for model '{model_name}' and dataset '{dataset_name}'.") -def search_hyperparameters(parser, method, model): +def search_hyperparameters(args, method, model): + setattr(args, 'grid_search_flag', True) + setattr(args, 'degree_flag', False) # Ensure method and model are valid if method in method_model_configs and model in method_model_configs[method]: # Get hyperparameters for the specific method and model @@ -96,8 +100,30 @@ def search_hyperparameters(parser, method, model): else: print("Invalid method or model configuration.") +def get_degree_accuracy(args, method, model): + setattr(args, 'grid_search_flag', False) + setattr(args, 'degree_flag', True) + args.logpath = setup_logpath( + folder_args=(args.data, args.model, args.flag), + quiet=args.quiet) + # Ensure method and model are valid + if method in method_model_configs and model in method_model_configs[method]: + optimal_path = os.path.join(args.logpath, args.conv+ "_optimal/config.json") + hyperparams = method_model_configs[method][model] + with open(optimal_path, 'r') as file: + config_dict = json.load(file) + + # Convert the dictionary to a namespace object + for param in hyperparams: + setattr(args, param, int(config_dict[param])) if isinstance(config_dict[param], int) else setattr(args, param, float(config_dict[param])) + # args = argparse.Namespace(**config_dict) + print('args:',args.data) + main(args) + else: + print("Invalid method or model configuration.") + def main(args): # ========== Run configuration args.logpath = setup_logpath( @@ -129,12 +155,17 @@ def main(args): logger.log(logging.LRES, f"[res]: {res_logger}") current_accuracy = float(res_logger.get_str("f1micro_test", 0).split(":")[1]) print("result:", current_accuracy) + f1micro_high = float(res_logger.get_str("f1micro_high", 0).split(":")[1]) + f1micro_low = float(res_logger.get_str("f1micro_low", 0).split(":")[1]) + print("f1micro_high:", f1micro_high, " f1micro_low", f1micro_low) + res_logger.save() clear_logger(logger) global best_accuracy table_path = Path('../log/sum_table.csv') - if current_accuracy>best_accuracy: + degree_table_path = Path('../log/degree_table.csv') + if current_accuracy>best_accuracy and args.grid_search_flag: setattr(args, 'optimal_accuracy', current_accuracy) optimal_path = os.path.join(args.logpath, args.conv+ "_optimal") if not os.path.exists(optimal_path): @@ -143,7 +174,10 @@ def main(args): best_accuracy = current_accuracy update_results_csv(table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data, best_accuracy) - + if args.degree_flag: + #Just output the degree accuracy using the optimal config + update_results_csv(degree_table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data+'-'+'high', f1micro_high) + update_results_csv(degree_table_path, args.model + "-" + args.conv+ "-" +args.theta, args.data+'-'+'low', f1micro_low) if __name__ == '__main__': parser = setup_argparse() @@ -151,8 +185,13 @@ def main(args): # parser.add_argument() args = setup_args(parser) - # Example usage - search_hyperparameters(args, args.model, args.conv) + # Example usage for grid search + # search_hyperparameters(args, args.model, args.conv) + + # Get the degree accuracy when you have an optimal config + get_degree_accuracy(args, args.model, args.conv) + + diff --git a/examples/scripts/grid_search_all.sh b/examples/scripts/grid_search_all.sh index 6f05a59..6e57e1d 100644 --- a/examples/scripts/grid_search_all.sh +++ b/examples/scripts/grid_search_all.sh @@ -1,5 +1,6 @@ # Define the list of datasets -# datasets=("Cora" "CiteSeer" "PubMed") +# datasets=("Cora" "CiteSeer" "PubMed" "Texas" "Squirrel" "Chameleon") + datasets=("texas" "squirrel" "chameleon") # Loop through each dataset and run the Python script with the parameters @@ -8,15 +9,15 @@ for graph in "${datasets[@]}"; do #fixed-linear python run_grid_search.py --model IterGNN --conv FixLinSumAdj --data $graph --theta khop #fixed-Impulse - python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta khop - #fixed-Monomial - python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta mono - #fixed-ppr - python run_grid_search.py --model PostMLP --conv FixSumAdj --data $graph --theta appr - #fixed-heart kernel - python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta hk - #fixed-guassian - python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta gaussian + # python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta khop + # #fixed-Monomial + # python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta mono + # #fixed-ppr + # python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta appr + # #fixed-heart kernel + # python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta hk + # #fixed-guassian + # python run_grid_search.py --model PreDecMLP --conv FixSumAdj --data $graph --theta gaussian #var-Linear-unfinished # python run_grid_search.py --model IterGNN --conv VarSumAdj --data $graph --theta khop @@ -28,7 +29,7 @@ for graph in "${datasets[@]}"; do python run_grid_search.py --model PostMLP --conv ChebBase --data $graph #var-Chebyshev2 python run_grid_search.py --model PostMLP --conv ChebConv2 --data $graph - #var-Bernstein + var-Bernstein python run_grid_search.py --model PostMLP --conv BernConv --data $graph #bank-linear diff --git a/examples/trainer/fullbatch.py b/examples/trainer/fullbatch.py index 0dfdd4f..83684c8 100755 --- a/examples/trainer/fullbatch.py +++ b/examples/trainer/fullbatch.py @@ -135,7 +135,7 @@ def run(self) -> ResLogger: res_test = self.test() res_run.merge(res_test) - self.test_deg() + res_run.merge(self.test_deg()) return self.res_logger.merge(res_run)