Skip to content

Commit 9f28fb9

Browse files
Enable scaling for mujoco object (#748)
1 parent 5794d93 commit 9f28fb9

2 files changed

Lines changed: 34 additions & 6 deletions

File tree

robosuite/models/objects/objects.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ class MujocoObject(MujocoModel):
5353
5454
"""
5555

56-
def __init__(self, obj_type="all", duplicate_collision_geoms=True):
56+
def __init__(self, obj_type="all", duplicate_collision_geoms=True, scale=None):
5757
super().__init__()
5858
self.asset = ET.Element("asset")
5959
assert obj_type in GEOM_GROUPS, "object type must be one in {}, got: {} instead.".format(GEOM_GROUPS, obj_type)
6060
self.obj_type = obj_type
6161
self.duplicate_collision_geoms = duplicate_collision_geoms
62-
62+
self._scale = scale
6363
# Attributes that should be filled in within the subclass
6464
self._name = None
6565
self._obj = None
@@ -73,6 +73,33 @@ def __init__(self, obj_type="all", duplicate_collision_geoms=True):
7373
self._contact_geoms = None
7474
self._visual_geoms = None
7575

76+
if self._scale is not None:
77+
self.set_scale(self._scale)
78+
79+
def set_scale(self, scale, obj=None):
80+
"""
81+
Scales each geom, mesh, site, and body.
82+
Called during initialization but can also be used externally
83+
84+
Args:
85+
scale (float or list of floats): Scale factor (1 or 3 dims)
86+
obj (ET.Element) Root object to apply. Defaults to root object of model
87+
"""
88+
if obj is None:
89+
obj = self._obj
90+
91+
self._scale = scale
92+
93+
# Use the centralized scaling utility function
94+
scale_mjcf_model(
95+
obj=obj,
96+
asset_root=self.asset,
97+
worldbody=None, # because we don't have a worldbody in MujocoObject
98+
scale=scale,
99+
get_elements_func=get_elements,
100+
scale_slide_joints=False, # MujocoObject doesn't handle slide joints
101+
)
102+
76103
def merge_assets(self, other):
77104
"""
78105
Merges @other's assets in a custom logic.

robosuite/utils/mjcf_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ def scale_site_element(element, scale_array):
10631063
element.set("size", s_size)
10641064

10651065

1066-
def scale_mjcf_model(obj, asset_root, worldbody, scale, get_elements_func, scale_slide_joints=True):
1066+
def scale_mjcf_model(obj, asset_root, scale, get_elements_func, worldbody=None, scale_slide_joints=True):
10671067
"""
10681068
Scales all elements (geoms, meshes, bodies, joints, sites) in an MJCF model.
10691069
@@ -1101,9 +1101,10 @@ def scale_mjcf_model(obj, asset_root, worldbody, scale, get_elements_func, scale
11011101
scale_joint_element(elem, scale_array, scale_slide_joints)
11021102

11031103
# Scale sites
1104-
site_pairs = get_elements_func(worldbody, "site")
1105-
for (_, elem) in site_pairs:
1106-
scale_site_element(elem, scale_array)
1104+
if worldbody is not None:
1105+
site_pairs = get_elements_func(worldbody, "site")
1106+
for (_, elem) in site_pairs:
1107+
scale_site_element(elem, scale_array)
11071108

11081109
return scale_array
11091110

0 commit comments

Comments
 (0)