Skip to content

Commit

Permalink
new_models_added
Browse files Browse the repository at this point in the history
  • Loading branch information
danielaruizl1 committed Dec 31, 2024
1 parent 8db978f commit db3386b
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 14 deletions.
4 changes: 2 additions & 2 deletions INSTALLATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The **Pytorch-Wildlife** library allows users to directly load the MegadetectorV

## Prerequisites

1. Python 3.8
1. Python 3.10
2. NVIDIA GPU for CUDA support (Optional, the code and demo also supports cpu calculation).
3. `conda` or `mamba` for python environment management and specific version of `opencv`.
4. If you are using CUDA. [CudaToolkit 12.1](https://developer.nvidia.com/cuda-12-1-0-download-archive) is required.
Expand All @@ -32,7 +32,7 @@ The **Pytorch-Wildlife** library allows users to directly load the MegadetectorV
### Create environment
If you have `conda` or `mamba` installed, you can create a new environment with the following commands (switch `conda` to `mamba` for `mamba` users):
```bash
conda create -n pytorch-wildlife python=3.8 -y
conda create -n pytorch-wildlife python=3.10 -y
conda activate pytorch-wildlife
```
NOTE: For Windows users, please use the Anaconda Prompt if you are using Anaconda. Otherwise, please use PowerShell for the conda environment and the rest of the set up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,24 @@ def __init__(self, weights=None, device="cpu", pretrained=True, version='yolov9c
pretrained (bool, optional): Whether to load the pretrained model. Default is True.
version (str, optional): Version of the model to load. Default is 'yolov9c'.
"""

if version == 'yolov9c':
self.IMAGE_SIZE = 640
url = "https://zenodo.org/records/13357337/files/MDV6b-yolov9c.pt?download=1"
self.IMAGE_SIZE = 640

if version == 'yolov9c':
url = "https://zenodo.org/records/14567879/files/MDV6b-yolov9c.pt?download=1"
self.MODEL_NAME = "MDV6b-yolov9c.pt"
elif version == 'yolov9e':
url = "https://zenodo.org/records/14567879/files/MDV6-yolov9e.pt?download=1"
self.MODEL_NAME = "MDV6-yolov9e.pt"
elif version == 'yolov10n':
url = "https://zenodo.org/records/14567879/files/MDV6-yolov10n.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10n.pt"
elif version == 'yolov10x':
url = "https://zenodo.org/records/14567879/files/MDV6-yolov10x.pt?download=1"
self.MODEL_NAME = "MDV6-yolov10x.pt"
elif version =='rtdetrl':
self.IMAGE_SIZE = 640
url = None
url = "https://zenodo.org/records/14567879/files/MDV6b-rtdetrl.pt?download=1"
self.MODEL_NAME = "MDV6b-rtdetrl.pt"
else:
print('Select a valid model version: yolov9c or rtdetrl')
print('Select a valid model version: yolov9c, yolov9e, yolov10n, yolov10x or rtdetrl')

super(MegaDetectorV6, self).__init__(weights=weights, device=device, url=url)
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def _load_model(self, weights=None, device="cpu", url=None):
if weights:
self.predictor.setup_model(weights)
elif url:
if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", "MDV6b-yolov9c.pt")):
if not os.path.exists(os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)):
os.makedirs(os.path.join(torch.hub.get_dir(), "checkpoints"), exist_ok=True)
weights = wget.download(url, out=os.path.join(torch.hub.get_dir(), "checkpoints"))
else:
weights = os.path.join(torch.hub.get_dir(), "checkpoints", "MDV6b-yolov9c.pt")
weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.MODEL_NAME)
self.predictor.setup_model(weights)
else:
raise Exception("Need weights for inference.")
Expand Down
2 changes: 1 addition & 1 deletion demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#%%
# Initializing the MegaDetectorV6 model for image detection
detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="yolov9c")
detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version="yolov10x")

# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6
#detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version="a")
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ ultralytics-yolov5
chardet
wget
ultralytics
setuptools==59.5.0
setuptools
scikit-learn
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'chardet',
'wget',
'ultralytics',
'setuptools==59.5.0',
'setuptools',
'scikit-learn',
],
classifiers=[
Expand Down

0 comments on commit db3386b

Please sign in to comment.