From 5476fe3764b8fa2a0d064f7534ae6cf0fe20f16a Mon Sep 17 00:00:00 2001 From: UltralyticsAssistant Date: Mon, 26 Aug 2024 15:49:12 +0200 Subject: [PATCH] Refactor code for speed and clarity --- README.md | 2 +- gcp/wave_pytorch_gcp.py | 5 ++++- train.py | 9 ++++++++- train_tf.py | 1 - utils/utils.py | 2 ++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d92adc7..d21e084 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Welcome to the [Ultralytics WAVE repository](https://github.com/ultralytics/wave Here, we introduce **WA**veform **V**ector **E**xploitation (WAVE), a novel approach that uses Deep Learning to readout and reconstruct signals from particle physics detectors. This repository contains our open-source codebase and aims to foster collaboration and innovation in this exciting intersection of ML and physics. -[![Ultralytics Actions](https://github.com/ultralytics/wave/actions/workflows/format.yml/badge.svg)](https://github.com/ultralytics/wave/actions/workflows/format.yml) Discord Ultralytics Forums +[![Ultralytics Actions](https://github.com/ultralytics/wave/actions/workflows/format.yml/badge.svg)](https://github.com/ultralytics/wave/actions/workflows/format.yml) Discord Ultralytics Forums Ultralytics Reddit ## 🚀 Project Objectives diff --git a/gcp/wave_pytorch_gcp.py b/gcp/wave_pytorch_gcp.py index 5bd7ecc..20cb574 100644 --- a/gcp/wave_pytorch_gcp.py +++ b/gcp/wave_pytorch_gcp.py @@ -4,7 +4,6 @@ import scipy.io import torch - from utils import * # set printoptions @@ -117,6 +116,8 @@ def runexample(H, model, str, lr=0.001, amsgrad=False): class LinearAct(torch.nn.Module): + """Applies a linear transformation followed by Tanh activation to the input tensor.""" + def __init__(self, nx, ny): """Initializes the LinearAct module with input and output dimensions and defines a linear transformation followed by a Tanh activation. @@ -131,6 +132,8 @@ def forward(self, x): class WAVE(torch.nn.Module): + """A neural network model for waveform data processing with multiple linear and activation layers.""" + def __init__(self, n): # n = [512, 108, 23, 5, 1] """Initializes the WAVE model with specified linear layers and activation functions.""" super(WAVE, self).__init__() diff --git a/train.py b/train.py index 947eda4..4e655c1 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,6 @@ import scipy.io import torch.nn as nn - from utils.torch_utils import * from utils.utils import * @@ -122,6 +121,8 @@ def train(H, model, str, lr=0.001): # 400 5.1498e-05 0.023752 12.484 0.15728 # var 0 class WAVE(torch.nn.Module): + """A neural network model for waveform data regression with three fully connected layers.""" + def __init__(self, n=(512, 64, 8, 2)): """Initializes the WAVE model architecture with specified layer sizes.""" super(WAVE, self).__init__() @@ -139,6 +140,8 @@ def forward(self, x): # x.shape = [bs, 512] # https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate # 121 0.47059 0.0306 14.184 0.1608 class WAVE4(nn.Module): + """Implements a convolutional neural network for waveform data processing with configurable output layers.""" + def __init__(self, n_out=2): """Initializes the WAVE4 model with specified output layers and configurations for convolutional layers.""" super(WAVE4, self).__init__() @@ -170,6 +173,8 @@ def forward(self, x): # x.shape = [bs, 512] # 65 4.22e-05 0.021527 11.883 0.14406 class WAVE3(nn.Module): + """Implements a convolutional neural network for feature extraction and classification from waveform data.""" + def __init__(self, n_out=2): """Initializes the WAVE3 class with neural network layers for feature extraction and classification in a sequential manner. @@ -215,6 +220,8 @@ def forward(self, x): # x.shape = [bs, 512] # 121 2.6941e-05 0.021642 11.923 0.14201 # var 1 class WAVE2(nn.Module): + """Implements the WAVE2 model for processing input tensors through convolutional layers for feature extraction.""" + def __init__(self, n_out=2): """Initializes the WAVE2 model architecture components.""" super(WAVE2, self).__init__() diff --git a/train_tf.py b/train_tf.py index dcfc5e2..625f955 100644 --- a/train_tf.py +++ b/train_tf.py @@ -5,7 +5,6 @@ import scipy.io import tensorflow as tf from plotly.offline import plot - from utils.utils import * tf.enable_eager_execution() diff --git a/utils/utils.py b/utils/utils.py index c259529..46db4ec 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -72,6 +72,8 @@ def model_info(model): class patienceStopper(object): + """Implements early stopping mechanism for training models based on validation loss and patience criteria.""" + def __init__(self, patience=10, verbose=True, epochs=1000, printerval=10): """Initialize a patience stopper with given parameters for early stopping in training.""" self.patience = patience