Skip to content

Commit 2c03fb0

Browse files
committed
feat(env): allow nested space types
1 parent 3b21945 commit 2c03fb0

2 files changed

Lines changed: 139 additions & 0 deletions

File tree

python/rcs/envs/space_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ def value(t, path=""):
200200
{key: value(get_args(t)[1], f"{path}/{key}") for key in child_dict_keys_to_unfold[unfold_key]}
201201
)
202202

203+
if not hasattr(t, "__metadata__"):
204+
return gym.spaces.Dict(
205+
{name: value(sub_t, path) for name, sub_t in get_type_hints(t, include_extras=True).items()}
206+
)
207+
203208
if len(t.__metadata__) == 2 and callable(t.__metadata__[0]):
204209
# space can be parametrized and is a function
205210
assert params is not None, "No params given."

python/tests/test_envs.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,55 @@ class SimpleNestedSpace(RCSpaceType):
3939
]
4040

4141

42+
class SimpleTypeNestedSpace(RCSpaceType):
43+
robots_joints: dict[
44+
Annotated[str, "robots"],
45+
SimpleSpace,
46+
]
47+
48+
49+
class CameraSpace(RCSpaceType):
50+
data: Annotated[
51+
np.ndarray,
52+
# needs to be filled with values downstream
53+
lambda height, width, color_dim=3, dtype=np.uint8, low=0, high=255: gym.spaces.Box(
54+
low=low,
55+
high=high,
56+
shape=(height, width, color_dim),
57+
dtype=dtype,
58+
),
59+
"frame",
60+
]
61+
intrinsics: Annotated[
62+
np.ndarray,
63+
gym.spaces.Box(
64+
low=-np.inf,
65+
high=np.inf,
66+
shape=(3, 4),
67+
dtype=np.float64,
68+
),
69+
]
70+
extrinsics: Annotated[
71+
np.ndarray,
72+
gym.spaces.Box(
73+
low=-np.inf,
74+
high=np.inf,
75+
shape=(4, 4),
76+
dtype=np.float64,
77+
),
78+
]
79+
80+
81+
class AdvancedTypeNestedSpace(RCSpaceType):
82+
frames: dict[
83+
Annotated[str, "camera_names"],
84+
dict[
85+
Annotated[str, "camera_type"], # "rgb" or "depth"
86+
CameraSpace,
87+
],
88+
]
89+
90+
4291
class AdvancedNestedSpace(RCSpaceType):
4392
frames: dict[
4493
Annotated[str, "cams"],
@@ -105,6 +154,91 @@ def test_simple_nested_space(self):
105154
}
106155
)
107156

157+
def test_simple_type_nested_space(self):
158+
assert get_space(SimpleTypeNestedSpace, child_dict_keys_to_unfold={"robots": ["robot1"]}) == gym.spaces.Dict(
159+
{
160+
"robots_joints": gym.spaces.Dict(
161+
{
162+
"robot1": gym.spaces.Dict(
163+
{
164+
"my_int": gym.spaces.Discrete(1),
165+
"my_float": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
166+
}
167+
),
168+
}
169+
),
170+
}
171+
)
172+
173+
def test_advanced_type_nested_space(self):
174+
assert get_space(
175+
AdvancedTypeNestedSpace,
176+
child_dict_keys_to_unfold={"camera_names": ["cam1"], "camera_type": ["depth", "rgb"]},
177+
params={
178+
"/cam1/rgb/frame": {
179+
"height": 480,
180+
"width": 640,
181+
"dtype": np.uint8,
182+
"low": 0,
183+
"high": 255,
184+
"color_dim": 3,
185+
},
186+
"/cam1/depth/frame": {
187+
"height": 480,
188+
"width": 640,
189+
"dtype": np.uint16,
190+
"low": 0,
191+
"high": 65535,
192+
"color_dim": 1,
193+
},
194+
},
195+
) == gym.spaces.Dict(
196+
{
197+
"frames": gym.spaces.Dict(
198+
{
199+
"cam1": gym.spaces.Dict(
200+
{
201+
"depth": gym.spaces.Dict(
202+
{
203+
"data": gym.spaces.Box(low=0, high=65535, shape=(480, 640, 1), dtype=np.uint16),
204+
"intrinsics": gym.spaces.Box(
205+
low=-np.inf,
206+
high=np.inf,
207+
shape=(3, 4),
208+
dtype=np.float64,
209+
),
210+
"extrinsics": gym.spaces.Box(
211+
low=-np.inf,
212+
high=np.inf,
213+
shape=(4, 4),
214+
dtype=np.float64,
215+
),
216+
}
217+
),
218+
"rgb": gym.spaces.Dict(
219+
{
220+
"data": gym.spaces.Box(low=0, high=255, shape=(480, 640, 3), dtype=np.uint16),
221+
"intrinsics": gym.spaces.Box(
222+
low=-np.inf,
223+
high=np.inf,
224+
shape=(3, 4),
225+
dtype=np.float64,
226+
),
227+
"extrinsics": gym.spaces.Box(
228+
low=-np.inf,
229+
high=np.inf,
230+
shape=(4, 4),
231+
dtype=np.float64,
232+
),
233+
}
234+
),
235+
}
236+
),
237+
}
238+
),
239+
}
240+
)
241+
108242
def test_advanced_nested_space(self):
109243

110244
assert get_space(

0 commit comments

Comments
 (0)