Skip to content

Commit b505c79

Browse files
committed
add oak base
1 parent 9b25180 commit b505c79

File tree

2 files changed

+224
-0
lines changed

2 files changed

+224
-0
lines changed

oikit/oak_base.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
import json
3+
4+
CATEGORIES = [
5+
'pincer',
6+
'hammer',
7+
'power_drill',
8+
'can',
9+
'screwdriver',
10+
'squeeze_tube',
11+
'cup',
12+
'wrench',
13+
'game_controller',
14+
'camera',
15+
'headphones',
16+
'mouse',
17+
'frying_pan',
18+
'bowl',
19+
'trigger_sprayer',
20+
'mug',
21+
'binoculars',
22+
'lotion_bottle',
23+
'flashlight',
24+
'eyeglasses',
25+
'lightbulb',
26+
'marker',
27+
'toothbrush',
28+
'bottle',
29+
'cylinder_bottle',
30+
'wineglass',
31+
'teapot',
32+
'scissor',
33+
'knife',
34+
]
35+
36+
ATTRIBUTE_PHRASES = [
37+
"contain_sth",
38+
"cover_sth",
39+
"pump_out_sth",
40+
"cut_sth",
41+
"stab_sth",
42+
"flow_out_sth",
43+
"flow_in_sth",
44+
"secure_sth",
45+
"tighten_sth",
46+
"loosen_sth",
47+
"control_sth",
48+
"clamp_sth",
49+
"brush_sth",
50+
"trigger_sth",
51+
"observe_sth",
52+
"illuminate_sth",
53+
"point_to_sth",
54+
"shear_sth",
55+
"attach_to_sth",
56+
"connect_to_sth",
57+
"knock_sth",
58+
"spray_sth",
59+
"draw_sth",
60+
"no_function",
61+
"held_by_hand",
62+
"pulled_by_hand",
63+
"pressed/unpressed_by_hand",
64+
"screwed/unscrewed_by_hand",
65+
"plugged/unplugged_by_hand",
66+
"squeezed/unsqueezed_by_hand",
67+
]
68+
69+
70+
class ObjectAffordanceKnowledge:
71+
72+
def __init__(self, category, obj_id, n_parts, obj_dir, part_files):
73+
self.category = category
74+
self.obj_id = obj_id
75+
self.n_parts = n_parts
76+
self.obj_dir = obj_dir
77+
78+
self.part_names = []
79+
self.part_name_to_segs = {}
80+
self.part_name_to_attrs = {}
81+
self.part_attr_to_names = {}
82+
83+
for pf in part_files:
84+
assert pf.endswith(".ply") and pf.startswith("part_"), f"part file {pf} is not valid"
85+
pif = os.path.join(obj_dir, pf[:-4] + ".json")
86+
assert os.path.exists(pif), f"part info file {pif} does not exist"
87+
with open(pif, "r") as f:
88+
part_info = json.load(f)
89+
part_name = part_info["name"] # str
90+
part_attrs = part_info["attr"] # list of str
91+
92+
self.part_names.append(part_name)
93+
self.part_name_to_segs[part_name] = os.path.join(obj_dir, pf)
94+
self.part_name_to_attrs[part_name] = part_attrs
95+
for attr in part_attrs:
96+
if attr not in self.part_attr_to_names:
97+
self.part_attr_to_names[attr] = []
98+
self.part_attr_to_names[attr].append(part_name)
99+
100+
def get_part_name_by_attribute(self, attribute):
101+
if attribute == "attach_to":
102+
part_name_list = []
103+
for attr in self.part_attr_to_names.keys():
104+
if attr.startswith("attach_to"):
105+
part_name_list.extend(self.part_attr_to_names[attr])
106+
elif attribute == "connect_to":
107+
part_name_list = []
108+
for attr in self.part_attr_to_names.keys():
109+
if attr.startswith("connect_to"):
110+
part_name_list.extend(self.part_attr_to_names[attr])
111+
else:
112+
part_name_list = self.part_attr_to_names[attribute]
113+
114+
return part_name_list
115+
116+
def get_part_attribute_by_name(self, name):
117+
return self.part_name_to_attrs[name]
118+
119+
def __repr__(self):
120+
return f"cate:{self.category}--id:{self.obj_id}"
121+
122+
123+
class OakBase:
124+
125+
def __init__(self):
126+
self._data_dir = os.path.join(os.environ["OAKINK_DIR"], "OakBase")
127+
self.categories = {}
128+
self.attributes = {}
129+
for cate in os.listdir(self._data_dir):
130+
cate_dir = os.path.join(self._data_dir, cate)
131+
if not os.path.isdir(cate_dir):
132+
continue
133+
134+
if cate not in self.categories:
135+
self.categories[cate] = []
136+
137+
for obj_id in os.listdir(cate_dir):
138+
obj_dir = os.path.join(cate_dir, obj_id)
139+
if not os.path.isdir(obj_dir):
140+
continue
141+
142+
part_files = [pf for pf in os.listdir(obj_dir) if pf.endswith(".ply")]
143+
oak = ObjectAffordanceKnowledge(category=cate,
144+
obj_id=obj_id,
145+
n_parts=len(part_files),
146+
obj_dir=obj_dir,
147+
part_files=part_files)
148+
self.categories[cate].append(oak)
149+
150+
attrs = list(oak.part_attr_to_names.keys())
151+
for attr in attrs:
152+
if attr not in self.attributes:
153+
self.attributes[attr] = []
154+
self.attributes[attr].append(oak)
155+
156+
def get_objs_by_category(self, category):
157+
return self.categories[category]
158+
159+
def get_objs_by_attribute(self, attribute):
160+
if attribute == "attach_to":
161+
obj_list = []
162+
for attr in self.attributes.keys():
163+
if attr.startswith("attach_to"):
164+
obj_list.extend(self.attributes[attr])
165+
elif attribute == "connect_to":
166+
obj_list = []
167+
for attr in self.attributes.keys():
168+
if attr.startswith("connect_to"):
169+
obj_list.extend(self.attributes[attr])
170+
else:
171+
obj_list = self.attributes[attribute]
172+
173+
return obj_list

scripts/demo_oak_base.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
import os
3+
from typing import List
4+
5+
import open3d as o3d
6+
7+
from oikit.oak_base import OakBase
8+
from oikit.oak_base import ObjectAffordanceKnowledge as OAK
9+
10+
11+
def main(arg):
12+
oakbase = OakBase()
13+
14+
# get all categories
15+
all_cates: List[str] = oakbase.categories.keys()
16+
17+
# get all objects in a category
18+
cate_objs: List[OAK] = oakbase.get_objs_by_category("teapot")
19+
print(f"Category: teapot has {len(cate_objs)} instances")
20+
21+
# get all objects that contain a specific attribute
22+
attr_objs: List[OAK] = oakbase.get_objs_by_attribute("observe_sth")
23+
24+
test_obj: OAK = attr_objs[0]
25+
# get all the attributes that the object has:
26+
all_attrs_of_obj: List[str] = test_obj.part_attr_to_names.keys()
27+
print(f"Object: {test_obj} has attributes: {list(all_attrs_of_obj)}")
28+
29+
# get the parts of the object that contain the specific attribute
30+
part_names: List[str] = test_obj.get_part_name_by_attribute("observe_sth")
31+
32+
test_part_name = part_names[0]
33+
# get all the attributes that the part has:
34+
all_attrs_of_part: List[str] = test_obj.get_part_attribute_by_name(test_part_name)
35+
print(f"Part: {test_part_name} has attributes: {all_attrs_of_part}")
36+
37+
# get the path of the part's segmentation point cloud
38+
part_seg_path: str = test_obj.part_name_to_segs[test_part_name]
39+
40+
# visualize the part's segmentation point cloud
41+
obj_pc = o3d.io.read_point_cloud(part_seg_path)
42+
obj_pc.paint_uniform_color([0.4, 0.8, 0.95])
43+
o3d.visualization.draw_geometries([obj_pc])
44+
45+
46+
if __name__ == "__main__":
47+
parser = argparse.ArgumentParser(description="Demo OakBase")
48+
parser.add_argument("--data_dir", type=str, default="data", help="environment variable 'OAKINK_DIR'")
49+
arg = parser.parse_args()
50+
os.environ["OAKINK_DIR"] = arg.data_dir
51+
main(arg)

0 commit comments

Comments
 (0)