diff --git a/README.md b/README.md index cd86ffc..63bf555 100644 --- a/README.md +++ b/README.md @@ -1,84 +1,118 @@ [![pages-build-deployment](https://github.com/a-rouxel/simca/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/a-rouxel/simca/actions/workflows/pages/pages-build-deployment) - # SIMCA: Coded Aperture Snapshot Spectral Imaging (CASSI) Simulator -SIMCA is a Python/Qt application designed to perform optical simulations for Coded Aperture Snapshot Spectral Imaging (CASSI). +SIMCA is a Python/Qt application designed to perform optical simulations for Coded Aperture Snapshot Spectral Imaging (CASSI). Go check the documentation page [here](https://a-rouxel.github.io/simca/) ## Installation -To install SIMCA, follow the steps below: +To perform again our experiments on SIMCA for the Optical Sensing Congress 2024, follow the steps below: 1. Clone the repository from Github: ```bash -git clone https://github.com/a-rouxel/simca.git +git clone -b optica-sensing-congress https://github.com/a-rouxel/simca.git cd simca ``` 2. Create a dedicated Python environment using Miniconda. If you don't have Miniconda installed, you can find the instructions [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html). ```bash -# Create a new Python environment -conda create -n simca-env python=3.9 - -# Activate the environment -conda activate simca-env +# Create a new Python environment with the required packages +conda env create -f environment.yml ``` -3. Install the necessary Python packages that SIMCA relies on. These are listed in the `requirements.txt` file in the repository. +3. Activate the environment ```bash -# Install necessary Python packages with pip -pip install -r requirements.txt +conda activate simca ``` ## Download datasets -4. Download the standard datasets from this [link](https://cloud.laas.fr/index.php/s/LQjWVsZgeq27Wz6/download), then unzip and paste the `datasets` folder in the root directory of SIMCA. - -## Quick Start with GUI (option 1) - -5. Start the application: - +4. Download the standard datasets from this [link (soon)](https://placeholder), then unzip and paste the `datasets_reconstruction` folder in the root directory of SIMCA. ```bash -# run the app -python main.py +|--simca + |--MST + |--simca + |--utils + |--datasets_reconstruction + |--mst_datasets + |--cave_1024_28_test + |--scene2.mat + : + |--scene191.mat + |--cave_1024_28_train + |--scene1.mat + : + |--scene205.mat ``` -## Quick Start from script (option 2) - -5. Run the example script : +## Download the checkpoints +5. If you want to use our saved checkpoints to run the architecture on the test dataset, download the checkpoints from this [link (soon)](https://placeholder), then unzip and paste the `saved_checkpoints` folder in the root directory of SIMCA. ```bash -# run the script -python simple_script.py +|--simca + |--MST + |--simca + |--utils + |--datasets_reconstruction + |--mst_datasets + |--cave_1024_28_test + |--scene2.mat + : + |--scene191.mat + |--cave_1024_28_train + |--scene1.mat + : + |--scene205.mat + |--saved_checkpoints ``` -## Main Features - -SIMCA includes four main features: - -- **Scene Analysis**: This module is used to analyze input multi- or hyper-spectral input scenes. It includes data slices, spectrum analysis, and ground truth labeling. +## Train the framework from scratch -- **Optical Design**: This module is used to compare the performances of various optical systems. +If you wish to train the framework again : +1. Run the training_simca_reconstruction.py script, this corresponds to the reconstruction with random masks: +```bash +python training_simca_reconstruction.py +``` +The checkpoints will be saved in the ```checkpoints``` folder. -- **Coded Aperture Patterns Generation**: This module is used to generate spectral and/or spatial filtering, based on the optical design. +2. Run the training_simca_reconstruction_full scripts, this corresponds to the reconstruction either with fine-tuned float or binary masks: +```bash +python training_simca_reconstruction_full_binary.py +python training_simca_reconstruction_full_float.py +``` +The checkpoints will be saved in the ```checkpoints_full_binary``` and ```checkpoints_full_float``` folders . -- **Acquisition Coded Images**: This module is used to encode and acquire images. +## Testing the framework -For more detailed information about each feature and further instructions, please visit our [documentation website](https://a-rouxel.github.io/simca/). +If you wish to test the framework: -## Testing the package +0. If you want to use other checkpoints than the ones provided, change the value of the following variables with the path of your checkpoints: +```bash +test_simca_reconstruction.py > reconstruction_checkpoint +test_simca_reconstruction_full_binary.py > reconstruction_checkpoint, full_model_checkpoint +test_simca_reconstruction_full_float.py > reconstruction_checkpoint, full_model_checkpoint +``` +```reconstruction_checkpoint``` is the path to the checkpoint generated in the ```checkpoints``` folder. +```full_model_checkpoint``` is the path to the checkpoint generated in the ```checkpoints_full_binary``` and ```checkpoints_full_float``` folders respectively. -If you wish to run tests on the simca package functionnalities: +1. Run the test scripts: -1. Run the tests.py script: +```bash +python test_simca_reconstruction.py +python test_simca_reconstruction_full_binary.py +python test_simca_reconstruction_full_float.py +``` +The results will be saved in the ```results``` folder. Afterwards, you can also run the ```summarize_results.py``` script to average results over the runs and per scene. +2. (Optional) Run the visualization script: +With this script you will be able to compare the reconstruction spectra of a few points in a scene. ```bash -python tests.py +python show_spectrum.py ``` ## Building Documentation @@ -115,4 +149,4 @@ SIMCA is licensed under the [GNU General Public License](https://www.gnu.org/lic ## Contact -For any questions or feedback, please contact us at arouxel@laas.fr +For any questions or feedback, please contact us at lpaillet@laas.fr diff --git a/training_simca_reconstruction.py b/training_simca_reconstruction.py index 2d0f5ee..9df79c6 100755 --- a/training_simca_reconstruction.py +++ b/training_simca_reconstruction.py @@ -4,15 +4,12 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger import torch -import datetime data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=11) -datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') - name = "training_simca_reconstruction" model_name = "dauhst_9" @@ -35,7 +32,7 @@ checkpoint_callback = ModelCheckpoint( monitor='val_loss', # Metric to monitor dirpath='checkpoints/', # Directory path for saving checkpoints - filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name + filename=f'best-checkpoint_{model_name}', # Checkpoint file name save_top_k=1, # Save the top k models mode='min', # 'min' for metrics where lower is better, 'max' for vice versa save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' diff --git a/training_simca_reconstruction_full_binary.py b/training_simca_reconstruction_full_binary.py index faff09b..d6e91f9 100755 --- a/training_simca_reconstruction_full_binary.py +++ b/training_simca_reconstruction_full_binary.py @@ -4,19 +4,15 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger import torch -import datetime - data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5, augment=True) -datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') - name = "training_simca_reconstruction_full_binary" model_name = "dauhst_9" -reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" +reconstruction_checkpoint = f"./checkpoints/best-checkpoint_{model_name}.ckpt" mask_model = "learned_mask" @@ -39,7 +35,7 @@ checkpoint_callback = ModelCheckpoint( monitor='val_ssim_loss', # Metric to monitor dirpath='checkpoints_full_binary/', # Directory path for saving checkpoints - filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name + filename=f'best-checkpoint_full_binary_{model_name}', # Checkpoint file name save_top_k=1, # Save the top k models mode='max', # 'min' for metrics where lower is better, 'max' for vice versa save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' diff --git a/training_simca_reconstruction_full_float.py b/training_simca_reconstruction_full_float.py index 03e3403..c37d799 100755 --- a/training_simca_reconstruction_full_float.py +++ b/training_simca_reconstruction_full_float.py @@ -4,19 +4,16 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger import torch -import datetime data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5, augment=True) -datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') - name = "training_simca_reconstruction_full_float" model_name = "dauhst_9" -reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" +reconstruction_checkpoint = f"./checkpoints/best-checkpoint_{model_name}.ckpt" mask_model = "learned_mask_float" @@ -39,7 +36,7 @@ checkpoint_callback = ModelCheckpoint( monitor='val_ssim_loss', # Metric to monitor dirpath='checkpoints_full_float/', # Directory path for saving checkpoints - filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name + filename=f'best-checkpoint_full_float_{model_name}', # Checkpoint file name save_top_k=1, # Save the top k models mode='max', # 'min' for metrics where lower is better, 'max' for vice versa save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt'