-
Notifications
You must be signed in to change notification settings - Fork 112
Add cnn model #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
sevmag
wants to merge
31
commits into
graphnet-team:main
Choose a base branch
from
sevmag:add_cnn_model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add cnn model #813
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
a143af5
Adding Images
sevmag 983427a
fix init
sevmag 6866c10
fix logic for detector in IC86 Image
sevmag cc2dd55
Fixing bugs TheoCNN
sevmag 16660c9
more fixes for cnn
sevmag ebef5bf
Fixing batching for images
sevmag 95051cf
Adjusting imports
sevmag b04de2f
Fixing gitignore & mapping_table
a50ff96
fixing image num_nodes
sevmag 380a8d1
adding_counts to summary features
sevmag 9ae74b1
change mapping to faster version
sevmag d3a5d4a
Faster Mapping & unit tests
sevmag 1bcfd83
Rename classes & more unit tests
sevmag 194cce1
Adding LCSC model
sevmag e52b215
Changing annotations and docstrings
sevmag 58e2898
Adding example script
sevmag 7ebc592
adding cnn example
sevmag ad608e2
Adjust docstring
sevmag 7c038a8
Fixing comments in example
sevmag 1e38718
Add more to docstring in LCSC
sevmag 62d510e
adjust docstrings theos cnn
sevmag 008961f
docstring clean ups
sevmag 7d02215
add info to docstring
sevmag ddcb8db
add shape property
sevmag 43dc067
renaming the module to IceCubeDNN
sevmag 8e5ff86
removing unecessary typing annotations in docstring
sevmag 696d053
cleaned up the __init__ of LCSC and defined parsing private functions
sevmag 77b7354
make IceCubeDNN configurable
sevmag 7a0eabd
adjusted docstring and assertion logic of LCSC to clarify detector usage
sevmag 9ea5286
fix spelling mistake and clarify comment
sevmag c2c3873
add shape property to imagedefinition
sevmag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,343 @@ | ||
| """Example of training a CNN Model.""" | ||
|
|
||
| import os | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from pytorch_lightning.loggers import WandbLogger | ||
| import torch | ||
| from torch.optim.adam import Adam | ||
|
|
||
| from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR | ||
| from graphnet.data.constants import TRUTH | ||
| from graphnet.models import StandardModel | ||
| from graphnet.models.cnn import LCSC | ||
| from graphnet.models.data_representation import PercentileClusters | ||
| from graphnet.models.task.reconstruction import EnergyReconstruction | ||
| from graphnet.training.callbacks import PiecewiseLinearLR | ||
| from graphnet.training.loss_functions import LogCoshLoss | ||
| from graphnet.utilities.argparse import ArgumentParser | ||
| from graphnet.utilities.logging import Logger | ||
| from graphnet.data.dataset import SQLiteDataset | ||
| from graphnet.data.dataset import ParquetDataset | ||
| from graphnet.models.detector import ORCA150 | ||
| from torch_geometric.data import Batch | ||
| from graphnet.models.data_representation.images import ExamplePrometheusImage | ||
|
|
||
| # Constants | ||
| features = ["sensor_id", "sensor_string_id", "t"] | ||
| truth = TRUTH.PROMETHEUS | ||
|
|
||
|
|
||
| def main( | ||
| path: str, | ||
| pulsemap: str, | ||
| target: str, | ||
| truth_table: str, | ||
| gpus: Optional[List[int]], | ||
| max_epochs: int, | ||
| early_stopping_patience: int, | ||
| batch_size: int, | ||
| num_workers: int, | ||
| wandb: bool = False, | ||
| ) -> None: | ||
| """Run example.""" | ||
| # Construct Logger | ||
| logger = Logger() | ||
|
|
||
| # Initialise Weights & Biases (W&B) run | ||
| if wandb: | ||
| # Make sure W&B output directory exists | ||
| wandb_dir = "./wandb/" | ||
| os.makedirs(wandb_dir, exist_ok=True) | ||
| wandb_logger = WandbLogger( | ||
| project="example-script", | ||
| entity="graphnet-team", | ||
| save_dir=wandb_dir, | ||
| log_model=True, | ||
| ) | ||
|
|
||
| logger.info(f"features: {features}") | ||
| logger.info(f"truth: {truth}") | ||
|
|
||
| # Configuration | ||
| config: Dict[str, Any] = { | ||
| "path": path, | ||
| "pulsemap": pulsemap, | ||
| "batch_size": batch_size, | ||
| "num_workers": num_workers, | ||
| "target": target, | ||
| "early_stopping_patience": early_stopping_patience, | ||
| "fit": { | ||
| "gpus": gpus, | ||
| "max_epochs": max_epochs, | ||
| }, | ||
| "dataset_reference": ( | ||
| SQLiteDataset if path.endswith(".db") else ParquetDataset | ||
| ), | ||
| } | ||
|
|
||
| archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_cnn_model") | ||
| run_name = "lcsc_{}_example".format(config["target"]) | ||
| if wandb: | ||
| # Log configuration to W&B | ||
| wandb_logger.experiment.config.update(config) | ||
|
|
||
| # First we need to define how the image is constructed. | ||
| # This is done using an ImageDefinition. | ||
|
|
||
| # An ImageDefinition combines two components: | ||
|
|
||
| # 1. A pixel definition, which defines how the pixel data is | ||
| # represented. Since an image has always fixed dimensions this | ||
| # pixel definition is also responsible to represent the data in | ||
| # a way such that this fixed dimensions can be achieved. | ||
| # Normally, this could mean that light pulses that arrive at | ||
| # the same optical module must be aggregated to a | ||
| # fixed-dimensional vector. | ||
| # A pixel definition works exactly the same as the | ||
| # a node definition in the graph scenerio. | ||
|
|
||
| # 2. A pixel mapping, which defines where each pixel is located | ||
| # in the final image. This is highly detector specific, as it | ||
| # depends on the geometry of the detector. | ||
|
|
||
| # An ImageDefinition can be used to create multiple images of | ||
| # a single event. In the example of IceCube, you can e.g | ||
| # create three images, one for the so called main array, | ||
| # one for the upper deep core and one for the lower deep | ||
| # core. Essentially, these are just different areas in | ||
| # the detector. | ||
|
|
||
| # Here we use the PercentileClusters pixel definition, which | ||
| # aggregates the light pulses that arrive at the same optical | ||
| # module with percentiles. | ||
| print(features) | ||
| pixel_definition = PercentileClusters( | ||
| cluster_on=["sensor_id", "sensor_string_id"], | ||
| percentiles=[10, 50, 90], | ||
| add_counts=True, | ||
| input_feature_names=features, | ||
| ) | ||
|
|
||
| # The final image definition used here is the ExamplePrometheusImage, | ||
| # which is a detector specific pixel mapping. | ||
| # It maps optical modules into the image | ||
| # using the sensor_string_id and sensor_id | ||
| # (number of the optical module). | ||
| # The detector class standardizes the input features, | ||
| # so that the features are in a ML friendly range. | ||
| # For the mapping of the optical modules to the image it is | ||
| # essential to not change the value of the sensor_id and | ||
| # sensor_string_id. Therefore we need to make sure that | ||
| # these features are not standardized, which is done by the | ||
| # `replace_with_identity` argument of the detector. | ||
| image_definition = ExamplePrometheusImage( | ||
| detector=ORCA150( | ||
| replace_with_identity=[ | ||
| "sensor_id", | ||
| "sensor_string_id", | ||
| ], | ||
| ), | ||
| node_definition=pixel_definition, | ||
| input_feature_names=features, | ||
| string_label="sensor_string_id", | ||
| dom_number_label="sensor_id", | ||
| ) | ||
|
|
||
| # Use SQLiteDataset to load in data | ||
| # The input here depends on the dataset being used, | ||
| # in this case the Prometheus dataset. | ||
| dataset = SQLiteDataset( | ||
| path=config["path"], | ||
| pulsemaps=config["pulsemap"], | ||
| truth_table=truth_table, | ||
| features=features, | ||
| truth=truth, | ||
| data_representation=image_definition, | ||
| ) | ||
|
|
||
| # Create the training and validation dataloaders. | ||
| training_dataloader = torch.utils.data.DataLoader( | ||
| dataset=dataset, | ||
| batch_size=config["batch_size"], | ||
| num_workers=config["num_workers"], | ||
| collate_fn=Batch.from_data_list, | ||
| ) | ||
|
|
||
| validation_dataloader = torch.utils.data.DataLoader( | ||
| dataset=dataset, | ||
| batch_size=config["batch_size"], | ||
| num_workers=config["num_workers"], | ||
| collate_fn=Batch.from_data_list, | ||
| ) | ||
|
|
||
| # Building model | ||
|
|
||
| # Define architecture of the backbone, in this example | ||
| # the LCSC architecture from Alexander Harnisch is used. | ||
| backbone = LCSC( | ||
| num_input_features=image_definition.nb_outputs, | ||
| out_put_dim=2, | ||
| input_norm=True, | ||
| num_conv_layers=5, | ||
| conv_filters=[5, 10, 20, 40, 60], | ||
| kernel_size=3, | ||
| image_size=(8, 9, 22), # dimensions of the example image | ||
| pooling_type=[ | ||
| "Avg", | ||
| None, | ||
| "Avg", | ||
| None, | ||
| "Avg", | ||
| ], | ||
| pooling_kernel_size=[ | ||
| [1, 1, 2], | ||
| None, | ||
| [2, 2, 2], | ||
| None, | ||
| [2, 2, 2], | ||
| ], | ||
| pooling_stride=[ | ||
| [1, 1, 2], | ||
| None, | ||
| [2, 2, 2], | ||
| None, | ||
| [2, 2, 2], | ||
| ], | ||
| num_fc_neurons=50, | ||
| norm_list=True, | ||
| norm_type="Batch", | ||
| ) | ||
| # Define the task. | ||
| # Here an energy reconstruction, with a LogCoshLoss function. | ||
| # The target and prediction are transformed using the log10 function. | ||
| # When infering the prediction is transformed back to the | ||
| # original scale using 10^x. | ||
| task = EnergyReconstruction( | ||
| hidden_size=backbone.nb_outputs, | ||
| target_labels=config["target"], | ||
| loss_function=LogCoshLoss(), | ||
| transform_prediction_and_target=lambda x: torch.log10(x), | ||
| transform_inference=lambda x: torch.pow(10, x), | ||
| ) | ||
| # Define the full model, which includes the backbone, task(s), | ||
| # along with typical machine learning options such as | ||
| # learning rate optimizers and schedulers. | ||
| model = StandardModel( | ||
| data_representation=image_definition, | ||
| backbone=backbone, | ||
| tasks=[task], | ||
| optimizer_class=Adam, | ||
| optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, | ||
| scheduler_class=PiecewiseLinearLR, | ||
| scheduler_kwargs={ | ||
| "milestones": [ | ||
| 0, | ||
| len(training_dataloader) / 2, | ||
| len(training_dataloader) * config["fit"]["max_epochs"], | ||
| ], | ||
| "factors": [1e-2, 1, 1e-02], | ||
| }, | ||
| scheduler_config={ | ||
| "interval": "step", | ||
| }, | ||
| ) | ||
|
|
||
| # Training model | ||
| model.fit( | ||
| training_dataloader, | ||
| validation_dataloader, | ||
| early_stopping_patience=config["early_stopping_patience"], | ||
| logger=wandb_logger if wandb else None, | ||
| **config["fit"], | ||
| ) | ||
|
|
||
| # Get predictions | ||
| additional_attributes = model.target_labels | ||
| assert isinstance(additional_attributes, list) # mypy | ||
|
|
||
| results = model.predict_as_dataframe( | ||
| validation_dataloader, | ||
| additional_attributes=additional_attributes + ["event_no"], | ||
| gpus=config["fit"]["gpus"], | ||
| ) | ||
|
|
||
| # Save predictions and model to file | ||
| db_name = path.split("/")[-1].split(".")[0] | ||
| path = os.path.join(archive, db_name, run_name) | ||
| logger.info(f"Writing results to {path}") | ||
| os.makedirs(path, exist_ok=True) | ||
|
|
||
| # Save results as .csv | ||
| results.to_csv(f"{path}/cnn_results.csv") | ||
|
|
||
| # Save model config and state dict - Version safe save method. | ||
| # This method of saving models is the safest way. | ||
| model.save_state_dict(f"{path}/cnn_state_dict.pth") | ||
| model.save_config(f"{path}/cnn_model_config.yml") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| # Parse command-line arguments | ||
| parser = ArgumentParser( | ||
| description=""" | ||
| Train GNN model without the use of config files. | ||
| """ | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--path", | ||
| help="Path to dataset file (default: %(default)s)", | ||
| default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--pulsemap", | ||
| help="Name of pulsemap to use (default: %(default)s)", | ||
| default="total", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--target", | ||
| help=( | ||
| "Name of feature to use as regression target (default: " | ||
| "%(default)s)" | ||
| ), | ||
| default="total_energy", | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--truth-table", | ||
| help="Name of truth table to be used (default: %(default)s)", | ||
| default="mc_truth", | ||
| ) | ||
|
|
||
| parser.with_standard_arguments( | ||
| "gpus", | ||
| ("max-epochs", 1), | ||
| "early-stopping-patience", | ||
| ("batch-size", 16), | ||
| ("num-workers", 2), | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--wandb", | ||
| action="store_true", | ||
| help="If True, Weights & Biases are used to track the experiment.", | ||
| ) | ||
|
|
||
| args, unknown = parser.parse_known_args() | ||
|
|
||
| main( | ||
| args.path, | ||
| args.pulsemap, | ||
| args.target, | ||
| args.truth_table, | ||
| args.gpus, | ||
| args.max_epochs, | ||
| args.early_stopping_patience, | ||
| args.batch_size, | ||
| args.num_workers, | ||
| args.wandb, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """CNN-specific modules, for performing the main learnable operations.""" | ||
|
|
||
| from .cnn import CNN | ||
| from .icecube_dnn import IceCubeDNN | ||
| from .lcsc import LCSC |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be neat to add a property to the ImageDefinition that contains the resulting image dimension. E.g.
ImageDefinition.shape