Skip to content

Commit 2f660c1

Browse files
Enable instance segmentation to include bodies in arena
1 parent 51cc017 commit 2f660c1

1 file changed

Lines changed: 54 additions & 7 deletions

File tree

robosuite/models/tasks/task.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,41 @@
1+
import xml.etree.ElementTree as ET
12
from copy import deepcopy
23

4+
import mujoco
5+
36
from robosuite.models.objects import MujocoObject
47
from robosuite.models.robots import RobotModel
58
from robosuite.models.world import MujocoWorldBase
69
from robosuite.utils.mjcf_utils import get_ids
710

811

12+
def get_subtree_geom_ids_by_group(model: mujoco.MjModel, body_id: int, group: int) -> list[int]:
13+
"""Get all geoms belonging to a subtree starting at a given body, filtered by group.
14+
15+
Args:
16+
model: MuJoCo model.
17+
body_id: ID of body where subtree starts.
18+
group: Group ID to filter geoms.
19+
20+
Returns:
21+
A list containing all subtree geom ids in the specified group.
22+
23+
Adapted from https://github.com/kevinzakka/mink/blob/main/mink/utils.py
24+
"""
25+
26+
def gather_geoms(body_id: int) -> list[int]:
27+
geoms: list[int] = []
28+
geom_start = model.body_geomadr[body_id]
29+
geom_end = geom_start + model.body_geomnum[body_id]
30+
geoms.extend(geom_id for geom_id in range(geom_start, geom_end) if model.geom_group[geom_id] == group)
31+
children = [i for i in range(model.nbody) if model.body_parentid[i] == body_id]
32+
for child_id in children:
33+
geoms.extend(gather_geoms(child_id))
34+
return geoms
35+
36+
return gather_geoms(body_id)
37+
38+
939
class Task(MujocoWorldBase):
1040
"""
1141
Creates MJCF model for a task performed.
@@ -106,15 +136,32 @@ def generate_id_mappings(self, sim):
106136
for robot in self.mujoco_robots:
107137
models += [robot] + robot.models
108138

139+
worldbody = self.mujoco_arena.root.find("worldbody")
140+
exclude_bodies = ["table"]
141+
top_level_bodies = [
142+
body.attrib.get("name")
143+
for body in worldbody.findall("body")
144+
if body.attrib.get("name") not in exclude_bodies
145+
]
146+
models.extend(top_level_bodies)
147+
109148
# Parse all mujoco models from robots and objects
110149
for model in models:
111-
# Grab model class name and visual IDs
112-
cls = str(type(model)).split("'")[1].split(".")[-1]
113-
inst = model.name
114-
id_groups = [
115-
get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"),
116-
get_ids(sim=sim, elements=model.sites, element_type="site"),
117-
]
150+
if isinstance(model, str):
151+
body_name = model
152+
visual_group_number = 1
153+
body_id = sim.model.body_name2id(body_name)
154+
inst, cls = body_name, body_name
155+
geom_ids = get_subtree_geom_ids_by_group(sim.model, body_id, visual_group_number)
156+
id_groups = [geom_ids, []]
157+
else:
158+
# Grab model class name and visual IDs
159+
cls = str(type(model)).split("'")[1].split(".")[-1]
160+
inst = model.name
161+
id_groups = [
162+
get_ids(sim=sim, elements=model.visual_geoms + model.contact_geoms, element_type="geom"),
163+
get_ids(sim=sim, elements=model.sites, element_type="site"),
164+
]
118165
group_types = ("geom", "site")
119166
ids_to_instances = (self._geom_ids_to_instances, self._site_ids_to_instances)
120167
ids_to_classes = (self._geom_ids_to_classes, self._site_ids_to_classes)

0 commit comments

Comments
 (0)