Skip to content

Commit

Permalink
fix: fix some bugs due to the lack of support in Python 3.6.10.
Browse files Browse the repository at this point in the history
  • Loading branch information
fu_jun committed Dec 23, 2024
1 parent 56a612e commit b0a0fa0
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ We train our DANet-101 with only fine annotated data and submit our test results
```

3. Dataset
- Download the [Cityscapes](https://www.cityscapes-dataset.com/) dataset and convert the dataset to [19 categories](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py).
- Download the [Cityscapes](https://www.cityscapes-dataset.com/) dataset.
- Please put dataset in folder `./datasets`

4. Evaluation for DANet
Expand Down
54 changes: 47 additions & 7 deletions encoding/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __getitem__(self, index):
return img, os.path.basename(self.images[index])

mask = Image.open(self.masks[index])

# synchrosized transform
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
Expand All @@ -64,8 +63,49 @@ def __getitem__(self, index):

def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
target[target == 255] = -1
return torch.from_numpy(target).long()
mapping_20 = {
0: 255,
1: 255,
2: 255,
3: 255,
4: 255,
5: 255,
6: 255,
7: 0,
8: 1,
9: 255,
10: 255,
11: 2,
12: 3,
13: 4,
14: 255,
15: 255,
16: 255,
17: 5,
18: 255,
19: 6,
20: 7,
21: 8,
22: 9,
23: 10,
24: 11,
25: 12,
26: 13,
27: 14,
28: 15,
29: 255,
30: 255,
31: 16,
32: 17,
33: 18,
-1: 255,
}

label_mask = np.zeros_like(target)
for k in mapping_20:
label_mask[target == k] = mapping_20[k]
label_mask[label_mask == 255] = -1
return torch.from_numpy(label_mask).long()

def __len__(self):
return len(self.images)
Expand All @@ -81,9 +121,9 @@ def get_path_pairs(folder,split_f):
mask_paths = []
with open(split_f, 'r') as lines:
for line in tqdm(lines):
ll_str = re.split('\t', line)
imgpath = os.path.join(folder,ll_str[0].rstrip())
maskpath = os.path.join(folder,ll_str[1].rstrip())
ll_str = line.rstrip()
imgpath = os.path.join(folder,'leftImg8bit/val', ll_str+'_leftImg8bit.png')
maskpath = os.path.join(folder,'gtFine/val', ll_str+'_gtFine_labelIds.png')
if os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
Expand All @@ -94,7 +134,7 @@ def get_path_pairs(folder,split_f):
split_f = os.path.join(folder, 'train_fine.txt')
img_paths, mask_paths = get_path_pairs(folder, split_f)
elif split == 'val':
split_f = os.path.join(folder, 'val_fine.txt')
split_f = os.path.join(folder, 'val.txt')
img_paths, mask_paths = get_path_pairs(folder, split_f)
elif split == 'test':
split_f = os.path.join(folder, 'test.txt')
Expand Down
11 changes: 6 additions & 5 deletions experiments/segmentation/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from encoding.nn import SegmentationLosses, SyncBatchNorm
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_dataset, test_batchify_fn
from encoding.models import get_model, get_segmentation_model, MultiEvalModule
from encoding.models import get_segmentation_model, MultiEvalModule
#from model_mapping import rename_weight_for_head

class Options():
Expand Down Expand Up @@ -113,7 +113,7 @@ def test(args):
testset = get_dataset(args.dataset, split='val', mode='test',
transform=input_transform)
else:
testset = get_dataset(args.dataset, split='test', mode='test',
testset = get_dataset(args.dataset, split='val', mode='testval',
transform=input_transform)
# dataloader
loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
Expand All @@ -124,9 +124,10 @@ def test(args):
# model
pretrained = args.resume is None and args.verify is None
if args.model_zoo is not None:
model = get_model(args.model_zoo, pretrained=pretrained)
model.base_size = args.base_size
model.crop_size = args.crop_size
pass
# model = get_model(args.model_zoo, pretrained=pretrained)
# model.base_size = args.base_size
# model.crop_size = args.crop_size
else:
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone=args.backbone, aux = args.aux,
Expand Down

0 comments on commit b0a0fa0

Please sign in to comment.