diff --git a/src/layoutparser/elements/base.py b/src/layoutparser/elements/base.py index c78cd71..deffc4d 100644 --- a/src/layoutparser/elements/base.py +++ b/src/layoutparser/elements/base.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict, Dict, Any from abc import ABC, abstractmethod from copy import copy +from typing import Any, Dict, List + class BaseLayoutElement: def set(self, inplace=False, **kwargs): @@ -49,13 +50,11 @@ class BaseCoordElement(ABC, BaseLayoutElement): @abstractmethod def _name(self) -> str: """The name of the class""" - pass @property @abstractmethod def _features(self) -> List[str]: """A list of features names used for initializing the class object""" - pass ####################################################################### ######################### Layout Properties ######################### @@ -108,8 +107,6 @@ def condition_on(self, other): The BaseCoordElement object of the original element in the absolute coordinate system. """ - pass - @abstractmethod def relative_to(self, other): """ @@ -127,8 +124,6 @@ def relative_to(self, other): The BaseCoordElement object of the original element in the relative coordinate system. """ - pass - @abstractmethod def is_in(self, other, soft_margin={}, center=False): """ @@ -147,8 +142,6 @@ def is_in(self, other, soft_margin={}, center=False): :obj:`bool`: Returns `True` if the current element is in the other element and `False` if not. """ - pass - ####################################################################### ################# Shape Operations (intersect, union) ################ ####################################################################### @@ -186,8 +179,6 @@ def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): :obj:`BaseCoordElement`: The padded BaseCoordElement object. """ - pass - @abstractmethod def shift(self, shift_distance=0): """ @@ -202,8 +193,6 @@ def shift(self, shift_distance=0): :obj:`BaseCoordElement`: The shifted BaseCoordElement of the same shape-specific class. """ - pass - @abstractmethod def scale(self, scale_factor=1): """ @@ -217,8 +206,6 @@ def scale(self, scale_factor=1): :obj:`BaseCoordElement`: The scaled BaseCoordElement of the same shape-specific class. """ - pass - ####################################################################### ################################# MISC ################################ ####################################################################### @@ -235,8 +222,6 @@ def crop_image(self, image): :obj:`Numpy array`: The array of the cropped image. """ - pass - ####################################################################### ########################## Import and Export ########################## ####################################################################### @@ -252,9 +237,7 @@ def to_dict(self) -> Dict[str, Any]: """ data = { - key: getattr(self, key) - for key in self._features - if getattr(self, key) is not None + key: getattr(self, key) for key in self._features if getattr(self, key) is not None } data["block_type"] = self._name return data @@ -267,9 +250,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "BaseCoordElement": data (:obj:`dict`): The dictionary representation of the object """ - assert ( - cls._name == data["block_type"] - ), f"Incompatible block types {data['block_type']}" + assert cls._name == data["block_type"], f"Incompatible block types {data['block_type']}" return cls(**{f: data[f] for f in cls._features}) - diff --git a/src/layoutparser/elements/layout.py b/src/layoutparser/elements/layout.py index ec421e4..1551bc9 100644 --- a/src/layoutparser/elements/layout.py +++ b/src/layoutparser/elements/layout.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Dict, Dict, Any, Optional -from collections.abc import MutableSequence, Iterable +from collections.abc import Iterable, MutableSequence from copy import copy +from typing import Any, Dict, List, Optional import pandas as pd -from .base import BaseCoordElement, BaseLayoutElement -from .layout_elements import ( +from layoutparser.elements.base import BaseCoordElement, BaseLayoutElement +from layoutparser.elements.layout_elements import ( + ALL_BASECOORD_ELEMENTS, + BASECOORD_ELEMENT_INDEXMAP, + BASECOORD_ELEMENT_NAMEMAP, Interval, - Rectangle, Quadrilateral, + Rectangle, TextBlock, - ALL_BASECOORD_ELEMENTS, - BASECOORD_ELEMENT_NAMEMAP, - BASECOORD_ELEMENT_INDEXMAP, ) @@ -47,20 +47,20 @@ class Layout(MutableSequence): """ def __init__(self, blocks: Optional[List] = None, *, page_data: Dict = None): - if not ( - (blocks is None) - or (isinstance(blocks, Iterable) and blocks.__class__.__name__ != "Layout") + blocks is None + or isinstance(blocks, Iterable) + and blocks.__class__.__name__ != "Layout" ): - if blocks.__class__.__name__ == "Layout": - error_msg = f"Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)" + error_msg = "Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)" + else: error_msg = f"Blocks should be a list of layout elements or empty (None), instead got {blocks}.\n" + raise ValueError(error_msg) - if isinstance(blocks, tuple): - blocks = list(blocks) # <- more robust handling for tuple-like inputs + blocks = list(blocks) # <- more robust handling for tuple-like inputs self._blocks = blocks if blocks is not None else [] self.page_data = page_data or {} @@ -82,8 +82,7 @@ def __len__(self): return len(self._blocks) def __iter__(self): - for ele in self._blocks: - yield ele + yield from self._blocks def __repr__(self): info_str = ", ".join([f"{key}={val}" for key, val in vars(self).items()]) @@ -98,9 +97,7 @@ def __eq__(self, other): def __add__(self, other): if isinstance(other, Layout): if self.page_data == other.page_data: - return self.__class__( - self._blocks + other._blocks, page_data=self.page_data - ) + return self.__class__(self._blocks + other._blocks, page_data=self.page_data) elif self.page_data == {} or other.page_data == {}: return self.__class__( self._blocks + other._blocks, @@ -113,9 +110,7 @@ def __add__(self, other): elif isinstance(other, list): return self.__class__(self._blocks + other, page_data=self.page_data) else: - raise ValueError( - f"Invalid input type for other {other.__class__.__name__}." - ) + raise ValueError(f"Invalid input type for other {other.__class__.__name__}.") def insert(self, key, value): self._blocks.insert(key, value) @@ -124,14 +119,10 @@ def copy(self): return self.__class__(copy(self._blocks), page_data=self.page_data) def relative_to(self, other): - return self.__class__( - [ele.relative_to(other) for ele in self], page_data=self.page_data - ) + return self.__class__([ele.relative_to(other) for ele in self], page_data=self.page_data) def condition_on(self, other): - return self.__class__( - [ele.condition_on(other) for ele in self], page_data=self.page_data - ) + return self.__class__([ele.condition_on(other) for ele in self], page_data=self.page_data) def is_in(self, other, soft_margin={}, center=False): return self.__class__( @@ -234,9 +225,7 @@ def scale(self, scale_factor): :obj:`Layout`: A new layout object with all the elements scaled in the specified values. """ - return self.__class__( - [ele.scale(scale_factor) for ele in self], page_data=self.page_data - ) + return self.__class__([ele.scale(scale_factor) for ele in self], page_data=self.page_data) def crop_image(self, image): return [ele.crop_image(image) for ele in self] @@ -297,9 +286,7 @@ def get_homogeneous_blocks(self) -> List[BaseLayoutElement]: else: block = ele - max_coord_level = max( - max_coord_level, BASECOORD_ELEMENT_INDEXMAP[block._name] - ) + max_coord_level = max(max_coord_level, BASECOORD_ELEMENT_INDEXMAP[block._name]) target_coord_name = ALL_BASECOORD_ELEMENTS[max_coord_level]._name if has_textblock: @@ -338,11 +325,5 @@ def to_dataframe(self, enforce_same_type=False) -> pd.DataFrame: pd.DataFrame: The dataframe representation of layout object """ - if enforce_same_type: - blocks = self.get_homogeneous_blocks() - else: - blocks = self - - df = pd.DataFrame([ele.to_dict() for ele in blocks]) - - return df + blocks = self.get_homogeneous_blocks() if enforce_same_type else self + return pd.DataFrame([ele.to_dict() for ele in blocks]) diff --git a/src/layoutparser/elements/layout_elements.py b/src/layoutparser/elements/layout_elements.py index 5bbe5b6..8014ccd 100644 --- a/src/layoutparser/elements/layout_elements.py +++ b/src/layoutparser/elements/layout_elements.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Dict, Dict, Any, Optional, Tuple +import functools +import warnings from collections.abc import Iterable from copy import copy from inspect import getmembers, isfunction -import warnings -import functools +from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd -from PIL import Image from cv2 import getPerspectiveTransform as _getPerspectiveTransform from cv2 import warpPerspective as _warpPerspective +from PIL import Image -from .base import BaseCoordElement, BaseLayoutElement -from .utils import ( +from layoutparser.elements.base import BaseCoordElement, BaseLayoutElement +from layoutparser.elements.errors import InvalidShapeError, NotSupportedShapeError +from layoutparser.elements.utils import ( cvt_coordinates_to_points, cvt_points_to_coordinates, perspective_transformation, - vertice_in_polygon, polygon_area, + vertice_in_polygon, ) -from .errors import NotSupportedShapeError, InvalidShapeError def mixin_textblock_meta(func): @@ -53,19 +53,16 @@ def inherit_docstrings(cls=None, *, base_class=None): # Refer to https://stackoverflow.com/a/17393254 if cls is None: return functools.partial(inherit_docstrings, base_class=base_class) - for name, func in getmembers(cls, isfunction): if func.__doc__: continue - if base_class == None: + if base_class is None: for parent in cls.__mro__[1:]: if hasattr(parent, name): func.__doc__ = getattr(parent, name).__doc__ break - else: - if hasattr(base_class, name): - func.__doc__ = getattr(base_class, name).__doc__ - + elif hasattr(base_class, name): + func.__doc__ = getattr(base_class, name).__doc__ return cls @@ -103,14 +100,11 @@ class Interval(BaseCoordElement): _features = ["start", "end", "axis", "canvas_height", "canvas_width"] def __init__(self, start, end, axis, canvas_height=None, canvas_width=None): - - assert start <= end, f"Invalid input for start and end. Start must <= end." + assert start <= end, "Invalid input for start and end. Start must <= end." self.start = start self.end = end - assert axis in ["x", "y"], f"Invalid axis {axis}. Axis must be in 'x' or 'y'" self.axis = axis - self.canvas_height = canvas_height or 0 self.canvas_width = canvas_width or 0 @@ -123,11 +117,7 @@ def height(self): Returns: :obj:`numeric`: Output the numeric value of the height. """ - - if self.axis == "x": - return self.canvas_height - else: - return self.end - self.start + return self.canvas_height if self.axis == "x" else self.end - self.start @property def width(self): @@ -138,11 +128,7 @@ def width(self): Returns: :obj:`numeric`: Output the numeric value of the width. """ - - if self.axis == "y": - return self.canvas_width - else: - return self.end - self.start + return self.canvas_width if self.axis == "y" else self.end - self.start @property def coordinates(self): @@ -154,13 +140,11 @@ def coordinates(self): :obj:`Tuple(numeric)`: Output the numeric values of the coordinates in a Tuple of size four. """ - - if self.axis == "x": - coords = (self.start, 0, self.end, self.canvas_height) - else: - coords = (0, self.start, self.canvas_width, self.end) - - return coords + return ( + (self.start, 0, self.end, self.canvas_height) + if self.axis == "x" + else (0, self.start, self.canvas_width, self.end) + ) @property def points(self): @@ -223,90 +207,63 @@ def put_on_canvas(self, canvas): @support_textblock def condition_on(self, other): - if isinstance(other, Interval): - if other.axis == self.axis: - d = other.start - # Reset the canvas size in the absolute coordinates - return self.__class__(self.start + d, self.end + d, self.axis) - else: + if other.axis != self.axis: return copy(self) - + d = other.start + return self.__class__(self.start + d, self.end + d, self.axis) elif isinstance(other, Rectangle): - return self.put_on_canvas(other).to_rectangle().condition_on(other) - elif isinstance(other, Quadrilateral): - return self.put_on_canvas(other).to_quadrilateral().condition_on(other) - else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def relative_to(self, other): - if isinstance(other, Interval): - if other.axis == self.axis: - d = other.start - # Reset the canvas size in the absolute coordinates - return self.__class__(self.start - d, self.end - d, self.axis) - else: + if other.axis != self.axis: return copy(self) - + d = other.start + return self.__class__(self.start - d, self.end - d, self.axis) elif isinstance(other, Rectangle): - return self.put_on_canvas(other).to_rectangle().relative_to(other) - elif isinstance(other, Quadrilateral): - return self.put_on_canvas(other).to_quadrilateral().relative_to(other) - else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def is_in(self, other, soft_margin={}, center=False): - other = other.pad(**soft_margin) - if isinstance(other, Interval): if self.axis != other.axis: return False else: - if not center: - return other.start <= self.start <= self.end <= other.end - else: - return other.start <= self.center <= other.end + return ( + other.start <= self.center <= other.end + if center + else other.start <= self.start <= self.end <= other.end + ) - elif isinstance(other, Rectangle) or isinstance(other, Quadrilateral): + elif isinstance(other, (Rectangle, Quadrilateral)): x_1, y_1, x_2, y_2 = other.coordinates - if center: - if self.axis == "x": - return x_1 <= self.center <= x_2 - else: - return y_1 <= self.center <= y_2 - else: - if self.axis == "x": - return x_1 <= self.start <= self.end <= x_2 - else: - return y_1 <= self.start <= self.end <= y_2 + return ( + x_1 <= self.center <= x_2 if self.axis == "x" else y_1 <= self.center <= y_2 + ) + if self.axis == "x": + return x_1 <= self.start <= self.end <= x_2 + else: + return y_1 <= self.start <= self.end <= y_2 else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def intersect(self, other: BaseCoordElement, strict: bool = True): - """""" - if isinstance(other, Interval): - if self.axis != other.axis: - if self.axis == "x" and other.axis == "y": - return Rectangle(self.start, other.start, self.end, other.end) - else: - return Rectangle(other.start, self.start, other.end, self.end) - else: + if self.axis == other.axis: return self.__class__( max(self.start, other.start), min(self.end, other.end), @@ -315,35 +272,38 @@ def intersect(self, other: BaseCoordElement, strict: bool = True): self.canvas_width, ) + if self.axis == "x" and other.axis == "y": + return Rectangle(self.start, other.start, self.end, other.end) + else: + return Rectangle(other.start, self.start, other.end, self.end) elif isinstance(other, Rectangle): x_1, y_1, x_2, y_2 = other.coordinates if self.axis == "x": return Rectangle(max(x_1, self.start), y_1, min(x_2, self.end), y_2) elif self.axis == "y": return Rectangle(x_1, max(y_1, self.start), x_2, min(y_2, self.end)) - elif isinstance(other, Quadrilateral): if strict: raise NotSupportedShapeError( "The intersection between an Interval and a Quadrilateral might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) - else: - warnings.warn( - f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" - ) - return self.intersect(other.to_rectangle()) + warnings.warn( + f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.intersect(other.to_rectangle()) else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def union(self, other: BaseCoordElement, strict: bool = True): - """""" if isinstance(other, Interval): if self.axis != other.axis: raise InvalidShapeError( - f"Unioning two intervals of different axes is not allowed." + "Unioning two intervals of different axes is not allowed." ) + else: return self.__class__( min(self.start, other.start), @@ -359,18 +319,17 @@ def union(self, other: BaseCoordElement, strict: bool = True): return Rectangle(min(x_1, self.start), y_1, max(x_2, self.end), y_2) elif self.axis == "y": return Rectangle(x_1, min(y_1, self.start), x_2, max(y_2, self.end)) - elif isinstance(other, Quadrilateral): if strict: raise NotSupportedShapeError( "The intersection between an Interval and a Quadrilateral might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) - else: - warnings.warn( - f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" - ) - return self.union(other.to_rectangle()) + warnings.warn( + f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.union(other.to_rectangle()) else: raise Exception(f"Invalid input type {other.__class__} for other") @@ -408,9 +367,7 @@ def shift(self, shift_distance): """ if isinstance(shift_distance, Iterable): - shift_distance = ( - shift_distance[0] if self.axis == "x" else shift_distance[1] - ) + shift_distance = shift_distance[0] if self.axis == "x" else shift_distance[1] warnings.warn( f"Input shift for multiple axes. Only use the distance for the {self.axis} axis" ) @@ -560,23 +517,13 @@ def area(self): @support_textblock def condition_on(self, other): - if isinstance(other, Interval): - if other.axis == "x": - dx, dy = other.start, 0 - else: - dx, dy = 0, other.start - - return self.__class__( - self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy - ) + dx, dy = (other.start, 0) if other.axis == "x" else (0, other.start) + return self.__class__(self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy) elif isinstance(other, Rectangle): dx, dy, _, _ = other.coordinates - - return self.__class__( - self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy - ) + return self.__class__(self.x_1 + dx, self.y_1 + dy, self.x_2 + dx, self.y_2 + dy) elif isinstance(other, Quadrilateral): transformed_points = perspective_transformation( @@ -584,28 +531,18 @@ def condition_on(self, other): ) return other.__class__(transformed_points, self.height, self.width) - else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def relative_to(self, other): if isinstance(other, Interval): - if other.axis == "x": - dx, dy = other.start, 0 - else: - dx, dy = 0, other.start - - return self.__class__( - self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy - ) + dx, dy = (other.start, 0) if other.axis == "x" else (0, other.start) + return self.__class__(self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy) elif isinstance(other, Rectangle): dx, dy, _, _ = other.coordinates - - return self.__class__( - self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy - ) + return self.__class__(self.x_1 - dx, self.y_1 - dy, self.x_2 - dx, self.y_2 - dy) elif isinstance(other, Quadrilateral): transformed_points = perspective_transformation( @@ -613,7 +550,6 @@ def relative_to(self, other): ) return other.__class__(transformed_points, self.height, self.width) - else: raise Exception(f"Invalid input type {other.__class__} for other") @@ -636,9 +572,7 @@ def is_in(self, other, soft_margin={}, center=False): elif isinstance(other, Rectangle): x_interval = other.to_interval(axis="x") y_interval = other.to_interval(axis="y") - return self.is_in(x_interval, center=center) and self.is_in( - y_interval, center=center - ) + return self.is_in(x_interval, center=center) and self.is_in(y_interval, center=center) elif isinstance(other, Quadrilateral): @@ -658,11 +592,8 @@ def is_in(self, other, soft_margin={}, center=False): @support_textblock def intersect(self, other: BaseCoordElement, strict: bool = True): - """""" - if isinstance(other, Interval): return other.intersect(self) - elif isinstance(other, Rectangle): return self.__class__( @@ -677,21 +608,19 @@ def intersect(self, other: BaseCoordElement, strict: bool = True): raise NotSupportedShapeError( "The intersection between a Rectangle and a Quadrilateral might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) - else: - warnings.warn( - f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" - ) - return self.intersect(other.to_rectangle()) + warnings.warn( + f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.intersect(other.to_rectangle()) else: raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def union(self, other: BaseCoordElement, strict: bool = True): - """""" if isinstance(other, Interval): return other.intersect(self) - elif isinstance(other, Rectangle): return self.__class__( min(self.x_1, other.x_1), @@ -705,12 +634,12 @@ def union(self, other: BaseCoordElement, strict: bool = True): raise NotSupportedShapeError( "The intersection between an Interval and a Quadrilateral might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) - else: - warnings.warn( - f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" - ) - return self.union(other.to_rectangle()) + warnings.warn( + f"With `strict=False`, the other of shape {other.__class__} will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.union(other.to_rectangle()) else: raise Exception(f"Invalid input type {other.__class__} for other") @@ -766,11 +695,7 @@ def crop_image(self, image): return image[int(y_1) : int(y_2), int(x_1) : int(x_2)] def to_interval(self, axis, **kwargs): - if axis == "x": - start, end = self.x_1, self.x_2 - else: - start, end = self.y_1, self.y_2 - + start, end = (self.x_1, self.x_2) if axis == "x" else (self.y_1, self.y_2) return Interval(start, end, axis=axis, **kwargs) def to_quadrilateral(self): @@ -809,9 +734,7 @@ class Quadrilateral(BaseCoordElement): _name = "quadrilateral" _features = ["points", "height", "width"] - def __init__( - self, points: Union[np.ndarray, List, List[List]], height=None, width=None - ): + def __init__(self, points: Union[np.ndarray, List, List[List]], height=None, width=None): if isinstance(points, np.ndarray): if points.shape != (4, 2): @@ -822,9 +745,7 @@ def __init__( elif len(points) == 4 and isinstance(points[0], list): points = np.array(points) else: - raise ValueError( - f"Invalid number of points element {len(points)}. Should be 8." - ) + raise ValueError(f"Invalid number of points element {len(points)}. Should be 8.") else: raise ValueError( f"Invalid input type for points {type(points)}." @@ -997,9 +918,7 @@ def is_in(self, other, soft_margin={}, center=False): elif isinstance(other, Rectangle): x_interval = other.to_interval(axis="x") y_interval = other.to_interval(axis="y") - return self.is_in(x_interval, center=center) and self.is_in( - y_interval, center=center - ) + return self.is_in(x_interval, center=center) and self.is_in(y_interval, center=center) elif isinstance(other, Quadrilateral): @@ -1019,46 +938,47 @@ def is_in(self, other, soft_margin={}, center=False): @support_textblock def intersect(self, other: BaseCoordElement, strict: bool = True): - """""" - if strict: raise NotSupportedShapeError( "The intersection between a Quadrilateral and other objects might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) + + if isinstance(other, (Interval, Rectangle)): + warnings.warn( + f"With `strict=False`, the current Quadrilateral object will be converted to {Rectangle} for obtaining the intersection" + ) + + return other.intersect(self.to_rectangle()) + elif isinstance(other, Quadrilateral): + warnings.warn( + f"With `strict=False`, both input Quadrilateral objects will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.to_rectangle().intersect(other.to_rectangle()) else: - if isinstance(other, Interval) or isinstance(other, Rectangle): - warnings.warn( - f"With `strict=False`, the current Quadrilateral object will be converted to {Rectangle} for obtaining the intersection" - ) - return other.intersect(self.to_rectangle()) - elif isinstance(other, Quadrilateral): - warnings.warn( - f"With `strict=False`, both input Quadrilateral objects will be converted to {Rectangle} for obtaining the intersection" - ) - return self.to_rectangle().intersect(other.to_rectangle()) - else: - raise Exception(f"Invalid input type {other.__class__} for other") + raise Exception(f"Invalid input type {other.__class__} for other") @support_textblock def union(self, other: BaseCoordElement, strict: bool = True): - """""" if strict: raise NotSupportedShapeError( "The intersection between a Quadrilateral and other objects might generate Polygon shapes that are not supported in the current version of layoutparser. You can pass `strict=False` in the input that converts the Quadrilateral to Rectangle to avoid this Exception." ) + + if isinstance(other, (Interval, Rectangle)): + warnings.warn( + f"With `strict=False`, the current Quadrilateral object will be converted to {Rectangle} for obtaining the intersection" + ) + + return other.union(self.to_rectangle()) + elif isinstance(other, Quadrilateral): + warnings.warn( + f"With `strict=False`, both input Quadrilateral objects will be converted to {Rectangle} for obtaining the intersection" + ) + + return self.to_rectangle().union(other.to_rectangle()) else: - if isinstance(other, Interval) or isinstance(other, Rectangle): - warnings.warn( - f"With `strict=False`, the current Quadrilateral object will be converted to {Rectangle} for obtaining the intersection" - ) - return other.union(self.to_rectangle()) - elif isinstance(other, Quadrilateral): - warnings.warn( - f"With `strict=False`, both input Quadrilateral objects will be converted to {Rectangle} for obtaining the intersection" - ) - return self.to_rectangle().union(other.to_rectangle()) - else: - raise Exception(f"Invalid input type {other.__class__} for other") + raise Exception(f"Invalid input type {other.__class__} for other") def pad(self, left=0, right=0, top=0, bottom=0, safe_mode=True): @@ -1117,13 +1037,8 @@ def crop_image(self, image): ) def to_interval(self, axis, **kwargs): - x_1, y_1, x_2, y_2 = self.coordinates - if axis == "x": - start, end = x_1, x_2 - else: - start, end = y_1, y_2 - + start, end = (x_1, x_2) if axis == "x" else (y_1, y_2) return Interval(start, end, axis=axis, **kwargs) def to_rectangle(self): @@ -1164,9 +1079,7 @@ def to_dict(self) -> Dict[str, Any]: ALL_BASECOORD_ELEMENTS = [Interval, Rectangle, Quadrilateral] BASECOORD_ELEMENT_NAMEMAP = {ele._name: ele for ele in ALL_BASECOORD_ELEMENTS} -BASECOORD_ELEMENT_INDEXMAP = { - ele._name: idx for idx, ele in enumerate(ALL_BASECOORD_ELEMENTS) -} +BASECOORD_ELEMENT_INDEXMAP = {ele._name: idx for idx, ele in enumerate(ALL_BASECOORD_ELEMENTS)} @inherit_docstrings(base_class=BaseCoordElement) @@ -1195,9 +1108,7 @@ class TextBlock(BaseLayoutElement): _name = "textblock" _features = ["text", "id", "type", "parent", "next", "score"] - def __init__( - self, block, text=None, id=None, type=None, parent=None, next=None, score=None - ): + def __init__(self, block, text=None, id=None, type=None, parent=None, next=None, score=None): assert isinstance(block, BaseCoordElement) self.block = block @@ -1298,12 +1209,10 @@ def crop_image(self, image): def to_interval(self, axis: Optional[str] = None, **kwargs): if isinstance(self.block, Interval): return self - else: - if not axis: - raise ValueError( - f"Please provide valid `axis` values {'x' or 'y'} as the input" - ) - return self.set(block=self.block.to_interval(axis=axis, **kwargs)) + if not axis: + raise ValueError(f"Please provide valid `axis` values {'x' or 'y'} as the input") + + return self.set(block=self.block.to_interval(axis=axis, **kwargs)) def to_rectangle(self): if isinstance(self.block, Rectangle): @@ -1349,4 +1258,4 @@ def from_dict(cls, data: Dict[str, Any]) -> "TextBlock": block = BASECOORD_ELEMENT_NAMEMAP[data["block_type"]].from_dict(data) - return cls(block, **{f: data.get(f, None) for f in cls._features}) \ No newline at end of file + return cls(block, **{f: data.get(f, None) for f in cls._features}) diff --git a/src/layoutparser/io/basic.py b/src/layoutparser/io/basic.py index 0b2df36..5ec0bb3 100644 --- a/src/layoutparser/io/basic.py +++ b/src/layoutparser/io/basic.py @@ -14,15 +14,15 @@ import ast import json -from typing import List, Union, Dict, Dict, Any +from typing import Dict, List, Union import pandas as pd -from ..elements import ( +from layoutparser.elements import ( + BASECOORD_ELEMENT_NAMEMAP, BaseLayoutElement, - TextBlock, Layout, - BASECOORD_ELEMENT_NAMEMAP, + TextBlock, ) @@ -38,7 +38,7 @@ def load_json(filename: str) -> Union[BaseLayoutElement, Layout]: Based on the JSON file format, it will automatically parse the type of the data and load it accordingly. """ - with open(filename, "r") as fp: + with open(filename) as fp: res = json.load(fp) return load_dict(res) @@ -70,23 +70,22 @@ def load_dict(data: Union[Dict, List[Dict]]) -> Union[BaseLayoutElement, Layout] if "page_data" in data: # It is a layout instance return Layout(load_dict(data["blocks"])._blocks, page_data=data["page_data"]) - else: - - if data["block_type"] not in BASECOORD_ELEMENT_NAMEMAP: - raise ValueError(f"Invalid block_type {data['block_type']}") + if data["block_type"] not in BASECOORD_ELEMENT_NAMEMAP: + raise ValueError(f"Invalid block_type {data['block_type']}") - # Check if it is a textblock - is_textblock = any(ele in data for ele in TextBlock._features) - if is_textblock: - return TextBlock.from_dict(data) - else: - return BASECOORD_ELEMENT_NAMEMAP[data["block_type"]].from_dict(data) + # Check if it is a textblock + is_textblock = any(ele in data for ele in TextBlock._features) + return ( + TextBlock.from_dict(data) + if is_textblock + else BASECOORD_ELEMENT_NAMEMAP[data["block_type"]].from_dict(data) + ) elif isinstance(data, list): return Layout([load_dict(ele) for ele in data]) else: - raise ValueError(f"Invalid input JSON structure.") + raise ValueError("Invalid input JSON structure.") def load_csv(filename: str, block_type: str = None) -> Layout: @@ -126,23 +125,14 @@ def load_dataframe(df: pd.DataFrame, block_type: str = None) -> Layout: The parsed Layout object from the CSV file. """ df = df.copy() - if "points" in df.columns: - if df["points"].dtype == object: - df["points"] = df["points"].map( - lambda x: ast.literal_eval(x) if not pd.isna(x) else x - ) + if "points" in df.columns and df["points"].dtype == object: + df["points"] = df["points"].map(lambda x: x if pd.isna(x) else ast.literal_eval(x)) if block_type is None: if "block_type" not in df.columns: - raise ValueError( - "`block_type` not specified both in dataframe and arguments" - ) + raise ValueError("`block_type` not specified both in dataframe and arguments") else: df["block_type"] = block_type - - if any(col in TextBlock._features for col in df.columns): - # Automatically setting index for textblock - if "id" not in df.columns: - df["id"] = df.index - + if any(col in TextBlock._features for col in df.columns) and "id" not in df.columns: + df["id"] = df.index return load_dict(df.apply(lambda x: x.dropna().to_dict(), axis=1).to_list()) diff --git a/src/layoutparser/models/base_layoutmodel.py b/src/layoutparser/models/base_layoutmodel.py index 495f34b..bfd2ab3 100644 --- a/src/layoutparser/models/base_layoutmodel.py +++ b/src/layoutparser/models/base_layoutmodel.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union, Dict from abc import ABC, abstractmethod +from typing import Dict, Optional, Tuple, Union + +from layoutparser.file_utils import requires_backends +from layoutparser.models.model_config import ( + LayoutModelConfig, + add_identifier_for_config, + is_lp_layout_model_config_any_format, + layout_model_config_parser, +) -from .model_config import LayoutModelConfig, add_identifier_for_config, layout_model_config_parser, is_lp_layout_model_config_any_format -from ..file_utils import requires_backends class BaseLayoutModel(ABC): @@ -25,7 +31,6 @@ class BaseLayoutModel(ABC): @abstractmethod def DEPENDENCIES(self): """DEPENDENCIES lists all necessary dependencies for the class.""" - pass @property @abstractmethod @@ -41,30 +46,33 @@ def MODEL_CATALOG(self) -> Dict[str, Dict[str, str]]: def detect(self, image: Union["np.ndarray", "Image.Image"]): pass - @abstractmethod def image_loader(self, image: Union["np.ndarray", "Image.Image"]): """It will process the input images appropriately to the target format.""" - pass - - def _parse_config(self, config_path:str, identifier:str) -> Union[LayoutModelConfig, str]: - - if is_lp_layout_model_config_any_format(config_path): - config_path = add_identifier_for_config(config_path, identifier) - for dataset_name in self.MODEL_CATALOG: - if dataset_name in config_path: - default_model_arch = list(self.MODEL_CATALOG[dataset_name].keys())[0] - # Use the first model_name for the dataset as the default_model_arch - return layout_model_config_parser(config_path, self.DETECTOR_NAME, default_model_arch) - raise ValueError(f"The config {config_path} is not a valid config for {self.__class__}, " - f"possibly because there aren't models trained for the specified dataset.") - else: + + def _parse_config(self, config_path: str, identifier: str) -> Union[LayoutModelConfig, str]: + + if not is_lp_layout_model_config_any_format(config_path): return config_path + config_path = add_identifier_for_config(config_path, identifier) + for dataset_name in self.MODEL_CATALOG: + if dataset_name in config_path: + default_model_arch = list(self.MODEL_CATALOG[dataset_name].keys())[0] + # Use the first model_name for the dataset as the default_model_arch + return layout_model_config_parser( + config_path, self.DETECTOR_NAME, default_model_arch + ) + raise ValueError( + f"The config {config_path} is not a valid config for {self.__class__}, " + f"possibly because there aren't models trained for the specified dataset." + ) - def config_parser(self, config_path:str, model_path: Optional[str], allow_empty_path=False) -> Tuple[str, str]: + def config_parser( + self, config_path: str, model_path: Optional[str], allow_empty_path=False + ) -> Tuple[str, str]: config_path = self._parse_config(config_path, "config") - + if isinstance(config_path, str) and model_path is None: if not allow_empty_path: raise ValueError( @@ -85,4 +93,4 @@ def config_parser(self, config_path:str, model_path: Optional[str], allow_empty_ def __new__(cls, *args, **kwargs): requires_backends(cls, cls.DEPENDENCIES) - return super().__new__(cls) \ No newline at end of file + return super().__new__(cls)