Skip to content

Commit 136ac6a

Browse files
committed
fix site scaling to get all sites from world body
1 parent 4062cb4 commit 136ac6a

3 files changed

Lines changed: 4 additions & 24 deletions

File tree

robosuite/models/arenas/arena.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def set_scale(self, scale: Union[float, List[float]], obj_name: str):
155155
scale_mjcf_model(
156156
obj=obj,
157157
asset_root=self.asset,
158+
worldbody=self.worldbody,
158159
scale=scale,
159160
get_elements_func=get_elements,
160161
scale_slide_joints=False, # Arena doesn't handle slide joints

robosuite/models/objects/objects.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,28 +97,6 @@ def get_obj(self):
9797
assert self._obj is not None, "Object XML tree has not been generated yet!"
9898
return self._obj
9999

100-
def set_scale(self, scale, obj=None):
101-
"""
102-
Scales each geom, mesh, site, body, and joint ranges (for slide joints).
103-
Called during initialization but can also be used externally.
104-
Args:
105-
scale (float or list of floats): Scale factor (1 or 3 dims)
106-
obj (ET.Element): Root object to apply scaling to. Defaults to root object of model.
107-
"""
108-
if obj is None:
109-
obj = self._obj
110-
111-
self._scale = scale
112-
113-
# Use the centralized scaling utility function
114-
scale_mjcf_model(
115-
obj=obj,
116-
asset_root=self.asset,
117-
scale=scale,
118-
get_elements_func=get_elements,
119-
scale_slide_joints=True,
120-
)
121-
122100
def exclude_from_prefixing(self, inp):
123101
"""
124102
A function that should take in either an ET.Element or its attribute (str) and return either True or False,
@@ -506,6 +484,7 @@ def set_scale(self, scale, obj=None):
506484
scale_mjcf_model(
507485
obj=obj,
508486
asset_root=self.asset,
487+
worldbody=self.worldbody,
509488
scale=scale,
510489
get_elements_func=get_elements,
511490
scale_slide_joints=False, # MujocoXMLObject doesn't handle slide joints

robosuite/utils/mjcf_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ def scale_site_element(element, scale_array):
10601060
element.set("size", s_size)
10611061

10621062

1063-
def scale_mjcf_model(obj, asset_root, scale, get_elements_func, scale_slide_joints=True):
1063+
def scale_mjcf_model(obj, asset_root, worldbody, scale, get_elements_func, scale_slide_joints=True):
10641064
"""
10651065
Scales all elements (geoms, meshes, bodies, joints, sites) in an MJCF model.
10661066
@@ -1098,7 +1098,7 @@ def scale_mjcf_model(obj, asset_root, scale, get_elements_func, scale_slide_join
10981098
scale_joint_element(elem, scale_array, scale_slide_joints)
10991099

11001100
# Scale sites
1101-
site_pairs = get_elements_func(obj, "site")
1101+
site_pairs = get_elements_func(worldbody, "site")
11021102
for (_, elem) in site_pairs:
11031103
scale_site_element(elem, scale_array)
11041104

0 commit comments

Comments
 (0)