Skip to content

Commit

Permalink
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Nov 13, 2023
1 parent fa28c3c commit e58545e
Show file tree
Hide file tree
Showing 19 changed files with 3,409 additions and 2,853 deletions.
2,802 changes: 1,401 additions & 1,401 deletions src/OCR_EDA.ipynb

Large diffs are not rendered by default.

371 changes: 275 additions & 96 deletions src/augmentation.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/bbox_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def expand_bbox(points, ratio, width, height):


def do_expansion(anno_path, ratio, suffix):
ufo = json.load(open(anno_path, "r"))
ufo = json.load(open(anno_path))

for image_name in tqdm(ufo["images"]):
image = ufo["images"][image_name]
Expand Down
2 changes: 1 addition & 1 deletion src/bbox_shaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def shake_bbox(points, width, height, ratio=0.125):


def do_shaking(anno_path, ratio, suffix):
ufo = json.load(open(anno_path, "r"))
ufo = json.load(open(anno_path))

for image_name in tqdm(ufo["images"]):
image = ufo["images"][image_name]
Expand Down
8 changes: 4 additions & 4 deletions src/bbox_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def rotate_bbox(bbox, theta, anchor=None):
def calc_error_from_rect(bbox):
"""Calculate the difference between the vertices orientation and default orientation
Default orientation is
x1y1 : left-top,
x2y2 : right-top,
x3y3 : right-bot,
Default orientation is
x1y1 : left-top,
x2y2 : right-top,
x3y3 : right-bot,
x4y4 : left-bot.
"""
x_min, y_min = np.min(bbox, axis=0)
Expand Down
109 changes: 64 additions & 45 deletions src/convert_mlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,24 @@
import os
import os.path as osp
from glob import glob
from PIL import Image

import numpy as np
from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from tqdm import tqdm

from torch.utils.data import DataLoader, ConcatDataset, Dataset


SRC_DATASET_DIR = '/data/datasets/ICDAR17_MLT' # FIXME
DST_DATASET_DIR = '/data/datasets/ICDAR17_Korean' # FIXME
SRC_DATASET_DIR = "/data/datasets/ICDAR17_MLT" # FIXME
DST_DATASET_DIR = "/data/datasets/ICDAR17_Korean" # FIXME

NUM_WORKERS = 32 # FIXME

IMAGE_EXTENSIONS = {'.gif', '.jpg', '.png'}
IMAGE_EXTENSIONS = {".gif", ".jpg", ".png"}

LANGUAGE_MAP = {"Korean": "ko", "Latin": "en", "Symbols": None}

LANGUAGE_MAP = {
'Korean': 'ko',
'Latin': 'en',
'Symbols': None
}

def get_language_token(x):
return LANGUAGE_MAP.get(x, 'others')
return LANGUAGE_MAP.get(x, "others")


def maybe_mkdir(x):
Expand All @@ -34,25 +29,31 @@ def maybe_mkdir(x):

class MLT17Dataset(Dataset):
def __init__(self, image_dir, label_dir, copy_images_to=None):
image_paths = {x for x in glob(osp.join(image_dir, '*')) if osp.splitext(x)[1] in
IMAGE_EXTENSIONS}
label_paths = set(glob(osp.join(label_dir, '*.txt')))
image_paths = {
x
for x in glob(osp.join(image_dir, "*"))
if osp.splitext(x)[1] in IMAGE_EXTENSIONS
}
label_paths = set(glob(osp.join(label_dir, "*.txt")))
assert len(image_paths) == len(label_paths)

sample_ids, samples_info = list(), dict()
for image_path in image_paths:
sample_id = osp.splitext(osp.basename(image_path))[0]

label_path = osp.join(label_dir, 'gt_{}.txt'.format(sample_id))
label_path = osp.join(label_dir, f"gt_{sample_id}.txt")
assert label_path in label_paths

words_info, extra_info = self.parse_label_file(label_path)
if 'ko' not in extra_info['languages'] or extra_info['languages'].difference({'ko', 'en'}):
if "ko" not in extra_info["languages"] or extra_info[
"languages"
].difference({"ko", "en"}):
continue

sample_ids.append(sample_id)
samples_info[sample_id] = dict(image_path=image_path, label_path=label_path,
words_info=words_info)
samples_info[sample_id] = dict(
image_path=image_path, label_path=label_path, words_info=words_info
)

self.sample_ids, self.samples_info = sample_ids, samples_info

Expand All @@ -64,18 +65,26 @@ def __len__(self):
def __getitem__(self, idx):
sample_info = self.samples_info[self.sample_ids[idx]]

image_fname = osp.basename(sample_info['image_path'])
image = Image.open(sample_info['image_path'])
image_fname = osp.basename(sample_info["image_path"])
image = Image.open(sample_info["image_path"])
img_w, img_h = image.size

if self.copy_images_to:
maybe_mkdir(self.copy_images_to)
image.save(osp.join(self.copy_images_to, osp.basename(sample_info['image_path'])))
image.save(
osp.join(self.copy_images_to, osp.basename(sample_info["image_path"]))
)

license_tag = dict(usability=True, public=True, commercial=True, type='CC-BY-SA',
holder=None)
sample_info_ufo = dict(img_h=img_h, img_w=img_w, words=sample_info['words_info'], tags=None,
license_tag=license_tag)
license_tag = dict(
usability=True, public=True, commercial=True, type="CC-BY-SA", holder=None
)
sample_info_ufo = dict(
img_h=img_h,
img_w=img_w,
words=sample_info["words_info"],
tags=None,
license_tag=license_tag,
)

return image_fname, sample_info_ufo

Expand All @@ -86,52 +95,62 @@ def rearrange_points(points):
points = np.roll(points, -start_idx, axis=0).tolist()
return points

with open(label_path, encoding='utf-8') as f:
with open(label_path, encoding="utf-8") as f:
lines = f.readlines()

words_info, languages = dict(), set()
for word_idx, line in enumerate(lines):
items = line.strip().split(',', 9)
items = line.strip().split(",", 9)
language, transcription = items[8], items[9]
points = np.array(items[:8], dtype=np.float32).reshape(4, 2).tolist()
points = rearrange_points(points)

illegibility = transcription == '###'
orientation = 'Horizontal'
illegibility = transcription == "###"
orientation = "Horizontal"
language = get_language_token(language)
words_info[word_idx] = dict(
points=points, transcription=transcription, language=[language],
illegibility=illegibility, orientation=orientation, word_tags=None
points=points,
transcription=transcription,
language=[language],
illegibility=illegibility,
orientation=orientation,
word_tags=None,
)
languages.add(language)

return words_info, dict(languages=languages)


def main():
dst_image_dir = osp.join(DST_DATASET_DIR, 'images')
dst_image_dir = osp.join(DST_DATASET_DIR, "images")
# dst_image_dir = None

mlt_train = MLT17Dataset(osp.join(SRC_DATASET_DIR, 'raw/ch8_training_images'),
osp.join(SRC_DATASET_DIR, 'raw/ch8_training_gt'),
copy_images_to=dst_image_dir)
mlt_valid = MLT17Dataset(osp.join(SRC_DATASET_DIR, 'raw/ch8_validation_images'),
osp.join(SRC_DATASET_DIR, 'raw/ch8_validation_gt'),
copy_images_to=dst_image_dir)
mlt_train = MLT17Dataset(
osp.join(SRC_DATASET_DIR, "raw/ch8_training_images"),
osp.join(SRC_DATASET_DIR, "raw/ch8_training_gt"),
copy_images_to=dst_image_dir,
)
mlt_valid = MLT17Dataset(
osp.join(SRC_DATASET_DIR, "raw/ch8_validation_images"),
osp.join(SRC_DATASET_DIR, "raw/ch8_validation_gt"),
copy_images_to=dst_image_dir,
)
mlt_merged = ConcatDataset([mlt_train, mlt_valid])

anno = dict(images=dict())
with tqdm(total=len(mlt_merged)) as pbar:
for batch in DataLoader(mlt_merged, num_workers=NUM_WORKERS, collate_fn=lambda x: x):
for batch in DataLoader(
mlt_merged, num_workers=NUM_WORKERS, collate_fn=lambda x: x
):
image_fname, sample_info = batch[0]
anno['images'][image_fname] = sample_info
anno["images"][image_fname] = sample_info
pbar.update(1)

ufo_dir = osp.join(DST_DATASET_DIR, 'ufo')
ufo_dir = osp.join(DST_DATASET_DIR, "ufo")
maybe_mkdir(ufo_dir)
with open(osp.join(ufo_dir, 'train.json'), 'w') as f:
with open(osp.join(ufo_dir, "train.json"), "w") as f:
json.dump(anno, f, indent=4)


if __name__ == '__main__':
if __name__ == "__main__":
main()
Loading

0 comments on commit e58545e

Please sign in to comment.