-
Notifications
You must be signed in to change notification settings - Fork 905
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
235 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,6 @@ dependencies: | |
- opencv | ||
- pip: | ||
- tensorflow==2.0.0 | ||
- lxml | ||
- tqdm | ||
- -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
aeroplane | ||
bicycle | ||
bird | ||
boat | ||
bottle | ||
bus | ||
car | ||
cat | ||
chair | ||
cow | ||
diningtable | ||
dog | ||
horse | ||
motorbike | ||
person | ||
pottedplant | ||
sheep | ||
sofa | ||
train | ||
tvmonitor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Training Instruction | ||
|
||
## VOC 2012 Dataset from Scratch | ||
|
||
Full instruction on how to train using VOC 2012 from scratch | ||
|
||
Requirement: | ||
1. Able to detect image using pretrained darknet model | ||
2. Many Gigabytes of Disk Space | ||
3. High Speed Internet Connection Prefered | ||
4. GPU Prefered | ||
|
||
|
||
### 1. Download Dataset | ||
|
||
You can read the full description of dataset [here](http://host.robots.ox.ac.uk/pascal/VOC/) | ||
```bash | ||
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar -O ./data/voc2012_raw.tar | ||
mkdir -p ./data/voc2012_raw | ||
tar -xf ./data/voc2012_raw.tar -C ./data/voc2012_raw | ||
ls ./data/voc2012_raw/VOCdevkit/VOC2012 # Explore the dataset | ||
``` | ||
|
||
### 2. Transform Dataset | ||
|
||
```bash | ||
python tools/voc2012.py \ | ||
--data_dir './data/voc2012_raw/VOCdevkit/VOC2012' | ||
--split train \ | ||
--output_file ./data/voc2012_train.tfrecord | ||
|
||
python tools/voc2012.py \ | ||
--data_dir './data/voc2012_raw/VOCdevkit/VOC2012' | ||
--split val \ | ||
--output_file ./data/voc2012_val.tfrecord | ||
``` | ||
|
||
### 3. Training | ||
|
||
You can adjust the parameters based on your setup | ||
|
||
```bash | ||
python train.py \ | ||
--dataset ./data/voc2012_train.tfrecord \ | ||
--val_dataset ./data/voc2012_val.tfrecord \ | ||
--classes ./data/voc2012.names \ | ||
--num_classes 20 \ | ||
--mode fit --transfer none \ | ||
--batch_size 16 \ | ||
--epochs 3 \ | ||
--weights ./checkpoints/yolov3_voc.tf | ||
``` | ||
|
||
I have tested this works 100% with correct loss and converging over time | ||
Each epoch takes around 10 minutes on single AWS p2.xlarge (Nvidia K80 GPU) Instance. | ||
|
||
### 4. Inference | ||
|
||
```bash | ||
python detect.py \ | ||
--classes ./data/voc2012.names \ | ||
--num_classes 20 \ | ||
--weights ./checkpoints/yolov3_voc.tf | ||
``` | ||
|
||
You should see some detect objects in the standard output and the visualization at `output.jpg`. | ||
this is just a proof of concept, so it won't be as good as pretrained models | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from setuptools import setup | ||
|
||
setup(name='yolov3_tf2', | ||
version='0.1', | ||
url='https://github.com/zzh8829/yolov3-tf2', | ||
author='Zihao Zhang', | ||
author_email='[email protected]', | ||
packages=['yolov3_tf2']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import time | ||
import os | ||
import hashlib | ||
|
||
from absl import app, flags, logging | ||
from absl.flags import FLAGS | ||
import tensorflow as tf | ||
import lxml.etree | ||
import tqdm | ||
|
||
flags.DEFINE_string('data_dir', './data/voc2012_raw/VOCdevkit/VOC2012/', | ||
'path to raw PASCAL VOC dataset') | ||
flags.DEFINE_enum('split', 'train', [ | ||
'train', 'val'], 'specify train or val spit') | ||
flags.DEFINE_string('output_file', './data/voc2012.tfrecord', 'outpot dataset') | ||
flags.DEFINE_string('classes', './data/voc2012.names', 'classes file') | ||
|
||
|
||
def build_example(annotation, class_map): | ||
img_path = os.path.join( | ||
FLAGS.data_dir, 'JPEGImages', annotation['filename']) | ||
img_raw = open(img_path, 'rb').read() | ||
key = hashlib.sha256(img_raw).hexdigest() | ||
|
||
width = int(annotation['size']['width']) | ||
height = int(annotation['size']['height']) | ||
|
||
xmin = [] | ||
ymin = [] | ||
xmax = [] | ||
ymax = [] | ||
classes = [] | ||
classes_text = [] | ||
truncated = [] | ||
views = [] | ||
difficult_obj = [] | ||
if 'object' in annotation: | ||
for obj in annotation['object']: | ||
difficult = bool(int(obj['difficult'])) | ||
difficult_obj.append(int(difficult)) | ||
|
||
xmin.append(float(obj['bndbox']['xmin']) / width) | ||
ymin.append(float(obj['bndbox']['ymin']) / height) | ||
xmax.append(float(obj['bndbox']['xmax']) / width) | ||
ymax.append(float(obj['bndbox']['ymax']) / height) | ||
classes_text.append(obj['name'].encode('utf8')) | ||
classes.append(class_map[obj['name']]) | ||
truncated.append(int(obj['truncated'])) | ||
views.append(obj['pose'].encode('utf8')) | ||
|
||
example = tf.train.Example(features=tf.train.Features(feature={ | ||
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), | ||
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), | ||
'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[ | ||
annotation['filename'].encode('utf8')])), | ||
'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[ | ||
annotation['filename'].encode('utf8')])), | ||
'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])), | ||
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), | ||
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])), | ||
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)), | ||
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)), | ||
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)), | ||
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)), | ||
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)), | ||
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)), | ||
'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)), | ||
'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)), | ||
'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)), | ||
})) | ||
return example | ||
|
||
|
||
def parse_xml(xml): | ||
if not len(xml): | ||
return {xml.tag: xml.text} | ||
result = {} | ||
for child in xml: | ||
child_result = parse_xml(child) | ||
if child.tag != 'object': | ||
result[child.tag] = child_result[child.tag] | ||
else: | ||
if child.tag not in result: | ||
result[child.tag] = [] | ||
result[child.tag].append(child_result[child.tag]) | ||
return {xml.tag: result} | ||
|
||
|
||
def main(_argv): | ||
class_map = {name: idx for idx, name in enumerate( | ||
open(FLAGS.classes).read().splitlines())} | ||
logging.info("Class mapping loaded: %s", class_map) | ||
|
||
writer = tf.io.TFRecordWriter(FLAGS.output_file) | ||
image_list = open(os.path.join( | ||
FLAGS.data_dir, 'ImageSets', 'Main', 'aeroplane_%s.txt' % FLAGS.split)).read().splitlines() | ||
logging.info("Image list loaded: %d", len(image_list)) | ||
for image in tqdm.tqdm(image_list): | ||
name, _ = image.split() | ||
annotation_xml = os.path.join( | ||
FLAGS.data_dir, 'Annotations', name + '.xml') | ||
annotation_xml = lxml.etree.fromstring(open(annotation_xml).read()) | ||
annotation = parse_xml(annotation_xml)['annotation'] | ||
tf_example = build_example(annotation, class_map) | ||
writer.write(tf_example.SerializeToString()) | ||
writer.close() | ||
logging.info("Done") | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters