Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Panneerselvam-N committed Aug 1, 2021
0 parents commit 551ee05
Show file tree
Hide file tree
Showing 34 changed files with 2,582 additions and 0 deletions.
96 changes: 96 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# PP_YOLO TensorFlow
### Tensorflow implementation of PP-YOLOv1

<p align="center" ><img src='sample_outputs/output_4.jpeg'\></p>



## Requiremets Installation

```bash
# Tensorflow CPU
conda env create -f conda-cpu.yml
conda activate yolov4-cpu

# Tensorflow GPU
conda env create -f conda-gpu.yml
conda activate yolov4-gpu
```

### Pip
```bash
# TensorFlow CPU
pip install -r requirements.txt

# TensorFlow GPU
pip install -r requirements-gpu.txt
```

## Custom Data Training
### Step - 1 : (Setup "core/config.py file")
* Modify path of .names file (line 14)
* Modify number of classes (line 15)
* Modify path of train.txt file (line 29)
* Modify other parametrs like batch size, learning rate,etc according to your requirements (Optional)

Use following code to create train.txt file. First need to copy all annotations file and image to 'data/dataset' and then run following code.

```python
# create train.txt file
import glob

files = glob.glob('data/dataset/*.jpg')
with open('train.txt','w') as f:
f.write('\n'.join(files))

```

### Step - 2 : (Model training)
Run following command for training
```bash
python train.py
```
Note : If training interrupts due to any network or other issues , run following command for resuming. Use less learning rate to fix Nan error.
```bash
python train.py --const_lr True --resume 'checkpoints/pp_yolo'
```

### Step - 3: (Model covertion)
Run following command for model convertion , basically it's take saved weights and convert it to saved model format.

```bash
python convert.py --weights './checkpoints/pp_yolo' --save './saved_model' --size 416
```
### Step - 4: (Detection)
Run following command for images:
```bash
python detect_img.py --model ./checkpoints/saved_model --image './source/test.jpeg'

```
Run following command for Video :
```bash
python detect_vid.py --model ./checkpoints/saved_model --video ./source/vid.mp4 --output './output/result.avi'

```

Note : Outputs are stored in detection folder defaultly, use --output to change path.

To Do List :
* [x] Core Architecture
* [x] CoordConv
* [x] SPP(Spatial Pyramid Pooling)
* [ ] Deformable Conv
* [ ] Drop Block
* [x] Detection(Infer)
* [ ] Model Evaluation

Note : This project is not optimized version, use official Paddle Paddle framework for better result.
### References
* PP-YOLO: An Effective and Efficient Implementation of Object Detector [PP-Yolo v1](https://arxiv.org/abs/2007.12099)

* Paddle Detection [Paddle implemetation](https://github.com/PaddlePaddle/PaddleDetection)

My project is inspired by this privious YOLOv4 implemetation.
* [YOLOv4](https://github.com/theAIGuysCode/tensorflow-yolov4-tflite)


53 changes: 53 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import shutil
import tensorflow as tf
from utils.data_process import decode_tf , filter_boxes
from utils.config import cfg
import numpy as np
from utils import utils
from core.fpn import fpn
from core.resnet_50 import resnet_50
from core.head import head
import argparse

def save_tf():
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config()

resnet = tf.keras.applications.ResNet50(include_top=False,weights='imagenet',input_shape=(args.size,args.size,3))
c3 , c4 ,c5 = resnet_50(resnet)
neck_output = fpn(c3 , c4 , c5)
head_output = head(neck_output)
feature_maps = head_output
bbox_tensors = []
prob_tensors = []

input_size = 416
for i, fm in enumerate(feature_maps):
if i == 0:
output_tensors = decode_tf(fm, args.size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
elif i == 1:
output_tensors = decode_tf(fm, args.size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE )
else:
output_tensors = decode_tf(fm, args.size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
bbox_tensors.append(output_tensors[0])
prob_tensors.append(output_tensors[1])
pred_bbox = tf.concat(bbox_tensors, axis=1)
pred_prob = tf.concat(prob_tensors, axis=1)

boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=0.2, input_shape=tf.constant([args.size, args.size]))
pred = tf.concat([boxes, pred_conf], axis=-1)
model = tf.keras.Model(resnet.input, pred)
print('loading weights..')
model.load_weights(args.weights)
#model.summary()
print('saving model..')
model.save(args.save)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-s','--size',help = 'Size of input image',type = int ,default = 416)
parser.add_argument('-w','--weights',help = 'path to weights',type = str , default = './checkpoints/pp_yolo')
parser.add_argument('-m','--save', help = 'path to save model',type = str , default = './checkpoints/saved_model')

args = parser.parse_args()
save_tf()
Binary file added core/__pycache__/blocks.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added core/__pycache__/fpn.cpython-37.pyc
Binary file not shown.
Binary file added core/__pycache__/head.cpython-37.pyc
Binary file not shown.
Binary file added core/__pycache__/resnet_50.cpython-37.pyc
Binary file not shown.
80 changes: 80 additions & 0 deletions core/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import tensorflow as tf
from tensorflow.keras.layers import Conv2D , BatchNormalization , MaxPool2D
from utils.config import cfg

def convblock(input_tensors, bn=True , coorconv = True ):

conv = Conv2D(filters= input_tensors.shape[-1]*2, kernel_size=(3,3),padding='same')(input_tensors)

if bn: conv = BatchNormalization()(conv)

conv = tf.nn.relu(conv)
if coordconv: conv = coordconv(conv)
conv = Conv2D(filters=input_tensors.shape[-1], kernel_size =(1,1))(conv)
if bn: conv = BatchNormalization()(conv)
conv = tf.nn.relu(conv)

return conv

def coordconv(feature_map):
batch_size = tf.shape(feature_map)[0]
x_shape = tf.shape(feature_map)[1]
y_shape = tf.shape(feature_map)[2]

x_ones = tf.ones((batch_size , x_shape),dtype=tf.float32)
x_ones = tf.expand_dims(x_ones,axis = -1)
x_range = tf.tile(tf.expand_dims(tf.range(y_shape,dtype=tf.float32),axis=0),[batch_size,1])
x_range = tf.expand_dims(x_range,1)
x_channel = tf.matmul(x_ones,x_range)
x_channel = tf.expand_dims(x_channel,axis=-1)

y_ones = tf.ones((batch_size , y_shape),dtype=tf.float32)
y_ones = tf.expand_dims(y_ones,axis = 1)
y_range = tf.tile(tf.expand_dims(tf.range(x_shape,dtype=tf.float32),axis=0),[batch_size,1])
y_range = tf.expand_dims(y_range,-1)
y_channel = tf.matmul(y_range,y_ones)
y_channel = tf.expand_dims(y_channel,axis=-1)

x_shape = tf.cast(x_shape , dtype=tf.float32)
y_shape = tf.cast(y_shape, dtype = tf.float32)


x_channel = tf.cast(x_channel,dtype=tf.float32) / (y_shape -1)
y_channel = tf.cast(y_channel,dtype=tf.float32) / (x_shape - 1)

x_channel = x_channel * 2 - 1
y_channel = y_channel * 2 -1

output_tensors = tf.concat([feature_map,x_channel,y_channel],axis=-1)

return output_tensors

def upsampling(features):
channels = features.shape[-1]
conv = coordconv(features)
conv = Conv2D(filters=channels/2 ,kernel_size=(1,1))(conv)
output = tf.image.resize(conv,size=(conv.shape[1]*2,conv.shape[2]*2))
return output

def sppblock(input_tensors):

pooling_1 = MaxPool2D(pool_size=(1,1),strides=(1,1))(input_tensors)
pooling_2 = MaxPool2D(pool_size=(5,5),padding='same',strides=(1,1))(input_tensors)
pooling_3 = MaxPool2D(pool_size=(9,9),padding='same',strides=(1,1))(input_tensors)
#pooling_4 = MaxPool2D(pool_size=(13,13),padding='same',strides=(1,1))(input_tensors)

output = tf.concat([input_tensors,pooling_1,pooling_2 ,pooling_3],axis=-1)

return output

def conv_head(features):
channel = features.shape[-1]
num_classes = cfg.YOLO.NUM_CLASSES
num_filters = 3 * (num_classes + 5)
conv = coordconv(features)
conv = Conv2D(filters=channel*2,kernel_size=(3,3),padding='same')(conv)
conv = tf.nn.relu(conv)
conv = Conv2D(filters= num_filters,kernel_size=(1,1))(conv)

return conv

Loading

0 comments on commit 551ee05

Please sign in to comment.