-
Notifications
You must be signed in to change notification settings - Fork 0
/
6_DeepLabv3+_Road_CE_Loss.py
122 lines (82 loc) · 1.89 KB
/
6_DeepLabv3+_Road_CE_Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import glob as glob
import albumentations as A
import requests
import zipfile
import time
import pathlib
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, AveragePooling2D, Conv2DTranspose, BatchNormalization, Activation, Dropout, Upsampling2D, Concatenate
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from matplotlib.ticker
import MultipleLocator, FormatStrFormatter
from dataclasses import dataclass
def system_config():
@dataclass(frozen = True)
class DatasetConfig:
@dataclass(frozen = True)
class TrainingConfig:
@dataclass(frozen = True)
class InferenceConfig:
def convolution_block():
def DilatedSpatialPyramidPooling():
def deeplabv3plus():
model = deeplabv3plus()
model.summary()
def download_file()
def unzip()
save_name
class CustomSegDataLoader():
def __init__():
def __len__():
def transforms():
def resize():
def reset_array():
def __getitem__():
id2color = {}
id2color_display = {}
def rgb_to_onehot():
def num_to_rgb():
def image_overlay():
def display_image_and_mask():
def create_datasets():
train_ds, valid_ds =
for i, (images,masks) in enumerate(train_ds):
if i ==3:
break
image, mask = images[0], masks[0]
display_image_and_mask()
def mean_iou():
return mean_iou
model.compile()
if not os.path.exists(TrainingConfig.CHECKPOINT_DIR):
os.makedirs()
num_versions
version_dir
os.makedirs
model_checkpoint_callback
history = model.fit()
def plot_results():
plt.close()
train_acc
valid_acc
train_iou
valid_iou
plot_results
plot_results
train_loss
valid_loss
max_loss
plot_results
trained_model
evaluate
print
print
def inference():
inference()