Skip to content

Commit

Permalink
Refactor code for speed and clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Aug 26, 2024
1 parent c12011a commit 5476fe3
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Welcome to the [Ultralytics WAVE repository](https://github.com/ultralytics/wave

Here, we introduce **WA**veform **V**ector **E**xploitation (WAVE), a novel approach that uses Deep Learning to readout and reconstruct signals from particle physics detectors. This repository contains our open-source codebase and aims to foster collaboration and innovation in this exciting intersection of ML and physics.

[![Ultralytics Actions](https://github.com/ultralytics/wave/actions/workflows/format.yml/badge.svg)](https://github.com/ultralytics/wave/actions/workflows/format.yml) <a href="https://ultralytics.com/discord"><img alt="Discord" src="https://img.shields.io/discord/1089800235347353640?logo=discord&logoColor=white&label=Discord&color=blue"></a> <a href="https://community.ultralytics.com"><img alt="Ultralytics Forums" src="https://img.shields.io/discourse/users?server=https%3A%2F%2Fcommunity.ultralytics.com&logo=discourse&label=Forums&color=blue"></a>
[![Ultralytics Actions](https://github.com/ultralytics/wave/actions/workflows/format.yml/badge.svg)](https://github.com/ultralytics/wave/actions/workflows/format.yml) <a href="https://ultralytics.com/discord"><img alt="Discord" src="https://img.shields.io/discord/1089800235347353640?logo=discord&logoColor=white&label=Discord&color=blue"></a> <a href="https://community.ultralytics.com"><img alt="Ultralytics Forums" src="https://img.shields.io/discourse/users?server=https%3A%2F%2Fcommunity.ultralytics.com&logo=discourse&label=Forums&color=blue"></a> <a href="https://reddit.com/r/ultralytics"><img alt="Ultralytics Reddit" src="https://img.shields.io/reddit/subreddit-subscribers/ultralytics?style=flat&logo=reddit&logoColor=white&label=Reddit&color=blue"></a>

## 🚀 Project Objectives

Expand Down
5 changes: 4 additions & 1 deletion gcp/wave_pytorch_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import scipy.io
import torch

from utils import *

# set printoptions
Expand Down Expand Up @@ -117,6 +116,8 @@ def runexample(H, model, str, lr=0.001, amsgrad=False):


class LinearAct(torch.nn.Module):
"""Applies a linear transformation followed by Tanh activation to the input tensor."""

def __init__(self, nx, ny):
"""Initializes the LinearAct module with input and output dimensions and defines a linear transformation
followed by a Tanh activation.
Expand All @@ -131,6 +132,8 @@ def forward(self, x):


class WAVE(torch.nn.Module):
"""A neural network model for waveform data processing with multiple linear and activation layers."""

def __init__(self, n): # n = [512, 108, 23, 5, 1]
"""Initializes the WAVE model with specified linear layers and activation functions."""
super(WAVE, self).__init__()
Expand Down
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import scipy.io
import torch.nn as nn

from utils.torch_utils import *
from utils.utils import *

Expand Down Expand Up @@ -122,6 +121,8 @@ def train(H, model, str, lr=0.001):

# 400 5.1498e-05 0.023752 12.484 0.15728 # var 0
class WAVE(torch.nn.Module):
"""A neural network model for waveform data regression with three fully connected layers."""

def __init__(self, n=(512, 64, 8, 2)):
"""Initializes the WAVE model architecture with specified layer sizes."""
super(WAVE, self).__init__()
Expand All @@ -139,6 +140,8 @@ def forward(self, x): # x.shape = [bs, 512]
# https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/02-intermediate
# 121 0.47059 0.0306 14.184 0.1608
class WAVE4(nn.Module):
"""Implements a convolutional neural network for waveform data processing with configurable output layers."""

def __init__(self, n_out=2):
"""Initializes the WAVE4 model with specified output layers and configurations for convolutional layers."""
super(WAVE4, self).__init__()
Expand Down Expand Up @@ -170,6 +173,8 @@ def forward(self, x): # x.shape = [bs, 512]

# 65 4.22e-05 0.021527 11.883 0.14406
class WAVE3(nn.Module):
"""Implements a convolutional neural network for feature extraction and classification from waveform data."""

def __init__(self, n_out=2):
"""Initializes the WAVE3 class with neural network layers for feature extraction and classification in a
sequential manner.
Expand Down Expand Up @@ -215,6 +220,8 @@ def forward(self, x): # x.shape = [bs, 512]

# 121 2.6941e-05 0.021642 11.923 0.14201 # var 1
class WAVE2(nn.Module):
"""Implements the WAVE2 model for processing input tensors through convolutional layers for feature extraction."""

def __init__(self, n_out=2):
"""Initializes the WAVE2 model architecture components."""
super(WAVE2, self).__init__()
Expand Down
1 change: 0 additions & 1 deletion train_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import scipy.io
import tensorflow as tf
from plotly.offline import plot

from utils.utils import *

tf.enable_eager_execution()
Expand Down
2 changes: 2 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def model_info(model):


class patienceStopper(object):
"""Implements early stopping mechanism for training models based on validation loss and patience criteria."""

def __init__(self, patience=10, verbose=True, epochs=1000, printerval=10):
"""Initialize a patience stopper with given parameters for early stopping in training."""
self.patience = patience
Expand Down

0 comments on commit 5476fe3

Please sign in to comment.