Skip to content

Commit

Permalink
[WIP] update files
Browse files Browse the repository at this point in the history
  • Loading branch information
shunk031 committed Mar 24, 2024
1 parent 8b4c70e commit d203bcc
Show file tree
Hide file tree
Showing 3 changed files with 719 additions and 599 deletions.
62 changes: 60 additions & 2 deletions MSCOCO.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def compress_rle(
height: int,
width: int,
) -> CompressedRLE:
breakpoint()
if iscrowd:
rle = cocomask.frPyObjects(segmentation, h=height, w=width)
else:
Expand Down Expand Up @@ -498,6 +499,52 @@ def from_dict(
)


@dataclass
class StuffAnnotationData(InstancesAnnotationData):
@classmethod
def from_dict(
cls,
json_dict: JsonDict,
images: Dict[ImageId, ImageData],
decode_rle: bool,
) -> "StuffAnnotationData":
segmentation = json_dict["segmentation"]
image_id = json_dict["image_id"]
image_data = images[image_id]
iscrowd = bool(json_dict["iscrowd"])

segmentation_mask = (
cls.rle_segmentation_to_mask(
segmentation=segmentation,
iscrowd=iscrowd,
height=image_data.height,
width=image_data.width,
)
if decode_rle
else cls.compress_rle(
segmentation=segmentation,
iscrowd=iscrowd,
height=image_data.height,
width=image_data.width,
)
)
return cls(
#
# for AnnotationData
#
annotation_id=json_dict["id"],
image_id=image_id,
#
# for InstancesAnnotationData
#
segmentation=segmentation_mask,
area=json_dict["area"],
iscrowd=iscrowd,
bbox=json_dict["bbox"],
category_id=json_dict["category_id"],
)


class LicenseDict(TypedDict):
license_id: LicenseId
name: str
Expand Down Expand Up @@ -926,10 +973,21 @@ def split_generators(self, file_paths: Dict[str, Any]) -> List[ds.SplitGenerator
]

def load_data(
self, ann_dicts: List[JsonDict], tqdm_desc: str = "Load stuff data", **kwargs
self,
ann_dicts: List[JsonDict],
images: Dict[ImageId, ImageData],
decode_rle: bool,
tqdm_desc: str = "Load stuff data",
):
annotations = defaultdict(list)
breakpoint()
ann_dicts = sorted(ann_dicts, key=lambda d: d["image_id"])

for ann_dict in tqdm(ann_dicts, desc=tqdm_desc):
ann_data = StuffAnnotationData.from_dict(
ann_dict, images=images, decode_rle=decode_rle
)
annotations[ann_data.image_id].append(ann_data)
return annotations

def generate_examples(
self,
Expand Down
Loading

0 comments on commit d203bcc

Please sign in to comment.