Skip to content

Commit d98b58a

Browse files
saumyasinhasaumyasinhacalebrob6adamjstewart
authored
Add DeepGlobe dataset for land cover (torchgeo#578)
* add class for Deep Globe Land Cover dataset * add Lightning data module implementation for deepglobe land cover * fix formatting errors * fix urls, formats and add link for paper * add tests for deepglobe dataset and datamodule * fix a test case and a few more formatting error * add data.py and modify error match for data download * modify draw_semantic_segmentation_masks for cases when mask is a subset of all class labels * fix mypy error * add to docs for documentation * add deepglobe to the dataset lists csv * fix error in building docs * Update datamodules.rst * Update datasets.rst * Update data.py * Update utils.py * change file permissions of non_geo_datasets.csv * Add versionadded * Update torchgeo/datasets/deepglobelandcover.py Co-authored-by: Adam J. Stewart <[email protected]> * Change end of line sequence * Update tests/data/deepglobelandcover/data.py Co-authored-by: Adam J. Stewart <[email protected]> * exist_ok * Update tests/datasets/test_deepglobelandcover.py Co-authored-by: Adam J. Stewart <[email protected]> * Remove datamodule tests * Remove split monkeypatch * Running black * Add val percent to test conf * Sort filelist so indices are the same across platforms * Simplified the file and mask fns * Re-adding datamodule tests for coverage * Add sub-configs to test val_split_pct in the datamodule * Lets try it * Update tests/conf/deepglobelandcover_0.yaml Co-authored-by: Adam J. Stewart <[email protected]> * nulllllllll * ingore_zeros -> ignore_index Co-authored-by: saumyasinha <[email protected]> Co-authored-by: Caleb Robinson <[email protected]> Co-authored-by: Adam J. Stewart <[email protected]>
1 parent bcd18c7 commit d98b58a

File tree

14 files changed

+594
-2
lines changed

14 files changed

+594
-2
lines changed

docs/api/datamodules.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ COWC
2929

3030
.. autoclass:: COWCCountingDataModule
3131

32+
Deep Globe Land Cover Challenge
33+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
34+
35+
.. autoclass:: DeepGlobeLandCoverDataModule
36+
3237
ETCI2021 Flood Detection
3338
^^^^^^^^^^^^^^^^^^^^^^^^
3439

docs/api/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ Kenya Crop Type
164164

165165
.. autoclass:: CV4AKenyaCropType
166166

167+
Deep Globe Land Cover Challenge
168+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
169+
170+
.. autoclass:: DeepGlobeLandCover
171+
167172
DFC2022
168173
^^^^^^^
169174

docs/api/non_geo_datasets.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
44
`BigEarthNet`_,C,Sentinel-1/2,"590,326",19--43,120x120,10,"SAR, MSI"
55
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB
66
`Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI
7+
`Deep Globe Land Cover Challenge`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
78
`DFC2022`_,S,Aerial,,15,"2,000x2,000",0.5,RGB
89
`ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR
910
`EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
experiment:
2+
task: "deepglobelandcover"
3+
module:
4+
loss: "ce"
5+
segmentation_model: "unet"
6+
encoder_name: "resnet18"
7+
encoder_weights: null
8+
learning_rate: 1e-3
9+
learning_rate_schedule_patience: 6
10+
verbose: false
11+
in_channels: 3
12+
num_classes: 7
13+
num_filters: 1
14+
ignore_index: null
15+
datamodule:
16+
root_dir: "tests/data/deepglobelandcover"
17+
val_split_pct: 0.0
18+
batch_size: 1
19+
num_workers: 0
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
experiment:
2+
task: "deepglobelandcover"
3+
module:
4+
loss: "ce"
5+
segmentation_model: "unet"
6+
encoder_name: "resnet18"
7+
encoder_weights: null
8+
learning_rate: 1e-3
9+
learning_rate_schedule_patience: 6
10+
verbose: false
11+
in_channels: 3
12+
num_classes: 7
13+
num_filters: 1
14+
ignore_index: null
15+
datamodule:
16+
root_dir: "tests/data/deepglobelandcover"
17+
val_split_pct: 0.5
18+
batch_size: 1
19+
num_workers: 0
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License.
5+
6+
import os
7+
import shutil
8+
9+
import numpy as np
10+
from PIL import Image
11+
from torchvision.datasets.utils import calculate_md5
12+
13+
14+
def generate_test_data(root: str, n_samples: int = 3) -> str:
15+
"""Create test data archive for DeepGlobeLandCover dataset.
16+
17+
Args:
18+
root: path to store test data
19+
n_samples: number of samples.
20+
21+
Returns:
22+
md5 hash of created archive
23+
"""
24+
dtype = np.uint8
25+
size = 2
26+
27+
folder_path = os.path.join(root, "data")
28+
29+
train_img_dir = os.path.join(folder_path, "data", "training_data", "images")
30+
train_mask_dir = os.path.join(folder_path, "data", "training_data", "masks")
31+
test_img_dir = os.path.join(folder_path, "data", "test_data", "images")
32+
test_mask_dir = os.path.join(folder_path, "data", "test_data", "masks")
33+
34+
os.makedirs(train_img_dir, exist_ok=True)
35+
os.makedirs(train_mask_dir, exist_ok=True)
36+
os.makedirs(test_img_dir, exist_ok=True)
37+
os.makedirs(test_mask_dir, exist_ok=True)
38+
39+
train_ids = [1, 2, 3]
40+
test_ids = [8, 9, 10]
41+
42+
for i in range(n_samples):
43+
train_id = train_ids[i]
44+
test_id = test_ids[i]
45+
46+
dtype_max = np.iinfo(dtype).max
47+
train_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype)
48+
train_img = Image.fromarray(train_arr)
49+
train_img.save(os.path.join(train_img_dir, str(train_id) + "_sat.jpg"))
50+
51+
test_arr = np.random.randint(dtype_max, size=(size, size, 3), dtype=dtype)
52+
test_img = Image.fromarray(test_arr)
53+
test_img.save(os.path.join(test_img_dir, str(test_id) + "_sat.jpg"))
54+
55+
train_mask_arr = np.full((size, size, 3), (0, 255, 255), dtype=dtype)
56+
train_mask_img = Image.fromarray(train_mask_arr)
57+
train_mask_img.save(os.path.join(train_mask_dir, str(train_id) + "_mask.png"))
58+
59+
test_mask_arr = np.full((size, size, 3), (255, 0, 255), dtype=dtype)
60+
test_mask_img = Image.fromarray(test_mask_arr)
61+
test_mask_img.save(os.path.join(test_mask_dir, str(test_id) + "_mask.png"))
62+
63+
# Create archive
64+
shutil.make_archive(folder_path, "zip", folder_path)
65+
shutil.rmtree(folder_path)
66+
return calculate_md5(f"{folder_path}.zip")
67+
68+
69+
if __name__ == "__main__":
70+
md5_hash = generate_test_data(os.getcwd(), 3)
71+
print(md5_hash + "\n")
5.61 KB
Binary file not shown.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
import shutil
6+
from pathlib import Path
7+
8+
import matplotlib.pyplot as plt
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
from _pytest.fixtures import SubRequest
13+
from _pytest.monkeypatch import MonkeyPatch
14+
15+
from torchgeo.datasets import DeepGlobeLandCover
16+
17+
18+
class TestDeepGlobeLandCover:
19+
@pytest.fixture(params=["train", "test"])
20+
def dataset(
21+
self, monkeypatch: MonkeyPatch, request: SubRequest
22+
) -> DeepGlobeLandCover:
23+
md5 = "2cbd68d36b1485f09f32d874dde7c5c5"
24+
monkeypatch.setattr(DeepGlobeLandCover, "md5", md5)
25+
root = os.path.join("tests", "data", "deepglobelandcover")
26+
split = request.param
27+
transforms = nn.Identity()
28+
return DeepGlobeLandCover(root, split, transforms, checksum=True)
29+
30+
def test_getitem(self, dataset: DeepGlobeLandCover) -> None:
31+
x = dataset[0]
32+
assert isinstance(x, dict)
33+
assert isinstance(x["image"], torch.Tensor)
34+
assert isinstance(x["mask"], torch.Tensor)
35+
36+
def test_len(self, dataset: DeepGlobeLandCover) -> None:
37+
assert len(dataset) == 3
38+
39+
def test_extract(self, tmp_path: Path) -> None:
40+
root = os.path.join("tests", "data", "deepglobelandcover")
41+
filename = "data.zip"
42+
shutil.copyfile(
43+
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
44+
)
45+
DeepGlobeLandCover(root=str(tmp_path))
46+
47+
def test_corrupted(self, tmp_path: Path) -> None:
48+
with open(os.path.join(tmp_path, "data.zip"), "w") as f:
49+
f.write("bad")
50+
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
51+
DeepGlobeLandCover(root=str(tmp_path), checksum=True)
52+
53+
def test_invalid_split(self) -> None:
54+
with pytest.raises(AssertionError):
55+
DeepGlobeLandCover(split="foo")
56+
57+
def test_not_downloaded(self, tmp_path: Path) -> None:
58+
with pytest.raises(
59+
RuntimeError,
60+
match="Dataset not found in `root`, either"
61+
+ " specify a different `root` directory or manually download"
62+
+ " the dataset to this directory.",
63+
):
64+
DeepGlobeLandCover(str(tmp_path))
65+
66+
def test_plot(self, dataset: DeepGlobeLandCover) -> None:
67+
x = dataset[0].copy()
68+
dataset.plot(x, suptitle="Test")
69+
plt.close()
70+
dataset.plot(x, show_titles=False)
71+
plt.close()
72+
x["prediction"] = x["mask"].clone()
73+
dataset.plot(x)
74+
plt.close()

tests/trainers/test_segmentation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchgeo.datamodules import (
1515
ChesapeakeCVPRDataModule,
16+
DeepGlobeLandCoverDataModule,
1617
ETCI2021DataModule,
1718
InriaAerialImageLabelingDataModule,
1819
LandCoverAIDataModule,
@@ -35,6 +36,8 @@ class TestSemanticSegmentationTask:
3536
"name,classname",
3637
[
3738
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
39+
("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
40+
("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
3841
("etci2021", ETCI2021DataModule),
3942
("inria", InriaAerialImageLabelingDataModule),
4043
("landcoverai", LandCoverAIDataModule),

torchgeo/datamodules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .chesapeake import ChesapeakeCVPRDataModule
88
from .cowc import COWCCountingDataModule
99
from .cyclone import CycloneDataModule
10+
from .deepglobelandcover import DeepGlobeLandCoverDataModule
1011
from .etci2021 import ETCI2021DataModule
1112
from .eurosat import EuroSATDataModule
1213
from .fair1m import FAIR1MDataModule
@@ -32,6 +33,7 @@
3233
# VisionDataset
3334
"BigEarthNetDataModule",
3435
"COWCCountingDataModule",
36+
"DeepGlobeLandCoverDataModule",
3537
"ETCI2021DataModule",
3638
"EuroSATDataModule",
3739
"FAIR1MDataModule",

0 commit comments

Comments
 (0)