Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor and cleanup #147

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions src/layoutparser/elements/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict, Dict, Any
from abc import ABC, abstractmethod
from copy import copy
from typing import Any, Dict, List


class BaseLayoutElement:
def set(self, inplace=False, **kwargs):
Expand Down Expand Up @@ -49,13 +50,11 @@ class BaseCoordElement(ABC, BaseLayoutElement):
@abstractmethod
def _name(self) -> str:
"""The name of the class"""
pass

@property
@abstractmethod
def _features(self) -> List[str]:
"""A list of features names used for initializing the class object"""
pass

#######################################################################
######################### Layout Properties #########################
Expand Down Expand Up @@ -108,8 +107,6 @@ def condition_on(self, other):
The BaseCoordElement object of the original element in the absolute coordinate system.
"""

pass

@abstractmethod
def relative_to(self, other):
"""
Expand All @@ -127,8 +124,6 @@ def relative_to(self, other):
The BaseCoordElement object of the original element in the relative coordinate system.
"""

pass

@abstractmethod
def is_in(self, other, soft_margin={}, center=False):
"""
Expand All @@ -147,8 +142,6 @@ def is_in(self, other, soft_margin={}, center=False):
:obj:`bool`: Returns `True` if the current element is in the other element and `False` if not.
"""

pass

#######################################################################
################# Shape Operations (intersect, union) ################
#######################################################################
Expand Down Expand Up @@ -186,8 +179,6 @@ def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True):
:obj:`BaseCoordElement`: The padded BaseCoordElement object.
"""

pass

@abstractmethod
def shift(self, shift_distance=0):
"""
Expand All @@ -202,8 +193,6 @@ def shift(self, shift_distance=0):
:obj:`BaseCoordElement`: The shifted BaseCoordElement of the same shape-specific class.
"""

pass

@abstractmethod
def scale(self, scale_factor=1):
"""
Expand All @@ -217,8 +206,6 @@ def scale(self, scale_factor=1):
:obj:`BaseCoordElement`: The scaled BaseCoordElement of the same shape-specific class.
"""

pass

#######################################################################
################################# MISC ################################
#######################################################################
Expand All @@ -235,8 +222,6 @@ def crop_image(self, image):
:obj:`Numpy array`: The array of the cropped image.
"""

pass

#######################################################################
########################## Import and Export ##########################
#######################################################################
Expand All @@ -252,9 +237,7 @@ def to_dict(self) -> Dict[str, Any]:
"""

data = {
key: getattr(self, key)
for key in self._features
if getattr(self, key) is not None
key: getattr(self, key) for key in self._features if getattr(self, key) is not None
}
data["block_type"] = self._name
return data
Expand All @@ -267,9 +250,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "BaseCoordElement":
data (:obj:`dict`): The dictionary representation of the object
"""

assert (
cls._name == data["block_type"]
), f"Incompatible block types {data['block_type']}"
assert cls._name == data["block_type"], f"Incompatible block types {data['block_type']}"

return cls(**{f: data[f] for f in cls._features})

67 changes: 24 additions & 43 deletions src/layoutparser/elements/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union, Dict, Dict, Any, Optional
from collections.abc import MutableSequence, Iterable
from collections.abc import Iterable, MutableSequence
from copy import copy
from typing import Any, Dict, List, Optional

import pandas as pd

from .base import BaseCoordElement, BaseLayoutElement
from .layout_elements import (
from layoutparser.elements.base import BaseCoordElement, BaseLayoutElement
from layoutparser.elements.layout_elements import (
ALL_BASECOORD_ELEMENTS,
BASECOORD_ELEMENT_INDEXMAP,
BASECOORD_ELEMENT_NAMEMAP,
Interval,
Rectangle,
Quadrilateral,
Rectangle,
TextBlock,
ALL_BASECOORD_ELEMENTS,
BASECOORD_ELEMENT_NAMEMAP,
BASECOORD_ELEMENT_INDEXMAP,
)


Expand All @@ -47,20 +47,20 @@ class Layout(MutableSequence):
"""

def __init__(self, blocks: Optional[List] = None, *, page_data: Dict = None):

if not (
(blocks is None)
or (isinstance(blocks, Iterable) and blocks.__class__.__name__ != "Layout")
blocks is None
or isinstance(blocks, Iterable)
and blocks.__class__.__name__ != "Layout"
):

if blocks.__class__.__name__ == "Layout":
error_msg = f"Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)"
error_msg = "Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)"

else:
error_msg = f"Blocks should be a list of layout elements or empty (None), instead got {blocks}.\n"

raise ValueError(error_msg)

if isinstance(blocks, tuple):
blocks = list(blocks) # <- more robust handling for tuple-like inputs
blocks = list(blocks) # <- more robust handling for tuple-like inputs

self._blocks = blocks if blocks is not None else []
self.page_data = page_data or {}
Expand All @@ -82,8 +82,7 @@ def __len__(self):
return len(self._blocks)

def __iter__(self):
for ele in self._blocks:
yield ele
yield from self._blocks

def __repr__(self):
info_str = ", ".join([f"{key}={val}" for key, val in vars(self).items()])
Expand All @@ -98,9 +97,7 @@ def __eq__(self, other):
def __add__(self, other):
if isinstance(other, Layout):
if self.page_data == other.page_data:
return self.__class__(
self._blocks + other._blocks, page_data=self.page_data
)
return self.__class__(self._blocks + other._blocks, page_data=self.page_data)
elif self.page_data == {} or other.page_data == {}:
return self.__class__(
self._blocks + other._blocks,
Expand All @@ -113,9 +110,7 @@ def __add__(self, other):
elif isinstance(other, list):
return self.__class__(self._blocks + other, page_data=self.page_data)
else:
raise ValueError(
f"Invalid input type for other {other.__class__.__name__}."
)
raise ValueError(f"Invalid input type for other {other.__class__.__name__}.")

def insert(self, key, value):
self._blocks.insert(key, value)
Expand All @@ -124,14 +119,10 @@ def copy(self):
return self.__class__(copy(self._blocks), page_data=self.page_data)

def relative_to(self, other):
return self.__class__(
[ele.relative_to(other) for ele in self], page_data=self.page_data
)
return self.__class__([ele.relative_to(other) for ele in self], page_data=self.page_data)

def condition_on(self, other):
return self.__class__(
[ele.condition_on(other) for ele in self], page_data=self.page_data
)
return self.__class__([ele.condition_on(other) for ele in self], page_data=self.page_data)

def is_in(self, other, soft_margin={}, center=False):
return self.__class__(
Expand Down Expand Up @@ -234,9 +225,7 @@ def scale(self, scale_factor):
:obj:`Layout`:
A new layout object with all the elements scaled in the specified values.
"""
return self.__class__(
[ele.scale(scale_factor) for ele in self], page_data=self.page_data
)
return self.__class__([ele.scale(scale_factor) for ele in self], page_data=self.page_data)

def crop_image(self, image):
return [ele.crop_image(image) for ele in self]
Expand Down Expand Up @@ -297,9 +286,7 @@ def get_homogeneous_blocks(self) -> List[BaseLayoutElement]:
else:
block = ele

max_coord_level = max(
max_coord_level, BASECOORD_ELEMENT_INDEXMAP[block._name]
)
max_coord_level = max(max_coord_level, BASECOORD_ELEMENT_INDEXMAP[block._name])
target_coord_name = ALL_BASECOORD_ELEMENTS[max_coord_level]._name

if has_textblock:
Expand Down Expand Up @@ -338,11 +325,5 @@ def to_dataframe(self, enforce_same_type=False) -> pd.DataFrame:
pd.DataFrame:
The dataframe representation of layout object
"""
if enforce_same_type:
blocks = self.get_homogeneous_blocks()
else:
blocks = self

df = pd.DataFrame([ele.to_dict() for ele in blocks])

return df
blocks = self.get_homogeneous_blocks() if enforce_same_type else self
return pd.DataFrame([ele.to_dict() for ele in blocks])
Loading