Skip to content
This repository has been archived by the owner on Apr 29, 2024. It is now read-only.

Commit

Permalink
bump to 0.2.0 (#101)
Browse files Browse the repository at this point in the history
bumps to v.0.2.0
  • Loading branch information
wdika committed Sep 12, 2022
1 parent c0fdae4 commit 076b5e7
Show file tree
Hide file tree
Showing 22 changed files with 643 additions and 105 deletions.
34 changes: 0 additions & 34 deletions .circleci/config.yml

This file was deleted.

7 changes: 5 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ authors:
- family-names: Karkalousos
given-names: Dimitrios
orcid: https://orcid.org/0000-0001-5983-0322
- family-names: Zhang
given-names: Chaoping
orcid: https://orcid.org/0000-0002-6004-983X
- family-names: Caan
given-names: Matthan
orcid: https://orcid.org/0000-0002-5162-8880
title: "MRI Data Consistency"
url: "https://github.com/wdika/mridc"
version: 0.1.1
date-released: 2022-25-05
version: 0.2.0
date-released: 2022-12-09
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ COPY . .

# start building the final container
FROM mridc-deps as mridc
ARG MRIDC_VERSION=0.1.1
ARG MRIDC_VERSION=0.2.0

# Check that MRIDC_VERSION is set. Build will fail without this. Expose MRIDC and base container
# version information as runtime environment variable for introspection purposes
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Data Consistency for Magnetic Resonance Imaging

[![CodeQL](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml/badge.svg)](https://github.com/wdika/mridc/actions/workflows/codeql-analysis.yml)
[![CircleCI](https://circleci.com/gh/wdika/mridc/tree/main.svg?style=svg)](https://circleci.com/gh/wdika/mridc/tree/main)
[![codecov](https://codecov.io/gh/wdika/mridc/branch/main/graph/badge.svg?token=KPPQ33DOTF)](https://codecov.io/gh/wdika/mridc)
[![Tox](https://github.com/wdika/mridc/actions/workflows/tox.yml/badge.svg)](https://github.com/wdika/mridc/actions/workflows/tox.yml)
<a href="https://github.com/psf/black"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>

---
Expand Down Expand Up @@ -36,7 +36,7 @@ The following models are implemented for quantitative imaging:
1.[quantitative Cascades of Independently Recurrent Inference Machines (qCIRIM)](https://iopscience.iop.org/article/10.1088/1361-6560/ac6cc2),
2.[quantitative End-to-End Variational Network (qE2EVN)](https://link.springer.com/chapter/10.1007/978-3-030-59713-9_7),
3.[quantitative Independently Recurrent Inference Machines (qIRIM)](http://arxiv.org/abs/2012.07819),
4.[quantitative Recurrent Inference Machines (qRIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078?via%3Dihub),
4.[quantitative Recurrent Inference Machines (qRIM)](https://www.sciencedirect.com/science/article/abs/pii/S1361841518306078?via%3Dihub).

_Note: Currently only the above models are implemented. More models can be added by extending the reconstruction models
for quantitative imaging. If you wish to extend the toolbox, please open an issue._
Expand Down
6 changes: 6 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ coverage:
# --------------
# which folders/files to ignore
ignore:
- mridc/launch.py
- mridc/core/utils/*
- mridc/utils/arguments.py
- mridc/utils/distributed.py
- mridc/utils/export_utils.py
- mridc/utils/decorators/*
- projects/*
- setup.py

Expand Down
8 changes: 8 additions & 0 deletions docs/source/mridc.collections.common.parts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ mridc.collections.common.parts.rnn\_utils module
:undoc-members:
:show-inheritance:

mridc.collections.common.parts.training\_utils module
-----------------------------------------------------

.. automodule:: mridc.collections.common.parts.training_utils
:members:
:undoc-members:
:show-inheritance:

mridc.collections.common.parts.utils module
-------------------------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/mridc.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ mridc.utils.config\_utils module
:undoc-members:
:show-inheritance:

mridc.utils.debug\_hook module
------------------------------

.. automodule:: mridc.utils.debug_hook
:members:
:undoc-members:
:show-inheritance:

mridc.utils.distributed module
------------------------------

Expand Down
97 changes: 95 additions & 2 deletions mridc/collections/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/dataset.py

from abc import ABC
from typing import Any, List
from typing import Any, Dict, List

import numpy as np
import torch.utils.data as pt_data
from torch.utils.data import Dataset, IterableDataset

__all__ = ["ConcatDataset"]
__all__ = ["ConcatDataset", "ConcatMapDataset"]


class ConcatDataset(pt_data.IterableDataset, ABC):
Expand Down Expand Up @@ -145,3 +146,95 @@ def random_generator(datasets, **kwargs):

while True:
yield np.random.choice(np.arange(num), p=p)


class ConcatMapDataset(Dataset):
"""
A dataset that accepts as argument multiple datasets and then samples from them based on the specified
sampling technique.
Parameters
----------
datasets: A list of datasets to sample from.
shuffle: Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True.
sampling_technique: Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'.
Currently supports 'random' and 'round-robin'.
sampling_probabilities: Probability values for sampling. Only used when sampling_technique = 'random'.
global_rank: Worker rank, used for partitioning map style datasets. Defaults to 0.
world_size: Total number of processes, used for partitioning map style datasets. Defaults to 1.
"""

def __init__(
self,
datasets: List[Any],
sampling_technique: str = "temperature",
sampling_temperature: int = 5,
sampling_probabilities: List[float] = None,
consumed_samples: int = 0,
):
super().__init__()
self.datasets = datasets
self.sampling_kwargs: Dict = {}
self.size = 0
self.sampling_technique = sampling_technique
self.sampling_temperature = sampling_temperature
self.sampling_probabilities = sampling_probabilities
self.consumed_samples = consumed_samples
self.np_rng = np.random.RandomState(consumed_samples)
for dataset in datasets:
self.size += len(dataset)
self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8)
self.permuted_dataset_indices = []
for dataset in self.datasets:
permuted_indices = np.arange(len(dataset))
self.np_rng.shuffle(permuted_indices)
self.permuted_dataset_indices.append(permuted_indices)
if self.sampling_technique == "temperature":
lengths = [len(dataset) for dataset in datasets]
p = np.array(lengths) / np.sum(lengths)
p = np.power(p, 1 / self.sampling_temperature)
p = p / np.sum(p)
self.p = p
elif self.sampling_technique == "random":
if not self.sampling_probabilities:
raise ValueError(
"Random generator expects a 'sampling_probabilities' - a list of probability values corresponding "
"to each dataset."
)
if len(self.sampling_probabilities) != len(self.datasets):
raise ValueError(
"Length of probabilities list must be equal to the number of datasets. " # type: ignore
f"Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets." # type: ignore
)
p = np.array(self.sampling_probabilities)
self.p = p / np.sum(p)

def __len__(self):
return self.size

def _get_dataset_index(self, idx):
"""Returns the index of the dataset to sample from."""
if self.sampling_technique in ["temperature", "random"]:
return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p)
elif self.sampling_technique == "round-robin":
return idx % len(self.datasets)

def __getitem__(self, idx):
# Get the dataset we want to sample from
dataset_index = self._get_dataset_index(idx)

# Get the index of the sample we want to fetch from the dataset
sample_idx = self.dataset_index[dataset_index]

# If the sample idx > dataset size, reset to 0.
if sample_idx > len(self.datasets[dataset_index]):
sample_idx = 0
self.dataset_index[dataset_index] = 0

# Sample index -> shuffled sample index
shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx]

sample = self.datasets[dataset_index][shuffled_sample_idx]
self.dataset_index[dataset_index] += 1

return sample
2 changes: 1 addition & 1 deletion mridc/collections/common/parts/patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

# Library version globals
TORCH_VERSION = None
TORCH_VERSION_MIN = version.Version("1.9.0")
TORCH_VERSION_MIN = version.Version("1.8.0")
28 changes: 28 additions & 0 deletions mridc/collections/common/parts/training_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/parts/training_utils.py

from contextlib import nullcontext

import torch

__all__ = ["avoid_bfloat16_autocast_context", "avoid_float16_autocast_context"]


def avoid_bfloat16_autocast_context():
"""If the current autocast context is bfloat16, cast it to float32."""
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()


def avoid_float16_autocast_context():
"""If the current autocast context is float16, cast it to bfloat16 if available or float32."""
if not torch.is_autocast_enabled() or torch.get_autocast_gpu_dtype() != torch.float16:
return nullcontext()
if torch.cuda.is_bf16_supported():
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return torch.cuda.amp.autocast(dtype=torch.float32)
19 changes: 12 additions & 7 deletions mridc/core/classes/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from abc import ABC
from os.path import exists
from typing import List, Union

import torch
from torch.onnx import TrainingMode
Expand Down Expand Up @@ -47,9 +48,10 @@ def export(
do_constant_folding=True,
onnx_opset_version=None,
training=TrainingMode.EVAL,
check_trace: bool = False,
check_trace: Union[bool, List[torch.Tensor]] = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions: bool = False,
):
"""
Export the module to a file.
Expand All @@ -65,6 +67,7 @@ def export(
check_trace: If True, check the trace of the exported model.
dynamic_axes: A dictionary of input names and dynamic axes.
check_tolerance: The tolerance for the check_trace.
export_modules_as_functions: If True, export modules as functions.
"""
all_out = []
all_descr = []
Expand All @@ -81,6 +84,7 @@ def export(
check_trace=check_trace,
dynamic_axes=dynamic_axes,
check_tolerance=check_tolerance,
export_modules_as_functions=export_modules_as_functions,
)
# Propagate input example (default scenario, may need to be overriden)
if input_example is not None:
Expand All @@ -101,6 +105,7 @@ def _export(
check_trace: bool = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions: bool = False,
):
"""
Helper to export the module to a file.
Expand All @@ -116,15 +121,13 @@ def _export(
check_trace: If True, check the trace of the exported model.
dynamic_axes: A dictionary of input names and dynamic axes.
check_tolerance: The tolerance for the check_trace.
export_modules_as_functions: If True, export modules as functions.
"""
my_args = locals().copy()
my_args.pop("self")

exportables = []
for m in self.modules(): # type: ignore
if isinstance(m, Exportable):
exportables.append(m)
qual_name = self.__module__ + "." + self.__class__.__qualname__
exportables = [m for m in self.modules() if isinstance(m, Exportable)] # type: ignore
qual_name = f"{self.__module__}.{self.__class__.__qualname__}"
format = get_export_format(output)
output_descr = f"{qual_name} exported to {format}"

Expand Down Expand Up @@ -191,10 +194,12 @@ def _export(
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
export_modules_as_functions=export_modules_as_functions,
)

if check_trace:
verify_runtime(output, input_list, input_dict, input_names, output_names, output_example)
check_trace_input = [input_example] if isinstance(check_trace, bool) else check_trace
verify_runtime(self, output, check_trace_input, input_names)

else:
raise ValueError(f"Encountered unknown export format {format}.")
Expand Down
4 changes: 2 additions & 2 deletions mridc/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N
for i in range(len(scheduler_config["name"]))
]

self._scheduler = _schedulers
self._optimizer = [self._optimizer] * len(scheduler_config["name"])
self._scheduler = _schedulers # type: ignore
self._optimizer = [self._optimizer] * len(scheduler_config["name"]) # type: ignore
else:
# Try to instantiate scheduler for optimizer
self._scheduler = mridc.core.optim.lr_scheduler.prepare_lr_scheduler( # type: ignore
Expand Down
11 changes: 11 additions & 0 deletions mridc/core/conf/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class NoamAnnealingParams(WarmupSchedulerParams):
min_lr: float = 0.0


@dataclass
class NoamHoldAnnealingParams(WarmupHoldSchedulerParams):
"""
Polynomial Hold Decay Annealing parameter config.
It is not derived from Config as it is not a MRIDC object (and in particular it doesn't need a name).
"""

decay_rate: float = 0.5


@dataclass
class WarmupAnnealingParams(WarmupSchedulerParams):
"""Warmup Annealing parameter config"""
Expand Down Expand Up @@ -205,6 +215,7 @@ def get_scheduler_config(name: str, **kwargs: Optional[Dict[str, Any]]) -> parti
"SquareRootConstantSchedulerParams": SquareRootConstantSchedulerParams,
"CosineAnnealingParams": CosineAnnealingParams,
"NoamAnnealingParams": NoamAnnealingParams,
"NoamHoldAnnealingParams": NoamHoldAnnealingParams,
"WarmupAnnealingParams": WarmupAnnealingParams,
"PolynomialDecayAnnealingParams": PolynomialDecayAnnealingParams,
"PolynomialHoldDecayAnnealingParams": PolynomialHoldDecayAnnealingParams,
Expand Down
Loading

0 comments on commit 076b5e7

Please sign in to comment.